Files
gemini-banlancer/internal/db/db.go
2025-11-20 12:24:05 +08:00

89 lines
2.6 KiB
Go

// Filename: internal/db/db.go
package db
import (
"gemini-balancer/internal/config"
"gemini-balancer/internal/db/dialect"
stdlog "log"
"os"
"path/filepath"
"strings"
"time"
"github.com/glebarez/sqlite"
"github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func NewDB(cfg *config.Config, appLogger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
Logger := appLogger.WithField("component", "db")
Logger.Info("Initializing database connection and dialect adapter...")
dbConfig := cfg.Database
dsn := dbConfig.DSN
var gormLogger logger.Interface
if cfg.Log.Level == "debug" {
gormLogger = logger.New(
stdlog.New(os.Stdout, "\r\n", stdlog.LstdFlags),
logger.Config{
SlowThreshold: 1 * time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: true,
},
)
Logger.Info("Debug mode enabled, GORM SQL logging is active.")
}
var dialector gorm.Dialector
var adapter dialect.DialectAdapter
switch {
case strings.HasPrefix(dsn, "postgres://"), strings.HasPrefix(dsn, "postgresql://"):
Logger.Info("Detected PostgreSQL database.")
dialector = postgres.Open(dsn)
adapter = dialect.NewPostgresAdapter()
case strings.Contains(dsn, "@tcp"):
Logger.Info("Detected MySQL database.")
if !strings.Contains(dsn, "parseTime=true") {
if strings.Contains(dsn, "?") {
dsn += "&parseTime=true"
} else {
dsn += "?parseTime=true"
}
}
dialector = mysql.Open(dsn)
adapter = dialect.NewPostgresAdapter()
default:
Logger.Info("Using SQLite database.")
if err := os.MkdirAll(filepath.Dir(dsn), 0755); err != nil {
Logger.Errorf("Failed to create SQLite directory: %v", err)
return nil, nil, err
}
dialector = sqlite.Open(dsn + "?_busy_timeout=5000")
adapter = dialect.NewSQLiteAdapter()
}
db, err := gorm.Open(dialector, &gorm.Config{
Logger: gormLogger,
PrepareStmt: true,
})
if err != nil {
Logger.Errorf("Failed to open database connection: %v", err)
return nil, nil, err
}
sqlDB, err := db.DB()
if err != nil {
Logger.Errorf("Failed to get underlying sql.DB: %v", err)
return nil, nil, err
}
sqlDB.SetMaxIdleConns(dbConfig.MaxIdleConns)
sqlDB.SetMaxOpenConns(dbConfig.MaxOpenConns)
sqlDB.SetConnMaxLifetime(dbConfig.ConnMaxLifetime)
Logger.Infof("Connection pool configured: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
dbConfig.MaxIdleConns, dbConfig.MaxOpenConns, dbConfig.ConnMaxLifetime)
Logger.Info("Database connection established successfully.")
return db, adapter, nil
}