Files
gemini-banlancer/internal/router/router.go
2025-11-25 16:58:15 +08:00

417 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/router/router.go
package router
import (
"gemini-balancer/internal/config"
"gemini-balancer/internal/domain/proxy"
"gemini-balancer/internal/domain/upstream"
"gemini-balancer/internal/handlers"
"gemini-balancer/internal/middleware"
"gemini-balancer/internal/pongo"
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/webhandlers"
"net/http"
"os"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
func NewRouter(
// Core Services
cfg *config.Config,
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
// Core Handlers
proxyHandler *handlers.ProxyHandler,
apiAuthHandler *handlers.APIAuthHandler,
// Admin API Handlers
keyGroupHandler *handlers.KeyGroupHandler,
apiKeyHandler *handlers.APIKeyHandler,
tokensHandler *handlers.TokensHandler,
logHandler *handlers.LogHandler,
settingHandler *handlers.SettingHandler,
dashboardHandler *handlers.DashboardHandler,
taskHandler *handlers.TaskHandler,
wsHandler *handlers.WebSocketHandler,
// Web Page Handlers
webAuthHandler *webhandlers.WebAuthHandler,
pageHandler *webhandlers.PageHandler,
// === Domain Modules ===
upstreamModule *upstream.Module,
proxyModule *proxy.Module,
) *gin.Engine {
// === 1. 创建全局 Logger统一管理===
logger := createLogger(cfg)
// === 2. 设置 Gin 运行模式 ===
if cfg.Log.Level != "debug" {
gin.SetMode(gin.ReleaseMode)
}
// === 3. 创建 Router使用 gin.New() 以便完全控制中间件)===
router := gin.New()
// === 4. 注册全局中间件(按执行顺序)===
setupGlobalMiddleware(router, logger)
// === 5. 配置静态文件和模板 ===
setupStaticAndTemplates(router, logger)
// === 6. 配置 CORS ===
setupCORS(router, cfg)
// === 7. 注册基础路由 ===
setupBasicRoutes(router)
// === 8. 创建认证中间件 ===
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger)
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
// === 9. 注册业务路由(按功能分组)===
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger)
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler,
logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
registerWebSocketRoutes(router, wsHandler)
registerProxyRoutes(router, proxyHandler, securityService, logger)
return router
}
// ==================== 辅助函数 ====================
// createLogger 创建并配置全局 Logger
func createLogger(cfg *config.Config) *logrus.Logger {
logger := logrus.New()
// 设置日志格式
if cfg.Log.Format == "json" {
logger.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: time.RFC3339,
})
} else {
logger.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
})
}
// 设置日志级别
switch cfg.Log.Level {
case "debug":
logger.SetLevel(logrus.DebugLevel)
case "info":
logger.SetLevel(logrus.InfoLevel)
case "warn":
logger.SetLevel(logrus.WarnLevel)
case "error":
logger.SetLevel(logrus.ErrorLevel)
default:
logger.SetLevel(logrus.InfoLevel)
}
// 设置输出(可选:输出到文件)
logger.SetOutput(os.Stdout)
return logger
}
// setupGlobalMiddleware 设置全局中间件
func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) {
// 1. 请求 ID 中间件(用于链路追踪)
router.Use(middleware.RequestIDMiddleware())
// 2. 数据脱敏中间件(在日志前执行)
router.Use(middleware.RedactionMiddleware())
// 3. 日志中间件
router.Use(middleware.LogrusLogger(logger))
// 4. 错误恢复中间件
router.Use(gin.RecoveryWithWriter(os.Stdout))
}
// setupStaticAndTemplates 配置静态文件和模板
func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) {
router.Static("/static", "./web/static")
isDebug := gin.Mode() != gin.ReleaseMode
router.HTMLRender = pongo.New("web/templates", isDebug, logger)
}
// setupCORS 配置 CORS
func setupCORS(router *gin.Engine, cfg *config.Config) {
corsConfig := cors.Config{
AllowOrigins: getCORSOrigins(cfg),
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"},
ExposeHeaders: []string{"Content-Length", "X-Request-Id"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
}
router.Use(cors.New(corsConfig))
}
// getCORSOrigins 获取 CORS 允许的来源
func getCORSOrigins(cfg *config.Config) []string {
// 默认值
origins := []string{"http://localhost:9000"}
// 从配置读取(修复:移除 nil 检查)
if len(cfg.Server.CORSOrigins) > 0 {
origins = cfg.Server.CORSOrigins
}
return origins
}
// setupBasicRoutes 设置基础路由
func setupBasicRoutes(router *gin.Engine) {
// 根路径重定向
router.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/dashboard")
})
// 健康检查
router.GET("/health", handleHealthCheck)
// 版本信息(可选)
router.GET("/version", handleVersion)
}
// handleHealthCheck 健康检查处理器
func handleHealthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"time": time.Now().Unix(),
})
}
// handleVersion 版本信息处理器
func handleVersion(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"version": "1.0.0", // 可以从配置或编译时变量读取
"build": "latest",
})
}
// ==================== 路由注册函数 ====================
// registerProxyRoutes 注册代理路由
func registerProxyRoutes(
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
securityService *service.SecurityService,
logger *logrus.Logger,
) {
// 创建代理认证中间件
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger)
// 模式一: 智能聚合模式(默认入口)
registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
// 模式二: 精确路由模式(按组名路由)
registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
}
// registerAggregateProxyRoutes 注册聚合代理路由
func registerAggregateProxyRoutes(
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
authMiddleware gin.HandlerFunc,
) {
// /v1 路径组
v1 := router.Group("/v1")
v1.Use(authMiddleware)
{
v1.Any("/*path", proxyHandler.HandleProxy)
}
// /v1beta 路径组
v1beta := router.Group("/v1beta")
v1beta.Use(authMiddleware)
{
v1beta.Any("/*path", proxyHandler.HandleProxy)
}
}
// registerGroupProxyRoutes 注册分组代理路由
func registerGroupProxyRoutes(
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
authMiddleware gin.HandlerFunc,
) {
proxyGroup := router.Group("/proxy/:group_name")
proxyGroup.Use(authMiddleware)
{
proxyGroup.Any("/*path", proxyHandler.HandleProxy)
}
}
// registerAdminRoutes 注册管理后台 API 路由
func registerAdminRoutes(
router *gin.Engine,
authMiddleware gin.HandlerFunc,
keyGroupHandler *handlers.KeyGroupHandler,
tokensHandler *handlers.TokensHandler,
apiKeyHandler *handlers.APIKeyHandler,
logHandler *handlers.LogHandler,
settingHandler *handlers.SettingHandler,
dashboardHandler *handlers.DashboardHandler,
taskHandler *handlers.TaskHandler,
upstreamModule *upstream.Module,
proxyModule *proxy.Module,
) {
admin := router.Group("/admin", authMiddleware)
{
// KeyGroup 路由
registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler)
// APIKey 全局路由
registerAPIKeyRoutes(admin, apiKeyHandler)
// 系统管理路由
registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler)
// 仪表盘路由
registerDashboardRoutes(admin, dashboardHandler)
// 领域模块路由
upstreamModule.RegisterRoutes(admin)
proxyModule.RegisterRoutes(admin)
}
}
// registerKeyGroupRoutes 注册 KeyGroup 相关路由
func registerKeyGroupRoutes(
admin *gin.RouterGroup,
keyGroupHandler *handlers.KeyGroupHandler,
apiKeyHandler *handlers.APIKeyHandler,
) {
// 基础路由
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
// 特定 KeyGroup 路由
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
// KeyGroup 的 APIKey 子资源
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
{
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
}
}
// registerAPIKeyRoutes 注册 APIKey 全局路由
func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) {
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys)
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey)
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys)
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys)
}
// registerSystemRoutes 注册系统管理路由
func registerSystemRoutes(
admin *gin.RouterGroup,
tokensHandler *handlers.TokensHandler,
logHandler *handlers.LogHandler,
settingHandler *handlers.SettingHandler,
taskHandler *handlers.TaskHandler,
) {
// Token 管理
admin.GET("/tokens", tokensHandler.GetAllTokens)
admin.PUT("/tokens", tokensHandler.UpdateTokens)
// 日志管理
admin.GET("/logs", logHandler.GetLogs)
// 设置管理
admin.GET("/settings", settingHandler.GetSettings)
admin.PUT("/settings", settingHandler.UpdateSettings)
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
// 任务管理
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
}
// registerDashboardRoutes 注册仪表盘路由
func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/overview", dashboardHandler.GetOverview)
dashboard.GET("/chart", dashboardHandler.GetChart)
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats)
}
}
// registerWebRoutes 注册 Web 页面路由
func registerWebRoutes(
router *gin.Engine,
authMiddleware gin.HandlerFunc,
webAuthHandler *webhandlers.WebAuthHandler,
pageHandler *webhandlers.PageHandler,
) {
// 公开的认证路由
router.GET("/login", webAuthHandler.ShowLoginPage)
router.POST("/login", webAuthHandler.HandleLogin)
router.GET("/logout", webAuthHandler.HandleLogout)
// 受保护的管理界面
webGroup := router.Group("/", authMiddleware)
{
webGroup.GET("/keys", pageHandler.ShowKeysPage)
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
webGroup.GET("/logs", pageHandler.ShowErrorLogsPage)
webGroup.GET("/dashboard", pageHandler.ShowDashboardPage)
webGroup.GET("/tasks", pageHandler.ShowTasksPage)
webGroup.GET("/chat", pageHandler.ShowChatPage)
}
}
// registerPublicAPIRoutes 注册公共 API 路由
func registerPublicAPIRoutes(
router *gin.Engine,
apiAuthHandler *handlers.APIAuthHandler,
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
logger *logrus.Logger,
) {
// 创建 IP 封禁中间件
ipBanCache := middleware.NewIPBanCache()
ipBanMiddleware := middleware.IPBanMiddleware(
securityService,
settingsManager,
ipBanCache,
logger,
)
// 公共 API 路由组
publicAPI := router.Group("/api")
{
publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
// 可以在这里添加其他公共 API 路由
// publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister)
// publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword)
}
}
func registerWebSocketRoutes(router *gin.Engine, wsHandler *handlers.WebSocketHandler) {
router.GET("/ws/system-logs", wsHandler.HandleSystemLogs)
}