Fix requestTimeout & memory store
This commit is contained in:
@@ -92,7 +92,6 @@ func BuildContainer() (*dig.Container, error) {
|
|||||||
|
|
||||||
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
|
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
|
||||||
|
|
||||||
// 为Channel提供依赖 (Logger 和 *models.SystemSettings 数据插座)
|
|
||||||
container.Provide(channel.NewGeminiChannel)
|
container.Provide(channel.NewGeminiChannel)
|
||||||
container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch })
|
container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch })
|
||||||
|
|
||||||
|
|||||||
@@ -197,8 +197,8 @@ func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body [
|
|||||||
var attemptErr *errors.APIError
|
var attemptErr *errors.APIError
|
||||||
var isSuccess bool
|
var isSuccess bool
|
||||||
|
|
||||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
|
||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
attemptReq := c.Request.Clone(ctx)
|
attemptReq := c.Request.Clone(ctx)
|
||||||
|
|||||||
@@ -29,10 +29,6 @@ const (
|
|||||||
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
|
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BatchRestoreResult struct {
|
type BatchRestoreResult struct {
|
||||||
RestoredCount int `json:"restored_count"`
|
RestoredCount int `json:"restored_count"`
|
||||||
SkippedCount int `json:"skipped_count"`
|
SkippedCount int `json:"skipped_count"`
|
||||||
@@ -52,12 +48,6 @@ type PaginatedAPIKeys struct {
|
|||||||
TotalPages int `json:"total_pages"`
|
TotalPages int `json:"total_pages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyTestResult struct {
|
|
||||||
Key string `json:"key"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type APIKeyService struct {
|
type APIKeyService struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
keyRepo repository.KeyRepository
|
keyRepo repository.KeyRepository
|
||||||
@@ -99,37 +89,39 @@ func NewAPIKeyService(
|
|||||||
func (s *APIKeyService) Start() {
|
func (s *APIKeyService) Start() {
|
||||||
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
|
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicRequestFinished, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
|
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
|
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicMasterKeyStatusChanged, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicKeyStatusChanged, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
|
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
|
s.logger.Fatalf("Failed to subscribe to %s: %v", models.TopicImportGroupCompleted, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Info("Started and subscribed to request, master key, health check, and import events.")
|
|
||||||
|
s.logger.Info("Started and subscribed to all event topics")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer requestSub.Close()
|
defer requestSub.Close()
|
||||||
defer masterKeySub.Close()
|
defer masterKeySub.Close()
|
||||||
defer keyStatusSub.Close()
|
defer keyStatusSub.Close()
|
||||||
defer importSub.Close()
|
defer importSub.Close()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-requestSub.Channel():
|
case msg := <-requestSub.Channel():
|
||||||
var event models.RequestFinishedEvent
|
var event models.RequestFinishedEvent
|
||||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||||
s.logger.Errorf("Failed to unmarshal event for key status update: %v", err)
|
s.logger.WithError(err).Error("Failed to unmarshal RequestFinishedEvent")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handleKeyUsageEvent(&event)
|
s.handleKeyUsageEvent(&event)
|
||||||
@@ -137,14 +129,15 @@ func (s *APIKeyService) Start() {
|
|||||||
case msg := <-masterKeySub.Channel():
|
case msg := <-masterKeySub.Channel():
|
||||||
var event models.MasterKeyStatusChangedEvent
|
var event models.MasterKeyStatusChangedEvent
|
||||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||||
s.logger.Errorf("Failed to unmarshal MasterKeyStatusChangedEvent: %v", err)
|
s.logger.WithError(err).Error("Failed to unmarshal MasterKeyStatusChangedEvent")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handleMasterKeyStatusChangeEvent(&event)
|
s.handleMasterKeyStatusChangeEvent(&event)
|
||||||
|
|
||||||
case msg := <-keyStatusSub.Channel():
|
case msg := <-keyStatusSub.Channel():
|
||||||
var event models.KeyStatusChangedEvent
|
var event models.KeyStatusChangedEvent
|
||||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
s.logger.WithError(err).Error("Failed to unmarshal KeyStatusChangedEvent")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handleKeyStatusChangeEvent(&event)
|
s.handleKeyStatusChangeEvent(&event)
|
||||||
@@ -152,15 +145,14 @@ func (s *APIKeyService) Start() {
|
|||||||
case msg := <-importSub.Channel():
|
case msg := <-importSub.Channel():
|
||||||
var event models.ImportGroupCompletedEvent
|
var event models.ImportGroupCompletedEvent
|
||||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||||
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent.")
|
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.logger.Infof("Received ImportGroupCompletedEvent for group %d, triggering validation for %d keys.", event.GroupID, len(event.KeyIDs))
|
s.logger.Infof("Received import completion for group %d, validating %d keys", event.GroupID, len(event.KeyIDs))
|
||||||
|
|
||||||
go s.handlePostImportValidation(&event)
|
go s.handlePostImportValidation(&event)
|
||||||
|
|
||||||
case <-s.stopChan:
|
case <-s.stopChan:
|
||||||
s.logger.Info("Stopping event listener.")
|
s.logger.Info("Stopping event listener")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,13 +167,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
|
|||||||
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
|
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
groupID := *event.RequestLog.GroupID
|
||||||
|
keyID := *event.RequestLog.KeyID
|
||||||
|
|
||||||
if event.RequestLog.IsSuccess {
|
if event.RequestLog.IsSuccess {
|
||||||
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
|
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Warnf("[%s] Could not find mapping for G:%d K:%d on successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
|
s.logger.Warnf("[%s] Mapping not found for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
statusChanged := false
|
statusChanged := false
|
||||||
oldStatus := mapping.Status
|
oldStatus := mapping.Status
|
||||||
if mapping.Status != models.StatusActive {
|
if mapping.Status != models.StatusActive {
|
||||||
@@ -193,38 +190,33 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
|
|||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mapping.LastUsedAt = &now
|
mapping.LastUsedAt = &now
|
||||||
|
|
||||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||||
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
|
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d: %v", event.CorrelationID, groupID, keyID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if statusChanged {
|
if statusChanged {
|
||||||
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.Error != nil {
|
if event.Error != nil {
|
||||||
s.judgeKeyErrors(
|
s.judgeKeyErrors(ctx, event.CorrelationID, groupID, keyID, event.Error, event.IsPreciseRouting)
|
||||||
ctx,
|
|
||||||
event.CorrelationID,
|
|
||||||
*event.RequestLog.GroupID,
|
|
||||||
*event.RequestLog.KeyID,
|
|
||||||
event.Error,
|
|
||||||
event.IsPreciseRouting,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
|
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
log := s.logger.WithFields(logrus.Fields{
|
s.logger.WithFields(logrus.Fields{
|
||||||
"group_id": event.GroupID,
|
"group_id": event.GroupID,
|
||||||
"key_id": event.KeyID,
|
"key_id": event.KeyID,
|
||||||
"new_status": event.NewStatus,
|
"new_status": event.NewStatus,
|
||||||
"reason": event.ChangeReason,
|
"reason": event.ChangeReason,
|
||||||
})
|
}).Info("Updating polling caches based on status change")
|
||||||
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
|
|
||||||
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
|
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
|
||||||
log.Info("Polling caches updated based on health check event.")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||||
@@ -238,7 +230,7 @@ func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, k
|
|||||||
}
|
}
|
||||||
eventData, _ := json.Marshal(changeEvent)
|
eventData, _ := json.Marshal(changeEvent)
|
||||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||||
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
|
s.logger.Errorf("Failed to publish status change event for group %d: %v", groupID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,10 +240,12 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
totalPages := 0
|
totalPages := 0
|
||||||
if total > 0 && params.PageSize > 0 {
|
if total > 0 && params.PageSize > 0 {
|
||||||
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
|
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PaginatedAPIKeys{
|
return &PaginatedAPIKeys{
|
||||||
Items: items,
|
Items: items,
|
||||||
Total: total,
|
Total: total,
|
||||||
@@ -260,34 +254,44 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
|||||||
TotalPages: totalPages,
|
TotalPages: totalPages,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
|
||||||
var statusesToFilter []string
|
s.logger.Infof("Performing in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
||||||
|
|
||||||
|
statusesToFilter := []string{"all"}
|
||||||
if params.Status != "" {
|
if params.Status != "" {
|
||||||
statusesToFilter = append(statusesToFilter, params.Status)
|
statusesToFilter = []string{params.Status}
|
||||||
} else {
|
|
||||||
statusesToFilter = append(statusesToFilter, "all")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
|
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch all key IDs for search: %w", err)
|
return nil, fmt.Errorf("failed to fetch key IDs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(allKeyIDs) == 0 {
|
if len(allKeyIDs) == 0 {
|
||||||
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
|
return &PaginatedAPIKeys{
|
||||||
|
Items: []*models.APIKeyDetails{},
|
||||||
|
Total: 0,
|
||||||
|
Page: 1,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
TotalPages: 0,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
|
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
|
return nil, fmt.Errorf("failed to fetch keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var allMappings []models.GroupAPIKeyMapping
|
var allMappings []models.GroupAPIKeyMapping
|
||||||
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
|
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error; err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("failed to fetch mappings: %w", err)
|
||||||
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
|
|
||||||
}
|
}
|
||||||
mappingMap := make(map[uint]*models.GroupAPIKeyMapping)
|
|
||||||
|
mappingMap := make(map[uint]*models.GroupAPIKeyMapping, len(allMappings))
|
||||||
for i := range allMappings {
|
for i := range allMappings {
|
||||||
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
|
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
var filteredItems []*models.APIKeyDetails
|
var filteredItems []*models.APIKeyDetails
|
||||||
for _, key := range allKeys {
|
for _, key := range allKeys {
|
||||||
if strings.Contains(key.APIKey, params.Keyword) {
|
if strings.Contains(key.APIKey, params.Keyword) {
|
||||||
@@ -307,12 +311,15 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(filteredItems, func(i, j int) bool {
|
sort.Slice(filteredItems, func(i, j int) bool {
|
||||||
return filteredItems[i].ID > filteredItems[j].ID
|
return filteredItems[i].ID > filteredItems[j].ID
|
||||||
})
|
})
|
||||||
|
|
||||||
total := int64(len(filteredItems))
|
total := int64(len(filteredItems))
|
||||||
start := (params.Page - 1) * params.PageSize
|
start := (params.Page - 1) * params.PageSize
|
||||||
end := start + params.PageSize
|
end := start + params.PageSize
|
||||||
|
|
||||||
if start < 0 {
|
if start < 0 {
|
||||||
start = 0
|
start = 0
|
||||||
}
|
}
|
||||||
@@ -328,9 +335,9 @@ func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQu
|
|||||||
if end > len(filteredItems) {
|
if end > len(filteredItems) {
|
||||||
end = len(filteredItems)
|
end = len(filteredItems)
|
||||||
}
|
}
|
||||||
paginatedItems := filteredItems[start:end]
|
|
||||||
return &PaginatedAPIKeys{
|
return &PaginatedAPIKeys{
|
||||||
Items: paginatedItems,
|
Items: filteredItems[start:end],
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.PageSize,
|
PageSize: params.PageSize,
|
||||||
@@ -344,15 +351,8 @@ func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.
|
|||||||
|
|
||||||
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
|
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
|
||||||
go func() {
|
go func() {
|
||||||
bgCtx := context.Background()
|
|
||||||
var oldKey models.APIKey
|
|
||||||
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
|
|
||||||
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := s.keyRepo.Update(key); err != nil {
|
if err := s.keyRepo.Update(key); err != nil {
|
||||||
s.logger.Errorf("Failed to asynchronously update key ID %d: %v", key.ID, err)
|
s.logger.Errorf("Failed to update key ID %d: %v", key.ID, err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
@@ -361,22 +361,24 @@ func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) er
|
|||||||
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
|
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
|
||||||
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
|
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
|
s.logger.Warnf("Could not get groups for key %d before deletion: %v", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.keyRepo.HardDeleteByID(id)
|
if err := s.keyRepo.HardDeleteByID(id); err != nil {
|
||||||
if err == nil {
|
return err
|
||||||
for _, groupID := range groups {
|
|
||||||
event := models.KeyStatusChangedEvent{
|
|
||||||
KeyID: id,
|
|
||||||
GroupID: groupID,
|
|
||||||
ChangeReason: "key_hard_deleted",
|
|
||||||
}
|
|
||||||
eventData, _ := json.Marshal(event)
|
|
||||||
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
for _, groupID := range groups {
|
||||||
|
event := models.KeyStatusChangedEvent{
|
||||||
|
KeyID: id,
|
||||||
|
GroupID: groupID,
|
||||||
|
ChangeReason: "key_hard_deleted",
|
||||||
|
}
|
||||||
|
eventData, _ := json.Marshal(event)
|
||||||
|
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
||||||
@@ -388,47 +390,54 @@ func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID
|
|||||||
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||||
return nil, CustomErrors.ErrStateConflictMasterRevoked
|
return nil, CustomErrors.ErrStateConflictMasterRevoked
|
||||||
}
|
}
|
||||||
|
|
||||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldStatus := mapping.Status
|
oldStatus := mapping.Status
|
||||||
if oldStatus == newStatus {
|
if oldStatus == newStatus {
|
||||||
return mapping, nil
|
return mapping, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mapping.Status = newStatus
|
mapping.Status = newStatus
|
||||||
if newStatus == models.StatusActive {
|
if newStatus == models.StatusActive {
|
||||||
mapping.ConsecutiveErrorCount = 0
|
mapping.ConsecutiveErrorCount = 0
|
||||||
mapping.LastError = ""
|
mapping.LastError = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
|
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
|
||||||
return mapping, nil
|
return mapping, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
|
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
|
||||||
ctx := context.Background()
|
|
||||||
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
|
|
||||||
if event.NewMasterStatus != models.MasterStatusRevoked {
|
if event.NewMasterStatus != models.MasterStatusRevoked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
s.logger.Infof("Key %d revoked, propagating to all groups", event.KeyID)
|
||||||
|
|
||||||
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
|
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
|
s.logger.WithError(err).Errorf("Failed to get groups for key %d", event.KeyID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(affectedGroupIDs) == 0 {
|
if len(affectedGroupIDs) == 0 {
|
||||||
s.logger.Infof("Key ID %d is revoked, but it's not associated with any group. No action needed.", event.KeyID)
|
s.logger.Infof("Key %d not associated with any group", event.KeyID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
|
|
||||||
for _, groupID := range affectedGroupIDs {
|
for _, groupID := range affectedGroupIDs {
|
||||||
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
|
if _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned); err != nil {
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
|
s.logger.WithError(err).Errorf("Failed to ban key %d in group %d", event.KeyID, groupID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -436,59 +445,71 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
|
|||||||
|
|
||||||
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
|
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
|
||||||
if len(keyIDs) == 0 {
|
if len(keyIDs) == 0 {
|
||||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
|
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
|
||||||
return taskStatus, nil
|
return taskStatus, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
|
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
|
s.logger.Errorf("Panic in restore task: %v", r)
|
||||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var mappingsToProcess []models.GroupAPIKeyMapping
|
var mappingsToProcess []models.GroupAPIKeyMapping
|
||||||
err := s.db.WithContext(ctx).Preload("APIKey").
|
if err := s.db.WithContext(ctx).Preload("APIKey").
|
||||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||||
Find(&mappingsToProcess).Error
|
Find(&mappingsToProcess).Error; err != nil {
|
||||||
if err != nil {
|
|
||||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
result := &BatchRestoreResult{
|
|
||||||
SkippedKeys: make([]SkippedKeyInfo, 0),
|
result := &BatchRestoreResult{SkippedKeys: make([]SkippedKeyInfo, 0)}
|
||||||
}
|
|
||||||
var successfulMappings []*models.GroupAPIKeyMapping
|
var successfulMappings []*models.GroupAPIKeyMapping
|
||||||
processedCount := 0
|
|
||||||
for _, mapping := range mappingsToProcess {
|
for i, mapping := range mappingsToProcess {
|
||||||
processedCount++
|
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
|
||||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
|
||||||
if mapping.APIKey == nil {
|
if mapping.APIKey == nil {
|
||||||
result.SkippedCount++
|
result.SkippedCount++
|
||||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
|
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
|
||||||
|
KeyID: mapping.APIKeyID,
|
||||||
|
Reason: "APIKey not found",
|
||||||
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
|
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
|
||||||
result.SkippedCount++
|
result.SkippedCount++
|
||||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: fmt.Sprintf("Master status is '%s'.", mapping.APIKey.MasterStatus)})
|
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
|
||||||
|
KeyID: mapping.APIKeyID,
|
||||||
|
Reason: fmt.Sprintf("Master status is %s", mapping.APIKey.MasterStatus),
|
||||||
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
oldStatus := mapping.Status
|
oldStatus := mapping.Status
|
||||||
if oldStatus != models.StatusActive {
|
if oldStatus != models.StatusActive {
|
||||||
mapping.Status = models.StatusActive
|
mapping.Status = models.StatusActive
|
||||||
mapping.ConsecutiveErrorCount = 0
|
mapping.ConsecutiveErrorCount = 0
|
||||||
mapping.LastError = ""
|
mapping.LastError = ""
|
||||||
|
|
||||||
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
|
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
|
||||||
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
|
|
||||||
result.SkippedCount++
|
result.SkippedCount++
|
||||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
|
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{
|
||||||
|
KeyID: mapping.APIKeyID,
|
||||||
|
Reason: "DB update failed",
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
result.RestoredCount++
|
result.RestoredCount++
|
||||||
successfulMappings = append(successfulMappings, &mapping)
|
successfulMappings = append(successfulMappings, &mapping)
|
||||||
@@ -498,61 +519,72 @@ func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, r
|
|||||||
result.RestoredCount++
|
result.RestoredCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
||||||
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
|
s.logger.WithError(err).Error("Failed batch cache update after restore")
|
||||||
}
|
}
|
||||||
|
|
||||||
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
|
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
|
||||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
|
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
|
||||||
var bannedKeyIDs []uint
|
var bannedKeyIDs []uint
|
||||||
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||||
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
|
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
|
||||||
Pluck("api_key_id", &bannedKeyIDs).Error
|
Pluck("api_key_id", &bannedKeyIDs).Error; err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, CustomErrors.ParseDBError(err)
|
return nil, CustomErrors.ParseDBError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(bannedKeyIDs) == 0 {
|
if len(bannedKeyIDs) == 0 {
|
||||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
|
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
|
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
group, ok := s.groupManager.GetGroupByID(event.GroupID)
|
group, ok := s.groupManager.GetGroupByID(event.GroupID)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
|
s.logger.Errorf("Group %d not found for post-import validation", event.GroupID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("Failed to build operational config for group %d, aborting validation: %v", event.GroupID, err)
|
s.logger.Errorf("Failed to build config for group %d: %v", event.GroupID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
|
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("Failed to build key check endpoint for group %d, aborting validation: %v", event.GroupID, err)
|
s.logger.Errorf("Failed to build endpoint for group %d: %v", event.GroupID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
globalSettings := s.SettingsManager.GetSettings()
|
|
||||||
concurrency := globalSettings.BaseKeyCheckConcurrency
|
concurrency := s.SettingsManager.GetSettings().BaseKeyCheckConcurrency
|
||||||
if opConfig.KeyCheckConcurrency != nil {
|
if opConfig.KeyCheckConcurrency != nil {
|
||||||
concurrency = *opConfig.KeyCheckConcurrency
|
concurrency = *opConfig.KeyCheckConcurrency
|
||||||
}
|
}
|
||||||
if concurrency <= 0 {
|
if concurrency <= 0 {
|
||||||
concurrency = 10
|
concurrency = 10
|
||||||
}
|
}
|
||||||
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
|
|
||||||
|
timeout := time.Duration(s.SettingsManager.GetSettings().KeyCheckTimeoutSeconds) * time.Second
|
||||||
|
|
||||||
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
|
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("Failed to get key models for validation in group %d: %v", event.GroupID, err)
|
s.logger.Errorf("Failed to get keys for validation in group %d: %v", event.GroupID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Infof("Validating %d keys for group %d with concurrency %d against endpoint %s", len(keysToValidate), event.GroupID, concurrency, endpoint)
|
|
||||||
|
s.logger.Infof("Validating %d keys for group %d (concurrency: %d)", len(keysToValidate), event.GroupID, concurrency)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
jobs := make(chan models.APIKey, len(keysToValidate))
|
jobs := make(chan models.APIKey, len(keysToValidate))
|
||||||
|
|
||||||
for i := 0; i < concurrency; i++ {
|
for i := 0; i < concurrency; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -560,9 +592,8 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
|||||||
for key := range jobs {
|
for key := range jobs {
|
||||||
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
|
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
|
||||||
if validationErr == nil {
|
if validationErr == nil {
|
||||||
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
|
|
||||||
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
|
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
|
||||||
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
|
s.logger.Errorf("Failed to activate key %d in group %d: %v", key.ID, event.GroupID, err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var apiErr *CustomErrors.APIError
|
var apiErr *CustomErrors.APIError
|
||||||
@@ -574,35 +605,34 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range keysToValidate {
|
for _, key := range keysToValidate {
|
||||||
jobs <- key
|
jobs <- key
|
||||||
}
|
}
|
||||||
close(jobs)
|
close(jobs)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
|
|
||||||
|
s.logger.Infof("Finished post-import validation for group %d", event.GroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
||||||
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
|
|
||||||
|
|
||||||
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
|
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, CustomErrors.ParseDBError(err)
|
return nil, CustomErrors.ParseDBError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(keyIDs) == 0 {
|
if len(keyIDs) == 0 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return &task.Status{
|
return &task.Status{
|
||||||
IsRunning: false,
|
IsRunning: false,
|
||||||
Processed: 0,
|
Processed: 0,
|
||||||
Total: 0,
|
Total: 0,
|
||||||
Result: map[string]string{
|
Result: map[string]string{"message": "没有找到符合条件的Key"},
|
||||||
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
|
|
||||||
},
|
|
||||||
Error: "",
|
|
||||||
StartedAt: now,
|
StartedAt: now,
|
||||||
FinishedAt: &now,
|
FinishedAt: &now,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
|
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
|
||||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -616,10 +646,11 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, group
|
|||||||
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
|
s.logger.Errorf("Panic in status update task: %v", r)
|
||||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic: %v", r))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
type BatchUpdateResult struct {
|
type BatchUpdateResult struct {
|
||||||
UpdatedCount int `json:"updated_count"`
|
UpdatedCount int `json:"updated_count"`
|
||||||
SkippedCount int `json:"skipped_count"`
|
SkippedCount int `json:"skipped_count"`
|
||||||
@@ -629,33 +660,35 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
|
|||||||
|
|
||||||
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
|
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
|
|
||||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
|
|
||||||
|
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus, len(keys))
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
masterStatusMap[key.ID] = key.MasterStatus
|
masterStatusMap[key.ID] = key.MasterStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
var mappings []*models.GroupAPIKeyMapping
|
var mappings []*models.GroupAPIKeyMapping
|
||||||
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
||||||
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
|
|
||||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
processedCount := 0
|
|
||||||
for _, mapping := range mappings {
|
for i, mapping := range mappings {
|
||||||
processedCount++
|
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+1)
|
||||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
|
||||||
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
|
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
|
||||||
if !ok {
|
if !ok {
|
||||||
result.SkippedCount++
|
result.SkippedCount++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||||
result.SkippedCount++
|
result.SkippedCount++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
oldStatus := mapping.Status
|
oldStatus := mapping.Status
|
||||||
if oldStatus != newStatus {
|
if oldStatus != newStatus {
|
||||||
mapping.Status = newStatus
|
mapping.Status = newStatus
|
||||||
@@ -663,6 +696,7 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
|
|||||||
mapping.ConsecutiveErrorCount = 0
|
mapping.ConsecutiveErrorCount = 0
|
||||||
mapping.LastError = ""
|
mapping.LastError = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
|
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
|
||||||
result.SkippedCount++
|
result.SkippedCount++
|
||||||
} else {
|
} else {
|
||||||
@@ -674,122 +708,112 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID,
|
|||||||
result.UpdatedCount++
|
result.UpdatedCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result.SkippedCount += (len(keyIDs) - len(mappings))
|
result.SkippedCount += (len(keyIDs) - len(mappings))
|
||||||
|
|
||||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
||||||
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
|
s.logger.WithError(err).Error("Failed batch cache update after status update")
|
||||||
}
|
}
|
||||||
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
|
|
||||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
|
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
if success {
|
if success {
|
||||||
if group.PollingStrategy == models.StrategyWeighted {
|
if group.PollingStrategy == models.StrategyWeighted {
|
||||||
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
|
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if apiErr == nil {
|
if apiErr == nil {
|
||||||
s.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided. No action taken.", key.ID, group.ID)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
errMsg := apiErr.Message
|
errMsg := apiErr.Message
|
||||||
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
|
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
|
||||||
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
|
|
||||||
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
|
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
|
||||||
} else {
|
|
||||||
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeForLog(errMsg string) string {
|
func (s *APIKeyService) judgeKeyErrors(ctx context.Context, correlationID string, groupID, keyID uint, apiErr *CustomErrors.APIError, isPreciseRouting bool) {
|
||||||
jsonStartIndex := strings.Index(errMsg, "{")
|
logger := s.logger.WithFields(logrus.Fields{
|
||||||
var cleanMsg string
|
"group_id": groupID,
|
||||||
if jsonStartIndex != -1 {
|
"key_id": keyID,
|
||||||
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
|
"correlation_id": correlationID,
|
||||||
} else {
|
})
|
||||||
cleanMsg = errMsg
|
|
||||||
}
|
|
||||||
const maxLen = 250
|
|
||||||
if len(cleanMsg) > maxLen {
|
|
||||||
return cleanMsg[:maxLen] + "..."
|
|
||||||
}
|
|
||||||
return cleanMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *APIKeyService) judgeKeyErrors(
|
|
||||||
ctx context.Context,
|
|
||||||
correlationID string,
|
|
||||||
groupID, keyID uint,
|
|
||||||
apiErr *CustomErrors.APIError,
|
|
||||||
isPreciseRouting bool,
|
|
||||||
) {
|
|
||||||
logger := s.logger.WithFields(logrus.Fields{"group_id": groupID, "key_id": keyID, "correlation_id": correlationID})
|
|
||||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Warn("Cannot apply consequences for error: mapping not found.")
|
logger.WithError(err).Warn("Mapping not found, cannot apply error consequences")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mapping.LastUsedAt = &now
|
mapping.LastUsedAt = &now
|
||||||
errorMessage := apiErr.Message
|
errorMessage := apiErr.Message
|
||||||
|
|
||||||
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
|
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
|
||||||
logger.Errorf("Permanent error detected. Banning mapping and revoking master key. Reason: %s", sanitizeForLog(errorMessage))
|
logger.Errorf("Permanent error: %s", sanitizeForLog(errorMessage))
|
||||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the permanent error.")
|
|
||||||
if mapping.Status != models.StatusBanned {
|
if mapping.Status != models.StatusBanned {
|
||||||
oldStatus := mapping.Status
|
oldStatus := mapping.Status
|
||||||
mapping.Status = models.StatusBanned
|
mapping.Status = models.StatusBanned
|
||||||
mapping.LastError = errorMessage
|
mapping.LastError = errorMessage
|
||||||
|
|
||||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||||
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
|
logger.WithError(err).Error("Failed to ban mapping")
|
||||||
} else {
|
} else {
|
||||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
|
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, models.StatusBanned, "permanent_error")
|
||||||
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
|
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
|
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
|
||||||
mapping.LastError = errorMessage
|
mapping.LastError = errorMessage
|
||||||
mapping.ConsecutiveErrorCount++
|
mapping.ConsecutiveErrorCount++
|
||||||
var threshold int
|
|
||||||
|
threshold := s.SettingsManager.GetSettings().BlacklistThreshold
|
||||||
if isPreciseRouting {
|
if isPreciseRouting {
|
||||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
if group, ok := s.groupManager.GetGroupByID(groupID); ok {
|
||||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
if opConfig, err := s.groupManager.BuildOperationalConfig(group); err == nil && opConfig.KeyBlacklistThreshold != nil {
|
||||||
if !ok || err != nil {
|
threshold = *opConfig.KeyBlacklistThreshold
|
||||||
logger.Warnf("Could not build operational config for group %d in Precise Routing mode. Falling back to global settings.", groupID)
|
}
|
||||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
|
||||||
} else {
|
|
||||||
threshold = *opConfig.KeyBlacklistThreshold
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
|
||||||
}
|
}
|
||||||
logger.Warnf("Temporary error detected. Incrementing error count. New count: %d (Threshold: %d). Reason: %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
|
|
||||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the temporary error.")
|
logger.Warnf("Temporary error (count: %d, threshold: %d): %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
|
||||||
|
|
||||||
oldStatus := mapping.Status
|
oldStatus := mapping.Status
|
||||||
newStatus := oldStatus
|
newStatus := oldStatus
|
||||||
|
|
||||||
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
|
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
|
||||||
newStatus = models.StatusCooldown
|
newStatus = models.StatusCooldown
|
||||||
logger.Errorf("Putting mapping into COOLDOWN due to reaching temporary error threshold (%d)", threshold)
|
logger.Errorf("Moving to COOLDOWN after reaching threshold %d", threshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldStatus != newStatus {
|
if oldStatus != newStatus {
|
||||||
mapping.Status = newStatus
|
mapping.Status = newStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||||
logger.WithError(err).Error("Failed to update mapping after temporary error.")
|
logger.WithError(err).Error("Failed to update mapping after temporary error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldStatus != newStatus {
|
if oldStatus != newStatus {
|
||||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
|
|
||||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
|
logger.Infof("Ignorable error: %s", sanitizeForLog(errorMessage))
|
||||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||||
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
|
logger.WithError(err).Error("Failed to update LastUsedAt")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -797,25 +821,27 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason
|
|||||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
s.logger.Warnf("Attempted to revoke non-existent key ID %d.", keyID)
|
s.logger.Warnf("Attempted to revoke non-existent key %d", keyID)
|
||||||
} else {
|
} else {
|
||||||
s.logger.Errorf("Failed to get key by ID %d for master status revocation: %v", keyID, err)
|
s.logger.Errorf("Failed to get key %d for revocation: %v", keyID, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if key.MasterStatus == models.MasterStatusRevoked {
|
if key.MasterStatus == models.MasterStatusRevoked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
oldMasterStatus := key.MasterStatus
|
oldMasterStatus := key.MasterStatus
|
||||||
newMasterStatus := models.MasterStatusRevoked
|
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, models.MasterStatusRevoked); err != nil {
|
||||||
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
|
s.logger.Errorf("Failed to revoke key %d: %v", keyID, err)
|
||||||
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
masterKeyEvent := models.MasterKeyStatusChangedEvent{
|
masterKeyEvent := models.MasterKeyStatusChangedEvent{
|
||||||
KeyID: keyID,
|
KeyID: keyID,
|
||||||
OldMasterStatus: oldMasterStatus,
|
OldMasterStatus: oldMasterStatus,
|
||||||
NewMasterStatus: newMasterStatus,
|
NewMasterStatus: models.MasterStatusRevoked,
|
||||||
ChangeReason: reason,
|
ChangeReason: reason,
|
||||||
ChangedAt: time.Now(),
|
ChangedAt: time.Now(),
|
||||||
}
|
}
|
||||||
@@ -826,3 +852,13 @@ func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason
|
|||||||
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
|
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
|
||||||
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
|
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeForLog(errMsg string) string {
|
||||||
|
if idx := strings.Index(errMsg, "{"); idx != -1 {
|
||||||
|
errMsg = strings.TrimSpace(errMsg[:idx]) + " {...}"
|
||||||
|
}
|
||||||
|
if len(errMsg) > 250 {
|
||||||
|
return errMsg[:250] + "..."
|
||||||
|
}
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
// Filename: internal/service/group_manager.go (Syncer升级版)
|
// Filename: internal/service/group_manager.go
|
||||||
|
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -29,10 +28,9 @@ type GroupManagerCacheData struct {
|
|||||||
Groups []*models.KeyGroup
|
Groups []*models.KeyGroup
|
||||||
GroupsByName map[string]*models.KeyGroup
|
GroupsByName map[string]*models.KeyGroup
|
||||||
GroupsByID map[uint]*models.KeyGroup
|
GroupsByID map[uint]*models.KeyGroup
|
||||||
KeyCounts map[uint]int64 // GroupID -> Total Key Count
|
KeyCounts map[uint]int64
|
||||||
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
|
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type GroupManager struct {
|
type GroupManager struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
keyRepo repository.KeyRepository
|
keyRepo repository.KeyRepository
|
||||||
@@ -41,7 +39,6 @@ type GroupManager struct {
|
|||||||
syncer *syncer.CacheSyncer[GroupManagerCacheData]
|
syncer *syncer.CacheSyncer[GroupManagerCacheData]
|
||||||
logger *logrus.Entry
|
logger *logrus.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateOrderPayload struct {
|
type UpdateOrderPayload struct {
|
||||||
ID uint `json:"id" binding:"required"`
|
ID uint `json:"id" binding:"required"`
|
||||||
Order int `json:"order"`
|
Order int `json:"order"`
|
||||||
@@ -49,43 +46,19 @@ type UpdateOrderPayload struct {
|
|||||||
|
|
||||||
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
|
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
|
||||||
return func() (GroupManagerCacheData, error) {
|
return func() (GroupManagerCacheData, error) {
|
||||||
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
|
|
||||||
var groups []*models.KeyGroup
|
var groups []*models.KeyGroup
|
||||||
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
|
|
||||||
|
|
||||||
if err := db.Preload("AllowedUpstreams").
|
if err := db.Preload("AllowedUpstreams").
|
||||||
Preload("AllowedModels").
|
Preload("AllowedModels").
|
||||||
Preload("Settings").
|
Preload("Settings").
|
||||||
Preload("RequestConfig").
|
Preload("RequestConfig").
|
||||||
|
Preload("Mappings").
|
||||||
Find(&groups).Error; err != nil {
|
Find(&groups).Error; err != nil {
|
||||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
|
return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err)
|
||||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
|
|
||||||
}
|
}
|
||||||
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
|
|
||||||
|
|
||||||
var allMappings []*models.GroupAPIKeyMapping
|
|
||||||
if err := db.Find(&allMappings).Error; err != nil {
|
|
||||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
|
|
||||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
|
|
||||||
}
|
|
||||||
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
|
|
||||||
|
|
||||||
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
|
|
||||||
for i := range allMappings {
|
|
||||||
mapping := allMappings[i] // Avoid pointer issues with range
|
|
||||||
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, group := range groups {
|
|
||||||
if mappings, ok := mappingsByGroupID[group.ID]; ok {
|
|
||||||
group.Mappings = mappings
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
|
|
||||||
|
|
||||||
keyCounts := make(map[uint]int64, len(groups))
|
keyCounts := make(map[uint]int64, len(groups))
|
||||||
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
|
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
|
||||||
|
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
||||||
|
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
keyCounts[group.ID] = int64(len(group.Mappings))
|
keyCounts[group.ID] = int64(len(group.Mappings))
|
||||||
statusCounts := make(map[models.APIKeyStatus]int64)
|
statusCounts := make(map[models.APIKeyStatus]int64)
|
||||||
@@ -93,20 +66,9 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
|
|||||||
statusCounts[mapping.Status]++
|
statusCounts[mapping.Status]++
|
||||||
}
|
}
|
||||||
keyStatusCounts[group.ID] = statusCounts
|
keyStatusCounts[group.ID] = statusCounts
|
||||||
|
groupsByName[group.Name] = group
|
||||||
|
groupsByID[group.ID] = group
|
||||||
}
|
}
|
||||||
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
|
||||||
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
|
||||||
|
|
||||||
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
|
|
||||||
for i, group := range groups {
|
|
||||||
if group == nil {
|
|
||||||
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
|
|
||||||
} else {
|
|
||||||
groupsByName[group.Name] = group
|
|
||||||
groupsByID[group.ID] = group
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
|
|
||||||
return GroupManagerCacheData{
|
return GroupManagerCacheData{
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
GroupsByName: groupsByName,
|
GroupsByName: groupsByName,
|
||||||
@@ -116,7 +78,6 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGroupManager(
|
func NewGroupManager(
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
keyRepo repository.KeyRepository,
|
keyRepo repository.KeyRepository,
|
||||||
@@ -134,138 +95,67 @@ func NewGroupManager(
|
|||||||
logger: logger.WithField("component", "GroupManager"),
|
logger: logger.WithField("component", "GroupManager"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
|
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
|
||||||
cache := gm.syncer.Get()
|
groups := gm.syncer.Get().Groups
|
||||||
if len(cache.Groups) == 0 {
|
sort.Slice(groups, func(i, j int) bool {
|
||||||
return []*models.KeyGroup{}
|
if groups[i].Order != groups[j].Order {
|
||||||
}
|
return groups[i].Order < groups[j].Order
|
||||||
groupsToOrder := cache.Groups
|
|
||||||
sort.Slice(groupsToOrder, func(i, j int) bool {
|
|
||||||
if groupsToOrder[i].Order != groupsToOrder[j].Order {
|
|
||||||
return groupsToOrder[i].Order < groupsToOrder[j].Order
|
|
||||||
}
|
}
|
||||||
return groupsToOrder[i].ID < groupsToOrder[j].ID
|
return groups[i].ID < groups[j].ID
|
||||||
})
|
})
|
||||||
return groupsToOrder
|
return groups
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
|
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
|
||||||
cache := gm.syncer.Get()
|
return gm.syncer.Get().KeyCounts[groupID]
|
||||||
if len(cache.KeyCounts) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
count := cache.KeyCounts[groupID]
|
|
||||||
return count
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
|
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
|
||||||
cache := gm.syncer.Get()
|
if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok {
|
||||||
if len(cache.KeyStatusCounts) == 0 {
|
|
||||||
return make(map[models.APIKeyStatus]int64)
|
|
||||||
}
|
|
||||||
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
|
|
||||||
return counts
|
return counts
|
||||||
}
|
}
|
||||||
return make(map[models.APIKeyStatus]int64)
|
return make(map[models.APIKeyStatus]int64)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
|
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
|
||||||
cache := gm.syncer.Get()
|
group, ok := gm.syncer.Get().GroupsByName[name]
|
||||||
if len(cache.GroupsByName) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
group, ok := cache.GroupsByName[name]
|
|
||||||
return group, ok
|
return group, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
|
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
|
||||||
cache := gm.syncer.Get()
|
group, ok := gm.syncer.Get().GroupsByID[id]
|
||||||
if len(cache.GroupsByID) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
group, ok := cache.GroupsByID[id]
|
|
||||||
return group, ok
|
return group, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) Stop() {
|
func (gm *GroupManager) Stop() {
|
||||||
gm.syncer.Stop()
|
gm.syncer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) Invalidate() error {
|
func (gm *GroupManager) Invalidate() error {
|
||||||
return gm.syncer.Invalidate()
|
return gm.syncer.Invalidate()
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Write Operations ---
|
|
||||||
|
|
||||||
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
|
|
||||||
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
|
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
|
||||||
if !utils.IsValidGroupName(group.Name) {
|
if !utils.IsValidGroupName(group.Name) {
|
||||||
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||||
}
|
}
|
||||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||||
// 1. Create the group itself to get an ID
|
|
||||||
if err := tx.Create(group).Error; err != nil {
|
if err := tx.Create(group).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. If settings are provided, create the associated GroupSettings record
|
|
||||||
if settings != nil {
|
if settings != nil {
|
||||||
// Only marshal non-nil fields to keep the JSON clean
|
settingsJSON, err := json.Marshal(settings)
|
||||||
settingsToMarshal := make(map[string]interface{})
|
if err != nil {
|
||||||
if settings.EnableKeyCheck != nil {
|
return fmt.Errorf("failed to marshal settings: %w", err)
|
||||||
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
|
|
||||||
}
|
}
|
||||||
if settings.KeyCheckIntervalMinutes != nil {
|
groupSettings := models.GroupSettings{
|
||||||
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
|
GroupID: group.ID,
|
||||||
|
SettingsJSON: datatypes.JSON(settingsJSON),
|
||||||
}
|
}
|
||||||
if settings.KeyBlacklistThreshold != nil {
|
if err := tx.Create(&groupSettings).Error; err != nil {
|
||||||
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
|
return err
|
||||||
}
|
|
||||||
if settings.KeyCooldownMinutes != nil {
|
|
||||||
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
|
|
||||||
}
|
|
||||||
if settings.KeyCheckConcurrency != nil {
|
|
||||||
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
|
|
||||||
}
|
|
||||||
if settings.KeyCheckEndpoint != nil {
|
|
||||||
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
|
|
||||||
}
|
|
||||||
if settings.KeyCheckModel != nil {
|
|
||||||
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
|
|
||||||
}
|
|
||||||
if settings.MaxRetries != nil {
|
|
||||||
settingsToMarshal["max_retries"] = settings.MaxRetries
|
|
||||||
}
|
|
||||||
if settings.EnableSmartGateway != nil {
|
|
||||||
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
|
|
||||||
}
|
|
||||||
if len(settingsToMarshal) > 0 {
|
|
||||||
settingsJSON, err := json.Marshal(settingsToMarshal)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal group settings: %w", err)
|
|
||||||
}
|
|
||||||
groupSettings := models.GroupSettings{
|
|
||||||
GroupID: group.ID,
|
|
||||||
SettingsJSON: datatypes.JSON(settingsJSON),
|
|
||||||
}
|
|
||||||
if err := tx.Create(&groupSettings).Error; err != nil {
|
|
||||||
return fmt.Errorf("failed to save group settings: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
if err == nil {
|
||||||
if err != nil {
|
go gm.Invalidate()
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
go gm.Invalidate()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
|
|
||||||
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
|
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
|
||||||
if !utils.IsValidGroupName(group.Name) {
|
if !utils.IsValidGroupName(group.Name) {
|
||||||
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||||
@@ -273,7 +163,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
|||||||
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
|
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
|
||||||
uniqueModelNames := uniqueStrings(modelNames)
|
uniqueModelNames := uniqueStrings(modelNames)
|
||||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||||
// --- 1. Update AllowedUpstreams (M:N relationship) ---
|
|
||||||
var upstreams []*models.UpstreamEndpoint
|
var upstreams []*models.UpstreamEndpoint
|
||||||
if len(uniqueUpstreamURLs) > 0 {
|
if len(uniqueUpstreamURLs) > 0 {
|
||||||
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
|
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
|
||||||
@@ -283,7 +172,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
|||||||
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
|
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
|
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -296,11 +184,9 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(group).Updates(group).Error; err != nil {
|
if err := tx.Model(group).Updates(group).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var existingSettings models.GroupSettings
|
var existingSettings models.GroupSettings
|
||||||
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
|
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
|
||||||
return err
|
return err
|
||||||
@@ -308,15 +194,15 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
|||||||
var currentSettingsData models.KeyGroupSettings
|
var currentSettingsData models.KeyGroupSettings
|
||||||
if len(existingSettings.SettingsJSON) > 0 {
|
if len(existingSettings.SettingsJSON) > 0 {
|
||||||
if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil {
|
if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
|
return fmt.Errorf("failed to unmarshal existing settings: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil {
|
if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil {
|
||||||
return fmt.Errorf("failed to merge group settings: %w", err)
|
return fmt.Errorf("failed to merge settings: %w", err)
|
||||||
}
|
}
|
||||||
updatedJSON, err := json.Marshal(currentSettingsData)
|
updatedJSON, err := json.Marshal(currentSettingsData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal updated group settings: %w", err)
|
return fmt.Errorf("failed to marshal updated settings: %w", err)
|
||||||
}
|
}
|
||||||
existingSettings.GroupID = group.ID
|
existingSettings.GroupID = group.ID
|
||||||
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
|
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
|
||||||
@@ -327,55 +213,25 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
|
|
||||||
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
||||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||||
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
|
|
||||||
// Step 1: First, retrieve the group object we are about to delete.
|
|
||||||
var group models.KeyGroup
|
var group models.KeyGroup
|
||||||
if err := tx.First(&group, id).Error; err != nil {
|
if err := tx.First(&group, id).Error; err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if err == gorm.ErrRecordNotFound {
|
||||||
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
|
return nil
|
||||||
return nil // Don't treat as an error, the group is already gone.
|
|
||||||
}
|
}
|
||||||
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
|
if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil {
|
||||||
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
|
|
||||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
|
|
||||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
|
|
||||||
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Also clear settings if they exist to maintain data integrity
|
|
||||||
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
|
|
||||||
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Step 3: Delete the KeyGroup itself.
|
|
||||||
if err := tx.Delete(&group).Error; err != nil {
|
|
||||||
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
|
|
||||||
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
|
|
||||||
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
|
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if deletedCount > 0 {
|
if deletedCount > 0 {
|
||||||
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
|
gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id)
|
||||||
}
|
}
|
||||||
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -383,7 +239,6 @@ func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||||
var originalGroup models.KeyGroup
|
var originalGroup models.KeyGroup
|
||||||
if err := gm.db.
|
if err := gm.db.
|
||||||
@@ -392,7 +247,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
|||||||
Preload("AllowedUpstreams").
|
Preload("AllowedUpstreams").
|
||||||
Preload("AllowedModels").
|
Preload("AllowedModels").
|
||||||
First(&originalGroup, id).Error; err != nil {
|
First(&originalGroup, id).Error; err != nil {
|
||||||
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
|
return nil, fmt.Errorf("failed to find original group %d: %w", id, err)
|
||||||
}
|
}
|
||||||
newGroup := originalGroup
|
newGroup := originalGroup
|
||||||
timestamp := time.Now().Unix()
|
timestamp := time.Now().Unix()
|
||||||
@@ -401,31 +256,25 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
|||||||
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
|
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
|
||||||
newGroup.CreatedAt = time.Time{}
|
newGroup.CreatedAt = time.Time{}
|
||||||
newGroup.UpdatedAt = time.Time{}
|
newGroup.UpdatedAt = time.Time{}
|
||||||
|
|
||||||
newGroup.RequestConfigID = nil
|
newGroup.RequestConfigID = nil
|
||||||
newGroup.RequestConfig = nil
|
newGroup.RequestConfig = nil
|
||||||
newGroup.Mappings = nil
|
newGroup.Mappings = nil
|
||||||
newGroup.AllowedUpstreams = nil
|
newGroup.AllowedUpstreams = nil
|
||||||
newGroup.AllowedModels = nil
|
newGroup.AllowedModels = nil
|
||||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
|
||||||
if err := tx.Create(&newGroup).Error; err != nil {
|
if err := tx.Create(&newGroup).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if originalGroup.RequestConfig != nil {
|
if originalGroup.RequestConfig != nil {
|
||||||
newRequestConfig := *originalGroup.RequestConfig
|
newRequestConfig := *originalGroup.RequestConfig
|
||||||
newRequestConfig.ID = 0 // Mark as new record
|
newRequestConfig.ID = 0
|
||||||
|
|
||||||
if err := tx.Create(&newRequestConfig).Error; err != nil {
|
if err := tx.Create(&newRequestConfig).Error; err != nil {
|
||||||
return fmt.Errorf("failed to clone request config: %w", err)
|
return fmt.Errorf("failed to clone request config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
|
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
|
||||||
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
|
return fmt.Errorf("failed to link cloned request config: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var originalSettings models.GroupSettings
|
var originalSettings models.GroupSettings
|
||||||
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
|
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
|
||||||
if err == nil && len(originalSettings.SettingsJSON) > 0 {
|
if err == nil && len(originalSettings.SettingsJSON) > 0 {
|
||||||
@@ -434,12 +283,11 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
|||||||
SettingsJSON: originalSettings.SettingsJSON,
|
SettingsJSON: originalSettings.SettingsJSON,
|
||||||
}
|
}
|
||||||
if err := tx.Create(&newSettings).Error; err != nil {
|
if err := tx.Create(&newSettings).Error; err != nil {
|
||||||
return fmt.Errorf("failed to clone group settings: %w", err)
|
return fmt.Errorf("failed to clone settings: %w", err)
|
||||||
}
|
}
|
||||||
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return fmt.Errorf("failed to query original group settings: %w", err)
|
return fmt.Errorf("failed to query original settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(originalGroup.Mappings) > 0 {
|
if len(originalGroup.Mappings) > 0 {
|
||||||
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
|
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
|
||||||
for i, oldMapping := range originalGroup.Mappings {
|
for i, oldMapping := range originalGroup.Mappings {
|
||||||
@@ -454,7 +302,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := tx.Create(&newMappings).Error; err != nil {
|
if err := tx.Create(&newMappings).Error; err != nil {
|
||||||
return fmt.Errorf("failed to clone key group mappings: %w", err)
|
return fmt.Errorf("failed to clone mappings: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(originalGroup.AllowedUpstreams) > 0 {
|
if len(originalGroup.AllowedUpstreams) > 0 {
|
||||||
@@ -469,13 +317,10 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go gm.Invalidate()
|
go gm.Invalidate()
|
||||||
|
|
||||||
var finalClonedGroup models.KeyGroup
|
var finalClonedGroup models.KeyGroup
|
||||||
if err := gm.db.
|
if err := gm.db.
|
||||||
Preload("RequestConfig").
|
Preload("RequestConfig").
|
||||||
@@ -487,10 +332,9 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
|||||||
}
|
}
|
||||||
return &finalClonedGroup, nil
|
return &finalClonedGroup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
||||||
globalSettings := gm.settingsManager.GetSettings()
|
globalSettings := gm.settingsManager.GetSettings()
|
||||||
s := "gemini-1.5-flash" // Per user feedback for default model
|
defaultModel := "gemini-1.5-flash"
|
||||||
opConfig := &models.KeyGroupSettings{
|
opConfig := &models.KeyGroupSettings{
|
||||||
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
||||||
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
||||||
@@ -498,52 +342,43 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
|
|||||||
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
||||||
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
||||||
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
||||||
KeyCheckModel: &s,
|
KeyCheckModel: &defaultModel,
|
||||||
MaxRetries: &globalSettings.MaxRetries,
|
MaxRetries: &globalSettings.MaxRetries,
|
||||||
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
||||||
}
|
}
|
||||||
|
|
||||||
if group == nil {
|
if group == nil {
|
||||||
return opConfig, nil
|
return opConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var groupSettingsRecord models.GroupSettings
|
var groupSettingsRecord models.GroupSettings
|
||||||
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
|
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return opConfig, nil
|
return opConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(groupSettingsRecord.SettingsJSON) == 0 {
|
if len(groupSettingsRecord.SettingsJSON) == 0 {
|
||||||
return opConfig, nil
|
return opConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var groupSpecificSettings models.KeyGroupSettings
|
var groupSpecificSettings models.KeyGroupSettings
|
||||||
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
|
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
|
||||||
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
|
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings")
|
||||||
return opConfig, err
|
return opConfig, nil
|
||||||
}
|
}
|
||||||
|
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
||||||
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings")
|
||||||
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
|
|
||||||
return opConfig, nil
|
return opConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return opConfig, nil
|
return opConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||||
group, ok := gm.GetGroupByID(groupID)
|
group, ok := gm.GetGroupByID(groupID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("group with id %d not found", groupID)
|
return "", fmt.Errorf("group %d not found", groupID)
|
||||||
}
|
}
|
||||||
opConfig, err := gm.BuildOperationalConfig(group)
|
opConfig, err := gm.BuildOperationalConfig(group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
|
return "", err
|
||||||
}
|
}
|
||||||
globalSettings := gm.settingsManager.GetSettings()
|
globalSettings := gm.settingsManager.GetSettings()
|
||||||
baseURL := globalSettings.DefaultUpstreamURL
|
baseURL := globalSettings.DefaultUpstreamURL
|
||||||
@@ -551,7 +386,7 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
|||||||
baseURL = *opConfig.KeyCheckEndpoint
|
baseURL = *opConfig.KeyCheckEndpoint
|
||||||
}
|
}
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
|
return "", fmt.Errorf("no endpoint configured for group %d", groupID)
|
||||||
}
|
}
|
||||||
modelName := globalSettings.BaseKeyCheckModel
|
modelName := globalSettings.BaseKeyCheckModel
|
||||||
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
|
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
|
||||||
@@ -559,38 +394,31 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
|||||||
}
|
}
|
||||||
parsedURL, err := url.Parse(baseURL)
|
parsedURL, err := url.Parse(baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
|
return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err)
|
||||||
}
|
}
|
||||||
cleanedPath := parsedURL.Path
|
cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta")
|
||||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
|
parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName)
|
||||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
|
return parsedURL.String(), nil
|
||||||
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
|
|
||||||
finalEndpoint := parsedURL.String()
|
|
||||||
return finalEndpoint, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
|
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
|
||||||
ordersMap := make(map[uint]int, len(payload))
|
ordersMap := make(map[uint]int, len(payload))
|
||||||
for _, item := range payload {
|
for _, item := range payload {
|
||||||
ordersMap[item.ID] = item.Order
|
ordersMap[item.ID] = item.Order
|
||||||
}
|
}
|
||||||
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
|
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
|
||||||
gm.logger.WithError(err).Error("Failed to update group order in transaction")
|
return fmt.Errorf("failed to update order: %w", err)
|
||||||
return fmt.Errorf("database transaction failed: %w", err)
|
|
||||||
}
|
}
|
||||||
gm.logger.Info("Group order updated successfully, invalidating cache...")
|
|
||||||
go gm.Invalidate()
|
go gm.Invalidate()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func uniqueStrings(slice []string) []string {
|
func uniqueStrings(slice []string) []string {
|
||||||
keys := make(map[string]struct{})
|
seen := make(map[string]struct{}, len(slice))
|
||||||
list := []string{}
|
result := make([]string, 0, len(slice))
|
||||||
for _, entry := range slice {
|
for _, s := range slice {
|
||||||
if _, value := keys[entry]; !value {
|
if _, exists := seen[s]; !exists {
|
||||||
keys[entry] = struct{}{}
|
seen[s] = struct{}{}
|
||||||
list = append(list, entry)
|
result = append(result, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return list
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -704,15 +704,17 @@ func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
|||||||
newList := make([]string, 0, len(list))
|
newList := make([]string, 0, len(list))
|
||||||
removed := int64(0)
|
removed := int64(0)
|
||||||
for _, v := range list {
|
for _, v := range list {
|
||||||
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
|
shouldRemove := v == capturedValue && (count == 0 || removed < count)
|
||||||
|
if shouldRemove {
|
||||||
removed++
|
removed++
|
||||||
continue
|
} else {
|
||||||
|
newList = append(newList, v)
|
||||||
}
|
}
|
||||||
newList = append(newList, v)
|
|
||||||
}
|
}
|
||||||
item.value = newList
|
item.value = newList
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||||
capturedKey := key
|
capturedKey := key
|
||||||
capturedValues := make(map[string]any, len(values))
|
capturedValues := make(map[string]any, len(values))
|
||||||
@@ -762,17 +764,31 @@ func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
|
|||||||
p.ops = append(p.ops, func() {
|
p.ops = append(p.ops, func() {
|
||||||
item, ok := p.store.items[capturedKey]
|
item, ok := p.store.items[capturedKey]
|
||||||
if !ok || item.isExpired() {
|
if !ok || item.isExpired() {
|
||||||
item = &memoryStoreItem{value: make(map[string]float64)}
|
item = &memoryStoreItem{value: make([]zsetMember, 0)}
|
||||||
p.store.items[capturedKey] = item
|
p.store.items[capturedKey] = item
|
||||||
}
|
}
|
||||||
zset, ok := item.value.(map[string]float64)
|
zset, ok := item.value.([]zsetMember)
|
||||||
if !ok {
|
if !ok {
|
||||||
zset = make(map[string]float64)
|
zset = make([]zsetMember, 0)
|
||||||
item.value = zset
|
|
||||||
}
|
}
|
||||||
for member, score := range capturedMembers {
|
membersMap := make(map[string]float64, len(zset))
|
||||||
zset[member] = score
|
for _, z := range zset {
|
||||||
|
membersMap[z.Value] = z.Score
|
||||||
}
|
}
|
||||||
|
for memberVal, score := range capturedMembers {
|
||||||
|
membersMap[memberVal] = score
|
||||||
|
}
|
||||||
|
newZSet := make([]zsetMember, 0, len(membersMap))
|
||||||
|
for val, score := range membersMap {
|
||||||
|
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
|
||||||
|
}
|
||||||
|
sort.Slice(newZSet, func(i, j int) bool {
|
||||||
|
if newZSet[i].Score == newZSet[j].Score {
|
||||||
|
return newZSet[i].Value < newZSet[j].Value
|
||||||
|
}
|
||||||
|
return newZSet[i].Score < newZSet[j].Score
|
||||||
|
})
|
||||||
|
item.value = newZSet
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
||||||
@@ -784,13 +800,21 @@ func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
|||||||
if !ok || item.isExpired() {
|
if !ok || item.isExpired() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
zset, ok := item.value.(map[string]float64)
|
zset, ok := item.value.([]zsetMember)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, member := range capturedMembers {
|
membersToRemove := make(map[string]struct{}, len(capturedMembers))
|
||||||
delete(zset, fmt.Sprintf("%v", member))
|
for _, m := range capturedMembers {
|
||||||
|
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
|
||||||
}
|
}
|
||||||
|
newZSet := make([]zsetMember, 0, len(zset))
|
||||||
|
for _, z := range zset {
|
||||||
|
if _, exists := membersToRemove[z.Value]; !exists {
|
||||||
|
newZSet = append(newZSet, z)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
item.value = newZSet
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user