428 lines
15 KiB
Go
428 lines
15 KiB
Go
// Filename: internal/repository/key_selector.go
|
||
package repository
|
||
|
||
import (
|
||
"context"
|
||
"crypto/sha1"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/store"
|
||
"io"
|
||
"sort"
|
||
"strconv"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const (
|
||
CacheTTL = 5 * time.Minute
|
||
EmptyCacheTTL = 1 * time.Minute
|
||
)
|
||
|
||
// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
|
||
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||
if group == nil {
|
||
return nil, nil, fmt.Errorf("group cannot be nil")
|
||
}
|
||
var keyIDStr string
|
||
var err error
|
||
switch group.PollingStrategy {
|
||
case models.StrategySequential:
|
||
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
|
||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||
case models.StrategyWeighted:
|
||
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
|
||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||
if zerr == nil {
|
||
if len(results) > 0 {
|
||
keyIDStr = results[0]
|
||
} else {
|
||
zerr = gorm.ErrRecordNotFound
|
||
}
|
||
}
|
||
err = zerr
|
||
case models.StrategyRandom:
|
||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
|
||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
|
||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||
default:
|
||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
|
||
}
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
|
||
return nil, nil, err
|
||
}
|
||
if keyIDStr == "" {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
|
||
if parseErr != nil {
|
||
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
|
||
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
|
||
}
|
||
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
|
||
if err != nil {
|
||
r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
|
||
return nil, nil, err
|
||
}
|
||
if group.PollingStrategy == models.StrategyWeighted {
|
||
go func() {
|
||
updateCtx, cancel := r.withTimeout(5 * time.Second)
|
||
defer cancel()
|
||
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
|
||
}()
|
||
}
|
||
return apiKey, mapping, nil
|
||
}
|
||
|
||
// SelectOneActiveKeyFromBasePool 从智能聚合池中选取一个可用Key。
|
||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||
if pool == nil || len(pool.CandidateGroups) == 0 {
|
||
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
|
||
}
|
||
poolID := r.generatePoolID(pool.CandidateGroups)
|
||
log := r.logger.WithField("pool_id", poolID)
|
||
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
|
||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
log.WithError(err).Error("Failed to ensure BasePool cache exists")
|
||
}
|
||
return nil, nil, err
|
||
}
|
||
var keyIDStr string
|
||
var err error
|
||
switch pool.PollingStrategy {
|
||
case models.StrategySequential:
|
||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||
case models.StrategyWeighted:
|
||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||
if zerr == nil {
|
||
if len(results) > 0 {
|
||
keyIDStr = results[0]
|
||
} else {
|
||
zerr = gorm.ErrRecordNotFound
|
||
}
|
||
}
|
||
err = zerr
|
||
case models.StrategyRandom:
|
||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||
default:
|
||
log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
|
||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||
}
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
|
||
return nil, nil, err
|
||
}
|
||
if keyIDStr == "" {
|
||
return nil, nil, gorm.ErrRecordNotFound
|
||
}
|
||
go func() {
|
||
bgCtx, cancel := r.withTimeout(5 * time.Second)
|
||
defer cancel()
|
||
r.refreshBasePoolHeartbeat(bgCtx, poolID)
|
||
}()
|
||
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
|
||
if parseErr != nil {
|
||
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
|
||
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
|
||
}
|
||
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
|
||
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
|
||
if err != nil {
|
||
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
|
||
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
|
||
}
|
||
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
|
||
if parseErr != nil {
|
||
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
|
||
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
|
||
}
|
||
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
|
||
if err != nil {
|
||
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
|
||
return nil, nil, err
|
||
}
|
||
var originGroup *models.KeyGroup
|
||
for _, g := range pool.CandidateGroups {
|
||
if g.ID == uint(groupID) {
|
||
originGroup = g
|
||
break
|
||
}
|
||
}
|
||
if originGroup == nil {
|
||
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
|
||
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
|
||
}
|
||
if pool.PollingStrategy == models.StrategyWeighted {
|
||
go func() {
|
||
bgCtx, cancel := r.withTimeout(5 * time.Second)
|
||
defer cancel()
|
||
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
|
||
}()
|
||
}
|
||
return apiKey, originGroup, nil
|
||
}
|
||
|
||
// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构。
|
||
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
|
||
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
|
||
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
|
||
// 预检查,快速失败
|
||
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
|
||
return gorm.ErrRecordNotFound
|
||
}
|
||
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
|
||
return nil
|
||
}
|
||
// 获取分布式锁
|
||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||
if err := r.acquireLock(ctx, lockKey); err != nil {
|
||
return err // acquireLock 内部已记录日志并返回明确错误
|
||
}
|
||
defer r.releaseLock(context.Background(), lockKey)
|
||
// 双重检查锁定
|
||
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
|
||
return gorm.ErrRecordNotFound
|
||
}
|
||
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
|
||
return nil
|
||
}
|
||
// 在执行重度操作前,最后检查一次上下文是否已取消
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
default:
|
||
}
|
||
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
|
||
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
|
||
keyToGroupMap := make(map[string]any)
|
||
allKeyIDsSet := make(map[string]struct{})
|
||
for _, group := range pool.CandidateGroups {
|
||
groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
|
||
groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
|
||
if err != nil {
|
||
r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
|
||
continue
|
||
}
|
||
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
|
||
for _, keyID := range groupKeyIDs {
|
||
if _, exists := allKeyIDsSet[keyID]; !exists {
|
||
allKeyIDsSet[keyID] = struct{}{}
|
||
keyToGroupMap[keyID] = groupIDStr
|
||
}
|
||
}
|
||
}
|
||
// 处理空池情况
|
||
if len(allKeyIDsSet) == 0 {
|
||
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
|
||
if emptyCacheTTL < time.Minute {
|
||
emptyCacheTTL = time.Minute
|
||
}
|
||
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
|
||
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
|
||
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
|
||
}
|
||
return gorm.ErrRecordNotFound
|
||
}
|
||
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
|
||
for keyID := range allKeyIDsSet {
|
||
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
|
||
}
|
||
// 使用 Pipeline 原子化构建所有缓存结构
|
||
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
|
||
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
|
||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
|
||
pipe := r.store.Pipeline(ctx)
|
||
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
|
||
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
|
||
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
|
||
if len(keyToGroupMap) > 0 {
|
||
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
|
||
pipe.Expire(keyToGroupMapKey, basePoolTTL)
|
||
}
|
||
pipe.Expire(mainPoolKey, basePoolTTL)
|
||
pipe.Expire(sequentialKey, basePoolTTL)
|
||
pipe.Expire(cooldownKey, basePoolTTL)
|
||
pipe.Expire(lruKey, basePoolTTL)
|
||
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
|
||
if err := pipe.Exec(); err != nil {
|
||
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
|
||
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
|
||
defer cancel()
|
||
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
|
||
return err
|
||
}
|
||
// 异步填充 LRU 缓存,并传入已构建好的映射
|
||
go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
|
||
return nil
|
||
}
|
||
|
||
// --- 辅助方法 ---
|
||
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
|
||
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
|
||
const (
|
||
lockTTL = 30 * time.Second
|
||
lockMaxRetries = 5
|
||
lockBaseBackoff = 50 * time.Millisecond
|
||
)
|
||
for i := 0; i < lockMaxRetries; i++ {
|
||
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
|
||
if err != nil {
|
||
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
|
||
return err
|
||
}
|
||
if acquired {
|
||
return nil
|
||
}
|
||
time.Sleep(lockBaseBackoff * (1 << i))
|
||
}
|
||
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
|
||
}
|
||
|
||
// releaseLock 封装了分布式锁的释放逻辑。
|
||
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
|
||
if err := r.store.Del(ctx, lockKey); err != nil {
|
||
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
|
||
}
|
||
}
|
||
|
||
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
|
||
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
|
||
return context.WithTimeout(context.Background(), duration)
|
||
}
|
||
|
||
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
|
||
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
|
||
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
|
||
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
|
||
// 使用 EXPIRE 命令来刷新,如果Key不存在,它什么也不做,是安全的
|
||
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
|
||
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
|
||
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
|
||
}
|
||
}
|
||
}
|
||
|
||
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
|
||
func (r *gormKeyRepository) populateBasePoolLRUCache(
|
||
parentCtx context.Context,
|
||
currentPoolID string,
|
||
keys []string,
|
||
keyToGroupMap map[string]any,
|
||
) {
|
||
lruMembers := make(map[string]float64, len(keys))
|
||
for _, keyIDStr := range keys {
|
||
select {
|
||
case <-parentCtx.Done():
|
||
return
|
||
default:
|
||
}
|
||
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
|
||
if !ok {
|
||
continue
|
||
}
|
||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||
groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
|
||
mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
|
||
data, err := r.store.Get(parentCtx, mappingKey)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
var mapping models.GroupAPIKeyMapping
|
||
if json.Unmarshal(data, &mapping) == nil {
|
||
var score float64
|
||
if mapping.LastUsedAt != nil {
|
||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||
}
|
||
lruMembers[keyIDStr] = score
|
||
}
|
||
}
|
||
if len(lruMembers) > 0 {
|
||
lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
|
||
if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
|
||
if parentCtx.Err() == nil {
|
||
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
|
||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
|
||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
|
||
strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
|
||
})
|
||
if err != nil {
|
||
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
|
||
}
|
||
}
|
||
|
||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||
func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
|
||
ids := make([]int, len(groups))
|
||
for i, g := range groups {
|
||
ids[i] = int(g.ID)
|
||
}
|
||
sort.Ints(ids)
|
||
h := sha1.New()
|
||
io.WriteString(h, fmt.Sprintf("%v", ids))
|
||
return fmt.Sprintf("%x", h.Sum(nil))
|
||
}
|
||
|
||
// toInterfaceSlice 类型转换辅助函数
|
||
func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
|
||
result := make([]interface{}, len(slice))
|
||
for i, v := range slice {
|
||
result[i] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略
|
||
func (r *gormKeyRepository) nowMilli() float64 {
|
||
return float64(time.Now().UnixMilli())
|
||
}
|
||
|
||
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
|
||
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
|
||
}
|
||
var apiKey models.APIKey
|
||
if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil {
|
||
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
|
||
}
|
||
|
||
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
|
||
}
|
||
var mapping models.GroupAPIKeyMapping
|
||
if err := json.Unmarshal(mappingJSON, &mapping); err != nil {
|
||
return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err)
|
||
}
|
||
|
||
return &apiKey, &mapping, nil
|
||
}
|