853 lines
31 KiB
Go
853 lines
31 KiB
Go
// Filename: internal/service/apikey_service.go
|
|
|
|
package service
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"gemini-balancer/internal/channel"
|
|
CustomErrors "gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/repository"
|
|
"gemini-balancer/internal/settings"
|
|
"gemini-balancer/internal/store"
|
|
"gemini-balancer/internal/task"
|
|
"math"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
TaskTypeRestoreAllBannedInGroup = "restore_all_banned_in_group"
|
|
TaskTypeRestoreSpecificKeys = "restore_specific_keys_in_group"
|
|
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
|
|
)
|
|
|
|
// DTOs & Constants
|
|
const (
|
|
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
)
|
|
|
|
type BatchRestoreResult struct {
|
|
RestoredCount int `json:"restored_count"`
|
|
SkippedCount int `json:"skipped_count"`
|
|
SkippedKeys []SkippedKeyInfo `json:"skipped_keys"`
|
|
}
|
|
|
|
type SkippedKeyInfo struct {
|
|
KeyID uint `json:"key_id"`
|
|
Reason string `json:"reason"`
|
|
}
|
|
|
|
type PaginatedAPIKeys struct {
|
|
Items []*models.APIKeyDetails `json:"items"`
|
|
Total int64 `json:"total"`
|
|
Page int `json:"page"`
|
|
PageSize int `json:"page_size"`
|
|
TotalPages int `json:"total_pages"`
|
|
}
|
|
|
|
type KeyTestResult struct {
|
|
Key string `json:"key"`
|
|
Status string `json:"status"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
type APIKeyService struct {
|
|
db *gorm.DB
|
|
keyRepo repository.KeyRepository
|
|
channel channel.ChannelProxy
|
|
store store.Store
|
|
SettingsManager *settings.SettingsManager
|
|
taskService task.Reporter
|
|
logger *logrus.Entry
|
|
stopChan chan struct{}
|
|
validationService *KeyValidationService
|
|
groupManager *GroupManager
|
|
}
|
|
|
|
func NewAPIKeyService(
|
|
db *gorm.DB,
|
|
repo repository.KeyRepository,
|
|
ch channel.ChannelProxy,
|
|
s store.Store,
|
|
sm *settings.SettingsManager,
|
|
ts task.Reporter,
|
|
vs *KeyValidationService,
|
|
gm *GroupManager,
|
|
logger *logrus.Logger,
|
|
) *APIKeyService {
|
|
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
|
|
return &APIKeyService{
|
|
db: db,
|
|
keyRepo: repo,
|
|
channel: ch,
|
|
store: s,
|
|
SettingsManager: sm,
|
|
taskService: ts,
|
|
logger: logger.WithField("component", "APIKeyService🔑"),
|
|
stopChan: make(chan struct{}),
|
|
validationService: vs,
|
|
groupManager: gm,
|
|
}
|
|
}
|
|
|
|
func (s *APIKeyService) Start() {
|
|
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
|
|
if err != nil {
|
|
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
|
return
|
|
}
|
|
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
|
|
if err != nil {
|
|
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
|
|
return
|
|
}
|
|
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
|
if err != nil {
|
|
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
|
return
|
|
}
|
|
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
|
|
if err != nil {
|
|
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
|
|
return
|
|
}
|
|
s.logger.Info("Started and subscribed to request, master key, health check, and import events.")
|
|
|
|
go func() {
|
|
defer requestSub.Close()
|
|
defer masterKeySub.Close()
|
|
defer keyStatusSub.Close()
|
|
defer importSub.Close()
|
|
for {
|
|
select {
|
|
case msg := <-requestSub.Channel():
|
|
var event models.RequestFinishedEvent
|
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
|
s.logger.Errorf("Failed to unmarshal event for key status update: %v", err)
|
|
continue
|
|
}
|
|
s.handleKeyUsageEvent(&event)
|
|
|
|
case msg := <-masterKeySub.Channel():
|
|
var event models.MasterKeyStatusChangedEvent
|
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
|
s.logger.Errorf("Failed to unmarshal MasterKeyStatusChangedEvent: %v", err)
|
|
continue
|
|
}
|
|
s.handleMasterKeyStatusChangeEvent(&event)
|
|
case msg := <-keyStatusSub.Channel():
|
|
var event models.KeyStatusChangedEvent
|
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
|
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
|
continue
|
|
}
|
|
s.handleKeyStatusChangeEvent(&event)
|
|
|
|
case msg := <-importSub.Channel():
|
|
var event models.ImportGroupCompletedEvent
|
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
|
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent.")
|
|
continue
|
|
}
|
|
s.logger.Infof("Received ImportGroupCompletedEvent for group %d, triggering validation for %d keys.", event.GroupID, len(event.KeyIDs))
|
|
|
|
go s.handlePostImportValidation(&event)
|
|
|
|
case <-s.stopChan:
|
|
s.logger.Info("Stopping event listener.")
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *APIKeyService) Stop() {
|
|
close(s.stopChan)
|
|
}
|
|
|
|
func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) {
|
|
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
|
|
return
|
|
}
|
|
if event.RequestLog.IsSuccess {
|
|
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
|
|
if err != nil {
|
|
s.logger.Warnf("[%s] Could not find mapping for G:%d K:%d on successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
|
|
return
|
|
}
|
|
statusChanged := false
|
|
oldStatus := mapping.Status
|
|
if mapping.Status != models.StatusActive {
|
|
mapping.Status = models.StatusActive
|
|
mapping.ConsecutiveErrorCount = 0
|
|
mapping.LastError = ""
|
|
statusChanged = true
|
|
}
|
|
|
|
now := time.Now()
|
|
mapping.LastUsedAt = &now
|
|
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
|
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
|
|
return
|
|
}
|
|
if statusChanged {
|
|
go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
|
}
|
|
return
|
|
}
|
|
if event.Error != nil {
|
|
s.judgeKeyErrors(
|
|
event.CorrelationID,
|
|
*event.RequestLog.GroupID,
|
|
*event.RequestLog.KeyID,
|
|
event.Error,
|
|
event.IsPreciseRouting,
|
|
)
|
|
}
|
|
}
|
|
|
|
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
|
|
log := s.logger.WithFields(logrus.Fields{
|
|
"group_id": event.GroupID,
|
|
"key_id": event.KeyID,
|
|
"new_status": event.NewStatus,
|
|
"reason": event.ChangeReason,
|
|
})
|
|
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
|
|
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
|
|
log.Info("Polling caches updated based on health check event.")
|
|
}
|
|
|
|
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
|
changeEvent := models.KeyStatusChangedEvent{
|
|
KeyID: keyID,
|
|
GroupID: groupID,
|
|
OldStatus: oldStatus,
|
|
NewStatus: newStatus,
|
|
ChangeReason: reason,
|
|
ChangedAt: time.Now(),
|
|
}
|
|
eventData, _ := json.Marshal(changeEvent)
|
|
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
|
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
|
|
}
|
|
}
|
|
|
|
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
|
|
// --- Path 1: High-performance DB pagination (no keyword) ---
|
|
if params.Keyword == "" {
|
|
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
totalPages := 0
|
|
if total > 0 && params.PageSize > 0 {
|
|
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
|
|
}
|
|
return &PaginatedAPIKeys{
|
|
Items: items,
|
|
Total: total,
|
|
Page: params.Page,
|
|
PageSize: params.PageSize,
|
|
TotalPages: totalPages,
|
|
}, nil
|
|
}
|
|
// --- Path 2: In-memory search (keyword present) ---
|
|
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
|
// To get all keys, we fetch all IDs first, then get their full details.
|
|
var statusesToFilter []string
|
|
if params.Status != "" {
|
|
statusesToFilter = append(statusesToFilter, params.Status)
|
|
} else {
|
|
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
|
|
}
|
|
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch all key IDs for search: %w", err)
|
|
}
|
|
if len(allKeyIDs) == 0 {
|
|
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
|
|
}
|
|
|
|
// This is the heavy operation: getting all keys and decrypting them.
|
|
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
|
|
}
|
|
// We also need mappings to build the final `APIKeyDetails`.
|
|
var allMappings []models.GroupAPIKeyMapping
|
|
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
|
|
}
|
|
mappingMap := make(map[uint]*models.GroupAPIKeyMapping)
|
|
for i := range allMappings {
|
|
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
|
|
}
|
|
// Filter the results in memory.
|
|
var filteredItems []*models.APIKeyDetails
|
|
for _, key := range allKeys {
|
|
if strings.Contains(key.APIKey, params.Keyword) {
|
|
if mapping, ok := mappingMap[key.ID]; ok {
|
|
filteredItems = append(filteredItems, &models.APIKeyDetails{
|
|
ID: key.ID,
|
|
CreatedAt: key.CreatedAt,
|
|
UpdatedAt: key.UpdatedAt,
|
|
APIKey: key.APIKey,
|
|
MasterStatus: key.MasterStatus,
|
|
Status: mapping.Status,
|
|
LastError: mapping.LastError,
|
|
ConsecutiveErrorCount: mapping.ConsecutiveErrorCount,
|
|
LastUsedAt: mapping.LastUsedAt,
|
|
CooldownUntil: mapping.CooldownUntil,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
// Sort the filtered results to ensure consistent pagination (by ID descending).
|
|
sort.Slice(filteredItems, func(i, j int) bool {
|
|
return filteredItems[i].ID > filteredItems[j].ID
|
|
})
|
|
// Manually paginate the filtered results.
|
|
total := int64(len(filteredItems))
|
|
start := (params.Page - 1) * params.PageSize
|
|
end := start + params.PageSize
|
|
if start < 0 {
|
|
start = 0
|
|
}
|
|
if start >= len(filteredItems) {
|
|
return &PaginatedAPIKeys{
|
|
Items: []*models.APIKeyDetails{},
|
|
Total: total,
|
|
Page: params.Page,
|
|
PageSize: params.PageSize,
|
|
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
|
|
}, nil
|
|
}
|
|
if end > len(filteredItems) {
|
|
end = len(filteredItems)
|
|
}
|
|
paginatedItems := filteredItems[start:end]
|
|
return &PaginatedAPIKeys{
|
|
Items: paginatedItems,
|
|
Total: total,
|
|
Page: params.Page,
|
|
PageSize: params.PageSize,
|
|
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
|
|
}, nil
|
|
}
|
|
|
|
func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) {
|
|
return s.keyRepo.GetKeysByIDs(ids)
|
|
}
|
|
|
|
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
|
|
go func() {
|
|
var oldKey models.APIKey
|
|
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
|
|
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
|
|
return
|
|
}
|
|
if err := s.keyRepo.Update(key); err != nil {
|
|
s.logger.Errorf("Failed to asynchronously update key ID %d: %v", key.ID, err)
|
|
return
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
|
|
// Get all associated groups before deletion to publish correct events
|
|
groups, err := s.keyRepo.GetGroupsForKey(id)
|
|
if err != nil {
|
|
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
|
|
}
|
|
|
|
err = s.keyRepo.HardDeleteByID(id)
|
|
if err == nil {
|
|
// Publish events for each group the key was a part of
|
|
for _, groupID := range groups {
|
|
event := models.KeyStatusChangedEvent{
|
|
KeyID: id,
|
|
GroupID: groupID,
|
|
ChangeReason: "key_hard_deleted",
|
|
}
|
|
eventData, _ := json.Marshal(event)
|
|
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
|
key, err := s.keyRepo.GetKeyByID(keyID)
|
|
if err != nil {
|
|
return nil, CustomErrors.ParseDBError(err)
|
|
}
|
|
|
|
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
|
return nil, CustomErrors.ErrStateConflictMasterRevoked
|
|
}
|
|
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
oldStatus := mapping.Status
|
|
if oldStatus == newStatus {
|
|
return mapping, nil
|
|
}
|
|
mapping.Status = newStatus
|
|
if newStatus == models.StatusActive {
|
|
mapping.ConsecutiveErrorCount = 0
|
|
mapping.LastError = ""
|
|
}
|
|
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
|
return nil, err
|
|
}
|
|
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
|
|
return mapping, nil
|
|
}
|
|
|
|
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
|
|
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
|
|
if event.NewMasterStatus != models.MasterStatusRevoked {
|
|
return
|
|
}
|
|
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
|
|
if err != nil {
|
|
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
|
|
return
|
|
}
|
|
if len(affectedGroupIDs) == 0 {
|
|
s.logger.Infof("Key ID %d is revoked, but it's not associated with any group. No action needed.", event.KeyID)
|
|
return
|
|
}
|
|
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
|
|
for _, groupID := range affectedGroupIDs {
|
|
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
|
|
if len(keyIDs) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
|
|
}
|
|
resourceID := fmt.Sprintf("group-%d", groupID)
|
|
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
|
|
return taskStatus, nil
|
|
}
|
|
|
|
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
|
|
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
|
|
}
|
|
}()
|
|
var mappingsToProcess []models.GroupAPIKeyMapping
|
|
err := s.db.Preload("APIKey").
|
|
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
|
Find(&mappingsToProcess).Error
|
|
if err != nil {
|
|
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
|
return
|
|
}
|
|
result := &BatchRestoreResult{
|
|
SkippedKeys: make([]SkippedKeyInfo, 0),
|
|
}
|
|
var successfulMappings []*models.GroupAPIKeyMapping
|
|
processedCount := 0
|
|
for _, mapping := range mappingsToProcess {
|
|
processedCount++
|
|
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
|
if mapping.APIKey == nil {
|
|
result.SkippedCount++
|
|
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
|
|
continue
|
|
}
|
|
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
|
|
result.SkippedCount++
|
|
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: fmt.Sprintf("Master status is '%s'.", mapping.APIKey.MasterStatus)})
|
|
continue
|
|
}
|
|
oldStatus := mapping.Status
|
|
if oldStatus != models.StatusActive {
|
|
mapping.Status = models.StatusActive
|
|
mapping.ConsecutiveErrorCount = 0
|
|
mapping.LastError = ""
|
|
// Use the version that doesn't trigger individual cache updates.
|
|
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
|
|
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
|
|
result.SkippedCount++
|
|
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
|
|
} else {
|
|
result.RestoredCount++
|
|
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
|
|
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
|
|
}
|
|
} else {
|
|
result.RestoredCount++ // Already active, count as success.
|
|
}
|
|
}
|
|
// After the loop, perform one single, efficient cache update.
|
|
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
|
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
|
|
// This is not a task-fatal error, so we just log it and continue.
|
|
}
|
|
// Account for keys that were requested but not found in the initial DB query.
|
|
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
|
|
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
|
}
|
|
|
|
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
|
|
var bannedKeyIDs []uint
|
|
err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
|
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
|
|
Pluck("api_key_id", &bannedKeyIDs).Error
|
|
if err != nil {
|
|
return nil, CustomErrors.ParseDBError(err)
|
|
}
|
|
if len(bannedKeyIDs) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
|
|
}
|
|
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
|
|
}
|
|
|
|
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
|
|
group, ok := s.groupManager.GetGroupByID(event.GroupID)
|
|
if !ok {
|
|
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
|
|
return
|
|
}
|
|
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
|
if err != nil {
|
|
s.logger.Errorf("Failed to build operational config for group %d, aborting validation: %v", event.GroupID, err)
|
|
return
|
|
}
|
|
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
|
|
if err != nil {
|
|
s.logger.Errorf("Failed to build key check endpoint for group %d, aborting validation: %v", event.GroupID, err)
|
|
return
|
|
}
|
|
globalSettings := s.SettingsManager.GetSettings()
|
|
concurrency := globalSettings.BaseKeyCheckConcurrency
|
|
if opConfig.KeyCheckConcurrency != nil {
|
|
concurrency = *opConfig.KeyCheckConcurrency
|
|
}
|
|
if concurrency <= 0 {
|
|
concurrency = 10 // Safety fallback
|
|
}
|
|
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
|
|
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
|
|
if err != nil {
|
|
s.logger.Errorf("Failed to get key models for validation in group %d: %v", event.GroupID, err)
|
|
return
|
|
}
|
|
s.logger.Infof("Validating %d keys for group %d with concurrency %d against endpoint %s", len(keysToValidate), event.GroupID, concurrency, endpoint)
|
|
var wg sync.WaitGroup
|
|
jobs := make(chan models.APIKey, len(keysToValidate))
|
|
for i := 0; i < concurrency; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for key := range jobs {
|
|
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
|
|
if validationErr == nil {
|
|
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
|
|
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
|
|
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
|
|
}
|
|
} else {
|
|
var apiErr *CustomErrors.APIError
|
|
if !CustomErrors.As(validationErr, &apiErr) {
|
|
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
|
|
}
|
|
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
for _, key := range keysToValidate {
|
|
jobs <- key
|
|
}
|
|
close(jobs)
|
|
wg.Wait()
|
|
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
|
|
}
|
|
|
|
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
|
|
// that match a specific set of source statuses within a group.
|
|
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
|
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
|
|
|
|
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
|
|
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
|
|
if err != nil {
|
|
return nil, CustomErrors.ParseDBError(err)
|
|
}
|
|
if len(keyIDs) == 0 {
|
|
now := time.Now()
|
|
return &task.Status{
|
|
IsRunning: false, // The "task" is not running.
|
|
Processed: 0,
|
|
Total: 0,
|
|
Result: map[string]string{ // We use the flexible Result field to pass the message.
|
|
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
|
|
},
|
|
Error: "", // There is no error.
|
|
StartedAt: now,
|
|
FinishedAt: &now, // It started and finished at the same time.
|
|
}, nil // Return nil for the error, signaling a 200 OK.
|
|
}
|
|
// 2. Start a new task using the TaskService, following existing patterns.
|
|
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
|
|
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
|
if err != nil {
|
|
return nil, err // Pass up errors like "task already in progress".
|
|
}
|
|
|
|
// 3. Run the core logic in a separate goroutine.
|
|
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
|
|
return taskStatus, nil
|
|
}
|
|
|
|
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
|
|
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
|
|
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
|
|
}
|
|
}()
|
|
type BatchUpdateResult struct {
|
|
UpdatedCount int `json:"updated_count"`
|
|
SkippedCount int `json:"skipped_count"`
|
|
}
|
|
result := &BatchUpdateResult{}
|
|
var successfulMappings []*models.GroupAPIKeyMapping
|
|
// 1. Fetch all key master statuses in one go. This is efficient.
|
|
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
|
|
if err != nil {
|
|
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
|
|
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
|
return
|
|
}
|
|
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
|
|
for _, key := range keys {
|
|
masterStatusMap[key.ID] = key.MasterStatus
|
|
}
|
|
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
|
|
// avoiding the need for a new repository method. This pattern is
|
|
// already used in other parts of this service.
|
|
var mappings []*models.GroupAPIKeyMapping
|
|
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
|
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
|
|
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
|
return
|
|
}
|
|
processedCount := 0
|
|
for _, mapping := range mappings {
|
|
processedCount++
|
|
// The progress update should reflect the number of items *being processed*, not the final count.
|
|
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
|
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
|
|
if !ok {
|
|
result.SkippedCount++
|
|
continue
|
|
}
|
|
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
|
result.SkippedCount++
|
|
continue
|
|
}
|
|
oldStatus := mapping.Status
|
|
if oldStatus != newStatus {
|
|
mapping.Status = newStatus
|
|
if newStatus == models.StatusActive {
|
|
mapping.ConsecutiveErrorCount = 0
|
|
mapping.LastError = ""
|
|
}
|
|
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
|
|
result.SkippedCount++
|
|
} else {
|
|
result.UpdatedCount++
|
|
successfulMappings = append(successfulMappings, mapping)
|
|
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
|
|
}
|
|
} else {
|
|
result.UpdatedCount++ // Already in desired state, count as success.
|
|
}
|
|
}
|
|
result.SkippedCount += (len(keyIDs) - len(mappings))
|
|
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
|
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
|
|
}
|
|
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
|
|
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
|
}
|
|
|
|
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
|
|
if success {
|
|
if group.PollingStrategy == models.StrategyWeighted {
|
|
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
|
}
|
|
return
|
|
}
|
|
if apiErr == nil {
|
|
s.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided. No action taken.", key.ID, group.ID)
|
|
return
|
|
}
|
|
errMsg := apiErr.Message
|
|
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
|
|
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
|
|
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
|
} else {
|
|
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
|
|
}
|
|
}
|
|
|
|
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
|
|
func sanitizeForLog(errMsg string) string {
|
|
// Find the start of any potential JSON blob or detailed structure.
|
|
jsonStartIndex := strings.Index(errMsg, "{")
|
|
var cleanMsg string
|
|
if jsonStartIndex != -1 {
|
|
// If a '{' is found, take everything before it as the summary
|
|
// and append a simple placeholder.
|
|
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
|
|
} else {
|
|
// If no JSON-like structure is found, use the original message.
|
|
cleanMsg = errMsg
|
|
}
|
|
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
|
|
const maxLen = 250
|
|
if len(cleanMsg) > maxLen {
|
|
return cleanMsg[:maxLen] + "..."
|
|
}
|
|
return cleanMsg
|
|
}
|
|
|
|
func (s *APIKeyService) judgeKeyErrors(
|
|
correlationID string,
|
|
groupID, keyID uint,
|
|
apiErr *CustomErrors.APIError,
|
|
isPreciseRouting bool,
|
|
) {
|
|
logger := s.logger.WithFields(logrus.Fields{"group_id": groupID, "key_id": keyID, "correlation_id": correlationID})
|
|
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
|
if err != nil {
|
|
logger.WithError(err).Warn("Cannot apply consequences for error: mapping not found.")
|
|
return
|
|
}
|
|
now := time.Now()
|
|
mapping.LastUsedAt = &now
|
|
errorMessage := apiErr.Message
|
|
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
|
|
logger.Errorf("Permanent error detected. Banning mapping and revoking master key. Reason: %s", sanitizeForLog(errorMessage))
|
|
logger.WithField("full_error_details", errorMessage).Debug("Full details of the permanent error.")
|
|
if mapping.Status != models.StatusBanned {
|
|
oldStatus := mapping.Status
|
|
mapping.Status = models.StatusBanned
|
|
mapping.LastError = errorMessage
|
|
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
|
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
|
|
} else {
|
|
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
|
|
go s.revokeMasterKey(keyID, "permanent_upstream_error")
|
|
}
|
|
}
|
|
return
|
|
}
|
|
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
|
|
mapping.LastError = errorMessage
|
|
mapping.ConsecutiveErrorCount++
|
|
var threshold int
|
|
if isPreciseRouting {
|
|
group, ok := s.groupManager.GetGroupByID(groupID)
|
|
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
|
if !ok || err != nil {
|
|
logger.Warnf("Could not build operational config for group %d in Precise Routing mode. Falling back to global settings.", groupID)
|
|
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
|
} else {
|
|
threshold = *opConfig.KeyBlacklistThreshold
|
|
}
|
|
} else {
|
|
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
|
}
|
|
logger.Warnf("Temporary error detected. Incrementing error count. New count: %d (Threshold: %d). Reason: %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
|
|
logger.WithField("full_error_details", errorMessage).Debug("Full details of the temporary error.")
|
|
oldStatus := mapping.Status
|
|
newStatus := oldStatus
|
|
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
|
|
newStatus = models.StatusCooldown
|
|
logger.Errorf("Putting mapping into COOLDOWN due to reaching temporary error threshold (%d)", threshold)
|
|
}
|
|
if oldStatus != newStatus {
|
|
mapping.Status = newStatus
|
|
}
|
|
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
|
logger.WithError(err).Error("Failed to update mapping after temporary error.")
|
|
return
|
|
}
|
|
if oldStatus != newStatus {
|
|
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
|
}
|
|
return
|
|
}
|
|
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
|
|
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
|
|
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
|
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
|
|
}
|
|
}
|
|
|
|
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
|
|
key, err := s.keyRepo.GetKeyByID(keyID)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
s.logger.Warnf("Attempted to revoke non-existent key ID %d.", keyID)
|
|
} else {
|
|
s.logger.Errorf("Failed to get key by ID %d for master status revocation: %v", keyID, err)
|
|
}
|
|
return
|
|
}
|
|
if key.MasterStatus == models.MasterStatusRevoked {
|
|
return
|
|
}
|
|
oldMasterStatus := key.MasterStatus
|
|
newMasterStatus := models.MasterStatusRevoked
|
|
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
|
|
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
|
|
return
|
|
}
|
|
masterKeyEvent := models.MasterKeyStatusChangedEvent{
|
|
KeyID: keyID,
|
|
OldMasterStatus: oldMasterStatus,
|
|
NewMasterStatus: newMasterStatus,
|
|
ChangeReason: reason,
|
|
ChangedAt: time.Now(),
|
|
}
|
|
eventData, _ := json.Marshal(masterKeyEvent)
|
|
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
|
|
}
|
|
|
|
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
|
|
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
|
|
}
|