Files
gemini-banlancer/internal/service/group_manager.go
2025-11-20 12:24:05 +08:00

597 lines
20 KiB
Go

// Filename: internal/service/group_manager.go (Syncer升级版)
package service
import (
"encoding/json"
"errors"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/pkg/reflectutil"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/syncer"
"gemini-balancer/internal/utils"
"net/url"
"path"
"sort"
"strings"
"time"
"github.com/sirupsen/logrus"
"gorm.io/datatypes"
"gorm.io/gorm"
)
const GroupUpdateChannel = "groups:cache_invalidation"
type GroupManagerCacheData struct {
Groups []*models.KeyGroup
GroupsByName map[string]*models.KeyGroup
GroupsByID map[uint]*models.KeyGroup
KeyCounts map[uint]int64 // GroupID -> Total Key Count
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
}
type GroupManager struct {
db *gorm.DB
keyRepo repository.KeyRepository
groupRepo repository.GroupRepository
settingsManager *settings.SettingsManager
syncer *syncer.CacheSyncer[GroupManagerCacheData]
logger *logrus.Entry
}
type UpdateOrderPayload struct {
ID uint `json:"id" binding:"required"`
Order int `json:"order"`
}
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
return func() (GroupManagerCacheData, error) {
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
var groups []*models.KeyGroup
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
if err := db.Preload("AllowedUpstreams").
Preload("AllowedModels").
Preload("Settings").
Preload("RequestConfig").
Find(&groups).Error; err != nil {
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
}
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
var allMappings []*models.GroupAPIKeyMapping
if err := db.Find(&allMappings).Error; err != nil {
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
}
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
for i := range allMappings {
mapping := allMappings[i] // Avoid pointer issues with range
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
}
for _, group := range groups {
if mappings, ok := mappingsByGroupID[group.ID]; ok {
group.Mappings = mappings
}
}
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
keyCounts := make(map[uint]int64, len(groups))
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
for _, group := range groups {
keyCounts[group.ID] = int64(len(group.Mappings))
statusCounts := make(map[models.APIKeyStatus]int64)
for _, mapping := range group.Mappings {
statusCounts[mapping.Status]++
}
keyStatusCounts[group.ID] = statusCounts
}
groupsByName := make(map[string]*models.KeyGroup, len(groups))
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
for i, group := range groups {
if group == nil {
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
} else {
groupsByName[group.Name] = group
groupsByID[group.ID] = group
}
}
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
return GroupManagerCacheData{
Groups: groups,
GroupsByName: groupsByName,
GroupsByID: groupsByID,
KeyCounts: keyCounts,
KeyStatusCounts: keyStatusCounts,
}, nil
}
}
func NewGroupManager(
db *gorm.DB,
keyRepo repository.KeyRepository,
groupRepo repository.GroupRepository,
sm *settings.SettingsManager,
syncer *syncer.CacheSyncer[GroupManagerCacheData],
logger *logrus.Logger,
) *GroupManager {
return &GroupManager{
db: db,
keyRepo: keyRepo,
groupRepo: groupRepo,
settingsManager: sm,
syncer: syncer,
logger: logger.WithField("component", "GroupManager"),
}
}
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
cache := gm.syncer.Get()
if len(cache.Groups) == 0 {
return []*models.KeyGroup{}
}
groupsToOrder := cache.Groups
sort.Slice(groupsToOrder, func(i, j int) bool {
if groupsToOrder[i].Order != groupsToOrder[j].Order {
return groupsToOrder[i].Order < groupsToOrder[j].Order
}
return groupsToOrder[i].ID < groupsToOrder[j].ID
})
return groupsToOrder
}
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
cache := gm.syncer.Get()
if len(cache.KeyCounts) == 0 {
return 0
}
count := cache.KeyCounts[groupID]
return count
}
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
cache := gm.syncer.Get()
if len(cache.KeyStatusCounts) == 0 {
return make(map[models.APIKeyStatus]int64)
}
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
return counts
}
return make(map[models.APIKeyStatus]int64)
}
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
cache := gm.syncer.Get()
if len(cache.GroupsByName) == 0 {
return nil, false
}
group, ok := cache.GroupsByName[name]
return group, ok
}
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
cache := gm.syncer.Get()
if len(cache.GroupsByID) == 0 {
return nil, false
}
group, ok := cache.GroupsByID[id]
return group, ok
}
func (gm *GroupManager) Stop() {
gm.syncer.Stop()
}
func (gm *GroupManager) Invalidate() error {
return gm.syncer.Invalidate()
}
// --- Write Operations ---
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
if !utils.IsValidGroupName(group.Name) {
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
}
err := gm.db.Transaction(func(tx *gorm.DB) error {
// 1. Create the group itself to get an ID
if err := tx.Create(group).Error; err != nil {
return err
}
// 2. If settings are provided, create the associated GroupSettings record
if settings != nil {
// Only marshal non-nil fields to keep the JSON clean
settingsToMarshal := make(map[string]interface{})
if settings.EnableKeyCheck != nil {
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
}
if settings.KeyCheckIntervalMinutes != nil {
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
}
if settings.KeyBlacklistThreshold != nil {
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
}
if settings.KeyCooldownMinutes != nil {
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
}
if settings.KeyCheckConcurrency != nil {
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
}
if settings.KeyCheckEndpoint != nil {
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
}
if settings.KeyCheckModel != nil {
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
}
if settings.MaxRetries != nil {
settingsToMarshal["max_retries"] = settings.MaxRetries
}
if settings.EnableSmartGateway != nil {
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
}
if len(settingsToMarshal) > 0 {
settingsJSON, err := json.Marshal(settingsToMarshal)
if err != nil {
return fmt.Errorf("failed to marshal group settings: %w", err)
}
groupSettings := models.GroupSettings{
GroupID: group.ID,
SettingsJSON: datatypes.JSON(settingsJSON),
}
if err := tx.Create(&groupSettings).Error; err != nil {
return fmt.Errorf("failed to save group settings: %w", err)
}
}
}
return nil
})
if err != nil {
return err
}
go gm.Invalidate()
return nil
}
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
if !utils.IsValidGroupName(group.Name) {
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
}
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
uniqueModelNames := uniqueStrings(modelNames)
err := gm.db.Transaction(func(tx *gorm.DB) error {
// --- 1. Update AllowedUpstreams (M:N relationship) ---
var upstreams []*models.UpstreamEndpoint
if len(uniqueUpstreamURLs) > 0 {
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
return err
}
}
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
return err
}
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
return err
}
if len(uniqueModelNames) > 0 {
var newMappings []models.GroupModelMapping
for _, name := range uniqueModelNames {
newMappings = append(newMappings, models.GroupModelMapping{ModelName: name})
}
if err := tx.Model(group).Association("AllowedModels").Append(newMappings); err != nil {
return err
}
}
if err := tx.Model(group).Updates(group).Error; err != nil {
return err
}
var existingSettings models.GroupSettings
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
return err
}
var currentSettingsData models.KeyGroupSettings
if len(existingSettings.SettingsJSON) > 0 {
if err := json.Unmarshal(existingSettings.SettingsJSON, &currentSettingsData); err != nil {
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
}
}
if err := reflectutil.MergeNilFields(&currentSettingsData, newSettings); err != nil {
return fmt.Errorf("failed to merge group settings: %w", err)
}
updatedJSON, err := json.Marshal(currentSettingsData)
if err != nil {
return fmt.Errorf("failed to marshal updated group settings: %w", err)
}
existingSettings.GroupID = group.ID
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
return tx.Save(&existingSettings).Error
})
if err == nil {
go gm.Invalidate()
}
return err
}
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
err := gm.db.Transaction(func(tx *gorm.DB) error {
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
// Step 1: First, retrieve the group object we are about to delete.
var group models.KeyGroup
if err := tx.First(&group, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
return nil // Don't treat as an error, the group is already gone.
}
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
return err
}
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
return err
}
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
return err
}
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
return err
}
// Also clear settings if they exist to maintain data integrity
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
return err
}
// Step 3: Delete the KeyGroup itself.
if err := tx.Delete(&group).Error; err != nil {
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
return err
}
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
if err != nil {
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
return err
}
if deletedCount > 0 {
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
}
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
return nil
})
if err == nil {
go gm.Invalidate()
}
return err
}
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
var originalGroup models.KeyGroup
if err := gm.db.
Preload("RequestConfig").
Preload("Mappings").
Preload("AllowedUpstreams").
Preload("AllowedModels").
First(&originalGroup, id).Error; err != nil {
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
}
newGroup := originalGroup
timestamp := time.Now().Unix()
newGroup.ID = 0
newGroup.Name = fmt.Sprintf("%s-clone-%d", originalGroup.Name, timestamp)
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
newGroup.CreatedAt = time.Time{}
newGroup.UpdatedAt = time.Time{}
newGroup.RequestConfigID = nil
newGroup.RequestConfig = nil
newGroup.Mappings = nil
newGroup.AllowedUpstreams = nil
newGroup.AllowedModels = nil
err := gm.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&newGroup).Error; err != nil {
return err
}
if originalGroup.RequestConfig != nil {
newRequestConfig := *originalGroup.RequestConfig
newRequestConfig.ID = 0 // Mark as new record
if err := tx.Create(&newRequestConfig).Error; err != nil {
return fmt.Errorf("failed to clone request config: %w", err)
}
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
}
}
var originalSettings models.GroupSettings
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
if err == nil && len(originalSettings.SettingsJSON) > 0 {
newSettings := models.GroupSettings{
GroupID: newGroup.ID,
SettingsJSON: originalSettings.SettingsJSON,
}
if err := tx.Create(&newSettings).Error; err != nil {
return fmt.Errorf("failed to clone group settings: %w", err)
}
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("failed to query original group settings: %w", err)
}
if len(originalGroup.Mappings) > 0 {
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
for i, oldMapping := range originalGroup.Mappings {
newMappings[i] = models.GroupAPIKeyMapping{
KeyGroupID: newGroup.ID,
APIKeyID: oldMapping.APIKeyID,
Status: oldMapping.Status,
LastError: oldMapping.LastError,
ConsecutiveErrorCount: oldMapping.ConsecutiveErrorCount,
LastUsedAt: oldMapping.LastUsedAt,
CooldownUntil: oldMapping.CooldownUntil,
}
}
if err := tx.Create(&newMappings).Error; err != nil {
return fmt.Errorf("failed to clone key group mappings: %w", err)
}
}
if len(originalGroup.AllowedUpstreams) > 0 {
if err := tx.Model(&newGroup).Association("AllowedUpstreams").Append(originalGroup.AllowedUpstreams); err != nil {
return err
}
}
if len(originalGroup.AllowedModels) > 0 {
if err := tx.Model(&newGroup).Association("AllowedModels").Append(originalGroup.AllowedModels); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
go gm.Invalidate()
var finalClonedGroup models.KeyGroup
if err := gm.db.
Preload("RequestConfig").
Preload("Mappings").
Preload("AllowedUpstreams").
Preload("AllowedModels").
First(&finalClonedGroup, newGroup.ID).Error; err != nil {
return nil, err
}
return &finalClonedGroup, nil
}
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
globalSettings := gm.settingsManager.GetSettings()
s := "gemini-1.5-flash" // Per user feedback for default model
opConfig := &models.KeyGroupSettings{
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
KeyCheckIntervalMinutes: &globalSettings.BaseKeyCheckIntervalMinutes,
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
KeyCheckModel: &s,
MaxRetries: &globalSettings.MaxRetries,
EnableSmartGateway: &globalSettings.EnableSmartGateway,
}
if group == nil {
return opConfig, nil
}
var groupSettingsRecord models.GroupSettings
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return opConfig, nil
}
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
return nil, err
}
if len(groupSettingsRecord.SettingsJSON) == 0 {
return opConfig, nil
}
var groupSpecificSettings models.KeyGroupSettings
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
return opConfig, err
}
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
return opConfig, nil
}
return opConfig, nil
}
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
group, ok := gm.GetGroupByID(groupID)
if !ok {
return "", fmt.Errorf("group with id %d not found", groupID)
}
opConfig, err := gm.BuildOperationalConfig(group)
if err != nil {
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
}
globalSettings := gm.settingsManager.GetSettings()
baseURL := globalSettings.DefaultUpstreamURL
if opConfig.KeyCheckEndpoint != nil && *opConfig.KeyCheckEndpoint != "" {
baseURL = *opConfig.KeyCheckEndpoint
}
if baseURL == "" {
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
}
modelName := globalSettings.BaseKeyCheckModel
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
modelName = *opConfig.KeyCheckModel
}
parsedURL, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
}
cleanedPath := parsedURL.Path
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
finalEndpoint := parsedURL.String()
return finalEndpoint, nil
}
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
ordersMap := make(map[uint]int, len(payload))
for _, item := range payload {
ordersMap[item.ID] = item.Order
}
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
gm.logger.WithError(err).Error("Failed to update group order in transaction")
return fmt.Errorf("database transaction failed: %w", err)
}
gm.logger.Info("Group order updated successfully, invalidating cache...")
go gm.Invalidate()
return nil
}
func uniqueStrings(slice []string) []string {
keys := make(map[string]struct{})
list := []string{}
for _, entry := range slice {
if _, value := keys[entry]; !value {
keys[entry] = struct{}{}
list = append(list, entry)
}
}
return list
}