// Filename: internal/service/key_validation_service.go package service import ( "encoding/json" "fmt" "gemini-balancer/internal/channel" CustomErrors "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" "gemini-balancer/internal/task" "gemini-balancer/internal/utils" "io" "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "gorm.io/gorm" ) const ( TaskTypeTestKeys = "test_keys" ) type KeyValidationService struct { taskService task.Reporter channel channel.ChannelProxy db *gorm.DB SettingsManager *settings.SettingsManager groupManager *GroupManager store store.Store keyRepo repository.KeyRepository logger *logrus.Entry } func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService { return &KeyValidationService{ taskService: ts, channel: ch, db: db, SettingsManager: ss, groupManager: gm, store: st, keyRepo: kr, logger: logger.WithField("component", "KeyValidationService🧐"), } } func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error { if err := s.keyRepo.Decrypt(key); err != nil { return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err) } client := &http.Client{Timeout: timeout} req, err := http.NewRequest("GET", endpoint, nil) if err != nil { s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err) return fmt.Errorf("failed to create request: %w", err) } s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request resp, err := client.Do(req) if err != nil { // This is a network-level error (e.g., timeout, DNS issue) return fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { return nil // Success } // Read the body for more error details bodyBytes, readErr := io.ReadAll(resp.Body) var errorMsg string if readErr != nil { errorMsg = "Failed to read error response body" } else { errorMsg = string(bodyBytes) } // This is a validation failure with a specific HTTP status code return &CustomErrors.APIError{ HTTPStatus: resp.StatusCode, Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg), Code: "VALIDATION_FAILED", } } // --- 异步任务方法 (全面适配新task包) --- func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) { keyStrings := utils.ParseKeysFromText(keysText) if len(keyStrings) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text") } apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings) if err != nil { return nil, CustomErrors.ParseDBError(err) } if len(apiKeyModels) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system") } if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil { s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.") return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation") } group, ok := s.groupManager.GetGroupByID(groupID) if !ok { // [FIX] Correctly use the NewAPIError constructor for a missing group. return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID)) } opConfig, err := s.groupManager.BuildOperationalConfig(group) if err != nil { return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err)) } resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour) if err != nil { return nil, err // Pass up the error from task service (e.g., "task already running") } settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID) if err != nil { s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails return nil, err } var concurrency int if opConfig.KeyCheckConcurrency != nil { concurrency = *opConfig.KeyCheckConcurrency } else { concurrency = settings.BaseKeyCheckConcurrency } go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency) return taskStatus, nil } func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) { var wg sync.WaitGroup var mu sync.Mutex finalResults := make([]models.KeyTestResult, len(keys)) processedCount := 0 if concurrency <= 0 { concurrency = 10 } type job struct { Index int Value models.APIKey } jobs := make(chan job, len(keys)) for i := 0; i < concurrency; i++ { wg.Add(1) go func() { defer wg.Done() for j := range jobs { apiKeyModel := j.Value keyToValidate := apiKeyModel validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint) var currentResult models.KeyTestResult event := models.RequestFinishedEvent{ RequestLog: models.RequestLog{ // GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址 GroupID: &groupID, KeyID: &apiKeyModel.ID, }, } if validationErr == nil { currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."} event.RequestLog.IsSuccess = true } else { var apiErr *CustomErrors.APIError if CustomErrors.As(validationErr, &apiErr) { currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)} event.Error = apiErr } else { currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()} event.Error = &CustomErrors.APIError{Message: validationErr.Error()} } event.RequestLog.IsSuccess = false } eventData, _ := json.Marshal(event) if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil { s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID) } mu.Lock() finalResults[j.Index] = currentResult processedCount++ _ = s.taskService.UpdateProgressByID(taskID, processedCount) mu.Unlock() } }() } for i, k := range keys { jobs <- job{Index: i, Value: k} } close(jobs) wg.Wait() s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil) } func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) { s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses) keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { return nil, CustomErrors.ParseDBError(err) } if len(keyValues) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.") } keysAsText := strings.Join(keyValues, "\n") s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID) return s.StartTestKeysTask(groupID, keysAsText) }