// Filename: internal/service/group_manager.go package service import ( "encoding/json" "errors" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/pkg/reflectutil" "gemini-balancer/internal/repository" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" "gemini-balancer/internal/syncer" "gemini-balancer/internal/utils" "net/url" "path" "sort" "strings" "time" "github.com/sirupsen/logrus" "gorm.io/datatypes" "gorm.io/gorm" ) const GroupUpdateChannel = "groups:cache_invalidation" type GroupManagerCacheData struct { Groups []*models.KeyGroup GroupsByName map[string]*models.KeyGroup GroupsByID map[uint]*models.KeyGroup KeyCounts map[uint]int64 KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 } type GroupManager struct { db *gorm.DB keyRepo repository.KeyRepository groupRepo repository.GroupRepository settingsManager *settings.SettingsManager syncer *syncer.CacheSyncer[GroupManagerCacheData] logger *logrus.Entry } type UpdateOrderPayload struct { ID uint `json:"id" binding:"required"` Order int `json:"order"` } func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] { return func() (GroupManagerCacheData, error) { var groups []*models.KeyGroup if err := db.Preload("AllowedUpstreams"). Preload("AllowedModels"). Preload("Settings"). Preload("RequestConfig"). Preload("Mappings"). Find(&groups).Error; err != nil { return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err) } keyCounts := make(map[uint]int64, len(groups)) keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups)) groupsByName := make(map[string]*models.KeyGroup, len(groups)) groupsByID := make(map[uint]*models.KeyGroup, len(groups)) for _, group := range groups { keyCounts[group.ID] = int64(len(group.Mappings)) statusCounts := make(map[models.APIKeyStatus]int64) for _, mapping := range group.Mappings { statusCounts[mapping.Status]++ } keyStatusCounts[group.ID] = statusCounts groupsByName[group.Name] = group groupsByID[group.ID] = group } return GroupManagerCacheData{ Groups: groups, GroupsByName: groupsByName, GroupsByID: groupsByID, KeyCounts: keyCounts, KeyStatusCounts: keyStatusCounts, }, nil } } func NewGroupManager( db *gorm.DB, keyRepo repository.KeyRepository, groupRepo repository.GroupRepository, sm *settings.SettingsManager, syncer *syncer.CacheSyncer[GroupManagerCacheData], logger *logrus.Logger, ) *GroupManager { return &GroupManager{ db: db, keyRepo: keyRepo, groupRepo: groupRepo, settingsManager: sm, syncer: syncer, logger: logger.WithField("component", "GroupManager"), } } func (gm *GroupManager) GetAllGroups() []*models.KeyGroup { groups := gm.syncer.Get().Groups sort.Slice(groups, func(i, j int) bool { if groups[i].Order != groups[j].Order { return groups[i].Order < groups[j].Order } return groups[i].ID < groups[j].ID }) return groups } func (gm *GroupManager) GetKeyCount(groupID uint) int64 { return gm.syncer.Get().KeyCounts[groupID] } func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 { if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok { return counts } return make(map[models.APIKeyStatus]int64) } func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) { group, ok := gm.syncer.Get().GroupsByName[name] return group, ok } func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) { group, ok := gm.syncer.Get().GroupsByID[id] return group, ok } func (gm *GroupManager) Stop() { gm.syncer.Stop() } func (gm *GroupManager) Invalidate() error { return gm.syncer.Invalidate() } func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error { if !utils.IsValidGroupName(group.Name) { return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens") } err := gm.db.Transaction(func(tx *gorm.DB) error { if err := tx.Create(group).Error; err != nil { return err } if settings != nil { settingsJSON, err := json.Marshal(settings) if err != nil { return fmt.Errorf("failed to marshal settings: %w", err) } groupSettings := models.GroupSettings{ GroupID: group.ID, SettingsJSON: datatypes.JSON(settingsJSON), } if err := tx.Create(&groupSettings).Error; err != nil { return err } } return nil }) if err == nil { go gm.Invalidate() } return err } func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error { if !utils.IsValidGroupName(group.Name) { return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens") } uniqueUpstreamURLs := uniqueStrings(upstreamURLs) uniqueModelNames := uniqueStrings(modelNames) err := gm.db.Transaction(func(tx *gorm.DB) error { var upstreams []*models.UpstreamEndpoint if len(uniqueUpstreamURLs) > 0 { if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil { return err } } if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil { return err } if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil { return err } if len(uniqueModelNames) > 0 { var newMappings []models.GroupModelMapping for _, name := range uniqueModelNames { newMappings = append(newMappings, models.GroupModelMapping{ModelName: name}) } if err := tx.Model(group).Association("AllowedModels").Append(newMappings); err != nil { return err } } if err := tx.Model(group).Updates(group).Error; err != nil { return err } var existingSettings models.GroupSettings if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound { return err } var currentSettingsData models.KeyGroupSettings if len(existingSettings.SettingsJSON) > 0 { if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil { return fmt.Errorf("failed to unmarshal existing settings: %w", err) } } if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil { return fmt.Errorf("failed to merge settings: %w", err) } updatedJSON, err := json.Marshal(currentSettingsData) if err != nil { return fmt.Errorf("failed to marshal updated settings: %w", err) } existingSettings.GroupID = group.ID existingSettings.SettingsJSON = datatypes.JSON(updatedJSON) return tx.Save(&existingSettings).Error }) if err == nil { go gm.Invalidate() } return err } func (gm *GroupManager) DeleteKeyGroup(id uint) error { err := gm.db.Transaction(func(tx *gorm.DB) error { var group models.KeyGroup if err := tx.First(&group, id).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil } return err } if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil { return err } deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx) if err != nil { return err } if deletedCount > 0 { gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id) } return nil }) if err == nil { go gm.Invalidate() } return err } func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { var originalGroup models.KeyGroup if err := gm.db. Preload("RequestConfig"). Preload("Mappings"). Preload("AllowedUpstreams"). Preload("AllowedModels"). First(&originalGroup, id).Error; err != nil { return nil, fmt.Errorf("failed to find original group %d: %w", id, err) } newGroup := originalGroup timestamp := time.Now().Unix() newGroup.ID = 0 newGroup.Name = fmt.Sprintf("%s-clone-%d", originalGroup.Name, timestamp) newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp) newGroup.CreatedAt = time.Time{} newGroup.UpdatedAt = time.Time{} newGroup.RequestConfigID = nil newGroup.RequestConfig = nil newGroup.Mappings = nil newGroup.AllowedUpstreams = nil newGroup.AllowedModels = nil err := gm.db.Transaction(func(tx *gorm.DB) error { if err := tx.Create(&newGroup).Error; err != nil { return err } if originalGroup.RequestConfig != nil { newRequestConfig := *originalGroup.RequestConfig newRequestConfig.ID = 0 if err := tx.Create(&newRequestConfig).Error; err != nil { return fmt.Errorf("failed to clone request config: %w", err) } if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil { return fmt.Errorf("failed to link cloned request config: %w", err) } } var originalSettings models.GroupSettings err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error if err == nil && len(originalSettings.SettingsJSON) > 0 { newSettings := models.GroupSettings{ GroupID: newGroup.ID, SettingsJSON: originalSettings.SettingsJSON, } if err := tx.Create(&newSettings).Error; err != nil { return fmt.Errorf("failed to clone settings: %w", err) } } else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("failed to query original settings: %w", err) } if len(originalGroup.Mappings) > 0 { newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings)) for i, oldMapping := range originalGroup.Mappings { newMappings[i] = models.GroupAPIKeyMapping{ KeyGroupID: newGroup.ID, APIKeyID: oldMapping.APIKeyID, Status: oldMapping.Status, LastError: oldMapping.LastError, ConsecutiveErrorCount: oldMapping.ConsecutiveErrorCount, LastUsedAt: oldMapping.LastUsedAt, CooldownUntil: oldMapping.CooldownUntil, } } if err := tx.Create(&newMappings).Error; err != nil { return fmt.Errorf("failed to clone mappings: %w", err) } } if len(originalGroup.AllowedUpstreams) > 0 { if err := tx.Model(&newGroup).Association("AllowedUpstreams").Append(originalGroup.AllowedUpstreams); err != nil { return err } } if len(originalGroup.AllowedModels) > 0 { if err := tx.Model(&newGroup).Association("AllowedModels").Append(originalGroup.AllowedModels); err != nil { return err } } return nil }) if err != nil { return nil, err } go gm.Invalidate() var finalClonedGroup models.KeyGroup if err := gm.db. Preload("RequestConfig"). Preload("Mappings"). Preload("AllowedUpstreams"). Preload("AllowedModels"). First(&finalClonedGroup, newGroup.ID).Error; err != nil { return nil, err } return &finalClonedGroup, nil } func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) { globalSettings := gm.settingsManager.GetSettings() opConfig := &models.KeyGroupSettings{ EnableKeyCheck: &globalSettings.EnableBaseKeyCheck, KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency, KeyCheckIntervalMinutes: &globalSettings.BaseKeyCheckIntervalMinutes, KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL, KeyBlacklistThreshold: &globalSettings.BlacklistThreshold, KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes, KeyCheckModel: &globalSettings.BaseKeyCheckModel, MaxRetries: &globalSettings.MaxRetries, EnableSmartGateway: &globalSettings.EnableSmartGateway, } if group == nil { return opConfig, nil } var groupSettingsRecord models.GroupSettings err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return opConfig, nil } return nil, err } if len(groupSettingsRecord.SettingsJSON) == 0 { return opConfig, nil } var groupSpecificSettings models.KeyGroupSettings if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil { gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings") return opConfig, nil } if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil { gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings") return opConfig, nil } return opConfig, nil } func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) { group, ok := gm.GetGroupByID(groupID) if !ok { return "", fmt.Errorf("group %d not found", groupID) } opConfig, err := gm.BuildOperationalConfig(group) if err != nil { return "", err } globalSettings := gm.settingsManager.GetSettings() baseURL := globalSettings.DefaultUpstreamURL if opConfig.KeyCheckEndpoint != nil && *opConfig.KeyCheckEndpoint != "" { baseURL = *opConfig.KeyCheckEndpoint } if baseURL == "" { return "", fmt.Errorf("no endpoint configured for group %d", groupID) } modelName := globalSettings.BaseKeyCheckModel if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" { modelName = *opConfig.KeyCheckModel } parsedURL, err := url.Parse(baseURL) if err != nil { return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err) } cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta") parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName) return parsedURL.String(), nil } func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error { ordersMap := make(map[uint]int, len(payload)) for _, item := range payload { ordersMap[item.ID] = item.Order } if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil { return fmt.Errorf("failed to update order: %w", err) } go gm.Invalidate() return nil } func uniqueStrings(slice []string) []string { seen := make(map[string]struct{}, len(slice)) result := make([]string, 0, len(slice)) for _, s := range slice { if _, exists := seen[s]; !exists { seen[s] = struct{}{} result = append(result, s) } } return result } // GroupManager配置Syncer func NewGroupManagerSyncer( loader syncer.LoaderFunc[GroupManagerCacheData], store store.Store, logger *logrus.Logger, ) (*syncer.CacheSyncer[GroupManagerCacheData], error) { const groupUpdateChannel = "groups:cache_invalidation" return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger) }