This commit is contained in:
XOF
2025-11-20 12:24:05 +08:00
commit f28bdc751f
164 changed files with 64248 additions and 0 deletions

View File

@@ -0,0 +1,206 @@
// Filename: internal/repository/auth_token.go
package repository
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/models"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// AuthTokenRepository defines the interface for AuthToken data access.
type AuthTokenRepository interface {
GetAllTokensWithGroups() ([]*models.AuthToken, error)
BatchUpdateTokens(updates []*models.TokenUpdateRequest) error
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) // <-- Add this line
SeedAdminToken(encryptedToken, tokenHash string) error // <-- And this line for the seeder
}
type gormAuthTokenRepository struct {
db *gorm.DB
crypto *crypto.Service
logger *logrus.Entry
}
func NewAuthTokenRepository(db *gorm.DB, crypto *crypto.Service, logger *logrus.Logger) AuthTokenRepository {
return &gormAuthTokenRepository{
db: db,
crypto: crypto,
logger: logger.WithField("component", "repository.authToken🔐"),
}
}
// GetAllTokensWithGroups fetches all tokens and decrypts them for use in services.
func (r *gormAuthTokenRepository) GetAllTokensWithGroups() ([]*models.AuthToken, error) {
var tokens []*models.AuthToken
if err := r.db.Preload("AllowedGroups").Find(&tokens).Error; err != nil {
return nil, err
}
// [CRITICAL] Decrypt all tokens before returning them.
if err := r.decryptTokens(tokens); err != nil {
// Log the error but return the partially decrypted data, as some might be usable.
r.logger.WithError(err).Error("Batch decryption failed for some auth tokens.")
}
return tokens, nil
}
// BatchUpdateTokens provides a transactional way to update all tokens, handling encryption.
func (r *gormAuthTokenRepository) BatchUpdateTokens(updates []*models.TokenUpdateRequest) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 1. Separate admin and user tokens from the request
var adminUpdate *models.TokenUpdateRequest
var userUpdates []*models.TokenUpdateRequest
for _, u := range updates {
if u.IsAdmin {
adminUpdate = u
} else {
userUpdates = append(userUpdates, u)
}
}
// 2. Handle Admin Token Update
if adminUpdate != nil && adminUpdate.Token != "" {
encryptedToken, err := r.crypto.Encrypt(adminUpdate.Token)
if err != nil {
return fmt.Errorf("failed to encrypt admin token: %w", err)
}
hash := sha256.Sum256([]byte(adminUpdate.Token))
tokenHash := hex.EncodeToString(hash[:])
// Update both encrypted value and the hash
updateData := map[string]interface{}{
"encrypted_token": encryptedToken,
"token_hash": tokenHash,
}
if err := tx.Model(&models.AuthToken{}).Where("is_admin = ?", true).Updates(updateData).Error; err != nil {
return fmt.Errorf("failed to update admin token in db: %w", err)
}
}
// 3. Handle User Tokens Upsert
var existingTokens []*models.AuthToken
if err := tx.Where("is_admin = ?", false).Find(&existingTokens).Error; err != nil {
return fmt.Errorf("failed to fetch existing user tokens: %w", err)
}
existingTokenMap := make(map[uint]bool)
for _, t := range existingTokens {
existingTokenMap[t.ID] = true
}
var tokensToUpsert []models.AuthToken
for _, req := range userUpdates {
if req.Token == "" {
continue // Skip tokens with empty values
}
encryptedToken, err := r.crypto.Encrypt(req.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token for upsert (ID: %d): %w", req.ID, err)
}
hash := sha256.Sum256([]byte(req.Token))
tokenHash := hex.EncodeToString(hash[:])
var groups []*models.KeyGroup
if len(req.AllowedGroupIDs) > 0 {
if err := tx.Find(&groups, req.AllowedGroupIDs).Error; err != nil {
return fmt.Errorf("failed to find key groups for token %d: %w", req.ID, err)
}
}
tokensToUpsert = append(tokensToUpsert, models.AuthToken{
ID: req.ID,
EncryptedToken: encryptedToken,
TokenHash: tokenHash,
Description: req.Description,
Tag: req.Tag,
Status: req.Status,
IsAdmin: false,
AllowedGroups: groups,
})
}
if len(tokensToUpsert) > 0 {
if err := tx.Save(&tokensToUpsert).Error; err != nil {
return fmt.Errorf("failed to upsert user tokens: %w", err)
}
}
// 4. Handle Deletions
incomingUserTokenIDs := make(map[uint]bool)
for _, u := range userUpdates {
if u.ID != 0 {
incomingUserTokenIDs[u.ID] = true
}
}
var idsToDelete []uint
for id := range existingTokenMap {
if !incomingUserTokenIDs[id] {
idsToDelete = append(idsToDelete, id)
}
}
if len(idsToDelete) > 0 {
if err := tx.Model(&models.AuthToken{}).Where("id IN ?", idsToDelete).Association("AllowedGroups").Clear(); err != nil {
return fmt.Errorf("failed to clear associations for tokens to be deleted: %w", err)
}
if err := tx.Where("id IN ?", idsToDelete).Delete(&models.AuthToken{}).Error; err != nil {
return fmt.Errorf("failed to delete user tokens: %w", err)
}
}
return nil
})
}
// --- Crypto Helper Functions ---
func (r *gormAuthTokenRepository) decryptToken(token *models.AuthToken) error {
if token == nil || token.EncryptedToken == "" || token.Token != "" {
return nil // Nothing to decrypt or already done
}
plaintext, err := r.crypto.Decrypt(token.EncryptedToken)
if err != nil {
return fmt.Errorf("failed to decrypt auth token ID %d: %w", token.ID, err)
}
token.Token = plaintext
return nil
}
func (r *gormAuthTokenRepository) decryptTokens(tokens []*models.AuthToken) error {
for i := range tokens {
if err := r.decryptToken(tokens[i]); err != nil {
r.logger.Error(err) // Log error but continue for other tokens
}
}
return nil
}
// GetTokenByHashedValue finds a token by its SHA256 hash for authentication.
func (r *gormAuthTokenRepository) GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) {
var authToken models.AuthToken
// Find the active token by its hash. This is the core of our secure authentication.
err := r.db.Where("token_hash = ? AND status = 'active'", tokenHash).Preload("AllowedGroups").First(&authToken).Error
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt the token before returning it to the service layer.
// This ensures that subsequent logic (like in ResourceService) gets the full, usable object.
if err := r.decryptToken(&authToken); err != nil {
return nil, err
}
return &authToken, nil
}
// SeedAdminToken is a special-purpose function for the seeder to insert the initial admin token.
func (r *gormAuthTokenRepository) SeedAdminToken(encryptedToken, tokenHash string) error {
adminToken := models.AuthToken{
EncryptedToken: encryptedToken,
TokenHash: tokenHash,
Description: "Default Administrator Token",
Tag: "SYSTEM_ADMIN",
IsAdmin: true,
Status: "active", // Ensure the seeded token is active
}
// Using FirstOrCreate to be idempotent. If an admin token already exists, it does nothing.
return r.db.Where(models.AuthToken{IsAdmin: true}).FirstOrCreate(&adminToken).Error
}

