217 lines
7.4 KiB
Go
217 lines
7.4 KiB
Go
// Filename: internal/service/key_validation_service.go
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/channel"
|
|
CustomErrors "gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/repository"
|
|
"gemini-balancer/internal/settings"
|
|
"gemini-balancer/internal/store"
|
|
"gemini-balancer/internal/task"
|
|
"gemini-balancer/internal/utils"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
TaskTypeTestKeys = "test_keys"
|
|
)
|
|
|
|
type KeyValidationService struct {
|
|
taskService task.Reporter
|
|
channel channel.ChannelProxy
|
|
db *gorm.DB
|
|
SettingsManager *settings.SettingsManager
|
|
groupManager *GroupManager
|
|
store store.Store
|
|
keyRepo repository.KeyRepository
|
|
logger *logrus.Entry
|
|
}
|
|
|
|
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
|
|
return &KeyValidationService{
|
|
taskService: ts,
|
|
channel: ch,
|
|
db: db,
|
|
SettingsManager: ss,
|
|
groupManager: gm,
|
|
store: st,
|
|
keyRepo: kr,
|
|
logger: logger.WithField("component", "KeyValidationService🧐"),
|
|
}
|
|
}
|
|
|
|
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
|
|
if err := s.keyRepo.Decrypt(key); err != nil {
|
|
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
|
|
}
|
|
client := &http.Client{Timeout: timeout}
|
|
req, err := http.NewRequest("GET", endpoint, nil)
|
|
if err != nil {
|
|
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
s.channel.ModifyRequest(req, key)
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
return nil
|
|
}
|
|
|
|
bodyBytes, readErr := io.ReadAll(resp.Body)
|
|
var errorMsg string
|
|
if readErr != nil {
|
|
errorMsg = "Failed to read error response body"
|
|
} else {
|
|
errorMsg = string(bodyBytes)
|
|
}
|
|
|
|
return &CustomErrors.APIError{
|
|
HTTPStatus: resp.StatusCode,
|
|
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
|
Code: "VALIDATION_FAILED",
|
|
}
|
|
}
|
|
|
|
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
|
keyStrings := utils.ParseKeysFromText(keysText)
|
|
if len(keyStrings) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
|
}
|
|
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
|
|
if err != nil {
|
|
return nil, CustomErrors.ParseDBError(err)
|
|
}
|
|
if len(apiKeyModels) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
|
|
}
|
|
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
|
|
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
|
|
}
|
|
group, ok := s.groupManager.GetGroupByID(groupID)
|
|
if !ok {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
|
}
|
|
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
|
if err != nil {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
|
}
|
|
resourceID := fmt.Sprintf("group-%d", groupID)
|
|
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
settings := s.SettingsManager.GetSettings()
|
|
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
|
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
|
|
return nil, err
|
|
}
|
|
var concurrency int
|
|
if opConfig.KeyCheckConcurrency != nil {
|
|
concurrency = *opConfig.KeyCheckConcurrency
|
|
} else {
|
|
concurrency = settings.BaseKeyCheckConcurrency
|
|
}
|
|
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
|
return taskStatus, nil
|
|
}
|
|
|
|
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex
|
|
finalResults := make([]models.KeyTestResult, len(keys))
|
|
processedCount := 0
|
|
if concurrency <= 0 {
|
|
concurrency = 10
|
|
}
|
|
type job struct {
|
|
Index int
|
|
Value models.APIKey
|
|
}
|
|
jobs := make(chan job, len(keys))
|
|
for i := 0; i < concurrency; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for j := range jobs {
|
|
apiKeyModel := j.Value
|
|
keyToValidate := apiKeyModel
|
|
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
|
|
|
|
var currentResult models.KeyTestResult
|
|
event := models.RequestFinishedEvent{
|
|
RequestLog: models.RequestLog{
|
|
GroupID: &groupID,
|
|
KeyID: &apiKeyModel.ID,
|
|
},
|
|
}
|
|
if validationErr == nil {
|
|
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
|
|
event.RequestLog.IsSuccess = true
|
|
} else {
|
|
var apiErr *CustomErrors.APIError
|
|
if CustomErrors.As(validationErr, &apiErr) {
|
|
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
|
|
event.Error = apiErr
|
|
} else {
|
|
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
|
|
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
|
}
|
|
event.RequestLog.IsSuccess = false
|
|
}
|
|
eventData, _ := json.Marshal(event)
|
|
|
|
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
|
|
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
|
|
}
|
|
|
|
mu.Lock()
|
|
finalResults[j.Index] = currentResult
|
|
processedCount++
|
|
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
|
mu.Unlock()
|
|
}
|
|
}()
|
|
}
|
|
for i, k := range keys {
|
|
jobs <- job{Index: i, Value: k}
|
|
}
|
|
close(jobs)
|
|
wg.Wait()
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
|
|
}
|
|
|
|
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
|
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
|
|
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
|
if err != nil {
|
|
return nil, CustomErrors.ParseDBError(err)
|
|
}
|
|
if len(keyValues) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
|
|
}
|
|
keysAsText := strings.Join(keyValues, "\n")
|
|
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
|
|
return s.StartTestKeysTask(ctx, groupID, keysAsText)
|
|
}
|