Files
gemini-banlancer/internal/service/key_import_service.go
2025-11-22 14:20:05 +08:00

380 lines
13 KiB
Go

// Filename: internal/service/key_import_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/store"
"gemini-balancer/internal/task"
"gemini-balancer/internal/utils"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
const (
TaskTypeAddKeysToGroup = "add_keys_to_group"
TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group"
TaskTypeHardDeleteKeys = "hard_delete_keys"
TaskTypeRestoreKeys = "restore_keys"
chunkSize = 500
)
type KeyImportService struct {
taskService task.Reporter
keyRepo repository.KeyRepository
store store.Store
logger *logrus.Entry
apiKeyService *APIKeyService
}
func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService {
return &KeyImportService{
taskService: ts,
keyRepo: kr,
store: s,
logger: logger.WithField("component", "KeyImportService🚀"),
apiKeyService: as,
}
}
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
taskFunc()
}
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in input text")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
})
return taskStatus, nil
}
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeyStrings []string
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
}
}
if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return
}
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
}
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
return
}
alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
var keysToLink []models.APIKey
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
}
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
for i := 0; i < len(idsToLink); i += chunkSize {
end := i + chunkSize
if end > len(idsToLink) {
end = len(idsToLink)
}
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
return
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
}
result := gin.H{
"newly_linked_count": len(keysToLink),
"already_linked_count": len(alreadyLinkedIDSet),
"total_linked_count": len(allKeyModels),
}
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
if validateOnImport {
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
}
}
}
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeys []string
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeys = append(uniqueKeys, kStr)
}
}
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
}
if len(keysToUnlink) == 0 {
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
return
}
idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID
}
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
var totalUnlinked int64
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize
if end > len(idsToUnlink) {
end = len(idsToUnlink)
}
chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
return
}
totalUnlinked += unlinked
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
if err != nil {
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
}
result := gin.H{
"unlinked_count": totalUnlinked,
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return
}
totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var restoredCount int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return
}
restoredCount += count
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
KeyID: keyID,
OldStatus: oldStatus,
NewStatus: newStatus,
ChangeReason: reason,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"reason": reason,
}).Error("Failed to publish single key change event.")
}
}
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
ChangeReason: reason,
}
eventData, _ := json.Marshal(event)
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 {
return
}
event := models.ImportGroupCompletedEvent{
GroupID: groupID,
KeyIDs: keyIDs,
CompletedAt: time.Now(),
}
eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
return
}
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
} else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
}
}
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
}
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
}