281 lines
8.3 KiB
Go
281 lines
8.3 KiB
Go
// Filename: internal/repository/key_crud.go
|
|
package repository
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"gemini-balancer/internal/models"
|
|
|
|
"math/rand"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, error) {
|
|
if len(keys) == 0 {
|
|
return []models.APIKey{}, nil
|
|
}
|
|
keyHashes := make([]string, len(keys))
|
|
keyValueToHashMap := make(map[string]string)
|
|
for i, k := range keys {
|
|
// All incoming keys must have plaintext APIKey
|
|
if k.APIKey == "" {
|
|
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
|
|
}
|
|
hash := sha256.Sum256([]byte(k.APIKey))
|
|
hashStr := hex.EncodeToString(hash[:])
|
|
keyHashes[i] = hashStr
|
|
keyValueToHashMap[k.APIKey] = hashStr
|
|
}
|
|
var finalKeys []models.APIKey
|
|
err := r.db.Transaction(func(tx *gorm.DB) error {
|
|
var existingKeys []models.APIKey
|
|
// [MODIFIED] Query by hash to find existing keys.
|
|
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
|
|
return err
|
|
}
|
|
existingKeyHashMap := make(map[string]models.APIKey)
|
|
for _, k := range existingKeys {
|
|
existingKeyHashMap[k.APIKeyHash] = k
|
|
}
|
|
var keysToCreate []models.APIKey
|
|
var keysToRestore []uint
|
|
for _, keyObj := range keys {
|
|
keyVal := keyObj.APIKey
|
|
hash := keyValueToHashMap[keyVal]
|
|
if ek, found := existingKeyHashMap[hash]; found {
|
|
if ek.DeletedAt.Valid {
|
|
keysToRestore = append(keysToRestore, ek.ID)
|
|
}
|
|
} else {
|
|
encryptedKey, err := r.crypto.Encrypt(keyVal)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt key '%s...': %w", keyVal[:min(4, len(keyVal))], err)
|
|
}
|
|
keysToCreate = append(keysToCreate, models.APIKey{
|
|
EncryptedKey: encryptedKey,
|
|
APIKeyHash: hash,
|
|
})
|
|
}
|
|
}
|
|
if len(keysToRestore) > 0 {
|
|
if err := tx.Model(&models.APIKey{}).Unscoped().Where("id IN ?", keysToRestore).Update("deleted_at", nil).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if len(keysToCreate) > 0 {
|
|
// [MODIFIED] Create now only provides encrypted data and hash.
|
|
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
|
|
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
|
|
return err
|
|
}
|
|
// [CRITICAL] Decrypt all keys before returning them to the service layer.
|
|
return r.decryptKeys(finalKeys)
|
|
})
|
|
return finalKeys, err
|
|
}
|
|
|
|
func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
|
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
|
|
// This indicates a potential change that needs to be re-encrypted.
|
|
if key.APIKey != "" {
|
|
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to re-encrypt key on update for ID %d: %w", key.ID, err)
|
|
}
|
|
key.EncryptedKey = encryptedKey
|
|
// Recalculate hash as a defensive measure.
|
|
hash := sha256.Sum256([]byte(key.APIKey))
|
|
key.APIKeyHash = hex.EncodeToString(hash[:])
|
|
}
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
|
|
return tx.Save(key).Error
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
|
|
if err := r.decryptKey(key); err != nil {
|
|
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
|
|
return nil // Continue without cache update if decryption fails.
|
|
}
|
|
if err := r.updateStoreCacheForKey(key); err != nil {
|
|
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
|
key, err := r.GetKeyByID(id) // This now returns a decrypted key
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
return tx.Unscoped().Delete(&models.APIKey{}, id).Error
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := r.removeStoreCacheForKey(key); err != nil {
|
|
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error) {
|
|
if len(keyValues) == 0 {
|
|
return 0, nil
|
|
}
|
|
hashes := make([]string, len(keyValues))
|
|
for i, v := range keyValues {
|
|
hash := sha256.Sum256([]byte(v))
|
|
hashes[i] = hex.EncodeToString(hash[:])
|
|
}
|
|
// Find the full key objects first to update the cache later.
|
|
var keysToDelete []models.APIKey
|
|
// [MODIFIED] Find by hash.
|
|
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
|
|
return 0, err
|
|
}
|
|
if len(keysToDelete) == 0 {
|
|
return 0, nil
|
|
}
|
|
// Decrypt them to ensure cache has plaintext if needed.
|
|
if err := r.decryptKeys(keysToDelete); err != nil {
|
|
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
|
|
}
|
|
var deletedCount int64
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
ids := pluckIDs(keysToDelete)
|
|
result := tx.Unscoped().Where("id IN ?", ids).Delete(&models.APIKey{})
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
deletedCount = result.RowsAffected
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for i := range keysToDelete {
|
|
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
|
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
|
}
|
|
}
|
|
return deletedCount, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetKeyByID(id uint) (*models.APIKey, error) {
|
|
var key models.APIKey
|
|
if err := r.db.First(&key, id).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := r.decryptKey(&key); err != nil {
|
|
return nil, err
|
|
}
|
|
return &key, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
|
|
if len(ids) == 0 {
|
|
return []models.APIKey{}, nil
|
|
}
|
|
var keys []models.APIKey
|
|
err := r.db.Where("id IN ?", ids).Find(&keys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// [CRITICAL] Decrypt before returning.
|
|
return keys, r.decryptKeys(keys)
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetKeyByValue(keyValue string) (*models.APIKey, error) {
|
|
hash := sha256.Sum256([]byte(keyValue))
|
|
hashStr := hex.EncodeToString(hash[:])
|
|
var key models.APIKey
|
|
if err := r.db.Where("api_key_hash = ?", hashStr).First(&key).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
key.APIKey = keyValue
|
|
return &key, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetKeysByValues(keyValues []string) ([]models.APIKey, error) {
|
|
if len(keyValues) == 0 {
|
|
return []models.APIKey{}, nil
|
|
}
|
|
hashes := make([]string, len(keyValues))
|
|
for i, v := range keyValues {
|
|
hash := sha256.Sum256([]byte(v))
|
|
hashes[i] = hex.EncodeToString(hash[:])
|
|
}
|
|
var keys []models.APIKey
|
|
err := r.db.Where("api_key_hash IN ?", hashes).Find(&keys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return keys, r.decryptKeys(keys)
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetKeysByGroup(groupID uint) ([]models.APIKey, error) {
|
|
var keys []models.APIKey
|
|
err := r.db.Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
|
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
|
Find(&keys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return keys, r.decryptKeys(keys)
|
|
}
|
|
|
|
func (r *gormKeyRepository) CountByGroup(groupID uint) (int64, error) {
|
|
var count int64
|
|
err := r.db.Model(&models.APIKey{}).
|
|
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
|
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
|
Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
// --- Helpers ---
|
|
|
|
func (r *gormKeyRepository) executeTransactionWithRetry(operation func(tx *gorm.DB) error) error {
|
|
const maxRetries = 3
|
|
const baseDelay = 50 * time.Millisecond
|
|
const maxJitter = 150 * time.Millisecond
|
|
var err error
|
|
|
|
for i := 0; i < maxRetries; i++ {
|
|
err = r.db.Transaction(operation)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if strings.Contains(err.Error(), "database is locked") {
|
|
jitter := time.Duration(rand.Intn(int(maxJitter)))
|
|
totalDelay := baseDelay + jitter
|
|
r.logger.Debugf("Database is locked, retrying in %v... (attempt %d/%d)", totalDelay, i+1, maxRetries)
|
|
time.Sleep(totalDelay)
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
|
|
func pluckIDs(keys []models.APIKey) []uint {
|
|
ids := make([]uint, 0, len(keys))
|
|
for _, key := range keys {
|
|
ids = append(ids, key.ID)
|
|
}
|
|
return ids
|
|
}
|