// Filename: internal/repository/key_cache.go package repository import ( "encoding/json" "fmt" "gemini-balancer/internal/models" "strconv" ) const ( KeyGroup = "group:%d:keys:active" KeyDetails = "key:%d:details" KeyMapping = "mapping:%d:%d" KeyGroupSequential = "group:%d:keys:sequential" KeyGroupLRU = "group:%d:keys:lru" KeyGroupRandomMain = "group:%d:keys:random:main" KeyGroupRandomCooldown = "group:%d:keys:random:cooldown" BasePoolSequential = "basepool:%s:keys:sequential" BasePoolLRU = "basepool:%s:keys:lru" BasePoolRandomMain = "basepool:%s:keys:random:main" BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown" ) func (r *gormKeyRepository) LoadAllKeysToStore() error { r.logger.Info("Starting to load all keys and associations into cache, including polling structures...") var allMappings []*models.GroupAPIKeyMapping if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil { return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err) } keyMap := make(map[uint]*models.APIKey) for _, m := range allMappings { if m.APIKey != nil { keyMap[m.APIKey.ID] = m.APIKey } } keysToDecrypt := make([]models.APIKey, 0, len(keyMap)) for _, k := range keyMap { keysToDecrypt = append(keysToDecrypt, *k) } if err := r.decryptKeys(keysToDecrypt); err != nil { r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.") } decryptedKeyMap := make(map[uint]models.APIKey) for _, k := range keysToDecrypt { decryptedKeyMap[k.ID] = k } activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping) pipe := r.store.Pipeline() detailsToSet := make(map[string][]byte) var allGroups []*models.KeyGroup if err := r.db.Find(&allGroups).Error; err == nil { for _, group := range allGroups { pipe.Del( fmt.Sprintf(KeyGroup, group.ID), fmt.Sprintf(KeyGroupSequential, group.ID), fmt.Sprintf(KeyGroupLRU, group.ID), fmt.Sprintf(KeyGroupRandomMain, group.ID), fmt.Sprintf(KeyGroupRandomCooldown, group.ID), ) } } else { r.logger.WithError(err).Error("Failed to get all groups for cache cleanup") } for _, mapping := range allMappings { if mapping.APIKey == nil { continue } decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID] if !ok { continue } keyJSON, _ := json.Marshal(decryptedKey) detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON mappingJSON, _ := json.Marshal(mapping) detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON if mapping.Status == models.StatusActive { activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping) } } for groupID, activeMappings := range activeKeysByGroup { if len(activeMappings) == 0 { continue } var activeKeyIDs []interface{} lruMembers := make(map[string]float64) for _, mapping := range activeMappings { keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) activeKeyIDs = append(activeKeyIDs, keyIDStr) var score float64 if mapping.LastUsedAt != nil { score = float64(mapping.LastUsedAt.UnixMilli()) } lruMembers[keyIDStr] = score } pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...) go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) } if err := pipe.Exec(); err != nil { return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err) } for key, value := range detailsToSet { if err := r.store.Set(key, value, 0); err != nil { r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key) } } r.logger.Info("Cache rebuild complete, including all polling structures.") return nil } func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error { if err := r.decryptKey(key); err != nil { return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err) } keyJSON, err := json.Marshal(key) if err != nil { return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err) } return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0) } func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error { groupIDs, err := r.GetGroupsForKey(key.ID) if err != nil { r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err) } pipe := r.store.Pipeline() pipe.Del(fmt.Sprintf(KeyDetails, key.ID)) for _, groupID := range groupIDs { pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID)) keyIDStr := strconv.FormatUint(uint64(key.ID), 10) pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) } return pipe.Exec() } func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error { pipe := r.store.Pipeline() activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID) pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID) if mapping.Status == models.StatusActive { pipe.LPush(activeKeyListKey, mapping.APIKeyID) } return pipe.Exec() } func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error { if len(mappings) == 0 { return nil } groupUpdates := make(map[uint]struct { ToAdd []interface{} ToRemove []interface{} }) for _, mapping := range mappings { keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) update, ok := groupUpdates[mapping.KeyGroupID] if !ok { update = struct { ToAdd []interface{} ToRemove []interface{} }{} } if mapping.Status == models.StatusActive { update.ToRemove = append(update.ToRemove, keyIDStr) update.ToAdd = append(update.ToAdd, keyIDStr) } else { update.ToRemove = append(update.ToRemove, keyIDStr) } groupUpdates[mapping.KeyGroupID] = update } pipe := r.store.Pipeline() var pipelineError error for groupID, updates := range groupUpdates { activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID) if len(updates.ToRemove) > 0 { for _, keyID := range updates.ToRemove { pipe.LRem(activeKeyListKey, 0, keyID) } } if len(updates.ToAdd) > 0 { pipe.LPush(activeKeyListKey, updates.ToAdd...) } } if err := pipe.Exec(); err != nil { pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err) } return pipelineError }