Update Context for store
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -22,7 +23,7 @@ const (
|
||||
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
|
||||
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
|
||||
@@ -48,7 +49,7 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
}
|
||||
|
||||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
pipe := r.store.Pipeline()
|
||||
pipe := r.store.Pipeline(context.Background())
|
||||
detailsToSet := make(map[string][]byte)
|
||||
var allGroups []*models.KeyGroup
|
||||
if err := r.db.Find(&allGroups).Error; err == nil {
|
||||
@@ -100,14 +101,14 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
|
||||
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
}
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
|
||||
}
|
||||
for key, value := range detailsToSet {
|
||||
if err := r.store.Set(key, value, 0); err != nil {
|
||||
if err := r.store.Set(context.Background(), key, value, 0); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
|
||||
}
|
||||
}
|
||||
@@ -124,16 +125,16 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
|
||||
}
|
||||
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(key.ID)
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
|
||||
if err != nil {
|
||||
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
|
||||
}
|
||||
|
||||
pipe := r.store.Pipeline()
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
@@ -144,13 +145,13 @@ func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
}
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
pipe := r.store.Pipeline()
|
||||
pipe := r.store.Pipeline(context.Background())
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
|
||||
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
|
||||
if mapping.Status == models.StatusActive {
|
||||
@@ -159,7 +160,7 @@ func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIK
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
|
||||
if len(mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -184,7 +185,7 @@ func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.Group
|
||||
}
|
||||
groupUpdates[mapping.KeyGroupID] = update
|
||||
}
|
||||
pipe := r.store.Pipeline()
|
||||
pipe := r.store.Pipeline(context.Background())
|
||||
var pipelineError error
|
||||
for groupID, updates := range groupUpdates {
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"context"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -115,7 +116,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
key, err := r.GetKeyByID(id) // This now returns a decrypted key
|
||||
key, err := r.GetKeyByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +126,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.removeStoreCacheForKey(key); err != nil {
|
||||
if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
|
||||
}
|
||||
return nil
|
||||
@@ -140,16 +141,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
// Find the full key objects first to update the cache later.
|
||||
var keysToDelete []models.APIKey
|
||||
// [MODIFIED] Find by hash.
|
||||
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(keysToDelete) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Decrypt them to ensure cache has plaintext if needed.
|
||||
if err := r.decryptKeys(keysToDelete); err != nil {
|
||||
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
|
||||
}
|
||||
@@ -167,7 +165,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
|
||||
return 0, err
|
||||
}
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
|
||||
}
|
||||
|
||||
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
|
||||
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&models.APIKey{}).
|
||||
Where("id = ?", keyID).
|
||||
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
|
||||
if err == nil {
|
||||
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
|
||||
go func() {
|
||||
if err := r.LoadAllKeysToStore(); err != nil {
|
||||
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
|
||||
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
}
|
||||
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
|
||||
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
}
|
||||
|
||||
return unlinkedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
|
||||
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
|
||||
strGroupIDs, err := r.store.SMembers(cacheKey)
|
||||
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
|
||||
if err != nil || len(strGroupIDs) == 0 {
|
||||
var groupIDs []uint
|
||||
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
|
||||
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
for _, id := range groupIDs {
|
||||
interfaceSlice = append(interfaceSlice, id)
|
||||
}
|
||||
r.store.SAdd(cacheKey, interfaceSlice...)
|
||||
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
|
||||
return &mapping, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// Filename: internal/repository/key_selector.go
|
||||
// Filename: internal/repository/key_selector.go (经审查后最终修复版)
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -23,19 +24,18 @@ const (
|
||||
)
|
||||
|
||||
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
|
||||
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch group.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
}
|
||||
@@ -44,11 +44,11 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||||
|
||||
default: // 默认或未指定策略时,使用基础的随机策略
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
|
||||
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -65,27 +65,25 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
|
||||
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
|
||||
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
|
||||
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID))
|
||||
}
|
||||
|
||||
return apiKey, mapping, nil
|
||||
}
|
||||
|
||||
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
// 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
poolID := generatePoolID(pool.CandidateGroups)
|
||||
log := r.logger.WithField("pool_id", poolID)
|
||||
|
||||
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
|
||||
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -96,10 +94,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
|
||||
switch pool.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
}
|
||||
@@ -107,12 +105,11 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||||
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||||
default:
|
||||
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
|
||||
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
|
||||
keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -128,12 +125,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
for _, group := range pool.CandidateGroups {
|
||||
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
|
||||
if cacheErr == nil && apiKey != nil && mapping != nil {
|
||||
|
||||
if pool.PollingStrategy == models.StrategyWeighted {
|
||||
|
||||
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
|
||||
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID))
|
||||
}
|
||||
return apiKey, group, nil
|
||||
}
|
||||
@@ -144,42 +139,39 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
|
||||
}
|
||||
|
||||
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
|
||||
listKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
|
||||
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
|
||||
exists, err := r.store.Exists(listKey)
|
||||
exists, err := r.store.Exists(ctx, listKey)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
|
||||
return err // 直接返回读取错误
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
val, err := r.store.LIndex(listKey, 0)
|
||||
val, err := r.store.LIndex(ctx, listKey, 0)
|
||||
if err != nil {
|
||||
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
|
||||
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
|
||||
} else {
|
||||
if val == EmptyPoolPlaceholder {
|
||||
return gorm.ErrRecordNotFound // 已知为空,直接返回
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil // 缓存有效,直接返回
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
|
||||
|
||||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||||
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
|
||||
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
|
||||
return err
|
||||
}
|
||||
if !acquired {
|
||||
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return r.ensureBasePoolCacheExists(pool, poolID)
|
||||
return r.ensureBasePoolCacheExists(ctx, pool, poolID)
|
||||
}
|
||||
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
|
||||
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
|
||||
if exists, _ := r.store.Exists(listKey); exists {
|
||||
defer r.store.Del(context.Background(), lockKey)
|
||||
|
||||
if exists, _ := r.store.Exists(ctx, listKey); exists {
|
||||
return nil
|
||||
}
|
||||
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
|
||||
@@ -187,22 +179,15 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
|
||||
lruMembers := make(map[string]float64)
|
||||
for _, group := range pool.CandidateGroups {
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
|
||||
|
||||
// --- [核心修正] ---
|
||||
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
|
||||
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
|
||||
groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
|
||||
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
|
||||
// 从而给了下一次请求一个全新的、成功的机会。
|
||||
return err
|
||||
}
|
||||
// 只有在 SMembers 成功时,才继续处理
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
|
||||
for _, keyIDStr := range groupKeyIDs {
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
_, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
|
||||
if err == nil && mapping != nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
@@ -213,12 +198,9 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
|
||||
}
|
||||
}
|
||||
|
||||
// --- [逻辑修正] ---
|
||||
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
|
||||
// 才允许写入“毒丸”。
|
||||
if len(allActiveKeyIDs) == 0 {
|
||||
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
|
||||
pipe := r.store.Pipeline()
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
pipe.LPush(listKey, EmptyPoolPlaceholder)
|
||||
pipe.Expire(listKey, EmptyCacheTTL)
|
||||
if err := pipe.Exec(); err != nil {
|
||||
@@ -226,14 +208,10 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
// 使用管道填充所有轮询结构
|
||||
pipe := r.store.Pipeline()
|
||||
// 1. 顺序
|
||||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
// 2. 随机
|
||||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
|
||||
// 设置合理的过期时间,例如5分钟,以防止孤儿数据
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
|
||||
@@ -244,17 +222,22 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
|
||||
}
|
||||
|
||||
if len(lruMembers) > 0 {
|
||||
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
|
||||
if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
r.store.ZAdd(lruKey, map[string]float64{
|
||||
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
|
||||
}
|
||||
}
|
||||
|
||||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||||
@@ -285,8 +268,8 @@ func nowMilli() float64 {
|
||||
}
|
||||
|
||||
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
|
||||
}
|
||||
@@ -295,7 +278,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
|
||||
}
|
||||
|
||||
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
// Filename: internal/repository/key_writer.go
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -9,7 +11,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
timestamp := float64(time.Now().UnixMilli())
|
||||
|
||||
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
strconv.FormatUint(uint64(keyID), 10): timestamp,
|
||||
}
|
||||
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
|
||||
|
||||
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(lruKey, keyIDStr)
|
||||
_ = r.store.SRem(mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
|
||||
_ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(ctx, lruKey, keyIDStr)
|
||||
_ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
|
||||
|
||||
if newStatus == models.StatusActive {
|
||||
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
|
||||
if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
|
||||
}
|
||||
members := map[string]float64{keyIDStr: 0}
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
|
||||
}
|
||||
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
|
||||
if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
|
||||
}
|
||||
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
|
||||
|
||||
// This call is correct. It uses the synchronous, direct method.
|
||||
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// Filename: internal/repository/repository.go
|
||||
// Filename: internal/repository/repository.go (经审查后最终修复版)
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -22,8 +23,8 @@ type BasePool struct {
|
||||
|
||||
type KeyRepository interface {
|
||||
// --- 核心选取与调度 --- key_selector
|
||||
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
|
||||
// --- 加密与解密 --- key_crud
|
||||
Decrypt(key *models.APIKey) error
|
||||
@@ -37,16 +38,16 @@ type KeyRepository interface {
|
||||
GetKeyByID(id uint) (*models.APIKey, error)
|
||||
GetKeyByValue(keyValue string) (*models.APIKey, error)
|
||||
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error)
|
||||
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
|
||||
CountByGroup(groupID uint) (int64, error)
|
||||
|
||||
// --- 多对多关系管理 --- key_mapping
|
||||
LinkKeysToGroup(groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(keyID uint) ([]uint, error)
|
||||
LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
|
||||
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
|
||||
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
|
||||
UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
|
||||
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
|
||||
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
|
||||
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
|
||||
@@ -55,8 +56,8 @@ type KeyRepository interface {
|
||||
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 缓存管理 --- key_cache
|
||||
LoadAllKeysToStore() error
|
||||
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
|
||||
LoadAllKeysToStore(ctx context.Context) error
|
||||
HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 维护与后台任务 --- key_maintenance
|
||||
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
|
||||
@@ -65,16 +66,14 @@ type KeyRepository interface {
|
||||
DeleteOrphanKeys() (int64, error)
|
||||
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
|
||||
GetActiveMasterKeys() ([]*models.APIKey, error)
|
||||
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
|
||||
UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
|
||||
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
|
||||
|
||||
// --- 轮询策略的"写"操作 --- key_writer
|
||||
UpdateKeyUsageTimestamp(groupID, keyID uint)
|
||||
// 同步更新缓存,供核心业务使用
|
||||
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
// 异步更新缓存,供事件订阅者使用
|
||||
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
|
||||
SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
}
|
||||
|
||||
type GroupRepository interface {
|
||||
|
||||
Reference in New Issue
Block a user