171 lines
4.9 KiB
Go
171 lines
4.9 KiB
Go
// Filename: internal/repository/key_maintenance.go
|
|
package repository
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"gemini-balancer/internal/models"
|
|
"io"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func (r *gormKeyRepository) StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error {
|
|
query := r.db.Model(&models.APIKey{}).
|
|
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
|
Where("group_api_key_mappings.key_group_id = ?", groupID)
|
|
|
|
if statusFilter != "" && statusFilter != "all" {
|
|
query = query.Where("group_api_key_mappings.status = ?", statusFilter)
|
|
}
|
|
var batchKeys []models.APIKey
|
|
return query.FindInBatches(&batchKeys, 1000, func(tx *gorm.DB, batch int) error {
|
|
if err := r.decryptKeys(batchKeys); err != nil {
|
|
r.logger.Errorf("Failed to decrypt batch %d for streaming: %v", batch, err)
|
|
}
|
|
for _, key := range batchKeys {
|
|
if key.APIKey != "" {
|
|
if _, err := writer.Write([]byte(key.APIKey + "\n")); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}).Error
|
|
}
|
|
|
|
func (r *gormKeyRepository) UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error) {
|
|
if len(keyValues) == 0 {
|
|
return 0, nil
|
|
}
|
|
hashes := make([]string, len(keyValues))
|
|
for i, v := range keyValues {
|
|
hash := sha256.Sum256([]byte(v))
|
|
hashes[i] = hex.EncodeToString(hash[:])
|
|
}
|
|
var result *gorm.DB
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
result = tx.Model(&models.APIKey{}).
|
|
Where("api_key_hash IN ?", hashes).
|
|
Update("master_status", newStatus)
|
|
return result.Error
|
|
})
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error {
|
|
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
result := tx.Model(&models.APIKey{}).
|
|
Where("id = ?", keyID).
|
|
Update("master_status", newStatus)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
// This ensures that if the key ID doesn't exist, we return a standard "not found" error.
|
|
return gorm.ErrRecordNotFound
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *gormKeyRepository) DeleteOrphanKeys() (int64, error) {
|
|
var deletedCount int64
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
count, err := r.deleteOrphanKeysLogic(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
deletedCount = count
|
|
return nil
|
|
})
|
|
return deletedCount, err
|
|
}
|
|
|
|
func (r *gormKeyRepository) DeleteOrphanKeysTx(tx *gorm.DB) (int64, error) {
|
|
return r.deleteOrphanKeysLogic(tx)
|
|
}
|
|
|
|
func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
|
|
var orphanKeyIDs []uint
|
|
err := db.Raw(`
|
|
SELECT api_keys.id FROM api_keys
|
|
LEFT JOIN group_api_key_mappings ON api_keys.id = group_api_key_mappings.api_key_id
|
|
WHERE group_api_key_mappings.api_key_id IS NULL`).Scan(&orphanKeyIDs).Error
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if len(orphanKeyIDs) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
var keysToDelete []models.APIKey
|
|
if err := db.Where("id IN ?", orphanKeyIDs).Find(&keysToDelete).Error; err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
|
|
if result.Error != nil {
|
|
return 0, result.Error
|
|
}
|
|
|
|
for i := range keysToDelete {
|
|
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
|
|
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
|
|
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
|
}
|
|
}
|
|
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) HardDeleteSoftDeletedBefore(date time.Time) (int64, error) {
|
|
result := r.db.Unscoped().Where("deleted_at < ?", date).Delete(&models.APIKey{})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
|
|
var keys []*models.APIKey
|
|
err := r.db.Where("master_status = ?", models.MasterStatusActive).Find(&keys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, key := range keys {
|
|
if err := r.decryptKey(key); err != nil {
|
|
r.logger.Warnf("Failed to decrypt key ID %d during GetActiveMasterKeys: %v", key.ID, err)
|
|
}
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
|
|
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
|
result := tx.Model(&models.APIKey{}).
|
|
Where("id = ?", keyID).
|
|
Update("master_status", status)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return gorm.ErrRecordNotFound
|
|
}
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
|
|
go func() {
|
|
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
|
|
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
|
|
}
|
|
}()
|
|
}
|
|
return err
|
|
}
|