This commit is contained in:
XOF
2025-11-20 12:24:05 +08:00
commit f28bdc751f
164 changed files with 64248 additions and 0 deletions

88
internal/db/db.go Normal file
View File

@@ -0,0 +1,88 @@
// Filename: internal/db/db.go
package db
import (
"gemini-balancer/internal/config"
"gemini-balancer/internal/db/dialect"
stdlog "log"
"os"
"path/filepath"
"strings"
"time"
"github.com/glebarez/sqlite"
"github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func NewDB(cfg *config.Config, appLogger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
Logger := appLogger.WithField("component", "db")
Logger.Info("Initializing database connection and dialect adapter...")
dbConfig := cfg.Database
dsn := dbConfig.DSN
var gormLogger logger.Interface
if cfg.Log.Level == "debug" {
gormLogger = logger.New(
stdlog.New(os.Stdout, "\r\n", stdlog.LstdFlags),
logger.Config{
SlowThreshold: 1 * time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: true,
},
)
Logger.Info("Debug mode enabled, GORM SQL logging is active.")
}
var dialector gorm.Dialector
var adapter dialect.DialectAdapter
switch {
case strings.HasPrefix(dsn, "postgres://"), strings.HasPrefix(dsn, "postgresql://"):
Logger.Info("Detected PostgreSQL database.")
dialector = postgres.Open(dsn)
adapter = dialect.NewPostgresAdapter()
case strings.Contains(dsn, "@tcp"):
Logger.Info("Detected MySQL database.")
if !strings.Contains(dsn, "parseTime=true") {
if strings.Contains(dsn, "?") {
dsn += "&parseTime=true"
} else {
dsn += "?parseTime=true"
}
}
dialector = mysql.Open(dsn)
adapter = dialect.NewPostgresAdapter()
default:
Logger.Info("Using SQLite database.")
if err := os.MkdirAll(filepath.Dir(dsn), 0755); err != nil {
Logger.Errorf("Failed to create SQLite directory: %v", err)
return nil, nil, err
}
dialector = sqlite.Open(dsn + "?_busy_timeout=5000")
adapter = dialect.NewSQLiteAdapter()
}
db, err := gorm.Open(dialector, &gorm.Config{
Logger: gormLogger,
PrepareStmt: true,
})
if err != nil {
Logger.Errorf("Failed to open database connection: %v", err)
return nil, nil, err
}
sqlDB, err := db.DB()
if err != nil {
Logger.Errorf("Failed to get underlying sql.DB: %v", err)
return nil, nil, err
}
sqlDB.SetMaxIdleConns(dbConfig.MaxIdleConns)
sqlDB.SetMaxOpenConns(dbConfig.MaxOpenConns)
sqlDB.SetConnMaxLifetime(dbConfig.ConnMaxLifetime)
Logger.Infof("Connection pool configured: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
dbConfig.MaxIdleConns, dbConfig.MaxOpenConns, dbConfig.ConnMaxLifetime)
Logger.Info("Database connection established successfully.")
return db, adapter, nil
}

View File

@@ -0,0 +1,14 @@
// Filename: internal/db/dialect/dialect.go
package dialect
import (
"gorm.io/gorm/clause"
)
// “通用语言”接口。
type DialectAdapter interface {
// OnConflictUpdateAll 生成一个完整的、适用于当前数据库的 "ON CONFLICT DO UPDATE" 子句。
// conflictColumns: 唯一的约束列,例如 ["time", "group_id", "model_name"]
// updateColumns: 需要累加更新的列,例如 ["request_count", "success_count", ...]
OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression
}

View File

@@ -0,0 +1,30 @@
// Filename: internal/db/mysql_adapter.go
package dialect
import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type mysqlAdapter struct{}
func NewMySQLAdapter() DialectAdapter {
return &mysqlAdapter{}
}
func (a *mysqlAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
conflictCols := make([]clause.Column, len(conflictColumns))
for i, col := range conflictColumns {
conflictCols[i] = clause.Column{Name: col}
}
assignments := make(map[string]interface{})
for _, col := range updateColumns {
assignments[col] = gorm.Expr(col + " + VALUES(" + col + ")")
}
return clause.OnConflict{
Columns: conflictCols,
DoUpdates: clause.Assignments(assignments),
}
}

