Files
gemini-banlancer/internal/repository/key_selector.go
2025-11-21 19:33:05 +08:00

297 lines
9.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/repository/key_selector.go
package repository
import (
"crypto/sha1"
"encoding/json"
"errors"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"io"
"sort"
"strconv"
"strings"
"time"
"gorm.io/gorm"
)
const (
CacheTTL = 5 * time.Minute
EmptyPoolPlaceholder = "EMPTY_POOL"
EmptyCacheTTL = 1 * time.Minute
)
// SelectOneActiveKey 根据指定的轮询策略从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
protocol := "default"
if pool.Protocol != "" {
protocol = string(pool.Protocol)
}
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
poolID := generatePoolID(pool.CandidateGroups, protocol)
log := r.logger.WithField("pool_id", poolID).WithField("protocol", protocol)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
}
return nil, nil, fmt.Errorf("unexpected error while ensuring base pool cache: %w", err)
}
var keyIDStr string
var err error
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
}
return apiKey, group, nil
}
}
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
exists, err := r.store.Exists(listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence of basepool key: %s", listKey)
return err
}
if exists {
val, err := r.store.LIndex(listKey, 0)
if err != nil {
return err
}
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound
}
return nil
}
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second)
if err != nil {
r.logger.WithError(err).Errorf("Failed to acquire distributed lock for basepool build: %s", lockKey)
return err
}
if !acquired {
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID)
}
defer r.store.Del(lockKey)
if exists, _ := r.store.Exists(listKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID)
var allActiveKeyIDs []string
lruMembers := make(map[string]float64)
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
if err != nil {
r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID)
continue
}
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if err != nil {
if errors.Is(err, store.ErrNotFound) || strings.Contains(err.Error(), "failed to get") {
r.logger.WithError(err).Warnf("Cache inconsistency detected for KeyID %s in GroupID %d. Skipping.", keyIDStr, group.ID)
continue
} else {
return err
}
}
allActiveKeyIDs = append(allActiveKeyIDs, keyIDStr)
if mapping != nil && mapping.LastUsedAt != nil {
lruMembers[keyIDStr] = float64(mapping.LastUsedAt.UnixMilli())
}
}
}
if len(allActiveKeyIDs) == 0 {
pipe := r.store.Pipeline()
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
pipe := r.store.Pipeline()
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
if err := pipe.Exec(); err != nil {
return err
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
}
return nil
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
})
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup, protocol string) string {
ids := make([]int, len(groups))
for i, g := range groups {
ids[i] = int(g.ID)
}
sort.Ints(ids)
h := sha1.New()
io.WriteString(h, fmt.Sprintf("protocol:%s;groups:%v", protocol, ids))
return fmt.Sprintf("%x", h.Sum(nil))
}
// toInterfaceSlice 类型转换辅助函数
func toInterfaceSlice(slice []string) []interface{} {
result := make([]interface{}, len(slice))
for i, v := range slice {
result[i] = v
}
return result
}
// nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略
func nowMilli() float64 {
return float64(time.Now().UnixMilli())
}
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
}
var apiKey models.APIKey
if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
}
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
}
var mapping models.GroupAPIKeyMapping
if err := json.Unmarshal(mappingJSON, &mapping); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err)
}
return &apiKey, &mapping, nil
}