Files
gemini-banlancer/internal/repository/key_maintenance.go
2025-11-22 14:20:05 +08:00

171 lines
4.9 KiB
Go

// Filename: internal/repository/key_maintenance.go
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/models"
"io"
"time"
"gorm.io/gorm"
)
func (r *gormKeyRepository) StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error {
query := r.db.Model(&models.APIKey{}).
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
Where("group_api_key_mappings.key_group_id = ?", groupID)
if statusFilter != "" && statusFilter != "all" {
query = query.Where("group_api_key_mappings.status = ?", statusFilter)
}
var batchKeys []models.APIKey
return query.FindInBatches(&batchKeys, 1000, func(tx *gorm.DB, batch int) error {
if err := r.decryptKeys(batchKeys); err != nil {
r.logger.Errorf("Failed to decrypt batch %d for streaming: %v", batch, err)
}
for _, key := range batchKeys {
if key.APIKey != "" {
if _, err := writer.Write([]byte(key.APIKey + "\n")); err != nil {
return err
}
}
}
return nil
}).Error
}
func (r *gormKeyRepository) UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error) {
if len(keyValues) == 0 {
return 0, nil
}
hashes := make([]string, len(keyValues))
for i, v := range keyValues {
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
var result *gorm.DB
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result = tx.Model(&models.APIKey{}).
Where("api_key_hash IN ?", hashes).
Update("master_status", newStatus)
return result.Error
})
if err != nil {
return 0, err
}
return result.RowsAffected, nil
}
func (r *gormKeyRepository) UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error {
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
Update("master_status", newStatus)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
// This ensures that if the key ID doesn't exist, we return a standard "not found" error.
return gorm.ErrRecordNotFound
}
return nil
})
}
func (r *gormKeyRepository) DeleteOrphanKeys() (int64, error) {
var deletedCount int64
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
count, err := r.deleteOrphanKeysLogic(tx)
if err != nil {
return err
}
deletedCount = count
return nil
})
return deletedCount, err
}
func (r *gormKeyRepository) DeleteOrphanKeysTx(tx *gorm.DB) (int64, error) {
return r.deleteOrphanKeysLogic(tx)
}
func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
var orphanKeyIDs []uint
err := db.Raw(`
SELECT api_keys.id FROM api_keys
LEFT JOIN group_api_key_mappings ON api_keys.id = group_api_key_mappings.api_key_id
WHERE group_api_key_mappings.api_key_id IS NULL`).Scan(&orphanKeyIDs).Error
if err != nil {
return 0, err
}
if len(orphanKeyIDs) == 0 {
return 0, nil
}
var keysToDelete []models.APIKey
if err := db.Where("id IN ?", orphanKeyIDs).Find(&keysToDelete).Error; err != nil {
return 0, err
}
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
if result.Error != nil {
return 0, result.Error
}
for i := range keysToDelete {
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
return result.RowsAffected, nil
}
func (r *gormKeyRepository) HardDeleteSoftDeletedBefore(date time.Time) (int64, error) {
result := r.db.Unscoped().Where("deleted_at < ?", date).Delete(&models.APIKey{})
return result.RowsAffected, result.Error
}
func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
var keys []*models.APIKey
err := r.db.Where("master_status = ?", models.MasterStatusActive).Find(&keys).Error
if err != nil {
return nil, err
}
for _, key := range keys {
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("Failed to decrypt key ID %d during GetActiveMasterKeys: %v", key.ID, err)
}
}
return keys, nil
}
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
Update("master_status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
})
if err == nil {
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
go func() {
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
}
}()
}
return err
}