Fix requestTimeout & memory store

This commit is contained in:
XOF
2025-11-24 00:09:55 +08:00
parent 6c7283d51b
commit 3a95a07e8a
5 changed files with 328 additions and 441 deletions

View File

@@ -92,7 +92,6 @@ func BuildContainer() (*dig.Container, error) {
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
// 为Channel提供依赖 (Logger 和 *models.SystemSettings 数据插座)
container.Provide(channel.NewGeminiChannel)
container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch })

View File

@@ -197,8 +197,8 @@ func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body [
var attemptErr *errors.APIError
var isSuccess bool
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
defer cancel()
attemptReq := c.Request.Clone(ctx)

View File

@@ -29,10 +29,6 @@ const (
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
)
const (
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
)
type BatchRestoreResult struct {
RestoredCount int `json:"restored_count"`
SkippedCount int `json:"skipped_count"`
@@ -52,12 +48,6 @@ type PaginatedAPIKeys struct {
TotalPages int `json:"total_pages"`
}
type KeyTestResult struct {
Key string `json:"key"`
Status string `json:"status"`
Message string `json:"message"`
}
type APIKeyService struct {
db *gorm.DB
keyRepo repository.KeyRepository
@@ -99,37 +89,39 @@ func NewAPIKeyService(
func (s *APIKeyService) Start() {
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicRequestFinished, err)
return
}
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicMasterKeyStatusChanged, err)
return
}
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicKeyStatusChanged, err)
return
}
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicImportGroupCompleted, err)
return
}
s.logger.Info("Started and subscribed to request, master key, health check, and import events.")
s.logger.Info("Started and subscribed to all event topics")
go func() {
defer requestSub.Close()
defer masterKeySub.Close()
defer keyStatusSub.Close()
defer importSub.Close()
for {
select {
case msg := <-requestSub.Channel():
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event for key status update: %v", err)
s.logger.WithError(err).Error("Failed to unmarshal RequestFinishedEvent")
continue
}
s.handleKeyUsageEvent(&event)
@@ -137,14 +129,15 @@ func (s *APIKeyService) Start() {
case msg := <-masterKeySub.Channel():
var event models.MasterKeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal MasterKeyStatusChangedEvent: %v", err)
s.logger.WithError(err).Error("Failed to unmarshal MasterKeyStatusChangedEvent")
continue
}
s.handleMasterKeyStatusChangeEvent(&event)
case msg := <-keyStatusSub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
s.logger.WithError(err).Error("Failed to unmarshal KeyStatusChangedEvent")
continue
}
s.handleKeyStatusChangeEvent(&event)
@@ -152,15 +145,14 @@ func (s *APIKeyService) Start() {
case msg := <-importSub.Channel():
var event models.ImportGroupCompletedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent.")
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent")
continue
}
s.logger.Infof("Received ImportGroupCompletedEvent for group %d, triggering validation for %d keys.", event.GroupID, len(event.KeyIDs))
s.logger.Infof("Received import completion for group %d, validating %d keys", event.GroupID, len(event.KeyIDs))
go s.handlePostImportValidation(&event)
case <-s.stopChan:
s.logger.Info("Stopping event listener.")
s.logger.Info("Stopping event listener")
return
}
}
@@ -175,13 +167,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
groupID := *event.RequestLog.GroupID
keyID := *event.RequestLog.KeyID
if event.RequestLog.IsSuccess {
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
if err != nil {
s.logger.Warnf("[%s] Could not find mapping for G:%d K:%d on successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
s.logger.Warnf("[%s] Mapping not found for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
return
}
statusChanged := false
oldStatus := mapping.Status
if mapping.Status != models.StatusActive {
@@ -193,38 +190,33 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
now := time.Now()
mapping.LastUsedAt = &now
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
return
}
if statusChanged {
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusActive, "key_recovered_after_use")
}
return
}
if event.Error != nil {
s.judgeKeyErrors(
ctx,
event.CorrelationID,
*event.RequestLog.GroupID,
*event.RequestLog.KeyID,
event.Error,
event.IsPreciseRouting,
)
s.judgeKeyErrors(ctx, event.CorrelationID, groupID, keyID, event.Error, event.IsPreciseRouting)
}
}
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
ctx := context.Background()
log := s.logger.WithFields(logrus.Fields{
s.logger.WithFields(logrus.Fields{
"group_id": event.GroupID,
"key_id": event.KeyID,
"new_status": event.NewStatus,
"reason": event.ChangeReason,
})
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
}).Info("Updating polling caches based on status change")
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
log.Info("Polling caches updated based on health check event.")
}
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
@@ -238,7 +230,7 @@ func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, k
}
eventData, _ := json.Marshal(changeEvent)
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
s.logger.Errorf("Failed to publish status change event for group %d: %v", groupID, err)
}
}
@@ -248,10 +240,12 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
if err != nil {
return nil, err
}
totalPages := 0
if total > 0 && params.PageSize > 0 {
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
}
return &PaginatedAPIKeys{
Items: items,
Total: total,
@@ -260,34 +254,44 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
TotalPages: totalPages,
}, nil
}
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
var statusesToFilter []string
s.logger.Infof("Performing in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
statusesToFilter := []string{"all"}
if params.Status != "" {
statusesToFilter = append(statusesToFilter, params.Status)
} else {
statusesToFilter = append(statusesToFilter, "all")
statusesToFilter = []string{params.Status}
}
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
if err != nil {
return nil, fmt.Errorf("failed to fetch all key IDs for search: %w", err)
return nil, fmt.Errorf("failed to fetch key IDs: %w", err)
}
if len(allKeyIDs) == 0 {
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
return &PaginatedAPIKeys{
Items: []*models.APIKeyDetails{},
Total: 0,
Page: 1,
PageSize: params.PageSize,
TotalPages: 0,
}, nil
}
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
if err != nil {
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
return nil, fmt.Errorf("failed to fetch keys: %w", err)
}
var allMappings []models.GroupAPIKeyMapping
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
if err != nil {
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error; err != nil {
return nil, fmt.Errorf("failed to fetch mappings: %w", err)
}
mappingMap := make(map[uint]*models.GroupAPIKeyMapping)
mappingMap := make(map[uint]*models.GroupAPIKeyMapping, len(allMappings))
for i := range allMappings {
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
}
var filteredItems []*models.APIKeyDetails
for _, key := range allKeys {
if strings.Contains(key.APIKey, params.Keyword) {
@@ -307,12 +311,15 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
}
}
}
sort.Slice(filteredItems, func(i, j int) bool {
return filteredItems[i].ID > filteredItems[j].ID
})
total := int64(len(filteredItems))
start := (params.Page - 1) * params.PageSize
end := start + params.PageSize
if start < 0 {
start = 0
}
@@ -328,9 +335,9 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
if end > len(filteredItems) {
end = len(filteredItems)
}
paginatedItems := filteredItems[start:end]
return &PaginatedAPIKeys{
Items: paginatedItems,
Items: filteredItems[start:end],
Total: total,
Page: params.Page,
PageSize: params.PageSize,
@@ -344,15 +351,8 @@ func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
go func() {
bgCtx := context.Background()
var oldKey models.APIKey
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
return
}
if err := s.keyRepo.Update(key); err != nil {
s.logger.Errorf("Failed to asynchronously update key ID %d: %v", key.ID, err)
return
s.logger.Errorf("Failed to update key ID %d: %v", key.ID, err)
}
}()
return nil
@@ -361,22 +361,24 @@ func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) er
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
if err != nil {
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
s.logger.Warnf("Could not get groups for key %d before deletion: %v", id, err)
}
err = s.keyRepo.HardDeleteByID(id)
if err == nil {
for _, groupID := range groups {
event := models.KeyStatusChangedEvent{
KeyID: id,
GroupID: groupID,
ChangeReason: "key_hard_deleted",
}
eventData, _ := json.Marshal(event)
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
}
if err := s.keyRepo.HardDeleteByID(id); err != nil {
return err
}
return err
for _, groupID := range groups {
event := models.KeyStatusChangedEvent{
KeyID: id,
GroupID: groupID,
ChangeReason: "key_hard_deleted",
}
eventData, _ := json.Marshal(event)
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
}
return nil
}
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
@@ -388,47 +390,54 @@ func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
return nil, CustomErrors.ErrStateConflictMasterRevoked
}
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
if err != nil {
return nil, err
}
oldStatus := mapping.Status
if oldStatus == newStatus {
return mapping, nil
}
mapping.Status = newStatus
if newStatus == models.StatusActive {
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
return nil, err
}
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
return mapping, nil
}
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
ctx := context.Background()
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
if event.NewMasterStatus != models.MasterStatusRevoked {
return
}
ctx := context.Background()
s.logger.Infof("Key %d revoked, propagating to all groups", event.KeyID)
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
s.logger.WithError(err).Errorf("Failed to get groups for key %d", event.KeyID)
return
}
if len(affectedGroupIDs) == 0 {
s.logger.Infof("Key ID %d is revoked, but it's not associated with any group. No action needed.", event.KeyID)
s.logger.Infof("Key %d not associated with any group", event.KeyID)
return
}
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
for _, groupID := range affectedGroupIDs {
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
if err != nil {
if _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned); err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
s.logger.WithError(err).Errorf("Failed to ban key %d in group %d", event.KeyID, groupID)
}
}
}
@@ -436,59 +445,71 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
if len(keyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
if err != nil {
return nil, err
}
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
return taskStatus, nil
}
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
s.logger.Errorf("Panic in restore task: %v", r)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
}
}()
var mappingsToProcess []models.GroupAPIKeyMapping
err := s.db.WithContext(ctx).Preload("APIKey").
if err := s.db.WithContext(ctx).Preload("APIKey").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Find(&mappingsToProcess).Error
if err != nil {
Find(&mappingsToProcess).Error; err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
result := &BatchRestoreResult{
SkippedKeys: make([]SkippedKeyInfo, 0),
}
result := &BatchRestoreResult{SkippedKeys: make([]SkippedKeyInfo, 0)}
var successfulMappings []*models.GroupAPIKeyMapping
processedCount := 0
for _, mapping := range mappingsToProcess {
processedCount++
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
for i, mapping := range mappingsToProcess {
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
if mapping.APIKey == nil {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
KeyID: mapping.APIKeyID,
Reason: "APIKey not found",
})
continue
}
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: fmt.Sprintf("Master status is '%s'.", mapping.APIKey.MasterStatus)})
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
KeyID: mapping.APIKeyID,
Reason: fmt.Sprintf("Master status is %s", mapping.APIKey.MasterStatus),
})
continue
}
oldStatus := mapping.Status
if oldStatus != models.StatusActive {
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
KeyID: mapping.APIKeyID,
Reason: "DB update failed",
})
} else {
result.RestoredCount++
successfulMappings = append(successfulMappings, &mapping)
@@ -498,61 +519,72 @@ func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, r
result.RestoredCount++
}
}
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
s.logger.WithError(err).Error("Failed batch cache update after restore")
}
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
var bannedKeyIDs []uint
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
Pluck("api_key_id", &bannedKeyIDs).Error
if err != nil {
Pluck("api_key_id", &bannedKeyIDs).Error; err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(bannedKeyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore")
}
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
}
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
ctx := context.Background()
group, ok := s.groupManager.GetGroupByID(event.GroupID)
if !ok {
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
s.logger.Errorf("Group %d not found for post-import validation", event.GroupID)
return
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
s.logger.Errorf("Failed to build operational config for group %d, aborting validation: %v", event.GroupID, err)
s.logger.Errorf("Failed to build config for group %d: %v", event.GroupID, err)
return
}
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
if err != nil {
s.logger.Errorf("Failed to build key check endpoint for group %d, aborting validation: %v", event.GroupID, err)
s.logger.Errorf("Failed to build endpoint for group %d: %v", event.GroupID, err)
return
}
globalSettings := s.SettingsManager.GetSettings()
concurrency := globalSettings.BaseKeyCheckConcurrency
concurrency := s.SettingsManager.GetSettings().BaseKeyCheckConcurrency
if opConfig.KeyCheckConcurrency != nil {
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 10
}
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
timeout := time.Duration(s.SettingsManager.GetSettings().KeyCheckTimeoutSeconds) * time.Second
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
if err != nil {
s.logger.Errorf("Failed to get key models for validation in group %d: %v", event.GroupID, err)
s.logger.Errorf("Failed to get keys for validation in group %d: %v", event.GroupID, err)
return
}
s.logger.Infof("Validating %d keys for group %d with concurrency %d against endpoint %s", len(keysToValidate), event.GroupID, concurrency, endpoint)
s.logger.Infof("Validating %d keys for group %d (concurrency: %d)", len(keysToValidate), event.GroupID, concurrency)
var wg sync.WaitGroup
jobs := make(chan models.APIKey, len(keysToValidate))
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
@@ -560,9 +592,8 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
for key := range jobs {
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
if validationErr == nil {
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
s.logger.Errorf("Failed to activate key %d in group %d: %v", key.ID, event.GroupID, err)
}
} else {
var apiErr *CustomErrors.APIError
@@ -574,35 +605,34 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
}
}()
}
for _, key := range keysToValidate {
jobs <- key
}
close(jobs)
wg.Wait()
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
s.logger.Infof("Finished post-import validation for group %d", event.GroupID)
}
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyIDs) == 0 {
now := time.Now()
return &task.Status{
IsRunning: false,
Processed: 0,
Total: 0,
Result: map[string]string{
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
},
Error: "",
IsRunning: false,
Processed: 0,
Total: 0,
Result: map[string]string{"message": "没有找到符合条件的Key"},
StartedAt: now,
FinishedAt: &now,
}, nil
}
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
if err != nil {
@@ -616,10 +646,11 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, group
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
s.logger.Errorf("Panic in status update task: %v", r)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
}
}()
type BatchUpdateResult struct {
UpdatedCount int `json:"updated_count"`
SkippedCount int `json:"skipped_count"`
@@ -629,33 +660,35 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
if err != nil {
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus, len(keys))
for _, key := range keys {
masterStatusMap[key.ID] = key.MasterStatus
}
var mappings []*models.GroupAPIKeyMapping
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
processedCount := 0
for _, mapping := range mappings {
processedCount++
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
for i, mapping := range mappings {
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
if !ok {
result.SkippedCount++
continue
}
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
result.SkippedCount++
continue
}
oldStatus := mapping.Status
if oldStatus != newStatus {
mapping.Status = newStatus
@@ -663,6 +696,7 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
result.SkippedCount++
} else {
@@ -674,122 +708,112 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
result.UpdatedCount++
}
}
result.SkippedCount += (len(keyIDs) - len(mappings))
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
s.logger.WithError(err).Error("Failed batch cache update after status update")
}
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
ctx := context.Background()
if success {
if group.PollingStrategy == models.StrategyWeighted {
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
}
return
}
if apiErr == nil {
s.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided. No action taken.", key.ID, group.ID)
return
}
errMsg := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
} else {
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
}
}
func sanitizeForLog(errMsg string) string {
jsonStartIndex := strings.Index(errMsg, "{")
var cleanMsg string
if jsonStartIndex != -1 {
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
} else {
cleanMsg = errMsg
}
const maxLen = 250
if len(cleanMsg) > maxLen {
return cleanMsg[:maxLen] + "..."
}
return cleanMsg
}
func (s *APIKeyService) judgeKeyErrors(ctx context.Context, correlationID string, groupID, keyID uint, apiErr *CustomErrors.APIError, isPreciseRouting bool) {
logger := s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"correlation_id": correlationID,
})
func (s *APIKeyService) judgeKeyErrors(
ctx context.Context,
correlationID string,
groupID, keyID uint,
apiErr *CustomErrors.APIError,
isPreciseRouting bool,
) {
logger := s.logger.WithFields(logrus.Fields{"group_id": groupID, "key_id": keyID, "correlation_id": correlationID})
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
if err != nil {
logger.WithError(err).Warn("Cannot apply consequences for error: mapping not found.")
logger.WithError(err).Warn("Mapping not found, cannot apply error consequences")
return
}
now := time.Now()
mapping.LastUsedAt = &now
errorMessage := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
logger.Errorf("Permanent error detected. Banning mapping and revoking master key. Reason: %s", sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the permanent error.")
logger.Errorf("Permanent error: %s", sanitizeForLog(errorMessage))
if mapping.Status != models.StatusBanned {
oldStatus := mapping.Status
mapping.Status = models.StatusBanned
mapping.LastError = errorMessage
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
logger.WithError(err).Error("Failed to ban mapping")
} else {
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusBanned, "permanent_error")
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
}
}
return
}
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
mapping.LastError = errorMessage
mapping.ConsecutiveErrorCount++
var threshold int
threshold := s.SettingsManager.GetSettings().BlacklistThreshold
if isPreciseRouting {
group, ok := s.groupManager.GetGroupByID(groupID)
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if !ok || err != nil {
logger.Warnf("Could not build operational config for group %d in Precise Routing mode. Falling back to global settings.", groupID)
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
} else {
threshold = *opConfig.KeyBlacklistThreshold
if group, ok := s.groupManager.GetGroupByID(groupID); ok {
if opConfig, err := s.groupManager.BuildOperationalConfig(group); err == nil && opConfig.KeyBlacklistThreshold != nil {
threshold = *opConfig.KeyBlacklistThreshold
}
}
} else {
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
}
logger.Warnf("Temporary error detected. Incrementing error count. New count: %d (Threshold: %d). Reason: %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the temporary error.")
logger.Warnf("Temporary error (count: %d, threshold: %d): %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
oldStatus := mapping.Status
newStatus := oldStatus
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
newStatus = models.StatusCooldown
logger.Errorf("Putting mapping into COOLDOWN due to reaching temporary error threshold (%d)", threshold)
logger.Errorf("Moving to COOLDOWN after reaching threshold %d", threshold)
}
if oldStatus != newStatus {
mapping.Status = newStatus
}
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping after temporary error.")
logger.WithError(err).Error("Failed to update mapping after temporary error")
return
}
if oldStatus != newStatus {
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
}
return
}
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
logger.Infof("Ignorable error: %s", sanitizeForLog(errorMessage))
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
logger.WithError(err).Error("Failed to update LastUsedAt")
}
}
@@ -797,25 +821,27 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.Warnf("Attempted to revoke non-existent key ID %d.", keyID)
s.logger.Warnf("Attempted to revoke non-existent key %d", keyID)
} else {
s.logger.Errorf("Failed to get key by ID %d for master status revocation: %v", keyID, err)
s.logger.Errorf("Failed to get key %d for revocation: %v", keyID, err)
}
return
}
if key.MasterStatus == models.MasterStatusRevoked {
return
}
oldMasterStatus := key.MasterStatus
newMasterStatus := models.MasterStatusRevoked
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, models.MasterStatusRevoked); err != nil {
s.logger.Errorf("Failed to revoke key %d: %v", keyID, err)
return
}
masterKeyEvent := models.MasterKeyStatusChangedEvent{
KeyID: keyID,
OldMasterStatus: oldMasterStatus,
NewMasterStatus: newMasterStatus,
NewMasterStatus: models.MasterStatusRevoked,
ChangeReason: reason,
ChangedAt: time.Now(),
}
@@ -826,3 +852,13 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
}
func sanitizeForLog(errMsg string) string {
if idx := strings.Index(errMsg, "{"); idx != -1 {
errMsg = strings.TrimSpace(errMsg[:idx]) + " {...}"
}
if len(errMsg) > 250 {
return errMsg[:250] + "..."
}
return errMsg
}

