253 lines
7.9 KiB
Go
253 lines
7.9 KiB
Go
// Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版)
|
||
package settings
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/store"
|
||
"gemini-balancer/internal/syncer"
|
||
"reflect"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
)
|
||
|
||
const SettingsUpdateChannel = "system_settings:updated"
|
||
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||
|
||
var _ models.SettingsManager = (*SettingsManager)(nil)
|
||
|
||
// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。
|
||
type SettingsManager struct {
|
||
db *gorm.DB
|
||
syncer *syncer.CacheSyncer[*models.SystemSettings]
|
||
logger *logrus.Entry
|
||
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
|
||
}
|
||
|
||
// NewSettingsManager 创建一个新的 SettingsManager 实例。
|
||
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
|
||
sm := &SettingsManager{
|
||
db: db,
|
||
logger: logger.WithField("component", "SettingsManager⚙️"),
|
||
jsonToFieldType: make(map[string]reflect.Type),
|
||
}
|
||
|
||
settingsType := reflect.TypeOf(models.SystemSettings{})
|
||
for i := 0; i < settingsType.NumField(); i++ {
|
||
field := settingsType.Field(i)
|
||
jsonTag := field.Tag.Get("json")
|
||
if jsonTag != "" && jsonTag != "-" {
|
||
sm.jsonToFieldType[jsonTag] = field.Type
|
||
}
|
||
}
|
||
|
||
settingsLoader := func() (*models.SystemSettings, error) {
|
||
sm.logger.Info("Loading system settings from database...")
|
||
var dbRecords []models.Setting
|
||
if err := sm.db.Find(&dbRecords).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to load system settings from db: %w", err)
|
||
}
|
||
|
||
settingsMap := make(map[string]string)
|
||
for _, record := range dbRecords {
|
||
settingsMap[record.Key] = record.Value
|
||
}
|
||
|
||
settings := defaultSystemSettings()
|
||
v := reflect.ValueOf(settings).Elem()
|
||
|
||
for i := 0; i < v.NumField(); i++ {
|
||
field := v.Type().Field(i)
|
||
fieldValue := v.Field(i)
|
||
jsonTag := field.Tag.Get("json")
|
||
|
||
if dbValue, ok := settingsMap[jsonTag]; ok {
|
||
if err := parseAndSetField(fieldValue, dbValue); err != nil {
|
||
sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。
|
||
if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" {
|
||
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
|
||
sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint)
|
||
settings.BaseKeyCheckEndpoint = derivedEndpoint
|
||
} else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" {
|
||
// 恢复 else 日志,以明确告知用户正在使用自定义覆盖。
|
||
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
|
||
}
|
||
|
||
sm.logger.Info("System settings loaded and cached.")
|
||
return settings, nil
|
||
}
|
||
|
||
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
|
||
}
|
||
sm.syncer = s
|
||
|
||
if err := sm.ensureSettingsInitialized(); err != nil {
|
||
return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err)
|
||
}
|
||
|
||
return sm, nil
|
||
}
|
||
|
||
// GetSettings 返回当前缓存的系统设置。
|
||
func (sm *SettingsManager) GetSettings() *models.SystemSettings {
|
||
return sm.syncer.Get()
|
||
}
|
||
|
||
// UpdateSettings 更新一个或多个系统设置。
|
||
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
|
||
var settingsToUpdate []models.Setting
|
||
|
||
for key, value := range settingsMap {
|
||
fieldType, ok := sm.jsonToFieldType[key]
|
||
if !ok {
|
||
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
|
||
continue
|
||
}
|
||
|
||
dbValue, err := sm.convertToDBValue(key, value, fieldType)
|
||
if err != nil {
|
||
sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err)
|
||
continue
|
||
}
|
||
|
||
settingsToUpdate = append(settingsToUpdate, models.Setting{
|
||
Key: key,
|
||
Value: dbValue,
|
||
})
|
||
}
|
||
|
||
if len(settingsToUpdate) > 0 {
|
||
err := sm.db.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "key"}},
|
||
DoUpdates: clause.AssignmentColumns([]string{"value"}),
|
||
}).Create(&settingsToUpdate).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to update settings in db: %w", err)
|
||
}
|
||
}
|
||
|
||
if err := sm.syncer.Invalidate(); err != nil {
|
||
sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err)
|
||
return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ResetAndSaveSettings 将所有设置重置为其默认值。
|
||
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
|
||
defaults := defaultSystemSettings()
|
||
settingsToSave := sm.buildSettingsFromDefaults(defaults)
|
||
|
||
if len(settingsToSave) > 0 {
|
||
err := sm.db.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "key"}},
|
||
DoUpdates: clause.AssignmentColumns([]string{"value", "name", "description", "category", "default_value"}),
|
||
}).Create(&settingsToSave).Error
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to reset settings in db: %w", err)
|
||
}
|
||
}
|
||
|
||
if err := sm.syncer.Invalidate(); err != nil {
|
||
sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err)
|
||
return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err)
|
||
}
|
||
|
||
return defaults, nil
|
||
}
|
||
|
||
// --- 私有辅助函数 ---
|
||
|
||
func (sm *SettingsManager) ensureSettingsInitialized() error {
|
||
defaults := defaultSystemSettings()
|
||
settingsToCreate := sm.buildSettingsFromDefaults(defaults)
|
||
|
||
for _, setting := range settingsToCreate {
|
||
var existing models.Setting
|
||
err := sm.db.Where("key = ?", setting.Key).First(&existing).Error
|
||
|
||
if err == gorm.ErrRecordNotFound {
|
||
sm.logger.Infof("Initializing new setting '%s'", setting.Key)
|
||
if createErr := sm.db.Create(&setting).Error; createErr != nil {
|
||
return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr)
|
||
}
|
||
} else if err != nil {
|
||
return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting {
|
||
v := reflect.ValueOf(defaults).Elem()
|
||
t := v.Type()
|
||
var settings []models.Setting
|
||
|
||
for i := 0; i < t.NumField(); i++ {
|
||
field := t.Field(i)
|
||
fieldValue := v.Field(i)
|
||
key := field.Tag.Get("json")
|
||
|
||
if key == "" || key == "-" {
|
||
continue
|
||
}
|
||
|
||
var defaultValue string
|
||
kind := fieldValue.Kind()
|
||
|
||
if kind == reflect.Slice || kind == reflect.Map {
|
||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||
defaultValue = string(jsonBytes)
|
||
} else {
|
||
defaultValue = field.Tag.Get("default")
|
||
}
|
||
|
||
settings = append(settings, models.Setting{
|
||
Key: key,
|
||
Value: defaultValue,
|
||
Name: field.Tag.Get("name"),
|
||
Description: field.Tag.Get("desc"),
|
||
Category: field.Tag.Get("category"),
|
||
DefaultValue: field.Tag.Get("default"),
|
||
})
|
||
}
|
||
return settings
|
||
}
|
||
|
||
// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。
|
||
func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) {
|
||
kind := fieldType.Kind()
|
||
|
||
switch kind {
|
||
case reflect.Slice, reflect.Map:
|
||
jsonBytes, err := json.Marshal(value)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
|
||
}
|
||
return string(jsonBytes), nil
|
||
|
||
case reflect.Bool:
|
||
b, ok := value.(bool)
|
||
if !ok {
|
||
return "", fmt.Errorf("expected bool, but got %T", value)
|
||
}
|
||
return strconv.FormatBool(b), nil
|
||
|
||
default:
|
||
return fmt.Sprintf("%v", value), nil
|
||
}
|
||
}
|