Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -1,7 +1,8 @@
// Filename: internal/repository/key_selector.go
// Filename: internal/repository/key_selector.go (经审查后最终修复版)
package repository
import (
"context"
"crypto/sha1"
"encoding/json"
"errors"
@@ -23,19 +24,18 @@ const (
)
// SelectOneActiveKey 根据指定的轮询策略从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
@@ -44,11 +44,11 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
}
if err != nil {
@@ -65,27 +65,25 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID))
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
poolID := generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
return nil, nil, err
}
@@ -96,10 +94,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
@@ -107,12 +105,11 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0)
}
if err != nil {
@@ -128,12 +125,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID))
}
return apiKey, group, nil
}
@@ -144,42 +139,39 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
exists, err := r.store.Exists(listKey)
exists, err := r.store.Exists(ctx, listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err // 直接返回读取错误
return err
}
if exists {
val, err := r.store.LIndex(listKey, 0)
val, err := r.store.LIndex(ctx, listKey, 0)
if err != nil {
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound // 已知为空,直接返回
return gorm.ErrRecordNotFound
}
return nil // 缓存有效,直接返回
return nil
}
}
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
}
if !acquired {
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID)
return r.ensureBasePoolCacheExists(ctx, pool, poolID)
}
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
if exists, _ := r.store.Exists(listKey); exists {
defer r.store.Del(context.Background(), lockKey)
if exists, _ := r.store.Exists(ctx, listKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
@@ -187,22 +179,15 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
lruMembers := make(map[string]float64)
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
// --- [核心修正] ---
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey)
if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
// 从而给了下一次请求一个全新的、成功的机会。
return err
}
// 只有在 SMembers 成功时,才继续处理
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
_, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
@@ -213,12 +198,9 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
}
// --- [逻辑修正] ---
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
// 才允许写入“毒丸”。
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(ctx)
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
@@ -226,14 +208,10 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
return gorm.ErrRecordNotFound
}
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据
pipe := r.store.Pipeline(ctx)
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
@@ -244,17 +222,22 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID)
}
}
return nil
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
})
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
}
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
@@ -285,8 +268,8 @@ func nowMilli() float64 {
}
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
}
@@ -295,7 +278,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
}
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
}