// Filename: internal/repository/key_crud.go package repository import ( "crypto/sha256" "encoding/hex" "fmt" "gemini-balancer/internal/models" "context" "math/rand" "strings" "time" "gorm.io/gorm" "gorm.io/gorm/clause" ) func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, error) { if len(keys) == 0 { return []models.APIKey{}, nil } keyHashes := make([]string, len(keys)) keyValueToHashMap := make(map[string]string) for i, k := range keys { if k.APIKey == "" { return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i) } hash := sha256.Sum256([]byte(k.APIKey)) hashStr := hex.EncodeToString(hash[:]) keyHashes[i] = hashStr keyValueToHashMap[k.APIKey] = hashStr } var finalKeys []models.APIKey err := r.db.Transaction(func(tx *gorm.DB) error { var existingKeys []models.APIKey if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil { return err } existingKeyHashMap := make(map[string]models.APIKey) for _, k := range existingKeys { existingKeyHashMap[k.APIKeyHash] = k } var keysToCreate []models.APIKey var keysToRestore []uint for _, keyObj := range keys { keyVal := keyObj.APIKey hash := keyValueToHashMap[keyVal] if ek, found := existingKeyHashMap[hash]; found { if ek.DeletedAt.Valid { keysToRestore = append(keysToRestore, ek.ID) } } else { encryptedKey, err := r.crypto.Encrypt(keyVal) if err != nil { return fmt.Errorf("failed to encrypt key '%s...': %w", keyVal[:min(4, len(keyVal))], err) } keysToCreate = append(keysToCreate, models.APIKey{ EncryptedKey: encryptedKey, APIKeyHash: hash, }) } } if len(keysToRestore) > 0 { if err := tx.Model(&models.APIKey{}).Unscoped().Where("id IN ?", keysToRestore).Update("deleted_at", nil).Error; err != nil { return err } } if len(keysToCreate) > 0 { if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil { return err } } if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil { return err } return r.decryptKeys(finalKeys) }) return finalKeys, err } func (r *gormKeyRepository) Update(key *models.APIKey) error { if key.APIKey != "" { encryptedKey, err := r.crypto.Encrypt(key.APIKey) if err != nil { return fmt.Errorf("failed to re-encrypt key on update for ID %d: %w", key.ID, err) } key.EncryptedKey = encryptedKey // Recalculate hash as a defensive measure. hash := sha256.Sum256([]byte(key.APIKey)) key.APIKeyHash = hex.EncodeToString(hash[:]) } err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { return tx.Save(key).Error }) if err != nil { return err } if err := r.decryptKey(key); err != nil { r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err) return nil } if err := r.updateStoreCacheForKey(key); err != nil { r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err) } return nil } func (r *gormKeyRepository) HardDeleteByID(id uint) error { key, err := r.GetKeyByID(id) if err != nil { return err } err = r.executeTransactionWithRetry(func(tx *gorm.DB) error { return tx.Unscoped().Delete(&models.APIKey{}, id).Error }) if err != nil { return err } if err := r.removeStoreCacheForKey(context.Background(), key); err != nil { r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err) } return nil } func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error) { if len(keyValues) == 0 { return 0, nil } hashes := make([]string, len(keyValues)) for i, v := range keyValues { hash := sha256.Sum256([]byte(v)) hashes[i] = hex.EncodeToString(hash[:]) } var keysToDelete []models.APIKey if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil { return 0, err } if len(keysToDelete) == 0 { return 0, nil } if err := r.decryptKeys(keysToDelete); err != nil { r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err) } var deletedCount int64 err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { ids := pluckIDs(keysToDelete) result := tx.Unscoped().Where("id IN ?", ids).Delete(&models.APIKey{}) if result.Error != nil { return result.Error } deletedCount = result.RowsAffected return nil }) if err != nil { return 0, err } for i := range keysToDelete { if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil { r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err) } } return deletedCount, nil } func (r *gormKeyRepository) GetKeyByID(id uint) (*models.APIKey, error) { var key models.APIKey if err := r.db.First(&key, id).Error; err != nil { return nil, err } if err := r.decryptKey(&key); err != nil { return nil, err } return &key, nil } func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) { if len(ids) == 0 { return []models.APIKey{}, nil } var keys []models.APIKey err := r.db.Where("id IN ?", ids).Find(&keys).Error if err != nil { return nil, err } return keys, r.decryptKeys(keys) } func (r *gormKeyRepository) GetKeyByValue(keyValue string) (*models.APIKey, error) { hash := sha256.Sum256([]byte(keyValue)) hashStr := hex.EncodeToString(hash[:]) var key models.APIKey if err := r.db.Where("api_key_hash = ?", hashStr).First(&key).Error; err != nil { return nil, err } key.APIKey = keyValue return &key, nil } func (r *gormKeyRepository) GetKeysByValues(keyValues []string) ([]models.APIKey, error) { if len(keyValues) == 0 { return []models.APIKey{}, nil } hashes := make([]string, len(keyValues)) for i, v := range keyValues { hash := sha256.Sum256([]byte(v)) hashes[i] = hex.EncodeToString(hash[:]) } var keys []models.APIKey err := r.db.Where("api_key_hash IN ?", hashes).Find(&keys).Error if err != nil { return nil, err } return keys, r.decryptKeys(keys) } func (r *gormKeyRepository) GetKeysByGroup(groupID uint) ([]models.APIKey, error) { var keys []models.APIKey err := r.db.Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id"). Where("group_api_key_mappings.key_group_id = ?", groupID). Find(&keys).Error if err != nil { return nil, err } return keys, r.decryptKeys(keys) } func (r *gormKeyRepository) CountByGroup(groupID uint) (int64, error) { var count int64 err := r.db.Model(&models.APIKey{}). Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id"). Where("group_api_key_mappings.key_group_id = ?", groupID). Count(&count).Error return count, err } // --- Helpers --- func (r *gormKeyRepository) executeTransactionWithRetry(operation func(tx *gorm.DB) error) error { const maxRetries = 3 const baseDelay = 50 * time.Millisecond const maxJitter = 150 * time.Millisecond var err error for i := 0; i < maxRetries; i++ { err = r.db.Transaction(operation) if err == nil { return nil } if strings.Contains(err.Error(), "database is locked") { jitter := time.Duration(rand.Intn(int(maxJitter))) totalDelay := baseDelay + jitter r.logger.Debugf("Database is locked, retrying in %v... (attempt %d/%d)", totalDelay, i+1, maxRetries) time.Sleep(totalDelay) continue } break } return err } func pluckIDs(keys []models.APIKey) []uint { ids := make([]uint, 0, len(keys)) for _, key := range keys { ids = append(ids, key.ID) } return ids }