Initial commit
This commit is contained in:
206
internal/repository/auth_token.go
Normal file
206
internal/repository/auth_token.go
Normal file
@@ -0,0 +1,206 @@
|
||||
// Filename: internal/repository/auth_token.go
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuthTokenRepository defines the interface for AuthToken data access.
|
||||
type AuthTokenRepository interface {
|
||||
GetAllTokensWithGroups() ([]*models.AuthToken, error)
|
||||
BatchUpdateTokens(updates []*models.TokenUpdateRequest) error
|
||||
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) // <-- Add this line
|
||||
SeedAdminToken(encryptedToken, tokenHash string) error // <-- And this line for the seeder
|
||||
}
|
||||
|
||||
type gormAuthTokenRepository struct {
|
||||
db *gorm.DB
|
||||
crypto *crypto.Service
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewAuthTokenRepository(db *gorm.DB, crypto *crypto.Service, logger *logrus.Logger) AuthTokenRepository {
|
||||
return &gormAuthTokenRepository{
|
||||
db: db,
|
||||
crypto: crypto,
|
||||
logger: logger.WithField("component", "repository.authToken🔐"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllTokensWithGroups fetches all tokens and decrypts them for use in services.
|
||||
func (r *gormAuthTokenRepository) GetAllTokensWithGroups() ([]*models.AuthToken, error) {
|
||||
var tokens []*models.AuthToken
|
||||
if err := r.db.Preload("AllowedGroups").Find(&tokens).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt all tokens before returning them.
|
||||
if err := r.decryptTokens(tokens); err != nil {
|
||||
// Log the error but return the partially decrypted data, as some might be usable.
|
||||
r.logger.WithError(err).Error("Batch decryption failed for some auth tokens.")
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// BatchUpdateTokens provides a transactional way to update all tokens, handling encryption.
|
||||
func (r *gormAuthTokenRepository) BatchUpdateTokens(updates []*models.TokenUpdateRequest) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Separate admin and user tokens from the request
|
||||
var adminUpdate *models.TokenUpdateRequest
|
||||
var userUpdates []*models.TokenUpdateRequest
|
||||
for _, u := range updates {
|
||||
if u.IsAdmin {
|
||||
adminUpdate = u
|
||||
} else {
|
||||
userUpdates = append(userUpdates, u)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Handle Admin Token Update
|
||||
if adminUpdate != nil && adminUpdate.Token != "" {
|
||||
encryptedToken, err := r.crypto.Encrypt(adminUpdate.Token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt admin token: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(adminUpdate.Token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// Update both encrypted value and the hash
|
||||
updateData := map[string]interface{}{
|
||||
"encrypted_token": encryptedToken,
|
||||
"token_hash": tokenHash,
|
||||
}
|
||||
if err := tx.Model(&models.AuthToken{}).Where("is_admin = ?", true).Updates(updateData).Error; err != nil {
|
||||
return fmt.Errorf("failed to update admin token in db: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle User Tokens Upsert
|
||||
var existingTokens []*models.AuthToken
|
||||
if err := tx.Where("is_admin = ?", false).Find(&existingTokens).Error; err != nil {
|
||||
return fmt.Errorf("failed to fetch existing user tokens: %w", err)
|
||||
}
|
||||
existingTokenMap := make(map[uint]bool)
|
||||
for _, t := range existingTokens {
|
||||
existingTokenMap[t.ID] = true
|
||||
}
|
||||
|
||||
var tokensToUpsert []models.AuthToken
|
||||
for _, req := range userUpdates {
|
||||
if req.Token == "" {
|
||||
continue // Skip tokens with empty values
|
||||
}
|
||||
encryptedToken, err := r.crypto.Encrypt(req.Token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt token for upsert (ID: %d): %w", req.ID, err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(req.Token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
var groups []*models.KeyGroup
|
||||
if len(req.AllowedGroupIDs) > 0 {
|
||||
if err := tx.Find(&groups, req.AllowedGroupIDs).Error; err != nil {
|
||||
return fmt.Errorf("failed to find key groups for token %d: %w", req.ID, err)
|
||||
}
|
||||
}
|
||||
tokensToUpsert = append(tokensToUpsert, models.AuthToken{
|
||||
ID: req.ID,
|
||||
EncryptedToken: encryptedToken,
|
||||
TokenHash: tokenHash,
|
||||
Description: req.Description,
|
||||
Tag: req.Tag,
|
||||
Status: req.Status,
|
||||
IsAdmin: false,
|
||||
AllowedGroups: groups,
|
||||
})
|
||||
}
|
||||
if len(tokensToUpsert) > 0 {
|
||||
if err := tx.Save(&tokensToUpsert).Error; err != nil {
|
||||
return fmt.Errorf("failed to upsert user tokens: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Handle Deletions
|
||||
incomingUserTokenIDs := make(map[uint]bool)
|
||||
for _, u := range userUpdates {
|
||||
if u.ID != 0 {
|
||||
incomingUserTokenIDs[u.ID] = true
|
||||
}
|
||||
}
|
||||
var idsToDelete []uint
|
||||
for id := range existingTokenMap {
|
||||
if !incomingUserTokenIDs[id] {
|
||||
idsToDelete = append(idsToDelete, id)
|
||||
}
|
||||
}
|
||||
if len(idsToDelete) > 0 {
|
||||
if err := tx.Model(&models.AuthToken{}).Where("id IN ?", idsToDelete).Association("AllowedGroups").Clear(); err != nil {
|
||||
return fmt.Errorf("failed to clear associations for tokens to be deleted: %w", err)
|
||||
}
|
||||
if err := tx.Where("id IN ?", idsToDelete).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete user tokens: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// --- Crypto Helper Functions ---
|
||||
|
||||
func (r *gormAuthTokenRepository) decryptToken(token *models.AuthToken) error {
|
||||
if token == nil || token.EncryptedToken == "" || token.Token != "" {
|
||||
return nil // Nothing to decrypt or already done
|
||||
}
|
||||
plaintext, err := r.crypto.Decrypt(token.EncryptedToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt auth token ID %d: %w", token.ID, err)
|
||||
}
|
||||
token.Token = plaintext
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormAuthTokenRepository) decryptTokens(tokens []*models.AuthToken) error {
|
||||
for i := range tokens {
|
||||
if err := r.decryptToken(tokens[i]); err != nil {
|
||||
r.logger.Error(err) // Log error but continue for other tokens
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenByHashedValue finds a token by its SHA256 hash for authentication.
|
||||
func (r *gormAuthTokenRepository) GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) {
|
||||
var authToken models.AuthToken
|
||||
// Find the active token by its hash. This is the core of our secure authentication.
|
||||
err := r.db.Where("token_hash = ? AND status = 'active'", tokenHash).Preload("AllowedGroups").First(&authToken).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt the token before returning it to the service layer.
|
||||
// This ensures that subsequent logic (like in ResourceService) gets the full, usable object.
|
||||
if err := r.decryptToken(&authToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authToken, nil
|
||||
}
|
||||
|
||||
// SeedAdminToken is a special-purpose function for the seeder to insert the initial admin token.
|
||||
func (r *gormAuthTokenRepository) SeedAdminToken(encryptedToken, tokenHash string) error {
|
||||
adminToken := models.AuthToken{
|
||||
EncryptedToken: encryptedToken,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Default Administrator Token",
|
||||
Tag: "SYSTEM_ADMIN",
|
||||
IsAdmin: true,
|
||||
Status: "active", // Ensure the seeded token is active
|
||||
}
|
||||
// Using FirstOrCreate to be idempotent. If an admin token already exists, it does nothing.
|
||||
return r.db.Where(models.AuthToken{IsAdmin: true}).FirstOrCreate(&adminToken).Error
|
||||
}
|
||||
37
internal/repository/group_repository.go
Normal file
37
internal/repository/group_repository.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Filename: internal/repository/group_repository.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (r *gormGroupRepository) GetGroupByName(name string) (*models.KeyGroup, error) {
|
||||
var group models.KeyGroup
|
||||
if err := r.db.Where("name = ?", name).First(&group).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (r *gormGroupRepository) GetAllGroups() ([]*models.KeyGroup, error) {
|
||||
var groups []*models.KeyGroup
|
||||
if err := r.db.Order("\"order\" asc, id desc").Find(&groups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// 更新group排序
|
||||
func (r *gormGroupRepository) UpdateOrderInTransaction(orders map[uint]int) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for id, order := range orders {
|
||||
result := tx.Model(&models.KeyGroup{}).Where("id = ?", id).Update("order", order)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
204
internal/repository/key_cache.go
Normal file
204
internal/repository/key_cache.go
Normal file
@@ -0,0 +1,204 @@
|
||||
// Filename: internal/repository/key_cache.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
KeyGroup = "group:%d:keys:active"
|
||||
KeyDetails = "key:%d:details"
|
||||
KeyMapping = "mapping:%d:%d"
|
||||
KeyGroupSequential = "group:%d:keys:sequential"
|
||||
KeyGroupLRU = "group:%d:keys:lru"
|
||||
KeyGroupRandomMain = "group:%d:keys:random:main"
|
||||
KeyGroupRandomCooldown = "group:%d:keys:random:cooldown"
|
||||
BasePoolSequential = "basepool:%s:keys:sequential"
|
||||
BasePoolLRU = "basepool:%s:keys:lru"
|
||||
BasePoolRandomMain = "basepool:%s:keys:random:main"
|
||||
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
|
||||
}
|
||||
|
||||
keyMap := make(map[uint]*models.APIKey)
|
||||
for _, m := range allMappings {
|
||||
if m.APIKey != nil {
|
||||
keyMap[m.APIKey.ID] = m.APIKey
|
||||
}
|
||||
}
|
||||
keysToDecrypt := make([]models.APIKey, 0, len(keyMap))
|
||||
for _, k := range keyMap {
|
||||
keysToDecrypt = append(keysToDecrypt, *k)
|
||||
}
|
||||
if err := r.decryptKeys(keysToDecrypt); err != nil {
|
||||
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
|
||||
}
|
||||
decryptedKeyMap := make(map[uint]models.APIKey)
|
||||
for _, k := range keysToDecrypt {
|
||||
decryptedKeyMap[k.ID] = k
|
||||
}
|
||||
|
||||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
pipe := r.store.Pipeline()
|
||||
detailsToSet := make(map[string][]byte)
|
||||
var allGroups []*models.KeyGroup
|
||||
if err := r.db.Find(&allGroups).Error; err == nil {
|
||||
for _, group := range allGroups {
|
||||
pipe.Del(
|
||||
fmt.Sprintf(KeyGroup, group.ID),
|
||||
fmt.Sprintf(KeyGroupSequential, group.ID),
|
||||
fmt.Sprintf(KeyGroupLRU, group.ID),
|
||||
fmt.Sprintf(KeyGroupRandomMain, group.ID),
|
||||
fmt.Sprintf(KeyGroupRandomCooldown, group.ID),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
|
||||
}
|
||||
|
||||
for _, mapping := range allMappings {
|
||||
if mapping.APIKey == nil {
|
||||
continue
|
||||
}
|
||||
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keyJSON, _ := json.Marshal(decryptedKey)
|
||||
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
|
||||
mappingJSON, _ := json.Marshal(mapping)
|
||||
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
|
||||
if mapping.Status == models.StatusActive {
|
||||
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
}
|
||||
|
||||
for groupID, activeMappings := range activeKeysByGroup {
|
||||
if len(activeMappings) == 0 {
|
||||
continue
|
||||
}
|
||||
var activeKeyIDs []interface{}
|
||||
lruMembers := make(map[string]float64)
|
||||
for _, mapping := range activeMappings {
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
activeKeyIDs = append(activeKeyIDs, keyIDStr)
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
}
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
|
||||
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
}
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
|
||||
}
|
||||
for key, value := range detailsToSet {
|
||||
if err := r.store.Set(key, value, 0); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Cache rebuild complete, including all polling structures.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
|
||||
}
|
||||
keyJSON, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
|
||||
}
|
||||
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(key.ID)
|
||||
if err != nil {
|
||||
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
|
||||
}
|
||||
|
||||
pipe := r.store.Pipeline()
|
||||
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
|
||||
|
||||
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
}
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
pipe := r.store.Pipeline()
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
|
||||
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
|
||||
if mapping.Status == models.StatusActive {
|
||||
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
|
||||
}
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
|
||||
if len(mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
groupUpdates := make(map[uint]struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
})
|
||||
for _, mapping := range mappings {
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
update, ok := groupUpdates[mapping.KeyGroupID]
|
||||
if !ok {
|
||||
update = struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
}{}
|
||||
}
|
||||
if mapping.Status == models.StatusActive {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
update.ToAdd = append(update.ToAdd, keyIDStr)
|
||||
} else {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
}
|
||||
groupUpdates[mapping.KeyGroupID] = update
|
||||
}
|
||||
pipe := r.store.Pipeline()
|
||||
var pipelineError error
|
||||
for groupID, updates := range groupUpdates {
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
if len(updates.ToRemove) > 0 {
|
||||
for _, keyID := range updates.ToRemove {
|
||||
pipe.LRem(activeKeyListKey, 0, keyID)
|
||||
}
|
||||
}
|
||||
if len(updates.ToAdd) > 0 {
|
||||
pipe.LPush(activeKeyListKey, updates.ToAdd...)
|
||||
}
|
||||
}
|
||||
if err := pipe.Exec(); err != nil {
|
||||
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
|
||||
}
|
||||
return pipelineError
|
||||
}
|
||||
280
internal/repository/key_crud.go
Normal file
280
internal/repository/key_crud.go
Normal file
@@ -0,0 +1,280 @@
|
||||
// Filename: internal/repository/key_crud.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, error) {
|
||||
if len(keys) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
keyHashes := make([]string, len(keys))
|
||||
keyValueToHashMap := make(map[string]string)
|
||||
for i, k := range keys {
|
||||
// All incoming keys must have plaintext APIKey
|
||||
if k.APIKey == "" {
|
||||
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(k.APIKey))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
keyHashes[i] = hashStr
|
||||
keyValueToHashMap[k.APIKey] = hashStr
|
||||
}
|
||||
var finalKeys []models.APIKey
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var existingKeys []models.APIKey
|
||||
// [MODIFIED] Query by hash to find existing keys.
|
||||
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
existingKeyHashMap := make(map[string]models.APIKey)
|
||||
for _, k := range existingKeys {
|
||||
existingKeyHashMap[k.APIKeyHash] = k
|
||||
}
|
||||
var keysToCreate []models.APIKey
|
||||
var keysToRestore []uint
|
||||
for _, keyObj := range keys {
|
||||
keyVal := keyObj.APIKey
|
||||
hash := keyValueToHashMap[keyVal]
|
||||
if ek, found := existingKeyHashMap[hash]; found {
|
||||
if ek.DeletedAt.Valid {
|
||||
keysToRestore = append(keysToRestore, ek.ID)
|
||||
}
|
||||
} else {
|
||||
encryptedKey, err := r.crypto.Encrypt(keyVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key '%s...': %w", keyVal[:min(4, len(keyVal))], err)
|
||||
}
|
||||
keysToCreate = append(keysToCreate, models.APIKey{
|
||||
EncryptedKey: encryptedKey,
|
||||
APIKeyHash: hash,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(keysToRestore) > 0 {
|
||||
if err := tx.Model(&models.APIKey{}).Unscoped().Where("id IN ?", keysToRestore).Update("deleted_at", nil).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(keysToCreate) > 0 {
|
||||
// [MODIFIED] Create now only provides encrypted data and hash.
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
|
||||
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// [CRITICAL] Decrypt all keys before returning them to the service layer.
|
||||
return r.decryptKeys(finalKeys)
|
||||
})
|
||||
return finalKeys, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
|
||||
// This indicates a potential change that needs to be re-encrypted.
|
||||
if key.APIKey != "" {
|
||||
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to re-encrypt key on update for ID %d: %w", key.ID, err)
|
||||
}
|
||||
key.EncryptedKey = encryptedKey
|
||||
// Recalculate hash as a defensive measure.
|
||||
hash := sha256.Sum256([]byte(key.APIKey))
|
||||
key.APIKeyHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
|
||||
return tx.Save(key).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
|
||||
return nil // Continue without cache update if decryption fails.
|
||||
}
|
||||
if err := r.updateStoreCacheForKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
key, err := r.GetKeyByID(id) // This now returns a decrypted key
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Unscoped().Delete(&models.APIKey{}, id).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.removeStoreCacheForKey(key); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error) {
|
||||
if len(keyValues) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
hashes := make([]string, len(keyValues))
|
||||
for i, v := range keyValues {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
// Find the full key objects first to update the cache later.
|
||||
var keysToDelete []models.APIKey
|
||||
// [MODIFIED] Find by hash.
|
||||
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(keysToDelete) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Decrypt them to ensure cache has plaintext if needed.
|
||||
if err := r.decryptKeys(keysToDelete); err != nil {
|
||||
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
|
||||
}
|
||||
var deletedCount int64
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
ids := pluckIDs(keysToDelete)
|
||||
result := tx.Unscoped().Where("id IN ?", ids).Delete(&models.APIKey{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
deletedCount = result.RowsAffected
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
return deletedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeyByID(id uint) (*models.APIKey, error) {
|
||||
var key models.APIKey
|
||||
if err := r.db.First(&key, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKey(&key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
|
||||
if len(ids) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
var keys []models.APIKey
|
||||
err := r.db.Where("id IN ?", ids).Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt before returning.
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeyByValue(keyValue string) (*models.APIKey, error) {
|
||||
hash := sha256.Sum256([]byte(keyValue))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
var key models.APIKey
|
||||
if err := r.db.Where("api_key_hash = ?", hashStr).First(&key).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key.APIKey = keyValue
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeysByValues(keyValues []string) ([]models.APIKey, error) {
|
||||
if len(keyValues) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
hashes := make([]string, len(keyValues))
|
||||
for i, v := range keyValues {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
var keys []models.APIKey
|
||||
err := r.db.Where("api_key_hash IN ?", hashes).Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeysByGroup(groupID uint) ([]models.APIKey, error) {
|
||||
var keys []models.APIKey
|
||||
err := r.db.Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
||||
Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) CountByGroup(groupID uint) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&models.APIKey{}).
|
||||
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func (r *gormKeyRepository) executeTransactionWithRetry(operation func(tx *gorm.DB) error) error {
|
||||
const maxRetries = 3
|
||||
const baseDelay = 50 * time.Millisecond
|
||||
const maxJitter = 150 * time.Millisecond
|
||||
var err error
|
||||
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
err = r.db.Transaction(operation)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "database is locked") {
|
||||
jitter := time.Duration(rand.Intn(int(maxJitter)))
|
||||
totalDelay := baseDelay + jitter
|
||||
r.logger.Debugf("Database is locked, retrying in %v... (attempt %d/%d)", totalDelay, i+1, maxRetries)
|
||||
time.Sleep(totalDelay)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func pluckIDs(keys []models.APIKey) []uint {
|
||||
ids := make([]uint, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
ids = append(ids, key.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
62
internal/repository/key_crypto.go
Normal file
62
internal/repository/key_crypto.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Filename: internal/repository/key_crypto.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) decryptKey(key *models.APIKey) error {
|
||||
if key == nil || key.EncryptedKey == "" {
|
||||
return nil // Nothing to decrypt
|
||||
}
|
||||
// Avoid re-decrypting if plaintext already exists
|
||||
if key.APIKey != "" {
|
||||
return nil
|
||||
}
|
||||
plaintext, err := r.crypto.Decrypt(key.EncryptedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt key ID %d: %w", key.ID, err)
|
||||
}
|
||||
key.APIKey = plaintext
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) decryptKeys(keys []models.APIKey) error {
|
||||
for i := range keys {
|
||||
if err := r.decryptKey(&keys[i]); err != nil {
|
||||
// In a batch operation, we log the error but allow the rest to proceed.
|
||||
r.logger.Errorf("Batch decrypt error for key index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt 实现了 KeyRepository 接口
|
||||
func (r *gormKeyRepository) Decrypt(key *models.APIKey) error {
|
||||
if key == nil || len(key.EncryptedKey) == 0 {
|
||||
return nil // Nothing to decrypt
|
||||
}
|
||||
// Avoid re-decrypting if plaintext already exists
|
||||
if key.APIKey != "" {
|
||||
return nil
|
||||
}
|
||||
plaintext, err := r.crypto.Decrypt(key.EncryptedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt key ID %d: %w", key.ID, err)
|
||||
}
|
||||
key.APIKey = plaintext
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptBatch 实现了 KeyRepository 接口
|
||||
func (r *gormKeyRepository) DecryptBatch(keys []models.APIKey) error {
|
||||
for i := range keys {
|
||||
// This delegates to the robust single-key decryption logic.
|
||||
if err := r.Decrypt(&keys[i]); err != nil {
|
||||
// In a batch operation, we log the error but allow the rest to proceed.
|
||||
r.logger.Errorf("Batch decrypt error for key index %d (ID: %d): %v", i, keys[i].ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
169
internal/repository/key_maintenance.go
Normal file
169
internal/repository/key_maintenance.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Filename: internal/repository/key_maintenance.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/models"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error {
|
||||
query := r.db.Model(&models.APIKey{}).
|
||||
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID)
|
||||
|
||||
if statusFilter != "" && statusFilter != "all" {
|
||||
query = query.Where("group_api_key_mappings.status = ?", statusFilter)
|
||||
}
|
||||
var batchKeys []models.APIKey
|
||||
return query.FindInBatches(&batchKeys, 1000, func(tx *gorm.DB, batch int) error {
|
||||
if err := r.decryptKeys(batchKeys); err != nil {
|
||||
r.logger.Errorf("Failed to decrypt batch %d for streaming: %v", batch, err)
|
||||
}
|
||||
for _, key := range batchKeys {
|
||||
if key.APIKey != "" {
|
||||
if _, err := writer.Write([]byte(key.APIKey + "\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error) {
|
||||
if len(keyValues) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
hashes := make([]string, len(keyValues))
|
||||
for i, v := range keyValues {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
var result *gorm.DB
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result = tx.Model(&models.APIKey{}).
|
||||
Where("api_key_hash IN ?", hashes).
|
||||
Update("master_status", newStatus)
|
||||
return result.Error
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error {
|
||||
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&models.APIKey{}).
|
||||
Where("id = ?", keyID).
|
||||
Update("master_status", newStatus)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
// This ensures that if the key ID doesn't exist, we return a standard "not found" error.
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) DeleteOrphanKeys() (int64, error) {
|
||||
var deletedCount int64
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
count, err := r.deleteOrphanKeysLogic(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deletedCount = count
|
||||
return nil
|
||||
})
|
||||
return deletedCount, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) DeleteOrphanKeysTx(tx *gorm.DB) (int64, error) {
|
||||
return r.deleteOrphanKeysLogic(tx)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
|
||||
var orphanKeyIDs []uint
|
||||
err := db.Raw(`
|
||||
SELECT api_keys.id FROM api_keys
|
||||
LEFT JOIN group_api_key_mappings ON api_keys.id = group_api_key_mappings.api_key_id
|
||||
WHERE group_api_key_mappings.api_key_id IS NULL`).Scan(&orphanKeyIDs).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(orphanKeyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var keysToDelete []models.APIKey
|
||||
if err := db.Where("id IN ?", orphanKeyIDs).Find(&keysToDelete).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteSoftDeletedBefore(date time.Time) (int64, error) {
|
||||
result := r.db.Unscoped().Where("deleted_at < ?", date).Delete(&models.APIKey{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
|
||||
var keys []*models.APIKey
|
||||
err := r.db.Where("master_status = ?", models.MasterStatusActive).Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, key := range keys {
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
r.logger.Warnf("Failed to decrypt key ID %d during GetActiveMasterKeys: %v", key.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&models.APIKey{}).
|
||||
Where("id = ?", keyID).
|
||||
Update("master_status", status)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
|
||||
go func() {
|
||||
if err := r.LoadAllKeysToStore(); err != nil {
|
||||
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
return err
|
||||
}
|
||||
289
internal/repository/key_mapping.go
Normal file
289
internal/repository/key_mapping.go
Normal file
@@ -0,0 +1,289 @@
|
||||
// Filename: internal/repository/key_mapping.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
var mappings []models.GroupAPIKeyMapping
|
||||
for _, keyID := range keyIDs {
|
||||
mappings = append(mappings, models.GroupAPIKeyMapping{
|
||||
KeyGroupID: groupID,
|
||||
APIKeyID: keyID,
|
||||
})
|
||||
}
|
||||
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&mappings).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var unlinkedCount int64
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Table("group_api_key_mappings").
|
||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||
Delete(nil)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
unlinkedCount = result.RowsAffected
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
}
|
||||
|
||||
return unlinkedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
|
||||
strGroupIDs, err := r.store.SMembers(cacheKey)
|
||||
if err != nil || len(strGroupIDs) == 0 {
|
||||
var groupIDs []uint
|
||||
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
if len(groupIDs) > 0 {
|
||||
var interfaceSlice []interface{}
|
||||
for _, id := range groupIDs {
|
||||
interfaceSlice = append(interfaceSlice, id)
|
||||
}
|
||||
r.store.SAdd(cacheKey, interfaceSlice...)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
var groupIDs []uint
|
||||
for _, strID := range strGroupIDs {
|
||||
id, _ := strconv.Atoi(strID)
|
||||
groupIDs = append(groupIDs, uint(id))
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) {
|
||||
var mapping models.GroupAPIKeyMapping
|
||||
err := r.db.Where("key_group_id = ? AND api_key_id = ?", groupID, keyID).First(&mapping).Error
|
||||
return &mapping, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.updateStoreCacheForMapping(mapping)
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] This is the final version for the core refactoring.
|
||||
func (r *gormKeyRepository) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) {
|
||||
items := make([]*models.APIKeyDetails, 0)
|
||||
var total int64
|
||||
|
||||
query := r.db.Table("api_keys").
|
||||
Select(`
|
||||
api_keys.id, api_keys.created_at, api_keys.updated_at,
|
||||
api_keys.encrypted_key, -- Select encrypted key to be scanned into APIKeyDetails.EncryptedKey
|
||||
api_keys.master_status,
|
||||
m.status, m.last_error, m.consecutive_error_count, m.last_used_at, m.cooldown_until
|
||||
`).
|
||||
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id")
|
||||
|
||||
if params.KeyGroupID <= 0 {
|
||||
return nil, 0, errors.New("KeyGroupID is required for this query")
|
||||
}
|
||||
query = query.Where("m.key_group_id = ?", params.KeyGroupID)
|
||||
|
||||
if params.Status != "" {
|
||||
query = query.Where("LOWER(m.status) = LOWER(?)", params.Status)
|
||||
}
|
||||
|
||||
// Keyword search is now handled by the service layer.
|
||||
if params.Keyword != "" {
|
||||
r.logger.Warn("DB query is ignoring keyword; service layer will perform in-memory filtering.")
|
||||
}
|
||||
|
||||
countQuery := query.Model(&models.APIKey{}) // Use model for count to avoid GORM issues
|
||||
err := countQuery.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if total == 0 {
|
||||
return items, 0, nil
|
||||
}
|
||||
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
err = query.Order("api_keys.id DESC").Limit(params.PageSize).Offset(offset).Scan(&items).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Decrypt all results before returning. This loop is now valid.
|
||||
for i := range items {
|
||||
if items[i].EncryptedKey != "" {
|
||||
plaintext, err := r.crypto.Decrypt(items[i].EncryptedKey)
|
||||
if err == nil {
|
||||
items[i].APIKey = plaintext
|
||||
} else {
|
||||
items[i].APIKey = "[DECRYPTION FAILED]"
|
||||
r.logger.Errorf("Failed to decrypt key ID %d for pagination: %v", items[i].ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Uses hashes for lookup.
|
||||
func (r *gormKeyRepository) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) {
|
||||
if len(values) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
hashes := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
var keys []models.APIKey
|
||||
err := r.db.Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
||||
Where("api_keys.api_key_hash IN ?", hashes).
|
||||
Find(&keys).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Fetches full objects, decrypts, then extracts strings.
|
||||
func (r *gormKeyRepository) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) {
|
||||
var keys []models.APIKey
|
||||
query := r.db.Table("api_keys").
|
||||
Select("api_keys.*").
|
||||
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id").
|
||||
Where("m.key_group_id = ?", groupID)
|
||||
|
||||
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(m.status) IN (?)", lowerStatuses)
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKeys(keys); err != nil {
|
||||
return nil, fmt.Errorf("decryption failed during FindKeyValuesByStatus: %w", err)
|
||||
}
|
||||
|
||||
keyValues := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
keyValues[i] = key.APIKey
|
||||
}
|
||||
return keyValues, nil
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Consistent with the new pattern.
|
||||
func (r *gormKeyRepository) GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error) {
|
||||
var keys []models.APIKey
|
||||
query := r.db.Table("api_keys").
|
||||
Select("api_keys.*").
|
||||
Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID)
|
||||
|
||||
if len(statuses) > 0 {
|
||||
isAll := false
|
||||
for _, s := range statuses {
|
||||
if strings.ToLower(s) == "all" {
|
||||
isAll = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isAll {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(group_api_key_mappings.status) IN ?", lowerStatuses)
|
||||
}
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKeys(keys); err != nil {
|
||||
return nil, fmt.Errorf("decryption failed during GetKeyStringsByGroupAndStatus: %w", err)
|
||||
}
|
||||
|
||||
keyStrings := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
keyStrings[i] = key.APIKey
|
||||
}
|
||||
return keyStrings, nil
|
||||
}
|
||||
|
||||
// FindKeyIDsByStatus remains unchanged as it does not deal with key values.
|
||||
func (r *gormKeyRepository) FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error) {
|
||||
var keyIDs []uint
|
||||
query := r.db.Table("group_api_key_mappings").
|
||||
Select("api_key_id").
|
||||
Where("key_group_id = ?", groupID)
|
||||
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(status) IN (?)", lowerStatuses)
|
||||
}
|
||||
if err := query.Pluck("api_key_id", &keyIDs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keyIDs, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error {
|
||||
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
}
|
||||
276
internal/repository/key_selector.go
Normal file
276
internal/repository/key_selector.go
Normal file
@@ -0,0 +1,276 @@
|
||||
// Filename: internal/repository/key_selector.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
CacheTTL = 5 * time.Minute
|
||||
EmptyPoolPlaceholder = "EMPTY_POOL"
|
||||
EmptyCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
|
||||
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch group.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
}
|
||||
err = zerr
|
||||
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
|
||||
default: // 默认或未指定策略时,使用基础的随机策略
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
|
||||
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
|
||||
}
|
||||
|
||||
return apiKey, mapping, nil
|
||||
}
|
||||
|
||||
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
// 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离
|
||||
poolID := generatePoolID(pool.CandidateGroups)
|
||||
log := r.logger.WithField("pool_id", poolID)
|
||||
|
||||
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch pool.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
}
|
||||
err = zerr
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||||
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
|
||||
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
|
||||
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
|
||||
return nil, nil, err
|
||||
}
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
for _, group := range pool.CandidateGroups {
|
||||
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if cacheErr == nil && apiKey != nil && mapping != nil {
|
||||
|
||||
if pool.PollingStrategy == models.StrategyWeighted {
|
||||
|
||||
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
|
||||
}
|
||||
return apiKey, group, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
|
||||
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
|
||||
}
|
||||
|
||||
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
|
||||
// 使用 LIST 键作为存在性检查的标志
|
||||
listKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
exists, err := r.store.Exists(listKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
|
||||
val, err := r.store.LIndex(listKey, 0)
|
||||
if err == nil && val == EmptyPoolPlaceholder {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID)
|
||||
|
||||
var allActiveKeyIDs []string
|
||||
lruMembers := make(map[string]float64)
|
||||
for _, group := range pool.CandidateGroups {
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID)
|
||||
continue
|
||||
}
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
|
||||
|
||||
for _, keyIDStr := range groupKeyIDs {
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if err == nil && mapping != nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(allActiveKeyIDs) == 0 {
|
||||
pipe := r.store.Pipeline()
|
||||
pipe.LPush(listKey, EmptyPoolPlaceholder)
|
||||
pipe.Expire(listKey, EmptyCacheTTL)
|
||||
if err := pipe.Exec(); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
// 使用管道填充所有轮询结构
|
||||
pipe := r.store.Pipeline()
|
||||
// 1. 顺序
|
||||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
// 2. 随机
|
||||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
|
||||
// 设置合理的过期时间,例如5分钟,以防止孤儿数据
|
||||
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(lruMembers) > 0 {
|
||||
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
r.store.ZAdd(lruKey, map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||||
func generatePoolID(groups []*models.KeyGroup) string {
|
||||
ids := make([]int, len(groups))
|
||||
for i, g := range groups {
|
||||
ids[i] = int(g.ID)
|
||||
}
|
||||
sort.Ints(ids)
|
||||
|
||||
h := sha1.New()
|
||||
io.WriteString(h, fmt.Sprintf("%v", ids))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// toInterfaceSlice 类型转换辅助函数
|
||||
func toInterfaceSlice(slice []string) []interface{} {
|
||||
result := make([]interface{}, len(slice))
|
||||
for i, v := range slice {
|
||||
result[i] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略
|
||||
func nowMilli() float64 {
|
||||
return float64(time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
|
||||
}
|
||||
var apiKey models.APIKey
|
||||
if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
|
||||
}
|
||||
|
||||
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
|
||||
}
|
||||
var mapping models.GroupAPIKeyMapping
|
||||
if err := json.Unmarshal(mappingJSON, &mapping); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err)
|
||||
}
|
||||
|
||||
return &apiKey, &mapping, nil
|
||||
}
|
||||
77
internal/repository/key_writer.go
Normal file
77
internal/repository/key_writer.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Filename: internal/repository/key_writer.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
timestamp := float64(time.Now().UnixMilli())
|
||||
|
||||
members := map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): timestamp,
|
||||
}
|
||||
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
|
||||
|
||||
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(lruKey, keyIDStr)
|
||||
_ = r.store.SRem(mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
|
||||
|
||||
if newStatus == models.StatusActive {
|
||||
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
|
||||
}
|
||||
members := map[string]float64{keyIDStr: 0}
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
|
||||
}
|
||||
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
if apiErr == nil {
|
||||
r.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided.", key.ID, group.ID)
|
||||
return
|
||||
}
|
||||
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
|
||||
|
||||
// This call is correct. It uses the synchronous, direct method.
|
||||
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
}
|
||||
107
internal/repository/repository.go
Normal file
107
internal/repository/repository.go
Normal file
@@ -0,0 +1,107 @@
|
||||
// Filename: internal/repository/repository.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BasePool 虚拟的临时资源池,用于智能聚合模式。
|
||||
type BasePool struct {
|
||||
CandidateGroups []*models.KeyGroup
|
||||
PollingStrategy models.PollingStrategy
|
||||
}
|
||||
|
||||
type KeyRepository interface {
|
||||
// --- 核心选取与调度 --- key_selector
|
||||
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
|
||||
// --- 加密与解密 --- key_crud
|
||||
Decrypt(key *models.APIKey) error
|
||||
DecryptBatch(keys []models.APIKey) error
|
||||
|
||||
// --- 基础增删改查 --- key_crud
|
||||
AddKeys(keys []models.APIKey) ([]models.APIKey, error)
|
||||
Update(key *models.APIKey) error
|
||||
HardDeleteByID(id uint) error
|
||||
HardDeleteByValues(keyValues []string) (int64, error)
|
||||
GetKeyByID(id uint) (*models.APIKey, error)
|
||||
GetKeyByValue(keyValue string) (*models.APIKey, error)
|
||||
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
|
||||
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
|
||||
CountByGroup(groupID uint) (int64, error)
|
||||
|
||||
// --- 多对多关系管理 --- key_mapping
|
||||
LinkKeysToGroup(groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(keyID uint) ([]uint, error)
|
||||
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
|
||||
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
|
||||
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
|
||||
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
|
||||
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
|
||||
FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error)
|
||||
GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error)
|
||||
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 缓存管理 --- key_cache
|
||||
LoadAllKeysToStore() error
|
||||
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 维护与后台任务 --- key_maintenance
|
||||
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
|
||||
UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error)
|
||||
UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error
|
||||
DeleteOrphanKeys() (int64, error)
|
||||
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
|
||||
GetActiveMasterKeys() ([]*models.APIKey, error)
|
||||
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
|
||||
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
|
||||
|
||||
// --- 轮询策略的"写"操作 --- key_writer
|
||||
UpdateKeyUsageTimestamp(groupID, keyID uint)
|
||||
// 同步更新缓存,供核心业务使用
|
||||
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
// 异步更新缓存,供事件订阅者使用
|
||||
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
}
|
||||
|
||||
type GroupRepository interface {
|
||||
GetGroupByName(name string) (*models.KeyGroup, error)
|
||||
GetAllGroups() ([]*models.KeyGroup, error)
|
||||
UpdateOrderInTransaction(orders map[uint]int) error
|
||||
}
|
||||
|
||||
type gormKeyRepository struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
crypto *crypto.Service
|
||||
}
|
||||
|
||||
type gormGroupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
|
||||
return &gormKeyRepository{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "repository.key🔗"),
|
||||
crypto: crypto,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGroupRepository(db *gorm.DB) GroupRepository {
|
||||
return &gormGroupRepository{db: db}
|
||||
}
|
||||
Reference in New Issue
Block a user