New
This commit is contained in:
289
internal/repository/key_mapping.go
Normal file
289
internal/repository/key_mapping.go
Normal file
@@ -0,0 +1,289 @@
|
||||
// Filename: internal/repository/key_mapping.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
var mappings []models.GroupAPIKeyMapping
|
||||
for _, keyID := range keyIDs {
|
||||
mappings = append(mappings, models.GroupAPIKeyMapping{
|
||||
KeyGroupID: groupID,
|
||||
APIKeyID: keyID,
|
||||
})
|
||||
}
|
||||
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&mappings).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var unlinkedCount int64
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Table("group_api_key_mappings").
|
||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||
Delete(nil)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
unlinkedCount = result.RowsAffected
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
}
|
||||
|
||||
return unlinkedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
|
||||
strGroupIDs, err := r.store.SMembers(cacheKey)
|
||||
if err != nil || len(strGroupIDs) == 0 {
|
||||
var groupIDs []uint
|
||||
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
if len(groupIDs) > 0 {
|
||||
var interfaceSlice []interface{}
|
||||
for _, id := range groupIDs {
|
||||
interfaceSlice = append(interfaceSlice, id)
|
||||
}
|
||||
r.store.SAdd(cacheKey, interfaceSlice...)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
var groupIDs []uint
|
||||
for _, strID := range strGroupIDs {
|
||||
id, _ := strconv.Atoi(strID)
|
||||
groupIDs = append(groupIDs, uint(id))
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) {
|
||||
var mapping models.GroupAPIKeyMapping
|
||||
err := r.db.Where("key_group_id = ? AND api_key_id = ?", groupID, keyID).First(&mapping).Error
|
||||
return &mapping, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.updateStoreCacheForMapping(mapping)
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] This is the final version for the core refactoring.
|
||||
func (r *gormKeyRepository) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) {
|
||||
items := make([]*models.APIKeyDetails, 0)
|
||||
var total int64
|
||||
|
||||
query := r.db.Table("api_keys").
|
||||
Select(`
|
||||
api_keys.id, api_keys.created_at, api_keys.updated_at,
|
||||
api_keys.encrypted_key, -- Select encrypted key to be scanned into APIKeyDetails.EncryptedKey
|
||||
api_keys.master_status,
|
||||
m.status, m.last_error, m.consecutive_error_count, m.last_used_at, m.cooldown_until
|
||||
`).
|
||||
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id")
|
||||
|
||||
if params.KeyGroupID <= 0 {
|
||||
return nil, 0, errors.New("KeyGroupID is required for this query")
|
||||
}
|
||||
query = query.Where("m.key_group_id = ?", params.KeyGroupID)
|
||||
|
||||
if params.Status != "" {
|
||||
query = query.Where("LOWER(m.status) = LOWER(?)", params.Status)
|
||||
}
|
||||
|
||||
// Keyword search is now handled by the service layer.
|
||||
if params.Keyword != "" {
|
||||
r.logger.Warn("DB query is ignoring keyword; service layer will perform in-memory filtering.")
|
||||
}
|
||||
|
||||
countQuery := query.Model(&models.APIKey{}) // Use model for count to avoid GORM issues
|
||||
err := countQuery.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if total == 0 {
|
||||
return items, 0, nil
|
||||
}
|
||||
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
err = query.Order("api_keys.id DESC").Limit(params.PageSize).Offset(offset).Scan(&items).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Decrypt all results before returning. This loop is now valid.
|
||||
for i := range items {
|
||||
if items[i].EncryptedKey != "" {
|
||||
plaintext, err := r.crypto.Decrypt(items[i].EncryptedKey)
|
||||
if err == nil {
|
||||
items[i].APIKey = plaintext
|
||||
} else {
|
||||
items[i].APIKey = "[DECRYPTION FAILED]"
|
||||
r.logger.Errorf("Failed to decrypt key ID %d for pagination: %v", items[i].ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Uses hashes for lookup.
|
||||
func (r *gormKeyRepository) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) {
|
||||
if len(values) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
hashes := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
var keys []models.APIKey
|
||||
err := r.db.Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
||||
Where("api_keys.api_key_hash IN ?", hashes).
|
||||
Find(&keys).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Fetches full objects, decrypts, then extracts strings.
|
||||
func (r *gormKeyRepository) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) {
|
||||
var keys []models.APIKey
|
||||
query := r.db.Table("api_keys").
|
||||
Select("api_keys.*").
|
||||
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id").
|
||||
Where("m.key_group_id = ?", groupID)
|
||||
|
||||
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(m.status) IN (?)", lowerStatuses)
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKeys(keys); err != nil {
|
||||
return nil, fmt.Errorf("decryption failed during FindKeyValuesByStatus: %w", err)
|
||||
}
|
||||
|
||||
keyValues := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
keyValues[i] = key.APIKey
|
||||
}
|
||||
return keyValues, nil
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Consistent with the new pattern.
|
||||
func (r *gormKeyRepository) GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error) {
|
||||
var keys []models.APIKey
|
||||
query := r.db.Table("api_keys").
|
||||
Select("api_keys.*").
|
||||
Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID)
|
||||
|
||||
if len(statuses) > 0 {
|
||||
isAll := false
|
||||
for _, s := range statuses {
|
||||
if strings.ToLower(s) == "all" {
|
||||
isAll = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isAll {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(group_api_key_mappings.status) IN ?", lowerStatuses)
|
||||
}
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKeys(keys); err != nil {
|
||||
return nil, fmt.Errorf("decryption failed during GetKeyStringsByGroupAndStatus: %w", err)
|
||||
}
|
||||
|
||||
keyStrings := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
keyStrings[i] = key.APIKey
|
||||
}
|
||||
return keyStrings, nil
|
||||
}
|
||||
|
||||
// FindKeyIDsByStatus remains unchanged as it does not deal with key values.
|
||||
func (r *gormKeyRepository) FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error) {
|
||||
var keyIDs []uint
|
||||
query := r.db.Table("group_api_key_mappings").
|
||||
Select("api_key_id").
|
||||
Where("key_group_id = ?", groupID)
|
||||
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(status) IN (?)", lowerStatuses)
|
||||
}
|
||||
if err := query.Pluck("api_key_id", &keyIDs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keyIDs, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error {
|
||||
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user