From f2706d6fc828e709f8427d9ee1b689db1540f64b Mon Sep 17 00:00:00 2001 From: xofine Date: Mon, 24 Nov 2025 04:48:07 +0800 Subject: [PATCH] Fix Services & Update the middleware && others --- cmd/server/main.go | 2 + config.yaml | 6 + go.mod | 2 + go.sum | 4 + internal/config/config.go | 10 +- internal/container/container.go | 2 +- internal/domain/proxy/module.go | 2 +- internal/errors/upstream_errors.go | 1 + internal/handlers/log_handler.go | 25 +- internal/logging/logging.go | 60 +- internal/middleware/auth.go | 92 ++- internal/middleware/cors.go | 90 +++ internal/middleware/log.go | 213 ++++-- internal/middleware/rate_limit.go | 86 +++ internal/middleware/request_id.go | 39 + internal/middleware/security.go | 122 ++- internal/middleware/timeout.go | 52 ++ internal/middleware/web.go | 390 +++++++++- internal/models/runtime.go | 2 + internal/pongo/renderer.go | 104 ++- internal/router/router.go | 384 +++++++--- internal/service/analytics_service.go | 425 +++++++++-- internal/service/dashboard_query_service.go | 773 +++++++++++++++---- internal/service/db_log_writer_service.go | 279 ++++++- internal/service/group_manager.go | 3 +- internal/service/healthcheck_service.go | 782 +++++++++++++++----- internal/service/key_import_service.go | 538 +++++++++----- internal/service/key_validation_service.go | 507 +++++++++---- internal/service/log_service.go | 164 ++-- internal/service/stats_service.go | 126 +++- internal/service/token_manager.go | 2 +- internal/settings/settings.go | 2 +- internal/syncer/syncer.go | 125 +++- internal/task/task.go | 111 +-- internal/webhandlers/auth_handler.go | 95 ++- web/templates/base.html | 2 +- web/templates/settings.html | 2 +- 37 files changed, 4458 insertions(+), 1166 deletions(-) create mode 100644 internal/middleware/cors.go create mode 100644 internal/middleware/rate_limit.go create mode 100644 internal/middleware/request_id.go create mode 100644 internal/middleware/timeout.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 6a40af4..2eda74c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -3,10 +3,12 @@ package main import ( "gemini-balancer/internal/app" "gemini-balancer/internal/container" + "gemini-balancer/internal/logging" "log" ) func main() { + defer logging.Close() cont, err := container.BuildContainer() if err != nil { log.Fatalf("FATAL: Failed to build dependency container: %v", err) diff --git a/config.yaml b/config.yaml index e498190..f0c1824 100644 --- a/config.yaml +++ b/config.yaml @@ -14,6 +14,12 @@ server: log: level: "debug" +# 日志轮转配置 +max_size: 100 # MB +max_backups: 7 # 保留文件数 +max_age: 30 # 保留天数 +compress: true # 压缩旧日志 + redis: dsn: "redis://localhost:6379/0" diff --git a/go.mod b/go.mod index e685f39..38d19c5 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,8 @@ require ( github.com/spf13/viper v1.20.1 go.uber.org/dig v1.19.0 golang.org/x/net v0.42.0 + golang.org/x/time v0.14.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/datatypes v1.0.5 gorm.io/driver/mysql v1.6.0 gorm.io/driver/postgres v1.6.0 diff --git a/go.sum b/go.sum index cce6ad6..fc28ad2 100644 --- a/go.sum +++ b/go.sum @@ -311,6 +311,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= @@ -325,6 +327,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/config/config.go b/internal/config/config.go index e9c877f..0027d5b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,7 +29,9 @@ type DatabaseConfig struct { // ServerConfig 存储HTTP服务器配置 type ServerConfig struct { - Port string `mapstructure:"port"` + Port string `mapstructure:"port"` + Host string `yaml:"host"` + CORSOrigins []string `yaml:"cors_origins"` } // LogConfig 存储日志配置 @@ -38,6 +40,12 @@ type LogConfig struct { Format string `mapstructure:"format" json:"format"` EnableFile bool `mapstructure:"enable_file" json:"enable_file"` FilePath string `mapstructure:"file_path" json:"file_path"` + + // 日志轮转配置(可选) + MaxSize int `yaml:"max_size"` // MB,默认 100 + MaxBackups int `yaml:"max_backups"` // 默认 7 + MaxAge int `yaml:"max_age"` // 天,默认 30 + Compress bool `yaml:"compress"` // 默认 true } type RedisConfig struct { diff --git a/internal/container/container.go b/internal/container/container.go index f9aeb1e..e8a12e3 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -87,7 +87,7 @@ func BuildContainer() (*dig.Container, error) { // 为GroupManager配置Syncer container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) { const groupUpdateChannel = "groups:cache_invalidation" - return syncer.NewCacheSyncer(loader, store, groupUpdateChannel) + return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger) }) // =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) =========== diff --git a/internal/domain/proxy/module.go b/internal/domain/proxy/module.go index 3cd794c..f1b59ae 100644 --- a/internal/domain/proxy/module.go +++ b/internal/domain/proxy/module.go @@ -20,7 +20,7 @@ type Module struct { func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) { loader := newManagerLoader(gormDB) - cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation") + cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation", logger) if err != nil { return nil, err } diff --git a/internal/errors/upstream_errors.go b/internal/errors/upstream_errors.go index 751e35a..c85e456 100644 --- a/internal/errors/upstream_errors.go +++ b/internal/errors/upstream_errors.go @@ -71,6 +71,7 @@ var clientNetworkErrorSubstrings = []string{ "broken pipe", "use of closed network connection", "request canceled", + "invalid query parameters", // 参数解析错误,归类为客户端错误 } // IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid. diff --git a/internal/handlers/log_handler.go b/internal/handlers/log_handler.go index 8e1edeb..4440dbf 100644 --- a/internal/handlers/log_handler.go +++ b/internal/handlers/log_handler.go @@ -5,7 +5,6 @@ import ( "gemini-balancer/internal/errors" "gemini-balancer/internal/response" "gemini-balancer/internal/service" - "strconv" "github.com/gin-gonic/gin" ) @@ -19,22 +18,26 @@ func NewLogHandler(logService *service.LogService) *LogHandler { } func (h *LogHandler) GetLogs(c *gin.Context) { - // 调用新的服务函数,接收日志列表和总数 - logs, total, err := h.logService.GetLogs(c) + queryParams := make(map[string]string) + for key, values := range c.Request.URL.Query() { + if len(values) > 0 { + queryParams[key] = values[0] + } + } + params, err := service.ParseLogQueryParams(queryParams) + if err != nil { + response.Error(c, errors.ErrBadRequest) + return + } + logs, total, err := h.logService.GetLogs(c.Request.Context(), params) if err != nil { response.Error(c, errors.ErrDatabase) return } - - // 解析分页参数用于响应体 - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) - - // 使用标准的分页响应结构 response.Success(c, gin.H{ "items": logs, "total": total, - "page": page, - "page_size": pageSize, + "page": params.Page, + "page_size": params.PageSize, }) } diff --git a/internal/logging/logging.go b/internal/logging/logging.go index ddf766f..db75182 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -9,20 +9,25 @@ import ( "path/filepath" "github.com/sirupsen/logrus" + "gopkg.in/natefinch/lumberjack.v2" ) +// 包级变量,用于存储日志轮转器 +var logRotator *lumberjack.Logger + +// NewLogger 返回标准的 *logrus.Logger(兼容 Fx 依赖注入) func NewLogger(cfg *config.Config) *logrus.Logger { logger := logrus.New() - // 1. 设置日志级别 + // 设置日志级别 level, err := logrus.ParseLevel(cfg.Log.Level) if err != nil { - logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level specified, defaulting to 'info'.") + logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level, defaulting to 'info'") level = logrus.InfoLevel } logger.SetLevel(level) - // 2. 设置日志格式 + // 设置日志格式 if cfg.Log.Format == "json" { logger.SetFormatter(&logrus.JSONFormatter{ TimestampFormat: "2006-01-02T15:04:05.000Z07:00", @@ -39,36 +44,57 @@ func NewLogger(cfg *config.Config) *logrus.Logger { }) } - // 3. 设置日志输出 + // 添加全局字段 + hostname, _ := os.Hostname() + logger = logger.WithFields(logrus.Fields{ + "service": "gemini-balancer", + "hostname": hostname, + }).Logger + + // 设置日志输出 if cfg.Log.EnableFile { if cfg.Log.FilePath == "" { - logger.Warn("Log file is enabled but no file path is specified. Logging to console only.") + logger.Warn("Log file enabled but no path specified. Logging to console only") logger.SetOutput(os.Stdout) return logger } logDir := filepath.Dir(cfg.Log.FilePath) - if err := os.MkdirAll(logDir, 0755); err != nil { - logger.WithError(err).Warn("Failed to create log directory. Logging to console only.") + if err := os.MkdirAll(logDir, 0750); err != nil { + logger.WithError(err).Warn("Failed to create log directory. Logging to console only") logger.SetOutput(os.Stdout) return logger } - logFile, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - logger.WithError(err).Warn("Failed to open log file. Logging to console only.") - logger.SetOutput(os.Stdout) - return logger + // 配置日志轮转(保存到包级变量) + logRotator = &lumberjack.Logger{ + Filename: cfg.Log.FilePath, + MaxSize: getOrDefault(cfg.Log.MaxSize, 100), + MaxBackups: getOrDefault(cfg.Log.MaxBackups, 7), + MaxAge: getOrDefault(cfg.Log.MaxAge, 30), + Compress: cfg.Log.Compress, } - // 同时输出到控制台和文件 - logger.SetOutput(io.MultiWriter(os.Stdout, logFile)) - logger.WithField("log_file_path", cfg.Log.FilePath).Info("Logging is now configured to output to both console and file.") + logger.SetOutput(io.MultiWriter(os.Stdout, logRotator)) + logger.WithField("log_file", cfg.Log.FilePath).Info("Logging to both console and file") } else { - // 仅输出到控制台 logger.SetOutput(os.Stdout) } - logger.Info("Root logger initialized.") + logger.Info("Logger initialized successfully") return logger } + +// Close 关闭日志轮转器(在 main.go 中调用) +func Close() { + if logRotator != nil { + logRotator.Close() + } +} + +func getOrDefault(value, defaultValue int) int { + if value <= 0 { + return defaultValue + } + return value +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index bef9658..91ea18b 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -1,4 +1,5 @@ // Filename: internal/middleware/auth.go + package middleware import ( @@ -7,76 +8,115 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" ) -// === API Admin 认证管道 (/admin/* API路由) === +type ErrorResponse struct { + Error string `json:"error"` + Code string `json:"code,omitempty"` +} -func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc { +// APIAdminAuthMiddleware 管理后台 API 认证 +func APIAdminAuthMiddleware( + securityService *service.SecurityService, + logger *logrus.Logger, +) gin.HandlerFunc { return func(c *gin.Context) { tokenValue := extractBearerToken(c) if tokenValue == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"}) + c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{ + Error: "Authentication required", + Code: "AUTH_MISSING", + }) return } + + // ✅ 只传 token 参数(移除 context) authToken, err := securityService.AuthenticateToken(tokenValue) - if err != nil || !authToken.IsAdmin { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"}) + if err != nil { + logger.WithError(err).Warn("Authentication failed") + c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{ + Error: "Invalid authentication", + Code: "AUTH_INVALID", + }) return } + + if !authToken.IsAdmin { + c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{ + Error: "Admin access required", + Code: "AUTH_FORBIDDEN", + }) + return + } + c.Set("adminUser", authToken) c.Next() } } -// === /v1 Proxy 认证 === - -func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc { +// ProxyAuthMiddleware 代理请求认证 +func ProxyAuthMiddleware( + securityService *service.SecurityService, + logger *logrus.Logger, +) gin.HandlerFunc { return func(c *gin.Context) { tokenValue := extractProxyToken(c) if tokenValue == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"}) + c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{ + Error: "API key required", + Code: "KEY_MISSING", + }) return } + + // ✅ 只传 token 参数(移除 context) authToken, err := securityService.AuthenticateToken(tokenValue) if err != nil { - // 通用信息,避免泄露过多信息 - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"}) + c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{ + Error: "Invalid API key", + Code: "KEY_INVALID", + }) return } + c.Set("authToken", authToken) c.Next() } } +// extractProxyToken 按优先级提取 token func extractProxyToken(c *gin.Context) string { - if key := c.Query("key"); key != "" { - return key - } - authHeader := c.GetHeader("Authorization") - if authHeader != "" { - if strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimPrefix(authHeader, "Bearer ") - } + // 优先级 1: Authorization Header + if token := extractBearerToken(c); token != "" { + return token } + + // 优先级 2: X-Api-Key if key := c.GetHeader("X-Api-Key"); key != "" { return key } + + // 优先级 3: X-Goog-Api-Key if key := c.GetHeader("X-Goog-Api-Key"); key != "" { return key } - return "" + + // 优先级 4: Query 参数(不推荐) + return c.Query("key") } -// === 辅助函数 === - +// extractBearerToken 提取 Bearer Token func extractBearerToken(c *gin.Context) string { authHeader := c.GetHeader("Authorization") if authHeader == "" { return "" } - parts := strings.Split(authHeader, " ") - if len(parts) == 2 && parts[0] == "Bearer" { - return parts[1] + + const prefix = "Bearer " + if !strings.HasPrefix(authHeader, prefix) { + return "" } - return "" + + return strings.TrimSpace(authHeader[len(prefix):]) } diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go new file mode 100644 index 0000000..c94143f --- /dev/null +++ b/internal/middleware/cors.go @@ -0,0 +1,90 @@ +// Filename: internal/middleware/cors.go + +package middleware + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +type CORSConfig struct { + AllowedOrigins []string + AllowedMethods []string + AllowedHeaders []string + ExposedHeaders []string + AllowCredentials bool + MaxAge int +} + +func CORSMiddleware(config CORSConfig) gin.HandlerFunc { + return func(c *gin.Context) { + origin := c.Request.Header.Get("Origin") + + // 检查 origin 是否允许 + if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) { + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + } + + if config.AllowCredentials { + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + } + + if len(config.ExposedHeaders) > 0 { + c.Writer.Header().Set("Access-Control-Expose-Headers", + strings.Join(config.ExposedHeaders, ", ")) + } + + // 处理预检请求 + if c.Request.Method == http.MethodOptions { + if len(config.AllowedMethods) > 0 { + c.Writer.Header().Set("Access-Control-Allow-Methods", + strings.Join(config.AllowedMethods, ", ")) + } + + if len(config.AllowedHeaders) > 0 { + c.Writer.Header().Set("Access-Control-Allow-Headers", + strings.Join(config.AllowedHeaders, ", ")) + } + + if config.MaxAge > 0 { + c.Writer.Header().Set("Access-Control-Max-Age", + string(rune(config.MaxAge))) + } + + c.AbortWithStatus(http.StatusNoContent) + return + } + + c.Next() + } +} + +func isOriginAllowed(origin string, allowedOrigins []string) bool { + for _, allowed := range allowedOrigins { + if allowed == "*" || allowed == origin { + return true + } + // 支持通配符子域名 + if strings.HasPrefix(allowed, "*.") { + domain := strings.TrimPrefix(allowed, "*.") + if strings.HasSuffix(origin, domain) { + return true + } + } + } + return false +} + +// 使用示例 +func SetupCORS(r *gin.Engine) { + r.Use(CORSMiddleware(CORSConfig{ + AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"}, + ExposedHeaders: []string{"X-Request-Id"}, + AllowCredentials: true, + MaxAge: 3600, + })) +} diff --git a/internal/middleware/log.go b/internal/middleware/log.go index 4a35e9e..cce1a56 100644 --- a/internal/middleware/log.go +++ b/internal/middleware/log.go @@ -1,84 +1,213 @@ -// Filename: internal/middleware/log_redaction.go +// Filename: internal/middleware/logging.go + package middleware import ( "bytes" "io" "regexp" + "strings" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) -const RedactedBodyKey = "redactedBody" -const RedactedAuthHeaderKey = "redactedAuthHeader" -const RedactedValue = `"[REDACTED]"` +const ( + RedactedBodyKey = "redactedBody" + RedactedAuthHeaderKey = "redactedAuthHeader" + RedactedValue = `"[REDACTED]"` +) +// 预编译正则表达式(全局变量,提升性能) +var ( + // JSON 敏感字段脱敏 + jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`) + + // Bearer Token 脱敏 + bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`) + + // URL 中的 key 参数脱敏 + queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`) +) + +// RedactionMiddleware 请求数据脱敏中间件 func RedactionMiddleware() gin.HandlerFunc { - // Pre-compile regex for efficiency - jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`) - bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`) return func(c *gin.Context) { - // --- 1. Redact Request Body --- - if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" { - if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil { - - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - bodyString := string(bodyBytes) - - redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue) - - c.Set(RedactedBodyKey, redactedBody) - } - } - // --- 2. Redact Authorization Header --- - authHeader := c.GetHeader("Authorization") - if authHeader != "" { - if bearerTokenPattern.MatchString(authHeader) { - redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`) - c.Set(RedactedAuthHeaderKey, redactedHeader) - } + // 1. 脱敏请求体 + if shouldRedactBody(c) { + redactRequestBody(c) } + + // 2. 脱敏认证头 + redactAuthHeader(c) + + // 3. 脱敏 URL 查询参数 + redactQueryParams(c) + c.Next() } } -// LogrusLogger is a Gin middleware that logs requests using a Logrus logger. -// It consumes redacted data prepared by the RedactionMiddleware. +// shouldRedactBody 判断是否需要脱敏请求体 +func shouldRedactBody(c *gin.Context) bool { + method := c.Request.Method + contentType := c.GetHeader("Content-Type") + + // 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求 + return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") && + strings.Contains(contentType, "application/json") +} + +// redactRequestBody 脱敏请求体 +func redactRequestBody(c *gin.Context) { + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + return + } + + // 恢复请求体供后续使用 + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // 脱敏敏感字段 + bodyString := string(bodyBytes) + redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue) + + c.Set(RedactedBodyKey, redactedBody) +} + +// redactAuthHeader 脱敏认证头 +func redactAuthHeader(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + return + } + + if bearerTokenPattern.MatchString(authHeader) { + redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`) + c.Set(RedactedAuthHeaderKey, redacted) + } else { + // 对于非 Bearer 的 token,全部脱敏 + c.Set(RedactedAuthHeaderKey, "[REDACTED]") + } + + // 同时处理其他敏感 Header + sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"} + for _, header := range sensitiveHeaders { + if value := c.GetHeader(header); value != "" { + c.Set("redacted_"+header, "[REDACTED]") + } + } +} + +// redactQueryParams 脱敏 URL 查询参数 +func redactQueryParams(c *gin.Context) { + rawQuery := c.Request.URL.RawQuery + if rawQuery == "" { + return + } + + redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`) + if redacted != rawQuery { + c.Set("redactedQuery", redacted) + } +} + +// LogrusLogger Gin 请求日志中间件(使用 Logrus) func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path + method := c.Request.Method - // Process request + // 处理请求 c.Next() - // After request, gather data and log + // 计算延迟 latency := time.Since(start) statusCode := c.Writer.Status() + clientIP := c.ClientIP() - entry := logger.WithFields(logrus.Fields{ - "status_code": statusCode, - "latency_ms": latency.Milliseconds(), - "client_ip": c.ClientIP(), - "method": c.Request.Method, - "path": path, - }) + // 构建日志字段 + fields := logrus.Fields{ + "status": statusCode, + "latency_ms": latency.Milliseconds(), + "ip": clientIP, + "method": method, + "path": path, + } + // 添加请求 ID(如果存在) + if requestID := getRequestID(c); requestID != "" { + fields["request_id"] = requestID + } + + // 添加脱敏后的数据 if redactedBody, exists := c.Get(RedactedBodyKey); exists { - entry = entry.WithField("body", redactedBody) + fields["body"] = redactedBody } if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists { - entry = entry.WithField("authorization", redactedAuth) + fields["authorization"] = redactedAuth } + if redactedQuery, exists := c.Get("redactedQuery"); exists { + fields["query"] = redactedQuery + } + + // 添加用户信息(如果已认证) + if user := getAuthenticatedUser(c); user != "" { + fields["user"] = user + } + + // 根据状态码选择日志级别 + entry := logger.WithFields(fields) + if len(c.Errors) > 0 { - entry.Error(c.Errors.String()) + fields["errors"] = c.Errors.String() + entry.Error("Request failed") } else { - entry.Info("request handled") + switch { + case statusCode >= 500: + entry.Error("Server error") + case statusCode >= 400: + entry.Warn("Client error") + case statusCode >= 300: + entry.Info("Redirect") + default: + // 只在 Debug 模式记录成功请求 + if logger.Level >= logrus.DebugLevel { + entry.Debug("Request completed") + } + } } } } + +// getRequestID 获取请求 ID +func getRequestID(c *gin.Context) string { + if id, exists := c.Get("request_id"); exists { + if requestID, ok := id.(string); ok { + return requestID + } + } + return "" +} + +// getAuthenticatedUser 获取已认证用户标识 +func getAuthenticatedUser(c *gin.Context) string { + // 尝试从不同来源获取用户信息 + if user, exists := c.Get("adminUser"); exists { + if authToken, ok := user.(interface{ GetID() string }); ok { + return authToken.GetID() + } + } + + if user, exists := c.Get("authToken"); exists { + if authToken, ok := user.(interface{ GetID() string }); ok { + return authToken.GetID() + } + } + + return "" +} diff --git a/internal/middleware/rate_limit.go b/internal/middleware/rate_limit.go new file mode 100644 index 0000000..3c5bf91 --- /dev/null +++ b/internal/middleware/rate_limit.go @@ -0,0 +1,86 @@ +// Filename: internal/middleware/rate_limit.go + +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type RateLimiter struct { + limiters map[string]*rate.Limiter + mu sync.RWMutex + r rate.Limit // 每秒请求数 + b int // 突发容量 +} + +func NewRateLimiter(r rate.Limit, b int) *RateLimiter { + return &RateLimiter{ + limiters: make(map[string]*rate.Limiter), + r: r, + b: b, + } +} + +func (rl *RateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + limiter, exists := rl.limiters[key] + if !exists { + limiter = rate.NewLimiter(rl.r, rl.b) + rl.limiters[key] = limiter + } + + return limiter +} + +// 定期清理不活跃的限制器 +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + rl.mu.Lock() + // 简单策略:定期清空(生产环境应该用 LRU) + rl.limiters = make(map[string]*rate.Limiter) + rl.mu.Unlock() + } +} + +func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc { + return func(c *gin.Context) { + // 按 IP 限流 + key := c.ClientIP() + + // 如果有认证 token,按 token 限流(更精确) + if authToken, exists := c.Get("authToken"); exists { + if token, ok := authToken.(interface{ GetID() string }); ok { + key = "token:" + token.GetID() + } + } + + l := limiter.getLimiter(key) + if !l.Allow() { + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "error": "Rate limit exceeded", + "code": "RATE_LIMIT", + }) + return + } + + c.Next() + } +} + +// 使用示例 +func SetupRateLimit(r *gin.Engine) { + limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20 + go limiter.cleanup() + + r.Use(RateLimitMiddleware(limiter)) +} diff --git a/internal/middleware/request_id.go b/internal/middleware/request_id.go new file mode 100644 index 0000000..ca5fbb2 --- /dev/null +++ b/internal/middleware/request_id.go @@ -0,0 +1,39 @@ +// Filename: internal/middleware/request_id.go + +package middleware + +import ( + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// RequestIDMiddleware 请求 ID 追踪中间件 +func RequestIDMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // 1. 尝试从 Header 获取现有的 Request ID + requestID := c.GetHeader("X-Request-Id") + + // 2. 如果没有,生成新的 + if requestID == "" { + requestID = uuid.New().String() + } + + // 3. 设置到 Context + c.Set("request_id", requestID) + + // 4. 返回给客户端(用于追踪) + c.Writer.Header().Set("X-Request-Id", requestID) + + c.Next() + } +} + +// GetRequestID 获取当前请求的 Request ID +func GetRequestID(c *gin.Context) string { + if id, exists := c.Get("request_id"); exists { + if requestID, ok := id.(string); ok { + return requestID + } + } + return "" +} diff --git a/internal/middleware/security.go b/internal/middleware/security.go index 673ce97..3850426 100644 --- a/internal/middleware/security.go +++ b/internal/middleware/security.go @@ -1,4 +1,4 @@ -// Filename: internal/middleware/security.go +// Filename: internal/middleware/security.go (简化版) package middleware @@ -6,26 +6,136 @@ import ( "gemini-balancer/internal/service" "gemini-balancer/internal/settings" "net/http" + "sync" + "time" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" ) -func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc { +// 简单的缓存项 +type cacheItem struct { + value bool + expiration int64 +} + +// 简单的 TTL 缓存实现 +type IPBanCache struct { + items map[string]*cacheItem + mu sync.RWMutex + ttl time.Duration +} + +func NewIPBanCache() *IPBanCache { + cache := &IPBanCache{ + items: make(map[string]*cacheItem), + ttl: 1 * time.Minute, + } + + // 启动清理协程 + go cache.cleanup() + + return cache +} + +func (c *IPBanCache) Get(key string) (bool, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + item, exists := c.items[key] + if !exists { + return false, false + } + + // 检查是否过期 + if time.Now().UnixNano() > item.expiration { + return false, false + } + + return item.value, true +} + +func (c *IPBanCache) Set(key string, value bool) { + c.mu.Lock() + defer c.mu.Unlock() + + c.items[key] = &cacheItem{ + value: value, + expiration: time.Now().Add(c.ttl).UnixNano(), + } +} + +func (c *IPBanCache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.items, key) +} + +func (c *IPBanCache) cleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + c.mu.Lock() + now := time.Now().UnixNano() + for key, item := range c.items { + if now > item.expiration { + delete(c.items, key) + } + } + c.mu.Unlock() + } +} + +func IPBanMiddleware( + securityService *service.SecurityService, + settingsManager *settings.SettingsManager, + banCache *IPBanCache, + logger *logrus.Logger, +) gin.HandlerFunc { return func(c *gin.Context) { if !settingsManager.IsIPBanEnabled() { c.Next() return } + ip := c.ClientIP() - isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip) - if err != nil { + + // 查缓存 + if isBanned, exists := banCache.Get(ip); exists { + if isBanned { + logger.WithField("ip", ip).Debug("IP blocked (cached)") + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ + "error": "Access denied", + }) + return + } c.Next() return } - if isBanned { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁,请稍后再试"}) + + // 查数据库 + ctx := c.Request.Context() + isBanned, err := securityService.IsIPBanned(ctx, ip) + if err != nil { + logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status") + + // 降级策略:允许访问 + c.Next() return } + + // 更新缓存 + banCache.Set(ip, isBanned) + + if isBanned { + logger.WithField("ip", ip).Info("IP blocked") + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ + "error": "Access denied", + }) + return + } + c.Next() } } diff --git a/internal/middleware/timeout.go b/internal/middleware/timeout.go new file mode 100644 index 0000000..b285702 --- /dev/null +++ b/internal/middleware/timeout.go @@ -0,0 +1,52 @@ +// Filename: internal/middleware/timeout.go + +package middleware + +import ( + "context" + "net/http" + "time" + + "github.com/gin-gonic/gin" +) + +func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + // 创建带超时的 context + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + // 替换 request context + c.Request = c.Request.WithContext(ctx) + + // 使用 channel 等待请求完成 + finished := make(chan struct{}) + + go func() { + c.Next() + close(finished) + }() + + select { + case <-finished: + // 请求正常完成 + return + case <-ctx.Done(): + // 超时 + c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{ + "error": "Request timeout", + "code": "TIMEOUT", + }) + } + } +} + +// 使用示例 +func SetupTimeout(r *gin.Engine) { + // 对 API 路由设置 30 秒超时 + api := r.Group("/api") + api.Use(TimeoutMiddleware(30 * time.Second)) + { + // ... API routes + } +} diff --git a/internal/middleware/web.go b/internal/middleware/web.go index 8d32dba..7831e01 100644 --- a/internal/middleware/web.go +++ b/internal/middleware/web.go @@ -1,23 +1,151 @@ // Filename: internal/middleware/web.go + package middleware import ( + "crypto/sha256" + "encoding/hex" "gemini-balancer/internal/service" - "log" "net/http" + "os" + "strings" + "sync" + "time" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" ) const ( AdminSessionCookie = "gemini_admin_session" + SessionMaxAge = 3600 * 24 * 7 // 7天 + CacheTTL = 5 * time.Minute + CleanupInterval = 10 * time.Minute // 降低清理频率 + SessionRefreshTime = 30 * time.Minute ) +// ==================== 缓存层 ==================== + +type authCacheEntry struct { + Token interface{} + ExpiresAt time.Time +} + +type authCache struct { + mu sync.RWMutex + cache map[string]*authCacheEntry + ttl time.Duration +} + +var webAuthCache = newAuthCache(CacheTTL) + +func newAuthCache(ttl time.Duration) *authCache { + c := &authCache{ + cache: make(map[string]*authCacheEntry), + ttl: ttl, + } + go c.cleanupLoop() + return c +} + +func (c *authCache) get(key string) (interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.cache[key] + if !exists || time.Now().After(entry.ExpiresAt) { + return nil, false + } + return entry.Token, true +} + +func (c *authCache) set(key string, token interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cache[key] = &authCacheEntry{ + Token: token, + ExpiresAt: time.Now().Add(c.ttl), + } +} + +func (c *authCache) delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.cache, key) +} + +func (c *authCache) cleanupLoop() { + ticker := time.NewTicker(CleanupInterval) + defer ticker.Stop() + + for range ticker.C { + c.cleanup() + } +} + +func (c *authCache) cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + count := 0 + for key, entry := range c.cache { + if now.After(entry.ExpiresAt) { + delete(c.cache, key) + count++ + } + } + + if count > 0 { + logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count) + } +} + +// ==================== 会话刷新缓存 ==================== + +var sessionRefreshCache = struct { + sync.RWMutex + timestamps map[string]time.Time +}{ + timestamps: make(map[string]time.Time), +} + +// 定期清理刷新时间戳 +func init() { + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for range ticker.C { + sessionRefreshCache.Lock() + now := time.Now() + for key, ts := range sessionRefreshCache.timestamps { + if now.Sub(ts) > 2*time.Hour { + delete(sessionRefreshCache.timestamps, key) + } + } + sessionRefreshCache.Unlock() + } + }() +} + +// ==================== Cookie 操作 ==================== + func SetAdminSessionCookie(c *gin.Context, adminToken string) { - c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true) + secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" + c.SetSameSite(http.SameSiteStrictMode) + c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true) +} + +func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) { + secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" + c.SetSameSite(http.SameSiteStrictMode) + c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true) } func ClearAdminSessionCookie(c *gin.Context) { + c.SetSameSite(http.SameSiteStrictMode) c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true) } @@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string { return cookie } +// ==================== 认证中间件 ==================== + func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc { + logger := logrus.New() + logger.SetLevel(getLogLevel()) + return func(c *gin.Context) { cookie := ExtractTokenFromCookie(c) - log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path) - log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie) - authToken, err := authService.AuthenticateToken(cookie) - if err != nil { - log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err) - } else if !authToken.IsAdmin { - log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.") - } else { - log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.") - } - if err != nil || !authToken.IsAdmin { + if cookie == "" { + logger.Debug("[WebAuth] No session cookie found") ClearAdminSessionCookie(c) - c.Redirect(http.StatusFound, "/login") - c.Abort() + redirectToLogin(c) return } + + cacheKey := hashToken(cookie) + + if cachedToken, found := webAuthCache.get(cacheKey); found { + logger.Debug("[WebAuth] Using cached token") + c.Set("adminUser", cachedToken) + refreshSessionIfNeeded(c, cookie) + c.Next() + return + } + + logger.Debug("[WebAuth] Cache miss, authenticating...") + authToken, err := authService.AuthenticateToken(cookie) + + if err != nil { + logger.WithError(err).Warn("[WebAuth] Authentication failed") + ClearAdminSessionCookie(c) + webAuthCache.delete(cacheKey) + redirectToLogin(c) + return + } + + if !authToken.IsAdmin { + logger.Warn("[WebAuth] User is not admin") + ClearAdminSessionCookie(c) + webAuthCache.delete(cacheKey) + redirectToLogin(c) + return + } + + webAuthCache.set(cacheKey, authToken) + logger.Debug("[WebAuth] Authentication success, token cached") + c.Set("adminUser", authToken) + refreshSessionIfNeeded(c, cookie) c.Next() } } + +func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + cookie := ExtractTokenFromCookie(c) + + if cookie == "" { + logger.Debug("No session cookie found") + ClearAdminSessionCookie(c) + redirectToLogin(c) + return + } + + cacheKey := hashToken(cookie) + if cachedToken, found := webAuthCache.get(cacheKey); found { + c.Set("adminUser", cachedToken) + refreshSessionIfNeeded(c, cookie) + c.Next() + return + } + + authToken, err := authService.AuthenticateToken(cookie) + + if err != nil { + logger.WithError(err).Warn("Token authentication failed") + ClearAdminSessionCookie(c) + webAuthCache.delete(cacheKey) + redirectToLogin(c) + return + } + + if !authToken.IsAdmin { + logger.Warn("Token valid but user is not admin") + ClearAdminSessionCookie(c) + webAuthCache.delete(cacheKey) + redirectToLogin(c) + return + } + + webAuthCache.set(cacheKey, authToken) + c.Set("adminUser", authToken) + refreshSessionIfNeeded(c, cookie) + c.Next() + } +} + +// ==================== 辅助函数 ==================== + +func hashToken(token string) string { + h := sha256.Sum256([]byte(token)) + return hex.EncodeToString(h[:]) +} + +func redirectToLogin(c *gin.Context) { + if isAjaxRequest(c) { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Session expired", + "code": "AUTH_REQUIRED", + }) + c.Abort() + return + } + + originalPath := c.Request.URL.Path + if originalPath != "/" && originalPath != "/login" { + c.Redirect(http.StatusFound, "/login?redirect="+originalPath) + } else { + c.Redirect(http.StatusFound, "/login") + } + c.Abort() +} + +func isAjaxRequest(c *gin.Context) bool { + // 检查 Content-Type + contentType := c.GetHeader("Content-Type") + if strings.Contains(contentType, "application/json") { + return true + } + + // 检查 Accept(优先检查 JSON) + accept := c.GetHeader("Accept") + if strings.Contains(accept, "application/json") && + !strings.Contains(accept, "text/html") { + return true + } + + // 兼容旧版 XMLHttpRequest + return c.GetHeader("X-Requested-With") == "XMLHttpRequest" +} + +func refreshSessionIfNeeded(c *gin.Context, token string) { + tokenHash := hashToken(token) + + sessionRefreshCache.RLock() + lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash] + sessionRefreshCache.RUnlock() + + if !exists || time.Since(lastRefresh) > SessionRefreshTime { + SetAdminSessionCookie(c, token) + + sessionRefreshCache.Lock() + sessionRefreshCache.timestamps[tokenHash] = time.Now() + sessionRefreshCache.Unlock() + } +} + +func getLogLevel() logrus.Level { + level := os.Getenv("LOG_LEVEL") + switch strings.ToLower(level) { + case "debug": + return logrus.DebugLevel + case "warn": + return logrus.WarnLevel + case "error": + return logrus.ErrorLevel + default: + return logrus.InfoLevel + } +} + +// ==================== 工具函数 ==================== + +func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) { + return c.Get("adminUser") +} + +func InvalidateTokenCache(token string) { + tokenHash := hashToken(token) + webAuthCache.delete(tokenHash) + + // 同时清理刷新时间戳 + sessionRefreshCache.Lock() + delete(sessionRefreshCache.timestamps, tokenHash) + sessionRefreshCache.Unlock() +} + +func ClearAllAuthCache() { + webAuthCache.mu.Lock() + webAuthCache.cache = make(map[string]*authCacheEntry) + webAuthCache.mu.Unlock() + + sessionRefreshCache.Lock() + sessionRefreshCache.timestamps = make(map[string]time.Time) + sessionRefreshCache.Unlock() +} + +// ==================== 调试工具 ==================== + +type SessionInfo struct { + HasCookie bool `json:"has_cookie"` + IsValid bool `json:"is_valid"` + IsAdmin bool `json:"is_admin"` + IsCached bool `json:"is_cached"` + LastActivity string `json:"last_activity"` +} + +func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo { + info := SessionInfo{ + HasCookie: false, + IsValid: false, + IsAdmin: false, + IsCached: false, + LastActivity: time.Now().Format(time.RFC3339), + } + + cookie := ExtractTokenFromCookie(c) + if cookie == "" { + return info + } + + info.HasCookie = true + + cacheKey := hashToken(cookie) + if _, found := webAuthCache.get(cacheKey); found { + info.IsCached = true + } + + authToken, err := authService.AuthenticateToken(cookie) + if err != nil { + return info + } + + info.IsValid = true + info.IsAdmin = authToken.IsAdmin + + return info +} + +func GetCacheStats() map[string]interface{} { + webAuthCache.mu.RLock() + cacheSize := len(webAuthCache.cache) + webAuthCache.mu.RUnlock() + + sessionRefreshCache.RLock() + refreshSize := len(sessionRefreshCache.timestamps) + sessionRefreshCache.RUnlock() + return map[string]interface{}{ + "auth_cache_entries": cacheSize, + "refresh_cache_entries": refreshSize, + "ttl_seconds": int(webAuthCache.ttl.Seconds()), + "cleanup_interval": int(CleanupInterval.Seconds()), + "session_refresh_time": int(SessionRefreshTime.Seconds()), + } +} diff --git a/internal/models/runtime.go b/internal/models/runtime.go index f09beeb..ac6bcdc 100644 --- a/internal/models/runtime.go +++ b/internal/models/runtime.go @@ -27,6 +27,8 @@ type SystemSettings struct { BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"` BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"` + KeyCheckSchedulerIntervalSeconds int `json:"key_check_scheduler_interval_seconds" default:"60" name:"Key检查调度器间隔(秒)" category:"健康检查" desc:"动态调度器检查各组是否需要执行健康检查的周期。"` + EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"` UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"` diff --git a/internal/pongo/renderer.go b/internal/pongo/renderer.go index f00257c..14a705a 100644 --- a/internal/pongo/renderer.go +++ b/internal/pongo/renderer.go @@ -1,63 +1,96 @@ -// Filename: internal/pongo/renderer.go - package pongo import ( "fmt" "net/http" + "sync" "github.com/flosch/pongo2/v6" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/render" + "github.com/sirupsen/logrus" ) type Renderer struct { - Context pongo2.Context - tplSet *pongo2.TemplateSet + mu sync.RWMutex + globalContext pongo2.Context + tplSet *pongo2.TemplateSet + logger *logrus.Logger } -func New(directory string, isDebug bool) *Renderer { +func New(directory string, isDebug bool, logger *logrus.Logger) *Renderer { loader := pongo2.MustNewLocalFileSystemLoader(directory) tplSet := pongo2.NewSet("gin-pongo-templates", loader) tplSet.Debug = isDebug - return &Renderer{Context: make(pongo2.Context), tplSet: tplSet} + return &Renderer{ + globalContext: make(pongo2.Context), + tplSet: tplSet, + logger: logger, + } } -// Instance returns a new render.HTML instance for a single request. -func (p *Renderer) Instance(name string, data interface{}) render.Render { - var glob pongo2.Context - if p.Context != nil { - glob = p.Context - } +// SetGlobalContext 线程安全地设置全局上下文 +func (p *Renderer) SetGlobalContext(key string, value interface{}) { + p.mu.Lock() + defer p.mu.Unlock() + p.globalContext[key] = value +} +// Warmup 预加载模板 +func (p *Renderer) Warmup(templateNames ...string) error { + for _, name := range templateNames { + if _, err := p.tplSet.FromCache(name); err != nil { + return fmt.Errorf("failed to warmup template '%s': %w", name, err) + } + } + p.logger.WithField("count", len(templateNames)).Info("Templates warmed up") + return nil +} + +func (p *Renderer) Instance(name string, data interface{}) render.Render { + // 安全读取全局上下文 + p.mu.RLock() + glob := make(pongo2.Context, len(p.globalContext)) + for k, v := range p.globalContext { + glob[k] = v + } + p.mu.RUnlock() + + // 解析请求数据 var context pongo2.Context if data != nil { - if ginContext, ok := data.(gin.H); ok { - context = pongo2.Context(ginContext) - } else if pongoContext, ok := data.(pongo2.Context); ok { - context = pongoContext - } else if m, ok := data.(map[string]interface{}); ok { - context = m - } else { + switch v := data.(type) { + case gin.H: + context = pongo2.Context(v) + case pongo2.Context: + context = v + case map[string]interface{}: + context = v + default: context = make(pongo2.Context) } } else { context = make(pongo2.Context) } + // 合并上下文(请求数据优先) for k, v := range glob { - if _, ok := context[k]; !ok { + if _, exists := context[k]; !exists { context[k] = v } } + // 加载模板 tpl, err := p.tplSet.FromCache(name) if err != nil { - panic(fmt.Sprintf("Failed to load template '%s': %v", name, err)) + p.logger.WithError(err).WithField("template", name).Error("Failed to load template") + return &ErrorHTML{ + StatusCode: http.StatusInternalServerError, + Error: fmt.Errorf("template load error: %s", name), + } } return &HTML{ - p: p, Template: tpl, Name: name, Data: context, @@ -65,7 +98,6 @@ func (p *Renderer) Instance(name string, data interface{}) render.Render { } type HTML struct { - p *Renderer Template *pongo2.Template Name string Data pongo2.Context @@ -82,15 +114,31 @@ func (h *HTML) Render(w http.ResponseWriter) error { } func (h *HTML) WriteContentType(w http.ResponseWriter) { - header := w.Header() - if val := header["Content-Type"]; len(val) == 0 { - header["Content-Type"] = []string{"text/html; charset=utf-8"} + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", "text/html; charset=utf-8") } } +// ErrorHTML 错误渲染器 +type ErrorHTML struct { + StatusCode int + Error error +} + +func (e *ErrorHTML) Render(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(e.StatusCode) + _, err := w.Write([]byte(e.Error.Error())) + return err +} + +func (e *ErrorHTML) WriteContentType(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") +} + +// C 获取或创建 pongo2 上下文 func C(ctx *gin.Context) pongo2.Context { - p, exists := ctx.Get("pongo2") - if exists { + if p, exists := ctx.Get("pongo2"); exists { if pCtx, ok := p.(pongo2.Context); ok { return pCtx } diff --git a/internal/router/router.go b/internal/router/router.go index 4724c1b..6c6d324 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -17,6 +17,7 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" ) func NewRouter( @@ -42,70 +43,214 @@ func NewRouter( upstreamModule *upstream.Module, proxyModule *proxy.Module, ) *gin.Engine { + // === 1. 创建全局 Logger(统一管理)=== + logger := createLogger(cfg) + + // === 2. 设置 Gin 运行模式 === if cfg.Log.Level != "debug" { gin.SetMode(gin.ReleaseMode) } - router := gin.Default() - router.Static("/static", "./web/static") - // CORS 配置 - config := cors.Config{ - // 允许前端的来源。在生产环境中,需改为实际域名 - AllowOrigins: []string{"http://localhost:9000"}, - AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, - AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"}, - ExposeHeaders: []string{"Content-Length"}, - AllowCredentials: true, - MaxAge: 12 * time.Hour, - } - router.Use(cors.New(config)) - isDebug := gin.Mode() != gin.ReleaseMode - router.HTMLRender = pongo.New("web/templates", isDebug) + // === 3. 创建 Router(使用 gin.New() 以便完全控制中间件)=== + router := gin.New() - // --- 基础设施 --- - router.GET("/", func(c *gin.Context) { c.Redirect(http.StatusMovedPermanently, "/dashboard") }) - router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) - // --- 统一的认证管道 --- - apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService) + // === 4. 注册全局中间件(按执行顺序)=== + setupGlobalMiddleware(router, logger) + + // === 5. 配置静态文件和模板 === + setupStaticAndTemplates(router, logger) + + // === 6. 配置 CORS === + setupCORS(router, cfg) + + // === 7. 注册基础路由 === + setupBasicRoutes(router) + + // === 8. 创建认证中间件 === + apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger) webAdminAuth := middleware.WebAdminAuthMiddleware(securityService) - router.Use(gin.RecoveryWithWriter(os.Stdout)) - // --- 将正确的依赖和中间件管道传递下去 --- - registerProxyRoutes(router, proxyHandler, securityService) - registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule) - registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager) + // === 9. 注册业务路由(按功能分组)=== + registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger) registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler) + registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, + logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule) + registerProxyRoutes(router, proxyHandler, securityService, logger) + return router } +// ==================== 辅助函数 ==================== + +// createLogger 创建并配置全局 Logger +func createLogger(cfg *config.Config) *logrus.Logger { + logger := logrus.New() + + // 设置日志格式 + if cfg.Log.Format == "json" { + logger.SetFormatter(&logrus.JSONFormatter{ + TimestampFormat: time.RFC3339, + }) + } else { + logger.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + TimestampFormat: "2006-01-02 15:04:05", + }) + } + + // 设置日志级别 + switch cfg.Log.Level { + case "debug": + logger.SetLevel(logrus.DebugLevel) + case "info": + logger.SetLevel(logrus.InfoLevel) + case "warn": + logger.SetLevel(logrus.WarnLevel) + case "error": + logger.SetLevel(logrus.ErrorLevel) + default: + logger.SetLevel(logrus.InfoLevel) + } + + // 设置输出(可选:输出到文件) + logger.SetOutput(os.Stdout) + + return logger +} + +// setupGlobalMiddleware 设置全局中间件 +func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) { + // 1. 请求 ID 中间件(用于链路追踪) + router.Use(middleware.RequestIDMiddleware()) + + // 2. 数据脱敏中间件(在日志前执行) + router.Use(middleware.RedactionMiddleware()) + + // 3. 日志中间件 + router.Use(middleware.LogrusLogger(logger)) + + // 4. 错误恢复中间件 + router.Use(gin.RecoveryWithWriter(os.Stdout)) +} + +// setupStaticAndTemplates 配置静态文件和模板 +func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) { + router.Static("/static", "./web/static") + + isDebug := gin.Mode() != gin.ReleaseMode + router.HTMLRender = pongo.New("web/templates", isDebug, logger) +} + +// setupCORS 配置 CORS +func setupCORS(router *gin.Engine, cfg *config.Config) { + corsConfig := cors.Config{ + AllowOrigins: getCORSOrigins(cfg), + AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, + AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"}, + ExposeHeaders: []string{"Content-Length", "X-Request-Id"}, + AllowCredentials: true, + MaxAge: 12 * time.Hour, + } + router.Use(cors.New(corsConfig)) +} + +// getCORSOrigins 获取 CORS 允许的来源 +func getCORSOrigins(cfg *config.Config) []string { + // 默认值 + origins := []string{"http://localhost:9000"} + + // 从配置读取(修复:移除 nil 检查) + if len(cfg.Server.CORSOrigins) > 0 { + origins = cfg.Server.CORSOrigins + } + + return origins +} + +// setupBasicRoutes 设置基础路由 +func setupBasicRoutes(router *gin.Engine) { + // 根路径重定向 + router.GET("/", func(c *gin.Context) { + c.Redirect(http.StatusMovedPermanently, "/dashboard") + }) + + // 健康检查 + router.GET("/health", handleHealthCheck) + + // 版本信息(可选) + router.GET("/version", handleVersion) +} + +// handleHealthCheck 健康检查处理器 +func handleHealthCheck(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "ok", + "time": time.Now().Unix(), + }) +} + +// handleVersion 版本信息处理器 +func handleVersion(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "version": "1.0.0", // 可以从配置或编译时变量读取 + "build": "latest", + }) +} + +// ==================== 路由注册函数 ==================== + +// registerProxyRoutes 注册代理路由 func registerProxyRoutes( - router *gin.Engine, proxyHandler *handlers.ProxyHandler, securityService *service.SecurityService, + router *gin.Engine, + proxyHandler *handlers.ProxyHandler, + securityService *service.SecurityService, + logger *logrus.Logger, ) { - // 通用的代理认证中间件 - proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService) - // --- 模式一: 智能聚合模式 (根路径) --- - // /v1 和 /v1beta 路径作为默认入口,服务于 BasePool 聚合逻辑 + // 创建代理认证中间件 + proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger) + + // 模式一: 智能聚合模式(默认入口) + registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware) + + // 模式二: 精确路由模式(按组名路由) + registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware) +} + +// registerAggregateProxyRoutes 注册聚合代理路由 +func registerAggregateProxyRoutes( + router *gin.Engine, + proxyHandler *handlers.ProxyHandler, + authMiddleware gin.HandlerFunc, +) { + // /v1 路径组 v1 := router.Group("/v1") - v1.Use(proxyAuthMiddleware) + v1.Use(authMiddleware) { v1.Any("/*path", proxyHandler.HandleProxy) } + + // /v1beta 路径组 v1beta := router.Group("/v1beta") - v1beta.Use(proxyAuthMiddleware) + v1beta.Use(authMiddleware) { v1beta.Any("/*path", proxyHandler.HandleProxy) } - // --- 模式二: 精确路由模式 (/proxy/:group_name) --- - // 创建一个新的、物理隔离的路由组,用于按组名精确路由 +} + +// registerGroupProxyRoutes 注册分组代理路由 +func registerGroupProxyRoutes( + router *gin.Engine, + proxyHandler *handlers.ProxyHandler, + authMiddleware gin.HandlerFunc, +) { proxyGroup := router.Group("/proxy/:group_name") - proxyGroup.Use(proxyAuthMiddleware) + proxyGroup.Use(authMiddleware) { - // 捕获所有子路径 (例如 /v1/chat/completions),并全部交给同一个 ProxyHandler。 proxyGroup.Any("/*path", proxyHandler.HandleProxy) } } -// registerAdminRoutes +// registerAdminRoutes 注册管理后台 API 路由 func registerAdminRoutes( router *gin.Engine, authMiddleware gin.HandlerFunc, @@ -121,74 +266,112 @@ func registerAdminRoutes( ) { admin := router.Group("/admin", authMiddleware) { - // --- KeyGroup Base Routes --- - admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup) - admin.GET("/keygroups", keyGroupHandler.GetKeyGroups) - admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder) - // --- KeyGroup Specific Routes (by :id) --- - admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups) - admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup) - admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup) - admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup) - admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats) - admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction) - // --- APIKey Sub-resource Routes under a KeyGroup --- - keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys") - { - keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup) - keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup) - keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup) - keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup) - keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup) - keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping) - } + // KeyGroup 路由 + registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler) - // Global key operations - admin.GET("/apikeys", apiKeyHandler.ListAPIKeys) - // admin.PUT("/apikeys/:id", apiKeyHandler.UpdateAPIKey) // DEPRECATED: Status is now contextual - admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) // Test keys globally - admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) // Hard delete a single key - admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) // Hard delete multiple keys - admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) // Restore multiple keys globally + // APIKey 全局路由 + registerAPIKeyRoutes(admin, apiKeyHandler) - // --- Global Routes --- - admin.GET("/tokens", tokensHandler.GetAllTokens) - admin.PUT("/tokens", tokensHandler.UpdateTokens) - admin.GET("/logs", logHandler.GetLogs) - admin.GET("/settings", settingHandler.GetSettings) - admin.PUT("/settings", settingHandler.UpdateSettings) - admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults) + // 系统管理路由 + registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler) - // 用于查询异步任务的状态 - admin.GET("/tasks/:id", taskHandler.GetTaskStatus) + // 仪表盘路由 + registerDashboardRoutes(admin, dashboardHandler) - // 领域模块 + // 领域模块路由 upstreamModule.RegisterRoutes(admin) proxyModule.RegisterRoutes(admin) - // --- 全局仪表盘路由 --- - dashboard := admin.Group("/dashboard") - { - dashboard.GET("/overview", dashboardHandler.GetOverview) - dashboard.GET("/chart", dashboardHandler.GetChart) - dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) // 点击详情 - } } } -// registerWebRoutes +// registerKeyGroupRoutes 注册 KeyGroup 相关路由 +func registerKeyGroupRoutes( + admin *gin.RouterGroup, + keyGroupHandler *handlers.KeyGroupHandler, + apiKeyHandler *handlers.APIKeyHandler, +) { + // 基础路由 + admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup) + admin.GET("/keygroups", keyGroupHandler.GetKeyGroups) + admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder) + + // 特定 KeyGroup 路由 + admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups) + admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup) + admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup) + admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup) + admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats) + admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction) + + // KeyGroup 的 APIKey 子资源 + keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys") + { + keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup) + keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup) + keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup) + keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup) + keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup) + keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping) + } +} + +// registerAPIKeyRoutes 注册 APIKey 全局路由 +func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) { + admin.GET("/apikeys", apiKeyHandler.ListAPIKeys) + admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) + admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) + admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) + admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) +} + +// registerSystemRoutes 注册系统管理路由 +func registerSystemRoutes( + admin *gin.RouterGroup, + tokensHandler *handlers.TokensHandler, + logHandler *handlers.LogHandler, + settingHandler *handlers.SettingHandler, + taskHandler *handlers.TaskHandler, +) { + // Token 管理 + admin.GET("/tokens", tokensHandler.GetAllTokens) + admin.PUT("/tokens", tokensHandler.UpdateTokens) + + // 日志管理 + admin.GET("/logs", logHandler.GetLogs) + + // 设置管理 + admin.GET("/settings", settingHandler.GetSettings) + admin.PUT("/settings", settingHandler.UpdateSettings) + admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults) + + // 任务管理 + admin.GET("/tasks/:id", taskHandler.GetTaskStatus) +} + +// registerDashboardRoutes 注册仪表盘路由 +func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) { + dashboard := admin.Group("/dashboard") + { + dashboard.GET("/overview", dashboardHandler.GetOverview) + dashboard.GET("/chart", dashboardHandler.GetChart) + dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) + } +} + +// registerWebRoutes 注册 Web 页面路由 func registerWebRoutes( router *gin.Engine, authMiddleware gin.HandlerFunc, webAuthHandler *webhandlers.WebAuthHandler, pageHandler *webhandlers.PageHandler, ) { + // 公开的认证路由 router.GET("/login", webAuthHandler.ShowLoginPage) router.POST("/login", webAuthHandler.HandleLogin) router.GET("/logout", webAuthHandler.HandleLogout) - // For Test only router.Run("127.0.0.1:9000") - // 受保护的Admin Web界面 + + // 受保护的管理界面 webGroup := router.Group("/", authMiddleware) - webGroup.Use(authMiddleware) { webGroup.GET("/keys", pageHandler.ShowKeysPage) webGroup.GET("/settings", pageHandler.ShowConfigEditorPage) @@ -197,14 +380,31 @@ func registerWebRoutes( webGroup.GET("/tasks", pageHandler.ShowTasksPage) webGroup.GET("/chat", pageHandler.ShowChatPage) } - } -// registerPublicAPIRoutes 无需后台登录的公共API路由 -func registerPublicAPIRoutes(router *gin.Engine, apiAuthHandler *handlers.APIAuthHandler, securityService *service.SecurityService, settingsManager *settings.SettingsManager) { - ipBanMiddleware := middleware.IPBanMiddleware(securityService, settingsManager) - publicAPIGroup := router.Group("/api") +// registerPublicAPIRoutes 注册公共 API 路由 +func registerPublicAPIRoutes( + router *gin.Engine, + apiAuthHandler *handlers.APIAuthHandler, + securityService *service.SecurityService, + settingsManager *settings.SettingsManager, + logger *logrus.Logger, +) { + // 创建 IP 封禁中间件 + ipBanCache := middleware.NewIPBanCache() + ipBanMiddleware := middleware.IPBanMiddleware( + securityService, + settingsManager, + ipBanCache, + logger, + ) + + // 公共 API 路由组 + publicAPI := router.Group("/api") { - publicAPIGroup.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin) + publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin) + // 可以在这里添加其他公共 API 路由 + // publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister) + // publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword) } } diff --git a/internal/service/analytics_service.go b/internal/service/analytics_service.go index 3c0cad9..e2f8d8f 100644 --- a/internal/service/analytics_service.go +++ b/internal/service/analytics_service.go @@ -5,93 +5,179 @@ import ( "context" "encoding/json" "fmt" - "gemini-balancer/internal/db/dialect" - "gemini-balancer/internal/models" - "gemini-balancer/internal/store" "strconv" "strings" "sync" + "sync/atomic" "time" + "gemini-balancer/internal/db/dialect" + "gemini-balancer/internal/models" + "gemini-balancer/internal/settings" + "gemini-balancer/internal/store" + "github.com/sirupsen/logrus" "gorm.io/gorm" ) const ( - flushLoopInterval = 1 * time.Minute + defaultFlushInterval = 1 * time.Minute + maxRetryAttempts = 3 + retryDelay = 5 * time.Second ) -type AnalyticsServiceLogger struct{ *logrus.Entry } - type AnalyticsService struct { - db *gorm.DB - store store.Store - logger *logrus.Entry + db *gorm.DB + store store.Store + logger *logrus.Entry + dialect dialect.DialectAdapter + settingsManager *settings.SettingsManager + stopChan chan struct{} wg sync.WaitGroup - dialect dialect.DialectAdapter + ctx context.Context + cancel context.CancelFunc + + // 统计指标 + eventsReceived atomic.Uint64 + eventsProcessed atomic.Uint64 + eventsFailed atomic.Uint64 + flushCount atomic.Uint64 + recordsFlushed atomic.Uint64 + flushErrors atomic.Uint64 + lastFlushTime time.Time + lastFlushMutex sync.RWMutex + + // 运行时配置 + flushInterval time.Duration + configMutex sync.RWMutex } -func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService { +func NewAnalyticsService( + db *gorm.DB, + s store.Store, + logger *logrus.Logger, + d dialect.DialectAdapter, + settingsManager *settings.SettingsManager, +) *AnalyticsService { + ctx, cancel := context.WithCancel(context.Background()) + return &AnalyticsService{ - db: db, - store: s, - logger: logger.WithField("component", "Analytics📊"), - stopChan: make(chan struct{}), - dialect: d, + db: db, + store: s, + logger: logger.WithField("component", "Analytics📊"), + dialect: d, + settingsManager: settingsManager, + stopChan: make(chan struct{}), + ctx: ctx, + cancel: cancel, + flushInterval: defaultFlushInterval, + lastFlushTime: time.Now(), } } func (s *AnalyticsService) Start() { - s.wg.Add(2) - go s.flushLoop() + s.wg.Add(3) go s.eventListener() - s.logger.Info("AnalyticsService (Command Side) started.") + go s.flushLoop() + go s.metricsReporter() + + s.logger.WithFields(logrus.Fields{ + "flush_interval": s.flushInterval, + }).Info("AnalyticsService started") } func (s *AnalyticsService) Stop() { + s.logger.Info("AnalyticsService stopping...") close(s.stopChan) + s.cancel() s.wg.Wait() - s.logger.Info("AnalyticsService stopped. Performing final data flush...") + + s.logger.Info("Performing final data flush...") s.flushToDB() - s.logger.Info("AnalyticsService final data flush completed.") + + // 输出最终统计 + s.logger.WithFields(logrus.Fields{ + "events_received": s.eventsReceived.Load(), + "events_processed": s.eventsProcessed.Load(), + "events_failed": s.eventsFailed.Load(), + "flush_count": s.flushCount.Load(), + "records_flushed": s.recordsFlushed.Load(), + "flush_errors": s.flushErrors.Load(), + }).Info("AnalyticsService stopped") } +// 事件监听循环 func (s *AnalyticsService) eventListener() { defer s.wg.Done() - sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished) + + sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished) if err != nil { - s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) + s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled") return } - defer sub.Close() - s.logger.Info("AnalyticsService subscribed to request events.") + defer func() { + if err := sub.Close(); err != nil { + s.logger.WithError(err).Warn("Failed to close subscription") + } + }() + + s.logger.Info("Subscribed to request events for analytics") for { select { case msg := <-sub.Channel(): - var event models.RequestFinishedEvent - if err := json.Unmarshal(msg.Payload, &event); err != nil { - s.logger.Errorf("Failed to unmarshal analytics event: %v", err) - continue - } - s.handleAnalyticsEvent(&event) + s.handleMessage(msg) + case <-s.stopChan: - s.logger.Info("AnalyticsService stopping event listener.") + s.logger.Info("Event listener stopping") + return + + case <-s.ctx.Done(): + s.logger.Info("Event listener context cancelled") return } } } -func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) { - if event.RequestLog.GroupID == nil { +// 处理单条消息 +func (s *AnalyticsService) handleMessage(msg *store.Message) { + var event models.RequestFinishedEvent + if err := json.Unmarshal(msg.Payload, &event); err != nil { + s.logger.WithError(err).Error("Failed to unmarshal analytics event") + s.eventsFailed.Add(1) return } - ctx := context.Background() - key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15")) + + s.eventsReceived.Add(1) + + if err := s.handleAnalyticsEvent(&event); err != nil { + s.eventsFailed.Add(1) + s.logger.WithFields(logrus.Fields{ + "correlation_id": event.CorrelationID, + "group_id": event.RequestLog.GroupID, + }).WithError(err).Warn("Failed to process analytics event") + } else { + s.eventsProcessed.Add(1) + } +} + +// 处理分析事件 +func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error { + if event.RequestLog.GroupID == nil { + return nil // 跳过无 GroupID 的事件 + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now().UTC() + key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")) fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName) + pipe := s.store.Pipeline(ctx) pipe.HIncrBy(key, fieldPrefix+":requests", 1) + if event.RequestLog.IsSuccess { pipe.HIncrBy(key, fieldPrefix+":success", 1) } @@ -101,80 +187,213 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve if event.RequestLog.CompletionTokens > 0 { pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens)) } + + // 设置过期时间(保留48小时) + pipe.Expire(key, 48*time.Hour) + if err := pipe.Exec(); err != nil { - s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, *event.RequestLog.GroupID, err) + return fmt.Errorf("redis pipeline failed: %w", err) } + + return nil } +// 刷新循环 func (s *AnalyticsService) flushLoop() { defer s.wg.Done() - ticker := time.NewTicker(flushLoopInterval) + + s.configMutex.RLock() + interval := s.flushInterval + s.configMutex.RUnlock() + + ticker := time.NewTicker(interval) defer ticker.Stop() + + s.logger.WithField("interval", interval).Info("Flush loop started") + for { select { case <-ticker.C: s.flushToDB() + case <-s.stopChan: + s.logger.Info("Flush loop stopping") + return + + case <-s.ctx.Done(): return } } } +// 刷写到数据库 func (s *AnalyticsService) flushToDB() { - ctx := context.Background() + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + now := time.Now().UTC() - keysToFlush := []string{ - fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")), - fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")), - } + keysToFlush := s.generateFlushKeys(now) + + totalRecords := 0 + totalErrors := 0 for _, key := range keysToFlush { - data, err := s.store.HGetAll(ctx, key) - if err != nil || len(data) == 0 { - continue + records, err := s.flushSingleKey(ctx, key, now) + if err != nil { + s.logger.WithError(err).WithField("key", key).Error("Failed to flush key") + totalErrors++ + s.flushErrors.Add(1) + } else { + totalRecords += records } + } - statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data) + s.recordsFlushed.Add(uint64(totalRecords)) + s.flushCount.Add(1) - if len(statsToFlush) > 0 { - upsertClause := s.dialect.OnConflictUpdateAll( - []string{"time", "group_id", "model_name"}, - []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, - ) - err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error - if err != nil { - s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err) - } else { - s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key) - _ = s.store.HDel(ctx, key, parsedFields...) - } - } + s.lastFlushMutex.Lock() + s.lastFlushTime = time.Now() + s.lastFlushMutex.Unlock() + + duration := time.Since(start) + + if totalRecords > 0 || totalErrors > 0 { + s.logger.WithFields(logrus.Fields{ + "records_flushed": totalRecords, + "keys_processed": len(keysToFlush), + "errors": totalErrors, + "duration": duration, + }).Info("Analytics data flush completed") + } else { + s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)") } } +// 生成需要刷新的 Redis 键 +func (s *AnalyticsService) generateFlushKeys(now time.Time) []string { + keys := make([]string, 0, 4) + + // 当前小时 + keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))) + + // 前3个小时(处理延迟和时区问题) + for i := 1; i <= 3; i++ { + pastHour := now.Add(-time.Duration(i) * time.Hour) + keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15"))) + } + + return keys +} + +// 刷写单个 Redis 键 +func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) { + data, err := s.store.HGetAll(ctx, key) + if err != nil { + return 0, fmt.Errorf("failed to get hash data: %w", err) + } + + if len(data) == 0 { + return 0, nil // 无数据,跳过 + } + + // 解析时间戳 + hourStr := strings.TrimPrefix(key, "analytics:hourly:") + recordTime, err := time.Parse("2006-01-02T15", hourStr) + if err != nil { + s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key") + recordTime = baseTime.Truncate(time.Hour) + } + + statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data) + + if len(statsToFlush) == 0 { + return 0, nil + } + + // 使用事务 + 重试机制 + var dbErr error + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush) + if dbErr == nil { + break + } + + if attempt < maxRetryAttempts { + s.logger.WithFields(logrus.Fields{ + "attempt": attempt, + "key": key, + }).WithError(dbErr).Warn("Database upsert failed, retrying...") + time.Sleep(retryDelay) + } + } + + if dbErr != nil { + return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr) + } + + // 删除已处理的字段 + if len(parsedFields) > 0 { + if err := s.store.HDel(ctx, key, parsedFields...); err != nil { + s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis") + } + } + + return len(statsToFlush), nil +} + +// 使用事务批量 upsert +func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + upsertClause := s.dialect.OnConflictUpdateAll( + []string{"time", "group_id", "model_name"}, + []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, + ) + return tx.Clauses(upsertClause).Create(&stats).Error + }) +} + +// 解析 Redis Hash 数据 func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) { tempAggregator := make(map[string]*models.StatsHourly) - var parsedFields []string + parsedFields := make([]string, 0, len(data)) + for field, valueStr := range data { parts := strings.Split(field, ":") if len(parts) != 3 { + s.logger.WithField("field", field).Warn("Invalid field format") continue } - groupIDStr, modelName, counterType := parts[0], parts[1], parts[2] + groupIDStr, modelName, counterType := parts[0], parts[1], parts[2] aggKey := groupIDStr + ":" + modelName + if _, ok := tempAggregator[aggKey]; !ok { gid, err := strconv.Atoi(groupIDStr) if err != nil { + s.logger.WithFields(logrus.Fields{ + "field": field, + "group_id": groupIDStr, + }).Warn("Invalid group ID") continue } + tempAggregator[aggKey] = &models.StatsHourly{ Time: t, GroupID: uint(gid), ModelName: modelName, } } - val, _ := strconv.ParseInt(valueStr, 10, 64) + + val, err := strconv.ParseInt(valueStr, 10, 64) + if err != nil { + s.logger.WithFields(logrus.Fields{ + "field": field, + "value": valueStr, + }).Warn("Invalid counter value") + continue + } + switch counterType { case "requests": tempAggregator[aggKey].RequestCount = val @@ -184,14 +403,92 @@ func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]strin tempAggregator[aggKey].PromptTokens = val case "completion": tempAggregator[aggKey].CompletionTokens = val + default: + s.logger.WithField("counter_type", counterType).Warn("Unknown counter type") + continue } + parsedFields = append(parsedFields, field) } - var result []models.StatsHourly + + result := make([]models.StatsHourly, 0, len(tempAggregator)) for _, stats := range tempAggregator { if stats.RequestCount > 0 { result = append(result, *stats) } } + return result, parsedFields } + +// 定期输出统计信息 +func (s *AnalyticsService) metricsReporter() { + defer s.wg.Done() + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.reportMetrics() + case <-s.stopChan: + return + case <-s.ctx.Done(): + return + } + } +} + +func (s *AnalyticsService) reportMetrics() { + s.lastFlushMutex.RLock() + lastFlush := s.lastFlushTime + s.lastFlushMutex.RUnlock() + + received := s.eventsReceived.Load() + processed := s.eventsProcessed.Load() + failed := s.eventsFailed.Load() + + var successRate float64 + if received > 0 { + successRate = float64(processed) / float64(received) * 100 + } + + s.logger.WithFields(logrus.Fields{ + "events_received": received, + "events_processed": processed, + "events_failed": failed, + "success_rate": fmt.Sprintf("%.2f%%", successRate), + "flush_count": s.flushCount.Load(), + "records_flushed": s.recordsFlushed.Load(), + "flush_errors": s.flushErrors.Load(), + "last_flush_ago": time.Since(lastFlush).Round(time.Second), + }).Info("Analytics metrics") +} + +// GetMetrics 返回当前统计指标(供监控使用) +func (s *AnalyticsService) GetMetrics() map[string]interface{} { + s.lastFlushMutex.RLock() + lastFlush := s.lastFlushTime + s.lastFlushMutex.RUnlock() + + received := s.eventsReceived.Load() + processed := s.eventsProcessed.Load() + + var successRate float64 + if received > 0 { + successRate = float64(processed) / float64(received) * 100 + } + + return map[string]interface{}{ + "events_received": received, + "events_processed": processed, + "events_failed": s.eventsFailed.Load(), + "success_rate": successRate, + "flush_count": s.flushCount.Load(), + "records_flushed": s.recordsFlushed.Load(), + "flush_errors": s.flushErrors.Load(), + "last_flush_ago": time.Since(lastFlush).Seconds(), + "flush_interval": s.flushInterval.Seconds(), + } +} diff --git a/internal/service/dashboard_query_service.go b/internal/service/dashboard_query_service.go index d1a6764..2258111 100644 --- a/internal/service/dashboard_query_service.go +++ b/internal/service/dashboard_query_service.go @@ -4,158 +4,297 @@ package service import ( "context" "fmt" + "strconv" + "sync" + "sync/atomic" + "time" + "gemini-balancer/internal/models" "gemini-balancer/internal/store" "gemini-balancer/internal/syncer" - "strconv" - "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "gorm.io/gorm" ) -const overviewCacheChannel = "syncer:cache:dashboard_overview" +const ( + overviewCacheChannel = "syncer:cache:dashboard_overview" + defaultChartDays = 7 + cacheLoadTimeout = 30 * time.Second +) + +var ( + // 图表颜色调色板 + chartColorPalette = []string{ + "#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", + "#9966FF", "#FF9F40", "#C9CBCF", "#4D5360", + } +) type DashboardQueryService struct { db *gorm.DB store store.Store overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse] logger *logrus.Entry - stopChan chan struct{} + + stopChan chan struct{} + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + + // 统计指标 + queryCount atomic.Uint64 + cacheHits atomic.Uint64 + cacheMisses atomic.Uint64 + overviewLoadCount atomic.Uint64 + lastQueryTime time.Time + lastQueryMutex sync.RWMutex } -func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) { - qs := &DashboardQueryService{ - db: db, - store: s, - logger: logger.WithField("component", "DashboardQueryService"), - stopChan: make(chan struct{}), +func NewDashboardQueryService( + db *gorm.DB, + s store.Store, + logger *logrus.Logger, +) (*DashboardQueryService, error) { + ctx, cancel := context.WithCancel(context.Background()) + + service := &DashboardQueryService{ + db: db, + store: s, + logger: logger.WithField("component", "DashboardQuery📈"), + stopChan: make(chan struct{}), + ctx: ctx, + cancel: cancel, + lastQueryTime: time.Now(), } - loader := qs.loadOverviewData - overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel) + // 创建 CacheSyncer + overviewSyncer, err := syncer.NewCacheSyncer( + service.loadOverviewData, + s, + overviewCacheChannel, + logger, + ) if err != nil { + cancel() return nil, fmt.Errorf("failed to create overview cache syncer: %w", err) } - qs.overviewSyncer = overviewSyncer - return qs, nil + service.overviewSyncer = overviewSyncer + + return service, nil } func (s *DashboardQueryService) Start() { + s.wg.Add(2) go s.eventListener() - s.logger.Info("DashboardQueryService started and listening for invalidation events.") + go s.metricsReporter() + + s.logger.Info("DashboardQueryService started") } func (s *DashboardQueryService) Stop() { + s.logger.Info("DashboardQueryService stopping...") close(s.stopChan) - s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.") + s.cancel() + s.wg.Wait() + + // 输出最终统计 + s.logger.WithFields(logrus.Fields{ + "total_queries": s.queryCount.Load(), + "cache_hits": s.cacheHits.Load(), + "cache_misses": s.cacheMisses.Load(), + "overview_loads": s.overviewLoadCount.Load(), + }).Info("DashboardQueryService stopped") } +// ==================== 核心查询方法 ==================== + +// GetDashboardOverviewData 获取仪表盘概览数据(带缓存) +func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) { + s.queryCount.Add(1) + + cachedDataPtr := s.overviewSyncer.Get() + if cachedDataPtr == nil { + s.cacheMisses.Add(1) + s.logger.Warn("Overview cache is empty, attempting to load...") + + // 触发立即加载 + if err := s.overviewSyncer.Invalidate(); err != nil { + return nil, fmt.Errorf("failed to trigger cache reload: %w", err) + } + + // 等待加载完成(最多30秒) + ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if data := s.overviewSyncer.Get(); data != nil { + s.cacheHits.Add(1) + return data, nil + } + case <-ctx.Done(): + return nil, fmt.Errorf("timeout waiting for overview cache to load") + } + } + } + + s.cacheHits.Add(1) + return cachedDataPtr, nil +} + +// InvalidateOverviewCache 手动失效概览缓存 +func (s *DashboardQueryService) InvalidateOverviewCache() error { + s.logger.Info("Manually invalidating overview cache") + return s.overviewSyncer.Invalidate() +} + +// GetGroupStats 获取指定分组的统计数据 func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) { + s.queryCount.Add(1) + s.updateLastQueryTime() + + start := time.Now() + + // 1. 从 Redis 获取 Key 统计 statsKey := fmt.Sprintf("stats:group:%d", groupID) keyStatsMap, err := s.store.HGetAll(ctx, statsKey) if err != nil { - s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID) + s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID) return nil, fmt.Errorf("failed to get key stats from cache: %w", err) } + keyStats := make(map[string]int64) for k, v := range keyStatsMap { val, _ := strconv.ParseInt(v, 10, 64) keyStats[k] = val } - now := time.Now() + + // 2. 查询请求统计(使用 UTC 时间) + now := time.Now().UTC() oneHourAgo := now.Add(-1 * time.Hour) twentyFourHoursAgo := now.Add(-24 * time.Hour) + type requestStatsResult struct { TotalRequests int64 SuccessRequests int64 } + var last1Hour, last24Hours requestStatsResult - s.db.WithContext(ctx).Model(&models.StatsHourly{}). - Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). - Where("group_id = ? AND time >= ?", groupID, oneHourAgo). - Scan(&last1Hour) - s.db.WithContext(ctx).Model(&models.StatsHourly{}). - Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). - Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo). - Scan(&last24Hours) - failureRate1h := 0.0 - if last1Hour.TotalRequests > 0 { - failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100 - } - failureRate24h := 0.0 - if last24Hours.TotalRequests > 0 { - failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100 - } - last1HourStats := map[string]any{ - "total_requests": last1Hour.TotalRequests, - "success_requests": last1Hour.SuccessRequests, - "failure_rate": failureRate1h, - } - last24HoursStats := map[string]any{ - "total_requests": last24Hours.TotalRequests, - "success_requests": last24Hours.SuccessRequests, - "failure_rate": failureRate24h, + + // 并发查询优化 + var wg sync.WaitGroup + errChan := make(chan error, 2) + + wg.Add(2) + + // 查询最近1小时 + go func() { + defer wg.Done() + if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}). + Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests"). + Where("group_id = ? AND time >= ?", groupID, oneHourAgo). + Scan(&last1Hour).Error; err != nil { + errChan <- fmt.Errorf("failed to query 1h stats: %w", err) + } + }() + + // 查询最近24小时 + go func() { + defer wg.Done() + if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}). + Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests"). + Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo). + Scan(&last24Hours).Error; err != nil { + errChan <- fmt.Errorf("failed to query 24h stats: %w", err) + } + }() + + wg.Wait() + close(errChan) + + // 检查错误 + for err := range errChan { + if err != nil { + return nil, err + } } + + // 3. 计算失败率 + failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests) + failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests) + result := map[string]any{ - "key_stats": keyStats, - "last_1_hour": last1HourStats, - "last_24_hours": last24HoursStats, + "key_stats": keyStats, + "last_1_hour": map[string]any{ + "total_requests": last1Hour.TotalRequests, + "success_requests": last1Hour.SuccessRequests, + "failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests, + "failure_rate": failureRate1h, + }, + "last_24_hours": map[string]any{ + "total_requests": last24Hours.TotalRequests, + "success_requests": last24Hours.SuccessRequests, + "failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests, + "failure_rate": failureRate24h, + }, } + + duration := time.Since(start) + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "duration": duration, + }).Debug("Group stats query completed") + return result, nil } -func (s *DashboardQueryService) eventListener() { - ctx := context.Background() - keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged) - upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged) - defer keyStatusSub.Close() - defer upstreamStatusSub.Close() - for { - select { - case <-keyStatusSub.Channel(): - s.logger.Info("Received key status changed event, invalidating overview cache...") - _ = s.InvalidateOverviewCache() - case <-upstreamStatusSub.Channel(): - s.logger.Info("Received upstream status changed event, invalidating overview cache...") - _ = s.InvalidateOverviewCache() - case <-s.stopChan: - s.logger.Info("Stopping dashboard event listener.") - return - } - } -} - -func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) { - cachedDataPtr := s.overviewSyncer.Get() - if cachedDataPtr == nil { - return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing") - } - return cachedDataPtr, nil -} - -func (s *DashboardQueryService) InvalidateOverviewCache() error { - return s.overviewSyncer.Invalidate() -} - +// QueryHistoricalChart 查询历史图表数据 func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) { + s.queryCount.Add(1) + s.updateLastQueryTime() + + start := time.Now() + type ChartPoint struct { TimeLabel string `gorm:"column:time_label"` ModelName string `gorm:"column:model_name"` TotalRequests int64 `gorm:"column:total_requests"` } - sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour) + + // 查询最近7天数据(使用 UTC) + sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour) + + // 根据数据库类型构建时间格式化子句 sqlFormat, goFormat := s.buildTimeFormatSelectClause() - selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat) - query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC") + selectClause := fmt.Sprintf( + "%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests", + sqlFormat, + ) + + // 构建查询 + query := s.db.WithContext(ctx). + Model(&models.StatsHourly{}). + Select(selectClause). + Where("time >= ?", sevenDaysAgo). + Group("time_label, model_name"). + Order("time_label ASC") + if groupID != nil && *groupID > 0 { query = query.Where("group_id = ?", *groupID) } + var points []ChartPoint if err := query.Find(&points).Error; err != nil { - return nil, err + return nil, fmt.Errorf("failed to query chart data: %w", err) } + + // 构建数据集 datasets := make(map[string]map[string]int64) for _, p := range points { if _, ok := datasets[p.ModelName]; !ok { @@ -163,32 +302,99 @@ func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupI } datasets[p.ModelName][p.TimeLabel] = p.TotalRequests } + + // 生成时间标签(按小时) var labels []string - for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) { + for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) { labels = append(labels, t.Format(goFormat)) } - chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)} - colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"} + + // 构建图表数据 + chartData := &models.ChartData{ + Labels: labels, + Datasets: make([]models.ChartDataset, 0, len(datasets)), + } + colorIndex := 0 for modelName, dataPoints := range datasets { dataArray := make([]int64, len(labels)) for i, label := range labels { dataArray[i] = dataPoints[label] } + chartData.Datasets = append(chartData.Datasets, models.ChartDataset{ Label: modelName, Data: dataArray, - Color: colorPalette[colorIndex%len(colorPalette)], + Color: chartColorPalette[colorIndex%len(chartColorPalette)], }) colorIndex++ } + + duration := time.Since(start) + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "points": len(points), + "datasets": len(chartData.Datasets), + "duration": duration, + }).Debug("Historical chart query completed") + return chartData, nil } +// GetRequestStatsForPeriod 获取指定时间段的请求统计 +func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) { + s.queryCount.Add(1) + s.updateLastQueryTime() + + var startTime time.Time + now := time.Now().UTC() + + switch period { + case "1m": + startTime = now.Add(-1 * time.Minute) + case "1h": + startTime = now.Add(-1 * time.Hour) + case "1d": + year, month, day := now.Date() + startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC) + default: + return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period) + } + + var result struct { + Total int64 + Success int64 + } + + err := s.db.WithContext(ctx).Model(&models.RequestLog{}). + Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success"). + Where("request_time >= ?", startTime). + Scan(&result).Error + + if err != nil { + return nil, fmt.Errorf("failed to query request stats: %w", err) + } + + return gin.H{ + "period": period, + "total": result.Total, + "success": result.Success, + "failure": result.Total - result.Success, + }, nil +} + +// ==================== 内部方法 ==================== + +// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用) func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) { - ctx := context.Background() - s.logger.Info("[CacheSyncer] Starting to load overview data from database...") + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + s.overviewLoadCount.Add(1) startTime := time.Now() + + s.logger.Info("Starting to load dashboard overview data...") + resp := &models.DashboardStatsResponse{ KeyStatusCount: make(map[models.APIKeyStatus]int64), MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64), @@ -200,108 +406,391 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon RequestCounts: make(map[string]int64), } + var loadErr error + var wg sync.WaitGroup + errChan := make(chan error, 10) + + // 1. 并发加载 Key 映射状态统计 + wg.Add(1) + go func() { + defer wg.Done() + if err := s.loadMappingStatusStats(ctx, resp); err != nil { + errChan <- fmt.Errorf("mapping stats: %w", err) + } + }() + + // 2. 并发加载 Master Key 状态统计 + wg.Add(1) + go func() { + defer wg.Done() + if err := s.loadMasterStatusStats(ctx, resp); err != nil { + errChan <- fmt.Errorf("master stats: %w", err) + } + }() + + // 3. 并发加载请求统计 + wg.Add(1) + go func() { + defer wg.Done() + if err := s.loadRequestCounts(ctx, resp); err != nil { + errChan <- fmt.Errorf("request counts: %w", err) + } + }() + + // 4. 并发加载上游健康状态 + wg.Add(1) + go func() { + defer wg.Done() + if err := s.loadUpstreamHealth(ctx, resp); err != nil { + // 上游健康状态失败不阻塞整体加载 + s.logger.WithError(err).Warn("Failed to load upstream health status") + } + }() + + // 等待所有加载任务完成 + wg.Wait() + close(errChan) + + // 收集错误 + for err := range errChan { + if err != nil { + loadErr = err + break + } + } + + if loadErr != nil { + s.logger.WithError(loadErr).Error("Failed to load overview data") + return nil, loadErr + } + + duration := time.Since(startTime) + s.logger.WithFields(logrus.Fields{ + "duration": duration, + "total_keys": resp.KeyCount.Value, + "requests_1d": resp.RequestCounts["1d"], + "upstreams": len(resp.UpstreamHealthStatus), + }).Info("Successfully loaded dashboard overview data") + + return resp, nil +} + +// loadMappingStatusStats 加载 Key 映射状态统计 +func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error { type MappingStatusResult struct { Status models.APIKeyStatus Count int64 } - var mappingStatusResults []MappingStatusResult - if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil { - return nil, fmt.Errorf("failed to query mapping status stats: %w", err) + + var results []MappingStatusResult + if err := s.db.WithContext(ctx). + Model(&models.GroupAPIKeyMapping{}). + Select("status, COUNT(*) as count"). + Group("status"). + Find(&results).Error; err != nil { + return err } - for _, res := range mappingStatusResults { + + for _, res := range results { resp.KeyStatusCount[res.Status] = res.Count } + return nil +} + +// loadMasterStatusStats 加载 Master Key 状态统计 +func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error { type MasterStatusResult struct { Status models.MasterAPIKeyStatus Count int64 } - var masterStatusResults []MasterStatusResult - if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil { - return nil, fmt.Errorf("failed to query master status stats: %w", err) + + var results []MasterStatusResult + if err := s.db.WithContext(ctx). + Model(&models.APIKey{}). + Select("master_status as status, COUNT(*) as count"). + Group("master_status"). + Find(&results).Error; err != nil { + return err } + var totalKeys, invalidKeys int64 - for _, res := range masterStatusResults { + for _, res := range results { resp.MasterStatusCount[res.Status] = res.Count totalKeys += res.Count if res.Status != models.MasterStatusActive { invalidKeys += res.Count } } - resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"} - now := time.Now() + resp.KeyCount = models.StatCard{ + Value: float64(totalKeys), + SubValue: invalidKeys, + SubValueTip: "非活跃身份密钥数", + } - var count1m, count1h, count1d int64 - s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m) - s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h) - year, month, day := now.UTC().Date() + return nil +} + +// loadRequestCounts 加载请求计数统计 +func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error { + now := time.Now().UTC() + + // 使用 RequestLog 表查询短期数据 + var count1m, count1h int64 + + // 最近1分钟 + if err := s.db.WithContext(ctx). + Model(&models.RequestLog{}). + Where("request_time >= ?", now.Add(-1*time.Minute)). + Count(&count1m).Error; err != nil { + return fmt.Errorf("1m count: %w", err) + } + + // 最近1小时 + if err := s.db.WithContext(ctx). + Model(&models.RequestLog{}). + Where("request_time >= ?", now.Add(-1*time.Hour)). + Count(&count1h).Error; err != nil { + return fmt.Errorf("1h count: %w", err) + } + + // 今天(UTC) + year, month, day := now.Date() startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC) - s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d) + var count1d int64 + if err := s.db.WithContext(ctx). + Model(&models.RequestLog{}). + Where("request_time >= ?", startOfDay). + Count(&count1d).Error; err != nil { + return fmt.Errorf("1d count: %w", err) + } + + // 最近30天使用聚合表 var count30d int64 - s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d) + if err := s.db.WithContext(ctx). + Model(&models.StatsHourly{}). + Where("time >= ?", now.AddDate(0, 0, -30)). + Select("COALESCE(SUM(request_count), 0)"). + Scan(&count30d).Error; err != nil { + return fmt.Errorf("30d count: %w", err) + } resp.RequestCounts["1m"] = count1m resp.RequestCounts["1h"] = count1h resp.RequestCounts["1d"] = count1d resp.RequestCounts["30d"] = count30d + return nil +} + +// loadUpstreamHealth 加载上游健康状态 +func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error { var upstreams []*models.UpstreamEndpoint if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil { - s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.") - } else { - for _, u := range upstreams { - resp.UpstreamHealthStatus[u.URL] = u.Status + return err + } + + for _, u := range upstreams { + resp.UpstreamHealthStatus[u.URL] = u.Status + } + + return nil +} + +// ==================== 事件监听 ==================== + +// eventListener 监听缓存失效事件 +func (s *DashboardQueryService) eventListener() { + defer s.wg.Done() + + // 订阅事件 + keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged) + upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged) + + // 错误处理 + if err1 != nil { + s.logger.WithError(err1).Error("Failed to subscribe to key status events") + keyStatusSub = nil + } + if err2 != nil { + s.logger.WithError(err2).Error("Failed to subscribe to upstream status events") + upstreamStatusSub = nil + } + + // 如果全部失败,直接返回 + if keyStatusSub == nil && upstreamStatusSub == nil { + s.logger.Error("All event subscriptions failed, listener disabled") + return + } + + // 安全关闭订阅 + defer func() { + if keyStatusSub != nil { + if err := keyStatusSub.Close(); err != nil { + s.logger.WithError(err).Warn("Failed to close key status subscription") + } + } + if upstreamStatusSub != nil { + if err := upstreamStatusSub.Close(); err != nil { + s.logger.WithError(err).Warn("Failed to close upstream status subscription") + } + } + }() + + s.logger.WithFields(logrus.Fields{ + "key_status_sub": keyStatusSub != nil, + "upstream_status_sub": upstreamStatusSub != nil, + }).Info("Event listener started") + + neverReady := make(chan *store.Message) + close(neverReady) // 立即关闭,确保永远不会阻塞 + + for { + // 动态选择有效的 channel + var keyStatusChan <-chan *store.Message = neverReady + if keyStatusSub != nil { + keyStatusChan = keyStatusSub.Channel() + } + + var upstreamStatusChan <-chan *store.Message = neverReady + if upstreamStatusSub != nil { + upstreamStatusChan = upstreamStatusSub.Channel() + } + + select { + case _, ok := <-keyStatusChan: + if !ok { + s.logger.Warn("Key status channel closed") + keyStatusSub = nil + continue + } + s.logger.Debug("Received key status changed event") + if err := s.InvalidateOverviewCache(); err != nil { + s.logger.WithError(err).Warn("Failed to invalidate cache on key status change") + } + + case _, ok := <-upstreamStatusChan: + if !ok { + s.logger.Warn("Upstream status channel closed") + upstreamStatusSub = nil + continue + } + s.logger.Debug("Received upstream status changed event") + if err := s.InvalidateOverviewCache(); err != nil { + s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change") + } + + case <-s.stopChan: + s.logger.Info("Event listener stopping (stopChan)") + return + + case <-s.ctx.Done(): + s.logger.Info("Event listener stopping (context cancelled)") + return } } - - duration := time.Since(startTime) - s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration) - return resp, nil } -func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) { - var startTime time.Time - now := time.Now() - switch period { - case "1m": - startTime = now.Add(-1 * time.Minute) - case "1h": - startTime = now.Add(-1 * time.Hour) - case "1d": - year, month, day := now.UTC().Date() - startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC) - default: - return nil, fmt.Errorf("invalid period specified: %s", period) - } - var result struct { - Total int64 - Success int64 - } +// ==================== 监控指标 ==================== - err := s.db.WithContext(ctx).Model(&models.RequestLog{}). - Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success"). - Where("request_time >= ?", startTime). - Scan(&result).Error - if err != nil { - return nil, err +// metricsReporter 定期输出统计信息 +func (s *DashboardQueryService) metricsReporter() { + defer s.wg.Done() + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.reportMetrics() + case <-s.stopChan: + return + case <-s.ctx.Done(): + return + } } - return gin.H{ - "total": result.Total, - "success": result.Success, - "failure": result.Total - result.Success, - }, nil } +func (s *DashboardQueryService) reportMetrics() { + s.lastQueryMutex.RLock() + lastQuery := s.lastQueryTime + s.lastQueryMutex.RUnlock() + + totalQueries := s.queryCount.Load() + hits := s.cacheHits.Load() + misses := s.cacheMisses.Load() + + var cacheHitRate float64 + if hits+misses > 0 { + cacheHitRate = float64(hits) / float64(hits+misses) * 100 + } + + s.logger.WithFields(logrus.Fields{ + "total_queries": totalQueries, + "cache_hits": hits, + "cache_misses": misses, + "cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate), + "overview_loads": s.overviewLoadCount.Load(), + "last_query_ago": time.Since(lastQuery).Round(time.Second), + }).Info("DashboardQuery metrics") +} + +// GetMetrics 返回当前统计指标(供监控使用) +func (s *DashboardQueryService) GetMetrics() map[string]interface{} { + s.lastQueryMutex.RLock() + lastQuery := s.lastQueryTime + s.lastQueryMutex.RUnlock() + + hits := s.cacheHits.Load() + misses := s.cacheMisses.Load() + + var cacheHitRate float64 + if hits+misses > 0 { + cacheHitRate = float64(hits) / float64(hits+misses) * 100 + } + + return map[string]interface{}{ + "total_queries": s.queryCount.Load(), + "cache_hits": hits, + "cache_misses": misses, + "cache_hit_rate": cacheHitRate, + "overview_loads": s.overviewLoadCount.Load(), + "last_query_ago": time.Since(lastQuery).Seconds(), + } +} + +// ==================== 辅助方法 ==================== + +// calculateFailureRate 计算失败率 +func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 { + if total == 0 { + return 0.0 + } + return float64(total-success) / float64(total) * 100 +} + +// updateLastQueryTime 更新最后查询时间 +func (s *DashboardQueryService) updateLastQueryTime() { + s.lastQueryMutex.Lock() + s.lastQueryTime = time.Now() + s.lastQueryMutex.Unlock() +} + +// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句 func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) { dialect := s.db.Dialector.Name() switch dialect { case "mysql": return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00" + case "postgres": + return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00" case "sqlite": return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00" default: + s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format") return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00" } } diff --git a/internal/service/db_log_writer_service.go b/internal/service/db_log_writer_service.go index b80904a..563db2a 100644 --- a/internal/service/db_log_writer_service.go +++ b/internal/service/db_log_writer_service.go @@ -4,11 +4,13 @@ package service import ( "context" "encoding/json" + "sync" + "sync/atomic" + "time" + "gemini-balancer/internal/models" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" - "sync" - "time" "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -18,25 +20,47 @@ type DBLogWriterService struct { db *gorm.DB store store.Store logger *logrus.Entry - logBuffer chan *models.RequestLog - stopChan chan struct{} - wg sync.WaitGroup - SettingsManager *settings.SettingsManager + settingsManager *settings.SettingsManager + + logBuffer chan *models.RequestLog + stopChan chan struct{} + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + + // 统计指标 + totalReceived atomic.Uint64 + totalFlushed atomic.Uint64 + totalDropped atomic.Uint64 + flushCount atomic.Uint64 + lastFlushTime time.Time + lastFlushMutex sync.RWMutex } -func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService { - cfg := settings.GetSettings() +func NewDBLogWriterService( + db *gorm.DB, + s store.Store, + settingsManager *settings.SettingsManager, + logger *logrus.Logger, +) *DBLogWriterService { + cfg := settingsManager.GetSettings() bufferCapacity := cfg.LogBufferCapacity if bufferCapacity <= 0 { bufferCapacity = 1000 } + + ctx, cancel := context.WithCancel(context.Background()) + return &DBLogWriterService{ db: db, store: s, - SettingsManager: settings, + settingsManager: settingsManager, logger: logger.WithField("component", "DBLogWriter📝"), logBuffer: make(chan *models.RequestLog, bufferCapacity), stopChan: make(chan struct{}), + ctx: ctx, + cancel: cancel, + lastFlushTime: time.Now(), } } @@ -44,93 +68,276 @@ func (s *DBLogWriterService) Start() { s.wg.Add(2) go s.eventListenerLoop() go s.dbWriterLoop() - s.logger.Info("DBLogWriterService started.") + + // 定期输出统计信息 + s.wg.Add(1) + go s.metricsReporter() + + s.logger.WithFields(logrus.Fields{ + "buffer_capacity": cap(s.logBuffer), + }).Info("DBLogWriterService started") } func (s *DBLogWriterService) Stop() { s.logger.Info("DBLogWriterService stopping...") close(s.stopChan) + s.cancel() // 取消上下文 s.wg.Wait() - s.logger.Info("DBLogWriterService stopped.") + + // 输出最终统计 + s.logger.WithFields(logrus.Fields{ + "total_received": s.totalReceived.Load(), + "total_flushed": s.totalFlushed.Load(), + "total_dropped": s.totalDropped.Load(), + "flush_count": s.flushCount.Load(), + }).Info("DBLogWriterService stopped") } +// 事件监听循环 func (s *DBLogWriterService) eventListenerLoop() { defer s.wg.Done() - ctx := context.Background() - sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished) + sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished) if err != nil { - s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) + s.logger.WithError(err).Error("Failed to subscribe to request events, log writing disabled") return } - defer sub.Close() + defer func() { + if err := sub.Close(); err != nil { + s.logger.WithError(err).Warn("Failed to close subscription") + } + }() - s.logger.Info("Subscribed to request events for database logging.") + s.logger.Info("Subscribed to request events for database logging") for { select { case msg := <-sub.Channel(): - var event models.RequestFinishedEvent - if err := json.Unmarshal(msg.Payload, &event); err != nil { - s.logger.Errorf("Failed to unmarshal event for logging: %v", err) - continue - } - select { - case s.logBuffer <- &event.RequestLog: - default: - s.logger.Warn("Log buffer is full. A log message might be dropped.") - } + s.handleMessage(msg) + case <-s.stopChan: - s.logger.Info("Event listener loop stopping.") + s.logger.Info("Event listener loop stopping") + close(s.logBuffer) + return + + case <-s.ctx.Done(): + s.logger.Info("Event listener context cancelled") close(s.logBuffer) return } } } +// 处理单条消息 +func (s *DBLogWriterService) handleMessage(msg *store.Message) { + var event models.RequestFinishedEvent + if err := json.Unmarshal(msg.Payload, &event); err != nil { + s.logger.WithError(err).Error("Failed to unmarshal request event") + return + } + + s.totalReceived.Add(1) + + select { + case s.logBuffer <- &event.RequestLog: + // 成功入队 + default: + // 缓冲区满,丢弃日志 + dropped := s.totalDropped.Add(1) + if dropped%100 == 1 { // 每100条丢失输出一次警告 + s.logger.WithFields(logrus.Fields{ + "total_dropped": dropped, + "buffer_capacity": cap(s.logBuffer), + "buffer_len": len(s.logBuffer), + }).Warn("Log buffer full, messages being dropped") + } + } +} + +// 数据库写入循环 func (s *DBLogWriterService) dbWriterLoop() { defer s.wg.Done() - cfg := s.SettingsManager.GetSettings() + cfg := s.settingsManager.GetSettings() batchSize := cfg.LogFlushBatchSize if batchSize <= 0 { batchSize = 100 } - flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second - if flushTimeout <= 0 { - flushTimeout = 5 * time.Second + flushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second + if flushInterval <= 0 { + flushInterval = 5 * time.Second } + + s.logger.WithFields(logrus.Fields{ + "batch_size": batchSize, + "flush_interval": flushInterval, + }).Info("DB writer loop started") + batch := make([]*models.RequestLog, 0, batchSize) - ticker := time.NewTicker(flushTimeout) + ticker := time.NewTicker(flushInterval) defer ticker.Stop() + + // 配置热更新检查(每分钟) + configTicker := time.NewTicker(1 * time.Minute) + defer configTicker.Stop() + for { select { case logEntry, ok := <-s.logBuffer: if !ok { + // 通道关闭,刷新剩余日志 if len(batch) > 0 { s.flushBatch(batch) } - s.logger.Info("DB writer loop finished.") + s.logger.Info("DB writer loop finished") return } + batch = append(batch, logEntry) if len(batch) >= batchSize { s.flushBatch(batch) batch = make([]*models.RequestLog, 0, batchSize) } + case <-ticker.C: if len(batch) > 0 { s.flushBatch(batch) batch = make([]*models.RequestLog, 0, batchSize) } + + case <-configTicker.C: + // 热更新配置 + cfg := s.settingsManager.GetSettings() + newBatchSize := cfg.LogFlushBatchSize + if newBatchSize <= 0 { + newBatchSize = 100 + } + newFlushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second + if newFlushInterval <= 0 { + newFlushInterval = 5 * time.Second + } + + if newBatchSize != batchSize { + s.logger.WithFields(logrus.Fields{ + "old": batchSize, + "new": newBatchSize, + }).Info("Batch size updated") + batchSize = newBatchSize + if len(batch) >= batchSize { + s.flushBatch(batch) + batch = make([]*models.RequestLog, 0, batchSize) + } + } + + if newFlushInterval != flushInterval { + s.logger.WithFields(logrus.Fields{ + "old": flushInterval, + "new": newFlushInterval, + }).Info("Flush interval updated") + flushInterval = newFlushInterval + ticker.Reset(flushInterval) + } } } } +// 批量刷写到数据库 func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) { - if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil { - s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.") + if len(batch) == 0 { + return + } + + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err := s.db.WithContext(ctx).CreateInBatches(batch, len(batch)).Error + duration := time.Since(start) + + s.lastFlushMutex.Lock() + s.lastFlushTime = time.Now() + s.lastFlushMutex.Unlock() + + if err != nil { + s.logger.WithFields(logrus.Fields{ + "batch_size": len(batch), + "duration": duration, + }).WithError(err).Error("Failed to flush log batch to database") } else { - s.logger.Infof("Successfully flushed %d logs to database.", len(batch)) + flushed := s.totalFlushed.Add(uint64(len(batch))) + flushCount := s.flushCount.Add(1) + + // 只在慢写入或大批量时输出日志 + if duration > 1*time.Second || len(batch) > 500 { + s.logger.WithFields(logrus.Fields{ + "batch_size": len(batch), + "duration": duration, + "total_flushed": flushed, + "flush_count": flushCount, + }).Info("Log batch flushed to database") + } else { + s.logger.WithFields(logrus.Fields{ + "batch_size": len(batch), + "duration": duration, + }).Debug("Log batch flushed to database") + } + } +} + +// 定期输出统计信息 +func (s *DBLogWriterService) metricsReporter() { + defer s.wg.Done() + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.reportMetrics() + case <-s.stopChan: + return + case <-s.ctx.Done(): + return + } + } +} + +func (s *DBLogWriterService) reportMetrics() { + s.lastFlushMutex.RLock() + lastFlush := s.lastFlushTime + s.lastFlushMutex.RUnlock() + + received := s.totalReceived.Load() + flushed := s.totalFlushed.Load() + dropped := s.totalDropped.Load() + pending := uint64(len(s.logBuffer)) + + s.logger.WithFields(logrus.Fields{ + "received": received, + "flushed": flushed, + "dropped": dropped, + "pending": pending, + "flush_count": s.flushCount.Load(), + "last_flush": time.Since(lastFlush).Round(time.Second), + "buffer_usage": float64(pending) / float64(cap(s.logBuffer)) * 100, + "success_rate": float64(flushed) / float64(received) * 100, + }).Info("DBLogWriter metrics") +} + +// GetMetrics 返回当前统计指标(供监控使用) +func (s *DBLogWriterService) GetMetrics() map[string]interface{} { + s.lastFlushMutex.RLock() + lastFlush := s.lastFlushTime + s.lastFlushMutex.RUnlock() + + return map[string]interface{}{ + "total_received": s.totalReceived.Load(), + "total_flushed": s.totalFlushed.Load(), + "total_dropped": s.totalDropped.Load(), + "flush_count": s.flushCount.Load(), + "buffer_pending": len(s.logBuffer), + "buffer_capacity": cap(s.logBuffer), + "last_flush_ago": time.Since(lastFlush).Seconds(), } } diff --git a/internal/service/group_manager.go b/internal/service/group_manager.go index d8690dd..4ec10eb 100644 --- a/internal/service/group_manager.go +++ b/internal/service/group_manager.go @@ -334,7 +334,6 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) { } func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) { globalSettings := gm.settingsManager.GetSettings() - defaultModel := "gemini-1.5-flash" opConfig := &models.KeyGroupSettings{ EnableKeyCheck: &globalSettings.EnableBaseKeyCheck, KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency, @@ -342,7 +341,7 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models. KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL, KeyBlacklistThreshold: &globalSettings.BlacklistThreshold, KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes, - KeyCheckModel: &defaultModel, + KeyCheckModel: &globalSettings.BaseKeyCheckModel, MaxRetries: &globalSettings.MaxRetries, EnableSmartGateway: &globalSettings.EnableSmartGateway, } diff --git a/internal/service/healthcheck_service.go b/internal/service/healthcheck_service.go index 36030a2..f28aedc 100644 --- a/internal/service/healthcheck_service.go +++ b/internal/service/healthcheck_service.go @@ -1,4 +1,4 @@ -// Filename: internal/service/healthcheck_service.go (最终校准版) +// Filename: internal/service/healthcheck_service.go package service @@ -23,17 +23,30 @@ import ( ) const ( - ProxyCheckTargetURL = "https://www.google.com/generate_204" - DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" - StatusActive = "active" - StatusInactive = "inactive" + // 业务状态常量 + StatusActive = "active" + StatusInactive = "inactive" + + // 代理检查目标(固定不变) + ProxyCheckTargetURL = "https://www.google.com/generate_204" + + // 并发控制边界 + minHealthCheckConcurrency = 1 + maxHealthCheckConcurrency = 100 + defaultKeyCheckConcurrency = 5 + defaultBaseKeyCheckConcurrency = 5 + + // 兜底默认值 + defaultSchedulerIntervalSeconds = 60 + defaultKeyCheckTimeoutSeconds = 30 + defaultBaseKeyCheckEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" ) type HealthCheckServiceLogger struct{ *logrus.Entry } type HealthCheckService struct { db *gorm.DB - SettingsManager *settings.SettingsManager + settingsManager *settings.SettingsManager store store.Store keyRepo repository.KeyRepository groupManager *GroupManager @@ -46,6 +59,7 @@ type HealthCheckService struct { lastResults map[string]string groupCheckTimeMutex sync.Mutex groupNextCheckTime map[uint]time.Time + httpClient *http.Client } func NewHealthCheckService( @@ -60,7 +74,7 @@ func NewHealthCheckService( ) *HealthCheckService { return &HealthCheckService{ db: db, - SettingsManager: ss, + settingsManager: ss, store: s, keyRepo: repo, groupManager: gm, @@ -70,6 +84,14 @@ func NewHealthCheckService( stopChan: make(chan struct{}), lastResults: make(map[string]string), groupNextCheckTime: make(map[uint]time.Time), + httpClient: &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableKeepAlives: false, + }, + }, } } @@ -86,6 +108,7 @@ func (s *HealthCheckService) Stop() { s.logger.Info("Stopping HealthCheckService...") close(s.stopChan) s.wg.Wait() + s.httpClient.CloseIdleConnections() s.logger.Info("HealthCheckService stopped gracefully.") } @@ -99,10 +122,19 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string { return resultsCopy } +// ==================== Key Check Loop ==================== + func (s *HealthCheckService) runKeyCheckLoop() { defer s.wg.Done() - s.logger.Info("Key check dynamic scheduler loop started.") - ticker := time.NewTicker(1 * time.Minute) + + settings := s.settingsManager.GetSettings() + schedulerInterval := time.Duration(settings.KeyCheckSchedulerIntervalSeconds) * time.Second + if schedulerInterval <= 0 { + schedulerInterval = time.Duration(defaultSchedulerIntervalSeconds) * time.Second + } + + s.logger.Infof("Key check dynamic scheduler loop started with interval: %v", schedulerInterval) + ticker := time.NewTicker(schedulerInterval) defer ticker.Stop() for { @@ -129,9 +161,11 @@ func (s *HealthCheckService) scheduleKeyChecks() { s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.") continue } + if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck { continue } + var intervalMinutes int if opConfig.KeyCheckIntervalMinutes != nil { intervalMinutes = *opConfig.KeyCheckIntervalMinutes @@ -140,64 +174,41 @@ func (s *HealthCheckService) scheduleKeyChecks() { if interval <= 0 { continue } + if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) { - s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID) - go s.performKeyChecksForGroup(group, opConfig) + s.logger.WithFields(logrus.Fields{ + "group_id": group.ID, + "group_name": group.Name, + "interval": interval, + }).Info("Scheduling key check for group") + + // 创建带超时的上下文 + ctx, cancel := context.WithTimeout(context.Background(), interval) + go func(g *models.KeyGroup, cfg *models.KeyGroupSettings) { + defer cancel() + select { + case <-s.stopChan: + return + default: + s.performKeyChecksForGroup(ctx, g, cfg) + } + }(group, opConfig) + s.groupNextCheckTime[group.ID] = now.Add(interval) } } } -func (s *HealthCheckService) runUpstreamCheckLoop() { - defer s.wg.Done() - s.logger.Info("Upstream check loop started.") - if s.SettingsManager.GetSettings().EnableUpstreamCheck { - s.performUpstreamChecks() - } - ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if s.SettingsManager.GetSettings().EnableUpstreamCheck { - s.logger.Debug("Upstream check ticker fired.") - s.performUpstreamChecks() - } - case <-s.stopChan: - s.logger.Info("Upstream check loop stopped.") - return - } - } -} - -func (s *HealthCheckService) runProxyCheckLoop() { - defer s.wg.Done() - s.logger.Info("Proxy check loop started.") - if s.SettingsManager.GetSettings().EnableProxyCheck { - s.performProxyChecks() - } - ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if s.SettingsManager.GetSettings().EnableProxyCheck { - s.logger.Debug("Proxy check ticker fired.") - s.performProxyChecks() - } - case <-s.stopChan: - s.logger.Info("Proxy check loop stopped.") - return - } - } -} - -func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) { - ctx := context.Background() - settings := s.SettingsManager.GetSettings() +func (s *HealthCheckService) performKeyChecksForGroup( + ctx context.Context, + group *models.KeyGroup, + opConfig *models.KeyGroupSettings, +) { + settings := s.settingsManager.GetSettings() timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second + if timeout <= 0 { + timeout = time.Duration(defaultKeyCheckTimeoutSeconds) * time.Second + } endpoint, err := s.groupManager.BuildKeyCheckEndpoint(group.ID) if err != nil { @@ -205,57 +216,83 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op return } - log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name}) - log.Infof("Starting key health check cycle.") + log := s.logger.WithFields(logrus.Fields{ + "group_id": group.ID, + "group_name": group.Name, + }) + + log.Info("Starting key health check cycle") + var mappingsToCheck []models.GroupAPIKeyMapping err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id"). Where("group_api_key_mappings.key_group_id = ?", group.ID). Where("api_keys.master_status = ?", models.MasterStatusActive). - Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{models.StatusActive, models.StatusDisabled, models.StatusCooldown}). + Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{ + models.StatusActive, + models.StatusDisabled, + models.StatusCooldown, + }). Preload("APIKey"). Find(&mappingsToCheck).Error if err != nil { - log.WithError(err).Error("Failed to fetch key mappings for health check.") + log.WithError(err).Error("Failed to fetch key mappings for health check") return } + if len(mappingsToCheck) == 0 { - log.Info("No key mappings to check for this group.") + log.Info("No key mappings to check for this group") return } - log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck)) + + log.WithField("key_count", len(mappingsToCheck)).Info("Starting health check for key mappings") + jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck)) var wg sync.WaitGroup - var concurrency int - if opConfig.KeyCheckConcurrency != nil { - concurrency = *opConfig.KeyCheckConcurrency - } - if concurrency <= 0 { - concurrency = 1 - } + + concurrency := s.getConcurrency(opConfig.KeyCheckConcurrency, defaultKeyCheckConcurrency) + log.WithField("concurrency", concurrency).Debug("Using concurrency for key check") + for w := 1; w <= concurrency; w++ { wg.Add(1) go func(workerID int) { defer wg.Done() for mapping := range jobs { - s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint) + select { + case <-ctx.Done(): + log.Warn("Context cancelled, stopping worker") + return + default: + s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint) + } } }(w) } + for _, m := range mappingsToCheck { jobs <- m } close(jobs) wg.Wait() - log.Info("Finished key health check cycle.") + + log.Info("Finished key health check cycle") } -func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) { +func (s *HealthCheckService) checkAndProcessMapping( + ctx context.Context, + mapping *models.GroupAPIKeyMapping, + timeout time.Duration, + endpoint string, +) { if mapping.APIKey == nil { - s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID) + s.logger.WithFields(logrus.Fields{ + "group_id": mapping.KeyGroupID, + "key_id": mapping.APIKeyID, + }).Warn("Skipping check for mapping because associated APIKey is nil") return } + validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint) if validationErr == nil { if mapping.Status != models.StatusActive { @@ -263,17 +300,28 @@ func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping } return } + errorString := validationErr.Error() if CustomErrors.IsPermanentUpstreamError(errorString) { s.revokeMapping(ctx, mapping, validationErr) return } + if CustomErrors.IsTemporaryUpstreamError(errorString) { - s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr) + s.logger.WithFields(logrus.Fields{ + "key_id": mapping.APIKeyID, + "group_id": mapping.KeyGroupID, + "error": validationErr.Error(), + }).Warn("Health check failed with temporary error, applying penalty") s.penalizeMapping(ctx, mapping, validationErr) return } - s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr) + + s.logger.WithFields(logrus.Fields{ + "key_id": mapping.APIKeyID, + "group_id": mapping.KeyGroupID, + "error": validationErr.Error(), + }).Error("Health check failed with transient or unknown upstream error, mapping will not be penalized") } func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) { @@ -283,39 +331,64 @@ func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *model mapping.LastError = "" if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { - s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) + s.logger.WithError(err).WithFields(logrus.Fields{ + "group_id": mapping.KeyGroupID, + "key_id": mapping.APIKeyID, + }).Error("Failed to activate mapping") return } - s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus) + + s.logger.WithFields(logrus.Fields{ + "group_id": mapping.KeyGroupID, + "key_id": mapping.APIKeyID, + "old_status": oldStatus, + "new_status": mapping.Status, + }).Info("Mapping successfully activated") + s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) } func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) { group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID) if !ok { - s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID) + s.logger.WithField("group_id", mapping.KeyGroupID).Error("Could not find group to apply penalty") return } + opConfig, buildErr := s.groupManager.BuildOperationalConfig(group) if buildErr != nil { - s.logger.WithError(buildErr).Errorf("Failed to build operational config for group %d during penalty.", mapping.KeyGroupID) + s.logger.WithError(buildErr).WithField("group_id", mapping.KeyGroupID).Error("Failed to build operational config for group during penalty") return } + oldStatus := mapping.Status mapping.LastError = err.Error() mapping.ConsecutiveErrorCount++ + threshold := *opConfig.KeyBlacklistThreshold if mapping.ConsecutiveErrorCount >= threshold { mapping.Status = models.StatusCooldown cooldownDuration := time.Duration(*opConfig.KeyCooldownMinutes) * time.Minute cooldownTime := time.Now().Add(cooldownDuration) mapping.CooldownUntil = &cooldownTime - s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration) + + s.logger.WithFields(logrus.Fields{ + "group_id": mapping.KeyGroupID, + "key_id": mapping.APIKeyID, + "error_count": mapping.ConsecutiveErrorCount, + "threshold": threshold, + "cooldown_duration": cooldownDuration, + }).Warn("Mapping reached error threshold and is now in COOLDOWN") } + if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil { - s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) + s.logger.WithError(errDb).WithFields(logrus.Fields{ + "group_id": mapping.KeyGroupID, + "key_id": mapping.APIKeyID, + }).Error("Failed to penalize mapping") return } + if oldStatus != mapping.Status { s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) } @@ -326,144 +399,500 @@ func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models. if oldStatus == models.StatusBanned { return } + mapping.Status = models.StatusBanned mapping.LastError = "Definitive error: " + err.Error() mapping.ConsecutiveErrorCount = 0 + if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil { - s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) + s.logger.WithError(errDb).WithFields(logrus.Fields{ + "group_id": mapping.KeyGroupID, + "key_id": mapping.APIKeyID, + }).Error("Failed to revoke mapping") return } - s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err) + + s.logger.WithFields(logrus.Fields{ + "group_id": mapping.KeyGroupID, + "key_id": mapping.APIKeyID, + "error": err.Error(), + }).Warn("Mapping has been BANNED due to definitive error") + s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) - s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID) + + s.logger.WithField("key_id", mapping.APIKeyID).Info("Triggering MasterStatus update for definitively failed key") if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil { - s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID) + s.logger.WithError(err).WithField("key_id", mapping.APIKeyID).Error("Failed to update master status after group-level ban") + } +} + +// ==================== Upstream Check Loop ==================== + +func (s *HealthCheckService) runUpstreamCheckLoop() { + defer s.wg.Done() + s.logger.Info("Upstream check loop started") + + settings := s.settingsManager.GetSettings() + if settings.EnableUpstreamCheck { + s.performUpstreamChecks() + } + + interval := time.Duration(settings.UpstreamCheckIntervalSeconds) * time.Second + if interval <= 0 { + interval = 300 * time.Second // 5 分钟兜底 + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + settings := s.settingsManager.GetSettings() + if settings.EnableUpstreamCheck { + s.logger.Debug("Upstream check ticker fired") + s.performUpstreamChecks() + } + case <-s.stopChan: + s.logger.Info("Upstream check loop stopped") + return + } } } func (s *HealthCheckService) performUpstreamChecks() { ctx := context.Background() - settings := s.SettingsManager.GetSettings() + settings := s.settingsManager.GetSettings() timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second + if timeout <= 0 { + timeout = 10 * time.Second + } + var upstreams []*models.UpstreamEndpoint if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil { - s.logger.WithError(err).Error("Failed to retrieve upstreams.") + s.logger.WithError(err).Error("Failed to retrieve upstreams") return } + if len(upstreams) == 0 { + s.logger.Debug("No upstreams configured for health check") return } - s.logger.Infof("Starting validation for %d upstreams.", len(upstreams)) + + s.logger.WithFields(logrus.Fields{ + "count": len(upstreams), + "timeout": timeout, + }).Info("Starting upstream validation") + + type checkResult struct { + upstreamID uint + url string + oldStatus string + newStatus string + changed bool + err error + } + + results := make([]checkResult, 0, len(upstreams)) + var resultsMutex sync.Mutex var wg sync.WaitGroup + for _, u := range upstreams { wg.Add(1) go func(upstream *models.UpstreamEndpoint) { defer wg.Done() + oldStatus := upstream.Status isAlive := s.checkEndpoint(upstream.URL, timeout) newStatus := StatusInactive if isAlive { newStatus = StatusActive } + s.lastResultsMutex.Lock() s.lastResults[upstream.URL] = newStatus s.lastResultsMutex.Unlock() - if oldStatus != newStatus { - s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus) + + result := checkResult{ + upstreamID: upstream.ID, + url: upstream.URL, + oldStatus: oldStatus, + newStatus: newStatus, + changed: oldStatus != newStatus, + } + + if result.changed { + s.logger.WithFields(logrus.Fields{ + "upstream_id": upstream.ID, + "url": upstream.URL, + "old_status": oldStatus, + "new_status": newStatus, + }).Info("Upstream status changed") + if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil { - s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.") + s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status") + result.err = err } else { s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus) } + } else { + s.logger.WithFields(logrus.Fields{ + "upstream_id": upstream.ID, + "url": upstream.URL, + "status": newStatus, + }).Debug("Upstream status unchanged") } + + resultsMutex.Lock() + results = append(results, result) + resultsMutex.Unlock() }(u) } wg.Wait() + + // 汇总统计 + activeCount := 0 + inactiveCount := 0 + changedCount := 0 + errorCount := 0 + + for _, r := range results { + if r.changed { + changedCount++ + } + if r.err != nil { + errorCount++ + } + if r.newStatus == StatusActive { + activeCount++ + } else { + inactiveCount++ + } + } + + s.logger.WithFields(logrus.Fields{ + "total": len(upstreams), + "active": activeCount, + "inactive": inactiveCount, + "changed": changedCount, + "errors": errorCount, + }).Info("Upstream validation cycle completed") } func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) bool { - client := http.Client{Timeout: timeout} - resp, err := client.Head(urlStr) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodHead, urlStr, nil) + if err != nil { + s.logger.WithError(err).WithField("url", urlStr).Debug("Failed to create request for endpoint check") + return false + } + + resp, err := s.httpClient.Do(req) if err != nil { return false } defer resp.Body.Close() + return resp.StatusCode < http.StatusInternalServerError } +// ==================== Proxy Check Loop ==================== + +func (s *HealthCheckService) runProxyCheckLoop() { + defer s.wg.Done() + s.logger.Info("Proxy check loop started") + + settings := s.settingsManager.GetSettings() + if settings.EnableProxyCheck { + s.performProxyChecks() + } + + interval := time.Duration(settings.ProxyCheckIntervalSeconds) * time.Second + if interval <= 0 { + interval = 600 * time.Second // 10 分钟兜底 + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + settings := s.settingsManager.GetSettings() + if settings.EnableProxyCheck { + s.logger.Debug("Proxy check ticker fired") + s.performProxyChecks() + } + case <-s.stopChan: + s.logger.Info("Proxy check loop stopped") + return + } + } +} + func (s *HealthCheckService) performProxyChecks() { ctx := context.Background() - settings := s.SettingsManager.GetSettings() + settings := s.settingsManager.GetSettings() timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second + if timeout <= 0 { + timeout = 15 * time.Second + } + var proxies []*models.ProxyConfig if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil { - s.logger.WithError(err).Error("Failed to retrieve proxies.") + s.logger.WithError(err).Error("Failed to retrieve proxies") return } + if len(proxies) == 0 { + s.logger.Debug("No proxies configured for health check") return } - s.logger.Infof("Starting validation for %d proxies.", len(proxies)) + + s.logger.WithFields(logrus.Fields{ + "count": len(proxies), + "timeout": timeout, + }).Info("Starting proxy validation") + + activeCount := 0 + inactiveCount := 0 + changedCount := 0 + var statsMutex sync.Mutex var wg sync.WaitGroup + for _, p := range proxies { wg.Add(1) go func(proxyCfg *models.ProxyConfig) { defer wg.Done() + + oldStatus := proxyCfg.Status isAlive := s.checkProxy(proxyCfg, timeout) newStatus := StatusInactive if isAlive { newStatus = StatusActive } + s.lastResultsMutex.Lock() s.lastResults[proxyCfg.Address] = newStatus s.lastResultsMutex.Unlock() - if proxyCfg.Status != newStatus { - s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus) + + if oldStatus != newStatus { + s.logger.WithFields(logrus.Fields{ + "proxy_id": proxyCfg.ID, + "address": proxyCfg.Address, + "old_status": oldStatus, + "new_status": newStatus, + }).Info("Proxy status changed") + if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil { - s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.") + s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status") } + + statsMutex.Lock() + changedCount++ + if newStatus == StatusActive { + activeCount++ + } else { + inactiveCount++ + } + statsMutex.Unlock() + } else { + s.logger.WithFields(logrus.Fields{ + "proxy_id": proxyCfg.ID, + "address": proxyCfg.Address, + "status": newStatus, + }).Debug("Proxy status unchanged") + + statsMutex.Lock() + if newStatus == StatusActive { + activeCount++ + } else { + inactiveCount++ + } + statsMutex.Unlock() } }(p) } wg.Wait() + + s.logger.WithFields(logrus.Fields{ + "total": len(proxies), + "active": activeCount, + "inactive": inactiveCount, + "changed": changedCount, + }).Info("Proxy validation cycle completed") } func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout time.Duration) bool { transport := &http.Transport{} + switch proxyCfg.Protocol { case "http", "https": proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address) if err != nil { - s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format.") + s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format") return false } transport.Proxy = http.ProxyURL(proxyUrl) + case "socks5": dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct) if err != nil { - s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer.") + s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer") return false } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) } + default: - s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol.") + s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol") return false } + client := &http.Client{ Transport: transport, Timeout: timeout, } + resp, err := client.Get(ProxyCheckTargetURL) if err != nil { return false } defer resp.Body.Close() + return true } -func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) { +// ==================== Base Key Check Loop ==================== + +func (s *HealthCheckService) runBaseKeyCheckLoop() { + defer s.wg.Done() + s.logger.Info("Global base key check loop started") + + settings := s.settingsManager.GetSettings() + if !settings.EnableBaseKeyCheck { + s.logger.Info("Global base key check is disabled") + return + } + + // 启动时执行一次 + s.performBaseKeyChecks() + + interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute + if interval <= 0 { + s.logger.WithField("interval", settings.BaseKeyCheckIntervalMinutes).Warn("Invalid BaseKeyCheckIntervalMinutes, disabling base key check loop") + return + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.performBaseKeyChecks() + case <-s.stopChan: + s.logger.Info("Global base key check loop stopped") + return + } + } +} + +func (s *HealthCheckService) performBaseKeyChecks() { + ctx := context.Background() + s.logger.Info("Starting global base key check cycle") + + settings := s.settingsManager.GetSettings() + timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second + if timeout <= 0 { + timeout = time.Duration(defaultKeyCheckTimeoutSeconds) * time.Second + } + + endpoint := settings.BaseKeyCheckEndpoint + if endpoint == "" { + endpoint = defaultBaseKeyCheckEndpoint + s.logger.WithField("endpoint", endpoint).Debug("Using default base key check endpoint") + } + + concurrency := settings.BaseKeyCheckConcurrency + if concurrency <= 0 { + concurrency = defaultBaseKeyCheckConcurrency + } + concurrency = s.ensureConcurrencyBounds(concurrency) + + keys, err := s.keyRepo.GetActiveMasterKeys() + if err != nil { + s.logger.WithError(err).Error("Failed to fetch active master keys for base check") + return + } + + if len(keys) == 0 { + s.logger.Info("No active master keys to perform base check on") + return + } + + s.logger.WithFields(logrus.Fields{ + "key_count": len(keys), + "concurrency": concurrency, + "endpoint": endpoint, + }).Info("Performing base check on active master keys") + + jobs := make(chan *models.APIKey, len(keys)) + var wg sync.WaitGroup + + for w := 0; w < concurrency; w++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + for key := range jobs { + select { + case <-s.stopChan: + return + default: + err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint) + if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) { + oldStatus := key.MasterStatus + keyPrefix := key.APIKey + if len(keyPrefix) > 8 { + keyPrefix = keyPrefix[:8] + } + + s.logger.WithFields(logrus.Fields{ + "key_id": key.ID, + "key_prefix": keyPrefix + "...", + "error": err.Error(), + }).Warn("Key failed definitive base check, setting MasterStatus to REVOKED") + + if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil { + s.logger.WithError(updateErr).WithField("key_id", key.ID).Error("Failed to update master status") + } else { + s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked) + } + } + } + } + }(w) + } + + for _, key := range keys { + jobs <- key + } + close(jobs) + wg.Wait() + + s.logger.Info("Global base key check cycle finished") +} + +// ==================== Event Publishing ==================== + +func (s *HealthCheckService) publishKeyStatusChangedEvent( + ctx context.Context, + groupID, keyID uint, + oldStatus, newStatus models.APIKeyStatus, +) { event := models.KeyStatusChangedEvent{ KeyID: keyID, GroupID: groupID, @@ -472,17 +901,23 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, g ChangeReason: "health_check", ChangedAt: time.Now(), } + payload, err := json.Marshal(event) if err != nil { - s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID) + s.logger.WithError(err).WithField("group_id", groupID).Error("Failed to marshal KeyStatusChangedEvent") return } + if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil { - s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID) + s.logger.WithError(err).WithField("group_id", groupID).Error("Failed to publish KeyStatusChangedEvent") } } -func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) { +func (s *HealthCheckService) publishUpstreamHealthChangedEvent( + ctx context.Context, + upstream *models.UpstreamEndpoint, + oldStatus, newStatus string, +) { event := models.UpstreamHealthChangedEvent{ UpstreamID: upstream.ID, UpstreamURL: upstream.URL, @@ -492,93 +927,23 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Conte Reason: "health_check", CheckedAt: time.Now(), } + payload, err := json.Marshal(event) if err != nil { - s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.") + s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent") return } + if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil { - s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.") + s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent") } } -func (s *HealthCheckService) runBaseKeyCheckLoop() { - defer s.wg.Done() - s.logger.Info("Global base key check loop started.") - settings := s.SettingsManager.GetSettings() - if !settings.EnableBaseKeyCheck { - s.logger.Info("Global base key check is disabled.") - return - } - s.performBaseKeyChecks() - interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute - if interval <= 0 { - s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes) - return - } - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - s.performBaseKeyChecks() - case <-s.stopChan: - s.logger.Info("Global base key check loop stopped.") - return - } - } -} - -func (s *HealthCheckService) performBaseKeyChecks() { - ctx := context.Background() - s.logger.Info("Starting global base key check cycle.") - settings := s.SettingsManager.GetSettings() - timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second - endpoint := settings.BaseKeyCheckEndpoint - concurrency := settings.BaseKeyCheckConcurrency - keys, err := s.keyRepo.GetActiveMasterKeys() - if err != nil { - s.logger.WithError(err).Error("Failed to fetch active master keys for base check.") - return - } - if len(keys) == 0 { - s.logger.Info("No active master keys to perform base check on.") - return - } - s.logger.Infof("Performing base check on %d active master keys.", len(keys)) - jobs := make(chan *models.APIKey, len(keys)) - var wg sync.WaitGroup - if concurrency <= 0 { - concurrency = 5 - } - for w := 0; w < concurrency; w++ { - wg.Add(1) - go func() { - defer wg.Done() - for key := range jobs { - err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint) - if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) { - oldStatus := key.MasterStatus - s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err) - if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil { - s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID) - } else { - s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked) - } - } - } - }() - } - for _, key := range keys { - jobs <- key - } - close(jobs) - wg.Wait() - s.logger.Info("Global base key check cycle finished.") -} - -func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) { +func (s *HealthCheckService) publishMasterKeyStatusChangedEvent( + ctx context.Context, + keyID uint, + oldStatus, newStatus models.MasterAPIKeyStatus, +) { event := models.MasterKeyStatusChangedEvent{ KeyID: keyID, OldMasterStatus: oldStatus, @@ -586,12 +951,47 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Cont ChangeReason: "base_health_check", ChangedAt: time.Now(), } + payload, err := json.Marshal(event) if err != nil { - s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID) + s.logger.WithError(err).WithField("key_id", keyID).Error("Failed to marshal MasterKeyStatusChangedEvent") return } + if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil { - s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID) + s.logger.WithError(err).WithField("key_id", keyID).Error("Failed to publish MasterKeyStatusChangedEvent") } } + +// ==================== Helper Methods ==================== + +func (s *HealthCheckService) getConcurrency(configValue *int, defaultValue int) int { + var concurrency int + if configValue != nil && *configValue > 0 { + concurrency = *configValue + } else { + concurrency = defaultValue + } + + return s.ensureConcurrencyBounds(concurrency) +} + +func (s *HealthCheckService) ensureConcurrencyBounds(concurrency int) int { + if concurrency < minHealthCheckConcurrency { + s.logger.WithFields(logrus.Fields{ + "requested": concurrency, + "minimum": minHealthCheckConcurrency, + }).Debug("Concurrency below minimum, adjusting") + return minHealthCheckConcurrency + } + + if concurrency > maxHealthCheckConcurrency { + s.logger.WithFields(logrus.Fields{ + "requested": concurrency, + "maximum": maxHealthCheckConcurrency, + }).Warn("Concurrency exceeds maximum, capping it") + return maxHealthCheckConcurrency + } + + return concurrency +} diff --git a/internal/service/key_import_service.go b/internal/service/key_import_service.go index fca6c7c..9f178c8 100644 --- a/internal/service/key_import_service.go +++ b/internal/service/key_import_service.go @@ -23,6 +23,10 @@ const ( TaskTypeHardDeleteKeys = "hard_delete_keys" TaskTypeRestoreKeys = "restore_keys" chunkSize = 500 + + // 任务超时时间常量化 + defaultTaskTimeout = 15 * time.Minute + longTaskTimeout = time.Hour ) type KeyImportService struct { @@ -43,17 +47,19 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store. } } +// runTaskWithRecovery 统一的任务恢复包装器 func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic recovered in task %s: %v", taskID, r) - s.logger.Error(err) + s.logger.WithField("task_id", taskID).Error(err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) } }() taskFunc() } +// StartAddKeysTask 启动批量添加密钥任务 func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { @@ -61,260 +67,404 @@ func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, k } resourceID := fmt.Sprintf("group-%d", groupID) - taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute) + taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout) if err != nil { return nil, err } + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport) }) + return taskStatus, nil } +// StartUnlinkKeysTask 启动批量解绑密钥任务 func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } + resourceID := fmt.Sprintf("group-%d", groupID) - taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour) + taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout) if err != nil { return nil, err } + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys) }) + return taskStatus, nil } +// StartHardDeleteKeysTask 启动硬删除密钥任务 func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } + resourceID := "global_hard_delete" - taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour) + taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout) if err != nil { return nil, err } + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) + return taskStatus, nil } +// StartRestoreKeysTask 启动恢复密钥任务 func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } - resourceID := "global_restore_keys" - taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour) + resourceID := "global_restore_keys" + taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout) if err != nil { return nil, err } + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) + return taskStatus, nil } -func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { - uniqueKeysMap := make(map[string]struct{}) - var uniqueKeyStrings []string - for _, kStr := range keys { - if _, exists := uniqueKeysMap[kStr]; !exists { - uniqueKeysMap[kStr] = struct{}{} - uniqueKeyStrings = append(uniqueKeyStrings, kStr) - } - } - if len(uniqueKeyStrings) == 0 { - s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil) - return - } - keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings)) - for i, keyStr := range uniqueKeyStrings { - keysToEnsure[i] = models.APIKey{APIKey: keyStr} - } - allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure) +// StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑 +func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { + s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) + + keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) - return + return nil, fmt.Errorf("failed to find keys by filter: %w", err) } - alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID) - if err != nil { - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err)) - return - } - alreadyLinkedIDSet := make(map[uint]struct{}) - for _, key := range alreadyLinkedModels { - alreadyLinkedIDSet[key.ID] = struct{}{} - } - var keysToLink []models.APIKey - for _, key := range allKeyModels { - if _, exists := alreadyLinkedIDSet[key.ID]; !exists { - keysToLink = append(keysToLink, key) - } + if len(keyValues) == 0 { + return nil, fmt.Errorf("no keys found matching the provided filter") } + keysAsText := strings.Join(keyValues, "\n") + s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID) + + return s.StartUnlinkKeysTask(ctx, groupID, keysAsText) +} + +// ==================== 核心任务执行逻辑 ==================== + +// runAddKeysTask 执行批量添加密钥 +func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { + // 1. 去重 + uniqueKeys := s.deduplicateKeys(keys) + if len(uniqueKeys) == 0 { + s.endTaskWithResult(ctx, taskID, resourceID, gin.H{ + "newly_linked_count": 0, + "already_linked_count": 0, + }, nil) + return + } + + // 2. 确保所有密钥在数据库中存在(幂等操作) + allKeyModels, err := s.ensureKeysExist(uniqueKeys) + if err != nil { + s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) + return + } + + // 3. 过滤已关联的密钥 + keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys) + if err != nil { + s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err)) + return + } + + // 4. 更新任务的实际处理总数 if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } + + // 5. 批量关联密钥到组 if len(keysToLink) > 0 { - idsToLink := make([]uint, len(keysToLink)) - for i, key := range keysToLink { - idsToLink[i] = key.ID + if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil { + s.endTaskWithResult(ctx, taskID, resourceID, nil, err) + return } - for i := 0; i < len(idsToLink); i += chunkSize { - end := i + chunkSize - if end > len(idsToLink) { - end = len(idsToLink) - } - chunk := idsToLink[i:end] - if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil { - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err)) - return - } - _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) + } + + // 6. 根据验证标志处理密钥状态 + if len(keysToLink) > 0 { + s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport) + } + + // 7. 返回结果 + result := gin.H{ + "newly_linked_count": len(keysToLink), + "already_linked_count": alreadyLinkedCount, + "total_linked_count": len(allKeyModels), + } + s.endTaskWithResult(ctx, taskID, resourceID, result, nil) +} + +// runUnlinkKeysTask 执行批量解绑密钥 +func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) { + // 1. 去重 + uniqueKeys := s.deduplicateKeys(keys) + + // 2. 查找需要解绑的密钥 + keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) + if err != nil { + s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) + return + } + + if len(keysToUnlink) == 0 { + result := gin.H{ + "unlinked_count": 0, + "hard_deleted_count": 0, + "not_found_count": len(uniqueKeys), } + s.endTaskWithResult(ctx, taskID, resourceID, result, nil) + return + } + + // 3. 提取密钥 ID + idsToUnlink := s.extractKeyIDs(keysToUnlink) + + // 4. 更新任务总数 + if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil { + s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) + } + + // 5. 批量解绑 + totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink) + if err != nil { + s.endTaskWithResult(ctx, taskID, resourceID, nil, err) + return + } + + // 6. 清理孤立密钥 + totalDeleted, err := s.keyRepo.DeleteOrphanKeys() + if err != nil { + s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.") + } + + // 7. 返回结果 + result := gin.H{ + "unlinked_count": totalUnlinked, + "hard_deleted_count": totalDeleted, + "not_found_count": len(uniqueKeys) - int(totalUnlinked), + } + s.endTaskWithResult(ctx, taskID, resourceID, result, nil) +} + +// runHardDeleteKeysTask 执行硬删除密钥 +func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { + totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) { + return s.keyRepo.HardDeleteByValues(chunk) + }) + + if err != nil { + s.endTaskWithResult(ctx, taskID, resourceID, nil, err) + return } result := gin.H{ - "newly_linked_count": len(keysToLink), - "already_linked_count": len(alreadyLinkedIDSet), - "total_linked_count": len(allKeyModels), + "hard_deleted_count": totalDeleted, + "not_found_count": int64(len(keys)) - totalDeleted, } - if len(keysToLink) > 0 { - idsToLink := make([]uint, len(keysToLink)) - for i, key := range keysToLink { - idsToLink[i] = key.ID - } - if validateOnImport { - s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink) - for _, keyID := range idsToLink { - s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked") - } - } else { - for _, keyID := range idsToLink { - if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil { - s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err) - } - } - } - } - s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) + s.endTaskWithResult(ctx, taskID, resourceID, result, nil) + s.publishChangeEvent(ctx, 0, "keys_hard_deleted") } -func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) { - uniqueKeysMap := make(map[string]struct{}) - var uniqueKeys []string +// runRestoreKeysTask 执行恢复密钥 +func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { + restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) { + return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive) + }) + + if err != nil { + s.endTaskWithResult(ctx, taskID, resourceID, nil, err) + return + } + + result := gin.H{ + "restored_count": restoredCount, + "not_found_count": int64(len(keys)) - restoredCount, + } + s.endTaskWithResult(ctx, taskID, resourceID, result, nil) + s.publishChangeEvent(ctx, 0, "keys_bulk_restored") +} + +// ==================== 辅助方法 ==================== + +// deduplicateKeys 去重密钥列表 +func (s *KeyImportService) deduplicateKeys(keys []string) []string { + uniqueKeysMap := make(map[string]struct{}, len(keys)) + uniqueKeys := make([]string, 0, len(keys)) + for _, kStr := range keys { if _, exists := uniqueKeysMap[kStr]; !exists { uniqueKeysMap[kStr] = struct{}{} uniqueKeys = append(uniqueKeys, kStr) } } - keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) + return uniqueKeys +} + +// ensureKeysExist 确保所有密钥在数据库中存在 +func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) { + keysToEnsure := make([]models.APIKey, len(keys)) + for i, keyStr := range keys { + keysToEnsure[i] = models.APIKey{APIKey: keyStr} + } + return s.keyRepo.AddKeys(keysToEnsure) +} + +// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥 +func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) { + alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) if err != nil { - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) - return + return nil, 0, err } - if len(keysToUnlink) == 0 { - result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)} - s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) - return - } - idsToUnlink := make([]uint, len(keysToUnlink)) - for i, key := range keysToUnlink { - idsToUnlink[i] = key.ID + alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels)) + for _, key := range alreadyLinkedModels { + alreadyLinkedIDSet[key.ID] = struct{}{} } - if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil { - s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) - } - var totalUnlinked int64 - for i := 0; i < len(idsToUnlink); i += chunkSize { - end := i + chunkSize - if end > len(idsToUnlink) { - end = len(idsToUnlink) + keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet)) + for _, key := range allKeyModels { + if _, exists := alreadyLinkedIDSet[key.ID]; !exists { + keysToLink = append(keysToLink, key) } + } + + return keysToLink, len(alreadyLinkedIDSet), nil +} + +// extractKeyIDs 提取密钥 ID 列表 +func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint { + ids := make([]uint, len(keys)) + for i, key := range keys { + ids[i] = key.ID + } + return ids +} + +// linkKeysInChunks 分块关联密钥到组 +func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error { + idsToLink := s.extractKeyIDs(keysToLink) + + for i := 0; i < len(idsToLink); i += chunkSize { + end := min(i+chunkSize, len(idsToLink)) + chunk := idsToLink[i:end] + + if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil { + return fmt.Errorf("chunk failed to link keys: %w", err) + } + + _ = s.taskService.UpdateProgressByID(ctx, taskID, end) + } + return nil +} + +// unlinkKeysInChunks 分块解绑密钥 +func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) { + var totalUnlinked int64 + + for i := 0; i < len(idsToUnlink); i += chunkSize { + end := min(i+chunkSize, len(idsToUnlink)) chunk := idsToUnlink[i:end] + unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk) if err != nil { - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err)) - return + return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err) } + totalUnlinked += unlinked + + // 发布解绑事件 for _, keyID := range chunk { s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked") } - _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) + + _ = s.taskService.UpdateProgressByID(ctx, taskID, end) } - totalDeleted, err := s.keyRepo.DeleteOrphanKeys() + return totalUnlinked, nil +} + +// processKeysInChunks 通用的分块处理密钥逻辑 +func (s *KeyImportService) processKeysInChunks( + ctx context.Context, + taskID string, + keys []string, + processFunc func(chunk []string) (int64, error), +) (int64, error) { + var totalProcessed int64 + + for i := 0; i < len(keys); i += chunkSize { + end := min(i+chunkSize, len(keys)) + chunk := keys[i:end] + + count, err := processFunc(chunk) + if err != nil { + return 0, fmt.Errorf("failed to process chunk: %w", err) + } + + totalProcessed += count + _ = s.taskService.UpdateProgressByID(ctx, taskID, end) + } + + return totalProcessed, nil +} + +// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活) +func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) { + idsToLink := s.extractKeyIDs(keysToLink) + + if validateOnImport { + // 发布批量导入完成事件,触发验证 + s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink) + + // 发布单个密钥状态变更事件 + for _, keyID := range idsToLink { + s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked") + } + } else { + // 直接激活密钥,不进行验证 + for _, keyID := range idsToLink { + if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil { + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "key_id": keyID, + }).Errorf("Failed to directly activate key: %v", err) + } + } + } +} + +// endTaskWithResult 统一的任务结束处理 +func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) { if err != nil { - s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.") + s.logger.WithFields(logrus.Fields{ + "task_id": taskID, + "resource_id": resourceID, + }).WithError(err).Error("Task failed") } - result := gin.H{ - "unlinked_count": totalUnlinked, - "hard_deleted_count": totalDeleted, - "not_found_count": len(uniqueKeys) - int(totalUnlinked), - } - s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err) } -func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { - var totalDeleted int64 - for i := 0; i < len(keys); i += chunkSize { - end := i + chunkSize - if end > len(keys) { - end = len(keys) - } - chunk := keys[i:end] - - deleted, err := s.keyRepo.HardDeleteByValues(chunk) - if err != nil { - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err)) - return - } - totalDeleted += deleted - _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) - } - result := gin.H{ - "hard_deleted_count": totalDeleted, - "not_found_count": int64(len(keys)) - totalDeleted, - } - s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) - s.publishChangeEvent(ctx, 0, "keys_hard_deleted") -} - -func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { - var restoredCount int64 - for i := 0; i < len(keys); i += chunkSize { - end := i + chunkSize - if end > len(keys) { - end = len(keys) - } - chunk := keys[i:end] - - count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive) - if err != nil { - s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err)) - return - } - restoredCount += count - _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) - } - result := gin.H{ - "restored_count": restoredCount, - "not_found_count": int64(len(keys)) - restoredCount, - } - s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) - s.publishChangeEvent(ctx, 0, "keys_bulk_restored") -} +// ==================== 事件发布方法 ==================== +// publishSingleKeyChangeEvent 发布单个密钥状态变更事件 func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, @@ -324,56 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, grou ChangeReason: reason, ChangedAt: time.Now(), } - eventData, _ := json.Marshal(event) - if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ + + eventData, err := json.Marshal(event) + if err != nil { + s.logger.WithFields(logrus.Fields{ "group_id": groupID, "key_id": keyID, "reason": reason, - }).Error("Failed to publish single key change event.") + }).WithError(err).Error("Failed to marshal key change event") + return + } + + if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "key_id": keyID, + "reason": reason, + }).WithError(err).Error("Failed to publish single key change event") } } +// publishChangeEvent 发布通用变更事件 func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, ChangeReason: reason, } - eventData, _ := json.Marshal(event) - _ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData) + + eventData, err := json.Marshal(event) + if err != nil { + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "reason": reason, + }).WithError(err).Error("Failed to marshal change event") + return + } + + if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "reason": reason, + }).WithError(err).Error("Failed to publish change event") + } } +// publishImportGroupCompletedEvent 发布批量导入完成事件 func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) { if len(keyIDs) == 0 { return } + event := models.ImportGroupCompletedEvent{ GroupID: groupID, KeyIDs: keyIDs, CompletedAt: time.Now(), } + eventData, err := json.Marshal(event) if err != nil { - s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.") + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "key_count": len(keyIDs), + }).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent") return } + if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil { - s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.") + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "key_count": len(keyIDs), + }).WithError(err).Error("Failed to publish ImportGroupCompletedEvent") } else { - s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs)) + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "key_count": len(keyIDs), + }).Info("Published ImportGroupCompletedEvent") } } -func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { - s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) - keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) - if err != nil { - return nil, fmt.Errorf("failed to find keys by filter: %w", err) +// min 返回两个整数中的较小值 +func min(a, b int) int { + if a < b { + return a } - if len(keyValues) == 0 { - return nil, fmt.Errorf("no keys found matching the provided filter") - } - keysAsText := strings.Join(keyValues, "\n") - s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID) - return s.StartUnlinkKeysTask(ctx, groupID, keysAsText) + return b } diff --git a/internal/service/key_validation_service.go b/internal/service/key_validation_service.go index 74e6678..32820e3 100644 --- a/internal/service/key_validation_service.go +++ b/internal/service/key_validation_service.go @@ -25,26 +25,38 @@ import ( ) const ( - TaskTypeTestKeys = "test_keys" + TaskTypeTestKeys = "test_keys" + defaultConcurrency = 10 + maxValidationConcurrency = 100 + validationTaskTimeout = time.Hour ) type KeyValidationService struct { taskService task.Reporter channel channel.ChannelProxy db *gorm.DB - SettingsManager *settings.SettingsManager + settingsManager *settings.SettingsManager groupManager *GroupManager store store.Store keyRepo repository.KeyRepository logger *logrus.Entry } -func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService { +func NewKeyValidationService( + ts task.Reporter, + ch channel.ChannelProxy, + db *gorm.DB, + ss *settings.SettingsManager, + gm *GroupManager, + st store.Store, + kr repository.KeyRepository, + logger *logrus.Logger, +) *KeyValidationService { return &KeyValidationService{ taskService: ts, channel: ch, db: db, - SettingsManager: ss, + settingsManager: ss, groupManager: gm, store: st, keyRepo: kr, @@ -52,33 +64,393 @@ func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm } } +// ==================== 公开接口 ==================== + +// ValidateSingleKey 验证单个密钥 func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error { + // 1. 解密密钥 if err := s.keyRepo.Decrypt(key); err != nil { return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err) } + + // 2. 创建 HTTP 客户端和请求 client := &http.Client{Timeout: timeout} req, err := http.NewRequest("GET", endpoint, nil) if err != nil { - s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err) + s.logger.WithFields(logrus.Fields{ + "key_id": key.ID, + "endpoint": endpoint, + }).Error("Failed to create validation request") return fmt.Errorf("failed to create request: %w", err) } + // 3. 修改请求(添加密钥认证头) s.channel.ModifyRequest(req, key) + // 4. 执行请求 resp, err := client.Do(req) if err != nil { return fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() + // 5. 检查响应状态 if resp.StatusCode == http.StatusOK { return nil } + // 6. 处理错误响应 + return s.buildValidationError(resp) +} + +// StartTestKeysTask 启动批量密钥测试任务 +func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { + // 1. 解析和验证输入 + keyStrings := utils.ParseKeysFromText(keysText) + if len(keyStrings) == 0 { + return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text") + } + + // 2. 查询密钥模型 + apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings) + if err != nil { + return nil, CustomErrors.ParseDBError(err) + } + if len(apiKeyModels) == 0 { + return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system") + } + + // 3. 批量解密密钥 + if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil { + s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task") + return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation") + } + + // 4. 获取组配置 + group, ok := s.groupManager.GetGroupByID(groupID) + if !ok { + return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID)) + } + + opConfig, err := s.groupManager.BuildOperationalConfig(group) + if err != nil { + return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err)) + } + + // 5. 构建验证端点 + endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID) + if err != nil { + return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err)) + } + + // 6. 创建任务 + resourceID := fmt.Sprintf("group-%d", groupID) + taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout) + if err != nil { + return nil, err + } + + // 7. 准备任务参数 + params := s.buildValidationParams(opConfig) + + // 8. 启动异步验证任务 + go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint) + + return taskStatus, nil +} + +// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务 +func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { + s.logger.WithFields(logrus.Fields{ + "group_id": groupID, + "statuses": statuses, + }).Info("Starting test task with status filter") + + // 1. 根据过滤条件查询密钥 + keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) + if err != nil { + return nil, CustomErrors.ParseDBError(err) + } + + if len(keyValues) == 0 { + return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria") + } + + // 2. 转换为文本格式并启动任务 + keysAsText := strings.Join(keyValues, "\n") + s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID) + + return s.StartTestKeysTask(ctx, groupID, keysAsText) +} + +// ==================== 核心任务执行逻辑 ==================== + +// validationParams 验证参数封装 +type validationParams struct { + timeout time.Duration + concurrency int +} + +// buildValidationParams 构建验证参数 +func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams { + settings := s.settingsManager.GetSettings() + // 从配置读取超时时间(而非硬编码) + timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second // 仅在配置无效时使用默认值 + } + // 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值) + var concurrency int + if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 { + concurrency = *opConfig.KeyCheckConcurrency + } else if settings.BaseKeyCheckConcurrency > 0 { + concurrency = settings.BaseKeyCheckConcurrency + } else { + concurrency = defaultConcurrency // 兜底默认值 + } + // 限制最大并发数(防护措施) + if concurrency > maxValidationConcurrency { + concurrency = maxValidationConcurrency + } + return validationParams{ + timeout: timeout, + concurrency: concurrency, + } +} + +// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器 +func (s *KeyValidationService) runTestKeysTaskWithRecovery( + ctx context.Context, + taskID string, + resourceID string, + groupID uint, + keys []models.APIKey, + params validationParams, + endpoint string, +) { + defer func() { + if r := recover(); r != nil { + err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r) + s.logger.WithField("task_id", taskID).Error(err) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) + } + }() + + s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint) +} + +// runTestKeysTask 执行批量密钥验证任务 +func (s *KeyValidationService) runTestKeysTask( + ctx context.Context, + taskID string, + resourceID string, + groupID uint, + keys []models.APIKey, + params validationParams, + endpoint string, +) { + s.logger.WithFields(logrus.Fields{ + "task_id": taskID, + "group_id": groupID, + "key_count": len(keys), + "concurrency": params.concurrency, + "timeout": params.timeout, + }).Info("Starting validation task") + + // 1. 初始化结果收集 + results := make([]models.KeyTestResult, len(keys)) + + // 2. 创建任务分发器 + dispatcher := newValidationDispatcher( + keys, + params.concurrency, + s, + ctx, + taskID, + groupID, + endpoint, + params.timeout, + ) + + // 3. 执行并发验证 + dispatcher.run(results) + + // 4. 完成任务 + s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil) + + s.logger.WithFields(logrus.Fields{ + "task_id": taskID, + "group_id": groupID, + "processed": len(results), + }).Info("Validation task completed") +} + +// ==================== 验证调度器 ==================== + +// validationJob 验证作业 +type validationJob struct { + index int + key models.APIKey +} + +// validationDispatcher 验证任务分发器 +type validationDispatcher struct { + keys []models.APIKey + concurrency int + service *KeyValidationService + ctx context.Context + taskID string + groupID uint + endpoint string + timeout time.Duration + + mu sync.Mutex + processedCount int +} + +// newValidationDispatcher 创建验证分发器 +func newValidationDispatcher( + keys []models.APIKey, + concurrency int, + service *KeyValidationService, + ctx context.Context, + taskID string, + groupID uint, + endpoint string, + timeout time.Duration, +) *validationDispatcher { + return &validationDispatcher{ + keys: keys, + concurrency: concurrency, + service: service, + ctx: ctx, + taskID: taskID, + groupID: groupID, + endpoint: endpoint, + timeout: timeout, + } +} + +// run 执行并发验证 +func (d *validationDispatcher) run(results []models.KeyTestResult) { + var wg sync.WaitGroup + jobs := make(chan validationJob, len(d.keys)) + + // 启动 worker pool + for i := 0; i < d.concurrency; i++ { + wg.Add(1) + go d.worker(&wg, jobs, results) + } + + // 分发任务 + for i, key := range d.keys { + jobs <- validationJob{index: i, key: key} + } + close(jobs) + + // 等待所有 worker 完成 + wg.Wait() +} + +// worker 验证工作协程 +func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) { + defer wg.Done() + + for job := range jobs { + result := d.validateKey(job.key) + + d.mu.Lock() + results[job.index] = result + d.processedCount++ + _ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount) + d.mu.Unlock() + } +} + +// validateKey 验证单个密钥并返回结果 +func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult { + // 1. 执行验证 + validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint) + + // 2. 构建结果和事件 + result, event := d.buildResultAndEvent(key, validationErr) + + // 3. 发布验证事件 + d.publishValidationEvent(key.ID, event) + + return result +} + +// buildResultAndEvent 构建验证结果和事件 +func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) { + event := models.RequestFinishedEvent{ + RequestLog: models.RequestLog{ + GroupID: &d.groupID, + KeyID: &key.ID, + }, + } + + if validationErr == nil { + // 验证成功 + event.RequestLog.IsSuccess = true + return models.KeyTestResult{ + Key: key.APIKey, + Status: "valid", + Message: "Validation successful", + }, event + } + + // 验证失败 + event.RequestLog.IsSuccess = false + + var apiErr *CustomErrors.APIError + if CustomErrors.As(validationErr, &apiErr) { + event.Error = apiErr + return models.KeyTestResult{ + Key: key.APIKey, + Status: "invalid", + Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message), + }, event + } + + // 其他错误 + event.Error = &CustomErrors.APIError{Message: validationErr.Error()} + return models.KeyTestResult{ + Key: key.APIKey, + Status: "error", + Message: "Validation check failed: " + validationErr.Error(), + }, event +} + +// publishValidationEvent 发布验证事件 +func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) { + eventData, err := json.Marshal(event) + if err != nil { + d.service.logger.WithFields(logrus.Fields{ + "key_id": keyID, + "group_id": d.groupID, + }).WithError(err).Error("Failed to marshal validation event") + return + } + + if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil { + d.service.logger.WithFields(logrus.Fields{ + "key_id": keyID, + "group_id": d.groupID, + }).WithError(err).Error("Failed to publish validation event") + } +} + +// ==================== 辅助方法 ==================== + +// buildValidationError 构建验证错误 +func (s *KeyValidationService) buildValidationError(resp *http.Response) error { bodyBytes, readErr := io.ReadAll(resp.Body) + var errorMsg string if readErr != nil { errorMsg = "Failed to read error response body" + s.logger.WithError(readErr).Warn("Failed to read validation error response") } else { errorMsg = string(bodyBytes) } @@ -89,128 +461,3 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim Code: "VALIDATION_FAILED", } } - -func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { - keyStrings := utils.ParseKeysFromText(keysText) - if len(keyStrings) == 0 { - return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text") - } - apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings) - if err != nil { - return nil, CustomErrors.ParseDBError(err) - } - if len(apiKeyModels) == 0 { - return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system") - } - if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil { - s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.") - return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation") - } - group, ok := s.groupManager.GetGroupByID(groupID) - if !ok { - return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID)) - } - opConfig, err := s.groupManager.BuildOperationalConfig(group) - if err != nil { - return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err)) - } - resourceID := fmt.Sprintf("group-%d", groupID) - taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour) - if err != nil { - return nil, err - } - settings := s.SettingsManager.GetSettings() - timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second - endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID) - if err != nil { - s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err) - return nil, err - } - var concurrency int - if opConfig.KeyCheckConcurrency != nil { - concurrency = *opConfig.KeyCheckConcurrency - } else { - concurrency = settings.BaseKeyCheckConcurrency - } - go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency) - return taskStatus, nil -} - -func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) { - var wg sync.WaitGroup - var mu sync.Mutex - finalResults := make([]models.KeyTestResult, len(keys)) - processedCount := 0 - if concurrency <= 0 { - concurrency = 10 - } - type job struct { - Index int - Value models.APIKey - } - jobs := make(chan job, len(keys)) - for i := 0; i < concurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := range jobs { - apiKeyModel := j.Value - keyToValidate := apiKeyModel - validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint) - - var currentResult models.KeyTestResult - event := models.RequestFinishedEvent{ - RequestLog: models.RequestLog{ - GroupID: &groupID, - KeyID: &apiKeyModel.ID, - }, - } - if validationErr == nil { - currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."} - event.RequestLog.IsSuccess = true - } else { - var apiErr *CustomErrors.APIError - if CustomErrors.As(validationErr, &apiErr) { - currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)} - event.Error = apiErr - } else { - currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()} - event.Error = &CustomErrors.APIError{Message: validationErr.Error()} - } - event.RequestLog.IsSuccess = false - } - eventData, _ := json.Marshal(event) - - if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil { - s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID) - } - - mu.Lock() - finalResults[j.Index] = currentResult - processedCount++ - _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount) - mu.Unlock() - } - }() - } - for i, k := range keys { - jobs <- job{Index: i, Value: k} - } - close(jobs) - wg.Wait() - s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil) -} - -func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { - s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses) - keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) - if err != nil { - return nil, CustomErrors.ParseDBError(err) - } - if len(keyValues) == 0 { - return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.") - } - keysAsText := strings.Join(keyValues, "\n") - s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID) - return s.StartTestKeysTask(ctx, groupID, keysAsText) -} diff --git a/internal/service/log_service.go b/internal/service/log_service.go index 38cf66d..8ff5262 100644 --- a/internal/service/log_service.go +++ b/internal/service/log_service.go @@ -1,78 +1,152 @@ -// Filename: internal/service/log_service.go package service import ( + "context" + "fmt" "gemini-balancer/internal/models" "strconv" - "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" "gorm.io/gorm" ) type LogService struct { - db *gorm.DB + db *gorm.DB + logger *logrus.Entry } -func NewLogService(db *gorm.DB) *LogService { - return &LogService{db: db} +func NewLogService(db *gorm.DB, logger *logrus.Logger) *LogService { + return &LogService{ + db: db, + logger: logger.WithField("component", "LogService"), + } } -func (s *LogService) Record(log *models.RequestLog) error { - return s.db.Create(log).Error +func (s *LogService) Record(ctx context.Context, log *models.RequestLog) error { + return s.db.WithContext(ctx).Create(log).Error } -func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, int64, error) { +// LogQueryParams 解耦 Gin,使用结构体传参 +type LogQueryParams struct { + Page int + PageSize int + ModelName string + IsSuccess *bool // 使用指针区分"未设置"和"false" + StatusCode *int + KeyID *uint64 + GroupID *uint64 +} + +func (s *LogService) GetLogs(ctx context.Context, params LogQueryParams) ([]models.RequestLog, int64, error) { + // 参数校验 + if params.Page < 1 { + params.Page = 1 + } + if params.PageSize < 1 || params.PageSize > 100 { + params.PageSize = 20 + } + var logs []models.RequestLog var total int64 - query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c)) + // 构建基础查询 + query := s.db.WithContext(ctx).Model(&models.RequestLog{}) + query = s.applyFilters(query, params) - // 先计算总数 + // 计算总数 if err := query.Count(&total).Error; err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to count logs: %w", err) } + if total == 0 { return []models.RequestLog{}, 0, nil } - // 再执行分页查询 - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) - offset := (page - 1) * pageSize - - err := query.Order("request_time desc").Limit(pageSize).Offset(offset).Find(&logs).Error - if err != nil { - return nil, 0, err + // 分页查询 + offset := (params.Page - 1) * params.PageSize + if err := query.Order("request_time DESC"). + Limit(params.PageSize). + Offset(offset). + Find(&logs).Error; err != nil { + return nil, 0, fmt.Errorf("failed to query logs: %w", err) } return logs, total, nil } -func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB { - return func(db *gorm.DB) *gorm.DB { - if modelName := c.Query("model_name"); modelName != "" { - db = db.Where("model_name = ?", modelName) - } - if isSuccessStr := c.Query("is_success"); isSuccessStr != "" { - if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil { - db = db.Where("is_success = ?", isSuccess) - } - } - if statusCodeStr := c.Query("status_code"); statusCodeStr != "" { - if statusCode, err := strconv.Atoi(statusCodeStr); err == nil { - db = db.Where("status_code = ?", statusCode) - } - } - if keyIDStr := c.Query("key_id"); keyIDStr != "" { - if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil { - db = db.Where("key_id = ?", keyID) - } - } - if groupIDStr := c.Query("group_id"); groupIDStr != "" { - if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil { - db = db.Where("group_id = ?", groupID) - } - } - return db +func (s *LogService) applyFilters(query *gorm.DB, params LogQueryParams) *gorm.DB { + if params.ModelName != "" { + query = query.Where("model_name = ?", params.ModelName) } + if params.IsSuccess != nil { + query = query.Where("is_success = ?", *params.IsSuccess) + } + if params.StatusCode != nil { + query = query.Where("status_code = ?", *params.StatusCode) + } + if params.KeyID != nil { + query = query.Where("key_id = ?", *params.KeyID) + } + if params.GroupID != nil { + query = query.Where("group_id = ?", *params.GroupID) + } + return query +} + +// ParseLogQueryParams 在 Handler 层调用,解析 Gin 参数 +func ParseLogQueryParams(queryParams map[string]string) (LogQueryParams, error) { + params := LogQueryParams{ + Page: 1, + PageSize: 20, + } + + if pageStr, ok := queryParams["page"]; ok { + if page, err := strconv.Atoi(pageStr); err == nil && page > 0 { + params.Page = page + } + } + + if pageSizeStr, ok := queryParams["page_size"]; ok { + if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 { + params.PageSize = pageSize + } + } + + if modelName, ok := queryParams["model_name"]; ok { + params.ModelName = modelName + } + + if isSuccessStr, ok := queryParams["is_success"]; ok { + if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil { + params.IsSuccess = &isSuccess + } else { + return params, fmt.Errorf("invalid is_success parameter: %s", isSuccessStr) + } + } + + if statusCodeStr, ok := queryParams["status_code"]; ok { + if statusCode, err := strconv.Atoi(statusCodeStr); err == nil { + params.StatusCode = &statusCode + } else { + return params, fmt.Errorf("invalid status_code parameter: %s", statusCodeStr) + } + } + + if keyIDStr, ok := queryParams["key_id"]; ok { + if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil { + params.KeyID = &keyID + } else { + return params, fmt.Errorf("invalid key_id parameter: %s", keyIDStr) + } + } + + if groupIDStr, ok := queryParams["group_id"]; ok { + if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil { + params.GroupID = &groupID + } else { + return params, fmt.Errorf("invalid group_id parameter: %s", groupIDStr) + } + } + + return params, nil } diff --git a/internal/service/stats_service.go b/internal/service/stats_service.go index 50f7b9c..6babb30 100644 --- a/internal/service/stats_service.go +++ b/internal/service/stats_service.go @@ -35,34 +35,55 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, func (s *StatsService) Start() { s.logger.Info("Starting event listener for stats maintenance.") - sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged) - if err != nil { - s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) - return - } - go func() { - defer sub.Close() - for { - select { - case msg := <-sub.Channel(): - var event models.KeyStatusChangedEvent - if err := json.Unmarshal(msg.Payload, &event); err != nil { - s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err) - continue - } - s.handleKeyStatusChange(&event) - case <-s.stopChan: - s.logger.Info("Stopping stats event listener.") - return - } - } - }() + go s.listenForEvents() } func (s *StatsService) Stop() { close(s.stopChan) } +func (s *StatsService) listenForEvents() { + for { + select { + case <-s.stopChan: + s.logger.Info("Stopping stats event listener.") + return + default: + } + + ctx, cancel := context.WithCancel(context.Background()) + sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged) + if err != nil { + s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err) + cancel() + time.Sleep(5 * time.Second) + continue + } + + s.logger.Info("Subscribed to key status changes") + s.handleSubscription(sub, cancel) + } +} + +func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) { + defer sub.Close() + defer cancel() + + for { + select { + case msg := <-sub.Channel(): + var event models.KeyStatusChangedEvent + if err := json.Unmarshal(msg.Payload, &event); err != nil { + s.logger.Errorf("Failed to unmarshal event: %v", err) + continue + } + s.handleKeyStatusChange(&event) + case <-s.stopChan: + return + } + } +} + func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) { if event.GroupID == 0 { s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID) @@ -75,23 +96,47 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent switch event.ChangeReason { case "key_unlinked", "key_hard_deleted": if event.OldStatus != "" { - s.store.HIncrBy(ctx, statsKey, "total_keys", -1) - s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) + if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil { + s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) + return + } + if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil { + s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) + return + } } else { s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID) } case "key_linked": if event.NewStatus != "" { - s.store.HIncrBy(ctx, statsKey, "total_keys", 1) - s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) + if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil { + s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) + return + } + if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil { + s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) + return + } } else { s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID) } case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key": - s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) - s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) + if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil { + s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) + return + } + if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil { + s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) + return + } default: s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID) @@ -113,13 +158,16 @@ func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uin } statsKey := fmt.Sprintf("stats:group:%d", groupID) - updates := make(map[string]interface{}) - totalKeys := int64(0) + updates := map[string]interface{}{ + "active_keys": int64(0), + "disabled_keys": int64(0), + "error_keys": int64(0), + "total_keys": int64(0), + } for _, res := range results { updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count - totalKeys += res.Count + updates["total_keys"] = updates["total_keys"].(int64) + res.Count } - updates["total_keys"] = totalKeys if err := s.store.Del(ctx, statsKey); err != nil { s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID) @@ -180,8 +228,18 @@ func (s *StatsService) AggregateHourlyStats(ctx context.Context) error { }) } - return s.db.WithContext(ctx).Clauses(clause.OnConflict{ + if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}}, DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}), - }).Create(&hourlyStats).Error + }).Create(&hourlyStats).Error; err != nil { + return err + } + + if err := s.db.WithContext(ctx). + Where("request_time >= ? AND request_time < ?", startTime, endTime). + Delete(&models.RequestLog{}).Error; err != nil { + s.logger.WithError(err).Warn("Failed to delete aggregated request logs") + } + + return nil } diff --git a/internal/service/token_manager.go b/internal/service/token_manager.go index 4f4fc4c..99ca74d 100644 --- a/internal/service/token_manager.go +++ b/internal/service/token_manager.go @@ -37,7 +37,7 @@ func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, log return tokens, nil } - s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged) + s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged, logger) if err != nil { return nil, fmt.Errorf("failed to create token manager syncer: %w", err) } diff --git a/internal/settings/settings.go b/internal/settings/settings.go index d8217f1..092f0b4 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -87,7 +87,7 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) ( return settings, nil } - s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel) + s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel, logger,) if err != nil { return nil, fmt.Errorf("failed to create system settings syncer: %w", err) } diff --git a/internal/syncer/syncer.go b/internal/syncer/syncer.go index 66bdaf5..cbbf4b5 100644 --- a/internal/syncer/syncer.go +++ b/internal/syncer/syncer.go @@ -4,46 +4,54 @@ import ( "context" "fmt" "gemini-balancer/internal/store" - "log" "sync" "time" + + "github.com/sirupsen/logrus" +) + +const ( + ReconnectDelay = 5 * time.Second + ReloadTimeout = 30 * time.Second ) -// LoaderFunc type LoaderFunc[T any] func() (T, error) -// CacheSyncer type CacheSyncer[T any] struct { mu sync.RWMutex cache T loader LoaderFunc[T] store store.Store channelName string + logger *logrus.Entry stopChan chan struct{} wg sync.WaitGroup } -// NewCacheSyncer func NewCacheSyncer[T any]( loader LoaderFunc[T], store store.Store, channelName string, + logger *logrus.Logger, ) (*CacheSyncer[T], error) { s := &CacheSyncer[T]{ loader: loader, store: store, channelName: channelName, + logger: logger.WithField("component", fmt.Sprintf("CacheSyncer[%s]", channelName)), stopChan: make(chan struct{}), } + if err := s.reload(); err != nil { - return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err) + return nil, fmt.Errorf("initial load failed: %w", err) } + s.wg.Add(1) go s.listenForUpdates() + return s, nil } -// Get, Invalidate, Stop, reload 方法 . func (s *CacheSyncer[T]) Get() T { s.mu.RLock() defer s.mu.RUnlock() @@ -51,33 +59,60 @@ func (s *CacheSyncer[T]) Get() T { } func (s *CacheSyncer[T]) Invalidate() error { - log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName) - return s.store.Publish(context.Background(), s.channelName, []byte("reload")) + s.logger.Info("Publishing invalidation notification") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := s.store.Publish(ctx, s.channelName, []byte("reload")); err != nil { + s.logger.WithError(err).Error("Failed to publish invalidation") + return err + } + return nil } func (s *CacheSyncer[T]) Stop() { close(s.stopChan) s.wg.Wait() - log.Printf("INFO: CacheSyncer for channel '%s' stopped.", s.channelName) + s.logger.Info("CacheSyncer stopped") } func (s *CacheSyncer[T]) reload() error { - log.Printf("INFO: Reloading cache for channel '%s'...", s.channelName) - newData, err := s.loader() - if err != nil { - log.Printf("ERROR: Failed to reload cache for '%s': %v", s.channelName, err) - return err + s.logger.Info("Reloading cache...") + + ctx, cancel := context.WithTimeout(context.Background(), ReloadTimeout) + defer cancel() + + type result struct { + data T + err error + } + resultChan := make(chan result, 1) + + go func() { + data, err := s.loader() + resultChan <- result{data, err} + }() + + select { + case res := <-resultChan: + if res.err != nil { + s.logger.WithError(res.err).Error("Failed to reload cache") + return res.err + } + s.mu.Lock() + s.cache = res.data + s.mu.Unlock() + s.logger.Info("Cache reloaded successfully") + return nil + case <-ctx.Done(): + s.logger.Error("Cache reload timeout") + return fmt.Errorf("reload timeout after %v", ReloadTimeout) } - s.mu.Lock() - s.cache = newData - s.mu.Unlock() - log.Printf("INFO: Cache for channel '%s' reloaded successfully.", s.channelName) - return nil } -// listenForUpdates ... func (s *CacheSyncer[T]) listenForUpdates() { defer s.wg.Done() + for { select { case <-s.stopChan: @@ -85,31 +120,39 @@ func (s *CacheSyncer[T]) listenForUpdates() { default: } - subscription, err := s.store.Subscribe(context.Background(), s.channelName) - if err != nil { - log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err) - time.Sleep(5 * time.Second) - continue - } - log.Printf("INFO: Subscribed to channel '%s' for cache invalidation.", s.channelName) - - subscriberLoop: - for { + if err := s.subscribeAndListen(); err != nil { + s.logger.WithError(err).Warnf("Subscription error, retrying in %v", ReconnectDelay) select { - case _, ok := <-subscription.Channel(): - if !ok { - log.Printf("WARN: Subscription channel '%s' closed, will re-subscribe.", s.channelName) - break subscriberLoop - } - log.Printf("INFO: Received invalidation notification on '%s', reloading cache.", s.channelName) - if err := s.reload(); err != nil { - log.Printf("ERROR: Failed to reload cache for '%s' after notification: %v", s.channelName, err) - } + case <-time.After(ReconnectDelay): case <-s.stopChan: - subscription.Close() return } } - subscription.Close() + } +} + +func (s *CacheSyncer[T]) subscribeAndListen() error { + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + subscription, err := s.store.Subscribe(ctx, s.channelName) + if err != nil { + return fmt.Errorf("failed to subscribe: %w", err) + } + defer subscription.Close() + s.logger.Info("Subscribed to channel") + for { + select { + case msg, ok := <-subscription.Channel(): + if !ok { + return fmt.Errorf("subscription channel closed") + } + s.logger.WithField("message", string(msg.Payload)).Info("Received invalidation notification") + if err := s.reload(); err != nil { + s.logger.WithError(err).Error("Failed to reload after notification") + } + case <-s.stopChan: + return nil + } } } diff --git a/internal/task/task.go b/internal/task/task.go index aa204c8..2345d1f 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -1,4 +1,3 @@ -// Filename: internal/task/task.go package task import ( @@ -13,7 +12,9 @@ import ( ) const ( - ResultTTL = 60 * time.Minute + ResultTTL = 60 * time.Minute + DefaultTimeout = 24 * time.Hour + LockTTL = 30 * time.Minute ) type Reporter interface { @@ -65,14 +66,21 @@ func (s *Task) getIsRunningFlagKey(taskID string) string { func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { lockKey := s.getResourceLockKey(resourceID) + taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID) - if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 { + locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL) + if err != nil { + return nil, fmt.Errorf("failed to acquire task lock: %w", err) + } + if !locked { + existingTaskID, _ := s.store.Get(ctx, lockKey) return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID)) } - taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID) - taskKey := s.getTaskDataKey(taskID) - runningFlagKey := s.getIsRunningFlagKey(taskID) + if timeout == 0 { + timeout = DefaultTimeout + } + status := &Status{ ID: taskID, TaskType: taskType, @@ -81,63 +89,55 @@ func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourc Total: total, StartedAt: time.Now(), } - statusBytes, err := json.Marshal(status) - if err != nil { - return nil, fmt.Errorf("failed to serialize new task status: %w", err) - } - if timeout == 0 { - timeout = ResultTTL * 24 - } - - if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil { - return nil, fmt.Errorf("failed to acquire task resource lock: %w", err) - } - if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil { + if err := s.saveStatus(ctx, taskID, status, timeout); err != nil { _ = s.store.Del(ctx, lockKey) - return nil, fmt.Errorf("failed to set new task data in store: %w", err) + return nil, fmt.Errorf("failed to save task status: %w", err) } + runningFlagKey := s.getIsRunningFlagKey(taskID) if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil { _ = s.store.Del(ctx, lockKey) - _ = s.store.Del(ctx, taskKey) - return nil, fmt.Errorf("failed to set task running flag: %w", err) + _ = s.store.Del(ctx, s.getTaskDataKey(taskID)) + return nil, fmt.Errorf("failed to set running flag: %w", err) } + return status, nil } func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) { lockKey := s.getResourceLockKey(resourceID) - defer func() { - if err := s.store.Del(ctx, lockKey); err != nil { - s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID) - } - }() runningFlagKey := s.getIsRunningFlagKey(taskID) - _ = s.store.Del(ctx, runningFlagKey) + + defer func() { + _ = s.store.Del(ctx, lockKey) + _ = s.store.Del(ctx, runningFlagKey) + }() status, err := s.GetStatus(ctx, taskID) if err != nil { - s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID) + s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID) return } + if !status.IsRunning { - s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID) + s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID) return } + now := time.Now() status.IsRunning = false status.FinishedAt = &now status.DurationSeconds = now.Sub(status.StartedAt).Seconds() + if taskErr != nil { status.Error = taskErr.Error() } else { status.Result = resultData } - updatedTaskBytes, _ := json.Marshal(status) - taskKey := s.getTaskDataKey(taskID) - if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil { - s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID) + + if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil { + s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID) } } @@ -148,43 +148,42 @@ func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) { if errors.Is(err, store.ErrNotFound) { return nil, errors.New("task not found") } - return nil, fmt.Errorf("failed to get task status from store: %w", err) + return nil, fmt.Errorf("failed to get task status: %w", err) } var status Status if err := json.Unmarshal(statusBytes, &status); err != nil { - return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID) + return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err) } + if !status.IsRunning && status.FinishedAt != nil { status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() } + return &status, nil } func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error { runningFlagKey := s.getIsRunningFlagKey(taskID) if _, err := s.store.Get(ctx, runningFlagKey); err != nil { - return nil + if errors.Is(err, store.ErrNotFound) { + return errors.New("task is not running") + } + return fmt.Errorf("failed to check running flag: %w", err) } + status, err := s.GetStatus(ctx, taskID) if err != nil { - s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID) - return nil + return fmt.Errorf("failed to get task status: %w", err) } + if !status.IsRunning { - return nil + return errors.New("task is not running") } + updater(status) - statusBytes, marshalErr := json.Marshal(status) - if marshalErr != nil { - s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID) - return nil - } - taskKey := s.getTaskDataKey(taskID) - if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil { - s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID) - } - return nil + + return s.saveStatus(ctx, taskID, status, DefaultTimeout) } func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error { @@ -198,3 +197,17 @@ func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) er status.Total = total }) } + +func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error { + statusBytes, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("failed to serialize status: %w", err) + } + + taskKey := s.getTaskDataKey(taskID) + if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil { + return fmt.Errorf("failed to save status: %w", err) + } + + return nil +} diff --git a/internal/webhandlers/auth_handler.go b/internal/webhandlers/auth_handler.go index cc7ca0f..ef7cc5d 100644 --- a/internal/webhandlers/auth_handler.go +++ b/internal/webhandlers/auth_handler.go @@ -1,43 +1,118 @@ -// Filename: internal/webhandlers/auth_handler.go (最终现代化改造版) +// Filename: internal/webhandlers/auth_handler.go + package webhandlers import ( "gemini-balancer/internal/middleware" - "gemini-balancer/internal/service" // [核心改造] 依赖service层 + "gemini-balancer/internal/service" "net/http" + "strings" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" ) -// WebAuthHandler [核心改造] 依赖关系净化,注入SecurityService +// WebAuthHandler Web 认证处理器 type WebAuthHandler struct { securityService *service.SecurityService + logger *logrus.Logger } -// NewWebAuthHandler [核心改造] 构造函数更新 +// NewWebAuthHandler 创建 WebAuthHandler func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler { + logger := logrus.New() + logger.SetLevel(logrus.InfoLevel) + return &WebAuthHandler{ securityService: securityService, + logger: logger, } } -// ShowLoginPage 保持不变 +// ShowLoginPage 显示登录页面 func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) { errMsg := c.Query("error") - from := c.Query("from") // 可以从登录失败的页面返回 + + // 验证重定向路径(防止开放重定向攻击) + redirectPath := h.validateRedirectPath(c.Query("redirect")) + + // 如果已登录,直接重定向 + if cookie := middleware.ExtractTokenFromCookie(c); cookie != "" { + if _, err := h.securityService.AuthenticateToken(cookie); err == nil { + c.Redirect(http.StatusFound, redirectPath) + return + } + } + c.HTML(http.StatusOK, "auth.html", gin.H{ - "error": errMsg, - "from": from, + "error": errMsg, + "redirect": redirectPath, }) } -// HandleLogin [核心改造] 认证逻辑完全委托给SecurityService +// HandleLogin 已废弃(项目无用户名系统) func (h *WebAuthHandler) HandleLogin(c *gin.Context) { c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD") } -// HandleLogout 保持不变 +// HandleLogout 处理登出请求 func (h *WebAuthHandler) HandleLogout(c *gin.Context) { + cookie := middleware.ExtractTokenFromCookie(c) + + if cookie != "" { + // 尝试获取 Token 信息用于日志 + authToken, err := h.securityService.AuthenticateToken(cookie) + if err == nil { + h.logger.WithFields(logrus.Fields{ + "token_id": authToken.ID, + "client_ip": c.ClientIP(), + }).Info("User logged out") + } else { + h.logger.WithField("client_ip", c.ClientIP()).Warn("Logout with invalid token") + } + + // 使缓存失效 + middleware.InvalidateTokenCache(cookie) + } else { + h.logger.WithField("client_ip", c.ClientIP()).Debug("Logout without session cookie") + } + + // 清除 Cookie middleware.ClearAdminSessionCookie(c) + + // 重定向到登录页 c.Redirect(http.StatusFound, "/login") } + +// validateRedirectPath 验证重定向路径(防止开放重定向攻击) +func (h *WebAuthHandler) validateRedirectPath(path string) string { + defaultPath := "/dashboard" + + if path == "" { + return defaultPath + } + + // 只允许内部路径 + if !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") { + h.logger.WithField("path", path).Warn("Invalid redirect path blocked") + return defaultPath + } + + // 白名单验证 + allowedPaths := []string{ + "/dashboard", + "/keys", + "/settings", + "/logs", + "/tasks", + "/chat", + } + + for _, allowed := range allowedPaths { + if strings.HasPrefix(path, allowed) { + return path + } + } + + return defaultPath +} diff --git a/web/templates/base.html b/web/templates/base.html index 55df2f2..b213f55 100644 --- a/web/templates/base.html +++ b/web/templates/base.html @@ -123,7 +123,7 @@ {% block core_scripts %} - + {% endblock core_scripts %} diff --git a/web/templates/settings.html b/web/templates/settings.html index f2bea8c..271b1ab 100644 --- a/web/templates/settings.html +++ b/web/templates/settings.html @@ -492,7 +492,7 @@ type="text" id="TEST_MODEL" name="TEST_MODEL" - placeholder="gemini-1.5-flash" + placeholder="gemini-2.0-flash-lite" class="flex-grow px-4 py-3 rounded-lg form-input-themed" />