diff --git a/internal/container/container.go b/internal/container/container.go index d83a51d..f9aeb1e 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -92,7 +92,6 @@ func BuildContainer() (*dig.Container, error) { // =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) =========== - // 为Channel提供依赖 (Logger 和 *models.SystemSettings 数据插座) container.Provide(channel.NewGeminiChannel) container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch }) diff --git a/internal/handlers/proxy_handler.go b/internal/handlers/proxy_handler.go index 1271e51..109215f 100644 --- a/internal/handlers/proxy_handler.go +++ b/internal/handlers/proxy_handler.go @@ -197,8 +197,8 @@ func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body [ var attemptErr *errors.APIError var isSuccess bool - connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second - ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout) + requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second + ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout) defer cancel() attemptReq := c.Request.Clone(ctx) diff --git a/internal/service/apikey_service.go b/internal/service/apikey_service.go index c75f20f..308c392 100644 --- a/internal/service/apikey_service.go +++ b/internal/service/apikey_service.go @@ -29,10 +29,6 @@ const ( TaskTypeUpdateStatusByFilter = "update_status_by_filter" ) -const ( - TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" -) - type BatchRestoreResult struct { RestoredCount int `json:"restored_count"` SkippedCount int `json:"skipped_count"` @@ -52,12 +48,6 @@ type PaginatedAPIKeys struct { TotalPages int `json:"total_pages"` } -type KeyTestResult struct { - Key string `json:"key"` - Status string `json:"status"` - Message string `json:"message"` -} - type APIKeyService struct { db *gorm.DB keyRepo repository.KeyRepository @@ -99,37 +89,39 @@ func NewAPIKeyService( func (s *APIKeyService) Start() { requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished) if err != nil { - s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) + s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicRequestFinished, err) return } masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged) if err != nil { - s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err) + s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicMasterKeyStatusChanged, err) return } keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged) if err != nil { - s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) + s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicKeyStatusChanged, err) return } importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted) if err != nil { - s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err) + s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicImportGroupCompleted, err) return } - s.logger.Info("Started and subscribed to request, master key, health check, and import events.") + + s.logger.Info("Started and subscribed to all event topics") go func() { defer requestSub.Close() defer masterKeySub.Close() defer keyStatusSub.Close() defer importSub.Close() + for { select { case msg := <-requestSub.Channel(): var event models.RequestFinishedEvent if err := json.Unmarshal(msg.Payload, &event); err != nil { - s.logger.Errorf("Failed to unmarshal event for key status update: %v", err) + s.logger.WithError(err).Error("Failed to unmarshal RequestFinishedEvent") continue } s.handleKeyUsageEvent(&event) @@ -137,14 +129,15 @@ func (s *APIKeyService) Start() { case msg := <-masterKeySub.Channel(): var event models.MasterKeyStatusChangedEvent if err := json.Unmarshal(msg.Payload, &event); err != nil { - s.logger.Errorf("Failed to unmarshal MasterKeyStatusChangedEvent: %v", err) + s.logger.WithError(err).Error("Failed to unmarshal MasterKeyStatusChangedEvent") continue } s.handleMasterKeyStatusChangeEvent(&event) + case msg := <-keyStatusSub.Channel(): var event models.KeyStatusChangedEvent if err := json.Unmarshal(msg.Payload, &event); err != nil { - s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err) + s.logger.WithError(err).Error("Failed to unmarshal KeyStatusChangedEvent") continue } s.handleKeyStatusChangeEvent(&event) @@ -152,15 +145,14 @@ func (s *APIKeyService) Start() { case msg := <-importSub.Channel(): var event models.ImportGroupCompletedEvent if err := json.Unmarshal(msg.Payload, &event); err != nil { - s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent.") + s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent") continue } - s.logger.Infof("Received ImportGroupCompletedEvent for group %d, triggering validation for %d keys.", event.GroupID, len(event.KeyIDs)) - + s.logger.Infof("Received import completion for group %d, validating %d keys", event.GroupID, len(event.KeyIDs)) go s.handlePostImportValidation(&event) case <-s.stopChan: - s.logger.Info("Stopping event listener.") + s.logger.Info("Stopping event listener") return } } @@ -175,13 +167,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil { return } + ctx := context.Background() + groupID := *event.RequestLog.GroupID + keyID := *event.RequestLog.KeyID + if event.RequestLog.IsSuccess { - mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID) + mapping, err := s.keyRepo.GetMapping(groupID, keyID) if err != nil { - s.logger.Warnf("[%s] Could not find mapping for G:%d K:%d on successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err) + s.logger.Warnf("[%s] Mapping not found for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err) return } + statusChanged := false oldStatus := mapping.Status if mapping.Status != models.StatusActive { @@ -193,38 +190,33 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) now := time.Now() mapping.LastUsedAt = &now + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { - s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err) + s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err) return } + if statusChanged { - go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use") + go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusActive, "key_recovered_after_use") } return } + if event.Error != nil { - s.judgeKeyErrors( - ctx, - event.CorrelationID, - *event.RequestLog.GroupID, - *event.RequestLog.KeyID, - event.Error, - event.IsPreciseRouting, - ) + s.judgeKeyErrors(ctx, event.CorrelationID, groupID, keyID, event.Error, event.IsPreciseRouting) } } func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) { ctx := context.Background() - log := s.logger.WithFields(logrus.Fields{ + s.logger.WithFields(logrus.Fields{ "group_id": event.GroupID, "key_id": event.KeyID, "new_status": event.NewStatus, "reason": event.ChangeReason, - }) - log.Info("Received KeyStatusChangedEvent, will update polling caches.") + }).Info("Updating polling caches based on status change") + s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus) - log.Info("Polling caches updated based on health check event.") } func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { @@ -238,7 +230,7 @@ func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, k } eventData, _ := json.Marshal(changeEvent) if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { - s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err) + s.logger.Errorf("Failed to publish status change event for group %d: %v", groupID, err) } } @@ -248,10 +240,12 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu if err != nil { return nil, err } + totalPages := 0 if total > 0 && params.PageSize > 0 { totalPages = int(math.Ceil(float64(total) / float64(params.PageSize))) } + return &PaginatedAPIKeys{ Items: items, Total: total, @@ -260,34 +254,44 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu TotalPages: totalPages, }, nil } - s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword) - var statusesToFilter []string + + s.logger.Infof("Performing in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword) + + statusesToFilter := []string{"all"} if params.Status != "" { - statusesToFilter = append(statusesToFilter, params.Status) - } else { - statusesToFilter = append(statusesToFilter, "all") + statusesToFilter = []string{params.Status} } + allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter) if err != nil { - return nil, fmt.Errorf("failed to fetch all key IDs for search: %w", err) + return nil, fmt.Errorf("failed to fetch key IDs: %w", err) } + if len(allKeyIDs) == 0 { - return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil + return &PaginatedAPIKeys{ + Items: []*models.APIKeyDetails{}, + Total: 0, + Page: 1, + PageSize: params.PageSize, + TotalPages: 0, + }, nil } allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs) if err != nil { - return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err) + return nil, fmt.Errorf("failed to fetch keys: %w", err) } + var allMappings []models.GroupAPIKeyMapping - err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error - if err != nil { - return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err) + if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error; err != nil { + return nil, fmt.Errorf("failed to fetch mappings: %w", err) } - mappingMap := make(map[uint]*models.GroupAPIKeyMapping) + + mappingMap := make(map[uint]*models.GroupAPIKeyMapping, len(allMappings)) for i := range allMappings { mappingMap[allMappings[i].APIKeyID] = &allMappings[i] } + var filteredItems []*models.APIKeyDetails for _, key := range allKeys { if strings.Contains(key.APIKey, params.Keyword) { @@ -307,12 +311,15 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu } } } + sort.Slice(filteredItems, func(i, j int) bool { return filteredItems[i].ID > filteredItems[j].ID }) + total := int64(len(filteredItems)) start := (params.Page - 1) * params.PageSize end := start + params.PageSize + if start < 0 { start = 0 } @@ -328,9 +335,9 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu if end > len(filteredItems) { end = len(filteredItems) } - paginatedItems := filteredItems[start:end] + return &PaginatedAPIKeys{ - Items: paginatedItems, + Items: filteredItems[start:end], Total: total, Page: params.Page, PageSize: params.PageSize, @@ -344,15 +351,8 @@ func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models. func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error { go func() { - bgCtx := context.Background() - var oldKey models.APIKey - if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil { - s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err) - return - } if err := s.keyRepo.Update(key); err != nil { - s.logger.Errorf("Failed to asynchronously update key ID %d: %v", key.ID, err) - return + s.logger.Errorf("Failed to update key ID %d: %v", key.ID, err) } }() return nil @@ -361,22 +361,24 @@ func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) er func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error { groups, err := s.keyRepo.GetGroupsForKey(ctx, id) if err != nil { - s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err) + s.logger.Warnf("Could not get groups for key %d before deletion: %v", id, err) } - err = s.keyRepo.HardDeleteByID(id) - if err == nil { - for _, groupID := range groups { - event := models.KeyStatusChangedEvent{ - KeyID: id, - GroupID: groupID, - ChangeReason: "key_hard_deleted", - } - eventData, _ := json.Marshal(event) - go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData) - } + if err := s.keyRepo.HardDeleteByID(id); err != nil { + return err } - return err + + for _, groupID := range groups { + event := models.KeyStatusChangedEvent{ + KeyID: id, + GroupID: groupID, + ChangeReason: "key_hard_deleted", + } + eventData, _ := json.Marshal(event) + go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData) + } + + return nil } func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) { @@ -388,47 +390,54 @@ func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive { return nil, CustomErrors.ErrStateConflictMasterRevoked } + mapping, err := s.keyRepo.GetMapping(groupID, keyID) if err != nil { return nil, err } + oldStatus := mapping.Status if oldStatus == newStatus { return mapping, nil } + mapping.Status = newStatus if newStatus == models.StatusActive { mapping.ConsecutiveErrorCount = 0 mapping.LastError = "" } + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { return nil, err } + go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update") return mapping, nil } func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) { - ctx := context.Background() - s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus) if event.NewMasterStatus != models.MasterStatusRevoked { return } + + ctx := context.Background() + s.logger.Infof("Key %d revoked, propagating to all groups", event.KeyID) + affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID) if err != nil { - s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID) + s.logger.WithError(err).Errorf("Failed to get groups for key %d", event.KeyID) return } + if len(affectedGroupIDs) == 0 { - s.logger.Infof("Key ID %d is revoked, but it's not associated with any group. No action needed.", event.KeyID) + s.logger.Infof("Key %d not associated with any group", event.KeyID) return } - s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs)) + for _, groupID := range affectedGroupIDs { - _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned) - if err != nil { + if _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned); err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { - s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID) + s.logger.WithError(err).Errorf("Failed to ban key %d in group %d", event.KeyID, groupID) } } } @@ -436,59 +445,71 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) { if len(keyIDs) == 0 { - return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.") + return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided") } + resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour) if err != nil { return nil, err } + go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs) return taskStatus, nil } -func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) { +func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint) { defer func() { if r := recover(); r != nil { - s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r) - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r)) + s.logger.Errorf("Panic in restore task: %v", r) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r)) } }() + var mappingsToProcess []models.GroupAPIKeyMapping - err := s.db.WithContext(ctx).Preload("APIKey"). + if err := s.db.WithContext(ctx).Preload("APIKey"). Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs). - Find(&mappingsToProcess).Error - if err != nil { + Find(&mappingsToProcess).Error; err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) return } - result := &BatchRestoreResult{ - SkippedKeys: make([]SkippedKeyInfo, 0), - } + + result := &BatchRestoreResult{SkippedKeys: make([]SkippedKeyInfo, 0)} var successfulMappings []*models.GroupAPIKeyMapping - processedCount := 0 - for _, mapping := range mappingsToProcess { - processedCount++ - _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount) + + for i, mapping := range mappingsToProcess { + _ = s.taskService.UpdateProgressByID(ctx, taskID, i+1) + if mapping.APIKey == nil { result.SkippedCount++ - result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."}) + result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{ + KeyID: mapping.APIKeyID, + Reason: "APIKey not found", + }) continue } + if mapping.APIKey.MasterStatus != models.MasterStatusActive { result.SkippedCount++ - result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: fmt.Sprintf("Master status is '%s'.", mapping.APIKey.MasterStatus)}) + result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{ + KeyID: mapping.APIKeyID, + Reason: fmt.Sprintf("Master status is %s", mapping.APIKey.MasterStatus), + }) continue } + oldStatus := mapping.Status if oldStatus != models.StatusActive { mapping.Status = models.StatusActive mapping.ConsecutiveErrorCount = 0 mapping.LastError = "" + if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil { - s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.") result.SkippedCount++ - result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."}) + result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{ + KeyID: mapping.APIKeyID, + Reason: "DB update failed", + }) } else { result.RestoredCount++ successfulMappings = append(successfulMappings, &mapping) @@ -498,61 +519,72 @@ func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, r result.RestoredCount++ } } + if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil { - s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.") + s.logger.WithError(err).Error("Failed batch cache update after restore") } + result.SkippedCount += (len(keyIDs) - len(mappingsToProcess)) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) { var bannedKeyIDs []uint - err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). + if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned). - Pluck("api_key_id", &bannedKeyIDs).Error - if err != nil { + Pluck("api_key_id", &bannedKeyIDs).Error; err != nil { return nil, CustomErrors.ParseDBError(err) } + if len(bannedKeyIDs) == 0 { - return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.") + return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore") } + return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs) } func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) { ctx := context.Background() + group, ok := s.groupManager.GetGroupByID(event.GroupID) if !ok { - s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID) + s.logger.Errorf("Group %d not found for post-import validation", event.GroupID) return } + opConfig, err := s.groupManager.BuildOperationalConfig(group) if err != nil { - s.logger.Errorf("Failed to build operational config for group %d, aborting validation: %v", event.GroupID, err) + s.logger.Errorf("Failed to build config for group %d: %v", event.GroupID, err) return } + endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID) if err != nil { - s.logger.Errorf("Failed to build key check endpoint for group %d, aborting validation: %v", event.GroupID, err) + s.logger.Errorf("Failed to build endpoint for group %d: %v", event.GroupID, err) return } - globalSettings := s.SettingsManager.GetSettings() - concurrency := globalSettings.BaseKeyCheckConcurrency + + concurrency := s.SettingsManager.GetSettings().BaseKeyCheckConcurrency if opConfig.KeyCheckConcurrency != nil { concurrency = *opConfig.KeyCheckConcurrency } if concurrency <= 0 { concurrency = 10 } - timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second + + timeout := time.Duration(s.SettingsManager.GetSettings().KeyCheckTimeoutSeconds) * time.Second + keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs) if err != nil { - s.logger.Errorf("Failed to get key models for validation in group %d: %v", event.GroupID, err) + s.logger.Errorf("Failed to get keys for validation in group %d: %v", event.GroupID, err) return } - s.logger.Infof("Validating %d keys for group %d with concurrency %d against endpoint %s", len(keysToValidate), event.GroupID, concurrency, endpoint) + + s.logger.Infof("Validating %d keys for group %d (concurrency: %d)", len(keysToValidate), event.GroupID, concurrency) + var wg sync.WaitGroup jobs := make(chan models.APIKey, len(keysToValidate)) + for i := 0; i < concurrency; i++ { wg.Add(1) go func() { @@ -560,9 +592,8 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp for key := range jobs { validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint) if validationErr == nil { - s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID) if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil { - s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err) + s.logger.Errorf("Failed to activate key %d in group %d: %v", key.ID, event.GroupID, err) } } else { var apiErr *CustomErrors.APIError @@ -574,35 +605,34 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp } }() } + for _, key := range keysToValidate { jobs <- key } close(jobs) wg.Wait() - s.logger.Infof("Finished post-import validation for group %d.", event.GroupID) + + s.logger.Infof("Finished post-import validation for group %d", event.GroupID) } func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) { - s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses) - keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses) if err != nil { return nil, CustomErrors.ParseDBError(err) } + if len(keyIDs) == 0 { now := time.Now() return &task.Status{ - IsRunning: false, - Processed: 0, - Total: 0, - Result: map[string]string{ - "message": "没有找到任何符合当前过滤条件的Key可供操作。", - }, - Error: "", + IsRunning: false, + Processed: 0, + Total: 0, + Result: map[string]string{"message": "没有找到符合条件的Key"}, StartedAt: now, FinishedAt: &now, }, nil } + resourceID := fmt.Sprintf("group-%d-status-update", groupID) taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute) if err != nil { @@ -616,10 +646,11 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, group func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) { defer func() { if r := recover(); r != nil { - s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r) - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r)) + s.logger.Errorf("Panic in status update task: %v", r) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r)) } }() + type BatchUpdateResult struct { UpdatedCount int `json:"updated_count"` SkippedCount int `json:"skipped_count"` @@ -629,33 +660,35 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, keys, err := s.keyRepo.GetKeysByIDs(keyIDs) if err != nil { - s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.") s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) return } - masterStatusMap := make(map[uint]models.MasterAPIKeyStatus) + + masterStatusMap := make(map[uint]models.MasterAPIKeyStatus, len(keys)) for _, key := range keys { masterStatusMap[key.ID] = key.MasterStatus } + var mappings []*models.GroupAPIKeyMapping if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil { - s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.") s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) return } - processedCount := 0 - for _, mapping := range mappings { - processedCount++ - _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount) + + for i, mapping := range mappings { + _ = s.taskService.UpdateProgressByID(ctx, taskID, i+1) + masterStatus, ok := masterStatusMap[mapping.APIKeyID] if !ok { result.SkippedCount++ continue } + if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive { result.SkippedCount++ continue } + oldStatus := mapping.Status if oldStatus != newStatus { mapping.Status = newStatus @@ -663,6 +696,7 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, mapping.ConsecutiveErrorCount = 0 mapping.LastError = "" } + if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil { result.SkippedCount++ } else { @@ -674,122 +708,112 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, result.UpdatedCount++ } } + result.SkippedCount += (len(keyIDs) - len(mappings)) + if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil { - s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.") + s.logger.WithError(err).Error("Failed batch cache update after status update") } - s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount) + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) { ctx := context.Background() + if success { if group.PollingStrategy == models.StrategyWeighted { go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID) } return } + if apiErr == nil { - s.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided. No action taken.", key.ID, group.ID) return } + errMsg := apiErr.Message if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) { - s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg) go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown) - } else { - s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg) } } -func sanitizeForLog(errMsg string) string { - jsonStartIndex := strings.Index(errMsg, "{") - var cleanMsg string - if jsonStartIndex != -1 { - cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}" - } else { - cleanMsg = errMsg - } - const maxLen = 250 - if len(cleanMsg) > maxLen { - return cleanMsg[:maxLen] + "..." - } - return cleanMsg -} +func (s *APIKeyService) judgeKeyErrors(ctx context.Context, correlationID string, groupID, keyID uint, apiErr *CustomErrors.APIError, isPreciseRouting bool) { + logger := s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "key_id": keyID, + "correlation_id": correlationID, + }) -func (s *APIKeyService) judgeKeyErrors( - ctx context.Context, - correlationID string, - groupID, keyID uint, - apiErr *CustomErrors.APIError, - isPreciseRouting bool, -) { - logger := s.logger.WithFields(logrus.Fields{"group_id": groupID, "key_id": keyID, "correlation_id": correlationID}) mapping, err := s.keyRepo.GetMapping(groupID, keyID) if err != nil { - logger.WithError(err).Warn("Cannot apply consequences for error: mapping not found.") + logger.WithError(err).Warn("Mapping not found, cannot apply error consequences") return } + now := time.Now() mapping.LastUsedAt = &now errorMessage := apiErr.Message + if CustomErrors.IsPermanentUpstreamError(errorMessage) { - logger.Errorf("Permanent error detected. Banning mapping and revoking master key. Reason: %s", sanitizeForLog(errorMessage)) - logger.WithField("full_error_details", errorMessage).Debug("Full details of the permanent error.") + logger.Errorf("Permanent error: %s", sanitizeForLog(errorMessage)) + if mapping.Status != models.StatusBanned { oldStatus := mapping.Status mapping.Status = models.StatusBanned mapping.LastError = errorMessage + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { - logger.WithError(err).Error("Failed to update mapping status to BANNED.") + logger.WithError(err).Error("Failed to ban mapping") } else { - go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned") + go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusBanned, "permanent_error") go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error") } } return } + if CustomErrors.IsTemporaryUpstreamError(errorMessage) { mapping.LastError = errorMessage mapping.ConsecutiveErrorCount++ - var threshold int + + threshold := s.SettingsManager.GetSettings().BlacklistThreshold if isPreciseRouting { - group, ok := s.groupManager.GetGroupByID(groupID) - opConfig, err := s.groupManager.BuildOperationalConfig(group) - if !ok || err != nil { - logger.Warnf("Could not build operational config for group %d in Precise Routing mode. Falling back to global settings.", groupID) - threshold = s.SettingsManager.GetSettings().BlacklistThreshold - } else { - threshold = *opConfig.KeyBlacklistThreshold + if group, ok := s.groupManager.GetGroupByID(groupID); ok { + if opConfig, err := s.groupManager.BuildOperationalConfig(group); err == nil && opConfig.KeyBlacklistThreshold != nil { + threshold = *opConfig.KeyBlacklistThreshold + } } - } else { - threshold = s.SettingsManager.GetSettings().BlacklistThreshold } - logger.Warnf("Temporary error detected. Incrementing error count. New count: %d (Threshold: %d). Reason: %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage)) - logger.WithField("full_error_details", errorMessage).Debug("Full details of the temporary error.") + + logger.Warnf("Temporary error (count: %d, threshold: %d): %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage)) + oldStatus := mapping.Status newStatus := oldStatus + if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive { newStatus = models.StatusCooldown - logger.Errorf("Putting mapping into COOLDOWN due to reaching temporary error threshold (%d)", threshold) + logger.Errorf("Moving to COOLDOWN after reaching threshold %d", threshold) } + if oldStatus != newStatus { mapping.Status = newStatus } + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { - logger.WithError(err).Error("Failed to update mapping after temporary error.") + logger.WithError(err).Error("Failed to update mapping after temporary error") return } + if oldStatus != newStatus { go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached") } return } - logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage)) - logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.") + + logger.Infof("Ignorable error: %s", sanitizeForLog(errorMessage)) if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { - logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.") + logger.WithError(err).Error("Failed to update LastUsedAt") } } @@ -797,25 +821,27 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason key, err := s.keyRepo.GetKeyByID(keyID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - s.logger.Warnf("Attempted to revoke non-existent key ID %d.", keyID) + s.logger.Warnf("Attempted to revoke non-existent key %d", keyID) } else { - s.logger.Errorf("Failed to get key by ID %d for master status revocation: %v", keyID, err) + s.logger.Errorf("Failed to get key %d for revocation: %v", keyID, err) } return } + if key.MasterStatus == models.MasterStatusRevoked { return } + oldMasterStatus := key.MasterStatus - newMasterStatus := models.MasterStatusRevoked - if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil { - s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err) + if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, models.MasterStatusRevoked); err != nil { + s.logger.Errorf("Failed to revoke key %d: %v", keyID, err) return } + masterKeyEvent := models.MasterKeyStatusChangedEvent{ KeyID: keyID, OldMasterStatus: oldMasterStatus, - NewMasterStatus: newMasterStatus, + NewMasterStatus: models.MasterStatusRevoked, ChangeReason: reason, ChangedAt: time.Now(), } @@ -826,3 +852,13 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) { return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses) } + +func sanitizeForLog(errMsg string) string { + if idx := strings.Index(errMsg, "{"); idx != -1 { + errMsg = strings.TrimSpace(errMsg[:idx]) + " {...}" + } + if len(errMsg) > 250 { + return errMsg[:250] + "..." + } + return errMsg +} diff --git a/internal/service/group_manager.go b/internal/service/group_manager.go index d5ca7b7..d8690dd 100644 --- a/internal/service/group_manager.go +++ b/internal/service/group_manager.go @@ -1,5 +1,4 @@ -// Filename: internal/service/group_manager.go (Syncer升级版) - +// Filename: internal/service/group_manager.go package service import ( @@ -29,10 +28,9 @@ type GroupManagerCacheData struct { Groups []*models.KeyGroup GroupsByName map[string]*models.KeyGroup GroupsByID map[uint]*models.KeyGroup - KeyCounts map[uint]int64 // GroupID -> Total Key Count - KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count + KeyCounts map[uint]int64 + KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 } - type GroupManager struct { db *gorm.DB keyRepo repository.KeyRepository @@ -41,7 +39,6 @@ type GroupManager struct { syncer *syncer.CacheSyncer[GroupManagerCacheData] logger *logrus.Entry } - type UpdateOrderPayload struct { ID uint `json:"id" binding:"required"` Order int `json:"order"` @@ -49,43 +46,19 @@ type UpdateOrderPayload struct { func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] { return func() (GroupManagerCacheData, error) { - logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...") var groups []*models.KeyGroup - logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...") - if err := db.Preload("AllowedUpstreams"). Preload("AllowedModels"). Preload("Settings"). Preload("RequestConfig"). + Preload("Mappings"). Find(&groups).Error; err != nil { - logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err) - return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err) + return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err) } - logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups)) - - var allMappings []*models.GroupAPIKeyMapping - if err := db.Find(&allMappings).Error; err != nil { - logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err) - return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err) - } - logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings)) - - mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping) - for i := range allMappings { - mapping := allMappings[i] // Avoid pointer issues with range - mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping) - } - - for _, group := range groups { - if mappings, ok := mappingsByGroupID[group.ID]; ok { - group.Mappings = mappings - } - } - logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.") - keyCounts := make(map[uint]int64, len(groups)) keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups)) - + groupsByName := make(map[string]*models.KeyGroup, len(groups)) + groupsByID := make(map[uint]*models.KeyGroup, len(groups)) for _, group := range groups { keyCounts[group.ID] = int64(len(group.Mappings)) statusCounts := make(map[models.APIKeyStatus]int64) @@ -93,20 +66,9 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc statusCounts[mapping.Status]++ } keyStatusCounts[group.ID] = statusCounts + groupsByName[group.Name] = group + groupsByID[group.ID] = group } - groupsByName := make(map[string]*models.KeyGroup, len(groups)) - groupsByID := make(map[uint]*models.KeyGroup, len(groups)) - - logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...") - for i, group := range groups { - if group == nil { - logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i) - } else { - groupsByName[group.Name] = group - groupsByID[group.ID] = group - } - } - logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...") return GroupManagerCacheData{ Groups: groups, GroupsByName: groupsByName, @@ -116,7 +78,6 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc }, nil } } - func NewGroupManager( db *gorm.DB, keyRepo repository.KeyRepository, @@ -134,138 +95,67 @@ func NewGroupManager( logger: logger.WithField("component", "GroupManager"), } } - func (gm *GroupManager) GetAllGroups() []*models.KeyGroup { - cache := gm.syncer.Get() - if len(cache.Groups) == 0 { - return []*models.KeyGroup{} - } - groupsToOrder := cache.Groups - sort.Slice(groupsToOrder, func(i, j int) bool { - if groupsToOrder[i].Order != groupsToOrder[j].Order { - return groupsToOrder[i].Order < groupsToOrder[j].Order + groups := gm.syncer.Get().Groups + sort.Slice(groups, func(i, j int) bool { + if groups[i].Order != groups[j].Order { + return groups[i].Order < groups[j].Order } - return groupsToOrder[i].ID < groupsToOrder[j].ID + return groups[i].ID < groups[j].ID }) - return groupsToOrder + return groups } - func (gm *GroupManager) GetKeyCount(groupID uint) int64 { - cache := gm.syncer.Get() - if len(cache.KeyCounts) == 0 { - return 0 - } - count := cache.KeyCounts[groupID] - return count + return gm.syncer.Get().KeyCounts[groupID] } - func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 { - cache := gm.syncer.Get() - if len(cache.KeyStatusCounts) == 0 { - return make(map[models.APIKeyStatus]int64) - } - if counts, ok := cache.KeyStatusCounts[groupID]; ok { + if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok { return counts } return make(map[models.APIKeyStatus]int64) } - func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) { - cache := gm.syncer.Get() - if len(cache.GroupsByName) == 0 { - return nil, false - } - group, ok := cache.GroupsByName[name] + group, ok := gm.syncer.Get().GroupsByName[name] return group, ok } - func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) { - cache := gm.syncer.Get() - if len(cache.GroupsByID) == 0 { - return nil, false - } - group, ok := cache.GroupsByID[id] + group, ok := gm.syncer.Get().GroupsByID[id] return group, ok } - func (gm *GroupManager) Stop() { gm.syncer.Stop() } - func (gm *GroupManager) Invalidate() error { return gm.syncer.Invalidate() } - -// --- Write Operations --- - -// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache. func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error { if !utils.IsValidGroupName(group.Name) { return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens") } err := gm.db.Transaction(func(tx *gorm.DB) error { - // 1. Create the group itself to get an ID if err := tx.Create(group).Error; err != nil { return err } - - // 2. If settings are provided, create the associated GroupSettings record if settings != nil { - // Only marshal non-nil fields to keep the JSON clean - settingsToMarshal := make(map[string]interface{}) - if settings.EnableKeyCheck != nil { - settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck + settingsJSON, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("failed to marshal settings: %w", err) } - if settings.KeyCheckIntervalMinutes != nil { - settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes + groupSettings := models.GroupSettings{ + GroupID: group.ID, + SettingsJSON: datatypes.JSON(settingsJSON), } - if settings.KeyBlacklistThreshold != nil { - settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold - } - if settings.KeyCooldownMinutes != nil { - settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes - } - if settings.KeyCheckConcurrency != nil { - settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency - } - if settings.KeyCheckEndpoint != nil { - settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint - } - if settings.KeyCheckModel != nil { - settingsToMarshal["key_check_model"] = settings.KeyCheckModel - } - if settings.MaxRetries != nil { - settingsToMarshal["max_retries"] = settings.MaxRetries - } - if settings.EnableSmartGateway != nil { - settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway - } - if len(settingsToMarshal) > 0 { - settingsJSON, err := json.Marshal(settingsToMarshal) - if err != nil { - return fmt.Errorf("failed to marshal group settings: %w", err) - } - groupSettings := models.GroupSettings{ - GroupID: group.ID, - SettingsJSON: datatypes.JSON(settingsJSON), - } - if err := tx.Create(&groupSettings).Error; err != nil { - return fmt.Errorf("failed to save group settings: %w", err) - } + if err := tx.Create(&groupSettings).Error; err != nil { + return err } } return nil }) - - if err != nil { - return err + if err == nil { + go gm.Invalidate() } - - go gm.Invalidate() - return nil + return err } - -// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache. func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error { if !utils.IsValidGroupName(group.Name) { return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens") @@ -273,7 +163,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode uniqueUpstreamURLs := uniqueStrings(upstreamURLs) uniqueModelNames := uniqueStrings(modelNames) err := gm.db.Transaction(func(tx *gorm.DB) error { - // --- 1. Update AllowedUpstreams (M:N relationship) --- var upstreams []*models.UpstreamEndpoint if len(uniqueUpstreamURLs) > 0 { if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil { @@ -283,7 +172,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil { return err } - if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil { return err } @@ -296,11 +184,9 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode return err } } - if err := tx.Model(group).Updates(group).Error; err != nil { return err } - var existingSettings models.GroupSettings if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound { return err @@ -308,15 +194,15 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode var currentSettingsData models.KeyGroupSettings if len(existingSettings.SettingsJSON) > 0 { if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil { - return fmt.Errorf("failed to unmarshal existing group settings: %w", err) + return fmt.Errorf("failed to unmarshal existing settings: %w", err) } } if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil { - return fmt.Errorf("failed to merge group settings: %w", err) + return fmt.Errorf("failed to merge settings: %w", err) } updatedJSON, err := json.Marshal(currentSettingsData) if err != nil { - return fmt.Errorf("failed to marshal updated group settings: %w", err) + return fmt.Errorf("failed to marshal updated settings: %w", err) } existingSettings.GroupID = group.ID existingSettings.SettingsJSON = datatypes.JSON(updatedJSON) @@ -327,55 +213,25 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode } return err } - -// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans. func (gm *GroupManager) DeleteKeyGroup(id uint) error { err := gm.db.Transaction(func(tx *gorm.DB) error { - gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id) - // Step 1: First, retrieve the group object we are about to delete. var group models.KeyGroup if err := tx.First(&group, id).Error; err != nil { if err == gorm.ErrRecordNotFound { - gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id) - return nil // Don't treat as an error, the group is already gone. + return nil } - gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id) return err } - // Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods. - if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil { - gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id) + if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil { return err } - if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil { - gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id) - return err - } - if err := tx.Model(&group).Association("Mappings").Clear(); err != nil { - gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id) - return err - } - // Also clear settings if they exist to maintain data integrity - if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil { - gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id) - return err - } - // Step 3: Delete the KeyGroup itself. - if err := tx.Delete(&group).Error; err != nil { - gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id) - return err - } - gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id) - // Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct). deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx) if err != nil { - gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.") return err } if deletedCount > 0 { - gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount) + gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id) } - gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id) return nil }) if err == nil { @@ -383,7 +239,6 @@ func (gm *GroupManager) DeleteKeyGroup(id uint) error { } return err } - func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { var originalGroup models.KeyGroup if err := gm.db. @@ -392,7 +247,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { Preload("AllowedUpstreams"). Preload("AllowedModels"). First(&originalGroup, id).Error; err != nil { - return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err) + return nil, fmt.Errorf("failed to find original group %d: %w", id, err) } newGroup := originalGroup timestamp := time.Now().Unix() @@ -401,31 +256,25 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp) newGroup.CreatedAt = time.Time{} newGroup.UpdatedAt = time.Time{} - newGroup.RequestConfigID = nil newGroup.RequestConfig = nil newGroup.Mappings = nil newGroup.AllowedUpstreams = nil newGroup.AllowedModels = nil err := gm.db.Transaction(func(tx *gorm.DB) error { - if err := tx.Create(&newGroup).Error; err != nil { return err } - if originalGroup.RequestConfig != nil { newRequestConfig := *originalGroup.RequestConfig - newRequestConfig.ID = 0 // Mark as new record - + newRequestConfig.ID = 0 if err := tx.Create(&newRequestConfig).Error; err != nil { return fmt.Errorf("failed to clone request config: %w", err) } - if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil { - return fmt.Errorf("failed to link new group to cloned request config: %w", err) + return fmt.Errorf("failed to link cloned request config: %w", err) } } - var originalSettings models.GroupSettings err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error if err == nil && len(originalSettings.SettingsJSON) > 0 { @@ -434,12 +283,11 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { SettingsJSON: originalSettings.SettingsJSON, } if err := tx.Create(&newSettings).Error; err != nil { - return fmt.Errorf("failed to clone group settings: %w", err) + return fmt.Errorf("failed to clone settings: %w", err) } } else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return fmt.Errorf("failed to query original group settings: %w", err) + return fmt.Errorf("failed to query original settings: %w", err) } - if len(originalGroup.Mappings) > 0 { newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings)) for i, oldMapping := range originalGroup.Mappings { @@ -454,7 +302,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { } } if err := tx.Create(&newMappings).Error; err != nil { - return fmt.Errorf("failed to clone key group mappings: %w", err) + return fmt.Errorf("failed to clone mappings: %w", err) } } if len(originalGroup.AllowedUpstreams) > 0 { @@ -469,13 +317,10 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { } return nil }) - if err != nil { return nil, err } - go gm.Invalidate() - var finalClonedGroup models.KeyGroup if err := gm.db. Preload("RequestConfig"). @@ -487,10 +332,9 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { } return &finalClonedGroup, nil } - func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) { globalSettings := gm.settingsManager.GetSettings() - s := "gemini-1.5-flash" // Per user feedback for default model + defaultModel := "gemini-1.5-flash" opConfig := &models.KeyGroupSettings{ EnableKeyCheck: &globalSettings.EnableBaseKeyCheck, KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency, @@ -498,52 +342,43 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models. KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL, KeyBlacklistThreshold: &globalSettings.BlacklistThreshold, KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes, - KeyCheckModel: &s, + KeyCheckModel: &defaultModel, MaxRetries: &globalSettings.MaxRetries, EnableSmartGateway: &globalSettings.EnableSmartGateway, } - if group == nil { return opConfig, nil } - var groupSettingsRecord models.GroupSettings err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return opConfig, nil } - - gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID) return nil, err } - if len(groupSettingsRecord.SettingsJSON) == 0 { return opConfig, nil } - var groupSpecificSettings models.KeyGroupSettings if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil { - gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.") - return opConfig, err - } - - if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil { - gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.") + gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings") + return opConfig, nil + } + if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil { + gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings") return opConfig, nil } - return opConfig, nil } - func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) { group, ok := gm.GetGroupByID(groupID) if !ok { - return "", fmt.Errorf("group with id %d not found", groupID) + return "", fmt.Errorf("group %d not found", groupID) } opConfig, err := gm.BuildOperationalConfig(group) if err != nil { - return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err) + return "", err } globalSettings := gm.settingsManager.GetSettings() baseURL := globalSettings.DefaultUpstreamURL @@ -551,7 +386,7 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) { baseURL = *opConfig.KeyCheckEndpoint } if baseURL == "" { - return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID) + return "", fmt.Errorf("no endpoint configured for group %d", groupID) } modelName := globalSettings.BaseKeyCheckModel if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" { @@ -559,38 +394,31 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) { } parsedURL, err := url.Parse(baseURL) if err != nil { - return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err) + return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err) } - cleanedPath := parsedURL.Path - cleanedPath = strings.TrimSuffix(cleanedPath, "/") - cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta") - parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName) - finalEndpoint := parsedURL.String() - return finalEndpoint, nil + cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta") + parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName) + return parsedURL.String(), nil } - func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error { ordersMap := make(map[uint]int, len(payload)) for _, item := range payload { ordersMap[item.ID] = item.Order } if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil { - gm.logger.WithError(err).Error("Failed to update group order in transaction") - return fmt.Errorf("database transaction failed: %w", err) + return fmt.Errorf("failed to update order: %w", err) } - gm.logger.Info("Group order updated successfully, invalidating cache...") go gm.Invalidate() return nil } - func uniqueStrings(slice []string) []string { - keys := make(map[string]struct{}) - list := []string{} - for _, entry := range slice { - if _, value := keys[entry]; !value { - keys[entry] = struct{}{} - list = append(list, entry) + seen := make(map[string]struct{}, len(slice)) + result := make([]string, 0, len(slice)) + for _, s := range slice { + if _, exists := seen[s]; !exists { + seen[s] = struct{}{} + result = append(result, s) } } - return list + return result } diff --git a/internal/store/memory_store.go b/internal/store/memory_store.go index 3d9b268..36adc99 100644 --- a/internal/store/memory_store.go +++ b/internal/store/memory_store.go @@ -704,15 +704,17 @@ func (p *memoryPipeliner) LRem(key string, count int64, value any) { newList := make([]string, 0, len(list)) removed := int64(0) for _, v := range list { - if count != 0 && v == capturedValue && (count < 0 || removed < count) { + shouldRemove := v == capturedValue && (count == 0 || removed < count) + if shouldRemove { removed++ - continue + } else { + newList = append(newList, v) } - newList = append(newList, v) } item.value = newList }) } + func (p *memoryPipeliner) HSet(key string, values map[string]any) { capturedKey := key capturedValues := make(map[string]any, len(values)) @@ -762,17 +764,31 @@ func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) { p.ops = append(p.ops, func() { item, ok := p.store.items[capturedKey] if !ok || item.isExpired() { - item = &memoryStoreItem{value: make(map[string]float64)} + item = &memoryStoreItem{value: make([]zsetMember, 0)} p.store.items[capturedKey] = item } - zset, ok := item.value.(map[string]float64) + zset, ok := item.value.([]zsetMember) if !ok { - zset = make(map[string]float64) - item.value = zset + zset = make([]zsetMember, 0) } - for member, score := range capturedMembers { - zset[member] = score + membersMap := make(map[string]float64, len(zset)) + for _, z := range zset { + membersMap[z.Value] = z.Score } + for memberVal, score := range capturedMembers { + membersMap[memberVal] = score + } + newZSet := make([]zsetMember, 0, len(membersMap)) + for val, score := range membersMap { + newZSet = append(newZSet, zsetMember{Value: val, Score: score}) + } + sort.Slice(newZSet, func(i, j int) bool { + if newZSet[i].Score == newZSet[j].Score { + return newZSet[i].Value < newZSet[j].Value + } + return newZSet[i].Score < newZSet[j].Score + }) + item.value = newZSet }) } func (p *memoryPipeliner) ZRem(key string, members ...any) { @@ -784,13 +800,21 @@ func (p *memoryPipeliner) ZRem(key string, members ...any) { if !ok || item.isExpired() { return } - zset, ok := item.value.(map[string]float64) + zset, ok := item.value.([]zsetMember) if !ok { return } - for _, member := range capturedMembers { - delete(zset, fmt.Sprintf("%v", member)) + membersToRemove := make(map[string]struct{}, len(capturedMembers)) + for _, m := range capturedMembers { + membersToRemove[fmt.Sprintf("%v", m)] = struct{}{} } + newZSet := make([]zsetMember, 0, len(zset)) + for _, z := range zset { + if _, exists := membersToRemove[z.Value]; !exists { + newZSet = append(newZSet, z) + } + } + item.value = newZSet }) }