297 lines
9.8 KiB
Go
297 lines
9.8 KiB
Go
// Filename: internal/repository/key_selector.go
|
||
package repository
|
||
|
||
import (
|
||
"crypto/sha1"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/store"
|
||
"io"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const (
|
||
CacheTTL = 5 * time.Minute
|
||
EmptyPoolPlaceholder = "EMPTY_POOL"
|
||
EmptyCacheTTL = 1 * time.Minute
|
||
)
|
||
|
||
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
|
||
|
||
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||
var keyIDStr string
|
||
var err error
|
||
|
||
switch group.PollingStrategy {
|
||
case models.StrategySequential:
|
||
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
|
||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||
|
||
case models.StrategyWeighted:
|
||
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
|
||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||
if zerr == nil && len(results) > 0 {
|
||
keyIDStr = results[0]
|
||
}
|
||
err = zerr
|
||
|
||
case models.StrategyRandom:
|
||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
|
||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
|
||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||
|
||
default: // 默认或未指定策略时,使用基础的随机策略
|
||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
|
||
}
|
||
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
|
||
return nil, nil, err
|
||
}
|
||
|
||
if keyIDStr == "" {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
|
||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||
|
||
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||
if err != nil {
|
||
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
|
||
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
|
||
return nil, nil, err
|
||
}
|
||
|
||
if group.PollingStrategy == models.StrategyWeighted {
|
||
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
|
||
}
|
||
|
||
return apiKey, mapping, nil
|
||
}
|
||
|
||
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
|
||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||
protocol := "default"
|
||
if pool.Protocol != "" {
|
||
protocol = string(pool.Protocol)
|
||
}
|
||
// 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离
|
||
poolID := generatePoolID(pool.CandidateGroups, protocol)
|
||
log := r.logger.WithField("pool_id", poolID).WithField("protocol", protocol)
|
||
|
||
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
|
||
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil, err
|
||
}
|
||
return nil, nil, fmt.Errorf("unexpected error while ensuring base pool cache: %w", err)
|
||
}
|
||
|
||
var keyIDStr string
|
||
var err error
|
||
|
||
switch pool.PollingStrategy {
|
||
case models.StrategySequential:
|
||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||
case models.StrategyWeighted:
|
||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||
if zerr == nil && len(results) > 0 {
|
||
keyIDStr = results[0]
|
||
}
|
||
err = zerr
|
||
case models.StrategyRandom:
|
||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
|
||
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
|
||
|
||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
|
||
}
|
||
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNotFound) {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
|
||
return nil, nil, err
|
||
}
|
||
if keyIDStr == "" {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||
|
||
for _, group := range pool.CandidateGroups {
|
||
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||
if cacheErr == nil && apiKey != nil && mapping != nil {
|
||
|
||
if pool.PollingStrategy == models.StrategyWeighted {
|
||
|
||
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
|
||
}
|
||
return apiKey, group, nil
|
||
}
|
||
}
|
||
|
||
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
|
||
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
|
||
}
|
||
|
||
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
|
||
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
|
||
listKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||
exists, err := r.store.Exists(listKey)
|
||
if err != nil {
|
||
r.logger.WithError(err).Errorf("Failed to check existence of basepool key: %s", listKey)
|
||
return err
|
||
}
|
||
if exists {
|
||
val, err := r.store.LIndex(listKey, 0)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if val == EmptyPoolPlaceholder {
|
||
return gorm.ErrRecordNotFound
|
||
}
|
||
return nil
|
||
}
|
||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second)
|
||
if err != nil {
|
||
r.logger.WithError(err).Errorf("Failed to acquire distributed lock for basepool build: %s", lockKey)
|
||
return err
|
||
}
|
||
if !acquired {
|
||
time.Sleep(100 * time.Millisecond)
|
||
return r.ensureBasePoolCacheExists(pool, poolID)
|
||
}
|
||
defer r.store.Del(lockKey)
|
||
if exists, _ := r.store.Exists(listKey); exists {
|
||
return nil
|
||
}
|
||
r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID)
|
||
var allActiveKeyIDs []string
|
||
lruMembers := make(map[string]float64)
|
||
|
||
for _, group := range pool.CandidateGroups {
|
||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
|
||
if err != nil {
|
||
r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID)
|
||
continue
|
||
}
|
||
for _, keyIDStr := range groupKeyIDs {
|
||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||
|
||
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNotFound) || strings.Contains(err.Error(), "failed to get") {
|
||
r.logger.WithError(err).Warnf("Cache inconsistency detected for KeyID %s in GroupID %d. Skipping.", keyIDStr, group.ID)
|
||
continue
|
||
} else {
|
||
return err
|
||
}
|
||
}
|
||
allActiveKeyIDs = append(allActiveKeyIDs, keyIDStr)
|
||
if mapping != nil && mapping.LastUsedAt != nil {
|
||
lruMembers[keyIDStr] = float64(mapping.LastUsedAt.UnixMilli())
|
||
}
|
||
}
|
||
}
|
||
if len(allActiveKeyIDs) == 0 {
|
||
pipe := r.store.Pipeline()
|
||
pipe.LPush(listKey, EmptyPoolPlaceholder)
|
||
pipe.Expire(listKey, EmptyCacheTTL)
|
||
if err := pipe.Exec(); err != nil {
|
||
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
|
||
}
|
||
return gorm.ErrRecordNotFound
|
||
}
|
||
pipe := r.store.Pipeline()
|
||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
|
||
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
|
||
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
|
||
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
|
||
if err := pipe.Exec(); err != nil {
|
||
return err
|
||
}
|
||
if len(lruMembers) > 0 {
|
||
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
|
||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
|
||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||
r.store.ZAdd(lruKey, map[string]float64{
|
||
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
|
||
})
|
||
}
|
||
|
||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||
func generatePoolID(groups []*models.KeyGroup, protocol string) string {
|
||
ids := make([]int, len(groups))
|
||
for i, g := range groups {
|
||
ids[i] = int(g.ID)
|
||
}
|
||
sort.Ints(ids)
|
||
|
||
h := sha1.New()
|
||
io.WriteString(h, fmt.Sprintf("protocol:%s;groups:%v", protocol, ids))
|
||
return fmt.Sprintf("%x", h.Sum(nil))
|
||
}
|
||
|
||
// toInterfaceSlice 类型转换辅助函数
|
||
func toInterfaceSlice(slice []string) []interface{} {
|
||
result := make([]interface{}, len(slice))
|
||
for i, v := range slice {
|
||
result[i] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略
|
||
func nowMilli() float64 {
|
||
return float64(time.Now().UnixMilli())
|
||
}
|
||
|
||
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
|
||
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
|
||
}
|
||
var apiKey models.APIKey
|
||
if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil {
|
||
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
|
||
}
|
||
|
||
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
|
||
}
|
||
var mapping models.GroupAPIKeyMapping
|
||
if err := json.Unmarshal(mappingJSON, &mapping); err != nil {
|
||
return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err)
|
||
}
|
||
|
||
return &apiKey, &mapping, nil
|
||
}
|