// Filename: internal/repository/key_selector.go (经审查后最终修复版) package repository import ( "context" "crypto/sha1" "encoding/json" "errors" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/store" "io" "sort" "strconv" "time" "gorm.io/gorm" ) const ( CacheTTL = 5 * time.Minute EmptyPoolPlaceholder = "EMPTY_POOL" EmptyCacheTTL = 1 * time.Minute ) // SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。 func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) { var keyIDStr string var err error switch group.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID) keyIDStr, err = r.store.Rotate(ctx, sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(KeyGroupLRU, group.ID) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) if zerr == nil && len(results) > 0 { keyIDStr = results[0] } err = zerr case models.StrategyRandom: mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) default: // 默认或未指定策略时,使用基础的随机策略 activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey) } if err != nil { if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, gorm.ErrRecordNotFound } r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy) return nil, nil, err } if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if err != nil { r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID) return nil, nil, err } if group.PollingStrategy == models.StrategyWeighted { go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID)) } return apiKey, mapping, nil } // SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。 func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { poolID := generatePoolID(pool.CandidateGroups) log := r.logger.WithField("pool_id", poolID) if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil { log.WithError(err).Error("Failed to ensure BasePool cache exists.") return nil, nil, err } var keyIDStr string var err error switch pool.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) keyIDStr, err = r.store.Rotate(ctx, sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(BasePoolLRU, poolID) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) if zerr == nil && len(results) > 0 { keyIDStr = results[0] } err = zerr case models.StrategyRandom: mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID) cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) default: log.Warnf("Default polling strategy triggered inside selection. This should be rare.") sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0) } if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, nil, gorm.ErrRecordNotFound } log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy) return nil, nil, err } if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) for _, group := range pool.CandidateGroups { apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if cacheErr == nil && apiKey != nil && mapping != nil { if pool.PollingStrategy == models.StrategyWeighted { go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID)) } return apiKey, group, nil } } log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID) return nil, nil, errors.New("cache inconsistency: selected key has no origin group") } // ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构 func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error { listKey := fmt.Sprintf(BasePoolSequential, poolID) exists, err := r.store.Exists(ctx, listKey) if err != nil { r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID) return err } if exists { val, err := r.store.LIndex(ctx, listKey, 0) if err != nil { r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID) } else { if val == EmptyPoolPlaceholder { return gorm.ErrRecordNotFound } return nil } } lockKey := fmt.Sprintf("lock:basepool:%s", poolID) acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second) if err != nil { r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.") return err } if !acquired { time.Sleep(100 * time.Millisecond) return r.ensureBasePoolCacheExists(ctx, pool, poolID) } defer r.store.Del(context.Background(), lockKey) if exists, _ := r.store.Exists(ctx, listKey); exists { return nil } r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID) var allActiveKeyIDs []string lruMembers := make(map[string]float64) for _, group := range pool.CandidateGroups { activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey) if err != nil { r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID) return err } allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...) for _, keyIDStr := range groupKeyIDs { keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) _, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if err == nil && mapping != nil { var score float64 if mapping.LastUsedAt != nil { score = float64(mapping.LastUsedAt.UnixMilli()) } lruMembers[keyIDStr] = score } } } if len(allActiveKeyIDs) == 0 { r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID) pipe := r.store.Pipeline(ctx) pipe.LPush(listKey, EmptyPoolPlaceholder) pipe.Expire(listKey, EmptyCacheTTL) if err := pipe.Exec(); err != nil { r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID) } return gorm.ErrRecordNotFound } pipe := r.store.Pipeline(ctx) pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...) pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...) pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL) if err := pipe.Exec(); err != nil { return err } if len(lruMembers) > 0 { if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil { r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID) } } return nil } // updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) { lruKey := fmt.Sprintf(BasePoolLRU, poolID) err := r.store.ZAdd(ctx, lruKey, map[string]float64{ strconv.FormatUint(uint64(keyID), 10): nowMilli(), }) if err != nil { r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID) } } // generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID func generatePoolID(groups []*models.KeyGroup) string { ids := make([]int, len(groups)) for i, g := range groups { ids[i] = int(g.ID) } sort.Ints(ids) h := sha1.New() io.WriteString(h, fmt.Sprintf("%v", ids)) return fmt.Sprintf("%x", h.Sum(nil)) } // toInterfaceSlice 类型转换辅助函数 func toInterfaceSlice(slice []string) []interface{} { result := make([]interface{}, len(slice)) for i, v := range slice { result[i] = v } return result } // nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略 func nowMilli() float64 { return float64(time.Now().UnixMilli()) } // getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。 func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) { apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err) } var apiKey models.APIKey if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil { return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err) } mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err) } var mapping models.GroupAPIKeyMapping if err := json.Unmarshal(mappingJSON, &mapping); err != nil { return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err) } return &apiKey, &mapping, nil }