Files
gemini-banlancer/internal/service/key_validation_service.go

464 lines
13 KiB
Go

// Filename: internal/service/key_validation_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/channel"
CustomErrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"gemini-balancer/internal/task"
"gemini-balancer/internal/utils"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
TaskTypeTestKeys = "test_keys"
defaultConcurrency = 10
maxValidationConcurrency = 100
validationTaskTimeout = time.Hour
)
type KeyValidationService struct {
taskService task.Reporter
channel channel.ChannelProxy
db *gorm.DB
settingsManager *settings.SettingsManager
groupManager *GroupManager
store store.Store
keyRepo repository.KeyRepository
logger *logrus.Entry
}
func NewKeyValidationService(
ts task.Reporter,
ch channel.ChannelProxy,
db *gorm.DB,
ss *settings.SettingsManager,
gm *GroupManager,
st store.Store,
kr repository.KeyRepository,
logger *logrus.Logger,
) *KeyValidationService {
return &KeyValidationService{
taskService: ts,
channel: ch,
db: db,
settingsManager: ss,
groupManager: gm,
store: st,
keyRepo: kr,
logger: logger.WithField("component", "KeyValidationService🧐"),
}
}
// ==================== 公开接口 ====================
// ValidateSingleKey 验证单个密钥
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
// 1. 解密密钥
if err := s.keyRepo.Decrypt(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
}
// 2. 创建 HTTP 客户端和请求
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
s.logger.WithFields(logrus.Fields{
"key_id": key.ID,
"endpoint": endpoint,
}).Error("Failed to create validation request")
return fmt.Errorf("failed to create request: %w", err)
}
// 3. 修改请求(添加密钥认证头)
s.channel.ModifyRequest(req, key)
// 4. 执行请求
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
// 5. 检查响应状态
if resp.StatusCode == http.StatusOK {
return nil
}
// 6. 处理错误响应
return s.buildValidationError(resp)
}
// StartTestKeysTask 启动批量密钥测试任务
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
// 1. 解析和验证输入
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
}
// 2. 查询密钥模型
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(apiKeyModels) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
}
// 3. 批量解密密钥
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
}
// 4. 获取组配置
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
// 5. 构建验证端点
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
}
// 6. 创建任务
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
if err != nil {
return nil, err
}
// 7. 准备任务参数
params := s.buildValidationParams(opConfig)
// 8. 启动异步验证任务
go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
return taskStatus, nil
}
// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"statuses": statuses,
}).Info("Starting test task with status filter")
// 1. 根据过滤条件查询密钥
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyValues) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
}
// 2. 转换为文本格式并启动任务
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}
// ==================== 核心任务执行逻辑 ====================
// validationParams 验证参数封装
type validationParams struct {
timeout time.Duration
concurrency int
}
// buildValidationParams 构建验证参数
func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
settings := s.settingsManager.GetSettings()
// 从配置读取超时时间(而非硬编码)
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second // 仅在配置无效时使用默认值
}
// 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
var concurrency int
if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
concurrency = *opConfig.KeyCheckConcurrency
} else if settings.BaseKeyCheckConcurrency > 0 {
concurrency = settings.BaseKeyCheckConcurrency
} else {
concurrency = defaultConcurrency // 兜底默认值
}
// 限制最大并发数(防护措施)
if concurrency > maxValidationConcurrency {
concurrency = maxValidationConcurrency
}
return validationParams{
timeout: timeout,
concurrency: concurrency,
}
}
// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
func (s *KeyValidationService) runTestKeysTaskWithRecovery(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
}
// runTestKeysTask 执行批量密钥验证任务
func (s *KeyValidationService) runTestKeysTask(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"key_count": len(keys),
"concurrency": params.concurrency,
"timeout": params.timeout,
}).Info("Starting validation task")
// 1. 初始化结果收集
results := make([]models.KeyTestResult, len(keys))
// 2. 创建任务分发器
dispatcher := newValidationDispatcher(
keys,
params.concurrency,
s,
ctx,
taskID,
groupID,
endpoint,
params.timeout,
)
// 3. 执行并发验证
dispatcher.run(results)
// 4. 完成任务
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"processed": len(results),
}).Info("Validation task completed")
}
// ==================== 验证调度器 ====================
// validationJob 验证作业
type validationJob struct {
index int
key models.APIKey
}
// validationDispatcher 验证任务分发器
type validationDispatcher struct {
keys []models.APIKey
concurrency int
service *KeyValidationService
ctx context.Context
taskID string
groupID uint
endpoint string
timeout time.Duration
mu sync.Mutex
processedCount int
}
// newValidationDispatcher 创建验证分发器
func newValidationDispatcher(
keys []models.APIKey,
concurrency int,
service *KeyValidationService,
ctx context.Context,
taskID string,
groupID uint,
endpoint string,
timeout time.Duration,
) *validationDispatcher {
return &validationDispatcher{
keys: keys,
concurrency: concurrency,
service: service,
ctx: ctx,
taskID: taskID,
groupID: groupID,
endpoint: endpoint,
timeout: timeout,
}
}
// run 执行并发验证
func (d *validationDispatcher) run(results []models.KeyTestResult) {
var wg sync.WaitGroup
jobs := make(chan validationJob, len(d.keys))
// 启动 worker pool
for i := 0; i < d.concurrency; i++ {
wg.Add(1)
go d.worker(&wg, jobs, results)
}
// 分发任务
for i, key := range d.keys {
jobs <- validationJob{index: i, key: key}
}
close(jobs)
// 等待所有 worker 完成
wg.Wait()
}
// worker 验证工作协程
func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
defer wg.Done()
for job := range jobs {
result := d.validateKey(job.key)
d.mu.Lock()
results[job.index] = result
d.processedCount++
_ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
d.mu.Unlock()
}
}
// validateKey 验证单个密钥并返回结果
func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
// 1. 执行验证
validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
// 2. 构建结果和事件
result, event := d.buildResultAndEvent(key, validationErr)
// 3. 发布验证事件
d.publishValidationEvent(key.ID, event)
return result
}
// buildResultAndEvent 构建验证结果和事件
func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
GroupID: &d.groupID,
KeyID: &key.ID,
},
}
if validationErr == nil {
// 验证成功
event.RequestLog.IsSuccess = true
return models.KeyTestResult{
Key: key.APIKey,
Status: "valid",
Message: "Validation successful",
}, event
}
// 验证失败
event.RequestLog.IsSuccess = false
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
event.Error = apiErr
return models.KeyTestResult{
Key: key.APIKey,
Status: "invalid",
Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
}, event
}
// 其他错误
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
return models.KeyTestResult{
Key: key.APIKey,
Status: "error",
Message: "Validation check failed: " + validationErr.Error(),
}, event
}
// publishValidationEvent 发布验证事件
func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
eventData, err := json.Marshal(event)
if err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to marshal validation event")
return
}
if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to publish validation event")
}
}
// ==================== 辅助方法 ====================
// buildValidationError 构建验证错误
func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
errorMsg = "Failed to read error response body"
s.logger.WithError(readErr).Warn("Failed to read validation error response")
} else {
errorMsg = string(bodyBytes)
}
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
Code: "VALIDATION_FAILED",
}
}