397 lines
14 KiB
Go
397 lines
14 KiB
Go
// Filename: internal/service/key_import_service.go
|
||
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/repository"
|
||
"gemini-balancer/internal/store"
|
||
"gemini-balancer/internal/task"
|
||
"gemini-balancer/internal/utils"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
const (
|
||
TaskTypeAddKeysToGroup = "add_keys_to_group"
|
||
TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group"
|
||
TaskTypeHardDeleteKeys = "hard_delete_keys"
|
||
TaskTypeRestoreKeys = "restore_keys"
|
||
chunkSize = 500
|
||
)
|
||
|
||
type KeyImportService struct {
|
||
taskService task.Reporter
|
||
keyRepo repository.KeyRepository
|
||
store store.Store
|
||
logger *logrus.Entry
|
||
apiKeyService *APIKeyService
|
||
}
|
||
|
||
func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService {
|
||
return &KeyImportService{
|
||
taskService: ts,
|
||
keyRepo: kr,
|
||
store: s,
|
||
logger: logger.WithField("component", "KeyImportService🚀"),
|
||
apiKeyService: as,
|
||
}
|
||
}
|
||
|
||
// --- 通用的 Panic-Safe 任務執行器 ---
|
||
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
||
s.logger.Error(err)
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||
}
|
||
}()
|
||
taskFunc()
|
||
}
|
||
|
||
// --- Public Task Starters ---
|
||
|
||
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||
keys := utils.ParseKeysFromText(keysText)
|
||
if len(keys) == 0 {
|
||
return nil, fmt.Errorf("no valid keys found in input text")
|
||
}
|
||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||
})
|
||
return taskStatus, nil
|
||
}
|
||
|
||
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||
keys := utils.ParseKeysFromText(keysText)
|
||
if len(keys) == 0 {
|
||
return nil, fmt.Errorf("no valid keys found")
|
||
}
|
||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
|
||
})
|
||
return taskStatus, nil
|
||
}
|
||
|
||
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
|
||
keys := utils.ParseKeysFromText(keysText)
|
||
if len(keys) == 0 {
|
||
return nil, fmt.Errorf("no valid keys found")
|
||
}
|
||
resourceID := "global_hard_delete" // Global lock
|
||
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
|
||
})
|
||
return taskStatus, nil
|
||
}
|
||
|
||
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
|
||
keys := utils.ParseKeysFromText(keysText)
|
||
if len(keys) == 0 {
|
||
return nil, fmt.Errorf("no valid keys found")
|
||
}
|
||
resourceID := "global_restore_keys" // Global lock
|
||
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
|
||
})
|
||
return taskStatus, nil
|
||
}
|
||
|
||
// --- Private Task Runners ---
|
||
|
||
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||
// 步骤 1: 对输入的原始 key 列表进行去重。
|
||
uniqueKeysMap := make(map[string]struct{})
|
||
var uniqueKeyStrings []string
|
||
for _, kStr := range keys {
|
||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||
uniqueKeysMap[kStr] = struct{}{}
|
||
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
|
||
}
|
||
}
|
||
if len(uniqueKeyStrings) == 0 {
|
||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||
return
|
||
}
|
||
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
|
||
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
||
for i, keyStr := range uniqueKeyStrings {
|
||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||
}
|
||
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
||
if err != nil {
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||
return
|
||
}
|
||
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
|
||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
||
if err != nil {
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||
return
|
||
}
|
||
alreadyLinkedIDSet := make(map[uint]struct{})
|
||
for _, key := range alreadyLinkedModels {
|
||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||
}
|
||
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
|
||
var keysToLink []models.APIKey
|
||
for _, key := range allKeyModels {
|
||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||
keysToLink = append(keysToLink, key)
|
||
}
|
||
}
|
||
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
|
||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||
}
|
||
// 步骤 6: 分块处理【链接Key到组】的操作,并实时更新进度。
|
||
if len(keysToLink) > 0 {
|
||
idsToLink := make([]uint, len(keysToLink))
|
||
for i, key := range keysToLink {
|
||
idsToLink[i] = key.ID
|
||
}
|
||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||
end := i + chunkSize
|
||
if end > len(idsToLink) {
|
||
end = len(idsToLink)
|
||
}
|
||
chunk := idsToLink[i:end]
|
||
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||
return
|
||
}
|
||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||
}
|
||
}
|
||
|
||
// 步骤 7: 准备最终结果并结束任务。
|
||
result := gin.H{
|
||
"newly_linked_count": len(keysToLink),
|
||
"already_linked_count": len(alreadyLinkedIDSet),
|
||
"total_linked_count": len(allKeyModels),
|
||
}
|
||
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
|
||
if len(keysToLink) > 0 {
|
||
idsToLink := make([]uint, len(keysToLink))
|
||
for i, key := range keysToLink {
|
||
idsToLink[i] = key.ID
|
||
}
|
||
if validateOnImport {
|
||
s.publishImportGroupCompletedEvent(groupID, idsToLink)
|
||
for _, keyID := range idsToLink {
|
||
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||
}
|
||
} else {
|
||
for _, keyID := range idsToLink {
|
||
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
|
||
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||
}
|
||
|
||
// runUnlinkKeysTask
|
||
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
|
||
uniqueKeysMap := make(map[string]struct{})
|
||
var uniqueKeys []string
|
||
for _, kStr := range keys {
|
||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||
uniqueKeysMap[kStr] = struct{}{}
|
||
uniqueKeys = append(uniqueKeys, kStr)
|
||
}
|
||
}
|
||
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
|
||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||
if err != nil {
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||
return
|
||
}
|
||
|
||
if len(keysToUnlink) == 0 {
|
||
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||
return
|
||
}
|
||
idsToUnlink := make([]uint, len(keysToUnlink))
|
||
for i, key := range keysToUnlink {
|
||
idsToUnlink[i] = key.ID
|
||
}
|
||
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
|
||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||
}
|
||
var totalUnlinked int64
|
||
// 步骤 3: 分块处理【解绑Key】的操作,并上报进度。
|
||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||
end := i + chunkSize
|
||
if end > len(idsToUnlink) {
|
||
end = len(idsToUnlink)
|
||
}
|
||
chunk := idsToUnlink[i:end]
|
||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
|
||
if err != nil {
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||
return
|
||
}
|
||
totalUnlinked += unlinked
|
||
|
||
for _, keyID := range chunk {
|
||
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||
}
|
||
|
||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||
}
|
||
|
||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||
if err != nil {
|
||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||
}
|
||
result := gin.H{
|
||
"unlinked_count": totalUnlinked,
|
||
"hard_deleted_count": totalDeleted,
|
||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||
}
|
||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||
}
|
||
|
||
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
|
||
var totalDeleted int64
|
||
for i := 0; i < len(keys); i += chunkSize {
|
||
end := i + chunkSize
|
||
if end > len(keys) {
|
||
end = len(keys)
|
||
}
|
||
chunk := keys[i:end]
|
||
|
||
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
||
if err != nil {
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||
return
|
||
}
|
||
totalDeleted += deleted
|
||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||
}
|
||
|
||
result := gin.H{
|
||
"hard_deleted_count": totalDeleted,
|
||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||
}
|
||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
|
||
}
|
||
|
||
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
|
||
var restoredCount int64
|
||
for i := 0; i < len(keys); i += chunkSize {
|
||
end := i + chunkSize
|
||
if end > len(keys) {
|
||
end = len(keys)
|
||
}
|
||
chunk := keys[i:end]
|
||
|
||
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||
if err != nil {
|
||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||
return
|
||
}
|
||
restoredCount += count
|
||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||
}
|
||
result := gin.H{
|
||
"restored_count": restoredCount,
|
||
"not_found_count": int64(len(keys)) - restoredCount,
|
||
}
|
||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
|
||
}
|
||
|
||
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||
event := models.KeyStatusChangedEvent{
|
||
GroupID: groupID,
|
||
KeyID: keyID,
|
||
OldStatus: oldStatus,
|
||
NewStatus: newStatus,
|
||
ChangeReason: reason,
|
||
ChangedAt: time.Now(),
|
||
}
|
||
eventData, _ := json.Marshal(event)
|
||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||
s.logger.WithError(err).WithFields(logrus.Fields{
|
||
"group_id": groupID,
|
||
"key_id": keyID,
|
||
"reason": reason,
|
||
}).Error("Failed to publish single key change event.")
|
||
}
|
||
}
|
||
|
||
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
|
||
event := models.KeyStatusChangedEvent{
|
||
GroupID: groupID,
|
||
ChangeReason: reason,
|
||
}
|
||
eventData, _ := json.Marshal(event)
|
||
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||
}
|
||
|
||
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
|
||
if len(keyIDs) == 0 {
|
||
return
|
||
}
|
||
event := models.ImportGroupCompletedEvent{
|
||
GroupID: groupID,
|
||
KeyIDs: keyIDs,
|
||
CompletedAt: time.Now(),
|
||
}
|
||
eventData, err := json.Marshal(event)
|
||
if err != nil {
|
||
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
||
return
|
||
}
|
||
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
|
||
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
||
} else {
|
||
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
||
}
|
||
}
|
||
|
||
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
|
||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||
// 1. [New] Find the keys to operate on.
|
||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||
}
|
||
if len(keyValues) == 0 {
|
||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||
}
|
||
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
|
||
keysAsText := strings.Join(keyValues, "\n")
|
||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||
return s.StartUnlinkKeysTask(groupID, keysAsText)
|
||
} |