Fix requestTimeout & memory store
This commit is contained in:
@@ -92,7 +92,6 @@ func BuildContainer() (*dig.Container, error) {
|
||||
|
||||
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
|
||||
|
||||
// 为Channel提供依赖 (Logger 和 *models.SystemSettings 数据插座)
|
||||
container.Provide(channel.NewGeminiChannel)
|
||||
container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch })
|
||||
|
||||
|
||||
@@ -197,8 +197,8 @@ func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body [
|
||||
var attemptErr *errors.APIError
|
||||
var isSuccess bool
|
||||
|
||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
||||
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
attemptReq := c.Request.Clone(ctx)
|
||||
|
||||
@@ -29,10 +29,6 @@ const (
|
||||
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
|
||||
)
|
||||
|
||||
const (
|
||||
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
)
|
||||
|
||||
type BatchRestoreResult struct {
|
||||
RestoredCount int `json:"restored_count"`
|
||||
SkippedCount int `json:"skipped_count"`
|
||||
@@ -52,12 +48,6 @@ type PaginatedAPIKeys struct {
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
type KeyTestResult struct {
|
||||
Key string `json:"key"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type APIKeyService struct {
|
||||
db *gorm.DB
|
||||
keyRepo repository.KeyRepository
|
||||
@@ -99,37 +89,39 @@ func NewAPIKeyService(
|
||||
func (s *APIKeyService) Start() {
|
||||
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
|
||||
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicMasterKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
|
||||
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicImportGroupCompleted, err)
|
||||
return
|
||||
}
|
||||
s.logger.Info("Started and subscribed to request, master key, health check, and import events.")
|
||||
|
||||
s.logger.Info("Started and subscribed to all event topics")
|
||||
|
||||
go func() {
|
||||
defer requestSub.Close()
|
||||
defer masterKeySub.Close()
|
||||
defer keyStatusSub.Close()
|
||||
defer importSub.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-requestSub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event for key status update: %v", err)
|
||||
s.logger.WithError(err).Error("Failed to unmarshal RequestFinishedEvent")
|
||||
continue
|
||||
}
|
||||
s.handleKeyUsageEvent(&event)
|
||||
@@ -137,14 +129,15 @@ func (s *APIKeyService) Start() {
|
||||
case msg := <-masterKeySub.Channel():
|
||||
var event models.MasterKeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal MasterKeyStatusChangedEvent: %v", err)
|
||||
s.logger.WithError(err).Error("Failed to unmarshal MasterKeyStatusChangedEvent")
|
||||
continue
|
||||
}
|
||||
s.handleMasterKeyStatusChangeEvent(&event)
|
||||
|
||||
case msg := <-keyStatusSub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||||
s.logger.WithError(err).Error("Failed to unmarshal KeyStatusChangedEvent")
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChangeEvent(&event)
|
||||
@@ -152,15 +145,14 @@ func (s *APIKeyService) Start() {
|
||||
case msg := <-importSub.Channel():
|
||||
var event models.ImportGroupCompletedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent.")
|
||||
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent")
|
||||
continue
|
||||
}
|
||||
s.logger.Infof("Received ImportGroupCompletedEvent for group %d, triggering validation for %d keys.", event.GroupID, len(event.KeyIDs))
|
||||
|
||||
s.logger.Infof("Received import completion for group %d, validating %d keys", event.GroupID, len(event.KeyIDs))
|
||||
go s.handlePostImportValidation(&event)
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping event listener.")
|
||||
s.logger.Info("Stopping event listener")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -175,13 +167,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
|
||||
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
groupID := *event.RequestLog.GroupID
|
||||
keyID := *event.RequestLog.KeyID
|
||||
|
||||
if event.RequestLog.IsSuccess {
|
||||
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
|
||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||
if err != nil {
|
||||
s.logger.Warnf("[%s] Could not find mapping for G:%d K:%d on successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
|
||||
s.logger.Warnf("[%s] Mapping not found for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
|
||||
return
|
||||
}
|
||||
|
||||
statusChanged := false
|
||||
oldStatus := mapping.Status
|
||||
if mapping.Status != models.StatusActive {
|
||||
@@ -193,38 +190,33 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
|
||||
|
||||
now := time.Now()
|
||||
mapping.LastUsedAt = &now
|
||||
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
|
||||
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if statusChanged {
|
||||
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if event.Error != nil {
|
||||
s.judgeKeyErrors(
|
||||
ctx,
|
||||
event.CorrelationID,
|
||||
*event.RequestLog.GroupID,
|
||||
*event.RequestLog.KeyID,
|
||||
event.Error,
|
||||
event.IsPreciseRouting,
|
||||
)
|
||||
s.judgeKeyErrors(ctx, event.CorrelationID, groupID, keyID, event.Error, event.IsPreciseRouting)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
|
||||
ctx := context.Background()
|
||||
log := s.logger.WithFields(logrus.Fields{
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": event.GroupID,
|
||||
"key_id": event.KeyID,
|
||||
"new_status": event.NewStatus,
|
||||
"reason": event.ChangeReason,
|
||||
})
|
||||
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
|
||||
}).Info("Updating polling caches based on status change")
|
||||
|
||||
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
|
||||
log.Info("Polling caches updated based on health check event.")
|
||||
}
|
||||
|
||||
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
@@ -238,7 +230,7 @@ func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, k
|
||||
}
|
||||
eventData, _ := json.Marshal(changeEvent)
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
|
||||
s.logger.Errorf("Failed to publish status change event for group %d: %v", groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,10 +240,12 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
totalPages := 0
|
||||
if total > 0 && params.PageSize > 0 {
|
||||
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
|
||||
}
|
||||
|
||||
return &PaginatedAPIKeys{
|
||||
Items: items,
|
||||
Total: total,
|
||||
@@ -260,34 +254,44 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
||||
var statusesToFilter []string
|
||||
|
||||
s.logger.Infof("Performing in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
||||
|
||||
statusesToFilter := []string{"all"}
|
||||
if params.Status != "" {
|
||||
statusesToFilter = append(statusesToFilter, params.Status)
|
||||
} else {
|
||||
statusesToFilter = append(statusesToFilter, "all")
|
||||
statusesToFilter = []string{params.Status}
|
||||
}
|
||||
|
||||
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all key IDs for search: %w", err)
|
||||
return nil, fmt.Errorf("failed to fetch key IDs: %w", err)
|
||||
}
|
||||
|
||||
if len(allKeyIDs) == 0 {
|
||||
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
|
||||
return &PaginatedAPIKeys{
|
||||
Items: []*models.APIKeyDetails{},
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: params.PageSize,
|
||||
TotalPages: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
|
||||
return nil, fmt.Errorf("failed to fetch keys: %w", err)
|
||||
}
|
||||
|
||||
var allMappings []models.GroupAPIKeyMapping
|
||||
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
|
||||
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch mappings: %w", err)
|
||||
}
|
||||
mappingMap := make(map[uint]*models.GroupAPIKeyMapping)
|
||||
|
||||
mappingMap := make(map[uint]*models.GroupAPIKeyMapping, len(allMappings))
|
||||
for i := range allMappings {
|
||||
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
|
||||
}
|
||||
|
||||
var filteredItems []*models.APIKeyDetails
|
||||
for _, key := range allKeys {
|
||||
if strings.Contains(key.APIKey, params.Keyword) {
|
||||
@@ -307,12 +311,15 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(filteredItems, func(i, j int) bool {
|
||||
return filteredItems[i].ID > filteredItems[j].ID
|
||||
})
|
||||
|
||||
total := int64(len(filteredItems))
|
||||
start := (params.Page - 1) * params.PageSize
|
||||
end := start + params.PageSize
|
||||
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
@@ -328,9 +335,9 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
||||
if end > len(filteredItems) {
|
||||
end = len(filteredItems)
|
||||
}
|
||||
paginatedItems := filteredItems[start:end]
|
||||
|
||||
return &PaginatedAPIKeys{
|
||||
Items: paginatedItems,
|
||||
Items: filteredItems[start:end],
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
@@ -344,15 +351,8 @@ func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.
|
||||
|
||||
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
var oldKey models.APIKey
|
||||
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
|
||||
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
|
||||
return
|
||||
}
|
||||
if err := s.keyRepo.Update(key); err != nil {
|
||||
s.logger.Errorf("Failed to asynchronously update key ID %d: %v", key.ID, err)
|
||||
return
|
||||
s.logger.Errorf("Failed to update key ID %d: %v", key.ID, err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
@@ -361,22 +361,24 @@ func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) er
|
||||
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
|
||||
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
|
||||
s.logger.Warnf("Could not get groups for key %d before deletion: %v", id, err)
|
||||
}
|
||||
|
||||
err = s.keyRepo.HardDeleteByID(id)
|
||||
if err == nil {
|
||||
for _, groupID := range groups {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: id,
|
||||
GroupID: groupID,
|
||||
ChangeReason: "key_hard_deleted",
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
if err := s.keyRepo.HardDeleteByID(id); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
|
||||
for _, groupID := range groups {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: id,
|
||||
GroupID: groupID,
|
||||
ChangeReason: "key_hard_deleted",
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
||||
@@ -388,47 +390,54 @@ func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID
|
||||
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||
return nil, CustomErrors.ErrStateConflictMasterRevoked
|
||||
}
|
||||
|
||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus == newStatus {
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
mapping.Status = newStatus
|
||||
if newStatus == models.StatusActive {
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
}
|
||||
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
|
||||
ctx := context.Background()
|
||||
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
|
||||
if event.NewMasterStatus != models.MasterStatusRevoked {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
s.logger.Infof("Key %d revoked, propagating to all groups", event.KeyID)
|
||||
|
||||
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
|
||||
s.logger.WithError(err).Errorf("Failed to get groups for key %d", event.KeyID)
|
||||
return
|
||||
}
|
||||
|
||||
if len(affectedGroupIDs) == 0 {
|
||||
s.logger.Infof("Key ID %d is revoked, but it's not associated with any group. No action needed.", event.KeyID)
|
||||
s.logger.Infof("Key %d not associated with any group", event.KeyID)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
|
||||
|
||||
for _, groupID := range affectedGroupIDs {
|
||||
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
|
||||
if err != nil {
|
||||
if _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned); err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
|
||||
s.logger.WithError(err).Errorf("Failed to ban key %d in group %d", event.KeyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -436,59 +445,71 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
|
||||
|
||||
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided")
|
||||
}
|
||||
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
|
||||
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
|
||||
s.logger.Errorf("Panic in restore task: %v", r)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
var mappingsToProcess []models.GroupAPIKeyMapping
|
||||
err := s.db.WithContext(ctx).Preload("APIKey").
|
||||
if err := s.db.WithContext(ctx).Preload("APIKey").
|
||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||
Find(&mappingsToProcess).Error
|
||||
if err != nil {
|
||||
Find(&mappingsToProcess).Error; err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
result := &BatchRestoreResult{
|
||||
SkippedKeys: make([]SkippedKeyInfo, 0),
|
||||
}
|
||||
|
||||
result := &BatchRestoreResult{SkippedKeys: make([]SkippedKeyInfo, 0)}
|
||||
var successfulMappings []*models.GroupAPIKeyMapping
|
||||
processedCount := 0
|
||||
for _, mapping := range mappingsToProcess {
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
||||
|
||||
for i, mapping := range mappingsToProcess {
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
|
||||
|
||||
if mapping.APIKey == nil {
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
|
||||
KeyID: mapping.APIKeyID,
|
||||
Reason: "APIKey not found",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: fmt.Sprintf("Master status is '%s'.", mapping.APIKey.MasterStatus)})
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
|
||||
KeyID: mapping.APIKeyID,
|
||||
Reason: fmt.Sprintf("Master status is %s", mapping.APIKey.MasterStatus),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus != models.StatusActive {
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
|
||||
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
|
||||
KeyID: mapping.APIKeyID,
|
||||
Reason: "DB update failed",
|
||||
})
|
||||
} else {
|
||||
result.RestoredCount++
|
||||
successfulMappings = append(successfulMappings, &mapping)
|
||||
@@ -498,61 +519,72 @@ func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, r
|
||||
result.RestoredCount++
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
|
||||
s.logger.WithError(err).Error("Failed batch cache update after restore")
|
||||
}
|
||||
|
||||
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
|
||||
var bannedKeyIDs []uint
|
||||
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
|
||||
Pluck("api_key_id", &bannedKeyIDs).Error
|
||||
if err != nil {
|
||||
Pluck("api_key_id", &bannedKeyIDs).Error; err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
|
||||
if len(bannedKeyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore")
|
||||
}
|
||||
|
||||
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
|
||||
ctx := context.Background()
|
||||
|
||||
group, ok := s.groupManager.GetGroupByID(event.GroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
|
||||
s.logger.Errorf("Group %d not found for post-import validation", event.GroupID)
|
||||
return
|
||||
}
|
||||
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to build operational config for group %d, aborting validation: %v", event.GroupID, err)
|
||||
s.logger.Errorf("Failed to build config for group %d: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to build key check endpoint for group %d, aborting validation: %v", event.GroupID, err)
|
||||
s.logger.Errorf("Failed to build endpoint for group %d: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
globalSettings := s.SettingsManager.GetSettings()
|
||||
concurrency := globalSettings.BaseKeyCheckConcurrency
|
||||
|
||||
concurrency := s.SettingsManager.GetSettings().BaseKeyCheckConcurrency
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
timeout := time.Duration(s.SettingsManager.GetSettings().KeyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to get key models for validation in group %d: %v", event.GroupID, err)
|
||||
s.logger.Errorf("Failed to get keys for validation in group %d: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Validating %d keys for group %d with concurrency %d against endpoint %s", len(keysToValidate), event.GroupID, concurrency, endpoint)
|
||||
|
||||
s.logger.Infof("Validating %d keys for group %d (concurrency: %d)", len(keysToValidate), event.GroupID, concurrency)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan models.APIKey, len(keysToValidate))
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -560,9 +592,8 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
||||
for key := range jobs {
|
||||
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
|
||||
if validationErr == nil {
|
||||
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
|
||||
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
|
||||
s.logger.Errorf("Failed to activate key %d in group %d: %v", key.ID, event.GroupID, err)
|
||||
}
|
||||
} else {
|
||||
var apiErr *CustomErrors.APIError
|
||||
@@ -574,35 +605,34 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for _, key := range keysToValidate {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
|
||||
|
||||
s.logger.Infof("Finished post-import validation for group %d", event.GroupID)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
||||
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
|
||||
|
||||
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
|
||||
if len(keyIDs) == 0 {
|
||||
now := time.Now()
|
||||
return &task.Status{
|
||||
IsRunning: false,
|
||||
Processed: 0,
|
||||
Total: 0,
|
||||
Result: map[string]string{
|
||||
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
|
||||
},
|
||||
Error: "",
|
||||
IsRunning: false,
|
||||
Processed: 0,
|
||||
Total: 0,
|
||||
Result: map[string]string{"message": "没有找到符合条件的Key"},
|
||||
StartedAt: now,
|
||||
FinishedAt: &now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
||||
if err != nil {
|
||||
@@ -616,10 +646,11 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, group
|
||||
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
|
||||
s.logger.Errorf("Panic in status update task: %v", r)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
type BatchUpdateResult struct {
|
||||
UpdatedCount int `json:"updated_count"`
|
||||
SkippedCount int `json:"skipped_count"`
|
||||
@@ -629,33 +660,35 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
|
||||
|
||||
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
|
||||
|
||||
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus, len(keys))
|
||||
for _, key := range keys {
|
||||
masterStatusMap[key.ID] = key.MasterStatus
|
||||
}
|
||||
|
||||
var mappings []*models.GroupAPIKeyMapping
|
||||
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
processedCount := 0
|
||||
for _, mapping := range mappings {
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
||||
|
||||
for i, mapping := range mappings {
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
|
||||
|
||||
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
result.SkippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||
result.SkippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus != newStatus {
|
||||
mapping.Status = newStatus
|
||||
@@ -663,6 +696,7 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
}
|
||||
|
||||
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
|
||||
result.SkippedCount++
|
||||
} else {
|
||||
@@ -674,122 +708,112 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
|
||||
result.UpdatedCount++
|
||||
}
|
||||
}
|
||||
|
||||
result.SkippedCount += (len(keyIDs) - len(mappings))
|
||||
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
|
||||
s.logger.WithError(err).Error("Failed batch cache update after status update")
|
||||
}
|
||||
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
|
||||
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
|
||||
ctx := context.Background()
|
||||
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if apiErr == nil {
|
||||
s.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided. No action taken.", key.ID, group.ID)
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := apiErr.Message
|
||||
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
|
||||
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
|
||||
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
|
||||
} else {
|
||||
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeForLog(errMsg string) string {
|
||||
jsonStartIndex := strings.Index(errMsg, "{")
|
||||
var cleanMsg string
|
||||
if jsonStartIndex != -1 {
|
||||
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
|
||||
} else {
|
||||
cleanMsg = errMsg
|
||||
}
|
||||
const maxLen = 250
|
||||
if len(cleanMsg) > maxLen {
|
||||
return cleanMsg[:maxLen] + "..."
|
||||
}
|
||||
return cleanMsg
|
||||
}
|
||||
func (s *APIKeyService) judgeKeyErrors(ctx context.Context, correlationID string, groupID, keyID uint, apiErr *CustomErrors.APIError, isPreciseRouting bool) {
|
||||
logger := s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"correlation_id": correlationID,
|
||||
})
|
||||
|
||||
func (s *APIKeyService) judgeKeyErrors(
|
||||
ctx context.Context,
|
||||
correlationID string,
|
||||
groupID, keyID uint,
|
||||
apiErr *CustomErrors.APIError,
|
||||
isPreciseRouting bool,
|
||||
) {
|
||||
logger := s.logger.WithFields(logrus.Fields{"group_id": groupID, "key_id": keyID, "correlation_id": correlationID})
|
||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Cannot apply consequences for error: mapping not found.")
|
||||
logger.WithError(err).Warn("Mapping not found, cannot apply error consequences")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
mapping.LastUsedAt = &now
|
||||
errorMessage := apiErr.Message
|
||||
|
||||
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
|
||||
logger.Errorf("Permanent error detected. Banning mapping and revoking master key. Reason: %s", sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the permanent error.")
|
||||
logger.Errorf("Permanent error: %s", sanitizeForLog(errorMessage))
|
||||
|
||||
if mapping.Status != models.StatusBanned {
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = errorMessage
|
||||
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
|
||||
logger.WithError(err).Error("Failed to ban mapping")
|
||||
} else {
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusBanned, "permanent_error")
|
||||
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
|
||||
mapping.LastError = errorMessage
|
||||
mapping.ConsecutiveErrorCount++
|
||||
var threshold int
|
||||
|
||||
threshold := s.SettingsManager.GetSettings().BlacklistThreshold
|
||||
if isPreciseRouting {
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if !ok || err != nil {
|
||||
logger.Warnf("Could not build operational config for group %d in Precise Routing mode. Falling back to global settings.", groupID)
|
||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
||||
} else {
|
||||
threshold = *opConfig.KeyBlacklistThreshold
|
||||
if group, ok := s.groupManager.GetGroupByID(groupID); ok {
|
||||
if opConfig, err := s.groupManager.BuildOperationalConfig(group); err == nil && opConfig.KeyBlacklistThreshold != nil {
|
||||
threshold = *opConfig.KeyBlacklistThreshold
|
||||
}
|
||||
}
|
||||
} else {
|
||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
||||
}
|
||||
logger.Warnf("Temporary error detected. Incrementing error count. New count: %d (Threshold: %d). Reason: %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the temporary error.")
|
||||
|
||||
logger.Warnf("Temporary error (count: %d, threshold: %d): %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
|
||||
|
||||
oldStatus := mapping.Status
|
||||
newStatus := oldStatus
|
||||
|
||||
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
|
||||
newStatus = models.StatusCooldown
|
||||
logger.Errorf("Putting mapping into COOLDOWN due to reaching temporary error threshold (%d)", threshold)
|
||||
logger.Errorf("Moving to COOLDOWN after reaching threshold %d", threshold)
|
||||
}
|
||||
|
||||
if oldStatus != newStatus {
|
||||
mapping.Status = newStatus
|
||||
}
|
||||
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping after temporary error.")
|
||||
logger.WithError(err).Error("Failed to update mapping after temporary error")
|
||||
return
|
||||
}
|
||||
|
||||
if oldStatus != newStatus {
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
|
||||
|
||||
logger.Infof("Ignorable error: %s", sanitizeForLog(errorMessage))
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
|
||||
logger.WithError(err).Error("Failed to update LastUsedAt")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -797,25 +821,27 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason
|
||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.Warnf("Attempted to revoke non-existent key ID %d.", keyID)
|
||||
s.logger.Warnf("Attempted to revoke non-existent key %d", keyID)
|
||||
} else {
|
||||
s.logger.Errorf("Failed to get key by ID %d for master status revocation: %v", keyID, err)
|
||||
s.logger.Errorf("Failed to get key %d for revocation: %v", keyID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if key.MasterStatus == models.MasterStatusRevoked {
|
||||
return
|
||||
}
|
||||
|
||||
oldMasterStatus := key.MasterStatus
|
||||
newMasterStatus := models.MasterStatusRevoked
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
|
||||
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, models.MasterStatusRevoked); err != nil {
|
||||
s.logger.Errorf("Failed to revoke key %d: %v", keyID, err)
|
||||
return
|
||||
}
|
||||
|
||||
masterKeyEvent := models.MasterKeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
OldMasterStatus: oldMasterStatus,
|
||||
NewMasterStatus: newMasterStatus,
|
||||
NewMasterStatus: models.MasterStatusRevoked,
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
@@ -826,3 +852,13 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason
|
||||
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
|
||||
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
|
||||
}
|
||||
|
||||
func sanitizeForLog(errMsg string) string {
|
||||
if idx := strings.Index(errMsg, "{"); idx != -1 {
|
||||
errMsg = strings.TrimSpace(errMsg[:idx]) + " {...}"
|
||||
}
|
||||
if len(errMsg) > 250 {
|
||||
return errMsg[:250] + "..."
|
||||
}
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Filename: internal/service/group_manager.go (Syncer升级版)
|
||||
|
||||
// Filename: internal/service/group_manager.go
|
||||
package service
|
||||
|
||||
import (
|
||||
@@ -29,10 +28,9 @@ type GroupManagerCacheData struct {
|
||||
Groups []*models.KeyGroup
|
||||
GroupsByName map[string]*models.KeyGroup
|
||||
GroupsByID map[uint]*models.KeyGroup
|
||||
KeyCounts map[uint]int64 // GroupID -> Total Key Count
|
||||
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
|
||||
KeyCounts map[uint]int64
|
||||
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64
|
||||
}
|
||||
|
||||
type GroupManager struct {
|
||||
db *gorm.DB
|
||||
keyRepo repository.KeyRepository
|
||||
@@ -41,7 +39,6 @@ type GroupManager struct {
|
||||
syncer *syncer.CacheSyncer[GroupManagerCacheData]
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
type UpdateOrderPayload struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Order int `json:"order"`
|
||||
@@ -49,43 +46,19 @@ type UpdateOrderPayload struct {
|
||||
|
||||
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
|
||||
return func() (GroupManagerCacheData, error) {
|
||||
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
|
||||
var groups []*models.KeyGroup
|
||||
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
|
||||
|
||||
if err := db.Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
Preload("Settings").
|
||||
Preload("RequestConfig").
|
||||
Preload("Mappings").
|
||||
Find(&groups).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
|
||||
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := db.Find(&allMappings).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
|
||||
|
||||
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
for i := range allMappings {
|
||||
mapping := allMappings[i] // Avoid pointer issues with range
|
||||
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if mappings, ok := mappingsByGroupID[group.ID]; ok {
|
||||
group.Mappings = mappings
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
|
||||
|
||||
keyCounts := make(map[uint]int64, len(groups))
|
||||
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
|
||||
|
||||
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
||||
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
||||
for _, group := range groups {
|
||||
keyCounts[group.ID] = int64(len(group.Mappings))
|
||||
statusCounts := make(map[models.APIKeyStatus]int64)
|
||||
@@ -93,20 +66,9 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
|
||||
statusCounts[mapping.Status]++
|
||||
}
|
||||
keyStatusCounts[group.ID] = statusCounts
|
||||
groupsByName[group.Name] = group
|
||||
groupsByID[group.ID] = group
|
||||
}
|
||||
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
||||
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
||||
|
||||
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
|
||||
for i, group := range groups {
|
||||
if group == nil {
|
||||
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
|
||||
} else {
|
||||
groupsByName[group.Name] = group
|
||||
groupsByID[group.ID] = group
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
|
||||
return GroupManagerCacheData{
|
||||
Groups: groups,
|
||||
GroupsByName: groupsByName,
|
||||
@@ -116,7 +78,6 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewGroupManager(
|
||||
db *gorm.DB,
|
||||
keyRepo repository.KeyRepository,
|
||||
@@ -134,138 +95,67 @@ func NewGroupManager(
|
||||
logger: logger.WithField("component", "GroupManager"),
|
||||
}
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.Groups) == 0 {
|
||||
return []*models.KeyGroup{}
|
||||
}
|
||||
groupsToOrder := cache.Groups
|
||||
sort.Slice(groupsToOrder, func(i, j int) bool {
|
||||
if groupsToOrder[i].Order != groupsToOrder[j].Order {
|
||||
return groupsToOrder[i].Order < groupsToOrder[j].Order
|
||||
groups := gm.syncer.Get().Groups
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
if groups[i].Order != groups[j].Order {
|
||||
return groups[i].Order < groups[j].Order
|
||||
}
|
||||
return groupsToOrder[i].ID < groupsToOrder[j].ID
|
||||
return groups[i].ID < groups[j].ID
|
||||
})
|
||||
return groupsToOrder
|
||||
return groups
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyCounts) == 0 {
|
||||
return 0
|
||||
}
|
||||
count := cache.KeyCounts[groupID]
|
||||
return count
|
||||
return gm.syncer.Get().KeyCounts[groupID]
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyStatusCounts) == 0 {
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
|
||||
if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok {
|
||||
return counts
|
||||
}
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByName) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByName[name]
|
||||
group, ok := gm.syncer.Get().GroupsByName[name]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByID) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByID[id]
|
||||
group, ok := gm.syncer.Get().GroupsByID[id]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Stop() {
|
||||
gm.syncer.Stop()
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Invalidate() error {
|
||||
return gm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
// --- Write Operations ---
|
||||
|
||||
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
|
||||
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
}
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Create the group itself to get an ID
|
||||
if err := tx.Create(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. If settings are provided, create the associated GroupSettings record
|
||||
if settings != nil {
|
||||
// Only marshal non-nil fields to keep the JSON clean
|
||||
settingsToMarshal := make(map[string]interface{})
|
||||
if settings.EnableKeyCheck != nil {
|
||||
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
|
||||
settingsJSON, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal settings: %w", err)
|
||||
}
|
||||
if settings.KeyCheckIntervalMinutes != nil {
|
||||
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
|
||||
groupSettings := models.GroupSettings{
|
||||
GroupID: group.ID,
|
||||
SettingsJSON: datatypes.JSON(settingsJSON),
|
||||
}
|
||||
if settings.KeyBlacklistThreshold != nil {
|
||||
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
|
||||
}
|
||||
if settings.KeyCooldownMinutes != nil {
|
||||
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
|
||||
}
|
||||
if settings.KeyCheckConcurrency != nil {
|
||||
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
|
||||
}
|
||||
if settings.KeyCheckEndpoint != nil {
|
||||
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
|
||||
}
|
||||
if settings.KeyCheckModel != nil {
|
||||
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
|
||||
}
|
||||
if settings.MaxRetries != nil {
|
||||
settingsToMarshal["max_retries"] = settings.MaxRetries
|
||||
}
|
||||
if settings.EnableSmartGateway != nil {
|
||||
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
|
||||
}
|
||||
if len(settingsToMarshal) > 0 {
|
||||
settingsJSON, err := json.Marshal(settingsToMarshal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal group settings: %w", err)
|
||||
}
|
||||
groupSettings := models.GroupSettings{
|
||||
GroupID: group.ID,
|
||||
SettingsJSON: datatypes.JSON(settingsJSON),
|
||||
}
|
||||
if err := tx.Create(&groupSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to save group settings: %w", err)
|
||||
}
|
||||
if err := tx.Create(&groupSettings).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
if err == nil {
|
||||
go gm.Invalidate()
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
|
||||
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
@@ -273,7 +163,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
|
||||
uniqueModelNames := uniqueStrings(modelNames)
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// --- 1. Update AllowedUpstreams (M:N relationship) ---
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if len(uniqueUpstreamURLs) > 0 {
|
||||
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
|
||||
@@ -283,7 +172,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -296,11 +184,9 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Updates(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var existingSettings models.GroupSettings
|
||||
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
@@ -308,15 +194,15 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
var currentSettingsData models.KeyGroupSettings
|
||||
if len(existingSettings.SettingsJSON) > 0 {
|
||||
if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
|
||||
return fmt.Errorf("failed to unmarshal existing settings: %w", err)
|
||||
}
|
||||
}
|
||||
if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil {
|
||||
return fmt.Errorf("failed to merge group settings: %w", err)
|
||||
return fmt.Errorf("failed to merge settings: %w", err)
|
||||
}
|
||||
updatedJSON, err := json.Marshal(currentSettingsData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal updated group settings: %w", err)
|
||||
return fmt.Errorf("failed to marshal updated settings: %w", err)
|
||||
}
|
||||
existingSettings.GroupID = group.ID
|
||||
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
|
||||
@@ -327,55 +213,25 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
|
||||
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
|
||||
// Step 1: First, retrieve the group object we are about to delete.
|
||||
var group models.KeyGroup
|
||||
if err := tx.First(&group, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
|
||||
return nil // Don't treat as an error, the group is already gone.
|
||||
return nil
|
||||
}
|
||||
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
|
||||
return err
|
||||
}
|
||||
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
|
||||
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
|
||||
if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Also clear settings if they exist to maintain data integrity
|
||||
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Step 3: Delete the KeyGroup itself.
|
||||
if err := tx.Delete(&group).Error; err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
|
||||
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
|
||||
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
|
||||
if err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
|
||||
return err
|
||||
}
|
||||
if deletedCount > 0 {
|
||||
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
|
||||
gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id)
|
||||
}
|
||||
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
@@ -383,7 +239,6 @@ func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
var originalGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
@@ -392,7 +247,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
First(&originalGroup, id).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
|
||||
return nil, fmt.Errorf("failed to find original group %d: %w", id, err)
|
||||
}
|
||||
newGroup := originalGroup
|
||||
timestamp := time.Now().Unix()
|
||||
@@ -401,31 +256,25 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
|
||||
newGroup.CreatedAt = time.Time{}
|
||||
newGroup.UpdatedAt = time.Time{}
|
||||
|
||||
newGroup.RequestConfigID = nil
|
||||
newGroup.RequestConfig = nil
|
||||
newGroup.Mappings = nil
|
||||
newGroup.AllowedUpstreams = nil
|
||||
newGroup.AllowedModels = nil
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
if err := tx.Create(&newGroup).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if originalGroup.RequestConfig != nil {
|
||||
newRequestConfig := *originalGroup.RequestConfig
|
||||
newRequestConfig.ID = 0 // Mark as new record
|
||||
|
||||
newRequestConfig.ID = 0
|
||||
if err := tx.Create(&newRequestConfig).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone request config: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
|
||||
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
|
||||
return fmt.Errorf("failed to link cloned request config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var originalSettings models.GroupSettings
|
||||
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
|
||||
if err == nil && len(originalSettings.SettingsJSON) > 0 {
|
||||
@@ -434,12 +283,11 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
SettingsJSON: originalSettings.SettingsJSON,
|
||||
}
|
||||
if err := tx.Create(&newSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone group settings: %w", err)
|
||||
return fmt.Errorf("failed to clone settings: %w", err)
|
||||
}
|
||||
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("failed to query original group settings: %w", err)
|
||||
return fmt.Errorf("failed to query original settings: %w", err)
|
||||
}
|
||||
|
||||
if len(originalGroup.Mappings) > 0 {
|
||||
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
|
||||
for i, oldMapping := range originalGroup.Mappings {
|
||||
@@ -454,7 +302,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
}
|
||||
}
|
||||
if err := tx.Create(&newMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone key group mappings: %w", err)
|
||||
return fmt.Errorf("failed to clone mappings: %w", err)
|
||||
}
|
||||
}
|
||||
if len(originalGroup.AllowedUpstreams) > 0 {
|
||||
@@ -469,13 +317,10 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
|
||||
var finalClonedGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
Preload("RequestConfig").
|
||||
@@ -487,10 +332,9 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
}
|
||||
return &finalClonedGroup, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
s := "gemini-1.5-flash" // Per user feedback for default model
|
||||
defaultModel := "gemini-1.5-flash"
|
||||
opConfig := &models.KeyGroupSettings{
|
||||
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
||||
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
||||
@@ -498,52 +342,43 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
|
||||
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
||||
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
||||
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
||||
KeyCheckModel: &s,
|
||||
KeyCheckModel: &defaultModel,
|
||||
MaxRetries: &globalSettings.MaxRetries,
|
||||
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
||||
}
|
||||
|
||||
if group == nil {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSettingsRecord models.GroupSettings
|
||||
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(groupSettingsRecord.SettingsJSON) == 0 {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSpecificSettings models.KeyGroupSettings
|
||||
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
|
||||
return opConfig, err
|
||||
}
|
||||
|
||||
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings")
|
||||
return opConfig, nil
|
||||
}
|
||||
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings")
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
group, ok := gm.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("group with id %d not found", groupID)
|
||||
return "", fmt.Errorf("group %d not found", groupID)
|
||||
}
|
||||
opConfig, err := gm.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
|
||||
return "", err
|
||||
}
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
baseURL := globalSettings.DefaultUpstreamURL
|
||||
@@ -551,7 +386,7 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
baseURL = *opConfig.KeyCheckEndpoint
|
||||
}
|
||||
if baseURL == "" {
|
||||
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
|
||||
return "", fmt.Errorf("no endpoint configured for group %d", groupID)
|
||||
}
|
||||
modelName := globalSettings.BaseKeyCheckModel
|
||||
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
|
||||
@@ -559,38 +394,31 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
}
|
||||
parsedURL, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
|
||||
return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err)
|
||||
}
|
||||
cleanedPath := parsedURL.Path
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
|
||||
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
|
||||
finalEndpoint := parsedURL.String()
|
||||
return finalEndpoint, nil
|
||||
cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta")
|
||||
parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName)
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
|
||||
ordersMap := make(map[uint]int, len(payload))
|
||||
for _, item := range payload {
|
||||
ordersMap[item.ID] = item.Order
|
||||
}
|
||||
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to update group order in transaction")
|
||||
return fmt.Errorf("database transaction failed: %w", err)
|
||||
return fmt.Errorf("failed to update order: %w", err)
|
||||
}
|
||||
gm.logger.Info("Group order updated successfully, invalidating cache...")
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func uniqueStrings(slice []string) []string {
|
||||
keys := make(map[string]struct{})
|
||||
list := []string{}
|
||||
for _, entry := range slice {
|
||||
if _, value := keys[entry]; !value {
|
||||
keys[entry] = struct{}{}
|
||||
list = append(list, entry)
|
||||
seen := make(map[string]struct{}, len(slice))
|
||||
result := make([]string, 0, len(slice))
|
||||
for _, s := range slice {
|
||||
if _, exists := seen[s]; !exists {
|
||||
seen[s] = struct{}{}
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return list
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -704,15 +704,17 @@ func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
||||
newList := make([]string, 0, len(list))
|
||||
removed := int64(0)
|
||||
for _, v := range list {
|
||||
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
|
||||
shouldRemove := v == capturedValue && (count == 0 || removed < count)
|
||||
if shouldRemove {
|
||||
removed++
|
||||
continue
|
||||
} else {
|
||||
newList = append(newList, v)
|
||||
}
|
||||
newList = append(newList, v)
|
||||
}
|
||||
item.value = newList
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||
capturedKey := key
|
||||
capturedValues := make(map[string]any, len(values))
|
||||
@@ -762,17 +764,31 @@ func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]float64)}
|
||||
item = &memoryStoreItem{value: make([]zsetMember, 0)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
zset, ok := item.value.(map[string]float64)
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
zset = make(map[string]float64)
|
||||
item.value = zset
|
||||
zset = make([]zsetMember, 0)
|
||||
}
|
||||
for member, score := range capturedMembers {
|
||||
zset[member] = score
|
||||
membersMap := make(map[string]float64, len(zset))
|
||||
for _, z := range zset {
|
||||
membersMap[z.Value] = z.Score
|
||||
}
|
||||
for memberVal, score := range capturedMembers {
|
||||
membersMap[memberVal] = score
|
||||
}
|
||||
newZSet := make([]zsetMember, 0, len(membersMap))
|
||||
for val, score := range membersMap {
|
||||
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
|
||||
}
|
||||
sort.Slice(newZSet, func(i, j int) bool {
|
||||
if newZSet[i].Score == newZSet[j].Score {
|
||||
return newZSet[i].Value < newZSet[j].Value
|
||||
}
|
||||
return newZSet[i].Score < newZSet[j].Score
|
||||
})
|
||||
item.value = newZSet
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
||||
@@ -784,13 +800,21 @@ func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
zset, ok := item.value.(map[string]float64)
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, member := range capturedMembers {
|
||||
delete(zset, fmt.Sprintf("%v", member))
|
||||
membersToRemove := make(map[string]struct{}, len(capturedMembers))
|
||||
for _, m := range capturedMembers {
|
||||
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
|
||||
}
|
||||
newZSet := make([]zsetMember, 0, len(zset))
|
||||
for _, z := range zset {
|
||||
if _, exists := membersToRemove[z.Value]; !exists {
|
||||
newZSet = append(newZSet, z)
|
||||
}
|
||||
}
|
||||
item.value = newZSet
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user