Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/analytics_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/db/dialect"
@@ -43,7 +43,7 @@ func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d di
}
func (s *AnalyticsService) Start() {
s.wg.Add(2) // 2 (flushLoop, eventListener)
s.wg.Add(2)
go s.flushLoop()
go s.eventListener()
s.logger.Info("AnalyticsService (Command Side) started.")
@@ -53,13 +53,13 @@ func (s *AnalyticsService) Stop() {
close(s.stopChan)
s.wg.Wait()
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
s.flushToDB() // 停止前刷盘
s.flushToDB()
s.logger.Info("AnalyticsService final data flush completed.")
}
func (s *AnalyticsService) eventListener() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
@@ -87,9 +87,10 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline()
pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1)
@@ -120,6 +121,7 @@ func (s *AnalyticsService) flushLoop() {
}
func (s *AnalyticsService) flushToDB() {
ctx := context.Background()
now := time.Now().UTC()
keysToFlush := []string{
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
@@ -127,7 +129,7 @@ func (s *AnalyticsService) flushToDB() {
}
for _, key := range keysToFlush {
data, err := s.store.HGetAll(key)
data, err := s.store.HGetAll(ctx, key)
if err != nil || len(data) == 0 {
continue
}
@@ -136,15 +138,15 @@ func (s *AnalyticsService) flushToDB() {
if len(statsToFlush) > 0 {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"}, // conflict columns
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
if err != nil {
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
} else {
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
_ = s.store.HDel(key, parsedFields...)
_ = s.store.HDel(ctx, key, parsedFields...)
}
}
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/apikey_service.go
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -29,7 +29,6 @@ const (
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
)
// DTOs & Constants
const (
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
)
@@ -83,7 +82,6 @@ func NewAPIKeyService(
gm *GroupManager,
logger *logrus.Logger,
) *APIKeyService {
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
return &APIKeyService{
db: db,
keyRepo: repo,
@@ -99,22 +97,22 @@ func NewAPIKeyService(
}
func (s *APIKeyService) Start() {
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
}
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
return
}
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
return
@@ -177,6 +175,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
if event.RequestLog.IsSuccess {
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
if err != nil {
@@ -194,17 +193,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
now := time.Now()
mapping.LastUsedAt = &now
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
return
}
if statusChanged {
go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
}
return
}
if event.Error != nil {
s.judgeKeyErrors(
ctx,
event.CorrelationID,
*event.RequestLog.GroupID,
*event.RequestLog.KeyID,
@@ -215,6 +215,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
}
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
ctx := context.Background()
log := s.logger.WithFields(logrus.Fields{
"group_id": event.GroupID,
"key_id": event.KeyID,
@@ -222,11 +223,11 @@ func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChange
"reason": event.ChangeReason,
})
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
log.Info("Polling caches updated based on health check event.")
}
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
changeEvent := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -236,13 +237,12 @@ func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(changeEvent)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
}
}
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
// --- Path 1: High-performance DB pagination (no keyword) ---
func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
if params.Keyword == "" {
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
if err != nil {
@@ -260,14 +260,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
TotalPages: totalPages,
}, nil
}
// --- Path 2: In-memory search (keyword present) ---
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
// To get all keys, we fetch all IDs first, then get their full details.
var statusesToFilter []string
if params.Status != "" {
statusesToFilter = append(statusesToFilter, params.Status)
} else {
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
statusesToFilter = append(statusesToFilter, "all")
}
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
if err != nil {
@@ -277,14 +275,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
}
// This is the heavy operation: getting all keys and decrypting them.
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
if err != nil {
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
}
// We also need mappings to build the final `APIKeyDetails`.
var allMappings []models.GroupAPIKeyMapping
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
if err != nil {
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
}
@@ -292,7 +288,6 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
for i := range allMappings {
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
}
// Filter the results in memory.
var filteredItems []*models.APIKeyDetails
for _, key := range allKeys {
if strings.Contains(key.APIKey, params.Keyword) {
@@ -312,11 +307,9 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}
}
}
// Sort the filtered results to ensure consistent pagination (by ID descending).
sort.Slice(filteredItems, func(i, j int) bool {
return filteredItems[i].ID > filteredItems[j].ID
})
// Manually paginate the filtered results.
total := int64(len(filteredItems))
start := (params.Page - 1) * params.PageSize
end := start + params.PageSize
@@ -345,14 +338,15 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}, nil
}
func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) {
func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) {
return s.keyRepo.GetKeysByIDs(ids)
}
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
go func() {
bgCtx := context.Background()
var oldKey models.APIKey
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
return
}
@@ -364,16 +358,14 @@ func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
return nil
}
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
// Get all associated groups before deletion to publish correct events
groups, err := s.keyRepo.GetGroupsForKey(id)
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
if err != nil {
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
}
err = s.keyRepo.HardDeleteByID(id)
if err == nil {
// Publish events for each group the key was a part of
for _, groupID := range groups {
event := models.KeyStatusChangedEvent{
KeyID: id,
@@ -381,13 +373,13 @@ func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
ChangeReason: "key_hard_deleted",
}
eventData, _ := json.Marshal(event)
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
}
}
return err
}
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -409,19 +401,20 @@ func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus model
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
return nil, err
}
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
return mapping, nil
}
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
ctx := context.Background()
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
if event.NewMasterStatus != models.MasterStatusRevoked {
return
}
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
return
@@ -432,7 +425,7 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
for _, groupID := range affectedGroupIDs {
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
@@ -441,32 +434,32 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
}
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
if len(keyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
if err != nil {
return nil, err
}
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
return taskStatus, nil
}
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
}
}()
var mappingsToProcess []models.GroupAPIKeyMapping
err := s.db.Preload("APIKey").
err := s.db.WithContext(ctx).Preload("APIKey").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Find(&mappingsToProcess).Error
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
result := &BatchRestoreResult{
@@ -476,7 +469,7 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
processedCount := 0
for _, mapping := range mappingsToProcess {
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
if mapping.APIKey == nil {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
@@ -492,33 +485,29 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
// Use the version that doesn't trigger individual cache updates.
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
} else {
result.RestoredCount++
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
successfulMappings = append(successfulMappings, &mapping)
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
}
} else {
result.RestoredCount++ // Already active, count as success.
result.RestoredCount++
}
}
// After the loop, perform one single, efficient cache update.
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
// This is not a task-fatal error, so we just log it and continue.
}
// Account for keys that were requested but not found in the initial DB query.
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
var bannedKeyIDs []uint
err := s.db.Model(&models.GroupAPIKeyMapping{}).
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
Pluck("api_key_id", &bannedKeyIDs).Error
if err != nil {
@@ -527,10 +516,11 @@ func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, e
if len(bannedKeyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
}
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
}
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
ctx := context.Background()
group, ok := s.groupManager.GetGroupByID(event.GroupID)
if !ok {
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
@@ -552,7 +542,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 10 // Safety fallback
concurrency = 10
}
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
@@ -571,7 +561,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
if validationErr == nil {
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
}
} else {
@@ -579,7 +569,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
if !CustomErrors.As(validationErr, &apiErr) {
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
}
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false)
}
}
}()
@@ -592,12 +582,9 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
}
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
// that match a specific set of source statuses within a group.
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -605,35 +592,32 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatus
if len(keyIDs) == 0 {
now := time.Now()
return &task.Status{
IsRunning: false, // The "task" is not running.
IsRunning: false,
Processed: 0,
Total: 0,
Result: map[string]string{ // We use the flexible Result field to pass the message.
Result: map[string]string{
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
},
Error: "", // There is no error.
Error: "",
StartedAt: now,
FinishedAt: &now, // It started and finished at the same time.
}, nil // Return nil for the error, signaling a 200 OK.
FinishedAt: &now,
}, nil
}
// 2. Start a new task using the TaskService, following existing patterns.
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
if err != nil {
return nil, err // Pass up errors like "task already in progress".
return nil, err
}
// 3. Run the core logic in a separate goroutine.
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
return taskStatus, nil
}
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
}
}()
type BatchUpdateResult struct {
@@ -642,31 +626,27 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
}
result := &BatchUpdateResult{}
var successfulMappings []*models.GroupAPIKeyMapping
// 1. Fetch all key master statuses in one go. This is efficient.
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
if err != nil {
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
for _, key := range keys {
masterStatusMap[key.ID] = key.MasterStatus
}
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
// avoiding the need for a new repository method. This pattern is
// already used in other parts of this service.
var mappings []*models.GroupAPIKeyMapping
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
processedCount := 0
for _, mapping := range mappings {
processedCount++
// The progress update should reflect the number of items *being processed*, not the final count.
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
if !ok {
result.SkippedCount++
@@ -688,24 +668,25 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
} else {
result.UpdatedCount++
successfulMappings = append(successfulMappings, mapping)
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
}
} else {
result.UpdatedCount++ // Already in desired state, count as success.
result.UpdatedCount++
}
}
result.SkippedCount += (len(keyIDs) - len(mappings))
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
}
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
ctx := context.Background()
if success {
if group.PollingStrategy == models.StrategyWeighted {
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
}
return
}
@@ -716,26 +697,20 @@ func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.
errMsg := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
} else {
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
}
}
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
func sanitizeForLog(errMsg string) string {
// Find the start of any potential JSON blob or detailed structure.
jsonStartIndex := strings.Index(errMsg, "{")
var cleanMsg string
if jsonStartIndex != -1 {
// If a '{' is found, take everything before it as the summary
// and append a simple placeholder.
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
} else {
// If no JSON-like structure is found, use the original message.
cleanMsg = errMsg
}
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
const maxLen = 250
if len(cleanMsg) > maxLen {
return cleanMsg[:maxLen] + "..."
@@ -744,6 +719,7 @@ func sanitizeForLog(errMsg string) string {
}
func (s *APIKeyService) judgeKeyErrors(
ctx context.Context,
correlationID string,
groupID, keyID uint,
apiErr *CustomErrors.APIError,
@@ -765,11 +741,11 @@ func (s *APIKeyService) judgeKeyErrors(
oldStatus := mapping.Status
mapping.Status = models.StatusBanned
mapping.LastError = errorMessage
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
} else {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(keyID, "permanent_upstream_error")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
}
}
return
@@ -801,23 +777,23 @@ func (s *APIKeyService) judgeKeyErrors(
if oldStatus != newStatus {
mapping.Status = newStatus
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping after temporary error.")
return
}
if oldStatus != newStatus {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
}
return
}
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
}
}
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -832,7 +808,7 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
}
oldMasterStatus := key.MasterStatus
newMasterStatus := models.MasterStatusRevoked
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
return
}
@@ -844,9 +820,9 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(masterKeyEvent)
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
_ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData)
}
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/dashboard_query_service.go
package service
import (
"context"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
@@ -17,8 +17,6 @@ import (
const overviewCacheChannel = "syncer:cache:dashboard_overview"
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
type DashboardQueryService struct {
db *gorm.DB
store store.Store
@@ -54,9 +52,9 @@ func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
}
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(statsKey)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
@@ -74,11 +72,11 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
SuccessRequests int64
}
var last1Hour, last24Hours requestStatsResult
s.db.Model(&models.StatsHourly{}).
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour)
s.db.Model(&models.StatsHourly{}).
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours)
@@ -109,8 +107,9 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
}
func (s *DashboardQueryService) eventListener() {
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
ctx := context.Background()
keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close()
defer upstreamStatusSub.Close()
for {
@@ -128,7 +127,6 @@ func (s *DashboardQueryService) eventListener() {
}
}
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
@@ -141,8 +139,7 @@ func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate()
}
// QueryHistoricalChart 查询历史图表数据。
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
@@ -151,7 +148,7 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
@@ -189,38 +186,38 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
}
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx := context.Background()
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
startTime := time.Now()
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体而不是nil
RequestCount24h: models.StatCard{}, // 同上
KeyCount: models.StatCard{},
RequestCount24h: models.StatCard{},
TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{},
RequestCounts: make(map[string]int64),
}
// --- 1. Aggregate Operational Status from Mappings ---
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var mappingStatusResults []MappingStatusResult
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
}
for _, res := range mappingStatusResults {
resp.KeyStatusCount[res.Status] = res.Count
}
// --- 2. Aggregate Master Status from APIKeys ---
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var masterStatusResults []MasterStatusResult
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query master status stats: %w", err)
}
var totalKeys, invalidKeys int64
@@ -235,20 +232,15 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
now := time.Now()
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
var count1m, count1h, count1d int64
// RPM: 从此刻倒推1分钟
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
// RPH: 从此刻倒推1小时
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
// RPD: 从今天零点 (UTC) 到此刻
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
year, month, day := now.UTC().Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
var count30d int64
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
@@ -256,7 +248,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
resp.RequestCounts["30d"] = count30d
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
} else {
for _, u := range upstreams {
@@ -269,7 +261,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
return resp, nil
}
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
var startTime time.Time
now := time.Now()
switch period {
@@ -288,7 +280,7 @@ func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H,
Success int64
}
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/db_log_writer_service.go
package service
import (
"context"
"encoding/json"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
@@ -35,35 +35,30 @@ func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.Settin
store: s,
SettingsManager: settings,
logger: logger.WithField("component", "DBLogWriter📝"),
// 使用配置值来创建缓冲区
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
}
}
func (s *DBLogWriterService) Start() {
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
// 启动事件监听器
s.wg.Add(2)
go s.eventListenerLoop()
// 启动数据库写入器
go s.dbWriterLoop()
s.logger.Info("DBLogWriterService started.")
}
func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan) // 通知所有goroutine停止
s.wg.Wait() // 等待所有goroutine完成
close(s.stopChan)
s.wg.Wait()
s.logger.Info("DBLogWriterService stopped.")
}
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
ctx := context.Background()
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
@@ -80,34 +75,27 @@ func (s *DBLogWriterService) eventListenerLoop() {
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
continue
}
// 将事件中的日志部分放入缓冲区
select {
case s.logBuffer <- &event.RequestLog:
default:
s.logger.Warn("Log buffer is full. A log message might be dropped.")
}
case <-s.stopChan:
s.logger.Info("Event listener loop stopping.")
// 关闭缓冲区以通知dbWriterLoop处理完剩余日志后退出
close(s.logBuffer)
return
}
}
}
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done()
// 在启动时获取一次配置
cfg := s.SettingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 {
batchSize = 100
}
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushTimeout <= 0 {
flushTimeout = 5 * time.Second
@@ -126,7 +114,7 @@ func (s *DBLogWriterService) dbWriterLoop() {
return
}
batch = append(batch, logEntry)
if len(batch) >= batchSize { // 使用配置的批次大小
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
@@ -139,7 +127,6 @@ func (s *DBLogWriterService) dbWriterLoop() {
}
}
// flushBatch 将一个批次的日志写入数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")

View File

@@ -75,7 +75,7 @@ func NewHealthCheckService(
func (s *HealthCheckService) Start() {
s.logger.Info("Starting HealthCheckService with independent check loops...")
s.wg.Add(4) // Now four loops
s.wg.Add(4)
go s.runKeyCheckLoop()
go s.runUpstreamCheckLoop()
go s.runProxyCheckLoop()
@@ -102,8 +102,6 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
func (s *HealthCheckService) runKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Key check dynamic scheduler loop started.")
// 主调度循环,每分钟检查一次任务
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -126,26 +124,22 @@ func (s *HealthCheckService) scheduleKeyChecks() {
defer s.groupCheckTimeMutex.Unlock()
for _, group := range groups {
// 获取特定于组的运营配置
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
continue
}
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
continue // 跳过禁用了健康检查的组
continue
}
var intervalMinutes int
if opConfig.KeyCheckIntervalMinutes != nil {
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
}
interval := time.Duration(intervalMinutes) * time.Minute
if interval <= 0 {
continue // 跳过无效的检查周期
continue
}
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
go s.performKeyChecksForGroup(group, opConfig)
@@ -160,7 +154,6 @@ func (s *HealthCheckService) runUpstreamCheckLoop() {
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
s.performUpstreamChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -184,7 +177,6 @@ func (s *HealthCheckService) runProxyCheckLoop() {
if s.SettingsManager.GetSettings().EnableProxyCheck {
s.performProxyChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -203,6 +195,7 @@ func (s *HealthCheckService) runProxyCheckLoop() {
}
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -213,11 +206,9 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
}
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
log.Infof("Starting key health check cycle.")
var mappingsToCheck []models.GroupAPIKeyMapping
err = s.db.Model(&models.GroupAPIKeyMapping{}).
err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
Where("group_api_key_mappings.key_group_id = ?", group.ID).
Where("api_keys.master_status = ?", models.MasterStatusActive).
@@ -233,7 +224,6 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("No key mappings to check for this group.")
return
}
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
var wg sync.WaitGroup
@@ -242,14 +232,14 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 1 // 保证至少有一个 worker
concurrency = 1
}
for w := 1; w <= concurrency; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for mapping := range jobs {
s.checkAndProcessMapping(&mapping, timeout, endpoint)
s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
}
}(w)
}
@@ -261,52 +251,46 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("Finished key health check cycle.")
}
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
if mapping.APIKey == nil {
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
return
}
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
// --- 诊断一:验证成功 (健康) ---
if validationErr == nil {
if mapping.Status != models.StatusActive {
s.activateMapping(mapping)
s.activateMapping(ctx, mapping)
}
return
}
errorString := validationErr.Error()
// --- 诊断二:永久性错误 ---
if CustomErrors.IsPermanentUpstreamError(errorString) {
s.revokeMapping(mapping, validationErr)
s.revokeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断三:暂时性错误 ---
if CustomErrors.IsTemporaryUpstreamError(errorString) {
// Log with a higher level (WARN) since this is an actionable, proactive finding.
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
s.penalizeMapping(mapping, validationErr)
s.penalizeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断四:其他未知或上游服务错误 ---
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
}
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
oldStatus := mapping.Status
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
// Re-fetch group-specific operational config to get the correct thresholds
func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
if !ok {
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
@@ -320,7 +304,6 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
oldStatus := mapping.Status
mapping.LastError = err.Error()
mapping.ConsecutiveErrorCount++
// Use the group-specific threshold
threshold := *opConfig.KeyBlacklistThreshold
if mapping.ConsecutiveErrorCount >= threshold {
mapping.Status = models.StatusCooldown
@@ -329,44 +312,41 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
mapping.CooldownUntil = &cooldownTime
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
}
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
if oldStatus != mapping.Status {
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
}
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
oldStatus := mapping.Status
if oldStatus == models.StatusBanned {
return // Already banned, do nothing.
return
}
mapping.Status = models.StatusBanned
mapping.LastError = "Definitive error: " + err.Error()
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
mapping.ConsecutiveErrorCount = 0
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
}
}
func (s *HealthCheckService) performUpstreamChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
return
}
@@ -390,10 +370,10 @@ func (s *HealthCheckService) performUpstreamChecks() {
s.lastResultsMutex.Unlock()
if oldStatus != newStatus {
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
} else {
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
}
}
}(u)
@@ -412,10 +392,11 @@ func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration)
}
func (s *HealthCheckService) performProxyChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
var proxies []*models.ProxyConfig
if err := s.db.Find(&proxies).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve proxies.")
return
}
@@ -438,7 +419,7 @@ func (s *HealthCheckService) performProxyChecks() {
s.lastResultsMutex.Unlock()
if proxyCfg.Status != newStatus {
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
}
}
@@ -482,7 +463,7 @@ func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout ti
return true
}
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
event := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -496,12 +477,12 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, o
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
return
}
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
}
}
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
event := models.UpstreamHealthChangedEvent{
UpstreamID: upstream.ID,
UpstreamURL: upstream.URL,
@@ -516,28 +497,20 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
return
}
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
}
}
// =========================================================================
// Global Base Key Check (New Logic)
// =========================================================================
func (s *HealthCheckService) runBaseKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Global base key check loop started.")
settings := s.SettingsManager.GetSettings()
if !settings.EnableBaseKeyCheck {
s.logger.Info("Global base key check is disabled.")
return
}
// Perform an initial check on startup
s.performBaseKeyChecks()
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
if interval <= 0 {
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
@@ -558,6 +531,7 @@ func (s *HealthCheckService) runBaseKeyCheckLoop() {
}
func (s *HealthCheckService) performBaseKeyChecks() {
ctx := context.Background()
s.logger.Info("Starting global base key check cycle.")
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -576,7 +550,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
jobs := make(chan *models.APIKey, len(keys))
var wg sync.WaitGroup
if concurrency <= 0 {
concurrency = 5 // Safe default
concurrency = 5
}
for w := 0; w < concurrency; w++ {
wg.Add(1)
@@ -587,10 +561,10 @@ func (s *HealthCheckService) performBaseKeyChecks() {
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
oldStatus := key.MasterStatus
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
} else {
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
}
}
}
@@ -604,8 +578,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
s.logger.Info("Global base key check cycle finished.")
}
// 事件发布辅助函数
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
event := models.MasterKeyStatusChangedEvent{
KeyID: keyID,
OldMasterStatus: oldStatus,
@@ -618,7 +591,7 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldS
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
return
}
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
}
}

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -42,88 +43,84 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
}
}
// --- 通用的 Panic-Safe 任務執行器 ---
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err)
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
taskFunc()
}
// --- Public Task Starters ---
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in input text")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
})
return taskStatus, nil
}
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_hard_delete" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_restore_keys" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
// --- Private Task Runners ---
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 步骤 1: 对输入的原始 key 列表进行去重。
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeyStrings []string
for _, kStr := range keys {
@@ -133,41 +130,37 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
}
}
if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return
}
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
}
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
return
}
alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
var keysToLink []models.APIKey
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
}
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
// 步骤 6: 分块处理【链接Key到组】的操作并实时更新进度。
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
@@ -179,44 +172,41 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
end = len(idsToLink)
}
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
return
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
}
// 步骤 7: 准备最终结果并结束任务。
result := gin.H{
"newly_linked_count": len(keysToLink),
"already_linked_count": len(alreadyLinkedIDSet),
"total_linked_count": len(allKeyModels),
}
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
if validateOnImport {
s.publishImportGroupCompletedEvent(groupID, idsToLink)
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
}
}
}
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
// runUnlinkKeysTask
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeys []string
for _, kStr := range keys {
@@ -225,46 +215,42 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
uniqueKeys = append(uniqueKeys, kStr)
}
}
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
}
if len(keysToUnlink) == 0 {
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
return
}
idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID
}
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
var totalUnlinked int64
// 步骤 3: 分块处理【解绑Key】的操作并上报进度。
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize
if end > len(idsToUnlink) {
end = len(idsToUnlink)
}
chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
return
}
totalUnlinked += unlinked
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
@@ -276,10 +262,10 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
@@ -290,22 +276,21 @@ func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return
}
totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var restoredCount int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
@@ -316,21 +301,21 @@ func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return
}
restoredCount += count
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
KeyID: keyID,
@@ -340,7 +325,7 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
@@ -349,16 +334,16 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
}
}
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
ChangeReason: reason,
}
eventData, _ := json.Marshal(event)
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 {
return
}
@@ -372,17 +357,15 @@ func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
return
}
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
} else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
}
}
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
// 1. [New] Find the keys to operate on.
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
@@ -390,8 +373,7 @@ func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(groupID, keysAsText)
}
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/channel"
@@ -62,20 +63,18 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
return fmt.Errorf("failed to create request: %w", err)
}
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
s.channel.ModifyRequest(req, key)
resp, err := client.Do(req)
if err != nil {
// This is a network-level error (e.g., timeout, DNS issue)
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil // Success
return nil
}
// Read the body for more error details
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
@@ -84,7 +83,6 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
errorMsg = string(bodyBytes)
}
// This is a validation failure with a specific HTTP status code
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
@@ -92,8 +90,7 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
}
}
// --- 异步任务方法 (全面适配新task包) ---
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
@@ -111,7 +108,6 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
}
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
// [FIX] Correctly use the NewAPIError constructor for a missing group.
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
@@ -119,15 +115,15 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil {
return nil, err // Pass up the error from task service (e.g., "task already running")
return nil, err
}
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
return nil, err
}
var concurrency int
@@ -136,11 +132,11 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
} else {
concurrency = settings.BaseKeyCheckConcurrency
}
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
return taskStatus, nil
}
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup
var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys))
@@ -165,7 +161,6 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
GroupID: &groupID,
KeyID: &apiKeyModel.ID,
},
@@ -185,14 +180,15 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
event.RequestLog.IsSuccess = false
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
}
mu.Lock()
finalResults[j.Index] = currentResult
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
mu.Unlock()
}
}()
@@ -202,10 +198,10 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
}
close(jobs)
wg.Wait()
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
}
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
@@ -216,5 +212,5 @@ func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(groupID, keysAsText)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -3,6 +3,7 @@
package service
import (
"context"
"errors"
apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -43,7 +44,6 @@ func NewResourceService(
aks *APIKeyService,
logger *logrus.Logger,
) *ResourceService {
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
rs := &ResourceService{
settingsManager: sm,
groupManager: gm,
@@ -56,43 +56,40 @@ func NewResourceService(
go rs.preWarmCache(logger)
})
return rs
}
// --- [模式一:智能聚合模式] ---
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.")
// 1.筛选出所有符合条件的候选组,并按优先级排序
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable
}
// 2.从 BasePool中根据系统全局策略选择一个Key
basePool := &repository.BasePool{
CandidateGroups: candidateGroups,
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
}
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the BasePool.")
return nil, apperrors.ErrNoKeysAvailable
}
// 3. 组装最终资源
// [关键] 在此模式下RequestConfig 永远是空的,以保证透明性。
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
if err != nil {
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err
}
resources.RequestConfig = &models.RequestConfig{} // 强制为空
resources.RequestConfig = &models.RequestConfig{}
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil
}
// --- [模式二:精确路由模式] ---
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.")
@@ -101,12 +98,11 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou
if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
}
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
}
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
return nil, apperrors.ErrNoKeysAvailable
@@ -132,7 +128,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
if authToken.IsAdmin {
for _, group := range allGroups {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
@@ -144,7 +139,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
for _, group := range allGroups {
if _, ok := allowedGroupIDs[group.ID]; ok {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
@@ -164,14 +158,6 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
}
var proxyConfig *models.ProxyConfig
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
// if group.EnableProxy && s.proxyModule != nil {
// var err error
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
// if err != nil {
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
// }
// }
return &RequestResources{
KeyGroup: group,
APIKey: apiKey,
@@ -194,7 +180,7 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
time.Sleep(2 * time.Second)
s.logger.Info("Performing initial key cache pre-warming...")
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil {
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
return err
}
@@ -209,7 +195,6 @@ func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup
// 1. 确定权限范围
allowedGroupIDs := make(map[uint]bool)
isTokenRestricted := len(allowedGroupsFromToken) > 0
if isTokenRestricted {
@@ -217,15 +202,12 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed
allowedGroupIDs[ag.ID] = true
}
}
// 2. 筛选
for _, group := range allGroupsFromCache {
// 检查Token权限
if isTokenRestricted && !allowedGroupIDs[group.ID] {
continue
}
// 检查模型是否被允许
isModelAllowed := false
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
if len(group.AllowedModels) == 0 {
isModelAllowed = true
} else {
for _, m := range group.AllowedModels {
@@ -239,8 +221,6 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed
candidateGroups = append(candidateGroups, group)
}
}
// 3.按 Order 字段升序排序
sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order
})

View File

@@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke
// IsIPBanned
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
banKey := fmt.Sprintf("banned_ip:%s", ip)
return s.store.Exists(banKey)
return s.store.Exists(ctx, banKey)
}
// RecordFailedLoginAttempt
@@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
return nil
}
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1)
if err != nil {
return err
}
@@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
banDuration := s.SettingsManager.GetIPBanDuration()
banKey := fmt.Sprintf("banned_ip:%s", ip)
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil {
return err
}
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
s.store.HDel(loginAttemptsKey, ip)
s.store.HDel(ctx, loginAttemptsKey, ip)
}
return nil

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -34,7 +35,7 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
@@ -67,42 +68,43 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
return
}
ctx := context.Background()
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", 1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
}
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
var results []struct {
Status models.APIKeyStatus
Count int64
}
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ?", groupID).
Select("status, COUNT(*) as count").
Group("status").
@@ -119,37 +121,25 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
}
updates["total_keys"] = totalKeys
if err := s.store.Del(statsKey); err != nil {
if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
}
if err := s.store.HSet(statsKey, updates); err != nil {
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
}
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
return nil
}
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
// TODO 逻辑:
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
// 3. 组合成 DashboardStatsResponse
// ... 这个方法的具体实现我们可以在DashboardQueryService中完成
// 这里我们先确保StatsService的核心职责维护缓存已经完成。
// 为了编译通过,我们先返回一个空对象。
// 伪代码:
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
// ...
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
return &models.DashboardStatsResponse{}, nil
}
func (s *StatsService) AggregateHourlyStats() error {
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
s.logger.Info("Starting aggregation of the last hour's request data...")
now := time.Now()
endTime := now.Truncate(time.Hour) // 例如15:23 -> 15:00
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
endTime := now.Truncate(time.Hour)
startTime := endTime.Add(-1 * time.Hour)
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
type aggregationResult struct {
@@ -161,7 +151,8 @@ func (s *StatsService) AggregateHourlyStats() error {
CompletionTokens int64
}
var results []aggregationResult
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Group("group_id, model_name").
@@ -179,7 +170,7 @@ func (s *StatsService) AggregateHourlyStats() error {
var hourlyStats []models.StatsHourly
for _, res := range results {
hourlyStats = append(hourlyStats, models.StatsHourly{
Time: startTime, // 所有记录的时间戳都是该小时的起点
Time: startTime,
GroupID: res.GroupID,
ModelName: res.ModelName,
RequestCount: res.RequestCount,
@@ -189,7 +180,7 @@ func (s *StatsService) AggregateHourlyStats() error {
})
}
return s.db.Clauses(clause.OnConflict{
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error