// Filename: internal/service/key_import_service.go package service import ( "context" "encoding/json" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/store" "gemini-balancer/internal/task" "gemini-balancer/internal/utils" "strings" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) const ( TaskTypeAddKeysToGroup = "add_keys_to_group" TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group" TaskTypeHardDeleteKeys = "hard_delete_keys" TaskTypeRestoreKeys = "restore_keys" chunkSize = 500 ) type KeyImportService struct { taskService task.Reporter keyRepo repository.KeyRepository store store.Store logger *logrus.Entry apiKeyService *APIKeyService } func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService { return &KeyImportService{ taskService: ts, keyRepo: kr, store: s, logger: logger.WithField("component", "KeyImportService🚀"), apiKeyService: as, } } func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic recovered in task %s: %v", taskID, r) s.logger.Error(err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) } }() taskFunc() } func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found in input text") } resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport) }) return taskStatus, nil } func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := fmt.Sprintf("group-%d", groupID) taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys) }) return taskStatus, nil } func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := "global_hard_delete" taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) return taskStatus, nil } func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := "global_restore_keys" taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour) if err != nil { return nil, err } go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) return taskStatus, nil } func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { uniqueKeysMap := make(map[string]struct{}) var uniqueKeyStrings []string for _, kStr := range keys { if _, exists := uniqueKeysMap[kStr]; !exists { uniqueKeysMap[kStr] = struct{}{} uniqueKeyStrings = append(uniqueKeyStrings, kStr) } } if len(uniqueKeyStrings) == 0 { s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil) return } keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings)) for i, keyStr := range uniqueKeyStrings { keysToEnsure[i] = models.APIKey{APIKey: keyStr} } allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure) if err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) return } alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID) if err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err)) return } alreadyLinkedIDSet := make(map[uint]struct{}) for _, key := range alreadyLinkedModels { alreadyLinkedIDSet[key.ID] = struct{}{} } var keysToLink []models.APIKey for _, key := range allKeyModels { if _, exists := alreadyLinkedIDSet[key.ID]; !exists { keysToLink = append(keysToLink, key) } } if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } if len(keysToLink) > 0 { idsToLink := make([]uint, len(keysToLink)) for i, key := range keysToLink { idsToLink[i] = key.ID } for i := 0; i < len(idsToLink); i += chunkSize { end := i + chunkSize if end > len(idsToLink) { end = len(idsToLink) } chunk := idsToLink[i:end] if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err)) return } _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } } result := gin.H{ "newly_linked_count": len(keysToLink), "already_linked_count": len(alreadyLinkedIDSet), "total_linked_count": len(allKeyModels), } if len(keysToLink) > 0 { idsToLink := make([]uint, len(keysToLink)) for i, key := range keysToLink { idsToLink[i] = key.ID } if validateOnImport { s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink) for _, keyID := range idsToLink { s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked") } } else { for _, keyID := range idsToLink { if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil { s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err) } } } } s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) { uniqueKeysMap := make(map[string]struct{}) var uniqueKeys []string for _, kStr := range keys { if _, exists := uniqueKeysMap[kStr]; !exists { uniqueKeysMap[kStr] = struct{}{} uniqueKeys = append(uniqueKeys, kStr) } } keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) if err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) return } if len(keysToUnlink) == 0 { result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)} s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) return } idsToUnlink := make([]uint, len(keysToUnlink)) for i, key := range keysToUnlink { idsToUnlink[i] = key.ID } if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } var totalUnlinked int64 for i := 0; i < len(idsToUnlink); i += chunkSize { end := i + chunkSize if end > len(idsToUnlink) { end = len(idsToUnlink) } chunk := idsToUnlink[i:end] unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk) if err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err)) return } totalUnlinked += unlinked for _, keyID := range chunk { s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked") } _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } totalDeleted, err := s.keyRepo.DeleteOrphanKeys() if err != nil { s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.") } result := gin.H{ "unlinked_count": totalUnlinked, "hard_deleted_count": totalDeleted, "not_found_count": len(uniqueKeys) - int(totalUnlinked), } s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { var totalDeleted int64 for i := 0; i < len(keys); i += chunkSize { end := i + chunkSize if end > len(keys) { end = len(keys) } chunk := keys[i:end] deleted, err := s.keyRepo.HardDeleteByValues(chunk) if err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err)) return } totalDeleted += deleted _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } result := gin.H{ "hard_deleted_count": totalDeleted, "not_found_count": int64(len(keys)) - totalDeleted, } s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) s.publishChangeEvent(ctx, 0, "keys_hard_deleted") } func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { var restoredCount int64 for i := 0; i < len(keys); i += chunkSize { end := i + chunkSize if end > len(keys) { end = len(keys) } chunk := keys[i:end] count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive) if err != nil { s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err)) return } restoredCount += count _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } result := gin.H{ "restored_count": restoredCount, "not_found_count": int64(len(keys)) - restoredCount, } s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) s.publishChangeEvent(ctx, 0, "keys_bulk_restored") } func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, KeyID: keyID, OldStatus: oldStatus, NewStatus: newStatus, ChangeReason: reason, ChangedAt: time.Now(), } eventData, _ := json.Marshal(event) if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "group_id": groupID, "key_id": keyID, "reason": reason, }).Error("Failed to publish single key change event.") } } func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, ChangeReason: reason, } eventData, _ := json.Marshal(event) _ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData) } func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) { if len(keyIDs) == 0 { return } event := models.ImportGroupCompletedEvent{ GroupID: groupID, KeyIDs: keyIDs, CompletedAt: time.Now(), } eventData, err := json.Marshal(event) if err != nil { s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.") return } if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil { s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.") } else { s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs)) } } func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { return nil, fmt.Errorf("failed to find keys by filter: %w", err) } if len(keyValues) == 0 { return nil, fmt.Errorf("no keys found matching the provided filter") } keysAsText := strings.Join(keyValues, "\n") s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID) return s.StartUnlinkKeysTask(ctx, groupID, keysAsText) }