291 lines
8.5 KiB
Go
291 lines
8.5 KiB
Go
// Filename: internal/repository/key_mapping.go
|
|
package repository
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"gemini-balancer/internal/models"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
|
|
if len(keyIDs) == 0 {
|
|
return nil
|
|
}
|
|
var mappings []models.GroupAPIKeyMapping
|
|
for _, keyID := range keyIDs {
|
|
mappings = append(mappings, models.GroupAPIKeyMapping{
|
|
KeyGroupID: groupID,
|
|
APIKeyID: keyID,
|
|
})
|
|
}
|
|
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&mappings).Error
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, keyID := range keyIDs {
|
|
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
|
|
if len(keyIDs) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
var unlinkedCount int64
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
result := tx.Table("group_api_key_mappings").
|
|
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
|
Delete(nil)
|
|
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
unlinkedCount = result.RowsAffected
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
|
for _, keyID := range keyIDs {
|
|
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
|
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
|
}
|
|
|
|
return unlinkedCount, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
|
|
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
|
|
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
|
|
if err != nil || len(strGroupIDs) == 0 {
|
|
var groupIDs []uint
|
|
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
|
|
if dbErr != nil {
|
|
return nil, dbErr
|
|
}
|
|
if len(groupIDs) > 0 {
|
|
var interfaceSlice []interface{}
|
|
for _, id := range groupIDs {
|
|
interfaceSlice = append(interfaceSlice, id)
|
|
}
|
|
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
|
|
}
|
|
return groupIDs, nil
|
|
}
|
|
|
|
var groupIDs []uint
|
|
for _, strID := range strGroupIDs {
|
|
id, _ := strconv.Atoi(strID)
|
|
groupIDs = append(groupIDs, uint(id))
|
|
}
|
|
return groupIDs, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) {
|
|
var mapping models.GroupAPIKeyMapping
|
|
err := r.db.Where("key_group_id = ? AND api_key_id = ?", groupID, keyID).First(&mapping).Error
|
|
return &mapping, err
|
|
}
|
|
|
|
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
return tx.Save(mapping).Error
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return r.updateStoreCacheForMapping(mapping)
|
|
}
|
|
|
|
// [MODIFIED & FINAL] This is the final version for the core refactoring.
|
|
func (r *gormKeyRepository) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) {
|
|
items := make([]*models.APIKeyDetails, 0)
|
|
var total int64
|
|
|
|
query := r.db.Table("api_keys").
|
|
Select(`
|
|
api_keys.id, api_keys.created_at, api_keys.updated_at,
|
|
api_keys.encrypted_key, -- Select encrypted key to be scanned into APIKeyDetails.EncryptedKey
|
|
api_keys.master_status,
|
|
m.status, m.last_error, m.consecutive_error_count, m.last_used_at, m.cooldown_until
|
|
`).
|
|
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id")
|
|
|
|
if params.KeyGroupID <= 0 {
|
|
return nil, 0, errors.New("KeyGroupID is required for this query")
|
|
}
|
|
query = query.Where("m.key_group_id = ?", params.KeyGroupID)
|
|
|
|
if params.Status != "" {
|
|
query = query.Where("LOWER(m.status) = LOWER(?)", params.Status)
|
|
}
|
|
|
|
// Keyword search is now handled by the service layer.
|
|
if params.Keyword != "" {
|
|
r.logger.Warn("DB query is ignoring keyword; service layer will perform in-memory filtering.")
|
|
}
|
|
|
|
countQuery := query.Model(&models.APIKey{}) // Use model for count to avoid GORM issues
|
|
err := countQuery.Count(&total).Error
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
if total == 0 {
|
|
return items, 0, nil
|
|
}
|
|
|
|
offset := (params.Page - 1) * params.PageSize
|
|
err = query.Order("api_keys.id DESC").Limit(params.PageSize).Offset(offset).Scan(&items).Error
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Decrypt all results before returning. This loop is now valid.
|
|
for i := range items {
|
|
if items[i].EncryptedKey != "" {
|
|
plaintext, err := r.crypto.Decrypt(items[i].EncryptedKey)
|
|
if err == nil {
|
|
items[i].APIKey = plaintext
|
|
} else {
|
|
items[i].APIKey = "[DECRYPTION FAILED]"
|
|
r.logger.Errorf("Failed to decrypt key ID %d for pagination: %v", items[i].ID, err)
|
|
}
|
|
}
|
|
}
|
|
return items, total, nil
|
|
}
|
|
|
|
// [MODIFIED & FINAL] Uses hashes for lookup.
|
|
func (r *gormKeyRepository) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) {
|
|
if len(values) == 0 {
|
|
return []models.APIKey{}, nil
|
|
}
|
|
hashes := make([]string, len(values))
|
|
for i, v := range values {
|
|
hash := sha256.Sum256([]byte(v))
|
|
hashes[i] = hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
var keys []models.APIKey
|
|
err := r.db.Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
|
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
|
Where("api_keys.api_key_hash IN ?", hashes).
|
|
Find(&keys).Error
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return keys, r.decryptKeys(keys)
|
|
}
|
|
|
|
// [MODIFIED & FINAL] Fetches full objects, decrypts, then extracts strings.
|
|
func (r *gormKeyRepository) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) {
|
|
var keys []models.APIKey
|
|
query := r.db.Table("api_keys").
|
|
Select("api_keys.*").
|
|
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id").
|
|
Where("m.key_group_id = ?", groupID)
|
|
|
|
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
|
lowerStatuses := make([]string, len(statuses))
|
|
for i, s := range statuses {
|
|
lowerStatuses[i] = strings.ToLower(s)
|
|
}
|
|
query = query.Where("LOWER(m.status) IN (?)", lowerStatuses)
|
|
}
|
|
|
|
if err := query.Find(&keys).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := r.decryptKeys(keys); err != nil {
|
|
return nil, fmt.Errorf("decryption failed during FindKeyValuesByStatus: %w", err)
|
|
}
|
|
|
|
keyValues := make([]string, len(keys))
|
|
for i, key := range keys {
|
|
keyValues[i] = key.APIKey
|
|
}
|
|
return keyValues, nil
|
|
}
|
|
|
|
// [MODIFIED & FINAL] Consistent with the new pattern.
|
|
func (r *gormKeyRepository) GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error) {
|
|
var keys []models.APIKey
|
|
query := r.db.Table("api_keys").
|
|
Select("api_keys.*").
|
|
Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
|
Where("group_api_key_mappings.key_group_id = ?", groupID)
|
|
|
|
if len(statuses) > 0 {
|
|
isAll := false
|
|
for _, s := range statuses {
|
|
if strings.ToLower(s) == "all" {
|
|
isAll = true
|
|
break
|
|
}
|
|
}
|
|
if !isAll {
|
|
lowerStatuses := make([]string, len(statuses))
|
|
for i, s := range statuses {
|
|
lowerStatuses[i] = strings.ToLower(s)
|
|
}
|
|
query = query.Where("LOWER(group_api_key_mappings.status) IN ?", lowerStatuses)
|
|
}
|
|
}
|
|
|
|
if err := query.Find(&keys).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := r.decryptKeys(keys); err != nil {
|
|
return nil, fmt.Errorf("decryption failed during GetKeyStringsByGroupAndStatus: %w", err)
|
|
}
|
|
|
|
keyStrings := make([]string, len(keys))
|
|
for i, key := range keys {
|
|
keyStrings[i] = key.APIKey
|
|
}
|
|
return keyStrings, nil
|
|
}
|
|
|
|
// FindKeyIDsByStatus remains unchanged as it does not deal with key values.
|
|
func (r *gormKeyRepository) FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error) {
|
|
var keyIDs []uint
|
|
query := r.db.Table("group_api_key_mappings").
|
|
Select("api_key_id").
|
|
Where("key_group_id = ?", groupID)
|
|
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
|
lowerStatuses := make([]string, len(statuses))
|
|
for i, s := range statuses {
|
|
lowerStatuses[i] = strings.ToLower(s)
|
|
}
|
|
query = query.Where("LOWER(status) IN (?)", lowerStatuses)
|
|
}
|
|
if err := query.Pluck("api_key_id", &keyIDs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return keyIDs, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error {
|
|
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
return tx.Save(mapping).Error
|
|
})
|
|
}
|