Files
gemini-banlancer/internal/repository/key_mapping.go
2025-11-22 14:20:05 +08:00

291 lines
8.5 KiB
Go

// Filename: internal/repository/key_mapping.go
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"gemini-balancer/internal/models"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 {
return nil
}
var mappings []models.GroupAPIKeyMapping
for _, keyID := range keyIDs {
mappings = append(mappings, models.GroupAPIKeyMapping{
KeyGroupID: groupID,
APIKeyID: keyID,
})
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&mappings).Error
})
if err != nil {
return err
}
for _, keyID := range keyIDs {
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
}
return nil
}
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
if len(keyIDs) == 0 {
return 0, nil
}
var unlinkedCount int64
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Table("group_api_key_mappings").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Delete(nil)
if result.Error != nil {
return result.Error
}
unlinkedCount = result.RowsAffected
return nil
})
if err != nil {
return 0, err
}
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
for _, keyID := range keyIDs {
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
}
return unlinkedCount, nil
}
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
if err != nil || len(strGroupIDs) == 0 {
var groupIDs []uint
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
if dbErr != nil {
return nil, dbErr
}
if len(groupIDs) > 0 {
var interfaceSlice []interface{}
for _, id := range groupIDs {
interfaceSlice = append(interfaceSlice, id)
}
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
}
return groupIDs, nil
}
var groupIDs []uint
for _, strID := range strGroupIDs {
id, _ := strconv.Atoi(strID)
groupIDs = append(groupIDs, uint(id))
}
return groupIDs, nil
}
func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) {
var mapping models.GroupAPIKeyMapping
err := r.db.Where("key_group_id = ? AND api_key_id = ?", groupID, keyID).First(&mapping).Error
return &mapping, err
}
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})
if err != nil {
return err
}
return r.updateStoreCacheForMapping(mapping)
}
// [MODIFIED & FINAL] This is the final version for the core refactoring.
func (r *gormKeyRepository) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) {
items := make([]*models.APIKeyDetails, 0)
var total int64
query := r.db.Table("api_keys").
Select(`
api_keys.id, api_keys.created_at, api_keys.updated_at,
api_keys.encrypted_key, -- Select encrypted key to be scanned into APIKeyDetails.EncryptedKey
api_keys.master_status,
m.status, m.last_error, m.consecutive_error_count, m.last_used_at, m.cooldown_until
`).
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id")
if params.KeyGroupID <= 0 {
return nil, 0, errors.New("KeyGroupID is required for this query")
}
query = query.Where("m.key_group_id = ?", params.KeyGroupID)
if params.Status != "" {
query = query.Where("LOWER(m.status) = LOWER(?)", params.Status)
}
// Keyword search is now handled by the service layer.
if params.Keyword != "" {
r.logger.Warn("DB query is ignoring keyword; service layer will perform in-memory filtering.")
}
countQuery := query.Model(&models.APIKey{}) // Use model for count to avoid GORM issues
err := countQuery.Count(&total).Error
if err != nil {
return nil, 0, err
}
if total == 0 {
return items, 0, nil
}
offset := (params.Page - 1) * params.PageSize
err = query.Order("api_keys.id DESC").Limit(params.PageSize).Offset(offset).Scan(&items).Error
if err != nil {
return nil, 0, err
}
// Decrypt all results before returning. This loop is now valid.
for i := range items {
if items[i].EncryptedKey != "" {
plaintext, err := r.crypto.Decrypt(items[i].EncryptedKey)
if err == nil {
items[i].APIKey = plaintext
} else {
items[i].APIKey = "[DECRYPTION FAILED]"
r.logger.Errorf("Failed to decrypt key ID %d for pagination: %v", items[i].ID, err)
}
}
}
return items, total, nil
}
// [MODIFIED & FINAL] Uses hashes for lookup.
func (r *gormKeyRepository) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) {
if len(values) == 0 {
return []models.APIKey{}, nil
}
hashes := make([]string, len(values))
for i, v := range values {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
var keys []models.APIKey
err := r.db.Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID).
Where("api_keys.api_key_hash IN ?", hashes).
Find(&keys).Error
if err != nil {
return nil, err
}
return keys, r.decryptKeys(keys)
}
// [MODIFIED & FINAL] Fetches full objects, decrypts, then extracts strings.
func (r *gormKeyRepository) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) {
var keys []models.APIKey
query := r.db.Table("api_keys").
Select("api_keys.*").
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id").
Where("m.key_group_id = ?", groupID)
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
lowerStatuses := make([]string, len(statuses))
for i, s := range statuses {
lowerStatuses[i] = strings.ToLower(s)
}
query = query.Where("LOWER(m.status) IN (?)", lowerStatuses)
}
if err := query.Find(&keys).Error; err != nil {
return nil, err
}
if err := r.decryptKeys(keys); err != nil {
return nil, fmt.Errorf("decryption failed during FindKeyValuesByStatus: %w", err)
}
keyValues := make([]string, len(keys))
for i, key := range keys {
keyValues[i] = key.APIKey
}
return keyValues, nil
}
// [MODIFIED & FINAL] Consistent with the new pattern.
func (r *gormKeyRepository) GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error) {
var keys []models.APIKey
query := r.db.Table("api_keys").
Select("api_keys.*").
Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID)
if len(statuses) > 0 {
isAll := false
for _, s := range statuses {
if strings.ToLower(s) == "all" {
isAll = true
break
}
}
if !isAll {
lowerStatuses := make([]string, len(statuses))
for i, s := range statuses {
lowerStatuses[i] = strings.ToLower(s)
}
query = query.Where("LOWER(group_api_key_mappings.status) IN ?", lowerStatuses)
}
}
if err := query.Find(&keys).Error; err != nil {
return nil, err
}
if err := r.decryptKeys(keys); err != nil {
return nil, fmt.Errorf("decryption failed during GetKeyStringsByGroupAndStatus: %w", err)
}
keyStrings := make([]string, len(keys))
for i, key := range keys {
keyStrings[i] = key.APIKey
}
return keyStrings, nil
}
// FindKeyIDsByStatus remains unchanged as it does not deal with key values.
func (r *gormKeyRepository) FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error) {
var keyIDs []uint
query := r.db.Table("group_api_key_mappings").
Select("api_key_id").
Where("key_group_id = ?", groupID)
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
lowerStatuses := make([]string, len(statuses))
for i, s := range statuses {
lowerStatuses[i] = strings.ToLower(s)
}
query = query.Where("LOWER(status) IN (?)", lowerStatuses)
}
if err := query.Pluck("api_key_id", &keyIDs).Error; err != nil {
return nil, err
}
return keyIDs, nil
}
func (r *gormKeyRepository) UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error {
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})
}