// Filename: internal/repository/key_selector.go package repository import ( "context" "crypto/sha1" "encoding/json" "errors" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/store" "io" "sort" "strconv" "time" "gorm.io/gorm" ) const ( CacheTTL = 5 * time.Minute EmptyCacheTTL = 1 * time.Minute ) // SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。 func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) { if group == nil { return nil, nil, fmt.Errorf("group cannot be nil") } var keyIDStr string var err error switch group.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID) keyIDStr, err = r.store.Rotate(ctx, sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(KeyGroupLRU, group.ID) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) if zerr == nil { if len(results) > 0 { keyIDStr = results[0] } else { zerr = gorm.ErrRecordNotFound } } err = zerr case models.StrategyRandom: mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) default: activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey) } if err != nil { if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, gorm.ErrRecordNotFound } r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy) return nil, nil, err } if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64) if parseErr != nil { r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr) return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr) } apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if err != nil { r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID) return nil, nil, err } if group.PollingStrategy == models.StrategyWeighted { go func() { updateCtx, cancel := r.withTimeout(5 * time.Second) defer cancel() r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID)) }() } return apiKey, mapping, nil } // SelectOneActiveKeyFromBasePool 从智能聚合池中选取一个可用Key。 func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { if pool == nil || len(pool.CandidateGroups) == 0 { return nil, nil, fmt.Errorf("invalid or empty base pool configuration") } poolID := r.generatePoolID(pool.CandidateGroups) log := r.logger.WithField("pool_id", poolID) if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { log.WithError(err).Error("Failed to ensure BasePool cache exists") } return nil, nil, err } var keyIDStr string var err error switch pool.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) keyIDStr, err = r.store.Rotate(ctx, sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(BasePoolLRU, poolID) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) if zerr == nil { if len(results) > 0 { keyIDStr = results[0] } else { zerr = gorm.ErrRecordNotFound } } err = zerr case models.StrategyRandom: mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID) cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) default: log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy) sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) keyIDStr, err = r.store.Rotate(ctx, sequentialKey) } if err != nil { if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, gorm.ErrRecordNotFound } log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy) return nil, nil, err } if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } go func() { bgCtx, cancel := r.withTimeout(5 * time.Second) defer cancel() r.refreshBasePoolHeartbeat(bgCtx, poolID) }() keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64) if parseErr != nil { log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr) return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr) } keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID) groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr) if err != nil { log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID) return nil, nil, errors.New("cache inconsistency: key has no origin group mapping") } groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64) if parseErr != nil { log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr) return nil, nil, errors.New("cache inconsistency: invalid group id in mapping") } apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID)) if err != nil { log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID) return nil, nil, err } var originGroup *models.KeyGroup for _, g := range pool.CandidateGroups { if g.ID == uint(groupID) { originGroup = g break } } if originGroup == nil { log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID) return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list") } if pool.PollingStrategy == models.StrategyWeighted { go func() { bgCtx, cancel := r.withTimeout(5 * time.Second) defer cancel() r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID)) }() } return apiKey, originGroup, nil } // ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构。 func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error { heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID) emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID) // 预检查,快速失败 if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists { return gorm.ErrRecordNotFound } if exists, _ := r.store.Exists(ctx, heartbeatKey); exists { return nil } // 获取分布式锁 lockKey := fmt.Sprintf("lock:basepool:%s", poolID) if err := r.acquireLock(ctx, lockKey); err != nil { return err // acquireLock 内部已记录日志并返回明确错误 } defer r.releaseLock(context.Background(), lockKey) // 双重检查锁定 if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists { return gorm.ErrRecordNotFound } if exists, _ := r.store.Exists(ctx, heartbeatKey); exists { return nil } // 在执行重度操作前,最后检查一次上下文是否已取消 select { case <-ctx.Done(): return ctx.Err() default: } r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID) // 手动聚合所有 Keys 并同时构建 key-to-group 映射 keyToGroupMap := make(map[string]any) allKeyIDsSet := make(map[string]struct{}) for _, group := range pool.CandidateGroups { groupKeySet := fmt.Sprintf(KeyGroup, group.ID) groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet) if err != nil { r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID) continue } groupIDStr := strconv.FormatUint(uint64(group.ID), 10) for _, keyID := range groupKeyIDs { if _, exists := allKeyIDsSet[keyID]; !exists { allKeyIDsSet[keyID] = struct{}{} keyToGroupMap[keyID] = groupIDStr } } } // 处理空池情况 if len(allKeyIDsSet) == 0 { emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2 if emptyCacheTTL < time.Minute { emptyCacheTTL = time.Minute } r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID) if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil { r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID) } return gorm.ErrRecordNotFound } allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet)) for keyID := range allKeyIDsSet { allActiveKeyIDs = append(allActiveKeyIDs, keyID) } // 使用 Pipeline 原子化构建所有缓存结构 basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID) sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) lruKey := fmt.Sprintf(BasePoolLRU, poolID) keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID) pipe := r.store.Pipeline(ctx) pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey) pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...) pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...) if len(keyToGroupMap) > 0 { pipe.HSet(keyToGroupMapKey, keyToGroupMap) pipe.Expire(keyToGroupMapKey, basePoolTTL) } pipe.Expire(mainPoolKey, basePoolTTL) pipe.Expire(sequentialKey, basePoolTTL) pipe.Expire(cooldownKey, basePoolTTL) pipe.Expire(lruKey, basePoolTTL) pipe.Set(heartbeatKey, []byte("1"), basePoolTTI) if err := pipe.Exec(); err != nil { r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID) cleanupCtx, cancel := r.withTimeout(5 * time.Second) defer cancel() r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey) return err } // 异步填充 LRU 缓存,并传入已构建好的映射 go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap) return nil } // --- 辅助方法 --- // acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。 func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error { const ( lockTTL = 30 * time.Second lockMaxRetries = 5 lockBaseBackoff = 50 * time.Millisecond ) for i := 0; i < lockMaxRetries; i++ { acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL) if err != nil { r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock") return err } if acquired { return nil } time.Sleep(lockBaseBackoff * (1 << i)) } return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries) } // releaseLock 封装了分布式锁的释放逻辑。 func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) { if err := r.store.Del(ctx, lockKey); err != nil { r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey) } } // withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。 func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), duration) } // refreshBasePoolHeartbeat 异步刷新心跳Key的TTI func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) { basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID) // 使用 EXPIRE 命令来刷新,如果Key不存在,它什么也不做,是安全的 if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil { if ctx.Err() == nil { // 避免在context取消后打印不必要的错误 r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID) } } } // populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构 func (r *gormKeyRepository) populateBasePoolLRUCache( parentCtx context.Context, currentPoolID string, keys []string, keyToGroupMap map[string]any, ) { lruMembers := make(map[string]float64, len(keys)) for _, keyIDStr := range keys { select { case <-parentCtx.Done(): return default: } groupIDStr, ok := keyToGroupMap[keyIDStr].(string) if !ok { continue } keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) groupID, _ := strconv.ParseUint(groupIDStr, 10, 64) mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID) data, err := r.store.Get(parentCtx, mappingKey) if err != nil { continue } var mapping models.GroupAPIKeyMapping if json.Unmarshal(data, &mapping) == nil { var score float64 if mapping.LastUsedAt != nil { score = float64(mapping.LastUsedAt.UnixMilli()) } lruMembers[keyIDStr] = score } } if len(lruMembers) > 0 { lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID) if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil { if parentCtx.Err() == nil { r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID) } } } } // updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) { lruKey := fmt.Sprintf(BasePoolLRU, poolID) err := r.store.ZAdd(ctx, lruKey, map[string]float64{ strconv.FormatUint(uint64(keyID), 10): r.nowMilli(), }) if err != nil { r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID) } } // generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string { ids := make([]int, len(groups)) for i, g := range groups { ids[i] = int(g.ID) } sort.Ints(ids) h := sha1.New() io.WriteString(h, fmt.Sprintf("%v", ids)) return fmt.Sprintf("%x", h.Sum(nil)) } // toInterfaceSlice 类型转换辅助函数 func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} { result := make([]interface{}, len(slice)) for i, v := range slice { result[i] = v } return result } // nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略 func (r *gormKeyRepository) nowMilli() float64 { return float64(time.Now().UnixMilli()) } // getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。 func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) { apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err) } var apiKey models.APIKey if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil { return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err) } mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err) } var mapping models.GroupAPIKeyMapping if err := json.Unmarshal(mappingJSON, &mapping); err != nil { return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err) } return &apiKey, &mapping, nil }