View File

@@ -0,0 +1,37 @@
// Filename: internal/repository/group_repository.go
package repository
import (
"gemini-balancer/internal/models"
"gorm.io/gorm"
)
func (r *gormGroupRepository) GetGroupByName(name string) (*models.KeyGroup, error) {
var group models.KeyGroup
if err := r.db.Where("name = ?", name).First(&group).Error; err != nil {
return nil, err
}
return &group, nil
}
func (r *gormGroupRepository) GetAllGroups() ([]*models.KeyGroup, error) {
var groups []*models.KeyGroup
if err := r.db.Order("\"order\" asc, id desc").Find(&groups).Error; err != nil {
return nil, err
}
return groups, nil
}
// 更新group排序
func (r *gormGroupRepository) UpdateOrderInTransaction(orders map[uint]int) error {
return r.db.Transaction(func(tx *gorm.DB) error {
for id, order := range orders {
result := tx.Model(&models.KeyGroup{}).Where("id = ?", id).Update("order", order)
if result.Error != nil {
return result.Error
}
}
return nil
})
}

View File

@@ -0,0 +1,204 @@
// Filename: internal/repository/key_cache.go
package repository
import (
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"strconv"
)
const (
KeyGroup = "group:%d:keys:active"
KeyDetails = "key:%d:details"
KeyMapping = "mapping:%d:%d"
KeyGroupSequential = "group:%d:keys:sequential"
KeyGroupLRU = "group:%d:keys:lru"
KeyGroupRandomMain = "group:%d:keys:random:main"
KeyGroupRandomCooldown = "group:%d:keys:random:cooldown"
BasePoolSequential = "basepool:%s:keys:sequential"
BasePoolLRU = "basepool:%s:keys:lru"
BasePoolRandomMain = "basepool:%s:keys:random:main"
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
)
func (r *gormKeyRepository) LoadAllKeysToStore() error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
}
keyMap := make(map[uint]*models.APIKey)
for _, m := range allMappings {
if m.APIKey != nil {
keyMap[m.APIKey.ID] = m.APIKey
}
}
keysToDecrypt := make([]models.APIKey, 0, len(keyMap))
for _, k := range keyMap {
keysToDecrypt = append(keysToDecrypt, *k)
}
if err := r.decryptKeys(keysToDecrypt); err != nil {
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
}
decryptedKeyMap := make(map[uint]models.APIKey)
for _, k := range keysToDecrypt {
decryptedKeyMap[k.ID] = k
}
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline()
detailsToSet := make(map[string][]byte)
var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil {
for _, group := range allGroups {
pipe.Del(
fmt.Sprintf(KeyGroup, group.ID),
fmt.Sprintf(KeyGroupSequential, group.ID),
fmt.Sprintf(KeyGroupLRU, group.ID),
fmt.Sprintf(KeyGroupRandomMain, group.ID),
fmt.Sprintf(KeyGroupRandomCooldown, group.ID),
)
}
} else {
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
}
for _, mapping := range allMappings {
if mapping.APIKey == nil {
continue
}
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
if !ok {
continue
}
keyJSON, _ := json.Marshal(decryptedKey)
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
mappingJSON, _ := json.Marshal(mapping)
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
if mapping.Status == models.StatusActive {
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
}
}
for groupID, activeMappings := range activeKeysByGroup {
if len(activeMappings) == 0 {
continue
}
var activeKeyIDs []interface{}
lruMembers := make(map[string]float64)
for _, mapping := range activeMappings {
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
activeKeyIDs = append(activeKeyIDs, keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
}
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
}
if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
}
r.logger.Info("Cache rebuild complete, including all polling structures.")
return nil
}
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err := r.decryptKey(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
}
keyJSON, err := json.Marshal(key)
if err != nil {
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
}
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
}
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(key.ID)
if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
}
pipe := r.store.Pipeline()
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
for _, groupID := range groupIDs {
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
}
return pipe.Exec()
}
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline()
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
if mapping.Status == models.StatusActive {
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
}
return pipe.Exec()
}
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 {
return nil
}
groupUpdates := make(map[uint]struct {
ToAdd []interface{}
ToRemove []interface{}
})
for _, mapping := range mappings {
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
update, ok := groupUpdates[mapping.KeyGroupID]
if !ok {
update = struct {
ToAdd []interface{}
ToRemove []interface{}
}{}
}
if mapping.Status == models.StatusActive {
update.ToRemove = append(update.ToRemove, keyIDStr)
update.ToAdd = append(update.ToAdd, keyIDStr)
} else {
update.ToRemove = append(update.ToRemove, keyIDStr)
}
groupUpdates[mapping.KeyGroupID] = update
}
pipe := r.store.Pipeline()
var pipelineError error
for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
if len(updates.ToRemove) > 0 {
for _, keyID := range updates.ToRemove {
pipe.LRem(activeKeyListKey, 0, keyID)
}
}
if len(updates.ToAdd) > 0 {
pipe.LPush(activeKeyListKey, updates.ToAdd...)
}
}
if err := pipe.Exec(); err != nil {
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
}
return pipelineError
}

