// Filename: internal/repository/key_selector.go package repository import ( "crypto/sha1" "encoding/json" "errors" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/store" "io" "sort" "strconv" "strings" "time" "gorm.io/gorm" ) const ( CacheTTL = 5 * time.Minute EmptyPoolPlaceholder = "EMPTY_POOL" EmptyCacheTTL = 1 * time.Minute ) // SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。 func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) { var keyIDStr string var err error switch group.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID) keyIDStr, err = r.store.Rotate(sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(KeyGroupLRU, group.ID) results, zerr := r.store.ZRange(lruKey, 0, 0) if zerr == nil && len(results) > 0 { keyIDStr = results[0] } err = zerr case models.StrategyRandom: mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID) keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey) default: // 默认或未指定策略时,使用基础的随机策略 activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) keyIDStr, err = r.store.SRandMember(activeKeySetKey) } if err != nil { if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, gorm.ErrRecordNotFound } r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy) return nil, nil, err } if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID) if err != nil { r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID) // TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group) return nil, nil, err } if group.PollingStrategy == models.StrategyWeighted { go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID)) } return apiKey, mapping, nil } // SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。 func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { protocol := "default" if pool.Protocol != "" { protocol = string(pool.Protocol) } // 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离 poolID := generatePoolID(pool.CandidateGroups, protocol) log := r.logger.WithField("pool_id", poolID).WithField("protocol", protocol) if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil { log.WithError(err).Error("Failed to ensure BasePool cache exists.") if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, err } return nil, nil, fmt.Errorf("unexpected error while ensuring base pool cache: %w", err) } var keyIDStr string var err error switch pool.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) keyIDStr, err = r.store.Rotate(sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(BasePoolLRU, poolID) results, zerr := r.store.ZRange(lruKey, 0, 0) if zerr == nil && len(results) > 0 { keyIDStr = results[0] } err = zerr case models.StrategyRandom: mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID) cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey) default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案 log.Warnf("Default polling strategy triggered inside selection. This should be rare.") sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) keyIDStr, err = r.store.LIndex(sequentialKey, 0) } if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, nil, gorm.ErrRecordNotFound } log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy) return nil, nil, err } if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) for _, group := range pool.CandidateGroups { apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID) if cacheErr == nil && apiKey != nil && mapping != nil { if pool.PollingStrategy == models.StrategyWeighted { go r.updateKeyUsageTimestampForPool(poolID, uint(keyID)) } return apiKey, group, nil } } log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID) return nil, nil, errors.New("cache inconsistency: selected key has no origin group") } // ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构 func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error { listKey := fmt.Sprintf(BasePoolSequential, poolID) exists, err := r.store.Exists(listKey) if err != nil { r.logger.WithError(err).Errorf("Failed to check existence of basepool key: %s", listKey) return err } if exists { val, err := r.store.LIndex(listKey, 0) if err != nil { return err } if val == EmptyPoolPlaceholder { return gorm.ErrRecordNotFound } return nil } lockKey := fmt.Sprintf("lock:basepool:%s", poolID) acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) if err != nil { r.logger.WithError(err).Errorf("Failed to acquire distributed lock for basepool build: %s", lockKey) return err } if !acquired { time.Sleep(100 * time.Millisecond) return r.ensureBasePoolCacheExists(pool, poolID) } defer r.store.Del(lockKey) if exists, _ := r.store.Exists(listKey); exists { return nil } r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID) var allActiveKeyIDs []string lruMembers := make(map[string]float64) for _, group := range pool.CandidateGroups { activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) groupKeyIDs, err := r.store.SMembers(activeKeySetKey) if err != nil { r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID) continue } for _, keyIDStr := range groupKeyIDs { keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) _, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID) if err != nil { if errors.Is(err, store.ErrNotFound) || strings.Contains(err.Error(), "failed to get") { r.logger.WithError(err).Warnf("Cache inconsistency detected for KeyID %s in GroupID %d. Skipping.", keyIDStr, group.ID) continue } else { return err } } allActiveKeyIDs = append(allActiveKeyIDs, keyIDStr) if mapping != nil && mapping.LastUsedAt != nil { lruMembers[keyIDStr] = float64(mapping.LastUsedAt.UnixMilli()) } } } if len(allActiveKeyIDs) == 0 { pipe := r.store.Pipeline() pipe.LPush(listKey, EmptyPoolPlaceholder) pipe.Expire(listKey, EmptyCacheTTL) if err := pipe.Exec(); err != nil { r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID) } return gorm.ErrRecordNotFound } pipe := r.store.Pipeline() pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...) pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...) pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL) if err := pipe.Exec(); err != nil { return err } if len(lruMembers) > 0 { r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers) } return nil } // updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) { lruKey := fmt.Sprintf(BasePoolLRU, poolID) r.store.ZAdd(lruKey, map[string]float64{ strconv.FormatUint(uint64(keyID), 10): nowMilli(), }) } // generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID func generatePoolID(groups []*models.KeyGroup, protocol string) string { ids := make([]int, len(groups)) for i, g := range groups { ids[i] = int(g.ID) } sort.Ints(ids) h := sha1.New() io.WriteString(h, fmt.Sprintf("protocol:%s;groups:%v", protocol, ids)) return fmt.Sprintf("%x", h.Sum(nil)) } // toInterfaceSlice 类型转换辅助函数 func toInterfaceSlice(slice []string) []interface{} { result := make([]interface{}, len(slice)) for i, v := range slice { result[i] = v } return result } // nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略 func nowMilli() float64 { return float64(time.Now().UnixMilli()) } // getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。 func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) { apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err) } var apiKey models.APIKey if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil { return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err) } mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err) } var mapping models.GroupAPIKeyMapping if err := json.Unmarshal(mappingJSON, &mapping); err != nil { return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err) } return &apiKey, &mapping, nil }