Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -22,7 +23,7 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
)
func (r *gormKeyRepository) LoadAllKeysToStore() error {
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
@@ -48,7 +49,7 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
}
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(context.Background())
detailsToSet := make(map[string][]byte)
var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil {
@@ -100,14 +101,14 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
}
if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(key, value, 0); err != nil {
if err := r.store.Set(context.Background(), key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
}
@@ -124,16 +125,16 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err != nil {
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
}
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
}
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(key.ID)
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
}
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
for _, groupID := range groupIDs {
@@ -144,13 +145,13 @@ func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
}
return pipe.Exec()
}
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(context.Background())
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
if mapping.Status == models.StatusActive {
@@ -159,7 +160,7 @@ func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIK
return pipe.Exec()
}
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 {
return nil
}
@@ -184,7 +185,7 @@ func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.Group
}
groupUpdates[mapping.KeyGroupID] = update
}
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(context.Background())
var pipelineError error
for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"gemini-balancer/internal/models"
"context"
"math/rand"
"strings"
"time"
@@ -115,7 +116,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
}
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
key, err := r.GetKeyByID(id) // This now returns a decrypted key
key, err := r.GetKeyByID(id)
if err != nil {
return err
}
@@ -125,7 +126,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
if err != nil {
return err
}
if err := r.removeStoreCacheForKey(key); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
}
return nil
@@ -140,16 +141,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
// Find the full key objects first to update the cache later.
var keysToDelete []models.APIKey
// [MODIFIED] Find by hash.
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
return 0, err
}
if len(keysToDelete) == 0 {
return 0, nil
}
// Decrypt them to ensure cache has plaintext if needed.
if err := r.decryptKeys(keysToDelete); err != nil {
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
}
@@ -167,7 +165,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
return 0, err
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/models"
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
}
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
if result.Error != nil {
return 0, result.Error
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
return keys, nil
}
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
if err == nil {
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
go func() {
if err := r.LoadAllKeysToStore(); err != nil {
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
}
}()

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -14,7 +15,7 @@ import (
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 {
return nil
}
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
}
for _, keyID := range keyIDs {
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
}
return nil
}
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
if len(keyIDs) == 0 {
return 0, nil
}
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
for _, keyID := range keyIDs {
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
}
return unlinkedCount, nil
}
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
strGroupIDs, err := r.store.SMembers(cacheKey)
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
if err != nil || len(strGroupIDs) == 0 {
var groupIDs []uint
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
for _, id := range groupIDs {
interfaceSlice = append(interfaceSlice, id)
}
r.store.SAdd(cacheKey, interfaceSlice...)
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
}
return groupIDs, nil
}
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
return &mapping, err
}
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})

View File

@@ -1,7 +1,8 @@
// Filename: internal/repository/key_selector.go
// Filename: internal/repository/key_selector.go (经审查后最终修复版)
package repository
import (
"context"
"crypto/sha1"
"encoding/json"
"errors"
@@ -23,19 +24,18 @@ const (
)
// SelectOneActiveKey 根据指定的轮询策略从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
@@ -44,11 +44,11 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
}
if err != nil {
@@ -65,27 +65,25 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID))
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
poolID := generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
return nil, nil, err
}
@@ -96,10 +94,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
@@ -107,12 +105,11 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0)
}
if err != nil {
@@ -128,12 +125,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID))
}
return apiKey, group, nil
}
@@ -144,42 +139,39 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
exists, err := r.store.Exists(listKey)
exists, err := r.store.Exists(ctx, listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err // 直接返回读取错误
return err
}
if exists {
val, err := r.store.LIndex(listKey, 0)
val, err := r.store.LIndex(ctx, listKey, 0)
if err != nil {
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound // 已知为空,直接返回
return gorm.ErrRecordNotFound
}
return nil // 缓存有效,直接返回
return nil
}
}
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
}
if !acquired {
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID)
return r.ensureBasePoolCacheExists(ctx, pool, poolID)
}
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
if exists, _ := r.store.Exists(listKey); exists {
defer r.store.Del(context.Background(), lockKey)
if exists, _ := r.store.Exists(ctx, listKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
@@ -187,22 +179,15 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
lruMembers := make(map[string]float64)
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
// --- [核心修正] ---
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey)
if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
// 从而给了下一次请求一个全新的、成功的机会。
return err
}
// 只有在 SMembers 成功时,才继续处理
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
_, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
@@ -213,12 +198,9 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
}
// --- [逻辑修正] ---
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
// 才允许写入“毒丸”。
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(ctx)
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
@@ -226,14 +208,10 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
return gorm.ErrRecordNotFound
}
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据
pipe := r.store.Pipeline(ctx)
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
@@ -244,17 +222,22 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID)
}
}
return nil
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
})
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
}
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
@@ -285,8 +268,8 @@ func nowMilli() float64 {
}
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
}
@@ -295,7 +278,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
}
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
}

View File

@@ -1,7 +1,9 @@
// Filename: internal/repository/key_writer.go
package repository
import (
"context"
"fmt"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -9,7 +11,7 @@ import (
"time"
)
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
timestamp := float64(time.Now().UnixMilli())
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
strconv.FormatUint(uint64(keyID), 10): timestamp,
}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
}
}
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(lruKey, keyIDStr)
_ = r.store.SRem(mainPoolKey, keyIDStr)
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
_ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(ctx, lruKey, keyIDStr)
_ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
_ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
if newStatus == models.StatusActive {
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
}
members := map[string]float64{keyIDStr: 0}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
}
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
}
}
}
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
if success {
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
}
return
}
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
}
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
// This call is correct. It uses the synchronous, direct method.
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
}

View File

@@ -1,7 +1,8 @@
// Filename: internal/repository/repository.go
// Filename: internal/repository/repository.go (经审查后最终修复版)
package repository
import (
"context"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -22,8 +23,8 @@ type BasePool struct {
type KeyRepository interface {
// --- 核心选取与调度 --- key_selector
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
// --- 加密与解密 --- key_crud
Decrypt(key *models.APIKey) error
@@ -37,16 +38,16 @@ type KeyRepository interface {
GetKeyByID(id uint) (*models.APIKey, error)
GetKeyByValue(keyValue string) (*models.APIKey, error)
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
GetKeysByIDs(ids []uint) ([]models.APIKey, error)
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
CountByGroup(groupID uint) (int64, error)
// --- 多对多关系管理 --- key_mapping
LinkKeysToGroup(groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(keyID uint) ([]uint, error)
LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
@@ -55,8 +56,8 @@ type KeyRepository interface {
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
// --- 缓存管理 --- key_cache
LoadAllKeysToStore() error
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
LoadAllKeysToStore(ctx context.Context) error
HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
// --- 维护与后台任务 --- key_maintenance
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
@@ -65,16 +66,14 @@ type KeyRepository interface {
DeleteOrphanKeys() (int64, error)
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
GetActiveMasterKeys() ([]*models.APIKey, error)
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
// --- 轮询策略的"写"操作 --- key_writer
UpdateKeyUsageTimestamp(groupID, keyID uint)
// 同步更新缓存,供核心业务使用
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
// 异步更新缓存,供事件订阅者使用
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
}
type GroupRepository interface {