View File

@@ -0,0 +1,280 @@
// Filename: internal/repository/key_crud.go
package repository
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"gemini-balancer/internal/models"
"math/rand"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, error) {
if len(keys) == 0 {
return []models.APIKey{}, nil
}
keyHashes := make([]string, len(keys))
keyValueToHashMap := make(map[string]string)
for i, k := range keys {
// All incoming keys must have plaintext APIKey
if k.APIKey == "" {
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
}
hash := sha256.Sum256([]byte(k.APIKey))
hashStr := hex.EncodeToString(hash[:])
keyHashes[i] = hashStr
keyValueToHashMap[k.APIKey] = hashStr
}
var finalKeys []models.APIKey
err := r.db.Transaction(func(tx *gorm.DB) error {
var existingKeys []models.APIKey
// [MODIFIED] Query by hash to find existing keys.
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
return err
}
existingKeyHashMap := make(map[string]models.APIKey)
for _, k := range existingKeys {
existingKeyHashMap[k.APIKeyHash] = k
}
var keysToCreate []models.APIKey
var keysToRestore []uint
for _, keyObj := range keys {
keyVal := keyObj.APIKey
hash := keyValueToHashMap[keyVal]
if ek, found := existingKeyHashMap[hash]; found {
if ek.DeletedAt.Valid {
keysToRestore = append(keysToRestore, ek.ID)
}
} else {
encryptedKey, err := r.crypto.Encrypt(keyVal)
if err != nil {
return fmt.Errorf("failed to encrypt key '%s...': %w", keyVal[:min(4, len(keyVal))], err)
}
keysToCreate = append(keysToCreate, models.APIKey{
EncryptedKey: encryptedKey,
APIKeyHash: hash,
})
}
}
if len(keysToRestore) > 0 {
if err := tx.Model(&models.APIKey{}).Unscoped().Where("id IN ?", keysToRestore).Update("deleted_at", nil).Error; err != nil {
return err
}
}
if len(keysToCreate) > 0 {
// [MODIFIED] Create now only provides encrypted data and hash.
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
return err
}
}
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
return err
}
// [CRITICAL] Decrypt all keys before returning them to the service layer.
return r.decryptKeys(finalKeys)
})
return finalKeys, err
}
func (r *gormKeyRepository) Update(key *models.APIKey) error {
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
// This indicates a potential change that needs to be re-encrypted.
if key.APIKey != "" {
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
if err != nil {
return fmt.Errorf("failed to re-encrypt key on update for ID %d: %w", key.ID, err)
}
key.EncryptedKey = encryptedKey
// Recalculate hash as a defensive measure.
hash := sha256.Sum256([]byte(key.APIKey))
key.APIKeyHash = hex.EncodeToString(hash[:])
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
return tx.Save(key).Error
})
if err != nil {
return err
}
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
return nil // Continue without cache update if decryption fails.
}
if err := r.updateStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
}
return nil
}
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
key, err := r.GetKeyByID(id) // This now returns a decrypted key
if err != nil {
return err
}
err = r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Unscoped().Delete(&models.APIKey{}, id).Error
})
if err != nil {
return err
}
if err := r.removeStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
}
return nil
}
func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error) {
if len(keyValues) == 0 {
return 0, nil
}
hashes := make([]string, len(keyValues))
for i, v := range keyValues {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
// Find the full key objects first to update the cache later.
var keysToDelete []models.APIKey
// [MODIFIED] Find by hash.
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
return 0, err
}
if len(keysToDelete) == 0 {
return 0, nil
}
// Decrypt them to ensure cache has plaintext if needed.
if err := r.decryptKeys(keysToDelete); err != nil {
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
}
var deletedCount int64
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
ids := pluckIDs(keysToDelete)
result := tx.Unscoped().Where("id IN ?", ids).Delete(&models.APIKey{})
if result.Error != nil {
return result.Error
}
deletedCount = result.RowsAffected
return nil
})
if err != nil {
return 0, err
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
return deletedCount, nil
}
func (r *gormKeyRepository) GetKeyByID(id uint) (*models.APIKey, error) {
var key models.APIKey
if err := r.db.First(&key, id).Error; err != nil {
return nil, err
}
if err := r.decryptKey(&key); err != nil {
return nil, err
}
return &key, nil
}
func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
if len(ids) == 0 {
return []models.APIKey{}, nil
}
var keys []models.APIKey
err := r.db.Where("id IN ?", ids).Find(&keys).Error
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt before returning.
return keys, r.decryptKeys(keys)
}
func (r *gormKeyRepository) GetKeyByValue(keyValue string) (*models.APIKey, error) {
hash := sha256.Sum256([]byte(keyValue))
hashStr := hex.EncodeToString(hash[:])
var key models.APIKey
if err := r.db.Where("api_key_hash = ?", hashStr).First(&key).Error; err != nil {
return nil, err
}
key.APIKey = keyValue
return &key, nil
}
func (r *gormKeyRepository) GetKeysByValues(keyValues []string) ([]models.APIKey, error) {
if len(keyValues) == 0 {
return []models.APIKey{}, nil
}
hashes := make([]string, len(keyValues))
for i, v := range keyValues {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
var keys []models.APIKey
err := r.db.Where("api_key_hash IN ?", hashes).Find(&keys).Error
if err != nil {
return nil, err
}
return keys, r.decryptKeys(keys)
}
func (r *gormKeyRepository) GetKeysByGroup(groupID uint) ([]models.APIKey, error) {
var keys []models.APIKey
err := r.db.Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID).
Find(&keys).Error
if err != nil {
return nil, err
}
return keys, r.decryptKeys(keys)
}
func (r *gormKeyRepository) CountByGroup(groupID uint) (int64, error) {
var count int64
err := r.db.Model(&models.APIKey{}).
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID).
Count(&count).Error
return count, err
}
// --- Helpers ---
func (r *gormKeyRepository) executeTransactionWithRetry(operation func(tx *gorm.DB) error) error {
const maxRetries = 3
const baseDelay = 50 * time.Millisecond
const maxJitter = 150 * time.Millisecond
var err error
for i := 0; i < maxRetries; i++ {
err = r.db.Transaction(operation)
if err == nil {
return nil
}
if strings.Contains(err.Error(), "database is locked") {
jitter := time.Duration(rand.Intn(int(maxJitter)))
totalDelay := baseDelay + jitter
r.logger.Debugf("Database is locked, retrying in %v... (attempt %d/%d)", totalDelay, i+1, maxRetries)
time.Sleep(totalDelay)
continue
}
break
}
return err
}
func pluckIDs(keys []models.APIKey) []uint {
ids := make([]uint, 0, len(keys))
for _, key := range keys {
ids = append(ids, key.ID)
}
return ids
}

