// Filename: internal/service/key_validation_service.go package service import ( "context" "encoding/json" "fmt" "gemini-balancer/internal/channel" CustomErrors "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" "gemini-balancer/internal/task" "gemini-balancer/internal/utils" "io" "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "gorm.io/gorm" ) const ( TaskTypeTestKeys = "test_keys" defaultConcurrency = 10 maxValidationConcurrency = 100 validationTaskTimeout = time.Hour ) type KeyValidationService struct { taskService task.Reporter channel channel.ChannelProxy db *gorm.DB settingsManager *settings.SettingsManager groupManager *GroupManager store store.Store keyRepo repository.KeyRepository logger *logrus.Entry } func NewKeyValidationService( ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger, ) *KeyValidationService { return &KeyValidationService{ taskService: ts, channel: ch, db: db, settingsManager: ss, groupManager: gm, store: st, keyRepo: kr, logger: logger.WithField("component", "KeyValidationService🧐"), } } // ==================== 公开接口 ==================== // ValidateSingleKey 验证单个密钥 func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error { // 1. 解密密钥 if err := s.keyRepo.Decrypt(key); err != nil { return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err) } // 2. 创建 HTTP 客户端和请求 client := &http.Client{Timeout: timeout} req, err := http.NewRequest("GET", endpoint, nil) if err != nil { s.logger.WithFields(logrus.Fields{ "key_id": key.ID, "endpoint": endpoint, }).Error("Failed to create validation request") return fmt.Errorf("failed to create request: %w", err) } // 3. 修改请求(添加密钥认证头) s.channel.ModifyRequest(req, key) // 4. 执行请求 resp, err := client.Do(req) if err != nil { return fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() // 5. 检查响应状态 if resp.StatusCode == http.StatusOK { return nil } // 6. 处理错误响应 return s.buildValidationError(resp) } // StartTestKeysTask 启动批量密钥测试任务 func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { // 1. 解析和验证输入 keyStrings := utils.ParseKeysFromText(keysText) if len(keyStrings) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text") } // 2. 查询密钥模型 apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings) if err != nil { return nil, CustomErrors.ParseDBError(err) } if len(apiKeyModels) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system") } // 3. 批量解密密钥 if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil { s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task") return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation") } // 4. 获取组配置 group, ok := s.groupManager.GetGroupByID(groupID) if !ok { return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID)) } opConfig, err := s.groupManager.BuildOperationalConfig(group) if err != nil { return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err)) } // 5. 构建验证端点 endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID) if err != nil { return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err)) } // 6. 创建任务 resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout) if err != nil { return nil, err } // 7. 准备任务参数 params := s.buildValidationParams(opConfig) // 8. 启动异步验证任务 go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint) return taskStatus, nil } // StartTestKeysByFilterTask 根据状态过滤启动批量测试任务 func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "statuses": statuses, }).Info("Starting test task with status filter") // 1. 根据过滤条件查询密钥 keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { return nil, CustomErrors.ParseDBError(err) } if len(keyValues) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria") } // 2. 转换为文本格式并启动任务 keysAsText := strings.Join(keyValues, "\n") s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID) return s.StartTestKeysTask(ctx, groupID, keysAsText) } // ==================== 核心任务执行逻辑 ==================== // validationParams 验证参数封装 type validationParams struct { timeout time.Duration concurrency int } // buildValidationParams 构建验证参数 func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams { settings := s.settingsManager.GetSettings() // 从配置读取超时时间(而非硬编码) timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second if timeout <= 0 { timeout = 30 * time.Second // 仅在配置无效时使用默认值 } // 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值) var concurrency int if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 { concurrency = *opConfig.KeyCheckConcurrency } else if settings.BaseKeyCheckConcurrency > 0 { concurrency = settings.BaseKeyCheckConcurrency } else { concurrency = defaultConcurrency // 兜底默认值 } // 限制最大并发数(防护措施) if concurrency > maxValidationConcurrency { concurrency = maxValidationConcurrency } return validationParams{ timeout: timeout, concurrency: concurrency, } } // runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器 func (s *KeyValidationService) runTestKeysTaskWithRecovery( ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, params validationParams, endpoint string, ) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r) s.logger.WithField("task_id", taskID).Error(err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) } }() s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint) } // runTestKeysTask 执行批量密钥验证任务 func (s *KeyValidationService) runTestKeysTask( ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, params validationParams, endpoint string, ) { s.logger.WithFields(logrus.Fields{ "task_id": taskID, "group_id": groupID, "key_count": len(keys), "concurrency": params.concurrency, "timeout": params.timeout, }).Info("Starting validation task") // 1. 初始化结果收集 results := make([]models.KeyTestResult, len(keys)) // 2. 创建任务分发器 dispatcher := newValidationDispatcher( keys, params.concurrency, s, ctx, taskID, groupID, endpoint, params.timeout, ) // 3. 执行并发验证 dispatcher.run(results) // 4. 完成任务 s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil) s.logger.WithFields(logrus.Fields{ "task_id": taskID, "group_id": groupID, "processed": len(results), }).Info("Validation task completed") } // ==================== 验证调度器 ==================== // validationJob 验证作业 type validationJob struct { index int key models.APIKey } // validationDispatcher 验证任务分发器 type validationDispatcher struct { keys []models.APIKey concurrency int service *KeyValidationService ctx context.Context taskID string groupID uint endpoint string timeout time.Duration mu sync.Mutex processedCount int } // newValidationDispatcher 创建验证分发器 func newValidationDispatcher( keys []models.APIKey, concurrency int, service *KeyValidationService, ctx context.Context, taskID string, groupID uint, endpoint string, timeout time.Duration, ) *validationDispatcher { return &validationDispatcher{ keys: keys, concurrency: concurrency, service: service, ctx: ctx, taskID: taskID, groupID: groupID, endpoint: endpoint, timeout: timeout, } } // run 执行并发验证 func (d *validationDispatcher) run(results []models.KeyTestResult) { var wg sync.WaitGroup jobs := make(chan validationJob, len(d.keys)) // 启动 worker pool for i := 0; i < d.concurrency; i++ { wg.Add(1) go d.worker(&wg, jobs, results) } // 分发任务 for i, key := range d.keys { jobs <- validationJob{index: i, key: key} } close(jobs) // 等待所有 worker 完成 wg.Wait() } // worker 验证工作协程 func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) { defer wg.Done() for job := range jobs { result := d.validateKey(job.key) d.mu.Lock() results[job.index] = result d.processedCount++ _ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount) d.mu.Unlock() } } // validateKey 验证单个密钥并返回结果 func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult { // 1. 执行验证 validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint) // 2. 构建结果和事件 result, event := d.buildResultAndEvent(key, validationErr) // 3. 发布验证事件 d.publishValidationEvent(key.ID, event) return result } // buildResultAndEvent 构建验证结果和事件 func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) { event := models.RequestFinishedEvent{ RequestLog: models.RequestLog{ GroupID: &d.groupID, KeyID: &key.ID, }, } if validationErr == nil { // 验证成功 event.RequestLog.IsSuccess = true return models.KeyTestResult{ Key: key.APIKey, Status: "valid", Message: "Validation successful", }, event } // 验证失败 event.RequestLog.IsSuccess = false var apiErr *CustomErrors.APIError if CustomErrors.As(validationErr, &apiErr) { event.Error = apiErr return models.KeyTestResult{ Key: key.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message), }, event } // 其他错误 event.Error = &CustomErrors.APIError{Message: validationErr.Error()} return models.KeyTestResult{ Key: key.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error(), }, event } // publishValidationEvent 发布验证事件 func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) { eventData, err := json.Marshal(event) if err != nil { d.service.logger.WithFields(logrus.Fields{ "key_id": keyID, "group_id": d.groupID, }).WithError(err).Error("Failed to marshal validation event") return } if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil { d.service.logger.WithFields(logrus.Fields{ "key_id": keyID, "group_id": d.groupID, }).WithError(err).Error("Failed to publish validation event") } } // ==================== 辅助方法 ==================== // buildValidationError 构建验证错误 func (s *KeyValidationService) buildValidationError(resp *http.Response) error { bodyBytes, readErr := io.ReadAll(resp.Body) var errorMsg string if readErr != nil { errorMsg = "Failed to read error response body" s.logger.WithError(readErr).Warn("Failed to read validation error response") } else { errorMsg = string(bodyBytes) } return &CustomErrors.APIError{ HTTPStatus: resp.StatusCode, Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg), Code: "VALIDATION_FAILED", } }