Update Context for store
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
@@ -62,20 +63,18 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
|
||||
s.channel.ModifyRequest(req, key)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// This is a network-level error (e.g., timeout, DNS issue)
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil // Success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the body for more error details
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
var errorMsg string
|
||||
if readErr != nil {
|
||||
@@ -84,7 +83,6 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
|
||||
errorMsg = string(bodyBytes)
|
||||
}
|
||||
|
||||
// This is a validation failure with a specific HTTP status code
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
||||
@@ -92,8 +90,7 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
|
||||
}
|
||||
}
|
||||
|
||||
// --- 异步任务方法 (全面适配新task包) ---
|
||||
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
keyStrings := utils.ParseKeysFromText(keysText)
|
||||
if len(keyStrings) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
||||
@@ -111,7 +108,6 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
|
||||
}
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
// [FIX] Correctly use the NewAPIError constructor for a missing group.
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
||||
}
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
@@ -119,15 +115,15 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err // Pass up the error from task service (e.g., "task already running")
|
||||
return nil, err
|
||||
}
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
|
||||
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
|
||||
return nil, err
|
||||
}
|
||||
var concurrency int
|
||||
@@ -136,11 +132,11 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
|
||||
} else {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
}
|
||||
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
finalResults := make([]models.KeyTestResult, len(keys))
|
||||
@@ -165,7 +161,6 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
|
||||
var currentResult models.KeyTestResult
|
||||
event := models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
|
||||
GroupID: &groupID,
|
||||
KeyID: &apiKeyModel.ID,
|
||||
},
|
||||
@@ -185,14 +180,15 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
|
||||
event.RequestLog.IsSuccess = false
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
finalResults[j.Index] = currentResult
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
@@ -202,10 +198,10 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
@@ -216,5 +212,5 @@ func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses
|
||||
}
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
|
||||
return s.StartTestKeysTask(groupID, keysAsText)
|
||||
return s.StartTestKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user