Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/apikey_service.go
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -29,7 +29,6 @@ const (
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
)
// DTOs & Constants
const (
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
)
@@ -83,7 +82,6 @@ func NewAPIKeyService(
gm *GroupManager,
logger *logrus.Logger,
) *APIKeyService {
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
return &APIKeyService{
db: db,
keyRepo: repo,
@@ -99,22 +97,22 @@ func NewAPIKeyService(
}
func (s *APIKeyService) Start() {
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
}
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
return
}
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
return
@@ -177,6 +175,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
if event.RequestLog.IsSuccess {
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
if err != nil {
@@ -194,17 +193,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
now := time.Now()
mapping.LastUsedAt = &now
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
return
}
if statusChanged {
go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
}
return
}
if event.Error != nil {
s.judgeKeyErrors(
ctx,
event.CorrelationID,
*event.RequestLog.GroupID,
*event.RequestLog.KeyID,
@@ -215,6 +215,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
}
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
ctx := context.Background()
log := s.logger.WithFields(logrus.Fields{
"group_id": event.GroupID,
"key_id": event.KeyID,
@@ -222,11 +223,11 @@ func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChange
"reason": event.ChangeReason,
})
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
log.Info("Polling caches updated based on health check event.")
}
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
changeEvent := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -236,13 +237,12 @@ func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(changeEvent)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
}
}
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
// --- Path 1: High-performance DB pagination (no keyword) ---
func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
if params.Keyword == "" {
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
if err != nil {
@@ -260,14 +260,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
TotalPages: totalPages,
}, nil
}
// --- Path 2: In-memory search (keyword present) ---
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
// To get all keys, we fetch all IDs first, then get their full details.
var statusesToFilter []string
if params.Status != "" {
statusesToFilter = append(statusesToFilter, params.Status)
} else {
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
statusesToFilter = append(statusesToFilter, "all")
}
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
if err != nil {
@@ -277,14 +275,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
}
// This is the heavy operation: getting all keys and decrypting them.
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
if err != nil {
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
}
// We also need mappings to build the final `APIKeyDetails`.
var allMappings []models.GroupAPIKeyMapping
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
if err != nil {
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
}
@@ -292,7 +288,6 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
for i := range allMappings {
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
}
// Filter the results in memory.
var filteredItems []*models.APIKeyDetails
for _, key := range allKeys {
if strings.Contains(key.APIKey, params.Keyword) {
@@ -312,11 +307,9 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}
}
}
// Sort the filtered results to ensure consistent pagination (by ID descending).
sort.Slice(filteredItems, func(i, j int) bool {
return filteredItems[i].ID > filteredItems[j].ID
})
// Manually paginate the filtered results.
total := int64(len(filteredItems))
start := (params.Page - 1) * params.PageSize
end := start + params.PageSize
@@ -345,14 +338,15 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}, nil
}
func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) {
func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) {
return s.keyRepo.GetKeysByIDs(ids)
}
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
go func() {
bgCtx := context.Background()
var oldKey models.APIKey
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
return
}
@@ -364,16 +358,14 @@ func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
return nil
}
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
// Get all associated groups before deletion to publish correct events
groups, err := s.keyRepo.GetGroupsForKey(id)
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
if err != nil {
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
}
err = s.keyRepo.HardDeleteByID(id)
if err == nil {
// Publish events for each group the key was a part of
for _, groupID := range groups {
event := models.KeyStatusChangedEvent{
KeyID: id,
@@ -381,13 +373,13 @@ func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
ChangeReason: "key_hard_deleted",
}
eventData, _ := json.Marshal(event)
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
}
}
return err
}
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -409,19 +401,20 @@ func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus model
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
return nil, err
}
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
return mapping, nil
}
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
ctx := context.Background()
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
if event.NewMasterStatus != models.MasterStatusRevoked {
return
}
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
return
@@ -432,7 +425,7 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
for _, groupID := range affectedGroupIDs {
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
@@ -441,32 +434,32 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
}
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
if len(keyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
if err != nil {
return nil, err
}
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
return taskStatus, nil
}
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
}
}()
var mappingsToProcess []models.GroupAPIKeyMapping
err := s.db.Preload("APIKey").
err := s.db.WithContext(ctx).Preload("APIKey").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Find(&mappingsToProcess).Error
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
result := &BatchRestoreResult{
@@ -476,7 +469,7 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
processedCount := 0
for _, mapping := range mappingsToProcess {
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
if mapping.APIKey == nil {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
@@ -492,33 +485,29 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
// Use the version that doesn't trigger individual cache updates.
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
} else {
result.RestoredCount++
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
successfulMappings = append(successfulMappings, &mapping)
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
}
} else {
result.RestoredCount++ // Already active, count as success.
result.RestoredCount++
}
}
// After the loop, perform one single, efficient cache update.
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
// This is not a task-fatal error, so we just log it and continue.
}
// Account for keys that were requested but not found in the initial DB query.
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
var bannedKeyIDs []uint
err := s.db.Model(&models.GroupAPIKeyMapping{}).
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
Pluck("api_key_id", &bannedKeyIDs).Error
if err != nil {
@@ -527,10 +516,11 @@ func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, e
if len(bannedKeyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
}
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
}
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
ctx := context.Background()
group, ok := s.groupManager.GetGroupByID(event.GroupID)
if !ok {
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
@@ -552,7 +542,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 10 // Safety fallback
concurrency = 10
}
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
@@ -571,7 +561,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
if validationErr == nil {
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
}
} else {
@@ -579,7 +569,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
if !CustomErrors.As(validationErr, &apiErr) {
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
}
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false)
}
}
}()
@@ -592,12 +582,9 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
}
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
// that match a specific set of source statuses within a group.
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -605,35 +592,32 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatus
if len(keyIDs) == 0 {
now := time.Now()
return &task.Status{
IsRunning: false, // The "task" is not running.
IsRunning: false,
Processed: 0,
Total: 0,
Result: map[string]string{ // We use the flexible Result field to pass the message.
Result: map[string]string{
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
},
Error: "", // There is no error.
Error: "",
StartedAt: now,
FinishedAt: &now, // It started and finished at the same time.
}, nil // Return nil for the error, signaling a 200 OK.
FinishedAt: &now,
}, nil
}
// 2. Start a new task using the TaskService, following existing patterns.
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
if err != nil {
return nil, err // Pass up errors like "task already in progress".
return nil, err
}
// 3. Run the core logic in a separate goroutine.
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
return taskStatus, nil
}
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
}
}()
type BatchUpdateResult struct {
@@ -642,31 +626,27 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
}
result := &BatchUpdateResult{}
var successfulMappings []*models.GroupAPIKeyMapping
// 1. Fetch all key master statuses in one go. This is efficient.
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
if err != nil {
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
for _, key := range keys {
masterStatusMap[key.ID] = key.MasterStatus
}
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
// avoiding the need for a new repository method. This pattern is
// already used in other parts of this service.
var mappings []*models.GroupAPIKeyMapping
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
processedCount := 0
for _, mapping := range mappings {
processedCount++
// The progress update should reflect the number of items *being processed*, not the final count.
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
if !ok {
result.SkippedCount++
@@ -688,24 +668,25 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
} else {
result.UpdatedCount++
successfulMappings = append(successfulMappings, mapping)
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
}
} else {
result.UpdatedCount++ // Already in desired state, count as success.
result.UpdatedCount++
}
}
result.SkippedCount += (len(keyIDs) - len(mappings))
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
}
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
ctx := context.Background()
if success {
if group.PollingStrategy == models.StrategyWeighted {
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
}
return
}
@@ -716,26 +697,20 @@ func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.
errMsg := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
} else {
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
}
}
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
func sanitizeForLog(errMsg string) string {
// Find the start of any potential JSON blob or detailed structure.
jsonStartIndex := strings.Index(errMsg, "{")
var cleanMsg string
if jsonStartIndex != -1 {
// If a '{' is found, take everything before it as the summary
// and append a simple placeholder.
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
} else {
// If no JSON-like structure is found, use the original message.
cleanMsg = errMsg
}
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
const maxLen = 250
if len(cleanMsg) > maxLen {
return cleanMsg[:maxLen] + "..."
@@ -744,6 +719,7 @@ func sanitizeForLog(errMsg string) string {
}
func (s *APIKeyService) judgeKeyErrors(
ctx context.Context,
correlationID string,
groupID, keyID uint,
apiErr *CustomErrors.APIError,
@@ -765,11 +741,11 @@ func (s *APIKeyService) judgeKeyErrors(
oldStatus := mapping.Status
mapping.Status = models.StatusBanned
mapping.LastError = errorMessage
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
} else {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(keyID, "permanent_upstream_error")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
}
}
return
@@ -801,23 +777,23 @@ func (s *APIKeyService) judgeKeyErrors(
if oldStatus != newStatus {
mapping.Status = newStatus
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping after temporary error.")
return
}
if oldStatus != newStatus {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
}
return
}
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
}
}
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -832,7 +808,7 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
}
oldMasterStatus := key.MasterStatus
newMasterStatus := models.MasterStatusRevoked
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
return
}
@@ -844,9 +820,9 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(masterKeyEvent)
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
_ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData)
}
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
}