View File

@@ -0,0 +1,62 @@
// Filename: internal/repository/key_crypto.go
package repository
import (
"fmt"
"gemini-balancer/internal/models"
)
func (r *gormKeyRepository) decryptKey(key *models.APIKey) error {
if key == nil || key.EncryptedKey == "" {
return nil // Nothing to decrypt
}
// Avoid re-decrypting if plaintext already exists
if key.APIKey != "" {
return nil
}
plaintext, err := r.crypto.Decrypt(key.EncryptedKey)
if err != nil {
return fmt.Errorf("failed to decrypt key ID %d: %w", key.ID, err)
}
key.APIKey = plaintext
return nil
}
func (r *gormKeyRepository) decryptKeys(keys []models.APIKey) error {
for i := range keys {
if err := r.decryptKey(&keys[i]); err != nil {
// In a batch operation, we log the error but allow the rest to proceed.
r.logger.Errorf("Batch decrypt error for key index %d: %v", i, err)
}
}
return nil
}
// Decrypt 实现了 KeyRepository 接口
func (r *gormKeyRepository) Decrypt(key *models.APIKey) error {
if key == nil || len(key.EncryptedKey) == 0 {
return nil // Nothing to decrypt
}
// Avoid re-decrypting if plaintext already exists
if key.APIKey != "" {
return nil
}
plaintext, err := r.crypto.Decrypt(key.EncryptedKey)
if err != nil {
return fmt.Errorf("failed to decrypt key ID %d: %w", key.ID, err)
}
key.APIKey = plaintext
return nil
}
// DecryptBatch 实现了 KeyRepository 接口
func (r *gormKeyRepository) DecryptBatch(keys []models.APIKey) error {
for i := range keys {
// This delegates to the robust single-key decryption logic.
if err := r.Decrypt(&keys[i]); err != nil {
// In a batch operation, we log the error but allow the rest to proceed.
r.logger.Errorf("Batch decrypt error for key index %d (ID: %d): %v", i, keys[i].ID, err)
}
}
return nil
}

