Files
gemini-banlancer/internal/service/group_manager.go
2025-11-25 16:58:15 +08:00

435 lines
14 KiB
Go

// Filename: internal/service/group_manager.go
package service
import (
"encoding/json"
"errors"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/pkg/reflectutil"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"gemini-balancer/internal/utils"
"net/url"
"path"
"sort"
"strings"
"time"
"github.com/sirupsen/logrus"
"gorm.io/datatypes"
"gorm.io/gorm"
)
const GroupUpdateChannel = "groups:cache_invalidation"
type GroupManagerCacheData struct {
Groups []*models.KeyGroup
GroupsByName map[string]*models.KeyGroup
GroupsByID map[uint]*models.KeyGroup
KeyCounts map[uint]int64
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64
}
type GroupManager struct {
db *gorm.DB
keyRepo repository.KeyRepository
groupRepo repository.GroupRepository
settingsManager *settings.SettingsManager
syncer *syncer.CacheSyncer[GroupManagerCacheData]
logger *logrus.Entry
}
type UpdateOrderPayload struct {
ID uint `json:"id" binding:"required"`
Order int `json:"order"`
}
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
return func() (GroupManagerCacheData, error) {
var groups []*models.KeyGroup
if err := db.Preload("AllowedUpstreams").
Preload("AllowedModels").
Preload("Settings").
Preload("RequestConfig").
Preload("Mappings").
Find(&groups).Error; err != nil {
return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err)
}
keyCounts := make(map[uint]int64, len(groups))
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
groupsByName := make(map[string]*models.KeyGroup, len(groups))
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
for _, group := range groups {
keyCounts[group.ID] = int64(len(group.Mappings))
statusCounts := make(map[models.APIKeyStatus]int64)
for _, mapping := range group.Mappings {
statusCounts[mapping.Status]++
}
keyStatusCounts[group.ID] = statusCounts
groupsByName[group.Name] = group
groupsByID[group.ID] = group
}
return GroupManagerCacheData{
Groups: groups,
GroupsByName: groupsByName,
GroupsByID: groupsByID,
KeyCounts: keyCounts,
KeyStatusCounts: keyStatusCounts,
}, nil
}
}
func NewGroupManager(
db *gorm.DB,
keyRepo repository.KeyRepository,
groupRepo repository.GroupRepository,
sm *settings.SettingsManager,
syncer *syncer.CacheSyncer[GroupManagerCacheData],
logger *logrus.Logger,
) *GroupManager {
return &GroupManager{
db: db,
keyRepo: keyRepo,
groupRepo: groupRepo,
settingsManager: sm,
syncer: syncer,
logger: logger.WithField("component", "GroupManager"),
}
}
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
groups := gm.syncer.Get().Groups
sort.Slice(groups, func(i, j int) bool {
if groups[i].Order != groups[j].Order {
return groups[i].Order < groups[j].Order
}
return groups[i].ID < groups[j].ID
})
return groups
}
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
return gm.syncer.Get().KeyCounts[groupID]
}
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok {
return counts
}
return make(map[models.APIKeyStatus]int64)
}
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
group, ok := gm.syncer.Get().GroupsByName[name]
return group, ok
}
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
group, ok := gm.syncer.Get().GroupsByID[id]
return group, ok
}
func (gm *GroupManager) Stop() {
gm.syncer.Stop()
}
func (gm *GroupManager) Invalidate() error {
return gm.syncer.Invalidate()
}
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
if !utils.IsValidGroupName(group.Name) {
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
}
err := gm.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(group).Error; err != nil {
return err
}
if settings != nil {
settingsJSON, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("failed to marshal settings: %w", err)
}
groupSettings := models.GroupSettings{
GroupID: group.ID,
SettingsJSON: datatypes.JSON(settingsJSON),
}
if err := tx.Create(&groupSettings).Error; err != nil {
return err
}
}
return nil
})
if err == nil {
go gm.Invalidate()
}
return err
}
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
if !utils.IsValidGroupName(group.Name) {
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
}
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
uniqueModelNames := uniqueStrings(modelNames)
err := gm.db.Transaction(func(tx *gorm.DB) error {
var upstreams []*models.UpstreamEndpoint
if len(uniqueUpstreamURLs) > 0 {
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
return err
}
}
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
return err
}
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
return err
}
if len(uniqueModelNames) > 0 {
var newMappings []models.GroupModelMapping
for _, name := range uniqueModelNames {
newMappings = append(newMappings, models.GroupModelMapping{ModelName: name})
}
if err := tx.Model(group).Association("AllowedModels").Append(newMappings); err != nil {
return err
}
}
if err := tx.Model(group).Updates(group).Error; err != nil {
return err
}
var existingSettings models.GroupSettings
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
return err
}
var currentSettingsData models.KeyGroupSettings
if len(existingSettings.SettingsJSON) > 0 {
if err := json.Unmarshal(existingSettings.SettingsJSON, &currentSettingsData); err != nil {
return fmt.Errorf("failed to unmarshal existing settings: %w", err)
}
}
if err := reflectutil.MergeNilFields(&currentSettingsData, newSettings); err != nil {
return fmt.Errorf("failed to merge settings: %w", err)
}
updatedJSON, err := json.Marshal(currentSettingsData)
if err != nil {
return fmt.Errorf("failed to marshal updated settings: %w", err)
}
existingSettings.GroupID = group.ID
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
return tx.Save(&existingSettings).Error
})
if err == nil {
go gm.Invalidate()
}
return err
}
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
err := gm.db.Transaction(func(tx *gorm.DB) error {
var group models.KeyGroup
if err := tx.First(&group, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil
}
return err
}
if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil {
return err
}
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
if err != nil {
return err
}
if deletedCount > 0 {
gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id)
}
return nil
})
if err == nil {
go gm.Invalidate()
}
return err
}
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
var originalGroup models.KeyGroup
if err := gm.db.
Preload("RequestConfig").
Preload("Mappings").
Preload("AllowedUpstreams").
Preload("AllowedModels").
First(&originalGroup, id).Error; err != nil {
return nil, fmt.Errorf("failed to find original group %d: %w", id, err)
}
newGroup := originalGroup
timestamp := time.Now().Unix()
newGroup.ID = 0
newGroup.Name = fmt.Sprintf("%s-clone-%d", originalGroup.Name, timestamp)
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
newGroup.CreatedAt = time.Time{}
newGroup.UpdatedAt = time.Time{}
newGroup.RequestConfigID = nil
newGroup.RequestConfig = nil
newGroup.Mappings = nil
newGroup.AllowedUpstreams = nil
newGroup.AllowedModels = nil
err := gm.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&newGroup).Error; err != nil {
return err
}
if originalGroup.RequestConfig != nil {
newRequestConfig := *originalGroup.RequestConfig
newRequestConfig.ID = 0
if err := tx.Create(&newRequestConfig).Error; err != nil {
return fmt.Errorf("failed to clone request config: %w", err)
}
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
return fmt.Errorf("failed to link cloned request config: %w", err)
}
}
var originalSettings models.GroupSettings
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
if err == nil && len(originalSettings.SettingsJSON) > 0 {
newSettings := models.GroupSettings{
GroupID: newGroup.ID,
SettingsJSON: originalSettings.SettingsJSON,
}
if err := tx.Create(&newSettings).Error; err != nil {
return fmt.Errorf("failed to clone settings: %w", err)
}
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("failed to query original settings: %w", err)
}
if len(originalGroup.Mappings) > 0 {
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
for i, oldMapping := range originalGroup.Mappings {
newMappings[i] = models.GroupAPIKeyMapping{
KeyGroupID: newGroup.ID,
APIKeyID: oldMapping.APIKeyID,
Status: oldMapping.Status,
LastError: oldMapping.LastError,
ConsecutiveErrorCount: oldMapping.ConsecutiveErrorCount,
LastUsedAt: oldMapping.LastUsedAt,
CooldownUntil: oldMapping.CooldownUntil,
}
}
if err := tx.Create(&newMappings).Error; err != nil {
return fmt.Errorf("failed to clone mappings: %w", err)
}
}
if len(originalGroup.AllowedUpstreams) > 0 {
if err := tx.Model(&newGroup).Association("AllowedUpstreams").Append(originalGroup.AllowedUpstreams); err != nil {
return err
}
}
if len(originalGroup.AllowedModels) > 0 {
if err := tx.Model(&newGroup).Association("AllowedModels").Append(originalGroup.AllowedModels); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
go gm.Invalidate()
var finalClonedGroup models.KeyGroup
if err := gm.db.
Preload("RequestConfig").
Preload("Mappings").
Preload("AllowedUpstreams").
Preload("AllowedModels").
First(&finalClonedGroup, newGroup.ID).Error; err != nil {
return nil, err
}
return &finalClonedGroup, nil
}
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
globalSettings := gm.settingsManager.GetSettings()
opConfig := &models.KeyGroupSettings{
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
KeyCheckIntervalMinutes: &globalSettings.BaseKeyCheckIntervalMinutes,
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
KeyCheckModel: &globalSettings.BaseKeyCheckModel,
MaxRetries: &globalSettings.MaxRetries,
EnableSmartGateway: &globalSettings.EnableSmartGateway,
}
if group == nil {
return opConfig, nil
}
var groupSettingsRecord models.GroupSettings
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return opConfig, nil
}
return nil, err
}
if len(groupSettingsRecord.SettingsJSON) == 0 {
return opConfig, nil
}
var groupSpecificSettings models.KeyGroupSettings
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings")
return opConfig, nil
}
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings")
return opConfig, nil
}
return opConfig, nil
}
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
group, ok := gm.GetGroupByID(groupID)
if !ok {
return "", fmt.Errorf("group %d not found", groupID)
}
opConfig, err := gm.BuildOperationalConfig(group)
if err != nil {
return "", err
}
globalSettings := gm.settingsManager.GetSettings()
baseURL := globalSettings.DefaultUpstreamURL
if opConfig.KeyCheckEndpoint != nil && *opConfig.KeyCheckEndpoint != "" {
baseURL = *opConfig.KeyCheckEndpoint
}
if baseURL == "" {
return "", fmt.Errorf("no endpoint configured for group %d", groupID)
}
modelName := globalSettings.BaseKeyCheckModel
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
modelName = *opConfig.KeyCheckModel
}
parsedURL, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err)
}
cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta")
parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName)
return parsedURL.String(), nil
}
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
ordersMap := make(map[uint]int, len(payload))
for _, item := range payload {
ordersMap[item.ID] = item.Order
}
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
return fmt.Errorf("failed to update order: %w", err)
}
go gm.Invalidate()
return nil
}
func uniqueStrings(slice []string) []string {
seen := make(map[string]struct{}, len(slice))
result := make([]string, 0, len(slice))
for _, s := range slice {
if _, exists := seen[s]; !exists {
seen[s] = struct{}{}
result = append(result, s)
}
}
return result
}
// GroupManager配置Syncer
func NewGroupManagerSyncer(
loader syncer.LoaderFunc[GroupManagerCacheData],
store store.Store,
logger *logrus.Logger,
) (*syncer.CacheSyncer[GroupManagerCacheData], error) {
const groupUpdateChannel = "groups:cache_invalidation"
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger)
}