Fix basepool & 优化 repo

This commit is contained in:
XOF
2025-11-23 22:42:58 +08:00
parent 2b0b9b67dc
commit 6c7283d51b
16 changed files with 1312 additions and 723 deletions

View File

@@ -1,4 +1,4 @@
// Filename: internal/repository/key_cache.go
// Filename: internal/repository/key_cache.go (最终定稿)
package repository
import (
@@ -9,6 +9,7 @@ import (
"strconv"
)
// --- Redis Key 常量定义 ---
const (
KeyGroup = "group:%d:keys:active"
KeyDetails = "key:%d:details"
@@ -23,13 +24,16 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
)
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系并完整重建Redis缓存。
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
}
// 1. 批量解密所有涉及的密钥
keyMap := make(map[uint]*models.APIKey)
for _, m := range allMappings {
if m.APIKey != nil {
@@ -41,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
keysToDecrypt = append(keysToDecrypt, *k)
}
if err := r.decryptKeys(keysToDecrypt); err != nil {
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
// 即使解密失败,也继续尝试加载未加密或已解密的部分
}
decryptedKeyMap := make(map[uint]models.APIKey)
for _, k := range keysToDecrypt {
decryptedKeyMap[k.ID] = k
}
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline(context.Background())
detailsToSet := make(map[string][]byte)
// 2. 清理所有分组的旧轮询结构
pipe := r.store.Pipeline(ctx)
var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil {
for _, group := range allGroups {
@@ -63,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
)
}
} else {
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
}
// 3. 准备批量更新数据
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
detailsToSet := make(map[string]any)
for _, mapping := range allMappings {
if mapping.APIKey == nil {
continue
}
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
if !ok {
continue
continue // 跳过解密失败的密钥
}
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
keyJSON, _ := json.Marshal(decryptedKey)
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
mappingJSON, _ := json.Marshal(mapping)
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
if mapping.Status == models.StatusActive {
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
}
}
// 4. 使用 MSet 批量写入详情和映射缓存
if len(detailsToSet) > 0 {
if err := r.store.MSet(ctx, detailsToSet); err != nil {
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
}
}
// 5. 在Pipeline中重建所有分组的轮询结构
for groupID, activeMappings := range activeKeysByGroup {
if len(activeMappings) == 0 {
continue
@@ -101,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
}
// 6. 执行Pipeline
if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(context.Background(), key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
}
r.logger.Info("Cache rebuild complete, including all polling structures.")
r.logger.Info("Full cache rebuild completed successfully.")
return nil
}
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err := r.decryptKey(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
@@ -128,78 +144,101 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
}
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
}
pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
for _, groupID := range groupIDs {
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
}
return pipe.Exec()
}
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline(context.Background())
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
groupID := mapping.KeyGroupID
ctx := context.Background()
pipe := r.store.Pipeline(ctx)
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
// 如果新状态是 Active则重新添加到所有轮询结构中
if mapping.Status == models.StatusActive {
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
// 无论状态如何,都更新映射详情的 K-V 缓存
mappingJSON, err := json.Marshal(mapping)
if err != nil {
return fmt.Errorf("failed to marshal mapping: %w", err)
}
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
return pipe.Exec()
}
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 {
return nil
}
groupUpdates := make(map[uint]struct {
ToAdd []interface{}
ToRemove []interface{}
})
pipe := r.store.Pipeline(ctx)
for _, mapping := range mappings {
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
update, ok := groupUpdates[mapping.KeyGroupID]
if !ok {
update = struct {
ToAdd []interface{}
ToRemove []interface{}
}{}
}
groupID := mapping.KeyGroupID
// 对于批处理中的每一个mapping都执行完整的、正确的“先删后增”逻辑
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
if mapping.Status == models.StatusActive {
update.ToRemove = append(update.ToRemove, keyIDStr)
update.ToAdd = append(update.ToAdd, keyIDStr)
} else {
update.ToRemove = append(update.ToRemove, keyIDStr)
}
groupUpdates[mapping.KeyGroupID] = update
}
pipe := r.store.Pipeline(context.Background())
var pipelineError error
for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
if len(updates.ToRemove) > 0 {
for _, keyID := range updates.ToRemove {
pipe.LRem(activeKeyListKey, 0, keyID)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
if len(updates.ToAdd) > 0 {
pipe.LPush(activeKeyListKey, updates.ToAdd...)
}
mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误以保证大部分更新成功
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
}
if err := pipe.Exec(); err != nil {
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
}
return pipelineError
return pipe.Exec()
}