View File

@@ -0,0 +1,169 @@
// Filename: internal/repository/key_maintenance.go
package repository
import (
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/models"
"io"
"time"
"gorm.io/gorm"
)
func (r *gormKeyRepository) StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error {
query := r.db.Model(&models.APIKey{}).
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID)
if statusFilter != "" && statusFilter != "all" {
query = query.Where("group_api_key_mappings.status = ?", statusFilter)
}
var batchKeys []models.APIKey
return query.FindInBatches(&batchKeys, 1000, func(tx *gorm.DB, batch int) error {
if err := r.decryptKeys(batchKeys); err != nil {
r.logger.Errorf("Failed to decrypt batch %d for streaming: %v", batch, err)
}
for _, key := range batchKeys {
if key.APIKey != "" {
if _, err := writer.Write([]byte(key.APIKey + "\n")); err != nil {
return err
}
}
}
return nil
}).Error
}
func (r *gormKeyRepository) UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error) {
if len(keyValues) == 0 {
return 0, nil
}
hashes := make([]string, len(keyValues))
for i, v := range keyValues {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
var result *gorm.DB
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result = tx.Model(&models.APIKey{}).
Where("api_key_hash IN ?", hashes).
Update("master_status", newStatus)
return result.Error
})
if err != nil {
return 0, err
}
return result.RowsAffected, nil
}
func (r *gormKeyRepository) UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error {
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
Update("master_status", newStatus)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
// This ensures that if the key ID doesn't exist, we return a standard "not found" error.
return gorm.ErrRecordNotFound
}
return nil
})
}
func (r *gormKeyRepository) DeleteOrphanKeys() (int64, error) {
var deletedCount int64
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
count, err := r.deleteOrphanKeysLogic(tx)
if err != nil {
return err
}
deletedCount = count
return nil
})
return deletedCount, err
}
func (r *gormKeyRepository) DeleteOrphanKeysTx(tx *gorm.DB) (int64, error) {
return r.deleteOrphanKeysLogic(tx)
}
func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
var orphanKeyIDs []uint
err := db.Raw(`
SELECT api_keys.id FROM api_keys
LEFT JOIN group_api_key_mappings ON api_keys.id = group_api_key_mappings.api_key_id
WHERE group_api_key_mappings.api_key_id IS NULL`).Scan(&orphanKeyIDs).Error
if err != nil {
return 0, err
}
if len(orphanKeyIDs) == 0 {
return 0, nil
}
var keysToDelete []models.APIKey
if err := db.Where("id IN ?", orphanKeyIDs).Find(&keysToDelete).Error; err != nil {
return 0, err
}
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
if result.Error != nil {
return 0, result.Error
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
return result.RowsAffected, nil
}
func (r *gormKeyRepository) HardDeleteSoftDeletedBefore(date time.Time) (int64, error) {
result := r.db.Unscoped().Where("deleted_at < ?", date).Delete(&models.APIKey{})
return result.RowsAffected, result.Error
}
func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
var keys []*models.APIKey
err := r.db.Where("master_status = ?", models.MasterStatusActive).Find(&keys).Error
if err != nil {
return nil, err
}
for _, key := range keys {
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("Failed to decrypt key ID %d during GetActiveMasterKeys: %v", key.ID, err)
}
}
return keys, nil
}
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
Update("master_status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
})
if err == nil {
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
go func() {
if err := r.LoadAllKeysToStore(); err != nil {
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
}
}()
}
return err
}