View File

@@ -0,0 +1,30 @@
// Filename: internal/db/dialect/postgres_adapter.go (全新文件)
package dialect
import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type postgresAdapter struct{}
func NewPostgresAdapter() DialectAdapter {
return &postgresAdapter{}
}
func (a *postgresAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
conflictCols := make([]clause.Column, len(conflictColumns))
for i, col := range conflictColumns {
conflictCols[i] = clause.Column{Name: col}
}
assignments := make(map[string]interface{})
for _, col := range updateColumns {
assignments[col] = gorm.Expr(col + " + excluded." + col)
}
return clause.OnConflict{
Columns: conflictCols,
DoUpdates: clause.Assignments(assignments),
}
}

View File

@@ -0,0 +1,30 @@
// Filename: internal/db/sqlite_adapter.go (全新文件 - 最终版)
package dialect
import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type sqliteAdapter struct{}
func NewSQLiteAdapter() DialectAdapter {
return &sqliteAdapter{}
}
func (a *sqliteAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
conflictCols := make([]clause.Column, len(conflictColumns))
for i, col := range conflictColumns {
conflictCols[i] = clause.Column{Name: col}
}
assignments := make(map[string]interface{})
for _, col := range updateColumns {
assignments[col] = gorm.Expr(col + " + excluded." + col)
}
return clause.OnConflict{
Columns: conflictCols,
DoUpdates: clause.Assignments(assignments),
}
}

View File