View File

@@ -23,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
keyHashes := make([]string, len(keys))
keyValueToHashMap := make(map[string]string)
for i, k := range keys {
// All incoming keys must have plaintext APIKey
if k.APIKey == "" {
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
}
@@ -35,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
var finalKeys []models.APIKey
err := r.db.Transaction(func(tx *gorm.DB) error {
var existingKeys []models.APIKey
// [MODIFIED] Query by hash to find existing keys.
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
return err
}
@@ -69,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
}
}
if len(keysToCreate) > 0 {
// [MODIFIED] Create now only provides encrypted data and hash.
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
return err
}
}
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
return err
}
// [CRITICAL] Decrypt all keys before returning them to the service layer.
return r.decryptKeys(finalKeys)
})
return finalKeys, err
}
func (r *gormKeyRepository) Update(key *models.APIKey) error {
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
// This indicates a potential change that needs to be re-encrypted.
if key.APIKey != "" {
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
if err != nil {
@@ -98,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
key.APIKeyHash = hex.EncodeToString(hash[:])
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
return tx.Save(key).Error
})
if err != nil {
return err
}
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
return nil // Continue without cache update if decryption fails.
return nil
}
if err := r.updateStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
@@ -192,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt before returning.
return keys, r.decryptKeys(keys)
}

View File