View File

@@ -0,0 +1,289 @@
// Filename: internal/repository/key_mapping.go
package repository
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"gemini-balancer/internal/models"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 {
return nil
}
var mappings []models.GroupAPIKeyMapping
for _, keyID := range keyIDs {
mappings = append(mappings, models.GroupAPIKeyMapping{
KeyGroupID: groupID,
APIKeyID: keyID,
})
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&mappings).Error
})
if err != nil {
return err
}
for _, keyID := range keyIDs {
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
}
return nil
}
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
if len(keyIDs) == 0 {
return 0, nil
}
var unlinkedCount int64
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Table("group_api_key_mappings").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Delete(nil)
if result.Error != nil {
return result.Error
}
unlinkedCount = result.RowsAffected
return nil
})
if err != nil {
return 0, err
}
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
for _, keyID := range keyIDs {
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
}
return unlinkedCount, nil
}
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
strGroupIDs, err := r.store.SMembers(cacheKey)
if err != nil || len(strGroupIDs) == 0 {
var groupIDs []uint
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
if dbErr != nil {
return nil, dbErr
}
if len(groupIDs) > 0 {
var interfaceSlice []interface{}
for _, id := range groupIDs {
interfaceSlice = append(interfaceSlice, id)
}
r.store.SAdd(cacheKey, interfaceSlice...)
}
return groupIDs, nil
}
var groupIDs []uint
for _, strID := range strGroupIDs {
id, _ := strconv.Atoi(strID)
groupIDs = append(groupIDs, uint(id))
}
return groupIDs, nil
}
func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) {
var mapping models.GroupAPIKeyMapping
err := r.db.Where("key_group_id = ? AND api_key_id = ?", groupID, keyID).First(&mapping).Error
return &mapping, err
}
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})
if err != nil {
return err
}
return r.updateStoreCacheForMapping(mapping)
}
// [MODIFIED & FINAL] This is the final version for the core refactoring.
func (r *gormKeyRepository) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) {
items := make([]*models.APIKeyDetails, 0)
var total int64
query := r.db.Table("api_keys").
Select(`
api_keys.id, api_keys.created_at, api_keys.updated_at,
api_keys.encrypted_key, -- Select encrypted key to be scanned into APIKeyDetails.EncryptedKey
api_keys.master_status,
m.status, m.last_error, m.consecutive_error_count, m.last_used_at, m.cooldown_until
`).
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id")
if params.KeyGroupID <= 0 {
return nil, 0, errors.New("KeyGroupID is required for this query")
}
query = query.Where("m.key_group_id = ?", params.KeyGroupID)
if params.Status != "" {
query = query.Where("LOWER(m.status) = LOWER(?)", params.Status)
}
// Keyword search is now handled by the service layer.
if params.Keyword != "" {
r.logger.Warn("DB query is ignoring keyword; service layer will perform in-memory filtering.")
}
countQuery := query.Model(&models.APIKey{}) // Use model for count to avoid GORM issues
err := countQuery.Count(&total).Error
if err != nil {
return nil, 0, err
}
if total == 0 {
return items, 0, nil
}
offset := (params.Page - 1) * params.PageSize
err = query.Order("api_keys.id DESC").Limit(params.PageSize).Offset(offset).Scan(&items).Error
if err != nil {
return nil, 0, err
}
// Decrypt all results before returning. This loop is now valid.
for i := range items {
if items[i].EncryptedKey != "" {
plaintext, err := r.crypto.Decrypt(items[i].EncryptedKey)
if err == nil {
items[i].APIKey = plaintext
} else {
items[i].APIKey = "[DECRYPTION FAILED]"
r.logger.Errorf("Failed to decrypt key ID %d for pagination: %v", items[i].ID, err)
}
}
}
return items, total, nil
}
// [MODIFIED & FINAL] Uses hashes for lookup.
func (r *gormKeyRepository) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) {
if len(values) == 0 {
return []models.APIKey{}, nil
}
hashes := make([]string, len(values))
for i, v := range values {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
var keys []models.APIKey
err := r.db.Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID).
Where("api_keys.api_key_hash IN ?", hashes).
Find(&keys).Error
if err != nil {
return nil, err
}
return keys, r.decryptKeys(keys)
}
// [MODIFIED & FINAL] Fetches full objects, decrypts, then extracts strings.
func (r *gormKeyRepository) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) {
var keys []models.APIKey
query := r.db.Table("api_keys").
Select("api_keys.*").
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id").
Where("m.key_group_id = ?", groupID)
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
lowerStatuses := make([]string, len(statuses))
for i, s := range statuses {
lowerStatuses[i] = strings.ToLower(s)
}
query = query.Where("LOWER(m.status) IN (?)", lowerStatuses)
}
if err := query.Find(&keys).Error; err != nil {
return nil, err
}
if err := r.decryptKeys(keys); err != nil {
return nil, fmt.Errorf("decryption failed during FindKeyValuesByStatus: %w", err)
}
keyValues := make([]string, len(keys))
for i, key := range keys {
keyValues[i] = key.APIKey
}
return keyValues, nil
}
// [MODIFIED & FINAL] Consistent with the new pattern.
func (r *gormKeyRepository) GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error) {
var keys []models.APIKey
query := r.db.Table("api_keys").
Select("api_keys.*").
Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID)
if len(statuses) > 0 {
isAll := false
for _, s := range statuses {
if strings.ToLower(s) == "all" {
isAll = true
break
}
}
if !isAll {
lowerStatuses := make([]string, len(statuses))
for i, s := range statuses {
lowerStatuses[i] = strings.ToLower(s)
}
query = query.Where("LOWER(group_api_key_mappings.status) IN ?", lowerStatuses)
}
}
if err := query.Find(&keys).Error; err != nil {
return nil, err
}
if err := r.decryptKeys(keys); err != nil {
return nil, fmt.Errorf("decryption failed during GetKeyStringsByGroupAndStatus: %w", err)
}
keyStrings := make([]string, len(keys))
for i, key := range keys {
keyStrings[i] = key.APIKey
}
return keyStrings, nil
}
// FindKeyIDsByStatus remains unchanged as it does not deal with key values.
func (r *gormKeyRepository) FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error) {
var keyIDs []uint
query := r.db.Table("group_api_key_mappings").
Select("api_key_id").
Where("key_group_id = ?", groupID)
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
lowerStatuses := make([]string, len(statuses))
for i, s := range statuses {
lowerStatuses[i] = strings.ToLower(s)
}
query = query.Where("LOWER(status) IN (?)", lowerStatuses)
}
if err := query.Pluck("api_key_id", &keyIDs).Error; err != nil {
return nil, err
}
return keyIDs, nil
}
func (r *gormKeyRepository) UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error {
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})
}

