Fix Services & Update the middleware && others
This commit is contained in:
@@ -23,6 +23,10 @@ const (
|
||||
TaskTypeHardDeleteKeys = "hard_delete_keys"
|
||||
TaskTypeRestoreKeys = "restore_keys"
|
||||
chunkSize = 500
|
||||
|
||||
// 任务超时时间常量化
|
||||
defaultTaskTimeout = 15 * time.Minute
|
||||
longTaskTimeout = time.Hour
|
||||
)
|
||||
|
||||
type KeyImportService struct {
|
||||
@@ -43,17 +47,19 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
|
||||
}
|
||||
}
|
||||
|
||||
// runTaskWithRecovery 统一的任务恢复包装器
|
||||
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
||||
s.logger.Error(err)
|
||||
s.logger.WithField("task_id", taskID).Error(err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
taskFunc()
|
||||
}
|
||||
|
||||
// StartAddKeysTask 启动批量添加密钥任务
|
||||
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
@@ -61,260 +67,404 @@ func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, k
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// StartUnlinkKeysTask 启动批量解绑密钥任务
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// StartHardDeleteKeysTask 启动硬删除密钥任务
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
|
||||
resourceID := "global_hard_delete"
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// StartRestoreKeysTask 启动恢复密钥任务
|
||||
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_restore_keys"
|
||||
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||||
resourceID := "global_restore_keys"
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeyStrings []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
|
||||
}
|
||||
}
|
||||
if len(uniqueKeyStrings) == 0 {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||||
return
|
||||
}
|
||||
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
||||
for i, keyStr := range uniqueKeyStrings {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
||||
// StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
}
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
alreadyLinkedIDSet := make(map[uint]struct{})
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
var keysToLink []models.APIKey
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
|
||||
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
// ==================== 核心任务执行逻辑 ====================
|
||||
|
||||
// runAddKeysTask 执行批量添加密钥
|
||||
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
// 1. 去重
|
||||
uniqueKeys := s.deduplicateKeys(keys)
|
||||
if len(uniqueKeys) == 0 {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, gin.H{
|
||||
"newly_linked_count": 0,
|
||||
"already_linked_count": 0,
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 确保所有密钥在数据库中存在(幂等操作)
|
||||
allKeyModels, err := s.ensureKeysExist(uniqueKeys)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 过滤已关联的密钥
|
||||
keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 更新任务的实际处理总数
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
|
||||
// 5. 批量关联密钥到组
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToLink) {
|
||||
end = len(idsToLink)
|
||||
}
|
||||
chunk := idsToLink[i:end]
|
||||
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||||
return
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
// 6. 根据验证标志处理密钥状态
|
||||
if len(keysToLink) > 0 {
|
||||
s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport)
|
||||
}
|
||||
|
||||
// 7. 返回结果
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": alreadyLinkedCount,
|
||||
"total_linked_count": len(allKeyModels),
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runUnlinkKeysTask 执行批量解绑密钥
|
||||
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
|
||||
// 1. 去重
|
||||
uniqueKeys := s.deduplicateKeys(keys)
|
||||
|
||||
// 2. 查找需要解绑的密钥
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{
|
||||
"unlinked_count": 0,
|
||||
"hard_deleted_count": 0,
|
||||
"not_found_count": len(uniqueKeys),
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 提取密钥 ID
|
||||
idsToUnlink := s.extractKeyIDs(keysToUnlink)
|
||||
|
||||
// 4. 更新任务总数
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
|
||||
// 5. 批量解绑
|
||||
totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. 清理孤立密钥
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||||
}
|
||||
|
||||
// 7. 返回结果
|
||||
result := gin.H{
|
||||
"unlinked_count": totalUnlinked,
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runHardDeleteKeysTask 执行硬删除密钥
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
|
||||
return s.keyRepo.HardDeleteByValues(chunk)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": len(alreadyLinkedIDSet),
|
||||
"total_linked_count": len(allKeyModels),
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
if validateOnImport {
|
||||
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeys []string
|
||||
// runRestoreKeysTask 执行恢复密钥
|
||||
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
|
||||
return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// deduplicateKeys 去重密钥列表
|
||||
func (s *KeyImportService) deduplicateKeys(keys []string) []string {
|
||||
uniqueKeysMap := make(map[string]struct{}, len(keys))
|
||||
uniqueKeys := make([]string, 0, len(keys))
|
||||
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, kStr)
|
||||
}
|
||||
}
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
return uniqueKeys
|
||||
}
|
||||
|
||||
// ensureKeysExist 确保所有密钥在数据库中存在
|
||||
func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) {
|
||||
keysToEnsure := make([]models.APIKey, len(keys))
|
||||
for i, keyStr := range keys {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
return s.keyRepo.AddKeys(keysToEnsure)
|
||||
}
|
||||
|
||||
// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥
|
||||
func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) {
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
idsToUnlink := make([]uint, len(keysToUnlink))
|
||||
for i, key := range keysToUnlink {
|
||||
idsToUnlink[i] = key.ID
|
||||
alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels))
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
var totalUnlinked int64
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToUnlink) {
|
||||
end = len(idsToUnlink)
|
||||
keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet))
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keysToLink, len(alreadyLinkedIDSet), nil
|
||||
}
|
||||
|
||||
// extractKeyIDs 提取密钥 ID 列表
|
||||
func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint {
|
||||
ids := make([]uint, len(keys))
|
||||
for i, key := range keys {
|
||||
ids[i] = key.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// linkKeysInChunks 分块关联密钥到组
|
||||
func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error {
|
||||
idsToLink := s.extractKeyIDs(keysToLink)
|
||||
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := min(i+chunkSize, len(idsToLink))
|
||||
chunk := idsToLink[i:end]
|
||||
|
||||
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
|
||||
return fmt.Errorf("chunk failed to link keys: %w", err)
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unlinkKeysInChunks 分块解绑密钥
|
||||
func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) {
|
||||
var totalUnlinked int64
|
||||
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := min(i+chunkSize, len(idsToUnlink))
|
||||
chunk := idsToUnlink[i:end]
|
||||
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||||
return
|
||||
return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err)
|
||||
}
|
||||
|
||||
totalUnlinked += unlinked
|
||||
|
||||
// 发布解绑事件
|
||||
for _, keyID := range chunk {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
return totalUnlinked, nil
|
||||
}
|
||||
|
||||
// processKeysInChunks 通用的分块处理密钥逻辑
|
||||
func (s *KeyImportService) processKeysInChunks(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
keys []string,
|
||||
processFunc func(chunk []string) (int64, error),
|
||||
) (int64, error) {
|
||||
var totalProcessed int64
|
||||
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := min(i+chunkSize, len(keys))
|
||||
chunk := keys[i:end]
|
||||
|
||||
count, err := processFunc(chunk)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to process chunk: %w", err)
|
||||
}
|
||||
|
||||
totalProcessed += count
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
|
||||
return totalProcessed, nil
|
||||
}
|
||||
|
||||
// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活)
|
||||
func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) {
|
||||
idsToLink := s.extractKeyIDs(keysToLink)
|
||||
|
||||
if validateOnImport {
|
||||
// 发布批量导入完成事件,触发验证
|
||||
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
|
||||
|
||||
// 发布单个密钥状态变更事件
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
// 直接激活密钥,不进行验证
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
}).Errorf("Failed to directly activate key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// endTaskWithResult 统一的任务结束处理
|
||||
func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) {
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"task_id": taskID,
|
||||
"resource_id": resourceID,
|
||||
}).WithError(err).Error("Task failed")
|
||||
}
|
||||
result := gin.H{
|
||||
"unlinked_count": totalUnlinked,
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
var totalDeleted int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||||
return
|
||||
}
|
||||
totalDeleted += deleted
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
result := gin.H{
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
var restoredCount int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||||
return
|
||||
}
|
||||
restoredCount += count
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
|
||||
}
|
||||
// ==================== 事件发布方法 ====================
|
||||
|
||||
// publishSingleKeyChangeEvent 发布单个密钥状态变更事件
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
@@ -324,56 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, grou
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithError(err).WithFields(logrus.Fields{
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).Error("Failed to publish single key change event.")
|
||||
}).WithError(err).Error("Failed to marshal key change event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to publish single key change event")
|
||||
}
|
||||
}
|
||||
|
||||
// publishChangeEvent 发布通用变更事件
|
||||
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
ChangeReason: reason,
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to marshal change event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to publish change event")
|
||||
}
|
||||
}
|
||||
|
||||
// publishImportGroupCompletedEvent 发布批量导入完成事件
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
|
||||
if len(keyIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
event := models.ImportGroupCompletedEvent{
|
||||
GroupID: groupID,
|
||||
KeyIDs: keyIDs,
|
||||
CompletedAt: time.Now(),
|
||||
}
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).WithError(err).Error("Failed to publish ImportGroupCompletedEvent")
|
||||
} else {
|
||||
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).Info("Published ImportGroupCompletedEvent")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
// min 返回两个整数中的较小值
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
|
||||
return b
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user