@@ -1,4 +1,4 @@
// Filename: internal/repository/key_selector.go (经审查后最终修复版)
// Filename: internal/repository/key_selector.go
package repository
import (
@@ -18,39 +18,40 @@ import (
)
const (
CacheTTL = 5 * time.Minute
EmptyPoolPlaceholder = "EMPTY_POOL"
EmptyCacheTTL = 1 * time.Minute
CacheTTL = 5 * time.Minute
EmptyCacheTTL = 1 * time.Minute
)
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
if group == nil {
return nil, nil, fmt.Errorf("group cannot be nil")
}
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
default:
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
@@ -58,39 +59,44 @@ func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *model
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID))
go func() {
updateCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
}()
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 智能聚合模式设计的全新轮询器
// SelectOneActiveKeyFromBasePool 智能聚合池中选取一个可用Key
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
poolID := generatePoolID(pool.CandidateGroups)
if pool == nil || len(pool.CandidateGroups) == 0 {
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
}
poolID := r.generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
if !errors.Is(err, gorm.ErrRecordNotFound) {
log.WithError(err).Error("Failed to ensure BasePool cache exists")
}
return nil, nil, err
}
var keyIDStr string
var err error
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
@@ -98,8 +104,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
@@ -107,13 +117,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
@@ -122,118 +131,246 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID))
}
return apiKey, group, nil
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.refreshBasePoolHeartbeat(bgCtx, poolID)
}()
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
if err != nil {
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
}
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
}
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
if err != nil {
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
return nil, nil, err
}
var originGroup *models.KeyGroup
for _, g := range pool.CandidateGroups {
if g.ID == uint(groupID) {
originGroup = g
break
}
}
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
if originGroup == nil {
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
}
if pool.PollingStrategy == models.StrategyWeighted {
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
}()
}
return apiKey, originGroup, nil
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
exists, err := r.store.Exists(ctx, listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
// 预检查,快速失败
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
if exists {
val, err := r.store.LIndex(ctx, listKey, 0)
if err != nil {
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound
}
return nil
}
}
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
}
if !acquired {
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(ctx, pool, poolID)
}
defer r.store.Del(context.Background(), lockKey)
if exists, _ := r.store.Exists(ctx, listKey); exists {
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
var allActiveKeyIDs []string
lruMembers := make(map[string]float64)
// 获取分布式锁
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
if err := r.acquireLock(ctx, lockKey); err != nil {
return err // acquireLock 内部已记录日志并返回明确错误
}
defer r.releaseLock(context.Background(), lockKey)
// 双重检查锁定
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
// 在执行重度操作前,最后检查一次上下文是否已取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
keyToGroupMap := make(map[string]any)
allKeyIDsSet := make(map[string]struct{})
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey)
groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
return err
r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
continue
}
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
for _, keyID := range groupKeyIDs {
if _, exists := allKeyIDsSet[keyID]; !exists {
allKeyIDsSet[keyID] = struct{}{}
keyToGroupMap[keyID] = groupIDStr
}
}
}
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline(ctx)
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
// 处理空池情况
if len(allKeyIDsSet) == 0 {
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
if emptyCacheTTL < time.Minute {
emptyCacheTTL = time.Minute
}
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
for keyID := range allKeyIDsSet {
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
}
// 使用 Pipeline 原子化构建所有缓存结构
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
pipe := r.store.Pipeline(ctx)
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
if len(keyToGroupMap) > 0 {
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
pipe.Expire(keyToGroupMapKey, basePoolTTL)
}
pipe.Expire(mainPoolKey, basePoolTTL)
pipe.Expire(sequentialKey, basePoolTTL)
pipe.Expire(cooldownKey, basePoolTTL)
pipe.Expire(lruKey, basePoolTTL)
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
return err
}
// 异步填充 LRU 缓存,并传入已构建好的映射
go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
return nil
}
if len(lruMembers) > 0 {
if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID)
// --- 辅助方法 ---
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
const (
lockTTL = 30 * time.Second
lockMaxRetries = 5
lockBaseBackoff = 50 * time.Millisecond
)
for i := 0; i < lockMaxRetries; i++ {
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
return err
}
if acquired {
return nil
}
time.Sleep(lockBaseBackoff * (1 << i))
}
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
}
// releaseLock 封装了分布式锁的释放逻辑。
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
if err := r.store.Del(ctx, lockKey); err != nil {
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
}
}
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), duration)
}
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
// 使用 EXPIRE 命令来刷新如果Key不存在它什么也不做是安全的
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
}
}
}
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
func (r *gormKeyRepository) populateBasePoolLRUCache(
parentCtx context.Context,
currentPoolID string,
keys []string,
keyToGroupMap map[string]any,
) {
lruMembers := make(map[string]float64, len(keys))
for _, keyIDStr := range keys {
select {
case <-parentCtx.Done():
return
default:
}
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
if !ok {
continue
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
data, err := r.store.Get(parentCtx, mappingKey)
if err != nil {
continue
}
var mapping models.GroupAPIKeyMapping
if json.Unmarshal(data, &mapping) == nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
}
}
if len(lruMembers) > 0 {
lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
if parentCtx.Err() == nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
}
}
}
return nil
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
})
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
@@ -241,20 +378,19 @@ func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context,
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup) string {
func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
ids := make([]int, len(groups))
for i, g := range groups {
ids[i] = int(g.ID)
}
sort.Ints(ids)
h := sha1.New()
io.WriteString(h, fmt.Sprintf("%v", ids))
return fmt.Sprintf("%x", h.Sum(nil))
}
// toInterfaceSlice 类型转换辅助函数
func toInterfaceSlice(slice []string) []interface{} {
func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
result := make([]interface{}, len(slice))
for i, v := range slice {
result[i] = v
@@ -263,7 +399,7 @@ func toInterfaceSlice(slice []string) []interface{} {
}
// nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略
func nowMilli() float64 {
func (r *gormKeyRepository) nowMilli() float64 {
return float64(time.Now().UnixMilli())
}

View File

@@ -1,8 +1,9 @@
// Filename: internal/repository/repository.go (经审查后最终修复版)
// Filename: internal/repository/repository.go
package repository
import (
"context"
"gemini-balancer/internal/config"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -87,18 +88,20 @@ type gormKeyRepository struct {
store store.Store
logger *logrus.Entry
crypto *crypto.Service
config *config.Config
}
type gormGroupRepository struct {
db *gorm.DB
}
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository {
return &gormKeyRepository{
db: db,
store: s,
logger: logger.WithField("component", "repository.key🔗"),
crypto: crypto,
config: cfg,
}
}