New
This commit is contained in:
88
internal/db/db.go
Normal file
88
internal/db/db.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Filename: internal/db/db.go
|
||||
package db
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
stdlog "log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func NewDB(cfg *config.Config, appLogger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
|
||||
Logger := appLogger.WithField("component", "db")
|
||||
Logger.Info("Initializing database connection and dialect adapter...")
|
||||
dbConfig := cfg.Database
|
||||
dsn := dbConfig.DSN
|
||||
var gormLogger logger.Interface
|
||||
if cfg.Log.Level == "debug" {
|
||||
gormLogger = logger.New(
|
||||
stdlog.New(os.Stdout, "\r\n", stdlog.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: 1 * time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: true,
|
||||
},
|
||||
)
|
||||
Logger.Info("Debug mode enabled, GORM SQL logging is active.")
|
||||
}
|
||||
|
||||
var dialector gorm.Dialector
|
||||
var adapter dialect.DialectAdapter
|
||||
switch {
|
||||
case strings.HasPrefix(dsn, "postgres://"), strings.HasPrefix(dsn, "postgresql://"):
|
||||
Logger.Info("Detected PostgreSQL database.")
|
||||
dialector = postgres.Open(dsn)
|
||||
adapter = dialect.NewPostgresAdapter()
|
||||
case strings.Contains(dsn, "@tcp"):
|
||||
Logger.Info("Detected MySQL database.")
|
||||
if !strings.Contains(dsn, "parseTime=true") {
|
||||
if strings.Contains(dsn, "?") {
|
||||
dsn += "&parseTime=true"
|
||||
} else {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
dialector = mysql.Open(dsn)
|
||||
adapter = dialect.NewPostgresAdapter()
|
||||
default:
|
||||
Logger.Info("Using SQLite database.")
|
||||
if err := os.MkdirAll(filepath.Dir(dsn), 0755); err != nil {
|
||||
Logger.Errorf("Failed to create SQLite directory: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
dialector = sqlite.Open(dsn + "?_busy_timeout=5000")
|
||||
adapter = dialect.NewSQLiteAdapter()
|
||||
}
|
||||
db, err := gorm.Open(dialector, &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
Logger.Errorf("Failed to open database connection: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
Logger.Errorf("Failed to get underlying sql.DB: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(dbConfig.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(dbConfig.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(dbConfig.ConnMaxLifetime)
|
||||
Logger.Infof("Connection pool configured: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
dbConfig.MaxIdleConns, dbConfig.MaxOpenConns, dbConfig.ConnMaxLifetime)
|
||||
Logger.Info("Database connection established successfully.")
|
||||
return db, adapter, nil
|
||||
}
|
||||
14
internal/db/dialect/dialect.go
Normal file
14
internal/db/dialect/dialect.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Filename: internal/db/dialect/dialect.go
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// “通用语言”接口。
|
||||
type DialectAdapter interface {
|
||||
// OnConflictUpdateAll 生成一个完整的、适用于当前数据库的 "ON CONFLICT DO UPDATE" 子句。
|
||||
// conflictColumns: 唯一的约束列,例如 ["time", "group_id", "model_name"]
|
||||
// updateColumns: 需要累加更新的列,例如 ["request_count", "success_count", ...]
|
||||
OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression
|
||||
}
|
||||
30
internal/db/dialect/mysql_adapter.go
Normal file
30
internal/db/dialect/mysql_adapter.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Filename: internal/db/mysql_adapter.go
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type mysqlAdapter struct{}
|
||||
|
||||
func NewMySQLAdapter() DialectAdapter {
|
||||
return &mysqlAdapter{}
|
||||
}
|
||||
|
||||
func (a *mysqlAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
|
||||
conflictCols := make([]clause.Column, len(conflictColumns))
|
||||
for i, col := range conflictColumns {
|
||||
conflictCols[i] = clause.Column{Name: col}
|
||||
}
|
||||
|
||||
assignments := make(map[string]interface{})
|
||||
for _, col := range updateColumns {
|
||||
assignments[col] = gorm.Expr(col + " + VALUES(" + col + ")")
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: conflictCols,
|
||||
DoUpdates: clause.Assignments(assignments),
|
||||
}
|
||||
}
|
||||
30
internal/db/dialect/postgres_adapter.go
Normal file
30
internal/db/dialect/postgres_adapter.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Filename: internal/db/dialect/postgres_adapter.go (全新文件)
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type postgresAdapter struct{}
|
||||
|
||||
func NewPostgresAdapter() DialectAdapter {
|
||||
return &postgresAdapter{}
|
||||
}
|
||||
|
||||
func (a *postgresAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
|
||||
conflictCols := make([]clause.Column, len(conflictColumns))
|
||||
for i, col := range conflictColumns {
|
||||
conflictCols[i] = clause.Column{Name: col}
|
||||
}
|
||||
|
||||
assignments := make(map[string]interface{})
|
||||
for _, col := range updateColumns {
|
||||
assignments[col] = gorm.Expr(col + " + excluded." + col)
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: conflictCols,
|
||||
DoUpdates: clause.Assignments(assignments),
|
||||
}
|
||||
}
|
||||
30
internal/db/dialect/sqlite_adapter.go
Normal file
30
internal/db/dialect/sqlite_adapter.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Filename: internal/db/sqlite_adapter.go (全新文件 - 最终版)
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type sqliteAdapter struct{}
|
||||
|
||||
func NewSQLiteAdapter() DialectAdapter {
|
||||
return &sqliteAdapter{}
|
||||
}
|
||||
|
||||
func (a *sqliteAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
|
||||
conflictCols := make([]clause.Column, len(conflictColumns))
|
||||
for i, col := range conflictColumns {
|
||||
conflictCols[i] = clause.Column{Name: col}
|
||||
}
|
||||
|
||||
assignments := make(map[string]interface{})
|
||||
for _, col := range updateColumns {
|
||||
assignments[col] = gorm.Expr(col + " + excluded." + col)
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: conflictCols,
|
||||
DoUpdates: clause.Assignments(assignments),
|
||||
}
|
||||
}
|
||||
36
internal/db/migrations/migrations.go
Normal file
36
internal/db/migrations/migrations.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Filename: internal/db/migrations/migrations.go (全新)
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RunMigrations 负责执行所有的数据库模式迁移。
|
||||
func RunMigrations(db *gorm.DB, logger *logrus.Logger) error {
|
||||
log := logger.WithField("component", "migrations")
|
||||
log.Info("Running database schema migrations...")
|
||||
// 集中管理所有需要被创建或更新的表。
|
||||
err := db.AutoMigrate(
|
||||
&models.UpstreamEndpoint{},
|
||||
&models.ProxyConfig{},
|
||||
&models.APIKey{},
|
||||
&models.KeyGroup{},
|
||||
&models.GroupModelMapping{},
|
||||
&models.AuthToken{},
|
||||
&models.RequestLog{},
|
||||
&models.StatsHourly{},
|
||||
&models.FileRecord{},
|
||||
&models.Setting{},
|
||||
&models.GroupSettings{},
|
||||
&models.GroupAPIKeyMapping{},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Database schema migration failed: %v", err)
|
||||
return err
|
||||
}
|
||||
log.Info("Database schema migrations completed successfully.")
|
||||
return nil
|
||||
}
|
||||
62
internal/db/migrations/versioned_migrations.go
Normal file
62
internal/db/migrations/versioned_migrations.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Filename: internal/db/migrations/versioned_migrations.go
|
||||
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/config"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RunVersionedMigrations 负责运行所有已注册的版本化迁移。
|
||||
func RunVersionedMigrations(db *gorm.DB, cfg *config.Config, logger *logrus.Logger) error {
|
||||
log := logger.WithField("component", "versioned_migrations")
|
||||
log.Info("Checking for versioned database migrations...")
|
||||
|
||||
if err := db.AutoMigrate(&MigrationHistory{}); err != nil {
|
||||
log.Errorf("Failed to create migration history table: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var executedMigrations []MigrationHistory
|
||||
db.Find(&executedMigrations)
|
||||
executedVersions := make(map[string]bool)
|
||||
for _, m := range executedMigrations {
|
||||
executedVersions[m.Version] = true
|
||||
}
|
||||
|
||||
for _, migration := range migrationRegistry {
|
||||
if !executedVersions[migration.Version] {
|
||||
log.Infof("Running migration %s: %s", migration.Version, migration.Description)
|
||||
if err := migration.Migrate(db, cfg, log); err != nil {
|
||||
log.Errorf("Migration %s failed: %v", migration.Version, err)
|
||||
return fmt.Errorf("migration %s failed: %w", migration.Version, err)
|
||||
}
|
||||
db.Create(&MigrationHistory{Version: migration.Version})
|
||||
log.Infof("Migration %s completed successfully.", migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("All versioned migrations are up to date.")
|
||||
return nil
|
||||
}
|
||||
|
||||
type MigrationFunc func(db *gorm.DB, cfg *config.Config, logger *logrus.Entry) error
|
||||
type VersionedMigration struct {
|
||||
Version string
|
||||
Description string
|
||||
Migrate MigrationFunc
|
||||
}
|
||||
type MigrationHistory struct {
|
||||
Version string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
var migrationRegistry = []VersionedMigration{
|
||||
/*{
|
||||
Version: "20250828_encrypt_existing_auth_tokens",
|
||||
Description: "Encrypt plaintext tokens and populate new crypto columns in auth_tokens table.",
|
||||
Migrate: MigrateAuthTokenEncryption,
|
||||
},*/
|
||||
}
|
||||
87
internal/db/seeder/seeder.go
Normal file
87
internal/db/seeder/seeder.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Filename: internal/db/seeder/seeder.go
|
||||
package seeder
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RunSeeder now requires the crypto service to create the initial admin token securely.
|
||||
func RunSeeder(db *gorm.DB, cryptoService *crypto.Service, logger *logrus.Logger) {
|
||||
log := logger.WithField("component", "seeder")
|
||||
log.Info("Running database seeder...")
|
||||
// [REFACTORED] Admin token seeding is now crypto-aware.
|
||||
var count int64
|
||||
db.Model(&models.AuthToken{}).Where("is_admin = ?", true).Count(&count)
|
||||
if count == 0 {
|
||||
log.Info("No admin token found, attempting to seed one...")
|
||||
const adminTokenPlaintext = "admin-secret-token" // The default token
|
||||
// 1. Encrypt and Hash the token
|
||||
encryptedToken, err := cryptoService.Encrypt(adminTokenPlaintext)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Failed to encrypt default admin token during seeding: %v. Server cannot start.", err)
|
||||
return
|
||||
}
|
||||
hash := sha256.Sum256([]byte(adminTokenPlaintext))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
// 2. Use the repository to seed the token
|
||||
// Note: We create a temporary repository instance here just for the seeder.
|
||||
repo := repository.NewAuthTokenRepository(db, cryptoService, logger)
|
||||
if err := repo.SeedAdminToken(encryptedToken, tokenHash); err != nil {
|
||||
log.Warnf("Failed to seed admin token using repository: %v", err)
|
||||
} else {
|
||||
log.Infof("Default admin token has been seeded successfully. Please use '%s' for your first login.", adminTokenPlaintext)
|
||||
}
|
||||
} else {
|
||||
log.Info("Admin token already exists, seeder skipped.")
|
||||
}
|
||||
|
||||
// This functionality should be replaced by a proper user/token management UI in the future.
|
||||
linkAllKeysToDefaultGroup(db, log)
|
||||
}
|
||||
|
||||
// linkAllKeysToDefaultGroup ensures every key belongs to at least one group.
|
||||
func linkAllKeysToDefaultGroup(db *gorm.DB, logger *logrus.Entry) {
|
||||
logger.Info("Linking existing API keys to the default group as a fallback...")
|
||||
// 1. Find a default group (the first one for simplicity)
|
||||
var defaultGroup models.KeyGroup
|
||||
if err := db.Order("id asc").First(&defaultGroup).Error; err != nil {
|
||||
logger.Warnf("Seeder: Could not find a default key group to link keys to: %v", err)
|
||||
return
|
||||
}
|
||||
// 2. Find all "orphan keys" that don't belong to any group
|
||||
var orphanKeys []*models.APIKey
|
||||
err := db.Raw(`
|
||||
SELECT * FROM api_keys
|
||||
WHERE id NOT IN (SELECT DISTINCT api_key_id FROM group_api_key_mappings)
|
||||
AND deleted_at IS NULL
|
||||
`).Scan(&orphanKeys).Error
|
||||
if err != nil {
|
||||
logger.Errorf("Seeder: Failed to query for orphan keys: %v", err)
|
||||
return
|
||||
}
|
||||
if len(orphanKeys) == 0 {
|
||||
logger.Info("Seeder: No orphan API keys found to link.")
|
||||
return
|
||||
}
|
||||
// 3. Create GroupAPIKeyMapping records manually
|
||||
logger.Infof("Seeder: Found %d orphan keys. Creating mappings for them in group '%s' (ID: %d)...", len(orphanKeys), defaultGroup.Name, defaultGroup.ID)
|
||||
var newMappings []models.GroupAPIKeyMapping
|
||||
for _, key := range orphanKeys {
|
||||
newMappings = append(newMappings, models.GroupAPIKeyMapping{
|
||||
KeyGroupID: defaultGroup.ID,
|
||||
APIKeyID: key.ID,
|
||||
})
|
||||
}
|
||||
if err := db.Create(&newMappings).Error; err != nil {
|
||||
logger.Errorf("Seeder: Failed to create key mappings for orphan keys: %v", err)
|
||||
} else {
|
||||
logger.Info("Successfully created mappings for orphan API keys.")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user