245 lines
8.7 KiB
Go
245 lines
8.7 KiB
Go
// Filename: internal/repository/key_cache.go (最终定稿)
|
||
package repository
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"gemini-balancer/internal/models"
|
||
"strconv"
|
||
)
|
||
|
||
// --- Redis Key 常量定义 ---
|
||
const (
|
||
KeyGroup = "group:%d:keys:active"
|
||
KeyDetails = "key:%d:details"
|
||
KeyMapping = "mapping:%d:%d"
|
||
KeyGroupSequential = "group:%d:keys:sequential"
|
||
KeyGroupLRU = "group:%d:keys:lru"
|
||
KeyGroupRandomMain = "group:%d:keys:random:main"
|
||
KeyGroupRandomCooldown = "group:%d:keys:random:cooldown"
|
||
BasePoolSequential = "basepool:%s:keys:sequential"
|
||
BasePoolLRU = "basepool:%s:keys:lru"
|
||
BasePoolRandomMain = "basepool:%s:keys:random:main"
|
||
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
|
||
)
|
||
|
||
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系,并完整重建Redis缓存。
|
||
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
|
||
r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
|
||
|
||
var allMappings []*models.GroupAPIKeyMapping
|
||
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
|
||
return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
|
||
}
|
||
|
||
// 1. 批量解密所有涉及的密钥
|
||
keyMap := make(map[uint]*models.APIKey)
|
||
for _, m := range allMappings {
|
||
if m.APIKey != nil {
|
||
keyMap[m.APIKey.ID] = m.APIKey
|
||
}
|
||
}
|
||
keysToDecrypt := make([]models.APIKey, 0, len(keyMap))
|
||
for _, k := range keyMap {
|
||
keysToDecrypt = append(keysToDecrypt, *k)
|
||
}
|
||
if err := r.decryptKeys(keysToDecrypt); err != nil {
|
||
r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
|
||
// 即使解密失败,也继续尝试加载未加密或已解密的部分
|
||
}
|
||
decryptedKeyMap := make(map[uint]models.APIKey)
|
||
for _, k := range keysToDecrypt {
|
||
decryptedKeyMap[k.ID] = k
|
||
}
|
||
|
||
// 2. 清理所有分组的旧轮询结构
|
||
pipe := r.store.Pipeline(ctx)
|
||
var allGroups []*models.KeyGroup
|
||
if err := r.db.Find(&allGroups).Error; err == nil {
|
||
for _, group := range allGroups {
|
||
pipe.Del(
|
||
fmt.Sprintf(KeyGroup, group.ID),
|
||
fmt.Sprintf(KeyGroupSequential, group.ID),
|
||
fmt.Sprintf(KeyGroupLRU, group.ID),
|
||
fmt.Sprintf(KeyGroupRandomMain, group.ID),
|
||
fmt.Sprintf(KeyGroupRandomCooldown, group.ID),
|
||
)
|
||
}
|
||
} else {
|
||
r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
|
||
}
|
||
|
||
// 3. 准备批量更新数据
|
||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||
detailsToSet := make(map[string]any)
|
||
|
||
for _, mapping := range allMappings {
|
||
if mapping.APIKey == nil {
|
||
continue
|
||
}
|
||
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
|
||
if !ok {
|
||
continue // 跳过解密失败的密钥
|
||
}
|
||
|
||
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
|
||
keyJSON, _ := json.Marshal(decryptedKey)
|
||
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
|
||
mappingJSON, _ := json.Marshal(mapping)
|
||
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
|
||
|
||
if mapping.Status == models.StatusActive {
|
||
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
|
||
}
|
||
}
|
||
|
||
// 4. 使用 MSet 批量写入详情和映射缓存
|
||
if len(detailsToSet) > 0 {
|
||
if err := r.store.MSet(ctx, detailsToSet); err != nil {
|
||
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
|
||
}
|
||
}
|
||
|
||
// 5. 在Pipeline中重建所有分组的轮询结构
|
||
for groupID, activeMappings := range activeKeysByGroup {
|
||
if len(activeMappings) == 0 {
|
||
continue
|
||
}
|
||
var activeKeyIDs []interface{}
|
||
lruMembers := make(map[string]float64)
|
||
for _, mapping := range activeMappings {
|
||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||
activeKeyIDs = append(activeKeyIDs, keyIDStr)
|
||
var score float64
|
||
if mapping.LastUsedAt != nil {
|
||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||
}
|
||
lruMembers[keyIDStr] = score
|
||
}
|
||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
|
||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
|
||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
|
||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||
}
|
||
|
||
// 6. 执行Pipeline
|
||
if err := pipe.Exec(); err != nil {
|
||
return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
|
||
}
|
||
|
||
r.logger.Info("Full cache rebuild completed successfully.")
|
||
return nil
|
||
}
|
||
|
||
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
|
||
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||
if err := r.decryptKey(key); err != nil {
|
||
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
|
||
}
|
||
keyJSON, err := json.Marshal(key)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
|
||
}
|
||
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||
}
|
||
|
||
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
|
||
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
|
||
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
|
||
if err != nil {
|
||
r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
|
||
}
|
||
|
||
pipe := r.store.Pipeline(ctx)
|
||
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
|
||
|
||
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
||
for _, groupID := range groupIDs {
|
||
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
|
||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||
}
|
||
|
||
return pipe.Exec()
|
||
}
|
||
|
||
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
|
||
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
|
||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||
groupID := mapping.KeyGroupID
|
||
ctx := context.Background()
|
||
|
||
pipe := r.store.Pipeline(ctx)
|
||
|
||
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
|
||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||
|
||
// 如果新状态是 Active,则重新添加到所有轮询结构中
|
||
if mapping.Status == models.StatusActive {
|
||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
|
||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||
|
||
var score float64
|
||
if mapping.LastUsedAt != nil {
|
||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||
}
|
||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
|
||
}
|
||
|
||
// 无论状态如何,都更新映射详情的 K-V 缓存
|
||
mappingJSON, err := json.Marshal(mapping)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal mapping: %w", err)
|
||
}
|
||
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
|
||
|
||
return pipe.Exec()
|
||
}
|
||
|
||
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
|
||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
|
||
if len(mappings) == 0 {
|
||
return nil
|
||
}
|
||
|
||
pipe := r.store.Pipeline(ctx)
|
||
|
||
for _, mapping := range mappings {
|
||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||
groupID := mapping.KeyGroupID
|
||
|
||
// 对于批处理中的每一个mapping,都执行完整的、正确的“先删后增”逻辑
|
||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||
|
||
if mapping.Status == models.StatusActive {
|
||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
|
||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||
|
||
var score float64
|
||
if mapping.LastUsedAt != nil {
|
||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||
}
|
||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
|
||
}
|
||
|
||
mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误,以保证大部分更新成功
|
||
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
|
||
}
|
||
|
||
return pipe.Exec()
|
||
}
|