View File

@@ -1,5 +1,4 @@
// Filename: internal/service/group_manager.go (Syncer升级版)
// Filename: internal/service/group_manager.go
package service
import (
@@ -29,10 +28,9 @@ type GroupManagerCacheData struct {
Groups []*models.KeyGroup
GroupsByName map[string]*models.KeyGroup
GroupsByID map[uint]*models.KeyGroup
KeyCounts map[uint]int64 // GroupID -> Total Key Count
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
KeyCounts map[uint]int64
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64
}
type GroupManager struct {
db *gorm.DB
keyRepo repository.KeyRepository
@@ -41,7 +39,6 @@ type GroupManager struct {
syncer *syncer.CacheSyncer[GroupManagerCacheData]
logger *logrus.Entry
}
type UpdateOrderPayload struct {
ID uint `json:"id" binding:"required"`
Order int `json:"order"`
@@ -49,43 +46,19 @@ type UpdateOrderPayload struct {
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
return func() (GroupManagerCacheData, error) {
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
var groups []*models.KeyGroup
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
if err := db.Preload("AllowedUpstreams").
Preload("AllowedModels").
Preload("Settings").
Preload("RequestConfig").
Preload("Mappings").
Find(&groups).Error; err != nil {
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err)
}
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
var allMappings []*models.GroupAPIKeyMapping
if err := db.Find(&allMappings).Error; err != nil {
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
}
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
for i := range allMappings {
mapping := allMappings[i] // Avoid pointer issues with range
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
}
for _, group := range groups {
if mappings, ok := mappingsByGroupID[group.ID]; ok {
group.Mappings = mappings
}
}
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
keyCounts := make(map[uint]int64, len(groups))
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
groupsByName := make(map[string]*models.KeyGroup, len(groups))
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
for _, group := range groups {
keyCounts[group.ID] = int64(len(group.Mappings))
statusCounts := make(map[models.APIKeyStatus]int64)
@@ -93,20 +66,9 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
statusCounts[mapping.Status]++
}
keyStatusCounts[group.ID] = statusCounts
groupsByName[group.Name] = group
groupsByID[group.ID] = group
}
groupsByName := make(map[string]*models.KeyGroup, len(groups))
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
for i, group := range groups {
if group == nil {
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
} else {
groupsByName[group.Name] = group
groupsByID[group.ID] = group
}
}
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
return GroupManagerCacheData{
Groups: groups,
GroupsByName: groupsByName,
@@ -116,7 +78,6 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
}, nil
}
}
func NewGroupManager(
db *gorm.DB,
keyRepo repository.KeyRepository,
@@ -134,138 +95,67 @@ func NewGroupManager(
logger: logger.WithField("component", "GroupManager"),
}
}
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
cache := gm.syncer.Get()
if len(cache.Groups) == 0 {
return []*models.KeyGroup{}
}
groupsToOrder := cache.Groups
sort.Slice(groupsToOrder, func(i, j int) bool {
if groupsToOrder[i].Order != groupsToOrder[j].Order {
return groupsToOrder[i].Order < groupsToOrder[j].Order
groups := gm.syncer.Get().Groups
sort.Slice(groups, func(i, j int) bool {
if groups[i].Order != groups[j].Order {
return groups[i].Order < groups[j].Order
}
return groupsToOrder[i].ID < groupsToOrder[j].ID
return groups[i].ID < groups[j].ID
})
return groupsToOrder
return groups
}
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
cache := gm.syncer.Get()
if len(cache.KeyCounts) == 0 {
return 0
}
count := cache.KeyCounts[groupID]
return count
return gm.syncer.Get().KeyCounts[groupID]
}
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
cache := gm.syncer.Get()
if len(cache.KeyStatusCounts) == 0 {
return make(map[models.APIKeyStatus]int64)
}
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok {
return counts
}
return make(map[models.APIKeyStatus]int64)
}
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
cache := gm.syncer.Get()
if len(cache.GroupsByName) == 0 {
return nil, false
}
group, ok := cache.GroupsByName[name]
group, ok := gm.syncer.Get().GroupsByName[name]
return group, ok
}
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
cache := gm.syncer.Get()
if len(cache.GroupsByID) == 0 {
return nil, false
}
group, ok := cache.GroupsByID[id]
group, ok := gm.syncer.Get().GroupsByID[id]
return group, ok
}
func (gm *GroupManager) Stop() {
gm.syncer.Stop()
}
func (gm *GroupManager) Invalidate() error {
return gm.syncer.Invalidate()
}
// --- Write Operations ---
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
if !utils.IsValidGroupName(group.Name) {
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
}
err := gm.db.Transaction(func(tx *gorm.DB) error {
// 1. Create the group itself to get an ID
if err := tx.Create(group).Error; err != nil {
return err
}
// 2. If settings are provided, create the associated GroupSettings record
if settings != nil {
// Only marshal non-nil fields to keep the JSON clean
settingsToMarshal := make(map[string]interface{})
if settings.EnableKeyCheck != nil {
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
settingsJSON, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("failed to marshal settings: %w", err)
}
if settings.KeyCheckIntervalMinutes != nil {
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
groupSettings := models.GroupSettings{
GroupID: group.ID,
SettingsJSON: datatypes.JSON(settingsJSON),
}
if settings.KeyBlacklistThreshold != nil {
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
}
if settings.KeyCooldownMinutes != nil {
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
}
if settings.KeyCheckConcurrency != nil {
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
}
if settings.KeyCheckEndpoint != nil {
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
}
if settings.KeyCheckModel != nil {
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
}
if settings.MaxRetries != nil {
settingsToMarshal["max_retries"] = settings.MaxRetries
}
if settings.EnableSmartGateway != nil {
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
}
if len(settingsToMarshal) > 0 {
settingsJSON, err := json.Marshal(settingsToMarshal)
if err != nil {
return fmt.Errorf("failed to marshal group settings: %w", err)
}
groupSettings := models.GroupSettings{
GroupID: group.ID,
SettingsJSON: datatypes.JSON(settingsJSON),
}
if err := tx.Create(&groupSettings).Error; err != nil {
return fmt.Errorf("failed to save group settings: %w", err)
}
if err := tx.Create(&groupSettings).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return err
if err == nil {
go gm.Invalidate()
}
go gm.Invalidate()
return nil
return err
}
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
if !utils.IsValidGroupName(group.Name) {
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
@@ -273,7 +163,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
uniqueModelNames := uniqueStrings(modelNames)
err := gm.db.Transaction(func(tx *gorm.DB) error {
// --- 1. Update AllowedUpstreams (M:N relationship) ---
var upstreams []*models.UpstreamEndpoint
if len(uniqueUpstreamURLs) > 0 {
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
@@ -283,7 +172,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
return err
}
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
return err
}
@@ -296,11 +184,9 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
return err
}
}
if err := tx.Model(group).Updates(group).Error; err != nil {
return err
}
var existingSettings models.GroupSettings
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
return err
@@ -308,15 +194,15 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
var currentSettingsData models.KeyGroupSettings
if len(existingSettings.SettingsJSON) > 0 {
if err := json.Unmarshal(existingSettings.SettingsJSON, &currentSettingsData); err != nil {
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
return fmt.Errorf("failed to unmarshal existing settings: %w", err)
}
}
if err := reflectutil.MergeNilFields(&currentSettingsData, newSettings); err != nil {
return fmt.Errorf("failed to merge group settings: %w", err)
return fmt.Errorf("failed to merge settings: %w", err)
}
updatedJSON, err := json.Marshal(currentSettingsData)
if err != nil {
return fmt.Errorf("failed to marshal updated group settings: %w", err)
return fmt.Errorf("failed to marshal updated settings: %w", err)
}
existingSettings.GroupID = group.ID
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
@@ -327,55 +213,25 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
}
return err
}
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
err := gm.db.Transaction(func(tx *gorm.DB) error {
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
// Step 1: First, retrieve the group object we are about to delete.
var group models.KeyGroup
if err := tx.First(&group, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
return nil // Don't treat as an error, the group is already gone.
return nil
}
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
return err
}
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil {
return err
}
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
return err
}
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
return err
}
// Also clear settings if they exist to maintain data integrity
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
return err
}
// Step 3: Delete the KeyGroup itself.
if err := tx.Delete(&group).Error; err != nil {
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
return err
}
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
if err != nil {
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
return err
}
if deletedCount > 0 {
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id)
}
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
return nil
})
if err == nil {
@@ -383,7 +239,6 @@ func (gm *GroupManager) DeleteKeyGroup(id uint) error {
}
return err
}
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
var originalGroup models.KeyGroup
if err := gm.db.
@@ -392,7 +247,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
Preload("AllowedUpstreams").
Preload("AllowedModels").
First(&originalGroup, id).Error; err != nil {
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
return nil, fmt.Errorf("failed to find original group %d: %w", id, err)
}
newGroup := originalGroup
timestamp := time.Now().Unix()
@@ -401,31 +256,25 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
newGroup.CreatedAt = time.Time{}
newGroup.UpdatedAt = time.Time{}
newGroup.RequestConfigID = nil
newGroup.RequestConfig = nil
newGroup.Mappings = nil
newGroup.AllowedUpstreams = nil
newGroup.AllowedModels = nil
err := gm.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&newGroup).Error; err != nil {
return err
}
if originalGroup.RequestConfig != nil {
newRequestConfig := *originalGroup.RequestConfig
newRequestConfig.ID = 0 // Mark as new record
newRequestConfig.ID = 0
if err := tx.Create(&newRequestConfig).Error; err != nil {
return fmt.Errorf("failed to clone request config: %w", err)
}
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
return fmt.Errorf("failed to link cloned request config: %w", err)
}
}
var originalSettings models.GroupSettings
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
if err == nil && len(originalSettings.SettingsJSON) > 0 {
@@ -434,12 +283,11 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
SettingsJSON: originalSettings.SettingsJSON,
}
if err := tx.Create(&newSettings).Error; err != nil {
return fmt.Errorf("failed to clone group settings: %w", err)
return fmt.Errorf("failed to clone settings: %w", err)
}
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("failed to query original group settings: %w", err)
return fmt.Errorf("failed to query original settings: %w", err)
}
if len(originalGroup.Mappings) > 0 {
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
for i, oldMapping := range originalGroup.Mappings {
@@ -454,7 +302,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
}
if err := tx.Create(&newMappings).Error; err != nil {
return fmt.Errorf("failed to clone key group mappings: %w", err)
return fmt.Errorf("failed to clone mappings: %w", err)
}
}
if len(originalGroup.AllowedUpstreams) > 0 {
@@ -469,13 +317,10 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
return nil
})
if err != nil {
return nil, err
}
go gm.Invalidate()
var finalClonedGroup models.KeyGroup
if err := gm.db.
Preload("RequestConfig").
@@ -487,10 +332,9 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
return &finalClonedGroup, nil
}
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
globalSettings := gm.settingsManager.GetSettings()
s := "gemini-1.5-flash" // Per user feedback for default model
defaultModel := "gemini-1.5-flash"
opConfig := &models.KeyGroupSettings{
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
@@ -498,52 +342,43 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
KeyCheckModel: &s,
KeyCheckModel: &defaultModel,
MaxRetries: &globalSettings.MaxRetries,
EnableSmartGateway: &globalSettings.EnableSmartGateway,
}
if group == nil {
return opConfig, nil
}
var groupSettingsRecord models.GroupSettings
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return opConfig, nil
}
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
return nil, err
}
if len(groupSettingsRecord.SettingsJSON) == 0 {
return opConfig, nil
}
var groupSpecificSettings models.KeyGroupSettings
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
return opConfig, err
}
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings")
return opConfig, nil
}
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings")
return opConfig, nil
}
return opConfig, nil
}
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
group, ok := gm.GetGroupByID(groupID)
if !ok {
return "", fmt.Errorf("group with id %d not found", groupID)
return "", fmt.Errorf("group %d not found", groupID)
}
opConfig, err := gm.BuildOperationalConfig(group)
if err != nil {
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
return "", err
}
globalSettings := gm.settingsManager.GetSettings()
baseURL := globalSettings.DefaultUpstreamURL
@@ -551,7 +386,7 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
baseURL = *opConfig.KeyCheckEndpoint
}
if baseURL == "" {
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
return "", fmt.Errorf("no endpoint configured for group %d", groupID)
}
modelName := globalSettings.BaseKeyCheckModel
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
@@ -559,38 +394,31 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
}
parsedURL, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err)
}
cleanedPath := parsedURL.Path
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
finalEndpoint := parsedURL.String()
return finalEndpoint, nil
cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta")
parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName)
return parsedURL.String(), nil
}
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
ordersMap := make(map[uint]int, len(payload))
for _, item := range payload {
ordersMap[item.ID] = item.Order
}
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
gm.logger.WithError(err).Error("Failed to update group order in transaction")
return fmt.Errorf("database transaction failed: %w", err)
return fmt.Errorf("failed to update order: %w", err)
}
gm.logger.Info("Group order updated successfully, invalidating cache...")
go gm.Invalidate()
return nil
}
func uniqueStrings(slice []string) []string {
keys := make(map[string]struct{})
list := []string{}
for _, entry := range slice {
if _, value := keys[entry]; !value {
keys[entry] = struct{}{}
list = append(list, entry)
seen := make(map[string]struct{}, len(slice))
result := make([]string, 0, len(slice))
for _, s := range slice {
if _, exists := seen[s]; !exists {
seen[s] = struct{}{}
result = append(result, s)
}
}
return list
return result
}

View File

@@ -704,15 +704,17 @@ func (p *memoryPipeliner) LRem(key string, count int64, value any) {
newList := make([]string, 0, len(list))
removed := int64(0)
for _, v := range list {
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
shouldRemove := v == capturedValue && (count == 0 || removed < count)
if shouldRemove {
removed++
continue
} else {
newList = append(newList, v)
}
newList = append(newList, v)
}
item.value = newList
})
}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
capturedKey := key
capturedValues := make(map[string]any, len(values))
@@ -762,17 +764,31 @@ func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]float64)}
item = &memoryStoreItem{value: make([]zsetMember, 0)}
p.store.items[capturedKey] = item
}
zset, ok := item.value.(map[string]float64)
zset, ok := item.value.([]zsetMember)
if !ok {
zset = make(map[string]float64)
item.value = zset
zset = make([]zsetMember, 0)
}
for member, score := range capturedMembers {
zset[member] = score
membersMap := make(map[string]float64, len(zset))
for _, z := range zset {
membersMap[z.Value] = z.Score
}
for memberVal, score := range capturedMembers {
membersMap[memberVal] = score
}
newZSet := make([]zsetMember, 0, len(membersMap))
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
}
return newZSet[i].Score < newZSet[j].Score
})
item.value = newZSet
})
}
func (p *memoryPipeliner) ZRem(key string, members ...any) {
@@ -784,13 +800,21 @@ func (p *memoryPipeliner) ZRem(key string, members ...any) {
if !ok || item.isExpired() {
return
}
zset, ok := item.value.(map[string]float64)
zset, ok := item.value.([]zsetMember)
if !ok {
return
}
for _, member := range capturedMembers {
delete(zset, fmt.Sprintf("%v", member))
membersToRemove := make(map[string]struct{}, len(capturedMembers))
for _, m := range capturedMembers {
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
}
newZSet := make([]zsetMember, 0, len(zset))
for _, z := range zset {
if _, exists := membersToRemove[z.Value]; !exists {
newZSet = append(newZSet, z)
}
}
item.value = newZSet
})
}