380 lines
13 KiB
Go
380 lines
13 KiB
Go
// Filename: internal/service/key_import_service.go
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/repository"
|
|
"gemini-balancer/internal/store"
|
|
"gemini-balancer/internal/task"
|
|
"gemini-balancer/internal/utils"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
TaskTypeAddKeysToGroup = "add_keys_to_group"
|
|
TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group"
|
|
TaskTypeHardDeleteKeys = "hard_delete_keys"
|
|
TaskTypeRestoreKeys = "restore_keys"
|
|
chunkSize = 500
|
|
)
|
|
|
|
type KeyImportService struct {
|
|
taskService task.Reporter
|
|
keyRepo repository.KeyRepository
|
|
store store.Store
|
|
logger *logrus.Entry
|
|
apiKeyService *APIKeyService
|
|
}
|
|
|
|
func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService {
|
|
return &KeyImportService{
|
|
taskService: ts,
|
|
keyRepo: kr,
|
|
store: s,
|
|
logger: logger.WithField("component", "KeyImportService🚀"),
|
|
apiKeyService: as,
|
|
}
|
|
}
|
|
|
|
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
|
s.logger.Error(err)
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
|
}
|
|
}()
|
|
taskFunc()
|
|
}
|
|
|
|
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
|
keys := utils.ParseKeysFromText(keysText)
|
|
if len(keys) == 0 {
|
|
return nil, fmt.Errorf("no valid keys found in input text")
|
|
}
|
|
resourceID := fmt.Sprintf("group-%d", groupID)
|
|
|
|
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
|
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
|
})
|
|
return taskStatus, nil
|
|
}
|
|
|
|
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
|
keys := utils.ParseKeysFromText(keysText)
|
|
if len(keys) == 0 {
|
|
return nil, fmt.Errorf("no valid keys found")
|
|
}
|
|
resourceID := fmt.Sprintf("group-%d", groupID)
|
|
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
|
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
|
|
})
|
|
return taskStatus, nil
|
|
}
|
|
|
|
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
|
keys := utils.ParseKeysFromText(keysText)
|
|
if len(keys) == 0 {
|
|
return nil, fmt.Errorf("no valid keys found")
|
|
}
|
|
resourceID := "global_hard_delete"
|
|
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
|
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
|
})
|
|
return taskStatus, nil
|
|
}
|
|
|
|
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
|
keys := utils.ParseKeysFromText(keysText)
|
|
if len(keys) == 0 {
|
|
return nil, fmt.Errorf("no valid keys found")
|
|
}
|
|
resourceID := "global_restore_keys"
|
|
|
|
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
|
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
|
})
|
|
return taskStatus, nil
|
|
}
|
|
|
|
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
|
uniqueKeysMap := make(map[string]struct{})
|
|
var uniqueKeyStrings []string
|
|
for _, kStr := range keys {
|
|
if _, exists := uniqueKeysMap[kStr]; !exists {
|
|
uniqueKeysMap[kStr] = struct{}{}
|
|
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
|
|
}
|
|
}
|
|
if len(uniqueKeyStrings) == 0 {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
|
return
|
|
}
|
|
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
|
for i, keyStr := range uniqueKeyStrings {
|
|
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
|
}
|
|
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
|
return
|
|
}
|
|
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
|
return
|
|
}
|
|
alreadyLinkedIDSet := make(map[uint]struct{})
|
|
for _, key := range alreadyLinkedModels {
|
|
alreadyLinkedIDSet[key.ID] = struct{}{}
|
|
}
|
|
var keysToLink []models.APIKey
|
|
for _, key := range allKeyModels {
|
|
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
|
keysToLink = append(keysToLink, key)
|
|
}
|
|
}
|
|
|
|
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
|
|
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
|
}
|
|
if len(keysToLink) > 0 {
|
|
idsToLink := make([]uint, len(keysToLink))
|
|
for i, key := range keysToLink {
|
|
idsToLink[i] = key.ID
|
|
}
|
|
for i := 0; i < len(idsToLink); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(idsToLink) {
|
|
end = len(idsToLink)
|
|
}
|
|
chunk := idsToLink[i:end]
|
|
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
|
return
|
|
}
|
|
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
|
}
|
|
}
|
|
|
|
result := gin.H{
|
|
"newly_linked_count": len(keysToLink),
|
|
"already_linked_count": len(alreadyLinkedIDSet),
|
|
"total_linked_count": len(allKeyModels),
|
|
}
|
|
if len(keysToLink) > 0 {
|
|
idsToLink := make([]uint, len(keysToLink))
|
|
for i, key := range keysToLink {
|
|
idsToLink[i] = key.ID
|
|
}
|
|
if validateOnImport {
|
|
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
|
|
for _, keyID := range idsToLink {
|
|
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
|
}
|
|
} else {
|
|
for _, keyID := range idsToLink {
|
|
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
|
|
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
|
}
|
|
|
|
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
|
|
uniqueKeysMap := make(map[string]struct{})
|
|
var uniqueKeys []string
|
|
for _, kStr := range keys {
|
|
if _, exists := uniqueKeysMap[kStr]; !exists {
|
|
uniqueKeysMap[kStr] = struct{}{}
|
|
uniqueKeys = append(uniqueKeys, kStr)
|
|
}
|
|
}
|
|
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
|
return
|
|
}
|
|
|
|
if len(keysToUnlink) == 0 {
|
|
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
|
return
|
|
}
|
|
idsToUnlink := make([]uint, len(keysToUnlink))
|
|
for i, key := range keysToUnlink {
|
|
idsToUnlink[i] = key.ID
|
|
}
|
|
|
|
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
|
|
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
|
}
|
|
var totalUnlinked int64
|
|
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(idsToUnlink) {
|
|
end = len(idsToUnlink)
|
|
}
|
|
chunk := idsToUnlink[i:end]
|
|
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
|
return
|
|
}
|
|
totalUnlinked += unlinked
|
|
for _, keyID := range chunk {
|
|
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
|
|
}
|
|
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
|
}
|
|
|
|
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
|
if err != nil {
|
|
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
|
}
|
|
result := gin.H{
|
|
"unlinked_count": totalUnlinked,
|
|
"hard_deleted_count": totalDeleted,
|
|
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
|
}
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
|
}
|
|
|
|
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
|
var totalDeleted int64
|
|
for i := 0; i < len(keys); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(keys) {
|
|
end = len(keys)
|
|
}
|
|
chunk := keys[i:end]
|
|
|
|
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
|
return
|
|
}
|
|
totalDeleted += deleted
|
|
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
|
}
|
|
result := gin.H{
|
|
"hard_deleted_count": totalDeleted,
|
|
"not_found_count": int64(len(keys)) - totalDeleted,
|
|
}
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
|
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
|
|
}
|
|
|
|
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
|
var restoredCount int64
|
|
for i := 0; i < len(keys); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(keys) {
|
|
end = len(keys)
|
|
}
|
|
chunk := keys[i:end]
|
|
|
|
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
|
return
|
|
}
|
|
restoredCount += count
|
|
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
|
}
|
|
result := gin.H{
|
|
"restored_count": restoredCount,
|
|
"not_found_count": int64(len(keys)) - restoredCount,
|
|
}
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
|
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
|
|
}
|
|
|
|
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
|
event := models.KeyStatusChangedEvent{
|
|
GroupID: groupID,
|
|
KeyID: keyID,
|
|
OldStatus: oldStatus,
|
|
NewStatus: newStatus,
|
|
ChangeReason: reason,
|
|
ChangedAt: time.Now(),
|
|
}
|
|
eventData, _ := json.Marshal(event)
|
|
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
|
s.logger.WithError(err).WithFields(logrus.Fields{
|
|
"group_id": groupID,
|
|
"key_id": keyID,
|
|
"reason": reason,
|
|
}).Error("Failed to publish single key change event.")
|
|
}
|
|
}
|
|
|
|
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
|
|
event := models.KeyStatusChangedEvent{
|
|
GroupID: groupID,
|
|
ChangeReason: reason,
|
|
}
|
|
eventData, _ := json.Marshal(event)
|
|
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
|
|
}
|
|
|
|
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
|
|
if len(keyIDs) == 0 {
|
|
return
|
|
}
|
|
event := models.ImportGroupCompletedEvent{
|
|
GroupID: groupID,
|
|
KeyIDs: keyIDs,
|
|
CompletedAt: time.Now(),
|
|
}
|
|
eventData, err := json.Marshal(event)
|
|
if err != nil {
|
|
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
|
return
|
|
}
|
|
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
|
|
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
|
} else {
|
|
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
|
}
|
|
}
|
|
|
|
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
|
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
|
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
|
}
|
|
if len(keyValues) == 0 {
|
|
return nil, fmt.Errorf("no keys found matching the provided filter")
|
|
}
|
|
keysAsText := strings.Join(keyValues, "\n")
|
|
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
|
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
|
|
}
|