View File

@@ -0,0 +1,276 @@
// Filename: internal/repository/key_selector.go
package repository
import (
"crypto/sha1"
"encoding/json"
"errors"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"io"
"sort"
"strconv"
"time"
"gorm.io/gorm"
)
const (
CacheTTL = 5 * time.Minute
EmptyPoolPlaceholder = "EMPTY_POOL"
EmptyCacheTTL = 1 * time.Minute
)
// SelectOneActiveKey 根据指定的轮询策略从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
poolID := generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
return nil, nil, err
}
var keyIDStr string
var err error
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
}
return apiKey, group, nil
}
}
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
// 使用 LIST 键作为存在性检查的标志
listKey := fmt.Sprintf(BasePoolSequential, poolID)
exists, err := r.store.Exists(listKey)
if err != nil {
return err
}
if exists {
val, err := r.store.LIndex(listKey, 0)
if err == nil && val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound
}
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID)
var allActiveKeyIDs []string
lruMembers := make(map[string]float64)
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
if err != nil {
r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID)
continue
}
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
}
}
}
if len(allActiveKeyIDs) == 0 {
pipe := r.store.Pipeline()
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
if err := pipe.Exec(); err != nil {
return err
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
}
return nil
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
})
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup) string {
ids := make([]int, len(groups))
for i, g := range groups {
ids[i] = int(g.ID)
}
sort.Ints(ids)
h := sha1.New()
io.WriteString(h, fmt.Sprintf("%v", ids))
return fmt.Sprintf("%x", h.Sum(nil))
}
// toInterfaceSlice 类型转换辅助函数
func toInterfaceSlice(slice []string) []interface{} {
result := make([]interface{}, len(slice))
for i, v := range slice {
result[i] = v
}
return result
}
// nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略
func nowMilli() float64 {
return float64(time.Now().UnixMilli())
}
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
}
var apiKey models.APIKey
if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
}
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
}
var mapping models.GroupAPIKeyMapping
if err := json.Unmarshal(mappingJSON, &mapping); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err)
}
return &apiKey, &mapping, nil
}

