Files
gemini-banlancer/internal/repository/key_crud.go
2025-11-20 12:24:05 +08:00

281 lines
8.3 KiB
Go

// Filename: internal/repository/key_crud.go
package repository
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"gemini-balancer/internal/models"
"math/rand"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, error) {
if len(keys) == 0 {
return []models.APIKey{}, nil
}
keyHashes := make([]string, len(keys))
keyValueToHashMap := make(map[string]string)
for i, k := range keys {
// All incoming keys must have plaintext APIKey
if k.APIKey == "" {
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
}
hash := sha256.Sum256([]byte(k.APIKey))
hashStr := hex.EncodeToString(hash[:])
keyHashes[i] = hashStr
keyValueToHashMap[k.APIKey] = hashStr
}
var finalKeys []models.APIKey
err := r.db.Transaction(func(tx *gorm.DB) error {
var existingKeys []models.APIKey
// [MODIFIED] Query by hash to find existing keys.
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
return err
}
existingKeyHashMap := make(map[string]models.APIKey)
for _, k := range existingKeys {
existingKeyHashMap[k.APIKeyHash] = k
}
var keysToCreate []models.APIKey
var keysToRestore []uint
for _, keyObj := range keys {
keyVal := keyObj.APIKey
hash := keyValueToHashMap[keyVal]
if ek, found := existingKeyHashMap[hash]; found {
if ek.DeletedAt.Valid {
keysToRestore = append(keysToRestore, ek.ID)
}
} else {
encryptedKey, err := r.crypto.Encrypt(keyVal)
if err != nil {
return fmt.Errorf("failed to encrypt key '%s...': %w", keyVal[:min(4, len(keyVal))], err)
}
keysToCreate = append(keysToCreate, models.APIKey{
EncryptedKey: encryptedKey,
APIKeyHash: hash,
})
}
}
if len(keysToRestore) > 0 {
if err := tx.Model(&models.APIKey{}).Unscoped().Where("id IN ?", keysToRestore).Update("deleted_at", nil).Error; err != nil {
return err
}
}
if len(keysToCreate) > 0 {
// [MODIFIED] Create now only provides encrypted data and hash.
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
return err
}
}
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
return err
}
// [CRITICAL] Decrypt all keys before returning them to the service layer.
return r.decryptKeys(finalKeys)
})
return finalKeys, err
}
func (r *gormKeyRepository) Update(key *models.APIKey) error {
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
// This indicates a potential change that needs to be re-encrypted.
if key.APIKey != "" {
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
if err != nil {
return fmt.Errorf("failed to re-encrypt key on update for ID %d: %w", key.ID, err)
}
key.EncryptedKey = encryptedKey
// Recalculate hash as a defensive measure.
hash := sha256.Sum256([]byte(key.APIKey))
key.APIKeyHash = hex.EncodeToString(hash[:])
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
return tx.Save(key).Error
})
if err != nil {
return err
}
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
return nil // Continue without cache update if decryption fails.
}
if err := r.updateStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
}
return nil
}
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
key, err := r.GetKeyByID(id) // This now returns a decrypted key
if err != nil {
return err
}
err = r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Unscoped().Delete(&models.APIKey{}, id).Error
})
if err != nil {
return err
}
if err := r.removeStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
}
return nil
}
func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error) {
if len(keyValues) == 0 {
return 0, nil
}
hashes := make([]string, len(keyValues))
for i, v := range keyValues {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
// Find the full key objects first to update the cache later.
var keysToDelete []models.APIKey
// [MODIFIED] Find by hash.
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
return 0, err
}
if len(keysToDelete) == 0 {
return 0, nil
}
// Decrypt them to ensure cache has plaintext if needed.
if err := r.decryptKeys(keysToDelete); err != nil {
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
}
var deletedCount int64
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
ids := pluckIDs(keysToDelete)
result := tx.Unscoped().Where("id IN ?", ids).Delete(&models.APIKey{})
if result.Error != nil {
return result.Error
}
deletedCount = result.RowsAffected
return nil
})
if err != nil {
return 0, err
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
return deletedCount, nil
}
func (r *gormKeyRepository) GetKeyByID(id uint) (*models.APIKey, error) {
var key models.APIKey
if err := r.db.First(&key, id).Error; err != nil {
return nil, err
}
if err := r.decryptKey(&key); err != nil {
return nil, err
}
return &key, nil
}
func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
if len(ids) == 0 {
return []models.APIKey{}, nil
}
var keys []models.APIKey
err := r.db.Where("id IN ?", ids).Find(&keys).Error
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt before returning.
return keys, r.decryptKeys(keys)
}
func (r *gormKeyRepository) GetKeyByValue(keyValue string) (*models.APIKey, error) {
hash := sha256.Sum256([]byte(keyValue))
hashStr := hex.EncodeToString(hash[:])
var key models.APIKey
if err := r.db.Where("api_key_hash = ?", hashStr).First(&key).Error; err != nil {
return nil, err
}
key.APIKey = keyValue
return &key, nil
}
func (r *gormKeyRepository) GetKeysByValues(keyValues []string) ([]models.APIKey, error) {
if len(keyValues) == 0 {
return []models.APIKey{}, nil
}
hashes := make([]string, len(keyValues))
for i, v := range keyValues {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
var keys []models.APIKey
err := r.db.Where("api_key_hash IN ?", hashes).Find(&keys).Error
if err != nil {
return nil, err
}
return keys, r.decryptKeys(keys)
}
func (r *gormKeyRepository) GetKeysByGroup(groupID uint) ([]models.APIKey, error) {
var keys []models.APIKey
err := r.db.Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID).
Find(&keys).Error
if err != nil {
return nil, err
}
return keys, r.decryptKeys(keys)
}
func (r *gormKeyRepository) CountByGroup(groupID uint) (int64, error) {
var count int64
err := r.db.Model(&models.APIKey{}).
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID).
Count(&count).Error
return count, err
}
// --- Helpers ---
func (r *gormKeyRepository) executeTransactionWithRetry(operation func(tx *gorm.DB) error) error {
const maxRetries = 3
const baseDelay = 50 * time.Millisecond
const maxJitter = 150 * time.Millisecond
var err error
for i := 0; i < maxRetries; i++ {
err = r.db.Transaction(operation)
if err == nil {
return nil
}
if strings.Contains(err.Error(), "database is locked") {
jitter := time.Duration(rand.Intn(int(maxJitter)))
totalDelay := baseDelay + jitter
r.logger.Debugf("Database is locked, retrying in %v... (attempt %d/%d)", totalDelay, i+1, maxRetries)
time.Sleep(totalDelay)
continue
}
break
}
return err
}
func pluckIDs(keys []models.APIKey) []uint {
ids := make([]uint, 0, len(keys))
for _, key := range keys {
ids = append(ids, key.ID)
}
return ids
}