// Filename: internal/service/key_import_service.go package service import ( "context" "encoding/json" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/store" "gemini-balancer/internal/task" "gemini-balancer/internal/utils" "strings" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) const ( TaskTypeAddKeysToGroup = "add_keys_to_group" TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group" TaskTypeHardDeleteKeys = "hard_delete_keys" TaskTypeRestoreKeys = "restore_keys" chunkSize = 500 // 任务超时时间常量化 defaultTaskTimeout = 15 * time.Minute longTaskTimeout = time.Hour ) type KeyImportService struct { taskService task.Reporter keyRepo repository.KeyRepository store store.Store logger *logrus.Entry apiKeyService *APIKeyService } func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService { return &KeyImportService{ taskService: ts, keyRepo: kr, store: s, logger: logger.WithField("component", "KeyImportService🚀"), apiKeyService: as, } } // runTaskWithRecovery 统一的任务恢复包装器 func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic recovered in task %s: %v", taskID, r) s.logger.WithField("task_id", taskID).Error(err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) } }() taskFunc() } // StartAddKeysTask 启动批量添加密钥任务 func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found in input text") } resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport) }) return taskStatus, nil } // StartUnlinkKeysTask 启动批量解绑密钥任务 func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys) }) return taskStatus, nil } // StartHardDeleteKeysTask 启动硬删除密钥任务 func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := "global_hard_delete" taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) return taskStatus, nil } // StartRestoreKeysTask 启动恢复密钥任务 func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := "global_restore_keys" taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) return taskStatus, nil } // StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑 func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { return nil, fmt.Errorf("failed to find keys by filter: %w", err) } if len(keyValues) == 0 { return nil, fmt.Errorf("no keys found matching the provided filter") } keysAsText := strings.Join(keyValues, "\n") s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID) return s.StartUnlinkKeysTask(ctx, groupID, keysAsText) } // ==================== 核心任务执行逻辑 ==================== // runAddKeysTask 执行批量添加密钥 func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { // 1. 去重 uniqueKeys := s.deduplicateKeys(keys) if len(uniqueKeys) == 0 { s.endTaskWithResult(ctx, taskID, resourceID, gin.H{ "newly_linked_count": 0, "already_linked_count": 0, }, nil) return } // 2. 确保所有密钥在数据库中存在(幂等操作) allKeyModels, err := s.ensureKeysExist(uniqueKeys) if err != nil { s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) return } // 3. 过滤已关联的密钥 keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys) if err != nil { s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err)) return } // 4. 更新任务的实际处理总数 if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } // 5. 批量关联密钥到组 if len(keysToLink) > 0 { if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil { s.endTaskWithResult(ctx, taskID, resourceID, nil, err) return } } // 6. 根据验证标志处理密钥状态 if len(keysToLink) > 0 { s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport) } // 7. 返回结果 result := gin.H{ "newly_linked_count": len(keysToLink), "already_linked_count": alreadyLinkedCount, "total_linked_count": len(allKeyModels), } s.endTaskWithResult(ctx, taskID, resourceID, result, nil) } // runUnlinkKeysTask 执行批量解绑密钥 func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) { // 1. 去重 uniqueKeys := s.deduplicateKeys(keys) // 2. 查找需要解绑的密钥 keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) if err != nil { s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) return } if len(keysToUnlink) == 0 { result := gin.H{ "unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys), } s.endTaskWithResult(ctx, taskID, resourceID, result, nil) return } // 3. 提取密钥 ID idsToUnlink := s.extractKeyIDs(keysToUnlink) // 4. 更新任务总数 if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } // 5. 批量解绑 totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink) if err != nil { s.endTaskWithResult(ctx, taskID, resourceID, nil, err) return } // 6. 清理孤立密钥 totalDeleted, err := s.keyRepo.DeleteOrphanKeys() if err != nil { s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.") } // 7. 返回结果 result := gin.H{ "unlinked_count": totalUnlinked, "hard_deleted_count": totalDeleted, "not_found_count": len(uniqueKeys) - int(totalUnlinked), } s.endTaskWithResult(ctx, taskID, resourceID, result, nil) } // runHardDeleteKeysTask 执行硬删除密钥 func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) { return s.keyRepo.HardDeleteByValues(chunk) }) if err != nil { s.endTaskWithResult(ctx, taskID, resourceID, nil, err) return } result := gin.H{ "hard_deleted_count": totalDeleted, "not_found_count": int64(len(keys)) - totalDeleted, } s.endTaskWithResult(ctx, taskID, resourceID, result, nil) s.publishChangeEvent(ctx, 0, "keys_hard_deleted") } // runRestoreKeysTask 执行恢复密钥 func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) { return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive) }) if err != nil { s.endTaskWithResult(ctx, taskID, resourceID, nil, err) return } result := gin.H{ "restored_count": restoredCount, "not_found_count": int64(len(keys)) - restoredCount, } s.endTaskWithResult(ctx, taskID, resourceID, result, nil) s.publishChangeEvent(ctx, 0, "keys_bulk_restored") } // ==================== 辅助方法 ==================== // deduplicateKeys 去重密钥列表 func (s *KeyImportService) deduplicateKeys(keys []string) []string { uniqueKeysMap := make(map[string]struct{}, len(keys)) uniqueKeys := make([]string, 0, len(keys)) for _, kStr := range keys { if _, exists := uniqueKeysMap[kStr]; !exists { uniqueKeysMap[kStr] = struct{}{} uniqueKeys = append(uniqueKeys, kStr) } } return uniqueKeys } // ensureKeysExist 确保所有密钥在数据库中存在 func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) { keysToEnsure := make([]models.APIKey, len(keys)) for i, keyStr := range keys { keysToEnsure[i] = models.APIKey{APIKey: keyStr} } return s.keyRepo.AddKeys(keysToEnsure) } // filterNewKeys 过滤已关联的密钥,返回需要新增的密钥 func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) { alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) if err != nil { return nil, 0, err } alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels)) for _, key := range alreadyLinkedModels { alreadyLinkedIDSet[key.ID] = struct{}{} } keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet)) for _, key := range allKeyModels { if _, exists := alreadyLinkedIDSet[key.ID]; !exists { keysToLink = append(keysToLink, key) } } return keysToLink, len(alreadyLinkedIDSet), nil } // extractKeyIDs 提取密钥 ID 列表 func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint { ids := make([]uint, len(keys)) for i, key := range keys { ids[i] = key.ID } return ids } // linkKeysInChunks 分块关联密钥到组 func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error { idsToLink := s.extractKeyIDs(keysToLink) for i := 0; i < len(idsToLink); i += chunkSize { end := min(i+chunkSize, len(idsToLink)) chunk := idsToLink[i:end] if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil { return fmt.Errorf("chunk failed to link keys: %w", err) } _ = s.taskService.UpdateProgressByID(ctx, taskID, end) } return nil } // unlinkKeysInChunks 分块解绑密钥 func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) { var totalUnlinked int64 for i := 0; i < len(idsToUnlink); i += chunkSize { end := min(i+chunkSize, len(idsToUnlink)) chunk := idsToUnlink[i:end] unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk) if err != nil { return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err) } totalUnlinked += unlinked // 发布解绑事件 for _, keyID := range chunk { s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked") } _ = s.taskService.UpdateProgressByID(ctx, taskID, end) } return totalUnlinked, nil } // processKeysInChunks 通用的分块处理密钥逻辑 func (s *KeyImportService) processKeysInChunks( ctx context.Context, taskID string, keys []string, processFunc func(chunk []string) (int64, error), ) (int64, error) { var totalProcessed int64 for i := 0; i < len(keys); i += chunkSize { end := min(i+chunkSize, len(keys)) chunk := keys[i:end] count, err := processFunc(chunk) if err != nil { return 0, fmt.Errorf("failed to process chunk: %w", err) } totalProcessed += count _ = s.taskService.UpdateProgressByID(ctx, taskID, end) } return totalProcessed, nil } // processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活) func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) { idsToLink := s.extractKeyIDs(keysToLink) if validateOnImport { // 发布批量导入完成事件,触发验证 s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink) // 发布单个密钥状态变更事件 for _, keyID := range idsToLink { s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked") } } else { // 直接激活密钥,不进行验证 for _, keyID := range idsToLink { if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "key_id": keyID, }).Errorf("Failed to directly activate key: %v", err) } } } } // endTaskWithResult 统一的任务结束处理 func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) { if err != nil { s.logger.WithFields(logrus.Fields{ "task_id": taskID, "resource_id": resourceID, }).WithError(err).Error("Task failed") } s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err) } // ==================== 事件发布方法 ==================== // publishSingleKeyChangeEvent 发布单个密钥状态变更事件 func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, KeyID: keyID, OldStatus: oldStatus, NewStatus: newStatus, ChangeReason: reason, ChangedAt: time.Now(), } eventData, err := json.Marshal(event) if err != nil { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "key_id": keyID, "reason": reason, }).WithError(err).Error("Failed to marshal key change event") return } if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "key_id": keyID, "reason": reason, }).WithError(err).Error("Failed to publish single key change event") } } // publishChangeEvent 发布通用变更事件 func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, ChangeReason: reason, } eventData, err := json.Marshal(event) if err != nil { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "reason": reason, }).WithError(err).Error("Failed to marshal change event") return } if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "reason": reason, }).WithError(err).Error("Failed to publish change event") } } // publishImportGroupCompletedEvent 发布批量导入完成事件 func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) { if len(keyIDs) == 0 { return } event := models.ImportGroupCompletedEvent{ GroupID: groupID, KeyIDs: keyIDs, CompletedAt: time.Now(), } eventData, err := json.Marshal(event) if err != nil { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "key_count": len(keyIDs), }).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent") return } if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "key_count": len(keyIDs), }).WithError(err).Error("Failed to publish ImportGroupCompletedEvent") } else { s.logger.WithFields(logrus.Fields{ "group_id": groupID, "key_count": len(keyIDs), }).Info("Published ImportGroupCompletedEvent") } } // min 返回两个整数中的较小值 func min(a, b int) int { if a < b { return a } return b }