View File

@@ -0,0 +1,77 @@
// Filename: internal/repository/key_writer.go
package repository
import (
"fmt"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"strconv"
"time"
)
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
timestamp := float64(time.Now().UnixMilli())
members := map[string]float64{
strconv.FormatUint(uint64(keyID), 10): timestamp,
}
if err := r.store.ZAdd(lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
}
}
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
}
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
}
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(lruKey, keyIDStr)
_ = r.store.SRem(mainPoolKey, keyIDStr)
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
if newStatus == models.StatusActive {
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
}
members := map[string]float64{keyIDStr: 0}
if err := r.store.ZAdd(lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
}
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
}
}
}
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
if success {
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
}
return
}
if apiErr == nil {
r.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided.", key.ID, group.ID)
return
}
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
// This call is correct. It uses the synchronous, direct method.
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
}

View File

@@ -0,0 +1,107 @@
// Filename: internal/repository/repository.go
package repository
import (
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"io"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// BasePool 虚拟的临时资源池,用于智能聚合模式。
type BasePool struct {
CandidateGroups []*models.KeyGroup
PollingStrategy models.PollingStrategy
}
type KeyRepository interface {
// --- 核心选取与调度 --- key_selector
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
// --- 加密与解密 --- key_crud
Decrypt(key *models.APIKey) error
DecryptBatch(keys []models.APIKey) error
// --- 基础增删改查 --- key_crud
AddKeys(keys []models.APIKey) ([]models.APIKey, error)
Update(key *models.APIKey) error
HardDeleteByID(id uint) error
HardDeleteByValues(keyValues []string) (int64, error)
GetKeyByID(id uint) (*models.APIKey, error)
GetKeyByValue(keyValue string) (*models.APIKey, error)
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
CountByGroup(groupID uint) (int64, error)
// --- 多对多关系管理 --- key_mapping
LinkKeysToGroup(groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(keyID uint) ([]uint, error)
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error)
GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error)
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
// --- 缓存管理 --- key_cache
LoadAllKeysToStore() error
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
// --- 维护与后台任务 --- key_maintenance
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error)
UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error
DeleteOrphanKeys() (int64, error)
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
GetActiveMasterKeys() ([]*models.APIKey, error)
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
// --- 轮询策略的"写"操作 --- key_writer
UpdateKeyUsageTimestamp(groupID, keyID uint)
// 同步更新缓存,供核心业务使用
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
// 异步更新缓存,供事件订阅者使用
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
}
type GroupRepository interface {
GetGroupByName(name string) (*models.KeyGroup, error)
GetAllGroups() ([]*models.KeyGroup, error)
UpdateOrderInTransaction(orders map[uint]int) error
}
type gormKeyRepository struct {
db *gorm.DB
store store.Store
logger *logrus.Entry
crypto *crypto.Service
}
type gormGroupRepository struct {
db *gorm.DB
}
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
return &gormKeyRepository{
db: db,
store: s,
logger: logger.WithField("component", "repository.key🔗"),
crypto: crypto,
}
}
func NewGroupRepository(db *gorm.DB) GroupRepository {
return &gormGroupRepository{db: db}
}