Files
gemini-banlancer/internal/service/key_validation_service.go
2025-11-21 19:33:05 +08:00

221 lines
7.8 KiB
Go

// Filename: internal/service/key_validation_service.go
package service
import (
"encoding/json"
"fmt"
"gemini-balancer/internal/channel"
CustomErrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"gemini-balancer/internal/task"
"gemini-balancer/internal/utils"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
TaskTypeTestKeys = "test_keys"
)
type KeyValidationService struct {
taskService task.Reporter
channel channel.ChannelProxy
db *gorm.DB
SettingsManager *settings.SettingsManager
groupManager *GroupManager
store store.Store
keyRepo repository.KeyRepository
logger *logrus.Entry
}
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
return &KeyValidationService{
taskService: ts,
channel: ch,
db: db,
SettingsManager: ss,
groupManager: gm,
store: st,
keyRepo: kr,
logger: logger.WithField("component", "KeyValidationService🧐"),
}
}
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
if err := s.keyRepo.Decrypt(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
return fmt.Errorf("failed to create request: %w", err)
}
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
resp, err := client.Do(req)
if err != nil {
// This is a network-level error (e.g., timeout, DNS issue)
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil // Success
}
// Read the body for more error details
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
errorMsg = "Failed to read error response body"
} else {
errorMsg = string(bodyBytes)
}
// This is a validation failure with a specific HTTP status code
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
Code: "VALIDATION_FAILED",
}
}
// --- 异步任务方法 (全面适配新task包) ---
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
}
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(apiKeyModels) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
}
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
}
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
// [FIX] Correctly use the NewAPIError constructor for a missing group.
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil {
return nil, err // Pass up the error from task service (e.g., "task already running")
}
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
return nil, err
}
var concurrency int
if opConfig.KeyCheckConcurrency != nil {
concurrency = *opConfig.KeyCheckConcurrency
} else {
concurrency = settings.BaseKeyCheckConcurrency
}
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
return taskStatus, nil
}
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup
var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys))
processedCount := 0
if concurrency <= 0 {
concurrency = 10
}
type job struct {
Index int
Value models.APIKey
}
jobs := make(chan job, len(keys))
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := range jobs {
apiKeyModel := j.Value
keyToValidate := apiKeyModel
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
GroupID: &groupID,
KeyID: &apiKeyModel.ID,
},
}
if validationErr == nil {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
event.RequestLog.IsSuccess = true
} else {
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
event.Error = apiErr
} else {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
}
event.RequestLog.IsSuccess = false
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
}
mu.Lock()
finalResults[j.Index] = currentResult
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
mu.Unlock()
}
}()
}
for i, k := range keys {
jobs <- job{Index: i, Value: k}
}
close(jobs)
wg.Wait()
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
}
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyValues) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(groupID, keysAsText)
}