Files
gemini-banlancer/internal/service/apikey_service.go

865 lines
26 KiB
Go

// Filename: internal/service/apikey_service.go
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"gemini-balancer/internal/channel"
CustomErrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"gemini-balancer/internal/task"
"math"
"sort"
"strings"
"sync"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
TaskTypeRestoreAllBannedInGroup = "restore_all_banned_in_group"
TaskTypeRestoreSpecificKeys = "restore_specific_keys_in_group"
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
)
type BatchRestoreResult struct {
RestoredCount int `json:"restored_count"`
SkippedCount int `json:"skipped_count"`
SkippedKeys []SkippedKeyInfo `json:"skipped_keys"`
}
type SkippedKeyInfo struct {
KeyID uint `json:"key_id"`
Reason string `json:"reason"`
}
type PaginatedAPIKeys struct {
Items []*models.APIKeyDetails `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
type APIKeyService struct {
db *gorm.DB
keyRepo repository.KeyRepository
channel channel.ChannelProxy
store store.Store
SettingsManager *settings.SettingsManager
taskService task.Reporter
logger *logrus.Entry
stopChan chan struct{}
validationService *KeyValidationService
groupManager *GroupManager
}
func NewAPIKeyService(
db *gorm.DB,
repo repository.KeyRepository,
ch channel.ChannelProxy,
s store.Store,
sm *settings.SettingsManager,
ts task.Reporter,
vs *KeyValidationService,
gm *GroupManager,
logger *logrus.Logger,
) *APIKeyService {
return &APIKeyService{
db: db,
keyRepo: repo,
channel: ch,
store: s,
SettingsManager: sm,
taskService: ts,
logger: logger.WithField("component", "APIKeyService🔑"),
stopChan: make(chan struct{}),
validationService: vs,
groupManager: gm,
}
}
func (s *APIKeyService) Start() {
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicRequestFinished, err)
return
}
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicMasterKeyStatusChanged, err)
return
}
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicKeyStatusChanged, err)
return
}
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
if err != nil {
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicImportGroupCompleted, err)
return
}
s.logger.Info("Started and subscribed to all event topics")
go func() {
defer requestSub.Close()
defer masterKeySub.Close()
defer keyStatusSub.Close()
defer importSub.Close()
for {
select {
case msg := <-requestSub.Channel():
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal RequestFinishedEvent")
continue
}
s.handleKeyUsageEvent(&event)
case msg := <-masterKeySub.Channel():
var event models.MasterKeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal MasterKeyStatusChangedEvent")
continue
}
s.handleMasterKeyStatusChangeEvent(&event)
case msg := <-keyStatusSub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal KeyStatusChangedEvent")
continue
}
s.handleKeyStatusChangeEvent(&event)
case msg := <-importSub.Channel():
var event models.ImportGroupCompletedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent")
continue
}
s.logger.Infof("Received import completion for group %d, validating %d keys", event.GroupID, len(event.KeyIDs))
go s.handlePostImportValidation(&event)
case <-s.stopChan:
s.logger.Info("Stopping event listener")
return
}
}
}()
}
func (s *APIKeyService) Stop() {
close(s.stopChan)
}
func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) {
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
groupID := *event.RequestLog.GroupID
keyID := *event.RequestLog.KeyID
if event.RequestLog.IsSuccess {
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
if err != nil {
s.logger.Warnf("[%s] Mapping not found for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
return
}
statusChanged := false
oldStatus := mapping.Status
if mapping.Status != models.StatusActive {
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
statusChanged = true
}
now := time.Now()
mapping.LastUsedAt = &now
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
return
}
if statusChanged {
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusActive, "key_recovered_after_use")
}
return
}
if event.Error != nil {
s.judgeKeyErrors(ctx, event.CorrelationID, groupID, keyID, event.Error, event.IsPreciseRouting)
}
}
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
ctx := context.Background()
s.logger.WithFields(logrus.Fields{
"group_id": event.GroupID,
"key_id": event.KeyID,
"new_status": event.NewStatus,
"reason": event.ChangeReason,
}).Info("Updating polling caches based on status change")
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
}
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
changeEvent := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
OldStatus: oldStatus,
NewStatus: newStatus,
ChangeReason: reason,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(changeEvent)
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.Errorf("Failed to publish status change event for group %d: %v", groupID, err)
}
}
func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
if params.Keyword == "" {
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
if err != nil {
return nil, err
}
totalPages := 0
if total > 0 && params.PageSize > 0 {
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
}
return &PaginatedAPIKeys{
Items: items,
Total: total,
Page: params.Page,
PageSize: params.PageSize,
TotalPages: totalPages,
}, nil
}
s.logger.Infof("Performing in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
statusesToFilter := []string{"all"}
if params.Status != "" {
statusesToFilter = []string{params.Status}
}
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
if err != nil {
return nil, fmt.Errorf("failed to fetch key IDs: %w", err)
}
if len(allKeyIDs) == 0 {
return &PaginatedAPIKeys{
Items: []*models.APIKeyDetails{},
Total: 0,
Page: 1,
PageSize: params.PageSize,
TotalPages: 0,
}, nil
}
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
if err != nil {
return nil, fmt.Errorf("failed to fetch keys: %w", err)
}
var allMappings []models.GroupAPIKeyMapping
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error; err != nil {
return nil, fmt.Errorf("failed to fetch mappings: %w", err)
}
mappingMap := make(map[uint]*models.GroupAPIKeyMapping, len(allMappings))
for i := range allMappings {
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
}
var filteredItems []*models.APIKeyDetails
for _, key := range allKeys {
if strings.Contains(key.APIKey, params.Keyword) {
if mapping, ok := mappingMap[key.ID]; ok {
filteredItems = append(filteredItems, &models.APIKeyDetails{
ID: key.ID,
CreatedAt: key.CreatedAt,
UpdatedAt: key.UpdatedAt,
APIKey: key.APIKey,
MasterStatus: key.MasterStatus,
Status: mapping.Status,
LastError: mapping.LastError,
ConsecutiveErrorCount: mapping.ConsecutiveErrorCount,
LastUsedAt: mapping.LastUsedAt,
CooldownUntil: mapping.CooldownUntil,
})
}
}
}
sort.Slice(filteredItems, func(i, j int) bool {
return filteredItems[i].ID > filteredItems[j].ID
})
total := int64(len(filteredItems))
start := (params.Page - 1) * params.PageSize
end := start + params.PageSize
if start < 0 {
start = 0
}
if start >= len(filteredItems) {
return &PaginatedAPIKeys{
Items: []*models.APIKeyDetails{},
Total: total,
Page: params.Page,
PageSize: params.PageSize,
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
}, nil
}
if end > len(filteredItems) {
end = len(filteredItems)
}
return &PaginatedAPIKeys{
Items: filteredItems[start:end],
Total: total,
Page: params.Page,
PageSize: params.PageSize,
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
}, nil
}
func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) {
return s.keyRepo.GetKeysByIDs(ids)
}
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
go func() {
if err := s.keyRepo.Update(key); err != nil {
s.logger.Errorf("Failed to update key ID %d: %v", key.ID, err)
}
}()
return nil
}
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
if err != nil {
s.logger.Warnf("Could not get groups for key %d before deletion: %v", id, err)
}
if err := s.keyRepo.HardDeleteByID(id); err != nil {
return err
}
for _, groupID := range groups {
event := models.KeyStatusChangedEvent{
KeyID: id,
GroupID: groupID,
ChangeReason: "key_hard_deleted",
}
eventData, _ := json.Marshal(event)
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
}
return nil
}
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
return nil, CustomErrors.ErrStateConflictMasterRevoked
}
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
if err != nil {
return nil, err
}
oldStatus := mapping.Status
if oldStatus == newStatus {
return mapping, nil
}
mapping.Status = newStatus
if newStatus == models.StatusActive {
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
return nil, err
}
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
return mapping, nil
}
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
if event.NewMasterStatus != models.MasterStatusRevoked {
return
}
ctx := context.Background()
s.logger.Infof("Key %d revoked, propagating to all groups", event.KeyID)
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get groups for key %d", event.KeyID)
return
}
if len(affectedGroupIDs) == 0 {
s.logger.Infof("Key %d not associated with any group", event.KeyID)
return
}
for _, groupID := range affectedGroupIDs {
if _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned); err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.WithError(err).Errorf("Failed to ban key %d in group %d", event.KeyID, groupID)
}
}
}
}
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
if len(keyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
if err != nil {
return nil, err
}
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
return taskStatus, nil
}
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic in restore task: %v", r)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
}
}()
var mappingsToProcess []models.GroupAPIKeyMapping
if err := s.db.WithContext(ctx).Preload("APIKey").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Find(&mappingsToProcess).Error; err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
result := &BatchRestoreResult{SkippedKeys: make([]SkippedKeyInfo, 0)}
var successfulMappings []*models.GroupAPIKeyMapping
for i, mapping := range mappingsToProcess {
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
if mapping.APIKey == nil {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
KeyID: mapping.APIKeyID,
Reason: "APIKey not found",
})
continue
}
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
KeyID: mapping.APIKeyID,
Reason: fmt.Sprintf("Master status is %s", mapping.APIKey.MasterStatus),
})
continue
}
oldStatus := mapping.Status
if oldStatus != models.StatusActive {
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
KeyID: mapping.APIKeyID,
Reason: "DB update failed",
})
} else {
result.RestoredCount++
successfulMappings = append(successfulMappings, &mapping)
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
}
} else {
result.RestoredCount++
}
}
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed batch cache update after restore")
}
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
var bannedKeyIDs []uint
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
Pluck("api_key_id", &bannedKeyIDs).Error; err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(bannedKeyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore")
}
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
}
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
ctx := context.Background()
group, ok := s.groupManager.GetGroupByID(event.GroupID)
if !ok {
s.logger.Errorf("Group %d not found for post-import validation", event.GroupID)
return
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
s.logger.Errorf("Failed to build config for group %d: %v", event.GroupID, err)
return
}
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
if err != nil {
s.logger.Errorf("Failed to build endpoint for group %d: %v", event.GroupID, err)
return
}
concurrency := s.SettingsManager.GetSettings().BaseKeyCheckConcurrency
if opConfig.KeyCheckConcurrency != nil {
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 10
}
timeout := time.Duration(s.SettingsManager.GetSettings().KeyCheckTimeoutSeconds) * time.Second
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
if err != nil {
s.logger.Errorf("Failed to get keys for validation in group %d: %v", event.GroupID, err)
return
}
s.logger.Infof("Validating %d keys for group %d (concurrency: %d)", len(keysToValidate), event.GroupID, concurrency)
var wg sync.WaitGroup
jobs := make(chan models.APIKey, len(keysToValidate))
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for key := range jobs {
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
if validationErr == nil {
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to activate key %d in group %d: %v", key.ID, event.GroupID, err)
}
} else {
var apiErr *CustomErrors.APIError
if !CustomErrors.As(validationErr, &apiErr) {
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
}
s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false)
}
}
}()
}
for _, key := range keysToValidate {
jobs <- key
}
close(jobs)
wg.Wait()
s.logger.Infof("Finished post-import validation for group %d", event.GroupID)
}
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyIDs) == 0 {
now := time.Now()
return &task.Status{
IsRunning: false,
Processed: 0,
Total: 0,
Result: map[string]string{"message": "没有找到符合条件的Key"},
StartedAt: now,
FinishedAt: &now,
}, nil
}
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
if err != nil {
return nil, err
}
go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
return taskStatus, nil
}
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic in status update task: %v", r)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
}
}()
type BatchUpdateResult struct {
UpdatedCount int `json:"updated_count"`
SkippedCount int `json:"skipped_count"`
}
result := &BatchUpdateResult{}
var successfulMappings []*models.GroupAPIKeyMapping
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus, len(keys))
for _, key := range keys {
masterStatusMap[key.ID] = key.MasterStatus
}
var mappings []*models.GroupAPIKeyMapping
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
for i, mapping := range mappings {
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
if !ok {
result.SkippedCount++
continue
}
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
result.SkippedCount++
continue
}
oldStatus := mapping.Status
if oldStatus != newStatus {
mapping.Status = newStatus
if newStatus == models.StatusActive {
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
result.SkippedCount++
} else {
result.UpdatedCount++
successfulMappings = append(successfulMappings, mapping)
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
}
} else {
result.UpdatedCount++
}
}
result.SkippedCount += (len(keyIDs) - len(mappings))
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed batch cache update after status update")
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
ctx := context.Background()
if success {
if group.PollingStrategy == models.StrategyWeighted {
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
}
return
}
if apiErr == nil {
return
}
errMsg := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
}
}
func (s *APIKeyService) judgeKeyErrors(ctx context.Context, correlationID string, groupID, keyID uint, apiErr *CustomErrors.APIError, isPreciseRouting bool) {
logger := s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"correlation_id": correlationID,
})
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
if err != nil {
logger.WithError(err).Warn("Mapping not found, cannot apply error consequences")
return
}
now := time.Now()
mapping.LastUsedAt = &now
errorMessage := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
logger.Errorf("Permanent error: %s", sanitizeForLog(errorMessage))
if mapping.Status != models.StatusBanned {
oldStatus := mapping.Status
mapping.Status = models.StatusBanned
mapping.LastError = errorMessage
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to ban mapping")
} else {
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusBanned, "permanent_error")
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
}
}
return
}
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
mapping.LastError = errorMessage
mapping.ConsecutiveErrorCount++
threshold := s.SettingsManager.GetSettings().BlacklistThreshold
if isPreciseRouting {
if group, ok := s.groupManager.GetGroupByID(groupID); ok {
if opConfig, err := s.groupManager.BuildOperationalConfig(group); err == nil && opConfig.KeyBlacklistThreshold != nil {
threshold = *opConfig.KeyBlacklistThreshold
}
}
}
logger.Warnf("Temporary error (count: %d, threshold: %d): %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
oldStatus := mapping.Status
newStatus := oldStatus
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
newStatus = models.StatusCooldown
logger.Errorf("Moving to COOLDOWN after reaching threshold %d", threshold)
}
if oldStatus != newStatus {
mapping.Status = newStatus
}
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping after temporary error")
return
}
if oldStatus != newStatus {
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
}
return
}
logger.Infof("Ignorable error: %s", sanitizeForLog(errorMessage))
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update LastUsedAt")
}
}
func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.Warnf("Attempted to revoke non-existent key %d", keyID)
} else {
s.logger.Errorf("Failed to get key %d for revocation: %v", keyID, err)
}
return
}
if key.MasterStatus == models.MasterStatusRevoked {
return
}
oldMasterStatus := key.MasterStatus
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, models.MasterStatusRevoked); err != nil {
s.logger.Errorf("Failed to revoke key %d: %v", keyID, err)
return
}
masterKeyEvent := models.MasterKeyStatusChangedEvent{
KeyID: keyID,
OldMasterStatus: oldMasterStatus,
NewMasterStatus: models.MasterStatusRevoked,
ChangeReason: reason,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(masterKeyEvent)
_ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData)
}
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
}
func sanitizeForLog(errMsg string) string {
if idx := strings.Index(errMsg, "{"); idx != -1 {
errMsg = strings.TrimSpace(errMsg[:idx]) + " {...}"
}
if len(errMsg) > 250 {
return errMsg[:250] + "..."
}
return errMsg
}