// Filename: internal/repository/key_cache.go (最终定稿) package repository import ( "context" "encoding/json" "fmt" "gemini-balancer/internal/models" "strconv" ) // --- Redis Key 常量定义 --- const ( KeyGroup = "group:%d:keys:active" KeyDetails = "key:%d:details" KeyMapping = "mapping:%d:%d" KeyGroupSequential = "group:%d:keys:sequential" KeyGroupLRU = "group:%d:keys:lru" KeyGroupRandomMain = "group:%d:keys:random:main" KeyGroupRandomCooldown = "group:%d:keys:random:cooldown" BasePoolSequential = "basepool:%s:keys:sequential" BasePoolLRU = "basepool:%s:keys:lru" BasePoolRandomMain = "basepool:%s:keys:random:main" BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown" ) // LoadAllKeysToStore 从数据库加载所有密钥和映射关系,并完整重建Redis缓存。 func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error { r.logger.Info("Starting full cache rebuild for all keys and polling structures.") var allMappings []*models.GroupAPIKeyMapping if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil { return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err) } // 1. 批量解密所有涉及的密钥 keyMap := make(map[uint]*models.APIKey) for _, m := range allMappings { if m.APIKey != nil { keyMap[m.APIKey.ID] = m.APIKey } } keysToDecrypt := make([]models.APIKey, 0, len(keyMap)) for _, k := range keyMap { keysToDecrypt = append(keysToDecrypt, *k) } if err := r.decryptKeys(keysToDecrypt); err != nil { r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.") // 即使解密失败,也继续尝试加载未加密或已解密的部分 } decryptedKeyMap := make(map[uint]models.APIKey) for _, k := range keysToDecrypt { decryptedKeyMap[k.ID] = k } // 2. 清理所有分组的旧轮询结构 pipe := r.store.Pipeline(ctx) var allGroups []*models.KeyGroup if err := r.db.Find(&allGroups).Error; err == nil { for _, group := range allGroups { pipe.Del( fmt.Sprintf(KeyGroup, group.ID), fmt.Sprintf(KeyGroupSequential, group.ID), fmt.Sprintf(KeyGroupLRU, group.ID), fmt.Sprintf(KeyGroupRandomMain, group.ID), fmt.Sprintf(KeyGroupRandomCooldown, group.ID), ) } } else { r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.") } // 3. 准备批量更新数据 activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping) detailsToSet := make(map[string]any) for _, mapping := range allMappings { if mapping.APIKey == nil { continue } decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID] if !ok { continue // 跳过解密失败的密钥 } // 准备 KeyDetails 和 KeyMapping 的 MSet 数据 keyJSON, _ := json.Marshal(decryptedKey) detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON mappingJSON, _ := json.Marshal(mapping) detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON if mapping.Status == models.StatusActive { activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping) } } // 4. 使用 MSet 批量写入详情和映射缓存 if len(detailsToSet) > 0 { if err := r.store.MSet(ctx, detailsToSet); err != nil { r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.") } } // 5. 在Pipeline中重建所有分组的轮询结构 for groupID, activeMappings := range activeKeysByGroup { if len(activeMappings) == 0 { continue } var activeKeyIDs []interface{} lruMembers := make(map[string]float64) for _, mapping := range activeMappings { keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) activeKeyIDs = append(activeKeyIDs, keyIDStr) var score float64 if mapping.LastUsedAt != nil { score = float64(mapping.LastUsedAt.UnixMilli()) } lruMembers[keyIDStr] = score } pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...) pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) } // 6. 执行Pipeline if err := pipe.Exec(); err != nil { return fmt.Errorf("pipeline execution for polling structures failed: %w", err) } r.logger.Info("Full cache rebuild completed successfully.") return nil } // updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。 func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error { if err := r.decryptKey(key); err != nil { return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err) } keyJSON, err := json.Marshal(key) if err != nil { return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err) } return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0) } // removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。 func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error { groupIDs, err := r.GetGroupsForKey(ctx, key.ID) if err != nil { r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID) } pipe := r.store.Pipeline(ctx) pipe.Del(fmt.Sprintf(KeyDetails, key.ID)) keyIDStr := strconv.FormatUint(uint64(key.ID), 10) for _, groupID := range groupIDs { pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID)) pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) } return pipe.Exec() } // updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。 func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error { keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) groupID := mapping.KeyGroupID ctx := context.Background() pipe := r.store.Pipeline(ctx) // 统一、无条件地从所有轮询结构中移除,确保状态清洁 pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) // 如果新状态是 Active,则重新添加到所有轮询结构中 if mapping.Status == models.StatusActive { pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) var score float64 if mapping.LastUsedAt != nil { score = float64(mapping.LastUsedAt.UnixMilli()) } pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score}) } // 无论状态如何,都更新映射详情的 K-V 缓存 mappingJSON, err := json.Marshal(mapping) if err != nil { return fmt.Errorf("failed to marshal mapping: %w", err) } pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0) return pipe.Exec() } // HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。 func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error { if len(mappings) == 0 { return nil } pipe := r.store.Pipeline(ctx) for _, mapping := range mappings { keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) groupID := mapping.KeyGroupID // 对于批处理中的每一个mapping,都执行完整的、正确的“先删后增”逻辑 pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) if mapping.Status == models.StatusActive { pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) var score float64 if mapping.LastUsedAt != nil { score = float64(mapping.LastUsedAt.UnixMilli()) } pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score}) } mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误,以保证大部分更新成功 pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0) } return pipe.Exec() }