// Filename: internal/service/key_import_service.go package service import ( "encoding/json" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/store" "gemini-balancer/internal/task" "gemini-balancer/internal/utils" "strings" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) const ( TaskTypeAddKeysToGroup = "add_keys_to_group" TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group" TaskTypeHardDeleteKeys = "hard_delete_keys" TaskTypeRestoreKeys = "restore_keys" chunkSize = 500 ) type KeyImportService struct { taskService task.Reporter keyRepo repository.KeyRepository store store.Store logger *logrus.Entry apiKeyService *APIKeyService } func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService { return &KeyImportService{ taskService: ts, keyRepo: kr, store: s, logger: logger.WithField("component", "KeyImportService🚀"), apiKeyService: as, } } // --- 通用的 Panic-Safe 任務執行器 --- func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic recovered in task %s: %v", taskID, r) s.logger.Error(err) s.taskService.EndTaskByID(taskID, resourceID, nil, err) } }() taskFunc() } // --- Public Task Starters --- func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found in input text") } resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute) if err != nil { return nil, err } go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport) }) return taskStatus, nil } func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour) if err != nil { return nil, err } go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys) }) return taskStatus, nil } func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := "global_hard_delete" // Global lock taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour) if err != nil { return nil, err } go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys) }) return taskStatus, nil } func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := "global_restore_keys" // Global lock taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour) if err != nil { return nil, err } go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { s.runRestoreKeysTask(taskStatus.ID, resourceID, keys) }) return taskStatus, nil } // --- Private Task Runners --- func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { // 步骤 1: 对输入的原始 key 列表进行去重。 uniqueKeysMap := make(map[string]struct{}) var uniqueKeyStrings []string for _, kStr := range keys { if _, exists := uniqueKeysMap[kStr]; !exists { uniqueKeysMap[kStr] = struct{}{} uniqueKeyStrings = append(uniqueKeyStrings, kStr) } } if len(uniqueKeyStrings) == 0 { s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil) return } // 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。 keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings)) for i, keyStr := range uniqueKeyStrings { keysToEnsure[i] = models.APIKey{APIKey: keyStr} } allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure) if err != nil { s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) return } // 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。 alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID) if err != nil { s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err)) return } alreadyLinkedIDSet := make(map[uint]struct{}) for _, key := range alreadyLinkedModels { alreadyLinkedIDSet[key.ID] = struct{}{} } // 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。 var keysToLink []models.APIKey for _, key := range allKeyModels { if _, exists := alreadyLinkedIDSet[key.ID]; !exists { keysToLink = append(keysToLink, key) } } // 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。 if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } // 步骤 6: 分块处理【链接Key到组】的操作,并实时更新进度。 if len(keysToLink) > 0 { idsToLink := make([]uint, len(keysToLink)) for i, key := range keysToLink { idsToLink[i] = key.ID } for i := 0; i < len(idsToLink); i += chunkSize { end := i + chunkSize if end > len(idsToLink) { end = len(idsToLink) } chunk := idsToLink[i:end] if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil { s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err)) return } _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) } } // 步骤 7: 准备最终结果并结束任务。 result := gin.H{ "newly_linked_count": len(keysToLink), "already_linked_count": len(alreadyLinkedIDSet), "total_linked_count": len(allKeyModels), } // 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。 if len(keysToLink) > 0 { idsToLink := make([]uint, len(keysToLink)) for i, key := range keysToLink { idsToLink[i] = key.ID } if validateOnImport { s.publishImportGroupCompletedEvent(groupID, idsToLink) for _, keyID := range idsToLink { s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked") } } else { for _, keyID := range idsToLink { if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil { s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err) } } } } s.taskService.EndTaskByID(taskID, resourceID, result, nil) } // runUnlinkKeysTask func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) { uniqueKeysMap := make(map[string]struct{}) var uniqueKeys []string for _, kStr := range keys { if _, exists := uniqueKeysMap[kStr]; !exists { uniqueKeysMap[kStr] = struct{}{} uniqueKeys = append(uniqueKeys, kStr) } } // 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。 keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) if err != nil { s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) return } if len(keysToUnlink) == 0 { result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)} s.taskService.EndTaskByID(taskID, resourceID, result, nil) return } idsToUnlink := make([]uint, len(keysToUnlink)) for i, key := range keysToUnlink { idsToUnlink[i] = key.ID } // 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。 if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } var totalUnlinked int64 // 步骤 3: 分块处理【解绑Key】的操作,并上报进度。 for i := 0; i < len(idsToUnlink); i += chunkSize { end := i + chunkSize if end > len(idsToUnlink) { end = len(idsToUnlink) } chunk := idsToUnlink[i:end] unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk) if err != nil { s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err)) return } totalUnlinked += unlinked for _, keyID := range chunk { s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked") } _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) } totalDeleted, err := s.keyRepo.DeleteOrphanKeys() if err != nil { s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.") } result := gin.H{ "unlinked_count": totalUnlinked, "hard_deleted_count": totalDeleted, "not_found_count": len(uniqueKeys) - int(totalUnlinked), } s.taskService.EndTaskByID(taskID, resourceID, result, nil) } func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) { var totalDeleted int64 for i := 0; i < len(keys); i += chunkSize { end := i + chunkSize if end > len(keys) { end = len(keys) } chunk := keys[i:end] deleted, err := s.keyRepo.HardDeleteByValues(chunk) if err != nil { s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err)) return } totalDeleted += deleted _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) } result := gin.H{ "hard_deleted_count": totalDeleted, "not_found_count": int64(len(keys)) - totalDeleted, } s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.publishChangeEvent(0, "keys_hard_deleted") // Global event } func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) { var restoredCount int64 for i := 0; i < len(keys); i += chunkSize { end := i + chunkSize if end > len(keys) { end = len(keys) } chunk := keys[i:end] count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive) if err != nil { s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err)) return } restoredCount += count _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) } result := gin.H{ "restored_count": restoredCount, "not_found_count": int64(len(keys)) - restoredCount, } s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.publishChangeEvent(0, "keys_bulk_restored") // Global event } func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, KeyID: keyID, OldStatus: oldStatus, NewStatus: newStatus, ChangeReason: reason, ChangedAt: time.Now(), } eventData, _ := json.Marshal(event) if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "group_id": groupID, "key_id": keyID, "reason": reason, }).Error("Failed to publish single key change event.") } } func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, ChangeReason: reason, } eventData, _ := json.Marshal(event) _ = s.store.Publish(models.TopicKeyStatusChanged, eventData) } func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) { if len(keyIDs) == 0 { return } event := models.ImportGroupCompletedEvent{ GroupID: groupID, KeyIDs: keyIDs, CompletedAt: time.Now(), } eventData, err := json.Marshal(event) if err != nil { s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.") return } if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil { s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.") } else { s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs)) } } // [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter. func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) { s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) // 1. [New] Find the keys to operate on. keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { return nil, fmt.Errorf("failed to find keys by filter: %w", err) } if len(keyValues) == 0 { return nil, fmt.Errorf("no keys found matching the provided filter") } // 2. [REUSE] Convert to text and call the existing, robust unlink task logic. keysAsText := strings.Join(keyValues, "\n") s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID) return s.StartUnlinkKeysTask(groupID, keysAsText) }