// Filename: internal/repository/key_mapping.go package repository import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "gemini-balancer/internal/models" "strconv" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" ) func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error { if len(keyIDs) == 0 { return nil } var mappings []models.GroupAPIKeyMapping for _, keyID := range keyIDs { mappings = append(mappings, models.GroupAPIKeyMapping{ KeyGroupID: groupID, APIKeyID: keyID, }) } err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&mappings).Error }) if err != nil { return err } for _, keyID := range keyIDs { r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID) } return nil } func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) { if len(keyIDs) == 0 { return 0, nil } var unlinkedCount int64 err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { result := tx.Table("group_api_key_mappings"). Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs). Delete(nil) if result.Error != nil { return result.Error } unlinkedCount = result.RowsAffected return nil }) if err != nil { return 0, err } activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID) for _, keyID := range keyIDs { r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID) r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID))) } return unlinkedCount, nil } func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) { cacheKey := fmt.Sprintf("key:%d:groups", keyID) strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey) if err != nil || len(strGroupIDs) == 0 { var groupIDs []uint dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error if dbErr != nil { return nil, dbErr } if len(groupIDs) > 0 { var interfaceSlice []interface{} for _, id := range groupIDs { interfaceSlice = append(interfaceSlice, id) } r.store.SAdd(context.Background(), cacheKey, interfaceSlice...) } return groupIDs, nil } var groupIDs []uint for _, strID := range strGroupIDs { id, _ := strconv.Atoi(strID) groupIDs = append(groupIDs, uint(id)) } return groupIDs, nil } func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) { var mapping models.GroupAPIKeyMapping err := r.db.Where("key_group_id = ? AND api_key_id = ?", groupID, keyID).First(&mapping).Error return &mapping, err } func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error { err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { return tx.Save(mapping).Error }) if err != nil { return err } return r.updateStoreCacheForMapping(mapping) } // [MODIFIED & FINAL] This is the final version for the core refactoring. func (r *gormKeyRepository) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) { items := make([]*models.APIKeyDetails, 0) var total int64 query := r.db.Table("api_keys"). Select(` api_keys.id, api_keys.created_at, api_keys.updated_at, api_keys.encrypted_key, -- Select encrypted key to be scanned into APIKeyDetails.EncryptedKey api_keys.master_status, m.status, m.last_error, m.consecutive_error_count, m.last_used_at, m.cooldown_until `). Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id") if params.KeyGroupID <= 0 { return nil, 0, errors.New("KeyGroupID is required for this query") } query = query.Where("m.key_group_id = ?", params.KeyGroupID) if params.Status != "" { query = query.Where("LOWER(m.status) = LOWER(?)", params.Status) } // Keyword search is now handled by the service layer. if params.Keyword != "" { r.logger.Warn("DB query is ignoring keyword; service layer will perform in-memory filtering.") } countQuery := query.Model(&models.APIKey{}) // Use model for count to avoid GORM issues err := countQuery.Count(&total).Error if err != nil { return nil, 0, err } if total == 0 { return items, 0, nil } offset := (params.Page - 1) * params.PageSize err = query.Order("api_keys.id DESC").Limit(params.PageSize).Offset(offset).Scan(&items).Error if err != nil { return nil, 0, err } // Decrypt all results before returning. This loop is now valid. for i := range items { if items[i].EncryptedKey != "" { plaintext, err := r.crypto.Decrypt(items[i].EncryptedKey) if err == nil { items[i].APIKey = plaintext } else { items[i].APIKey = "[DECRYPTION FAILED]" r.logger.Errorf("Failed to decrypt key ID %d for pagination: %v", items[i].ID, err) } } } return items, total, nil } // [MODIFIED & FINAL] Uses hashes for lookup. func (r *gormKeyRepository) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) { if len(values) == 0 { return []models.APIKey{}, nil } hashes := make([]string, len(values)) for i, v := range values { hash := sha256.Sum256([]byte(v)) hashes[i] = hex.EncodeToString(hash[:]) } var keys []models.APIKey err := r.db.Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id"). Where("group_api_key_mappings.key_group_id = ?", groupID). Where("api_keys.api_key_hash IN ?", hashes). Find(&keys).Error if err != nil { return nil, err } return keys, r.decryptKeys(keys) } // [MODIFIED & FINAL] Fetches full objects, decrypts, then extracts strings. func (r *gormKeyRepository) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) { var keys []models.APIKey query := r.db.Table("api_keys"). Select("api_keys.*"). Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id"). Where("m.key_group_id = ?", groupID) if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") { lowerStatuses := make([]string, len(statuses)) for i, s := range statuses { lowerStatuses[i] = strings.ToLower(s) } query = query.Where("LOWER(m.status) IN (?)", lowerStatuses) } if err := query.Find(&keys).Error; err != nil { return nil, err } if err := r.decryptKeys(keys); err != nil { return nil, fmt.Errorf("decryption failed during FindKeyValuesByStatus: %w", err) } keyValues := make([]string, len(keys)) for i, key := range keys { keyValues[i] = key.APIKey } return keyValues, nil } // [MODIFIED & FINAL] Consistent with the new pattern. func (r *gormKeyRepository) GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error) { var keys []models.APIKey query := r.db.Table("api_keys"). Select("api_keys.*"). Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id"). Where("group_api_key_mappings.key_group_id = ?", groupID) if len(statuses) > 0 { isAll := false for _, s := range statuses { if strings.ToLower(s) == "all" { isAll = true break } } if !isAll { lowerStatuses := make([]string, len(statuses)) for i, s := range statuses { lowerStatuses[i] = strings.ToLower(s) } query = query.Where("LOWER(group_api_key_mappings.status) IN ?", lowerStatuses) } } if err := query.Find(&keys).Error; err != nil { return nil, err } if err := r.decryptKeys(keys); err != nil { return nil, fmt.Errorf("decryption failed during GetKeyStringsByGroupAndStatus: %w", err) } keyStrings := make([]string, len(keys)) for i, key := range keys { keyStrings[i] = key.APIKey } return keyStrings, nil } // FindKeyIDsByStatus remains unchanged as it does not deal with key values. func (r *gormKeyRepository) FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error) { var keyIDs []uint query := r.db.Table("group_api_key_mappings"). Select("api_key_id"). Where("key_group_id = ?", groupID) if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") { lowerStatuses := make([]string, len(statuses)) for i, s := range statuses { lowerStatuses[i] = strings.ToLower(s) } query = query.Where("LOWER(status) IN (?)", lowerStatuses) } if err := query.Pluck("api_key_id", &keyIDs).Error; err != nil { return nil, err } return keyIDs, nil } func (r *gormKeyRepository) UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error { return r.executeTransactionWithRetry(func(tx *gorm.DB) error { return tx.Save(mapping).Error }) }