417 lines
12 KiB
Go
417 lines
12 KiB
Go
// Filename: internal/router/router.go
|
||
package router
|
||
|
||
import (
|
||
"gemini-balancer/internal/config"
|
||
"gemini-balancer/internal/domain/proxy"
|
||
"gemini-balancer/internal/domain/upstream"
|
||
"gemini-balancer/internal/handlers"
|
||
"gemini-balancer/internal/middleware"
|
||
"gemini-balancer/internal/pongo"
|
||
"gemini-balancer/internal/service"
|
||
"gemini-balancer/internal/settings"
|
||
"gemini-balancer/internal/webhandlers"
|
||
"net/http"
|
||
"os"
|
||
"time"
|
||
|
||
"github.com/gin-contrib/cors"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
func NewRouter(
|
||
// Core Services
|
||
cfg *config.Config,
|
||
securityService *service.SecurityService,
|
||
settingsManager *settings.SettingsManager,
|
||
// Core Handlers
|
||
proxyHandler *handlers.ProxyHandler,
|
||
apiAuthHandler *handlers.APIAuthHandler,
|
||
// Admin API Handlers
|
||
keyGroupHandler *handlers.KeyGroupHandler,
|
||
apiKeyHandler *handlers.APIKeyHandler,
|
||
tokensHandler *handlers.TokensHandler,
|
||
logHandler *handlers.LogHandler,
|
||
settingHandler *handlers.SettingHandler,
|
||
dashboardHandler *handlers.DashboardHandler,
|
||
taskHandler *handlers.TaskHandler,
|
||
wsHandler *handlers.WebSocketHandler,
|
||
// Web Page Handlers
|
||
webAuthHandler *webhandlers.WebAuthHandler,
|
||
pageHandler *webhandlers.PageHandler,
|
||
// === Domain Modules ===
|
||
upstreamModule *upstream.Module,
|
||
proxyModule *proxy.Module,
|
||
) *gin.Engine {
|
||
// === 1. 创建全局 Logger(统一管理)===
|
||
logger := createLogger(cfg)
|
||
|
||
// === 2. 设置 Gin 运行模式 ===
|
||
if cfg.Log.Level != "debug" {
|
||
gin.SetMode(gin.ReleaseMode)
|
||
}
|
||
|
||
// === 3. 创建 Router(使用 gin.New() 以便完全控制中间件)===
|
||
router := gin.New()
|
||
|
||
// === 4. 注册全局中间件(按执行顺序)===
|
||
setupGlobalMiddleware(router, logger)
|
||
|
||
// === 5. 配置静态文件和模板 ===
|
||
setupStaticAndTemplates(router, logger)
|
||
|
||
// === 6. 配置 CORS ===
|
||
setupCORS(router, cfg)
|
||
|
||
// === 7. 注册基础路由 ===
|
||
setupBasicRoutes(router)
|
||
|
||
// === 8. 创建认证中间件 ===
|
||
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger)
|
||
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
|
||
|
||
// === 9. 注册业务路由(按功能分组)===
|
||
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger)
|
||
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
|
||
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler,
|
||
logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
|
||
registerWebSocketRoutes(router, wsHandler)
|
||
registerProxyRoutes(router, proxyHandler, securityService, logger)
|
||
|
||
return router
|
||
}
|
||
|
||
// ==================== 辅助函数 ====================
|
||
|
||
// createLogger 创建并配置全局 Logger
|
||
func createLogger(cfg *config.Config) *logrus.Logger {
|
||
logger := logrus.New()
|
||
|
||
// 设置日志格式
|
||
if cfg.Log.Format == "json" {
|
||
logger.SetFormatter(&logrus.JSONFormatter{
|
||
TimestampFormat: time.RFC3339,
|
||
})
|
||
} else {
|
||
logger.SetFormatter(&logrus.TextFormatter{
|
||
FullTimestamp: true,
|
||
TimestampFormat: "2006-01-02 15:04:05",
|
||
})
|
||
}
|
||
|
||
// 设置日志级别
|
||
switch cfg.Log.Level {
|
||
case "debug":
|
||
logger.SetLevel(logrus.DebugLevel)
|
||
case "info":
|
||
logger.SetLevel(logrus.InfoLevel)
|
||
case "warn":
|
||
logger.SetLevel(logrus.WarnLevel)
|
||
case "error":
|
||
logger.SetLevel(logrus.ErrorLevel)
|
||
default:
|
||
logger.SetLevel(logrus.InfoLevel)
|
||
}
|
||
|
||
// 设置输出(可选:输出到文件)
|
||
logger.SetOutput(os.Stdout)
|
||
|
||
return logger
|
||
}
|
||
|
||
// setupGlobalMiddleware 设置全局中间件
|
||
func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) {
|
||
// 1. 请求 ID 中间件(用于链路追踪)
|
||
router.Use(middleware.RequestIDMiddleware())
|
||
|
||
// 2. 数据脱敏中间件(在日志前执行)
|
||
router.Use(middleware.RedactionMiddleware())
|
||
|
||
// 3. 日志中间件
|
||
router.Use(middleware.LogrusLogger(logger))
|
||
|
||
// 4. 错误恢复中间件
|
||
router.Use(gin.RecoveryWithWriter(os.Stdout))
|
||
}
|
||
|
||
// setupStaticAndTemplates 配置静态文件和模板
|
||
func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) {
|
||
router.Static("/static", "./web/static")
|
||
|
||
isDebug := gin.Mode() != gin.ReleaseMode
|
||
router.HTMLRender = pongo.New("web/templates", isDebug, logger)
|
||
}
|
||
|
||
// setupCORS 配置 CORS
|
||
func setupCORS(router *gin.Engine, cfg *config.Config) {
|
||
corsConfig := cors.Config{
|
||
AllowOrigins: getCORSOrigins(cfg),
|
||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"},
|
||
ExposeHeaders: []string{"Content-Length", "X-Request-Id"},
|
||
AllowCredentials: true,
|
||
MaxAge: 12 * time.Hour,
|
||
}
|
||
router.Use(cors.New(corsConfig))
|
||
}
|
||
|
||
// getCORSOrigins 获取 CORS 允许的来源
|
||
func getCORSOrigins(cfg *config.Config) []string {
|
||
// 默认值
|
||
origins := []string{"http://localhost:9000"}
|
||
|
||
// 从配置读取(修复:移除 nil 检查)
|
||
if len(cfg.Server.CORSOrigins) > 0 {
|
||
origins = cfg.Server.CORSOrigins
|
||
}
|
||
|
||
return origins
|
||
}
|
||
|
||
// setupBasicRoutes 设置基础路由
|
||
func setupBasicRoutes(router *gin.Engine) {
|
||
// 根路径重定向
|
||
router.GET("/", func(c *gin.Context) {
|
||
c.Redirect(http.StatusMovedPermanently, "/dashboard")
|
||
})
|
||
|
||
// 健康检查
|
||
router.GET("/health", handleHealthCheck)
|
||
|
||
// 版本信息(可选)
|
||
router.GET("/version", handleVersion)
|
||
}
|
||
|
||
// handleHealthCheck 健康检查处理器
|
||
func handleHealthCheck(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"status": "ok",
|
||
"time": time.Now().Unix(),
|
||
})
|
||
}
|
||
|
||
// handleVersion 版本信息处理器
|
||
func handleVersion(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"version": "1.0.0", // 可以从配置或编译时变量读取
|
||
"build": "latest",
|
||
})
|
||
}
|
||
|
||
// ==================== 路由注册函数 ====================
|
||
|
||
// registerProxyRoutes 注册代理路由
|
||
func registerProxyRoutes(
|
||
router *gin.Engine,
|
||
proxyHandler *handlers.ProxyHandler,
|
||
securityService *service.SecurityService,
|
||
logger *logrus.Logger,
|
||
) {
|
||
// 创建代理认证中间件
|
||
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger)
|
||
|
||
// 模式一: 智能聚合模式(默认入口)
|
||
registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
|
||
|
||
// 模式二: 精确路由模式(按组名路由)
|
||
registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
|
||
}
|
||
|
||
// registerAggregateProxyRoutes 注册聚合代理路由
|
||
func registerAggregateProxyRoutes(
|
||
router *gin.Engine,
|
||
proxyHandler *handlers.ProxyHandler,
|
||
authMiddleware gin.HandlerFunc,
|
||
) {
|
||
// /v1 路径组
|
||
v1 := router.Group("/v1")
|
||
v1.Use(authMiddleware)
|
||
{
|
||
v1.Any("/*path", proxyHandler.HandleProxy)
|
||
}
|
||
|
||
// /v1beta 路径组
|
||
v1beta := router.Group("/v1beta")
|
||
v1beta.Use(authMiddleware)
|
||
{
|
||
v1beta.Any("/*path", proxyHandler.HandleProxy)
|
||
}
|
||
}
|
||
|
||
// registerGroupProxyRoutes 注册分组代理路由
|
||
func registerGroupProxyRoutes(
|
||
router *gin.Engine,
|
||
proxyHandler *handlers.ProxyHandler,
|
||
authMiddleware gin.HandlerFunc,
|
||
) {
|
||
proxyGroup := router.Group("/proxy/:group_name")
|
||
proxyGroup.Use(authMiddleware)
|
||
{
|
||
proxyGroup.Any("/*path", proxyHandler.HandleProxy)
|
||
}
|
||
}
|
||
|
||
// registerAdminRoutes 注册管理后台 API 路由
|
||
func registerAdminRoutes(
|
||
router *gin.Engine,
|
||
authMiddleware gin.HandlerFunc,
|
||
keyGroupHandler *handlers.KeyGroupHandler,
|
||
tokensHandler *handlers.TokensHandler,
|
||
apiKeyHandler *handlers.APIKeyHandler,
|
||
logHandler *handlers.LogHandler,
|
||
settingHandler *handlers.SettingHandler,
|
||
dashboardHandler *handlers.DashboardHandler,
|
||
taskHandler *handlers.TaskHandler,
|
||
upstreamModule *upstream.Module,
|
||
proxyModule *proxy.Module,
|
||
) {
|
||
admin := router.Group("/admin", authMiddleware)
|
||
{
|
||
// KeyGroup 路由
|
||
registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler)
|
||
|
||
// APIKey 全局路由
|
||
registerAPIKeyRoutes(admin, apiKeyHandler)
|
||
|
||
// 系统管理路由
|
||
registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler)
|
||
|
||
// 仪表盘路由
|
||
registerDashboardRoutes(admin, dashboardHandler)
|
||
|
||
// 领域模块路由
|
||
upstreamModule.RegisterRoutes(admin)
|
||
proxyModule.RegisterRoutes(admin)
|
||
}
|
||
}
|
||
|
||
// registerKeyGroupRoutes 注册 KeyGroup 相关路由
|
||
func registerKeyGroupRoutes(
|
||
admin *gin.RouterGroup,
|
||
keyGroupHandler *handlers.KeyGroupHandler,
|
||
apiKeyHandler *handlers.APIKeyHandler,
|
||
) {
|
||
// 基础路由
|
||
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
|
||
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
|
||
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
|
||
|
||
// 特定 KeyGroup 路由
|
||
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
|
||
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
|
||
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
|
||
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
|
||
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
|
||
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
|
||
|
||
// KeyGroup 的 APIKey 子资源
|
||
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
|
||
{
|
||
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
|
||
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
|
||
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
|
||
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
|
||
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
|
||
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
|
||
}
|
||
}
|
||
|
||
// registerAPIKeyRoutes 注册 APIKey 全局路由
|
||
func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) {
|
||
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
|
||
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys)
|
||
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey)
|
||
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys)
|
||
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys)
|
||
}
|
||
|
||
// registerSystemRoutes 注册系统管理路由
|
||
func registerSystemRoutes(
|
||
admin *gin.RouterGroup,
|
||
tokensHandler *handlers.TokensHandler,
|
||
logHandler *handlers.LogHandler,
|
||
settingHandler *handlers.SettingHandler,
|
||
taskHandler *handlers.TaskHandler,
|
||
) {
|
||
// Token 管理
|
||
admin.GET("/tokens", tokensHandler.GetAllTokens)
|
||
admin.PUT("/tokens", tokensHandler.UpdateTokens)
|
||
|
||
// 日志管理
|
||
admin.GET("/logs", logHandler.GetLogs)
|
||
|
||
// 设置管理
|
||
admin.GET("/settings", settingHandler.GetSettings)
|
||
admin.PUT("/settings", settingHandler.UpdateSettings)
|
||
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
|
||
|
||
// 任务管理
|
||
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
|
||
}
|
||
|
||
// registerDashboardRoutes 注册仪表盘路由
|
||
func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) {
|
||
dashboard := admin.Group("/dashboard")
|
||
{
|
||
dashboard.GET("/overview", dashboardHandler.GetOverview)
|
||
dashboard.GET("/chart", dashboardHandler.GetChart)
|
||
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats)
|
||
}
|
||
}
|
||
|
||
// registerWebRoutes 注册 Web 页面路由
|
||
func registerWebRoutes(
|
||
router *gin.Engine,
|
||
authMiddleware gin.HandlerFunc,
|
||
webAuthHandler *webhandlers.WebAuthHandler,
|
||
pageHandler *webhandlers.PageHandler,
|
||
) {
|
||
// 公开的认证路由
|
||
router.GET("/login", webAuthHandler.ShowLoginPage)
|
||
router.POST("/login", webAuthHandler.HandleLogin)
|
||
router.GET("/logout", webAuthHandler.HandleLogout)
|
||
|
||
// 受保护的管理界面
|
||
webGroup := router.Group("/", authMiddleware)
|
||
{
|
||
webGroup.GET("/keys", pageHandler.ShowKeysPage)
|
||
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
|
||
webGroup.GET("/logs", pageHandler.ShowErrorLogsPage)
|
||
webGroup.GET("/dashboard", pageHandler.ShowDashboardPage)
|
||
webGroup.GET("/tasks", pageHandler.ShowTasksPage)
|
||
webGroup.GET("/chat", pageHandler.ShowChatPage)
|
||
}
|
||
}
|
||
|
||
// registerPublicAPIRoutes 注册公共 API 路由
|
||
func registerPublicAPIRoutes(
|
||
router *gin.Engine,
|
||
apiAuthHandler *handlers.APIAuthHandler,
|
||
securityService *service.SecurityService,
|
||
settingsManager *settings.SettingsManager,
|
||
logger *logrus.Logger,
|
||
) {
|
||
// 创建 IP 封禁中间件
|
||
ipBanCache := middleware.NewIPBanCache()
|
||
ipBanMiddleware := middleware.IPBanMiddleware(
|
||
securityService,
|
||
settingsManager,
|
||
ipBanCache,
|
||
logger,
|
||
)
|
||
|
||
// 公共 API 路由组
|
||
publicAPI := router.Group("/api")
|
||
{
|
||
publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
|
||
// 可以在这里添加其他公共 API 路由
|
||
// publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister)
|
||
// publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword)
|
||
}
|
||
}
|
||
|
||
func registerWebSocketRoutes(router *gin.Engine, wsHandler *handlers.WebSocketHandler) {
|
||
router.GET("/ws/system-logs", wsHandler.HandleSystemLogs)
|
||
}
|