206 lines
6.8 KiB
Go
206 lines
6.8 KiB
Go
// Filename: internal/repository/key_cache.go
|
|
package repository
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/models"
|
|
"strconv"
|
|
)
|
|
|
|
const (
|
|
KeyGroup = "group:%d:keys:active"
|
|
KeyDetails = "key:%d:details"
|
|
KeyMapping = "mapping:%d:%d"
|
|
KeyGroupSequential = "group:%d:keys:sequential"
|
|
KeyGroupLRU = "group:%d:keys:lru"
|
|
KeyGroupRandomMain = "group:%d:keys:random:main"
|
|
KeyGroupRandomCooldown = "group:%d:keys:random:cooldown"
|
|
BasePoolSequential = "basepool:%s:keys:sequential"
|
|
BasePoolLRU = "basepool:%s:keys:lru"
|
|
BasePoolRandomMain = "basepool:%s:keys:random:main"
|
|
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
|
|
)
|
|
|
|
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
|
|
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
|
|
var allMappings []*models.GroupAPIKeyMapping
|
|
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
|
|
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
|
|
}
|
|
|
|
keyMap := make(map[uint]*models.APIKey)
|
|
for _, m := range allMappings {
|
|
if m.APIKey != nil {
|
|
keyMap[m.APIKey.ID] = m.APIKey
|
|
}
|
|
}
|
|
keysToDecrypt := make([]models.APIKey, 0, len(keyMap))
|
|
for _, k := range keyMap {
|
|
keysToDecrypt = append(keysToDecrypt, *k)
|
|
}
|
|
if err := r.decryptKeys(keysToDecrypt); err != nil {
|
|
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
|
|
}
|
|
decryptedKeyMap := make(map[uint]models.APIKey)
|
|
for _, k := range keysToDecrypt {
|
|
decryptedKeyMap[k.ID] = k
|
|
}
|
|
|
|
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
|
pipe := r.store.Pipeline(context.Background())
|
|
detailsToSet := make(map[string][]byte)
|
|
var allGroups []*models.KeyGroup
|
|
if err := r.db.Find(&allGroups).Error; err == nil {
|
|
for _, group := range allGroups {
|
|
pipe.Del(
|
|
fmt.Sprintf(KeyGroup, group.ID),
|
|
fmt.Sprintf(KeyGroupSequential, group.ID),
|
|
fmt.Sprintf(KeyGroupLRU, group.ID),
|
|
fmt.Sprintf(KeyGroupRandomMain, group.ID),
|
|
fmt.Sprintf(KeyGroupRandomCooldown, group.ID),
|
|
)
|
|
}
|
|
} else {
|
|
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
|
|
}
|
|
|
|
for _, mapping := range allMappings {
|
|
if mapping.APIKey == nil {
|
|
continue
|
|
}
|
|
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
|
|
if !ok {
|
|
continue
|
|
}
|
|
keyJSON, _ := json.Marshal(decryptedKey)
|
|
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
|
|
mappingJSON, _ := json.Marshal(mapping)
|
|
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
|
|
if mapping.Status == models.StatusActive {
|
|
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
|
|
}
|
|
}
|
|
|
|
for groupID, activeMappings := range activeKeysByGroup {
|
|
if len(activeMappings) == 0 {
|
|
continue
|
|
}
|
|
var activeKeyIDs []interface{}
|
|
lruMembers := make(map[string]float64)
|
|
for _, mapping := range activeMappings {
|
|
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
|
activeKeyIDs = append(activeKeyIDs, keyIDStr)
|
|
var score float64
|
|
if mapping.LastUsedAt != nil {
|
|
score = float64(mapping.LastUsedAt.UnixMilli())
|
|
}
|
|
lruMembers[keyIDStr] = score
|
|
}
|
|
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
|
|
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
|
|
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
|
|
go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
|
}
|
|
|
|
if err := pipe.Exec(); err != nil {
|
|
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
|
|
}
|
|
for key, value := range detailsToSet {
|
|
if err := r.store.Set(context.Background(), key, value, 0); err != nil {
|
|
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
|
|
}
|
|
}
|
|
|
|
r.logger.Info("Cache rebuild complete, including all polling structures.")
|
|
return nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
|
if err := r.decryptKey(key); err != nil {
|
|
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
|
|
}
|
|
keyJSON, err := json.Marshal(key)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
|
|
}
|
|
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
|
}
|
|
|
|
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
|
|
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
|
|
if err != nil {
|
|
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
|
|
}
|
|
|
|
pipe := r.store.Pipeline(ctx)
|
|
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
|
|
|
|
for _, groupID := range groupIDs {
|
|
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
|
|
|
|
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
|
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
|
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
|
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
|
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
|
go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
|
}
|
|
return pipe.Exec()
|
|
}
|
|
|
|
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
|
|
pipe := r.store.Pipeline(context.Background())
|
|
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
|
|
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
|
|
if mapping.Status == models.StatusActive {
|
|
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
|
|
}
|
|
return pipe.Exec()
|
|
}
|
|
|
|
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
|
|
if len(mappings) == 0 {
|
|
return nil
|
|
}
|
|
groupUpdates := make(map[uint]struct {
|
|
ToAdd []interface{}
|
|
ToRemove []interface{}
|
|
})
|
|
for _, mapping := range mappings {
|
|
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
|
update, ok := groupUpdates[mapping.KeyGroupID]
|
|
if !ok {
|
|
update = struct {
|
|
ToAdd []interface{}
|
|
ToRemove []interface{}
|
|
}{}
|
|
}
|
|
if mapping.Status == models.StatusActive {
|
|
update.ToRemove = append(update.ToRemove, keyIDStr)
|
|
update.ToAdd = append(update.ToAdd, keyIDStr)
|
|
} else {
|
|
update.ToRemove = append(update.ToRemove, keyIDStr)
|
|
}
|
|
groupUpdates[mapping.KeyGroupID] = update
|
|
}
|
|
pipe := r.store.Pipeline(context.Background())
|
|
var pipelineError error
|
|
for groupID, updates := range groupUpdates {
|
|
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
|
if len(updates.ToRemove) > 0 {
|
|
for _, keyID := range updates.ToRemove {
|
|
pipe.LRem(activeKeyListKey, 0, keyID)
|
|
}
|
|
}
|
|
if len(updates.ToAdd) > 0 {
|
|
pipe.LPush(activeKeyListKey, updates.ToAdd...)
|
|
}
|
|
}
|
|
if err := pipe.Exec(); err != nil {
|
|
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
|
|
}
|
|
return pipelineError
|
|
}
|