435 lines
14 KiB
Go
435 lines
14 KiB
Go
// Filename: internal/service/group_manager.go
|
|
package service
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/pkg/reflectutil"
|
|
"gemini-balancer/internal/repository"
|
|
"gemini-balancer/internal/settings"
|
|
"gemini-balancer/internal/store"
|
|
"gemini-balancer/internal/syncer"
|
|
"gemini-balancer/internal/utils"
|
|
"net/url"
|
|
"path"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/datatypes"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const GroupUpdateChannel = "groups:cache_invalidation"
|
|
|
|
type GroupManagerCacheData struct {
|
|
Groups []*models.KeyGroup
|
|
GroupsByName map[string]*models.KeyGroup
|
|
GroupsByID map[uint]*models.KeyGroup
|
|
KeyCounts map[uint]int64
|
|
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64
|
|
}
|
|
type GroupManager struct {
|
|
db *gorm.DB
|
|
keyRepo repository.KeyRepository
|
|
groupRepo repository.GroupRepository
|
|
settingsManager *settings.SettingsManager
|
|
syncer *syncer.CacheSyncer[GroupManagerCacheData]
|
|
logger *logrus.Entry
|
|
}
|
|
type UpdateOrderPayload struct {
|
|
ID uint `json:"id" binding:"required"`
|
|
Order int `json:"order"`
|
|
}
|
|
|
|
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
|
|
return func() (GroupManagerCacheData, error) {
|
|
var groups []*models.KeyGroup
|
|
if err := db.Preload("AllowedUpstreams").
|
|
Preload("AllowedModels").
|
|
Preload("Settings").
|
|
Preload("RequestConfig").
|
|
Preload("Mappings").
|
|
Find(&groups).Error; err != nil {
|
|
return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err)
|
|
}
|
|
keyCounts := make(map[uint]int64, len(groups))
|
|
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
|
|
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
|
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
|
for _, group := range groups {
|
|
keyCounts[group.ID] = int64(len(group.Mappings))
|
|
statusCounts := make(map[models.APIKeyStatus]int64)
|
|
for _, mapping := range group.Mappings {
|
|
statusCounts[mapping.Status]++
|
|
}
|
|
keyStatusCounts[group.ID] = statusCounts
|
|
groupsByName[group.Name] = group
|
|
groupsByID[group.ID] = group
|
|
}
|
|
return GroupManagerCacheData{
|
|
Groups: groups,
|
|
GroupsByName: groupsByName,
|
|
GroupsByID: groupsByID,
|
|
KeyCounts: keyCounts,
|
|
KeyStatusCounts: keyStatusCounts,
|
|
}, nil
|
|
}
|
|
}
|
|
func NewGroupManager(
|
|
db *gorm.DB,
|
|
keyRepo repository.KeyRepository,
|
|
groupRepo repository.GroupRepository,
|
|
sm *settings.SettingsManager,
|
|
syncer *syncer.CacheSyncer[GroupManagerCacheData],
|
|
logger *logrus.Logger,
|
|
) *GroupManager {
|
|
return &GroupManager{
|
|
db: db,
|
|
keyRepo: keyRepo,
|
|
groupRepo: groupRepo,
|
|
settingsManager: sm,
|
|
syncer: syncer,
|
|
logger: logger.WithField("component", "GroupManager"),
|
|
}
|
|
}
|
|
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
|
|
groups := gm.syncer.Get().Groups
|
|
sort.Slice(groups, func(i, j int) bool {
|
|
if groups[i].Order != groups[j].Order {
|
|
return groups[i].Order < groups[j].Order
|
|
}
|
|
return groups[i].ID < groups[j].ID
|
|
})
|
|
return groups
|
|
}
|
|
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
|
|
return gm.syncer.Get().KeyCounts[groupID]
|
|
}
|
|
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
|
|
if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok {
|
|
return counts
|
|
}
|
|
return make(map[models.APIKeyStatus]int64)
|
|
}
|
|
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
|
|
group, ok := gm.syncer.Get().GroupsByName[name]
|
|
return group, ok
|
|
}
|
|
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
|
|
group, ok := gm.syncer.Get().GroupsByID[id]
|
|
return group, ok
|
|
}
|
|
func (gm *GroupManager) Stop() {
|
|
gm.syncer.Stop()
|
|
}
|
|
func (gm *GroupManager) Invalidate() error {
|
|
return gm.syncer.Invalidate()
|
|
}
|
|
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
|
|
if !utils.IsValidGroupName(group.Name) {
|
|
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
|
}
|
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Create(group).Error; err != nil {
|
|
return err
|
|
}
|
|
if settings != nil {
|
|
settingsJSON, err := json.Marshal(settings)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal settings: %w", err)
|
|
}
|
|
groupSettings := models.GroupSettings{
|
|
GroupID: group.ID,
|
|
SettingsJSON: datatypes.JSON(settingsJSON),
|
|
}
|
|
if err := tx.Create(&groupSettings).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
go gm.Invalidate()
|
|
}
|
|
return err
|
|
}
|
|
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
|
|
if !utils.IsValidGroupName(group.Name) {
|
|
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
|
}
|
|
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
|
|
uniqueModelNames := uniqueStrings(modelNames)
|
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
|
var upstreams []*models.UpstreamEndpoint
|
|
if len(uniqueUpstreamURLs) > 0 {
|
|
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
|
|
return err
|
|
}
|
|
if len(uniqueModelNames) > 0 {
|
|
var newMappings []models.GroupModelMapping
|
|
for _, name := range uniqueModelNames {
|
|
newMappings = append(newMappings, models.GroupModelMapping{ModelName: name})
|
|
}
|
|
if err := tx.Model(group).Association("AllowedModels").Append(newMappings); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if err := tx.Model(group).Updates(group).Error; err != nil {
|
|
return err
|
|
}
|
|
var existingSettings models.GroupSettings
|
|
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
|
|
return err
|
|
}
|
|
var currentSettingsData models.KeyGroupSettings
|
|
if len(existingSettings.SettingsJSON) > 0 {
|
|
if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil {
|
|
return fmt.Errorf("failed to unmarshal existing settings: %w", err)
|
|
}
|
|
}
|
|
if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil {
|
|
return fmt.Errorf("failed to merge settings: %w", err)
|
|
}
|
|
updatedJSON, err := json.Marshal(currentSettingsData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal updated settings: %w", err)
|
|
}
|
|
existingSettings.GroupID = group.ID
|
|
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
|
|
return tx.Save(&existingSettings).Error
|
|
})
|
|
if err == nil {
|
|
go gm.Invalidate()
|
|
}
|
|
return err
|
|
}
|
|
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
|
var group models.KeyGroup
|
|
if err := tx.First(&group, id).Error; err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil {
|
|
return err
|
|
}
|
|
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if deletedCount > 0 {
|
|
gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id)
|
|
}
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
go gm.Invalidate()
|
|
}
|
|
return err
|
|
}
|
|
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
|
var originalGroup models.KeyGroup
|
|
if err := gm.db.
|
|
Preload("RequestConfig").
|
|
Preload("Mappings").
|
|
Preload("AllowedUpstreams").
|
|
Preload("AllowedModels").
|
|
First(&originalGroup, id).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to find original group %d: %w", id, err)
|
|
}
|
|
newGroup := originalGroup
|
|
timestamp := time.Now().Unix()
|
|
newGroup.ID = 0
|
|
newGroup.Name = fmt.Sprintf("%s-clone-%d", originalGroup.Name, timestamp)
|
|
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
|
|
newGroup.CreatedAt = time.Time{}
|
|
newGroup.UpdatedAt = time.Time{}
|
|
newGroup.RequestConfigID = nil
|
|
newGroup.RequestConfig = nil
|
|
newGroup.Mappings = nil
|
|
newGroup.AllowedUpstreams = nil
|
|
newGroup.AllowedModels = nil
|
|
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Create(&newGroup).Error; err != nil {
|
|
return err
|
|
}
|
|
if originalGroup.RequestConfig != nil {
|
|
newRequestConfig := *originalGroup.RequestConfig
|
|
newRequestConfig.ID = 0
|
|
if err := tx.Create(&newRequestConfig).Error; err != nil {
|
|
return fmt.Errorf("failed to clone request config: %w", err)
|
|
}
|
|
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
|
|
return fmt.Errorf("failed to link cloned request config: %w", err)
|
|
}
|
|
}
|
|
var originalSettings models.GroupSettings
|
|
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
|
|
if err == nil && len(originalSettings.SettingsJSON) > 0 {
|
|
newSettings := models.GroupSettings{
|
|
GroupID: newGroup.ID,
|
|
SettingsJSON: originalSettings.SettingsJSON,
|
|
}
|
|
if err := tx.Create(&newSettings).Error; err != nil {
|
|
return fmt.Errorf("failed to clone settings: %w", err)
|
|
}
|
|
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("failed to query original settings: %w", err)
|
|
}
|
|
if len(originalGroup.Mappings) > 0 {
|
|
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
|
|
for i, oldMapping := range originalGroup.Mappings {
|
|
newMappings[i] = models.GroupAPIKeyMapping{
|
|
KeyGroupID: newGroup.ID,
|
|
APIKeyID: oldMapping.APIKeyID,
|
|
Status: oldMapping.Status,
|
|
LastError: oldMapping.LastError,
|
|
ConsecutiveErrorCount: oldMapping.ConsecutiveErrorCount,
|
|
LastUsedAt: oldMapping.LastUsedAt,
|
|
CooldownUntil: oldMapping.CooldownUntil,
|
|
}
|
|
}
|
|
if err := tx.Create(&newMappings).Error; err != nil {
|
|
return fmt.Errorf("failed to clone mappings: %w", err)
|
|
}
|
|
}
|
|
if len(originalGroup.AllowedUpstreams) > 0 {
|
|
if err := tx.Model(&newGroup).Association("AllowedUpstreams").Append(originalGroup.AllowedUpstreams); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if len(originalGroup.AllowedModels) > 0 {
|
|
if err := tx.Model(&newGroup).Association("AllowedModels").Append(originalGroup.AllowedModels); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go gm.Invalidate()
|
|
var finalClonedGroup models.KeyGroup
|
|
if err := gm.db.
|
|
Preload("RequestConfig").
|
|
Preload("Mappings").
|
|
Preload("AllowedUpstreams").
|
|
Preload("AllowedModels").
|
|
First(&finalClonedGroup, newGroup.ID).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &finalClonedGroup, nil
|
|
}
|
|
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
|
globalSettings := gm.settingsManager.GetSettings()
|
|
opConfig := &models.KeyGroupSettings{
|
|
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
|
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
|
KeyCheckIntervalMinutes: &globalSettings.BaseKeyCheckIntervalMinutes,
|
|
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
|
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
|
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
|
KeyCheckModel: &globalSettings.BaseKeyCheckModel,
|
|
MaxRetries: &globalSettings.MaxRetries,
|
|
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
|
}
|
|
if group == nil {
|
|
return opConfig, nil
|
|
}
|
|
var groupSettingsRecord models.GroupSettings
|
|
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return opConfig, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if len(groupSettingsRecord.SettingsJSON) == 0 {
|
|
return opConfig, nil
|
|
}
|
|
var groupSpecificSettings models.KeyGroupSettings
|
|
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
|
|
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings")
|
|
return opConfig, nil
|
|
}
|
|
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
|
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings")
|
|
return opConfig, nil
|
|
}
|
|
return opConfig, nil
|
|
}
|
|
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
|
group, ok := gm.GetGroupByID(groupID)
|
|
if !ok {
|
|
return "", fmt.Errorf("group %d not found", groupID)
|
|
}
|
|
opConfig, err := gm.BuildOperationalConfig(group)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
globalSettings := gm.settingsManager.GetSettings()
|
|
baseURL := globalSettings.DefaultUpstreamURL
|
|
if opConfig.KeyCheckEndpoint != nil && *opConfig.KeyCheckEndpoint != "" {
|
|
baseURL = *opConfig.KeyCheckEndpoint
|
|
}
|
|
if baseURL == "" {
|
|
return "", fmt.Errorf("no endpoint configured for group %d", groupID)
|
|
}
|
|
modelName := globalSettings.BaseKeyCheckModel
|
|
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
|
|
modelName = *opConfig.KeyCheckModel
|
|
}
|
|
parsedURL, err := url.Parse(baseURL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err)
|
|
}
|
|
cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta")
|
|
parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName)
|
|
return parsedURL.String(), nil
|
|
}
|
|
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
|
|
ordersMap := make(map[uint]int, len(payload))
|
|
for _, item := range payload {
|
|
ordersMap[item.ID] = item.Order
|
|
}
|
|
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
|
|
return fmt.Errorf("failed to update order: %w", err)
|
|
}
|
|
go gm.Invalidate()
|
|
return nil
|
|
}
|
|
func uniqueStrings(slice []string) []string {
|
|
seen := make(map[string]struct{}, len(slice))
|
|
result := make([]string, 0, len(slice))
|
|
for _, s := range slice {
|
|
if _, exists := seen[s]; !exists {
|
|
seen[s] = struct{}{}
|
|
result = append(result, s)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GroupManager配置Syncer
|
|
func NewGroupManagerSyncer(
|
|
loader syncer.LoaderFunc[GroupManagerCacheData],
|
|
store store.Store,
|
|
logger *logrus.Logger,
|
|
) (*syncer.CacheSyncer[GroupManagerCacheData], error) {
|
|
const groupUpdateChannel = "groups:cache_invalidation"
|
|
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger)
|
|
}
|