@@ -0,0 +1,36 @@
// Filename: internal/db/migrations/migrations.go (全新)
package migrations
import (
"gemini-balancer/internal/models"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// RunMigrations 负责执行所有的数据库模式迁移。
func RunMigrations(db *gorm.DB, logger *logrus.Logger) error {
log := logger.WithField("component", "migrations")
log.Info("Running database schema migrations...")
// 集中管理所有需要被创建或更新的表。
err := db.AutoMigrate(
&models.UpstreamEndpoint{},
&models.ProxyConfig{},
&models.APIKey{},
&models.KeyGroup{},
&models.GroupModelMapping{},
&models.AuthToken{},
&models.RequestLog{},
&models.StatsHourly{},
&models.FileRecord{},
&models.Setting{},
&models.GroupSettings{},
&models.GroupAPIKeyMapping{},
)
if err != nil {
log.Errorf("Database schema migration failed: %v", err)
return err
}
log.Info("Database schema migrations completed successfully.")
return nil
}

View File

@@ -0,0 +1,62 @@
// Filename: internal/db/migrations/versioned_migrations.go
package migrations
import (
"fmt"
"gemini-balancer/internal/config"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// RunVersionedMigrations 负责运行所有已注册的版本化迁移。
func RunVersionedMigrations(db *gorm.DB, cfg *config.Config, logger *logrus.Logger) error {
log := logger.WithField("component", "versioned_migrations")
log.Info("Checking for versioned database migrations...")
if err := db.AutoMigrate(&MigrationHistory{}); err != nil {
log.Errorf("Failed to create migration history table: %v", err)
return err
}
var executedMigrations []MigrationHistory
db.Find(&executedMigrations)
executedVersions := make(map[string]bool)
for _, m := range executedMigrations {
executedVersions[m.Version] = true
}
for _, migration := range migrationRegistry {
if !executedVersions[migration.Version] {
log.Infof("Running migration %s: %s", migration.Version, migration.Description)
if err := migration.Migrate(db, cfg, log); err != nil {
log.Errorf("Migration %s failed: %v", migration.Version, err)
return fmt.Errorf("migration %s failed: %w", migration.Version, err)
}
db.Create(&MigrationHistory{Version: migration.Version})
log.Infof("Migration %s completed successfully.", migration.Version)
}
}
log.Info("All versioned migrations are up to date.")
return nil
}
type MigrationFunc func(db *gorm.DB, cfg *config.Config, logger *logrus.Entry) error
type VersionedMigration struct {
Version string
Description string
Migrate MigrationFunc
}
type MigrationHistory struct {
Version string `gorm:"primaryKey"`
}
var migrationRegistry = []VersionedMigration{
/*{
Version: "20250828_encrypt_existing_auth_tokens",
Description: "Encrypt plaintext tokens and populate new crypto columns in auth_tokens table.",
Migrate: MigrateAuthTokenEncryption,
},*/
}

View File

@@ -0,0 +1,87 @@
// Filename: internal/db/seeder/seeder.go
package seeder
import (
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// RunSeeder now requires the crypto service to create the initial admin token securely.
func RunSeeder(db *gorm.DB, cryptoService *crypto.Service, logger *logrus.Logger) {
log := logger.WithField("component", "seeder")
log.Info("Running database seeder...")
// [REFACTORED] Admin token seeding is now crypto-aware.
var count int64
db.Model(&models.AuthToken{}).Where("is_admin = ?", true).Count(&count)
if count == 0 {
log.Info("No admin token found, attempting to seed one...")
const adminTokenPlaintext = "admin-secret-token" // The default token
// 1. Encrypt and Hash the token
encryptedToken, err := cryptoService.Encrypt(adminTokenPlaintext)
if err != nil {
log.Fatalf("FATAL: Failed to encrypt default admin token during seeding: %v. Server cannot start.", err)
return
}
hash := sha256.Sum256([]byte(adminTokenPlaintext))
tokenHash := hex.EncodeToString(hash[:])
// 2. Use the repository to seed the token
// Note: We create a temporary repository instance here just for the seeder.
repo := repository.NewAuthTokenRepository(db, cryptoService, logger)
if err := repo.SeedAdminToken(encryptedToken, tokenHash); err != nil {
log.Warnf("Failed to seed admin token using repository: %v", err)
} else {
log.Infof("Default admin token has been seeded successfully. Please use '%s' for your first login.", adminTokenPlaintext)
}
} else {
log.Info("Admin token already exists, seeder skipped.")
}
// This functionality should be replaced by a proper user/token management UI in the future.
linkAllKeysToDefaultGroup(db, log)
}
// linkAllKeysToDefaultGroup ensures every key belongs to at least one group.
func linkAllKeysToDefaultGroup(db *gorm.DB, logger *logrus.Entry) {
logger.Info("Linking existing API keys to the default group as a fallback...")
// 1. Find a default group (the first one for simplicity)
var defaultGroup models.KeyGroup
if err := db.Order("id asc").First(&defaultGroup).Error; err != nil {
logger.Warnf("Seeder: Could not find a default key group to link keys to: %v", err)
return
}
// 2. Find all "orphan keys" that don't belong to any group
var orphanKeys []*models.APIKey
err := db.Raw(`
SELECT * FROM api_keys
WHERE id NOT IN (SELECT DISTINCT api_key_id FROM group_api_key_mappings)
AND deleted_at IS NULL
`).Scan(&orphanKeys).Error
if err != nil {
logger.Errorf("Seeder: Failed to query for orphan keys: %v", err)
return
}
if len(orphanKeys) == 0 {
logger.Info("Seeder: No orphan API keys found to link.")
return
}
// 3. Create GroupAPIKeyMapping records manually
logger.Infof("Seeder: Found %d orphan keys. Creating mappings for them in group '%s' (ID: %d)...", len(orphanKeys), defaultGroup.Name, defaultGroup.ID)
var newMappings []models.GroupAPIKeyMapping
for _, key := range orphanKeys {
newMappings = append(newMappings, models.GroupAPIKeyMapping{
KeyGroupID: defaultGroup.ID,
APIKeyID: key.ID,
})
}
if err := db.Create(&newMappings).Error; err != nil {
logger.Errorf("Seeder: Failed to create key mappings for orphan keys: %v", err)
} else {
logger.Info("Successfully created mappings for orphan API keys.")
}
}