diff --git a/cmd/server/main.go b/cmd/server/main.go
index 6a40af4..2eda74c 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -3,10 +3,12 @@ package main
import (
"gemini-balancer/internal/app"
"gemini-balancer/internal/container"
+ "gemini-balancer/internal/logging"
"log"
)
func main() {
+ defer logging.Close()
cont, err := container.BuildContainer()
if err != nil {
log.Fatalf("FATAL: Failed to build dependency container: %v", err)
diff --git a/config.yaml b/config.yaml
index e498190..f0c1824 100644
--- a/config.yaml
+++ b/config.yaml
@@ -14,6 +14,12 @@ server:
log:
level: "debug"
+# 日志轮转配置
+max_size: 100 # MB
+max_backups: 7 # 保留文件数
+max_age: 30 # 保留天数
+compress: true # 压缩旧日志
+
redis:
dsn: "redis://localhost:6379/0"
diff --git a/go.mod b/go.mod
index e685f39..38d19c5 100644
--- a/go.mod
+++ b/go.mod
@@ -17,6 +17,8 @@ require (
github.com/spf13/viper v1.20.1
go.uber.org/dig v1.19.0
golang.org/x/net v0.42.0
+ golang.org/x/time v0.14.0
+ gopkg.in/natefinch/lumberjack.v2 v2.2.1
gorm.io/datatypes v1.0.5
gorm.io/driver/mysql v1.6.0
gorm.io/driver/postgres v1.6.0
diff --git a/go.sum b/go.sum
index cce6ad6..fc28ad2 100644
--- a/go.sum
+++ b/go.sum
@@ -311,6 +311,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
+golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
+golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
@@ -325,6 +327,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
+gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
+gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
diff --git a/internal/config/config.go b/internal/config/config.go
index e9c877f..0027d5b 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -29,7 +29,9 @@ type DatabaseConfig struct {
// ServerConfig 存储HTTP服务器配置
type ServerConfig struct {
- Port string `mapstructure:"port"`
+ Port string `mapstructure:"port"`
+ Host string `yaml:"host"`
+ CORSOrigins []string `yaml:"cors_origins"`
}
// LogConfig 存储日志配置
@@ -38,6 +40,12 @@ type LogConfig struct {
Format string `mapstructure:"format" json:"format"`
EnableFile bool `mapstructure:"enable_file" json:"enable_file"`
FilePath string `mapstructure:"file_path" json:"file_path"`
+
+ // 日志轮转配置(可选)
+ MaxSize int `yaml:"max_size"` // MB,默认 100
+ MaxBackups int `yaml:"max_backups"` // 默认 7
+ MaxAge int `yaml:"max_age"` // 天,默认 30
+ Compress bool `yaml:"compress"` // 默认 true
}
type RedisConfig struct {
diff --git a/internal/container/container.go b/internal/container/container.go
index f9aeb1e..e8a12e3 100644
--- a/internal/container/container.go
+++ b/internal/container/container.go
@@ -87,7 +87,7 @@ func BuildContainer() (*dig.Container, error) {
// 为GroupManager配置Syncer
container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) {
const groupUpdateChannel = "groups:cache_invalidation"
- return syncer.NewCacheSyncer(loader, store, groupUpdateChannel)
+ return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger)
})
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
diff --git a/internal/domain/proxy/module.go b/internal/domain/proxy/module.go
index 3cd794c..f1b59ae 100644
--- a/internal/domain/proxy/module.go
+++ b/internal/domain/proxy/module.go
@@ -20,7 +20,7 @@ type Module struct {
func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) {
loader := newManagerLoader(gormDB)
- cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation")
+ cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation", logger)
if err != nil {
return nil, err
}
diff --git a/internal/errors/upstream_errors.go b/internal/errors/upstream_errors.go
index 751e35a..c85e456 100644
--- a/internal/errors/upstream_errors.go
+++ b/internal/errors/upstream_errors.go
@@ -71,6 +71,7 @@ var clientNetworkErrorSubstrings = []string{
"broken pipe",
"use of closed network connection",
"request canceled",
+ "invalid query parameters", // 参数解析错误,归类为客户端错误
}
// IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid.
diff --git a/internal/handlers/log_handler.go b/internal/handlers/log_handler.go
index 8e1edeb..4440dbf 100644
--- a/internal/handlers/log_handler.go
+++ b/internal/handlers/log_handler.go
@@ -5,7 +5,6 @@ import (
"gemini-balancer/internal/errors"
"gemini-balancer/internal/response"
"gemini-balancer/internal/service"
- "strconv"
"github.com/gin-gonic/gin"
)
@@ -19,22 +18,26 @@ func NewLogHandler(logService *service.LogService) *LogHandler {
}
func (h *LogHandler) GetLogs(c *gin.Context) {
- // 调用新的服务函数,接收日志列表和总数
- logs, total, err := h.logService.GetLogs(c)
+ queryParams := make(map[string]string)
+ for key, values := range c.Request.URL.Query() {
+ if len(values) > 0 {
+ queryParams[key] = values[0]
+ }
+ }
+ params, err := service.ParseLogQueryParams(queryParams)
+ if err != nil {
+ response.Error(c, errors.ErrBadRequest)
+ return
+ }
+ logs, total, err := h.logService.GetLogs(c.Request.Context(), params)
if err != nil {
response.Error(c, errors.ErrDatabase)
return
}
-
- // 解析分页参数用于响应体
- page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
- pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
-
- // 使用标准的分页响应结构
response.Success(c, gin.H{
"items": logs,
"total": total,
- "page": page,
- "page_size": pageSize,
+ "page": params.Page,
+ "page_size": params.PageSize,
})
}
diff --git a/internal/logging/logging.go b/internal/logging/logging.go
index ddf766f..db75182 100644
--- a/internal/logging/logging.go
+++ b/internal/logging/logging.go
@@ -9,20 +9,25 @@ import (
"path/filepath"
"github.com/sirupsen/logrus"
+ "gopkg.in/natefinch/lumberjack.v2"
)
+// 包级变量,用于存储日志轮转器
+var logRotator *lumberjack.Logger
+
+// NewLogger 返回标准的 *logrus.Logger(兼容 Fx 依赖注入)
func NewLogger(cfg *config.Config) *logrus.Logger {
logger := logrus.New()
- // 1. 设置日志级别
+ // 设置日志级别
level, err := logrus.ParseLevel(cfg.Log.Level)
if err != nil {
- logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level specified, defaulting to 'info'.")
+ logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level, defaulting to 'info'")
level = logrus.InfoLevel
}
logger.SetLevel(level)
- // 2. 设置日志格式
+ // 设置日志格式
if cfg.Log.Format == "json" {
logger.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: "2006-01-02T15:04:05.000Z07:00",
@@ -39,36 +44,57 @@ func NewLogger(cfg *config.Config) *logrus.Logger {
})
}
- // 3. 设置日志输出
+ // 添加全局字段
+ hostname, _ := os.Hostname()
+ logger = logger.WithFields(logrus.Fields{
+ "service": "gemini-balancer",
+ "hostname": hostname,
+ }).Logger
+
+ // 设置日志输出
if cfg.Log.EnableFile {
if cfg.Log.FilePath == "" {
- logger.Warn("Log file is enabled but no file path is specified. Logging to console only.")
+ logger.Warn("Log file enabled but no path specified. Logging to console only")
logger.SetOutput(os.Stdout)
return logger
}
logDir := filepath.Dir(cfg.Log.FilePath)
- if err := os.MkdirAll(logDir, 0755); err != nil {
- logger.WithError(err).Warn("Failed to create log directory. Logging to console only.")
+ if err := os.MkdirAll(logDir, 0750); err != nil {
+ logger.WithError(err).Warn("Failed to create log directory. Logging to console only")
logger.SetOutput(os.Stdout)
return logger
}
- logFile, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
- if err != nil {
- logger.WithError(err).Warn("Failed to open log file. Logging to console only.")
- logger.SetOutput(os.Stdout)
- return logger
+ // 配置日志轮转(保存到包级变量)
+ logRotator = &lumberjack.Logger{
+ Filename: cfg.Log.FilePath,
+ MaxSize: getOrDefault(cfg.Log.MaxSize, 100),
+ MaxBackups: getOrDefault(cfg.Log.MaxBackups, 7),
+ MaxAge: getOrDefault(cfg.Log.MaxAge, 30),
+ Compress: cfg.Log.Compress,
}
- // 同时输出到控制台和文件
- logger.SetOutput(io.MultiWriter(os.Stdout, logFile))
- logger.WithField("log_file_path", cfg.Log.FilePath).Info("Logging is now configured to output to both console and file.")
+ logger.SetOutput(io.MultiWriter(os.Stdout, logRotator))
+ logger.WithField("log_file", cfg.Log.FilePath).Info("Logging to both console and file")
} else {
- // 仅输出到控制台
logger.SetOutput(os.Stdout)
}
- logger.Info("Root logger initialized.")
+ logger.Info("Logger initialized successfully")
return logger
}
+
+// Close 关闭日志轮转器(在 main.go 中调用)
+func Close() {
+ if logRotator != nil {
+ logRotator.Close()
+ }
+}
+
+func getOrDefault(value, defaultValue int) int {
+ if value <= 0 {
+ return defaultValue
+ }
+ return value
+}
diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go
index bef9658..91ea18b 100644
--- a/internal/middleware/auth.go
+++ b/internal/middleware/auth.go
@@ -1,4 +1,5 @@
// Filename: internal/middleware/auth.go
+
package middleware
import (
@@ -7,76 +8,115 @@ import (
"strings"
"github.com/gin-gonic/gin"
+ "github.com/sirupsen/logrus"
)
-// === API Admin 认证管道 (/admin/* API路由) ===
+type ErrorResponse struct {
+ Error string `json:"error"`
+ Code string `json:"code,omitempty"`
+}
-func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
+// APIAdminAuthMiddleware 管理后台 API 认证
+func APIAdminAuthMiddleware(
+ securityService *service.SecurityService,
+ logger *logrus.Logger,
+) gin.HandlerFunc {
return func(c *gin.Context) {
tokenValue := extractBearerToken(c)
if tokenValue == "" {
- c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"})
+ c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
+ Error: "Authentication required",
+ Code: "AUTH_MISSING",
+ })
return
}
+
+ // ✅ 只传 token 参数(移除 context)
authToken, err := securityService.AuthenticateToken(tokenValue)
- if err != nil || !authToken.IsAdmin {
- c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"})
+ if err != nil {
+ logger.WithError(err).Warn("Authentication failed")
+ c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
+ Error: "Invalid authentication",
+ Code: "AUTH_INVALID",
+ })
return
}
+
+ if !authToken.IsAdmin {
+ c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
+ Error: "Admin access required",
+ Code: "AUTH_FORBIDDEN",
+ })
+ return
+ }
+
c.Set("adminUser", authToken)
c.Next()
}
}
-// === /v1 Proxy 认证 ===
-
-func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
+// ProxyAuthMiddleware 代理请求认证
+func ProxyAuthMiddleware(
+ securityService *service.SecurityService,
+ logger *logrus.Logger,
+) gin.HandlerFunc {
return func(c *gin.Context) {
tokenValue := extractProxyToken(c)
if tokenValue == "" {
- c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"})
+ c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
+ Error: "API key required",
+ Code: "KEY_MISSING",
+ })
return
}
+
+ // ✅ 只传 token 参数(移除 context)
authToken, err := securityService.AuthenticateToken(tokenValue)
if err != nil {
- // 通用信息,避免泄露过多信息
- c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"})
+ c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
+ Error: "Invalid API key",
+ Code: "KEY_INVALID",
+ })
return
}
+
c.Set("authToken", authToken)
c.Next()
}
}
+// extractProxyToken 按优先级提取 token
func extractProxyToken(c *gin.Context) string {
- if key := c.Query("key"); key != "" {
- return key
- }
- authHeader := c.GetHeader("Authorization")
- if authHeader != "" {
- if strings.HasPrefix(authHeader, "Bearer ") {
- return strings.TrimPrefix(authHeader, "Bearer ")
- }
+ // 优先级 1: Authorization Header
+ if token := extractBearerToken(c); token != "" {
+ return token
}
+
+ // 优先级 2: X-Api-Key
if key := c.GetHeader("X-Api-Key"); key != "" {
return key
}
+
+ // 优先级 3: X-Goog-Api-Key
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
- return ""
+
+ // 优先级 4: Query 参数(不推荐)
+ return c.Query("key")
}
-// === 辅助函数 ===
-
+// extractBearerToken 提取 Bearer Token
func extractBearerToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
- parts := strings.Split(authHeader, " ")
- if len(parts) == 2 && parts[0] == "Bearer" {
- return parts[1]
+
+ const prefix = "Bearer "
+ if !strings.HasPrefix(authHeader, prefix) {
+ return ""
}
- return ""
+
+ return strings.TrimSpace(authHeader[len(prefix):])
}
diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go
new file mode 100644
index 0000000..c94143f
--- /dev/null
+++ b/internal/middleware/cors.go
@@ -0,0 +1,90 @@
+// Filename: internal/middleware/cors.go
+
+package middleware
+
+import (
+ "net/http"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type CORSConfig struct {
+ AllowedOrigins []string
+ AllowedMethods []string
+ AllowedHeaders []string
+ ExposedHeaders []string
+ AllowCredentials bool
+ MaxAge int
+}
+
+func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ origin := c.Request.Header.Get("Origin")
+
+ // 检查 origin 是否允许
+ if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) {
+ c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
+ }
+
+ if config.AllowCredentials {
+ c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
+ }
+
+ if len(config.ExposedHeaders) > 0 {
+ c.Writer.Header().Set("Access-Control-Expose-Headers",
+ strings.Join(config.ExposedHeaders, ", "))
+ }
+
+ // 处理预检请求
+ if c.Request.Method == http.MethodOptions {
+ if len(config.AllowedMethods) > 0 {
+ c.Writer.Header().Set("Access-Control-Allow-Methods",
+ strings.Join(config.AllowedMethods, ", "))
+ }
+
+ if len(config.AllowedHeaders) > 0 {
+ c.Writer.Header().Set("Access-Control-Allow-Headers",
+ strings.Join(config.AllowedHeaders, ", "))
+ }
+
+ if config.MaxAge > 0 {
+ c.Writer.Header().Set("Access-Control-Max-Age",
+ string(rune(config.MaxAge)))
+ }
+
+ c.AbortWithStatus(http.StatusNoContent)
+ return
+ }
+
+ c.Next()
+ }
+}
+
+func isOriginAllowed(origin string, allowedOrigins []string) bool {
+ for _, allowed := range allowedOrigins {
+ if allowed == "*" || allowed == origin {
+ return true
+ }
+ // 支持通配符子域名
+ if strings.HasPrefix(allowed, "*.") {
+ domain := strings.TrimPrefix(allowed, "*.")
+ if strings.HasSuffix(origin, domain) {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// 使用示例
+func SetupCORS(r *gin.Engine) {
+ r.Use(CORSMiddleware(CORSConfig{
+ AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"},
+ AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
+ AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"},
+ ExposedHeaders: []string{"X-Request-Id"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ }))
+}
diff --git a/internal/middleware/log.go b/internal/middleware/log.go
index 4a35e9e..cce1a56 100644
--- a/internal/middleware/log.go
+++ b/internal/middleware/log.go
@@ -1,84 +1,213 @@
-// Filename: internal/middleware/log_redaction.go
+// Filename: internal/middleware/logging.go
+
package middleware
import (
"bytes"
"io"
"regexp"
+ "strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
-const RedactedBodyKey = "redactedBody"
-const RedactedAuthHeaderKey = "redactedAuthHeader"
-const RedactedValue = `"[REDACTED]"`
+const (
+ RedactedBodyKey = "redactedBody"
+ RedactedAuthHeaderKey = "redactedAuthHeader"
+ RedactedValue = `"[REDACTED]"`
+)
+// 预编译正则表达式(全局变量,提升性能)
+var (
+ // JSON 敏感字段脱敏
+ jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`)
+
+ // Bearer Token 脱敏
+ bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`)
+
+ // URL 中的 key 参数脱敏
+ queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`)
+)
+
+// RedactionMiddleware 请求数据脱敏中间件
func RedactionMiddleware() gin.HandlerFunc {
- // Pre-compile regex for efficiency
- jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
- bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
return func(c *gin.Context) {
- // --- 1. Redact Request Body ---
- if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
- if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil {
-
- c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
-
- bodyString := string(bodyBytes)
-
- redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
-
- c.Set(RedactedBodyKey, redactedBody)
- }
- }
- // --- 2. Redact Authorization Header ---
- authHeader := c.GetHeader("Authorization")
- if authHeader != "" {
- if bearerTokenPattern.MatchString(authHeader) {
- redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
- c.Set(RedactedAuthHeaderKey, redactedHeader)
- }
+ // 1. 脱敏请求体
+ if shouldRedactBody(c) {
+ redactRequestBody(c)
}
+
+ // 2. 脱敏认证头
+ redactAuthHeader(c)
+
+ // 3. 脱敏 URL 查询参数
+ redactQueryParams(c)
+
c.Next()
}
}
-// LogrusLogger is a Gin middleware that logs requests using a Logrus logger.
-// It consumes redacted data prepared by the RedactionMiddleware.
+// shouldRedactBody 判断是否需要脱敏请求体
+func shouldRedactBody(c *gin.Context) bool {
+ method := c.Request.Method
+ contentType := c.GetHeader("Content-Type")
+
+ // 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求
+ return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") &&
+ strings.Contains(contentType, "application/json")
+}
+
+// redactRequestBody 脱敏请求体
+func redactRequestBody(c *gin.Context) {
+ bodyBytes, err := io.ReadAll(c.Request.Body)
+ if err != nil {
+ return
+ }
+
+ // 恢复请求体供后续使用
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
+
+ // 脱敏敏感字段
+ bodyString := string(bodyBytes)
+ redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue)
+
+ c.Set(RedactedBodyKey, redactedBody)
+}
+
+// redactAuthHeader 脱敏认证头
+func redactAuthHeader(c *gin.Context) {
+ authHeader := c.GetHeader("Authorization")
+ if authHeader == "" {
+ return
+ }
+
+ if bearerTokenPattern.MatchString(authHeader) {
+ redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
+ c.Set(RedactedAuthHeaderKey, redacted)
+ } else {
+ // 对于非 Bearer 的 token,全部脱敏
+ c.Set(RedactedAuthHeaderKey, "[REDACTED]")
+ }
+
+ // 同时处理其他敏感 Header
+ sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"}
+ for _, header := range sensitiveHeaders {
+ if value := c.GetHeader(header); value != "" {
+ c.Set("redacted_"+header, "[REDACTED]")
+ }
+ }
+}
+
+// redactQueryParams 脱敏 URL 查询参数
+func redactQueryParams(c *gin.Context) {
+ rawQuery := c.Request.URL.RawQuery
+ if rawQuery == "" {
+ return
+ }
+
+ redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`)
+ if redacted != rawQuery {
+ c.Set("redactedQuery", redacted)
+ }
+}
+
+// LogrusLogger Gin 请求日志中间件(使用 Logrus)
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
+ method := c.Request.Method
- // Process request
+ // 处理请求
c.Next()
- // After request, gather data and log
+ // 计算延迟
latency := time.Since(start)
statusCode := c.Writer.Status()
+ clientIP := c.ClientIP()
- entry := logger.WithFields(logrus.Fields{
- "status_code": statusCode,
- "latency_ms": latency.Milliseconds(),
- "client_ip": c.ClientIP(),
- "method": c.Request.Method,
- "path": path,
- })
+ // 构建日志字段
+ fields := logrus.Fields{
+ "status": statusCode,
+ "latency_ms": latency.Milliseconds(),
+ "ip": clientIP,
+ "method": method,
+ "path": path,
+ }
+ // 添加请求 ID(如果存在)
+ if requestID := getRequestID(c); requestID != "" {
+ fields["request_id"] = requestID
+ }
+
+ // 添加脱敏后的数据
if redactedBody, exists := c.Get(RedactedBodyKey); exists {
- entry = entry.WithField("body", redactedBody)
+ fields["body"] = redactedBody
}
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
- entry = entry.WithField("authorization", redactedAuth)
+ fields["authorization"] = redactedAuth
}
+ if redactedQuery, exists := c.Get("redactedQuery"); exists {
+ fields["query"] = redactedQuery
+ }
+
+ // 添加用户信息(如果已认证)
+ if user := getAuthenticatedUser(c); user != "" {
+ fields["user"] = user
+ }
+
+ // 根据状态码选择日志级别
+ entry := logger.WithFields(fields)
+
if len(c.Errors) > 0 {
- entry.Error(c.Errors.String())
+ fields["errors"] = c.Errors.String()
+ entry.Error("Request failed")
} else {
- entry.Info("request handled")
+ switch {
+ case statusCode >= 500:
+ entry.Error("Server error")
+ case statusCode >= 400:
+ entry.Warn("Client error")
+ case statusCode >= 300:
+ entry.Info("Redirect")
+ default:
+ // 只在 Debug 模式记录成功请求
+ if logger.Level >= logrus.DebugLevel {
+ entry.Debug("Request completed")
+ }
+ }
}
}
}
+
+// getRequestID 获取请求 ID
+func getRequestID(c *gin.Context) string {
+ if id, exists := c.Get("request_id"); exists {
+ if requestID, ok := id.(string); ok {
+ return requestID
+ }
+ }
+ return ""
+}
+
+// getAuthenticatedUser 获取已认证用户标识
+func getAuthenticatedUser(c *gin.Context) string {
+ // 尝试从不同来源获取用户信息
+ if user, exists := c.Get("adminUser"); exists {
+ if authToken, ok := user.(interface{ GetID() string }); ok {
+ return authToken.GetID()
+ }
+ }
+
+ if user, exists := c.Get("authToken"); exists {
+ if authToken, ok := user.(interface{ GetID() string }); ok {
+ return authToken.GetID()
+ }
+ }
+
+ return ""
+}
diff --git a/internal/middleware/rate_limit.go b/internal/middleware/rate_limit.go
new file mode 100644
index 0000000..3c5bf91
--- /dev/null
+++ b/internal/middleware/rate_limit.go
@@ -0,0 +1,86 @@
+// Filename: internal/middleware/rate_limit.go
+
+package middleware
+
+import (
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "golang.org/x/time/rate"
+)
+
+type RateLimiter struct {
+ limiters map[string]*rate.Limiter
+ mu sync.RWMutex
+ r rate.Limit // 每秒请求数
+ b int // 突发容量
+}
+
+func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
+ return &RateLimiter{
+ limiters: make(map[string]*rate.Limiter),
+ r: r,
+ b: b,
+ }
+}
+
+func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ limiter, exists := rl.limiters[key]
+ if !exists {
+ limiter = rate.NewLimiter(rl.r, rl.b)
+ rl.limiters[key] = limiter
+ }
+
+ return limiter
+}
+
+// 定期清理不活跃的限制器
+func (rl *RateLimiter) cleanup() {
+ ticker := time.NewTicker(10 * time.Minute)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ rl.mu.Lock()
+ // 简单策略:定期清空(生产环境应该用 LRU)
+ rl.limiters = make(map[string]*rate.Limiter)
+ rl.mu.Unlock()
+ }
+}
+
+func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ // 按 IP 限流
+ key := c.ClientIP()
+
+ // 如果有认证 token,按 token 限流(更精确)
+ if authToken, exists := c.Get("authToken"); exists {
+ if token, ok := authToken.(interface{ GetID() string }); ok {
+ key = "token:" + token.GetID()
+ }
+ }
+
+ l := limiter.getLimiter(key)
+ if !l.Allow() {
+ c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
+ "error": "Rate limit exceeded",
+ "code": "RATE_LIMIT",
+ })
+ return
+ }
+
+ c.Next()
+ }
+}
+
+// 使用示例
+func SetupRateLimit(r *gin.Engine) {
+ limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20
+ go limiter.cleanup()
+
+ r.Use(RateLimitMiddleware(limiter))
+}
diff --git a/internal/middleware/request_id.go b/internal/middleware/request_id.go
new file mode 100644
index 0000000..ca5fbb2
--- /dev/null
+++ b/internal/middleware/request_id.go
@@ -0,0 +1,39 @@
+// Filename: internal/middleware/request_id.go
+
+package middleware
+
+import (
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+)
+
+// RequestIDMiddleware 请求 ID 追踪中间件
+func RequestIDMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ // 1. 尝试从 Header 获取现有的 Request ID
+ requestID := c.GetHeader("X-Request-Id")
+
+ // 2. 如果没有,生成新的
+ if requestID == "" {
+ requestID = uuid.New().String()
+ }
+
+ // 3. 设置到 Context
+ c.Set("request_id", requestID)
+
+ // 4. 返回给客户端(用于追踪)
+ c.Writer.Header().Set("X-Request-Id", requestID)
+
+ c.Next()
+ }
+}
+
+// GetRequestID 获取当前请求的 Request ID
+func GetRequestID(c *gin.Context) string {
+ if id, exists := c.Get("request_id"); exists {
+ if requestID, ok := id.(string); ok {
+ return requestID
+ }
+ }
+ return ""
+}
diff --git a/internal/middleware/security.go b/internal/middleware/security.go
index 673ce97..3850426 100644
--- a/internal/middleware/security.go
+++ b/internal/middleware/security.go
@@ -1,4 +1,4 @@
-// Filename: internal/middleware/security.go
+// Filename: internal/middleware/security.go (简化版)
package middleware
@@ -6,26 +6,136 @@ import (
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"net/http"
+ "sync"
+ "time"
"github.com/gin-gonic/gin"
+ "github.com/sirupsen/logrus"
)
-func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc {
+// 简单的缓存项
+type cacheItem struct {
+ value bool
+ expiration int64
+}
+
+// 简单的 TTL 缓存实现
+type IPBanCache struct {
+ items map[string]*cacheItem
+ mu sync.RWMutex
+ ttl time.Duration
+}
+
+func NewIPBanCache() *IPBanCache {
+ cache := &IPBanCache{
+ items: make(map[string]*cacheItem),
+ ttl: 1 * time.Minute,
+ }
+
+ // 启动清理协程
+ go cache.cleanup()
+
+ return cache
+}
+
+func (c *IPBanCache) Get(key string) (bool, bool) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ item, exists := c.items[key]
+ if !exists {
+ return false, false
+ }
+
+ // 检查是否过期
+ if time.Now().UnixNano() > item.expiration {
+ return false, false
+ }
+
+ return item.value, true
+}
+
+func (c *IPBanCache) Set(key string, value bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.items[key] = &cacheItem{
+ value: value,
+ expiration: time.Now().Add(c.ttl).UnixNano(),
+ }
+}
+
+func (c *IPBanCache) Delete(key string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ delete(c.items, key)
+}
+
+func (c *IPBanCache) cleanup() {
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ c.mu.Lock()
+ now := time.Now().UnixNano()
+ for key, item := range c.items {
+ if now > item.expiration {
+ delete(c.items, key)
+ }
+ }
+ c.mu.Unlock()
+ }
+}
+
+func IPBanMiddleware(
+ securityService *service.SecurityService,
+ settingsManager *settings.SettingsManager,
+ banCache *IPBanCache,
+ logger *logrus.Logger,
+) gin.HandlerFunc {
return func(c *gin.Context) {
if !settingsManager.IsIPBanEnabled() {
c.Next()
return
}
+
ip := c.ClientIP()
- isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
- if err != nil {
+
+ // 查缓存
+ if isBanned, exists := banCache.Get(ip); exists {
+ if isBanned {
+ logger.WithField("ip", ip).Debug("IP blocked (cached)")
+ c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
+ "error": "Access denied",
+ })
+ return
+ }
c.Next()
return
}
- if isBanned {
- c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁,请稍后再试"})
+
+ // 查数据库
+ ctx := c.Request.Context()
+ isBanned, err := securityService.IsIPBanned(ctx, ip)
+ if err != nil {
+ logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
+
+ // 降级策略:允许访问
+ c.Next()
return
}
+
+ // 更新缓存
+ banCache.Set(ip, isBanned)
+
+ if isBanned {
+ logger.WithField("ip", ip).Info("IP blocked")
+ c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
+ "error": "Access denied",
+ })
+ return
+ }
+
c.Next()
}
}
diff --git a/internal/middleware/timeout.go b/internal/middleware/timeout.go
new file mode 100644
index 0000000..b285702
--- /dev/null
+++ b/internal/middleware/timeout.go
@@ -0,0 +1,52 @@
+// Filename: internal/middleware/timeout.go
+
+package middleware
+
+import (
+ "context"
+ "net/http"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ // 创建带超时的 context
+ ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
+ defer cancel()
+
+ // 替换 request context
+ c.Request = c.Request.WithContext(ctx)
+
+ // 使用 channel 等待请求完成
+ finished := make(chan struct{})
+
+ go func() {
+ c.Next()
+ close(finished)
+ }()
+
+ select {
+ case <-finished:
+ // 请求正常完成
+ return
+ case <-ctx.Done():
+ // 超时
+ c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
+ "error": "Request timeout",
+ "code": "TIMEOUT",
+ })
+ }
+ }
+}
+
+// 使用示例
+func SetupTimeout(r *gin.Engine) {
+ // 对 API 路由设置 30 秒超时
+ api := r.Group("/api")
+ api.Use(TimeoutMiddleware(30 * time.Second))
+ {
+ // ... API routes
+ }
+}
diff --git a/internal/middleware/web.go b/internal/middleware/web.go
index 8d32dba..7831e01 100644
--- a/internal/middleware/web.go
+++ b/internal/middleware/web.go
@@ -1,23 +1,151 @@
// Filename: internal/middleware/web.go
+
package middleware
import (
+ "crypto/sha256"
+ "encoding/hex"
"gemini-balancer/internal/service"
- "log"
"net/http"
+ "os"
+ "strings"
+ "sync"
+ "time"
"github.com/gin-gonic/gin"
+ "github.com/sirupsen/logrus"
)
const (
AdminSessionCookie = "gemini_admin_session"
+ SessionMaxAge = 3600 * 24 * 7 // 7天
+ CacheTTL = 5 * time.Minute
+ CleanupInterval = 10 * time.Minute // 降低清理频率
+ SessionRefreshTime = 30 * time.Minute
)
+// ==================== 缓存层 ====================
+
+type authCacheEntry struct {
+ Token interface{}
+ ExpiresAt time.Time
+}
+
+type authCache struct {
+ mu sync.RWMutex
+ cache map[string]*authCacheEntry
+ ttl time.Duration
+}
+
+var webAuthCache = newAuthCache(CacheTTL)
+
+func newAuthCache(ttl time.Duration) *authCache {
+ c := &authCache{
+ cache: make(map[string]*authCacheEntry),
+ ttl: ttl,
+ }
+ go c.cleanupLoop()
+ return c
+}
+
+func (c *authCache) get(key string) (interface{}, bool) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ entry, exists := c.cache[key]
+ if !exists || time.Now().After(entry.ExpiresAt) {
+ return nil, false
+ }
+ return entry.Token, true
+}
+
+func (c *authCache) set(key string, token interface{}) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.cache[key] = &authCacheEntry{
+ Token: token,
+ ExpiresAt: time.Now().Add(c.ttl),
+ }
+}
+
+func (c *authCache) delete(key string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ delete(c.cache, key)
+}
+
+func (c *authCache) cleanupLoop() {
+ ticker := time.NewTicker(CleanupInterval)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ c.cleanup()
+ }
+}
+
+func (c *authCache) cleanup() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ now := time.Now()
+ count := 0
+ for key, entry := range c.cache {
+ if now.After(entry.ExpiresAt) {
+ delete(c.cache, key)
+ count++
+ }
+ }
+
+ if count > 0 {
+ logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count)
+ }
+}
+
+// ==================== 会话刷新缓存 ====================
+
+var sessionRefreshCache = struct {
+ sync.RWMutex
+ timestamps map[string]time.Time
+}{
+ timestamps: make(map[string]time.Time),
+}
+
+// 定期清理刷新时间戳
+func init() {
+ go func() {
+ ticker := time.NewTicker(1 * time.Hour)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ sessionRefreshCache.Lock()
+ now := time.Now()
+ for key, ts := range sessionRefreshCache.timestamps {
+ if now.Sub(ts) > 2*time.Hour {
+ delete(sessionRefreshCache.timestamps, key)
+ }
+ }
+ sessionRefreshCache.Unlock()
+ }
+ }()
+}
+
+// ==================== Cookie 操作 ====================
+
func SetAdminSessionCookie(c *gin.Context, adminToken string) {
- c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true)
+ secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
+ c.SetSameSite(http.SameSiteStrictMode)
+ c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true)
+}
+
+func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) {
+ secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
+ c.SetSameSite(http.SameSiteStrictMode)
+ c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true)
}
func ClearAdminSessionCookie(c *gin.Context) {
+ c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
}
@@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string {
return cookie
}
+// ==================== 认证中间件 ====================
+
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
+ logger := logrus.New()
+ logger.SetLevel(getLogLevel())
+
return func(c *gin.Context) {
cookie := ExtractTokenFromCookie(c)
- log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path)
- log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie)
- authToken, err := authService.AuthenticateToken(cookie)
- if err != nil {
- log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
- } else if !authToken.IsAdmin {
- log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
- } else {
- log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
- }
- if err != nil || !authToken.IsAdmin {
+ if cookie == "" {
+ logger.Debug("[WebAuth] No session cookie found")
ClearAdminSessionCookie(c)
- c.Redirect(http.StatusFound, "/login")
- c.Abort()
+ redirectToLogin(c)
return
}
+
+ cacheKey := hashToken(cookie)
+
+ if cachedToken, found := webAuthCache.get(cacheKey); found {
+ logger.Debug("[WebAuth] Using cached token")
+ c.Set("adminUser", cachedToken)
+ refreshSessionIfNeeded(c, cookie)
+ c.Next()
+ return
+ }
+
+ logger.Debug("[WebAuth] Cache miss, authenticating...")
+ authToken, err := authService.AuthenticateToken(cookie)
+
+ if err != nil {
+ logger.WithError(err).Warn("[WebAuth] Authentication failed")
+ ClearAdminSessionCookie(c)
+ webAuthCache.delete(cacheKey)
+ redirectToLogin(c)
+ return
+ }
+
+ if !authToken.IsAdmin {
+ logger.Warn("[WebAuth] User is not admin")
+ ClearAdminSessionCookie(c)
+ webAuthCache.delete(cacheKey)
+ redirectToLogin(c)
+ return
+ }
+
+ webAuthCache.set(cacheKey, authToken)
+ logger.Debug("[WebAuth] Authentication success, token cached")
+
c.Set("adminUser", authToken)
+ refreshSessionIfNeeded(c, cookie)
c.Next()
}
}
+
+func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ cookie := ExtractTokenFromCookie(c)
+
+ if cookie == "" {
+ logger.Debug("No session cookie found")
+ ClearAdminSessionCookie(c)
+ redirectToLogin(c)
+ return
+ }
+
+ cacheKey := hashToken(cookie)
+ if cachedToken, found := webAuthCache.get(cacheKey); found {
+ c.Set("adminUser", cachedToken)
+ refreshSessionIfNeeded(c, cookie)
+ c.Next()
+ return
+ }
+
+ authToken, err := authService.AuthenticateToken(cookie)
+
+ if err != nil {
+ logger.WithError(err).Warn("Token authentication failed")
+ ClearAdminSessionCookie(c)
+ webAuthCache.delete(cacheKey)
+ redirectToLogin(c)
+ return
+ }
+
+ if !authToken.IsAdmin {
+ logger.Warn("Token valid but user is not admin")
+ ClearAdminSessionCookie(c)
+ webAuthCache.delete(cacheKey)
+ redirectToLogin(c)
+ return
+ }
+
+ webAuthCache.set(cacheKey, authToken)
+ c.Set("adminUser", authToken)
+ refreshSessionIfNeeded(c, cookie)
+ c.Next()
+ }
+}
+
+// ==================== 辅助函数 ====================
+
+func hashToken(token string) string {
+ h := sha256.Sum256([]byte(token))
+ return hex.EncodeToString(h[:])
+}
+
+func redirectToLogin(c *gin.Context) {
+ if isAjaxRequest(c) {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "error": "Session expired",
+ "code": "AUTH_REQUIRED",
+ })
+ c.Abort()
+ return
+ }
+
+ originalPath := c.Request.URL.Path
+ if originalPath != "/" && originalPath != "/login" {
+ c.Redirect(http.StatusFound, "/login?redirect="+originalPath)
+ } else {
+ c.Redirect(http.StatusFound, "/login")
+ }
+ c.Abort()
+}
+
+func isAjaxRequest(c *gin.Context) bool {
+ // 检查 Content-Type
+ contentType := c.GetHeader("Content-Type")
+ if strings.Contains(contentType, "application/json") {
+ return true
+ }
+
+ // 检查 Accept(优先检查 JSON)
+ accept := c.GetHeader("Accept")
+ if strings.Contains(accept, "application/json") &&
+ !strings.Contains(accept, "text/html") {
+ return true
+ }
+
+ // 兼容旧版 XMLHttpRequest
+ return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
+}
+
+func refreshSessionIfNeeded(c *gin.Context, token string) {
+ tokenHash := hashToken(token)
+
+ sessionRefreshCache.RLock()
+ lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash]
+ sessionRefreshCache.RUnlock()
+
+ if !exists || time.Since(lastRefresh) > SessionRefreshTime {
+ SetAdminSessionCookie(c, token)
+
+ sessionRefreshCache.Lock()
+ sessionRefreshCache.timestamps[tokenHash] = time.Now()
+ sessionRefreshCache.Unlock()
+ }
+}
+
+func getLogLevel() logrus.Level {
+ level := os.Getenv("LOG_LEVEL")
+ switch strings.ToLower(level) {
+ case "debug":
+ return logrus.DebugLevel
+ case "warn":
+ return logrus.WarnLevel
+ case "error":
+ return logrus.ErrorLevel
+ default:
+ return logrus.InfoLevel
+ }
+}
+
+// ==================== 工具函数 ====================
+
+func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) {
+ return c.Get("adminUser")
+}
+
+func InvalidateTokenCache(token string) {
+ tokenHash := hashToken(token)
+ webAuthCache.delete(tokenHash)
+
+ // 同时清理刷新时间戳
+ sessionRefreshCache.Lock()
+ delete(sessionRefreshCache.timestamps, tokenHash)
+ sessionRefreshCache.Unlock()
+}
+
+func ClearAllAuthCache() {
+ webAuthCache.mu.Lock()
+ webAuthCache.cache = make(map[string]*authCacheEntry)
+ webAuthCache.mu.Unlock()
+
+ sessionRefreshCache.Lock()
+ sessionRefreshCache.timestamps = make(map[string]time.Time)
+ sessionRefreshCache.Unlock()
+}
+
+// ==================== 调试工具 ====================
+
+type SessionInfo struct {
+ HasCookie bool `json:"has_cookie"`
+ IsValid bool `json:"is_valid"`
+ IsAdmin bool `json:"is_admin"`
+ IsCached bool `json:"is_cached"`
+ LastActivity string `json:"last_activity"`
+}
+
+func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo {
+ info := SessionInfo{
+ HasCookie: false,
+ IsValid: false,
+ IsAdmin: false,
+ IsCached: false,
+ LastActivity: time.Now().Format(time.RFC3339),
+ }
+
+ cookie := ExtractTokenFromCookie(c)
+ if cookie == "" {
+ return info
+ }
+
+ info.HasCookie = true
+
+ cacheKey := hashToken(cookie)
+ if _, found := webAuthCache.get(cacheKey); found {
+ info.IsCached = true
+ }
+
+ authToken, err := authService.AuthenticateToken(cookie)
+ if err != nil {
+ return info
+ }
+
+ info.IsValid = true
+ info.IsAdmin = authToken.IsAdmin
+
+ return info
+}
+
+func GetCacheStats() map[string]interface{} {
+ webAuthCache.mu.RLock()
+ cacheSize := len(webAuthCache.cache)
+ webAuthCache.mu.RUnlock()
+
+ sessionRefreshCache.RLock()
+ refreshSize := len(sessionRefreshCache.timestamps)
+ sessionRefreshCache.RUnlock()
+ return map[string]interface{}{
+ "auth_cache_entries": cacheSize,
+ "refresh_cache_entries": refreshSize,
+ "ttl_seconds": int(webAuthCache.ttl.Seconds()),
+ "cleanup_interval": int(CleanupInterval.Seconds()),
+ "session_refresh_time": int(SessionRefreshTime.Seconds()),
+ }
+}
diff --git a/internal/models/runtime.go b/internal/models/runtime.go
index f09beeb..ac6bcdc 100644
--- a/internal/models/runtime.go
+++ b/internal/models/runtime.go
@@ -27,6 +27,8 @@ type SystemSettings struct {
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
+ KeyCheckSchedulerIntervalSeconds int `json:"key_check_scheduler_interval_seconds" default:"60" name:"Key检查调度器间隔(秒)" category:"健康检查" desc:"动态调度器检查各组是否需要执行健康检查的周期。"`
+
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`
diff --git a/internal/pongo/renderer.go b/internal/pongo/renderer.go
index f00257c..14a705a 100644
--- a/internal/pongo/renderer.go
+++ b/internal/pongo/renderer.go
@@ -1,63 +1,96 @@
-// Filename: internal/pongo/renderer.go
-
package pongo
import (
"fmt"
"net/http"
+ "sync"
"github.com/flosch/pongo2/v6"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/render"
+ "github.com/sirupsen/logrus"
)
type Renderer struct {
- Context pongo2.Context
- tplSet *pongo2.TemplateSet
+ mu sync.RWMutex
+ globalContext pongo2.Context
+ tplSet *pongo2.TemplateSet
+ logger *logrus.Logger
}
-func New(directory string, isDebug bool) *Renderer {
+func New(directory string, isDebug bool, logger *logrus.Logger) *Renderer {
loader := pongo2.MustNewLocalFileSystemLoader(directory)
tplSet := pongo2.NewSet("gin-pongo-templates", loader)
tplSet.Debug = isDebug
- return &Renderer{Context: make(pongo2.Context), tplSet: tplSet}
+ return &Renderer{
+ globalContext: make(pongo2.Context),
+ tplSet: tplSet,
+ logger: logger,
+ }
}
-// Instance returns a new render.HTML instance for a single request.
-func (p *Renderer) Instance(name string, data interface{}) render.Render {
- var glob pongo2.Context
- if p.Context != nil {
- glob = p.Context
- }
+// SetGlobalContext 线程安全地设置全局上下文
+func (p *Renderer) SetGlobalContext(key string, value interface{}) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.globalContext[key] = value
+}
+// Warmup 预加载模板
+func (p *Renderer) Warmup(templateNames ...string) error {
+ for _, name := range templateNames {
+ if _, err := p.tplSet.FromCache(name); err != nil {
+ return fmt.Errorf("failed to warmup template '%s': %w", name, err)
+ }
+ }
+ p.logger.WithField("count", len(templateNames)).Info("Templates warmed up")
+ return nil
+}
+
+func (p *Renderer) Instance(name string, data interface{}) render.Render {
+ // 安全读取全局上下文
+ p.mu.RLock()
+ glob := make(pongo2.Context, len(p.globalContext))
+ for k, v := range p.globalContext {
+ glob[k] = v
+ }
+ p.mu.RUnlock()
+
+ // 解析请求数据
var context pongo2.Context
if data != nil {
- if ginContext, ok := data.(gin.H); ok {
- context = pongo2.Context(ginContext)
- } else if pongoContext, ok := data.(pongo2.Context); ok {
- context = pongoContext
- } else if m, ok := data.(map[string]interface{}); ok {
- context = m
- } else {
+ switch v := data.(type) {
+ case gin.H:
+ context = pongo2.Context(v)
+ case pongo2.Context:
+ context = v
+ case map[string]interface{}:
+ context = v
+ default:
context = make(pongo2.Context)
}
} else {
context = make(pongo2.Context)
}
+ // 合并上下文(请求数据优先)
for k, v := range glob {
- if _, ok := context[k]; !ok {
+ if _, exists := context[k]; !exists {
context[k] = v
}
}
+ // 加载模板
tpl, err := p.tplSet.FromCache(name)
if err != nil {
- panic(fmt.Sprintf("Failed to load template '%s': %v", name, err))
+ p.logger.WithError(err).WithField("template", name).Error("Failed to load template")
+ return &ErrorHTML{
+ StatusCode: http.StatusInternalServerError,
+ Error: fmt.Errorf("template load error: %s", name),
+ }
}
return &HTML{
- p: p,
Template: tpl,
Name: name,
Data: context,
@@ -65,7 +98,6 @@ func (p *Renderer) Instance(name string, data interface{}) render.Render {
}
type HTML struct {
- p *Renderer
Template *pongo2.Template
Name string
Data pongo2.Context
@@ -82,15 +114,31 @@ func (h *HTML) Render(w http.ResponseWriter) error {
}
func (h *HTML) WriteContentType(w http.ResponseWriter) {
- header := w.Header()
- if val := header["Content-Type"]; len(val) == 0 {
- header["Content-Type"] = []string{"text/html; charset=utf-8"}
+ if w.Header().Get("Content-Type") == "" {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
}
}
+// ErrorHTML 错误渲染器
+type ErrorHTML struct {
+ StatusCode int
+ Error error
+}
+
+func (e *ErrorHTML) Render(w http.ResponseWriter) error {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(e.StatusCode)
+ _, err := w.Write([]byte(e.Error.Error()))
+ return err
+}
+
+func (e *ErrorHTML) WriteContentType(w http.ResponseWriter) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+}
+
+// C 获取或创建 pongo2 上下文
func C(ctx *gin.Context) pongo2.Context {
- p, exists := ctx.Get("pongo2")
- if exists {
+ if p, exists := ctx.Get("pongo2"); exists {
if pCtx, ok := p.(pongo2.Context); ok {
return pCtx
}
diff --git a/internal/router/router.go b/internal/router/router.go
index 4724c1b..6c6d324 100644
--- a/internal/router/router.go
+++ b/internal/router/router.go
@@ -17,6 +17,7 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
+ "github.com/sirupsen/logrus"
)
func NewRouter(
@@ -42,70 +43,214 @@ func NewRouter(
upstreamModule *upstream.Module,
proxyModule *proxy.Module,
) *gin.Engine {
+ // === 1. 创建全局 Logger(统一管理)===
+ logger := createLogger(cfg)
+
+ // === 2. 设置 Gin 运行模式 ===
if cfg.Log.Level != "debug" {
gin.SetMode(gin.ReleaseMode)
}
- router := gin.Default()
- router.Static("/static", "./web/static")
- // CORS 配置
- config := cors.Config{
- // 允许前端的来源。在生产环境中,需改为实际域名
- AllowOrigins: []string{"http://localhost:9000"},
- AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
- AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"},
- ExposeHeaders: []string{"Content-Length"},
- AllowCredentials: true,
- MaxAge: 12 * time.Hour,
- }
- router.Use(cors.New(config))
- isDebug := gin.Mode() != gin.ReleaseMode
- router.HTMLRender = pongo.New("web/templates", isDebug)
+ // === 3. 创建 Router(使用 gin.New() 以便完全控制中间件)===
+ router := gin.New()
- // --- 基础设施 ---
- router.GET("/", func(c *gin.Context) { c.Redirect(http.StatusMovedPermanently, "/dashboard") })
- router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })
- // --- 统一的认证管道 ---
- apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService)
+ // === 4. 注册全局中间件(按执行顺序)===
+ setupGlobalMiddleware(router, logger)
+
+ // === 5. 配置静态文件和模板 ===
+ setupStaticAndTemplates(router, logger)
+
+ // === 6. 配置 CORS ===
+ setupCORS(router, cfg)
+
+ // === 7. 注册基础路由 ===
+ setupBasicRoutes(router)
+
+ // === 8. 创建认证中间件 ===
+ apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger)
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
- router.Use(gin.RecoveryWithWriter(os.Stdout))
- // --- 将正确的依赖和中间件管道传递下去 ---
- registerProxyRoutes(router, proxyHandler, securityService)
- registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
- registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager)
+ // === 9. 注册业务路由(按功能分组)===
+ registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger)
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
+ registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler,
+ logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
+ registerProxyRoutes(router, proxyHandler, securityService, logger)
+
return router
}
+// ==================== 辅助函数 ====================
+
+// createLogger 创建并配置全局 Logger
+func createLogger(cfg *config.Config) *logrus.Logger {
+ logger := logrus.New()
+
+ // 设置日志格式
+ if cfg.Log.Format == "json" {
+ logger.SetFormatter(&logrus.JSONFormatter{
+ TimestampFormat: time.RFC3339,
+ })
+ } else {
+ logger.SetFormatter(&logrus.TextFormatter{
+ FullTimestamp: true,
+ TimestampFormat: "2006-01-02 15:04:05",
+ })
+ }
+
+ // 设置日志级别
+ switch cfg.Log.Level {
+ case "debug":
+ logger.SetLevel(logrus.DebugLevel)
+ case "info":
+ logger.SetLevel(logrus.InfoLevel)
+ case "warn":
+ logger.SetLevel(logrus.WarnLevel)
+ case "error":
+ logger.SetLevel(logrus.ErrorLevel)
+ default:
+ logger.SetLevel(logrus.InfoLevel)
+ }
+
+ // 设置输出(可选:输出到文件)
+ logger.SetOutput(os.Stdout)
+
+ return logger
+}
+
+// setupGlobalMiddleware 设置全局中间件
+func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) {
+ // 1. 请求 ID 中间件(用于链路追踪)
+ router.Use(middleware.RequestIDMiddleware())
+
+ // 2. 数据脱敏中间件(在日志前执行)
+ router.Use(middleware.RedactionMiddleware())
+
+ // 3. 日志中间件
+ router.Use(middleware.LogrusLogger(logger))
+
+ // 4. 错误恢复中间件
+ router.Use(gin.RecoveryWithWriter(os.Stdout))
+}
+
+// setupStaticAndTemplates 配置静态文件和模板
+func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) {
+ router.Static("/static", "./web/static")
+
+ isDebug := gin.Mode() != gin.ReleaseMode
+ router.HTMLRender = pongo.New("web/templates", isDebug, logger)
+}
+
+// setupCORS 配置 CORS
+func setupCORS(router *gin.Engine, cfg *config.Config) {
+ corsConfig := cors.Config{
+ AllowOrigins: getCORSOrigins(cfg),
+ AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
+ AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"},
+ ExposeHeaders: []string{"Content-Length", "X-Request-Id"},
+ AllowCredentials: true,
+ MaxAge: 12 * time.Hour,
+ }
+ router.Use(cors.New(corsConfig))
+}
+
+// getCORSOrigins 获取 CORS 允许的来源
+func getCORSOrigins(cfg *config.Config) []string {
+ // 默认值
+ origins := []string{"http://localhost:9000"}
+
+ // 从配置读取(修复:移除 nil 检查)
+ if len(cfg.Server.CORSOrigins) > 0 {
+ origins = cfg.Server.CORSOrigins
+ }
+
+ return origins
+}
+
+// setupBasicRoutes 设置基础路由
+func setupBasicRoutes(router *gin.Engine) {
+ // 根路径重定向
+ router.GET("/", func(c *gin.Context) {
+ c.Redirect(http.StatusMovedPermanently, "/dashboard")
+ })
+
+ // 健康检查
+ router.GET("/health", handleHealthCheck)
+
+ // 版本信息(可选)
+ router.GET("/version", handleVersion)
+}
+
+// handleHealthCheck 健康检查处理器
+func handleHealthCheck(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{
+ "status": "ok",
+ "time": time.Now().Unix(),
+ })
+}
+
+// handleVersion 版本信息处理器
+func handleVersion(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{
+ "version": "1.0.0", // 可以从配置或编译时变量读取
+ "build": "latest",
+ })
+}
+
+// ==================== 路由注册函数 ====================
+
+// registerProxyRoutes 注册代理路由
func registerProxyRoutes(
- router *gin.Engine, proxyHandler *handlers.ProxyHandler, securityService *service.SecurityService,
+ router *gin.Engine,
+ proxyHandler *handlers.ProxyHandler,
+ securityService *service.SecurityService,
+ logger *logrus.Logger,
) {
- // 通用的代理认证中间件
- proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService)
- // --- 模式一: 智能聚合模式 (根路径) ---
- // /v1 和 /v1beta 路径作为默认入口,服务于 BasePool 聚合逻辑
+ // 创建代理认证中间件
+ proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger)
+
+ // 模式一: 智能聚合模式(默认入口)
+ registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
+
+ // 模式二: 精确路由模式(按组名路由)
+ registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
+}
+
+// registerAggregateProxyRoutes 注册聚合代理路由
+func registerAggregateProxyRoutes(
+ router *gin.Engine,
+ proxyHandler *handlers.ProxyHandler,
+ authMiddleware gin.HandlerFunc,
+) {
+ // /v1 路径组
v1 := router.Group("/v1")
- v1.Use(proxyAuthMiddleware)
+ v1.Use(authMiddleware)
{
v1.Any("/*path", proxyHandler.HandleProxy)
}
+
+ // /v1beta 路径组
v1beta := router.Group("/v1beta")
- v1beta.Use(proxyAuthMiddleware)
+ v1beta.Use(authMiddleware)
{
v1beta.Any("/*path", proxyHandler.HandleProxy)
}
- // --- 模式二: 精确路由模式 (/proxy/:group_name) ---
- // 创建一个新的、物理隔离的路由组,用于按组名精确路由
+}
+
+// registerGroupProxyRoutes 注册分组代理路由
+func registerGroupProxyRoutes(
+ router *gin.Engine,
+ proxyHandler *handlers.ProxyHandler,
+ authMiddleware gin.HandlerFunc,
+) {
proxyGroup := router.Group("/proxy/:group_name")
- proxyGroup.Use(proxyAuthMiddleware)
+ proxyGroup.Use(authMiddleware)
{
- // 捕获所有子路径 (例如 /v1/chat/completions),并全部交给同一个 ProxyHandler。
proxyGroup.Any("/*path", proxyHandler.HandleProxy)
}
}
-// registerAdminRoutes
+// registerAdminRoutes 注册管理后台 API 路由
func registerAdminRoutes(
router *gin.Engine,
authMiddleware gin.HandlerFunc,
@@ -121,74 +266,112 @@ func registerAdminRoutes(
) {
admin := router.Group("/admin", authMiddleware)
{
- // --- KeyGroup Base Routes ---
- admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
- admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
- admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
- // --- KeyGroup Specific Routes (by :id) ---
- admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
- admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
- admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
- admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
- admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
- admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
- // --- APIKey Sub-resource Routes under a KeyGroup ---
- keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
- {
- keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
- keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
- keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
- keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
- keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
- keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
- }
+ // KeyGroup 路由
+ registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler)
- // Global key operations
- admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
- // admin.PUT("/apikeys/:id", apiKeyHandler.UpdateAPIKey) // DEPRECATED: Status is now contextual
- admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) // Test keys globally
- admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) // Hard delete a single key
- admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) // Hard delete multiple keys
- admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) // Restore multiple keys globally
+ // APIKey 全局路由
+ registerAPIKeyRoutes(admin, apiKeyHandler)
- // --- Global Routes ---
- admin.GET("/tokens", tokensHandler.GetAllTokens)
- admin.PUT("/tokens", tokensHandler.UpdateTokens)
- admin.GET("/logs", logHandler.GetLogs)
- admin.GET("/settings", settingHandler.GetSettings)
- admin.PUT("/settings", settingHandler.UpdateSettings)
- admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
+ // 系统管理路由
+ registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler)
- // 用于查询异步任务的状态
- admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
+ // 仪表盘路由
+ registerDashboardRoutes(admin, dashboardHandler)
- // 领域模块
+ // 领域模块路由
upstreamModule.RegisterRoutes(admin)
proxyModule.RegisterRoutes(admin)
- // --- 全局仪表盘路由 ---
- dashboard := admin.Group("/dashboard")
- {
- dashboard.GET("/overview", dashboardHandler.GetOverview)
- dashboard.GET("/chart", dashboardHandler.GetChart)
- dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) // 点击详情
- }
}
}
-// registerWebRoutes
+// registerKeyGroupRoutes 注册 KeyGroup 相关路由
+func registerKeyGroupRoutes(
+ admin *gin.RouterGroup,
+ keyGroupHandler *handlers.KeyGroupHandler,
+ apiKeyHandler *handlers.APIKeyHandler,
+) {
+ // 基础路由
+ admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
+ admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
+ admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
+
+ // 特定 KeyGroup 路由
+ admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
+ admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
+ admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
+ admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
+ admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
+ admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
+
+ // KeyGroup 的 APIKey 子资源
+ keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
+ {
+ keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
+ keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
+ keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
+ keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
+ keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
+ keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
+ }
+}
+
+// registerAPIKeyRoutes 注册 APIKey 全局路由
+func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) {
+ admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
+ admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys)
+ admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey)
+ admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys)
+ admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys)
+}
+
+// registerSystemRoutes 注册系统管理路由
+func registerSystemRoutes(
+ admin *gin.RouterGroup,
+ tokensHandler *handlers.TokensHandler,
+ logHandler *handlers.LogHandler,
+ settingHandler *handlers.SettingHandler,
+ taskHandler *handlers.TaskHandler,
+) {
+ // Token 管理
+ admin.GET("/tokens", tokensHandler.GetAllTokens)
+ admin.PUT("/tokens", tokensHandler.UpdateTokens)
+
+ // 日志管理
+ admin.GET("/logs", logHandler.GetLogs)
+
+ // 设置管理
+ admin.GET("/settings", settingHandler.GetSettings)
+ admin.PUT("/settings", settingHandler.UpdateSettings)
+ admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
+
+ // 任务管理
+ admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
+}
+
+// registerDashboardRoutes 注册仪表盘路由
+func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) {
+ dashboard := admin.Group("/dashboard")
+ {
+ dashboard.GET("/overview", dashboardHandler.GetOverview)
+ dashboard.GET("/chart", dashboardHandler.GetChart)
+ dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats)
+ }
+}
+
+// registerWebRoutes 注册 Web 页面路由
func registerWebRoutes(
router *gin.Engine,
authMiddleware gin.HandlerFunc,
webAuthHandler *webhandlers.WebAuthHandler,
pageHandler *webhandlers.PageHandler,
) {
+ // 公开的认证路由
router.GET("/login", webAuthHandler.ShowLoginPage)
router.POST("/login", webAuthHandler.HandleLogin)
router.GET("/logout", webAuthHandler.HandleLogout)
- // For Test only router.Run("127.0.0.1:9000")
- // 受保护的Admin Web界面
+
+ // 受保护的管理界面
webGroup := router.Group("/", authMiddleware)
- webGroup.Use(authMiddleware)
{
webGroup.GET("/keys", pageHandler.ShowKeysPage)
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
@@ -197,14 +380,31 @@ func registerWebRoutes(
webGroup.GET("/tasks", pageHandler.ShowTasksPage)
webGroup.GET("/chat", pageHandler.ShowChatPage)
}
-
}
-// registerPublicAPIRoutes 无需后台登录的公共API路由
-func registerPublicAPIRoutes(router *gin.Engine, apiAuthHandler *handlers.APIAuthHandler, securityService *service.SecurityService, settingsManager *settings.SettingsManager) {
- ipBanMiddleware := middleware.IPBanMiddleware(securityService, settingsManager)
- publicAPIGroup := router.Group("/api")
+// registerPublicAPIRoutes 注册公共 API 路由
+func registerPublicAPIRoutes(
+ router *gin.Engine,
+ apiAuthHandler *handlers.APIAuthHandler,
+ securityService *service.SecurityService,
+ settingsManager *settings.SettingsManager,
+ logger *logrus.Logger,
+) {
+ // 创建 IP 封禁中间件
+ ipBanCache := middleware.NewIPBanCache()
+ ipBanMiddleware := middleware.IPBanMiddleware(
+ securityService,
+ settingsManager,
+ ipBanCache,
+ logger,
+ )
+
+ // 公共 API 路由组
+ publicAPI := router.Group("/api")
{
- publicAPIGroup.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
+ publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
+ // 可以在这里添加其他公共 API 路由
+ // publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister)
+ // publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword)
}
}
diff --git a/internal/service/analytics_service.go b/internal/service/analytics_service.go
index 3c0cad9..e2f8d8f 100644
--- a/internal/service/analytics_service.go
+++ b/internal/service/analytics_service.go
@@ -5,93 +5,179 @@ import (
"context"
"encoding/json"
"fmt"
- "gemini-balancer/internal/db/dialect"
- "gemini-balancer/internal/models"
- "gemini-balancer/internal/store"
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
+ "gemini-balancer/internal/db/dialect"
+ "gemini-balancer/internal/models"
+ "gemini-balancer/internal/settings"
+ "gemini-balancer/internal/store"
+
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
- flushLoopInterval = 1 * time.Minute
+ defaultFlushInterval = 1 * time.Minute
+ maxRetryAttempts = 3
+ retryDelay = 5 * time.Second
)
-type AnalyticsServiceLogger struct{ *logrus.Entry }
-
type AnalyticsService struct {
- db *gorm.DB
- store store.Store
- logger *logrus.Entry
+ db *gorm.DB
+ store store.Store
+ logger *logrus.Entry
+ dialect dialect.DialectAdapter
+ settingsManager *settings.SettingsManager
+
stopChan chan struct{}
wg sync.WaitGroup
- dialect dialect.DialectAdapter
+ ctx context.Context
+ cancel context.CancelFunc
+
+ // 统计指标
+ eventsReceived atomic.Uint64
+ eventsProcessed atomic.Uint64
+ eventsFailed atomic.Uint64
+ flushCount atomic.Uint64
+ recordsFlushed atomic.Uint64
+ flushErrors atomic.Uint64
+ lastFlushTime time.Time
+ lastFlushMutex sync.RWMutex
+
+ // 运行时配置
+ flushInterval time.Duration
+ configMutex sync.RWMutex
}
-func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService {
+func NewAnalyticsService(
+ db *gorm.DB,
+ s store.Store,
+ logger *logrus.Logger,
+ d dialect.DialectAdapter,
+ settingsManager *settings.SettingsManager,
+) *AnalyticsService {
+ ctx, cancel := context.WithCancel(context.Background())
+
return &AnalyticsService{
- db: db,
- store: s,
- logger: logger.WithField("component", "Analytics📊"),
- stopChan: make(chan struct{}),
- dialect: d,
+ db: db,
+ store: s,
+ logger: logger.WithField("component", "Analytics📊"),
+ dialect: d,
+ settingsManager: settingsManager,
+ stopChan: make(chan struct{}),
+ ctx: ctx,
+ cancel: cancel,
+ flushInterval: defaultFlushInterval,
+ lastFlushTime: time.Now(),
}
}
func (s *AnalyticsService) Start() {
- s.wg.Add(2)
- go s.flushLoop()
+ s.wg.Add(3)
go s.eventListener()
- s.logger.Info("AnalyticsService (Command Side) started.")
+ go s.flushLoop()
+ go s.metricsReporter()
+
+ s.logger.WithFields(logrus.Fields{
+ "flush_interval": s.flushInterval,
+ }).Info("AnalyticsService started")
}
func (s *AnalyticsService) Stop() {
+ s.logger.Info("AnalyticsService stopping...")
close(s.stopChan)
+ s.cancel()
s.wg.Wait()
- s.logger.Info("AnalyticsService stopped. Performing final data flush...")
+
+ s.logger.Info("Performing final data flush...")
s.flushToDB()
- s.logger.Info("AnalyticsService final data flush completed.")
+
+ // 输出最终统计
+ s.logger.WithFields(logrus.Fields{
+ "events_received": s.eventsReceived.Load(),
+ "events_processed": s.eventsProcessed.Load(),
+ "events_failed": s.eventsFailed.Load(),
+ "flush_count": s.flushCount.Load(),
+ "records_flushed": s.recordsFlushed.Load(),
+ "flush_errors": s.flushErrors.Load(),
+ }).Info("AnalyticsService stopped")
}
+// 事件监听循环
func (s *AnalyticsService) eventListener() {
defer s.wg.Done()
- sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
+
+ sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil {
- s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
+ s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
return
}
- defer sub.Close()
- s.logger.Info("AnalyticsService subscribed to request events.")
+ defer func() {
+ if err := sub.Close(); err != nil {
+ s.logger.WithError(err).Warn("Failed to close subscription")
+ }
+ }()
+
+ s.logger.Info("Subscribed to request events for analytics")
for {
select {
case msg := <-sub.Channel():
- var event models.RequestFinishedEvent
- if err := json.Unmarshal(msg.Payload, &event); err != nil {
- s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
- continue
- }
- s.handleAnalyticsEvent(&event)
+ s.handleMessage(msg)
+
case <-s.stopChan:
- s.logger.Info("AnalyticsService stopping event listener.")
+ s.logger.Info("Event listener stopping")
+ return
+
+ case <-s.ctx.Done():
+ s.logger.Info("Event listener context cancelled")
return
}
}
}
-func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) {
- if event.RequestLog.GroupID == nil {
+// 处理单条消息
+func (s *AnalyticsService) handleMessage(msg *store.Message) {
+ var event models.RequestFinishedEvent
+ if err := json.Unmarshal(msg.Payload, &event); err != nil {
+ s.logger.WithError(err).Error("Failed to unmarshal analytics event")
+ s.eventsFailed.Add(1)
return
}
- ctx := context.Background()
- key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
+
+ s.eventsReceived.Add(1)
+
+ if err := s.handleAnalyticsEvent(&event); err != nil {
+ s.eventsFailed.Add(1)
+ s.logger.WithFields(logrus.Fields{
+ "correlation_id": event.CorrelationID,
+ "group_id": event.RequestLog.GroupID,
+ }).WithError(err).Warn("Failed to process analytics event")
+ } else {
+ s.eventsProcessed.Add(1)
+ }
+}
+
+// 处理分析事件
+func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
+ if event.RequestLog.GroupID == nil {
+ return nil // 跳过无 GroupID 的事件
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ now := time.Now().UTC()
+ key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
+
pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
+
if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1)
}
@@ -101,80 +187,213 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.CompletionTokens > 0 {
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
}
+
+ // 设置过期时间(保留48小时)
+ pipe.Expire(key, 48*time.Hour)
+
if err := pipe.Exec(); err != nil {
- s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, *event.RequestLog.GroupID, err)
+ return fmt.Errorf("redis pipeline failed: %w", err)
}
+
+ return nil
}
+// 刷新循环
func (s *AnalyticsService) flushLoop() {
defer s.wg.Done()
- ticker := time.NewTicker(flushLoopInterval)
+
+ s.configMutex.RLock()
+ interval := s.flushInterval
+ s.configMutex.RUnlock()
+
+ ticker := time.NewTicker(interval)
defer ticker.Stop()
+
+ s.logger.WithField("interval", interval).Info("Flush loop started")
+
for {
select {
case <-ticker.C:
s.flushToDB()
+
case <-s.stopChan:
+ s.logger.Info("Flush loop stopping")
+ return
+
+ case <-s.ctx.Done():
return
}
}
}
+// 刷写到数据库
func (s *AnalyticsService) flushToDB() {
- ctx := context.Background()
+ start := time.Now()
+ ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+ defer cancel()
+
now := time.Now().UTC()
- keysToFlush := []string{
- fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
- fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")),
- }
+ keysToFlush := s.generateFlushKeys(now)
+
+ totalRecords := 0
+ totalErrors := 0
for _, key := range keysToFlush {
- data, err := s.store.HGetAll(ctx, key)
- if err != nil || len(data) == 0 {
- continue
+ records, err := s.flushSingleKey(ctx, key, now)
+ if err != nil {
+ s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
+ totalErrors++
+ s.flushErrors.Add(1)
+ } else {
+ totalRecords += records
}
+ }
- statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data)
+ s.recordsFlushed.Add(uint64(totalRecords))
+ s.flushCount.Add(1)
- if len(statsToFlush) > 0 {
- upsertClause := s.dialect.OnConflictUpdateAll(
- []string{"time", "group_id", "model_name"},
- []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
- )
- err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
- if err != nil {
- s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
- } else {
- s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
- _ = s.store.HDel(ctx, key, parsedFields...)
- }
- }
+ s.lastFlushMutex.Lock()
+ s.lastFlushTime = time.Now()
+ s.lastFlushMutex.Unlock()
+
+ duration := time.Since(start)
+
+ if totalRecords > 0 || totalErrors > 0 {
+ s.logger.WithFields(logrus.Fields{
+ "records_flushed": totalRecords,
+ "keys_processed": len(keysToFlush),
+ "errors": totalErrors,
+ "duration": duration,
+ }).Info("Analytics data flush completed")
+ } else {
+ s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
}
}
+// 生成需要刷新的 Redis 键
+func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
+ keys := make([]string, 0, 4)
+
+ // 当前小时
+ keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
+
+ // 前3个小时(处理延迟和时区问题)
+ for i := 1; i <= 3; i++ {
+ pastHour := now.Add(-time.Duration(i) * time.Hour)
+ keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
+ }
+
+ return keys
+}
+
+// 刷写单个 Redis 键
+func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
+ data, err := s.store.HGetAll(ctx, key)
+ if err != nil {
+ return 0, fmt.Errorf("failed to get hash data: %w", err)
+ }
+
+ if len(data) == 0 {
+ return 0, nil // 无数据,跳过
+ }
+
+ // 解析时间戳
+ hourStr := strings.TrimPrefix(key, "analytics:hourly:")
+ recordTime, err := time.Parse("2006-01-02T15", hourStr)
+ if err != nil {
+ s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
+ recordTime = baseTime.Truncate(time.Hour)
+ }
+
+ statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
+
+ if len(statsToFlush) == 0 {
+ return 0, nil
+ }
+
+ // 使用事务 + 重试机制
+ var dbErr error
+ for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
+ dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
+ if dbErr == nil {
+ break
+ }
+
+ if attempt < maxRetryAttempts {
+ s.logger.WithFields(logrus.Fields{
+ "attempt": attempt,
+ "key": key,
+ }).WithError(dbErr).Warn("Database upsert failed, retrying...")
+ time.Sleep(retryDelay)
+ }
+ }
+
+ if dbErr != nil {
+ return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
+ }
+
+ // 删除已处理的字段
+ if len(parsedFields) > 0 {
+ if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
+ s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
+ }
+ }
+
+ return len(statsToFlush), nil
+}
+
+// 使用事务批量 upsert
+func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
+ return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ upsertClause := s.dialect.OnConflictUpdateAll(
+ []string{"time", "group_id", "model_name"},
+ []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
+ )
+ return tx.Clauses(upsertClause).Create(&stats).Error
+ })
+}
+
+// 解析 Redis Hash 数据
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
tempAggregator := make(map[string]*models.StatsHourly)
- var parsedFields []string
+ parsedFields := make([]string, 0, len(data))
+
for field, valueStr := range data {
parts := strings.Split(field, ":")
if len(parts) != 3 {
+ s.logger.WithField("field", field).Warn("Invalid field format")
continue
}
- groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
+ groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
aggKey := groupIDStr + ":" + modelName
+
if _, ok := tempAggregator[aggKey]; !ok {
gid, err := strconv.Atoi(groupIDStr)
if err != nil {
+ s.logger.WithFields(logrus.Fields{
+ "field": field,
+ "group_id": groupIDStr,
+ }).Warn("Invalid group ID")
continue
}
+
tempAggregator[aggKey] = &models.StatsHourly{
Time: t,
GroupID: uint(gid),
ModelName: modelName,
}
}
- val, _ := strconv.ParseInt(valueStr, 10, 64)
+
+ val, err := strconv.ParseInt(valueStr, 10, 64)
+ if err != nil {
+ s.logger.WithFields(logrus.Fields{
+ "field": field,
+ "value": valueStr,
+ }).Warn("Invalid counter value")
+ continue
+ }
+
switch counterType {
case "requests":
tempAggregator[aggKey].RequestCount = val
@@ -184,14 +403,92 @@ func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]strin
tempAggregator[aggKey].PromptTokens = val
case "completion":
tempAggregator[aggKey].CompletionTokens = val
+ default:
+ s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
+ continue
}
+
parsedFields = append(parsedFields, field)
}
- var result []models.StatsHourly
+
+ result := make([]models.StatsHourly, 0, len(tempAggregator))
for _, stats := range tempAggregator {
if stats.RequestCount > 0 {
result = append(result, *stats)
}
}
+
return result, parsedFields
}
+
+// 定期输出统计信息
+func (s *AnalyticsService) metricsReporter() {
+ defer s.wg.Done()
+
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ s.reportMetrics()
+ case <-s.stopChan:
+ return
+ case <-s.ctx.Done():
+ return
+ }
+ }
+}
+
+func (s *AnalyticsService) reportMetrics() {
+ s.lastFlushMutex.RLock()
+ lastFlush := s.lastFlushTime
+ s.lastFlushMutex.RUnlock()
+
+ received := s.eventsReceived.Load()
+ processed := s.eventsProcessed.Load()
+ failed := s.eventsFailed.Load()
+
+ var successRate float64
+ if received > 0 {
+ successRate = float64(processed) / float64(received) * 100
+ }
+
+ s.logger.WithFields(logrus.Fields{
+ "events_received": received,
+ "events_processed": processed,
+ "events_failed": failed,
+ "success_rate": fmt.Sprintf("%.2f%%", successRate),
+ "flush_count": s.flushCount.Load(),
+ "records_flushed": s.recordsFlushed.Load(),
+ "flush_errors": s.flushErrors.Load(),
+ "last_flush_ago": time.Since(lastFlush).Round(time.Second),
+ }).Info("Analytics metrics")
+}
+
+// GetMetrics 返回当前统计指标(供监控使用)
+func (s *AnalyticsService) GetMetrics() map[string]interface{} {
+ s.lastFlushMutex.RLock()
+ lastFlush := s.lastFlushTime
+ s.lastFlushMutex.RUnlock()
+
+ received := s.eventsReceived.Load()
+ processed := s.eventsProcessed.Load()
+
+ var successRate float64
+ if received > 0 {
+ successRate = float64(processed) / float64(received) * 100
+ }
+
+ return map[string]interface{}{
+ "events_received": received,
+ "events_processed": processed,
+ "events_failed": s.eventsFailed.Load(),
+ "success_rate": successRate,
+ "flush_count": s.flushCount.Load(),
+ "records_flushed": s.recordsFlushed.Load(),
+ "flush_errors": s.flushErrors.Load(),
+ "last_flush_ago": time.Since(lastFlush).Seconds(),
+ "flush_interval": s.flushInterval.Seconds(),
+ }
+}
diff --git a/internal/service/dashboard_query_service.go b/internal/service/dashboard_query_service.go
index d1a6764..2258111 100644
--- a/internal/service/dashboard_query_service.go
+++ b/internal/service/dashboard_query_service.go
@@ -4,158 +4,297 @@ package service
import (
"context"
"fmt"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
- "strconv"
- "time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
-const overviewCacheChannel = "syncer:cache:dashboard_overview"
+const (
+ overviewCacheChannel = "syncer:cache:dashboard_overview"
+ defaultChartDays = 7
+ cacheLoadTimeout = 30 * time.Second
+)
+
+var (
+ // 图表颜色调色板
+ chartColorPalette = []string{
+ "#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0",
+ "#9966FF", "#FF9F40", "#C9CBCF", "#4D5360",
+ }
+)
type DashboardQueryService struct {
db *gorm.DB
store store.Store
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
logger *logrus.Entry
- stopChan chan struct{}
+
+ stopChan chan struct{}
+ wg sync.WaitGroup
+ ctx context.Context
+ cancel context.CancelFunc
+
+ // 统计指标
+ queryCount atomic.Uint64
+ cacheHits atomic.Uint64
+ cacheMisses atomic.Uint64
+ overviewLoadCount atomic.Uint64
+ lastQueryTime time.Time
+ lastQueryMutex sync.RWMutex
}
-func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
- qs := &DashboardQueryService{
- db: db,
- store: s,
- logger: logger.WithField("component", "DashboardQueryService"),
- stopChan: make(chan struct{}),
+func NewDashboardQueryService(
+ db *gorm.DB,
+ s store.Store,
+ logger *logrus.Logger,
+) (*DashboardQueryService, error) {
+ ctx, cancel := context.WithCancel(context.Background())
+
+ service := &DashboardQueryService{
+ db: db,
+ store: s,
+ logger: logger.WithField("component", "DashboardQuery📈"),
+ stopChan: make(chan struct{}),
+ ctx: ctx,
+ cancel: cancel,
+ lastQueryTime: time.Now(),
}
- loader := qs.loadOverviewData
- overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
+ // 创建 CacheSyncer
+ overviewSyncer, err := syncer.NewCacheSyncer(
+ service.loadOverviewData,
+ s,
+ overviewCacheChannel,
+ logger,
+ )
if err != nil {
+ cancel()
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
}
- qs.overviewSyncer = overviewSyncer
- return qs, nil
+ service.overviewSyncer = overviewSyncer
+
+ return service, nil
}
func (s *DashboardQueryService) Start() {
+ s.wg.Add(2)
go s.eventListener()
- s.logger.Info("DashboardQueryService started and listening for invalidation events.")
+ go s.metricsReporter()
+
+ s.logger.Info("DashboardQueryService started")
}
func (s *DashboardQueryService) Stop() {
+ s.logger.Info("DashboardQueryService stopping...")
close(s.stopChan)
- s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
+ s.cancel()
+ s.wg.Wait()
+
+ // 输出最终统计
+ s.logger.WithFields(logrus.Fields{
+ "total_queries": s.queryCount.Load(),
+ "cache_hits": s.cacheHits.Load(),
+ "cache_misses": s.cacheMisses.Load(),
+ "overview_loads": s.overviewLoadCount.Load(),
+ }).Info("DashboardQueryService stopped")
}
+// ==================== 核心查询方法 ====================
+
+// GetDashboardOverviewData 获取仪表盘概览数据(带缓存)
+func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
+ s.queryCount.Add(1)
+
+ cachedDataPtr := s.overviewSyncer.Get()
+ if cachedDataPtr == nil {
+ s.cacheMisses.Add(1)
+ s.logger.Warn("Overview cache is empty, attempting to load...")
+
+ // 触发立即加载
+ if err := s.overviewSyncer.Invalidate(); err != nil {
+ return nil, fmt.Errorf("failed to trigger cache reload: %w", err)
+ }
+
+ // 等待加载完成(最多30秒)
+ ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout)
+ defer cancel()
+
+ ticker := time.NewTicker(100 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ if data := s.overviewSyncer.Get(); data != nil {
+ s.cacheHits.Add(1)
+ return data, nil
+ }
+ case <-ctx.Done():
+ return nil, fmt.Errorf("timeout waiting for overview cache to load")
+ }
+ }
+ }
+
+ s.cacheHits.Add(1)
+ return cachedDataPtr, nil
+}
+
+// InvalidateOverviewCache 手动失效概览缓存
+func (s *DashboardQueryService) InvalidateOverviewCache() error {
+ s.logger.Info("Manually invalidating overview cache")
+ return s.overviewSyncer.Invalidate()
+}
+
+// GetGroupStats 获取指定分组的统计数据
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
+ s.queryCount.Add(1)
+ s.updateLastQueryTime()
+
+ start := time.Now()
+
+ // 1. 从 Redis 获取 Key 统计
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil {
- s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
+ s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
}
+
keyStats := make(map[string]int64)
for k, v := range keyStatsMap {
val, _ := strconv.ParseInt(v, 10, 64)
keyStats[k] = val
}
- now := time.Now()
+
+ // 2. 查询请求统计(使用 UTC 时间)
+ now := time.Now().UTC()
oneHourAgo := now.Add(-1 * time.Hour)
twentyFourHoursAgo := now.Add(-24 * time.Hour)
+
type requestStatsResult struct {
TotalRequests int64
SuccessRequests int64
}
+
var last1Hour, last24Hours requestStatsResult
- s.db.WithContext(ctx).Model(&models.StatsHourly{}).
- Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
- Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
- Scan(&last1Hour)
- s.db.WithContext(ctx).Model(&models.StatsHourly{}).
- Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
- Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
- Scan(&last24Hours)
- failureRate1h := 0.0
- if last1Hour.TotalRequests > 0 {
- failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
- }
- failureRate24h := 0.0
- if last24Hours.TotalRequests > 0 {
- failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
- }
- last1HourStats := map[string]any{
- "total_requests": last1Hour.TotalRequests,
- "success_requests": last1Hour.SuccessRequests,
- "failure_rate": failureRate1h,
- }
- last24HoursStats := map[string]any{
- "total_requests": last24Hours.TotalRequests,
- "success_requests": last24Hours.SuccessRequests,
- "failure_rate": failureRate24h,
+
+ // 并发查询优化
+ var wg sync.WaitGroup
+ errChan := make(chan error, 2)
+
+ wg.Add(2)
+
+ // 查询最近1小时
+ go func() {
+ defer wg.Done()
+ if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
+ Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
+ Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
+ Scan(&last1Hour).Error; err != nil {
+ errChan <- fmt.Errorf("failed to query 1h stats: %w", err)
+ }
+ }()
+
+ // 查询最近24小时
+ go func() {
+ defer wg.Done()
+ if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
+ Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
+ Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
+ Scan(&last24Hours).Error; err != nil {
+ errChan <- fmt.Errorf("failed to query 24h stats: %w", err)
+ }
+ }()
+
+ wg.Wait()
+ close(errChan)
+
+ // 检查错误
+ for err := range errChan {
+ if err != nil {
+ return nil, err
+ }
}
+
+ // 3. 计算失败率
+ failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests)
+ failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests)
+
result := map[string]any{
- "key_stats": keyStats,
- "last_1_hour": last1HourStats,
- "last_24_hours": last24HoursStats,
+ "key_stats": keyStats,
+ "last_1_hour": map[string]any{
+ "total_requests": last1Hour.TotalRequests,
+ "success_requests": last1Hour.SuccessRequests,
+ "failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests,
+ "failure_rate": failureRate1h,
+ },
+ "last_24_hours": map[string]any{
+ "total_requests": last24Hours.TotalRequests,
+ "success_requests": last24Hours.SuccessRequests,
+ "failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests,
+ "failure_rate": failureRate24h,
+ },
}
+
+ duration := time.Since(start)
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "duration": duration,
+ }).Debug("Group stats query completed")
+
return result, nil
}
-func (s *DashboardQueryService) eventListener() {
- ctx := context.Background()
- keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
- upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
- defer keyStatusSub.Close()
- defer upstreamStatusSub.Close()
- for {
- select {
- case <-keyStatusSub.Channel():
- s.logger.Info("Received key status changed event, invalidating overview cache...")
- _ = s.InvalidateOverviewCache()
- case <-upstreamStatusSub.Channel():
- s.logger.Info("Received upstream status changed event, invalidating overview cache...")
- _ = s.InvalidateOverviewCache()
- case <-s.stopChan:
- s.logger.Info("Stopping dashboard event listener.")
- return
- }
- }
-}
-
-func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
- cachedDataPtr := s.overviewSyncer.Get()
- if cachedDataPtr == nil {
- return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
- }
- return cachedDataPtr, nil
-}
-
-func (s *DashboardQueryService) InvalidateOverviewCache() error {
- return s.overviewSyncer.Invalidate()
-}
-
+// QueryHistoricalChart 查询历史图表数据
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
+ s.queryCount.Add(1)
+ s.updateLastQueryTime()
+
+ start := time.Now()
+
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
TotalRequests int64 `gorm:"column:total_requests"`
}
- sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
+
+ // 查询最近7天数据(使用 UTC)
+ sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour)
+
+ // 根据数据库类型构建时间格式化子句
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
- selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
- query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
+ selectClause := fmt.Sprintf(
+ "%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests",
+ sqlFormat,
+ )
+
+ // 构建查询
+ query := s.db.WithContext(ctx).
+ Model(&models.StatsHourly{}).
+ Select(selectClause).
+ Where("time >= ?", sevenDaysAgo).
+ Group("time_label, model_name").
+ Order("time_label ASC")
+
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
+
var points []ChartPoint
if err := query.Find(&points).Error; err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to query chart data: %w", err)
}
+
+ // 构建数据集
datasets := make(map[string]map[string]int64)
for _, p := range points {
if _, ok := datasets[p.ModelName]; !ok {
@@ -163,32 +302,99 @@ func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupI
}
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
}
+
+ // 生成时间标签(按小时)
var labels []string
- for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
+ for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) {
labels = append(labels, t.Format(goFormat))
}
- chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
- colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
+
+ // 构建图表数据
+ chartData := &models.ChartData{
+ Labels: labels,
+ Datasets: make([]models.ChartDataset, 0, len(datasets)),
+ }
+
colorIndex := 0
for modelName, dataPoints := range datasets {
dataArray := make([]int64, len(labels))
for i, label := range labels {
dataArray[i] = dataPoints[label]
}
+
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
Label: modelName,
Data: dataArray,
- Color: colorPalette[colorIndex%len(colorPalette)],
+ Color: chartColorPalette[colorIndex%len(chartColorPalette)],
})
colorIndex++
}
+
+ duration := time.Since(start)
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "points": len(points),
+ "datasets": len(chartData.Datasets),
+ "duration": duration,
+ }).Debug("Historical chart query completed")
+
return chartData, nil
}
+// GetRequestStatsForPeriod 获取指定时间段的请求统计
+func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
+ s.queryCount.Add(1)
+ s.updateLastQueryTime()
+
+ var startTime time.Time
+ now := time.Now().UTC()
+
+ switch period {
+ case "1m":
+ startTime = now.Add(-1 * time.Minute)
+ case "1h":
+ startTime = now.Add(-1 * time.Hour)
+ case "1d":
+ year, month, day := now.Date()
+ startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
+ default:
+ return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period)
+ }
+
+ var result struct {
+ Total int64
+ Success int64
+ }
+
+ err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
+ Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success").
+ Where("request_time >= ?", startTime).
+ Scan(&result).Error
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to query request stats: %w", err)
+ }
+
+ return gin.H{
+ "period": period,
+ "total": result.Total,
+ "success": result.Success,
+ "failure": result.Total - result.Success,
+ }, nil
+}
+
+// ==================== 内部方法 ====================
+
+// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用)
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
- ctx := context.Background()
- s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
+ ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+ defer cancel()
+
+ s.overviewLoadCount.Add(1)
startTime := time.Now()
+
+ s.logger.Info("Starting to load dashboard overview data...")
+
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
@@ -200,108 +406,391 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
RequestCounts: make(map[string]int64),
}
+ var loadErr error
+ var wg sync.WaitGroup
+ errChan := make(chan error, 10)
+
+ // 1. 并发加载 Key 映射状态统计
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := s.loadMappingStatusStats(ctx, resp); err != nil {
+ errChan <- fmt.Errorf("mapping stats: %w", err)
+ }
+ }()
+
+ // 2. 并发加载 Master Key 状态统计
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := s.loadMasterStatusStats(ctx, resp); err != nil {
+ errChan <- fmt.Errorf("master stats: %w", err)
+ }
+ }()
+
+ // 3. 并发加载请求统计
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := s.loadRequestCounts(ctx, resp); err != nil {
+ errChan <- fmt.Errorf("request counts: %w", err)
+ }
+ }()
+
+ // 4. 并发加载上游健康状态
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := s.loadUpstreamHealth(ctx, resp); err != nil {
+ // 上游健康状态失败不阻塞整体加载
+ s.logger.WithError(err).Warn("Failed to load upstream health status")
+ }
+ }()
+
+ // 等待所有加载任务完成
+ wg.Wait()
+ close(errChan)
+
+ // 收集错误
+ for err := range errChan {
+ if err != nil {
+ loadErr = err
+ break
+ }
+ }
+
+ if loadErr != nil {
+ s.logger.WithError(loadErr).Error("Failed to load overview data")
+ return nil, loadErr
+ }
+
+ duration := time.Since(startTime)
+ s.logger.WithFields(logrus.Fields{
+ "duration": duration,
+ "total_keys": resp.KeyCount.Value,
+ "requests_1d": resp.RequestCounts["1d"],
+ "upstreams": len(resp.UpstreamHealthStatus),
+ }).Info("Successfully loaded dashboard overview data")
+
+ return resp, nil
+}
+
+// loadMappingStatusStats 加载 Key 映射状态统计
+func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
- var mappingStatusResults []MappingStatusResult
- if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
- return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
+
+ var results []MappingStatusResult
+ if err := s.db.WithContext(ctx).
+ Model(&models.GroupAPIKeyMapping{}).
+ Select("status, COUNT(*) as count").
+ Group("status").
+ Find(&results).Error; err != nil {
+ return err
}
- for _, res := range mappingStatusResults {
+
+ for _, res := range results {
resp.KeyStatusCount[res.Status] = res.Count
}
+ return nil
+}
+
+// loadMasterStatusStats 加载 Master Key 状态统计
+func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
- var masterStatusResults []MasterStatusResult
- if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
- return nil, fmt.Errorf("failed to query master status stats: %w", err)
+
+ var results []MasterStatusResult
+ if err := s.db.WithContext(ctx).
+ Model(&models.APIKey{}).
+ Select("master_status as status, COUNT(*) as count").
+ Group("master_status").
+ Find(&results).Error; err != nil {
+ return err
}
+
var totalKeys, invalidKeys int64
- for _, res := range masterStatusResults {
+ for _, res := range results {
resp.MasterStatusCount[res.Status] = res.Count
totalKeys += res.Count
if res.Status != models.MasterStatusActive {
invalidKeys += res.Count
}
}
- resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
- now := time.Now()
+ resp.KeyCount = models.StatCard{
+ Value: float64(totalKeys),
+ SubValue: invalidKeys,
+ SubValueTip: "非活跃身份密钥数",
+ }
- var count1m, count1h, count1d int64
- s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
- s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
- year, month, day := now.UTC().Date()
+ return nil
+}
+
+// loadRequestCounts 加载请求计数统计
+func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error {
+ now := time.Now().UTC()
+
+ // 使用 RequestLog 表查询短期数据
+ var count1m, count1h int64
+
+ // 最近1分钟
+ if err := s.db.WithContext(ctx).
+ Model(&models.RequestLog{}).
+ Where("request_time >= ?", now.Add(-1*time.Minute)).
+ Count(&count1m).Error; err != nil {
+ return fmt.Errorf("1m count: %w", err)
+ }
+
+ // 最近1小时
+ if err := s.db.WithContext(ctx).
+ Model(&models.RequestLog{}).
+ Where("request_time >= ?", now.Add(-1*time.Hour)).
+ Count(&count1h).Error; err != nil {
+ return fmt.Errorf("1h count: %w", err)
+ }
+
+ // 今天(UTC)
+ year, month, day := now.Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
- s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
+ var count1d int64
+ if err := s.db.WithContext(ctx).
+ Model(&models.RequestLog{}).
+ Where("request_time >= ?", startOfDay).
+ Count(&count1d).Error; err != nil {
+ return fmt.Errorf("1d count: %w", err)
+ }
+
+ // 最近30天使用聚合表
var count30d int64
- s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
+ if err := s.db.WithContext(ctx).
+ Model(&models.StatsHourly{}).
+ Where("time >= ?", now.AddDate(0, 0, -30)).
+ Select("COALESCE(SUM(request_count), 0)").
+ Scan(&count30d).Error; err != nil {
+ return fmt.Errorf("30d count: %w", err)
+ }
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
resp.RequestCounts["1d"] = count1d
resp.RequestCounts["30d"] = count30d
+ return nil
+}
+
+// loadUpstreamHealth 加载上游健康状态
+func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error {
var upstreams []*models.UpstreamEndpoint
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
- s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
- } else {
- for _, u := range upstreams {
- resp.UpstreamHealthStatus[u.URL] = u.Status
+ return err
+ }
+
+ for _, u := range upstreams {
+ resp.UpstreamHealthStatus[u.URL] = u.Status
+ }
+
+ return nil
+}
+
+// ==================== 事件监听 ====================
+
+// eventListener 监听缓存失效事件
+func (s *DashboardQueryService) eventListener() {
+ defer s.wg.Done()
+
+ // 订阅事件
+ keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged)
+ upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged)
+
+ // 错误处理
+ if err1 != nil {
+ s.logger.WithError(err1).Error("Failed to subscribe to key status events")
+ keyStatusSub = nil
+ }
+ if err2 != nil {
+ s.logger.WithError(err2).Error("Failed to subscribe to upstream status events")
+ upstreamStatusSub = nil
+ }
+
+ // 如果全部失败,直接返回
+ if keyStatusSub == nil && upstreamStatusSub == nil {
+ s.logger.Error("All event subscriptions failed, listener disabled")
+ return
+ }
+
+ // 安全关闭订阅
+ defer func() {
+ if keyStatusSub != nil {
+ if err := keyStatusSub.Close(); err != nil {
+ s.logger.WithError(err).Warn("Failed to close key status subscription")
+ }
+ }
+ if upstreamStatusSub != nil {
+ if err := upstreamStatusSub.Close(); err != nil {
+ s.logger.WithError(err).Warn("Failed to close upstream status subscription")
+ }
+ }
+ }()
+
+ s.logger.WithFields(logrus.Fields{
+ "key_status_sub": keyStatusSub != nil,
+ "upstream_status_sub": upstreamStatusSub != nil,
+ }).Info("Event listener started")
+
+ neverReady := make(chan *store.Message)
+ close(neverReady) // 立即关闭,确保永远不会阻塞
+
+ for {
+ // 动态选择有效的 channel
+ var keyStatusChan <-chan *store.Message = neverReady
+ if keyStatusSub != nil {
+ keyStatusChan = keyStatusSub.Channel()
+ }
+
+ var upstreamStatusChan <-chan *store.Message = neverReady
+ if upstreamStatusSub != nil {
+ upstreamStatusChan = upstreamStatusSub.Channel()
+ }
+
+ select {
+ case _, ok := <-keyStatusChan:
+ if !ok {
+ s.logger.Warn("Key status channel closed")
+ keyStatusSub = nil
+ continue
+ }
+ s.logger.Debug("Received key status changed event")
+ if err := s.InvalidateOverviewCache(); err != nil {
+ s.logger.WithError(err).Warn("Failed to invalidate cache on key status change")
+ }
+
+ case _, ok := <-upstreamStatusChan:
+ if !ok {
+ s.logger.Warn("Upstream status channel closed")
+ upstreamStatusSub = nil
+ continue
+ }
+ s.logger.Debug("Received upstream status changed event")
+ if err := s.InvalidateOverviewCache(); err != nil {
+ s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change")
+ }
+
+ case <-s.stopChan:
+ s.logger.Info("Event listener stopping (stopChan)")
+ return
+
+ case <-s.ctx.Done():
+ s.logger.Info("Event listener stopping (context cancelled)")
+ return
}
}
-
- duration := time.Since(startTime)
- s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
- return resp, nil
}
-func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
- var startTime time.Time
- now := time.Now()
- switch period {
- case "1m":
- startTime = now.Add(-1 * time.Minute)
- case "1h":
- startTime = now.Add(-1 * time.Hour)
- case "1d":
- year, month, day := now.UTC().Date()
- startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
- default:
- return nil, fmt.Errorf("invalid period specified: %s", period)
- }
- var result struct {
- Total int64
- Success int64
- }
+// ==================== 监控指标 ====================
- err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
- Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
- Where("request_time >= ?", startTime).
- Scan(&result).Error
- if err != nil {
- return nil, err
+// metricsReporter 定期输出统计信息
+func (s *DashboardQueryService) metricsReporter() {
+ defer s.wg.Done()
+
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ s.reportMetrics()
+ case <-s.stopChan:
+ return
+ case <-s.ctx.Done():
+ return
+ }
}
- return gin.H{
- "total": result.Total,
- "success": result.Success,
- "failure": result.Total - result.Success,
- }, nil
}
+func (s *DashboardQueryService) reportMetrics() {
+ s.lastQueryMutex.RLock()
+ lastQuery := s.lastQueryTime
+ s.lastQueryMutex.RUnlock()
+
+ totalQueries := s.queryCount.Load()
+ hits := s.cacheHits.Load()
+ misses := s.cacheMisses.Load()
+
+ var cacheHitRate float64
+ if hits+misses > 0 {
+ cacheHitRate = float64(hits) / float64(hits+misses) * 100
+ }
+
+ s.logger.WithFields(logrus.Fields{
+ "total_queries": totalQueries,
+ "cache_hits": hits,
+ "cache_misses": misses,
+ "cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate),
+ "overview_loads": s.overviewLoadCount.Load(),
+ "last_query_ago": time.Since(lastQuery).Round(time.Second),
+ }).Info("DashboardQuery metrics")
+}
+
+// GetMetrics 返回当前统计指标(供监控使用)
+func (s *DashboardQueryService) GetMetrics() map[string]interface{} {
+ s.lastQueryMutex.RLock()
+ lastQuery := s.lastQueryTime
+ s.lastQueryMutex.RUnlock()
+
+ hits := s.cacheHits.Load()
+ misses := s.cacheMisses.Load()
+
+ var cacheHitRate float64
+ if hits+misses > 0 {
+ cacheHitRate = float64(hits) / float64(hits+misses) * 100
+ }
+
+ return map[string]interface{}{
+ "total_queries": s.queryCount.Load(),
+ "cache_hits": hits,
+ "cache_misses": misses,
+ "cache_hit_rate": cacheHitRate,
+ "overview_loads": s.overviewLoadCount.Load(),
+ "last_query_ago": time.Since(lastQuery).Seconds(),
+ }
+}
+
+// ==================== 辅助方法 ====================
+
+// calculateFailureRate 计算失败率
+func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 {
+ if total == 0 {
+ return 0.0
+ }
+ return float64(total-success) / float64(total) * 100
+}
+
+// updateLastQueryTime 更新最后查询时间
+func (s *DashboardQueryService) updateLastQueryTime() {
+ s.lastQueryMutex.Lock()
+ s.lastQueryTime = time.Now()
+ s.lastQueryMutex.Unlock()
+}
+
+// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
dialect := s.db.Dialector.Name()
switch dialect {
case "mysql":
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
+ case "postgres":
+ return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00"
case "sqlite":
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
default:
+ s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format")
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
}
}
diff --git a/internal/service/db_log_writer_service.go b/internal/service/db_log_writer_service.go
index b80904a..563db2a 100644
--- a/internal/service/db_log_writer_service.go
+++ b/internal/service/db_log_writer_service.go
@@ -4,11 +4,13 @@ package service
import (
"context"
"encoding/json"
+ "sync"
+ "sync/atomic"
+ "time"
+
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
- "sync"
- "time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
@@ -18,25 +20,47 @@ type DBLogWriterService struct {
db *gorm.DB
store store.Store
logger *logrus.Entry
- logBuffer chan *models.RequestLog
- stopChan chan struct{}
- wg sync.WaitGroup
- SettingsManager *settings.SettingsManager
+ settingsManager *settings.SettingsManager
+
+ logBuffer chan *models.RequestLog
+ stopChan chan struct{}
+ wg sync.WaitGroup
+ ctx context.Context
+ cancel context.CancelFunc
+
+ // 统计指标
+ totalReceived atomic.Uint64
+ totalFlushed atomic.Uint64
+ totalDropped atomic.Uint64
+ flushCount atomic.Uint64
+ lastFlushTime time.Time
+ lastFlushMutex sync.RWMutex
}
-func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService {
- cfg := settings.GetSettings()
+func NewDBLogWriterService(
+ db *gorm.DB,
+ s store.Store,
+ settingsManager *settings.SettingsManager,
+ logger *logrus.Logger,
+) *DBLogWriterService {
+ cfg := settingsManager.GetSettings()
bufferCapacity := cfg.LogBufferCapacity
if bufferCapacity <= 0 {
bufferCapacity = 1000
}
+
+ ctx, cancel := context.WithCancel(context.Background())
+
return &DBLogWriterService{
db: db,
store: s,
- SettingsManager: settings,
+ settingsManager: settingsManager,
logger: logger.WithField("component", "DBLogWriter📝"),
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
+ ctx: ctx,
+ cancel: cancel,
+ lastFlushTime: time.Now(),
}
}
@@ -44,93 +68,276 @@ func (s *DBLogWriterService) Start() {
s.wg.Add(2)
go s.eventListenerLoop()
go s.dbWriterLoop()
- s.logger.Info("DBLogWriterService started.")
+
+ // 定期输出统计信息
+ s.wg.Add(1)
+ go s.metricsReporter()
+
+ s.logger.WithFields(logrus.Fields{
+ "buffer_capacity": cap(s.logBuffer),
+ }).Info("DBLogWriterService started")
}
func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan)
+ s.cancel() // 取消上下文
s.wg.Wait()
- s.logger.Info("DBLogWriterService stopped.")
+
+ // 输出最终统计
+ s.logger.WithFields(logrus.Fields{
+ "total_received": s.totalReceived.Load(),
+ "total_flushed": s.totalFlushed.Load(),
+ "total_dropped": s.totalDropped.Load(),
+ "flush_count": s.flushCount.Load(),
+ }).Info("DBLogWriterService stopped")
}
+// 事件监听循环
func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done()
- ctx := context.Background()
- sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
+ sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil {
- s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
+ s.logger.WithError(err).Error("Failed to subscribe to request events, log writing disabled")
return
}
- defer sub.Close()
+ defer func() {
+ if err := sub.Close(); err != nil {
+ s.logger.WithError(err).Warn("Failed to close subscription")
+ }
+ }()
- s.logger.Info("Subscribed to request events for database logging.")
+ s.logger.Info("Subscribed to request events for database logging")
for {
select {
case msg := <-sub.Channel():
- var event models.RequestFinishedEvent
- if err := json.Unmarshal(msg.Payload, &event); err != nil {
- s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
- continue
- }
- select {
- case s.logBuffer <- &event.RequestLog:
- default:
- s.logger.Warn("Log buffer is full. A log message might be dropped.")
- }
+ s.handleMessage(msg)
+
case <-s.stopChan:
- s.logger.Info("Event listener loop stopping.")
+ s.logger.Info("Event listener loop stopping")
+ close(s.logBuffer)
+ return
+
+ case <-s.ctx.Done():
+ s.logger.Info("Event listener context cancelled")
close(s.logBuffer)
return
}
}
}
+// 处理单条消息
+func (s *DBLogWriterService) handleMessage(msg *store.Message) {
+ var event models.RequestFinishedEvent
+ if err := json.Unmarshal(msg.Payload, &event); err != nil {
+ s.logger.WithError(err).Error("Failed to unmarshal request event")
+ return
+ }
+
+ s.totalReceived.Add(1)
+
+ select {
+ case s.logBuffer <- &event.RequestLog:
+ // 成功入队
+ default:
+ // 缓冲区满,丢弃日志
+ dropped := s.totalDropped.Add(1)
+ if dropped%100 == 1 { // 每100条丢失输出一次警告
+ s.logger.WithFields(logrus.Fields{
+ "total_dropped": dropped,
+ "buffer_capacity": cap(s.logBuffer),
+ "buffer_len": len(s.logBuffer),
+ }).Warn("Log buffer full, messages being dropped")
+ }
+ }
+}
+
+// 数据库写入循环
func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done()
- cfg := s.SettingsManager.GetSettings()
+ cfg := s.settingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 {
batchSize = 100
}
- flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
- if flushTimeout <= 0 {
- flushTimeout = 5 * time.Second
+ flushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
+ if flushInterval <= 0 {
+ flushInterval = 5 * time.Second
}
+
+ s.logger.WithFields(logrus.Fields{
+ "batch_size": batchSize,
+ "flush_interval": flushInterval,
+ }).Info("DB writer loop started")
+
batch := make([]*models.RequestLog, 0, batchSize)
- ticker := time.NewTicker(flushTimeout)
+ ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
+
+ // 配置热更新检查(每分钟)
+ configTicker := time.NewTicker(1 * time.Minute)
+ defer configTicker.Stop()
+
for {
select {
case logEntry, ok := <-s.logBuffer:
if !ok {
+ // 通道关闭,刷新剩余日志
if len(batch) > 0 {
s.flushBatch(batch)
}
- s.logger.Info("DB writer loop finished.")
+ s.logger.Info("DB writer loop finished")
return
}
+
batch = append(batch, logEntry)
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
+
case <-ticker.C:
if len(batch) > 0 {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
+
+ case <-configTicker.C:
+ // 热更新配置
+ cfg := s.settingsManager.GetSettings()
+ newBatchSize := cfg.LogFlushBatchSize
+ if newBatchSize <= 0 {
+ newBatchSize = 100
+ }
+ newFlushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
+ if newFlushInterval <= 0 {
+ newFlushInterval = 5 * time.Second
+ }
+
+ if newBatchSize != batchSize {
+ s.logger.WithFields(logrus.Fields{
+ "old": batchSize,
+ "new": newBatchSize,
+ }).Info("Batch size updated")
+ batchSize = newBatchSize
+ if len(batch) >= batchSize {
+ s.flushBatch(batch)
+ batch = make([]*models.RequestLog, 0, batchSize)
+ }
+ }
+
+ if newFlushInterval != flushInterval {
+ s.logger.WithFields(logrus.Fields{
+ "old": flushInterval,
+ "new": newFlushInterval,
+ }).Info("Flush interval updated")
+ flushInterval = newFlushInterval
+ ticker.Reset(flushInterval)
+ }
}
}
}
+// 批量刷写到数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
- if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
- s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
+ if len(batch) == 0 {
+ return
+ }
+
+ start := time.Now()
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ err := s.db.WithContext(ctx).CreateInBatches(batch, len(batch)).Error
+ duration := time.Since(start)
+
+ s.lastFlushMutex.Lock()
+ s.lastFlushTime = time.Now()
+ s.lastFlushMutex.Unlock()
+
+ if err != nil {
+ s.logger.WithFields(logrus.Fields{
+ "batch_size": len(batch),
+ "duration": duration,
+ }).WithError(err).Error("Failed to flush log batch to database")
} else {
- s.logger.Infof("Successfully flushed %d logs to database.", len(batch))
+ flushed := s.totalFlushed.Add(uint64(len(batch)))
+ flushCount := s.flushCount.Add(1)
+
+ // 只在慢写入或大批量时输出日志
+ if duration > 1*time.Second || len(batch) > 500 {
+ s.logger.WithFields(logrus.Fields{
+ "batch_size": len(batch),
+ "duration": duration,
+ "total_flushed": flushed,
+ "flush_count": flushCount,
+ }).Info("Log batch flushed to database")
+ } else {
+ s.logger.WithFields(logrus.Fields{
+ "batch_size": len(batch),
+ "duration": duration,
+ }).Debug("Log batch flushed to database")
+ }
+ }
+}
+
+// 定期输出统计信息
+func (s *DBLogWriterService) metricsReporter() {
+ defer s.wg.Done()
+
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ s.reportMetrics()
+ case <-s.stopChan:
+ return
+ case <-s.ctx.Done():
+ return
+ }
+ }
+}
+
+func (s *DBLogWriterService) reportMetrics() {
+ s.lastFlushMutex.RLock()
+ lastFlush := s.lastFlushTime
+ s.lastFlushMutex.RUnlock()
+
+ received := s.totalReceived.Load()
+ flushed := s.totalFlushed.Load()
+ dropped := s.totalDropped.Load()
+ pending := uint64(len(s.logBuffer))
+
+ s.logger.WithFields(logrus.Fields{
+ "received": received,
+ "flushed": flushed,
+ "dropped": dropped,
+ "pending": pending,
+ "flush_count": s.flushCount.Load(),
+ "last_flush": time.Since(lastFlush).Round(time.Second),
+ "buffer_usage": float64(pending) / float64(cap(s.logBuffer)) * 100,
+ "success_rate": float64(flushed) / float64(received) * 100,
+ }).Info("DBLogWriter metrics")
+}
+
+// GetMetrics 返回当前统计指标(供监控使用)
+func (s *DBLogWriterService) GetMetrics() map[string]interface{} {
+ s.lastFlushMutex.RLock()
+ lastFlush := s.lastFlushTime
+ s.lastFlushMutex.RUnlock()
+
+ return map[string]interface{}{
+ "total_received": s.totalReceived.Load(),
+ "total_flushed": s.totalFlushed.Load(),
+ "total_dropped": s.totalDropped.Load(),
+ "flush_count": s.flushCount.Load(),
+ "buffer_pending": len(s.logBuffer),
+ "buffer_capacity": cap(s.logBuffer),
+ "last_flush_ago": time.Since(lastFlush).Seconds(),
}
}
diff --git a/internal/service/group_manager.go b/internal/service/group_manager.go
index d8690dd..4ec10eb 100644
--- a/internal/service/group_manager.go
+++ b/internal/service/group_manager.go
@@ -334,7 +334,6 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
globalSettings := gm.settingsManager.GetSettings()
- defaultModel := "gemini-1.5-flash"
opConfig := &models.KeyGroupSettings{
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
@@ -342,7 +341,7 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
- KeyCheckModel: &defaultModel,
+ KeyCheckModel: &globalSettings.BaseKeyCheckModel,
MaxRetries: &globalSettings.MaxRetries,
EnableSmartGateway: &globalSettings.EnableSmartGateway,
}
diff --git a/internal/service/healthcheck_service.go b/internal/service/healthcheck_service.go
index 36030a2..f28aedc 100644
--- a/internal/service/healthcheck_service.go
+++ b/internal/service/healthcheck_service.go
@@ -1,4 +1,4 @@
-// Filename: internal/service/healthcheck_service.go (最终校准版)
+// Filename: internal/service/healthcheck_service.go
package service
@@ -23,17 +23,30 @@ import (
)
const (
- ProxyCheckTargetURL = "https://www.google.com/generate_204"
- DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
- StatusActive = "active"
- StatusInactive = "inactive"
+ // 业务状态常量
+ StatusActive = "active"
+ StatusInactive = "inactive"
+
+ // 代理检查目标(固定不变)
+ ProxyCheckTargetURL = "https://www.google.com/generate_204"
+
+ // 并发控制边界
+ minHealthCheckConcurrency = 1
+ maxHealthCheckConcurrency = 100
+ defaultKeyCheckConcurrency = 5
+ defaultBaseKeyCheckConcurrency = 5
+
+ // 兜底默认值
+ defaultSchedulerIntervalSeconds = 60
+ defaultKeyCheckTimeoutSeconds = 30
+ defaultBaseKeyCheckEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
)
type HealthCheckServiceLogger struct{ *logrus.Entry }
type HealthCheckService struct {
db *gorm.DB
- SettingsManager *settings.SettingsManager
+ settingsManager *settings.SettingsManager
store store.Store
keyRepo repository.KeyRepository
groupManager *GroupManager
@@ -46,6 +59,7 @@ type HealthCheckService struct {
lastResults map[string]string
groupCheckTimeMutex sync.Mutex
groupNextCheckTime map[uint]time.Time
+ httpClient *http.Client
}
func NewHealthCheckService(
@@ -60,7 +74,7 @@ func NewHealthCheckService(
) *HealthCheckService {
return &HealthCheckService{
db: db,
- SettingsManager: ss,
+ settingsManager: ss,
store: s,
keyRepo: repo,
groupManager: gm,
@@ -70,6 +84,14 @@ func NewHealthCheckService(
stopChan: make(chan struct{}),
lastResults: make(map[string]string),
groupNextCheckTime: make(map[uint]time.Time),
+ httpClient: &http.Client{
+ Transport: &http.Transport{
+ MaxIdleConns: 100,
+ MaxIdleConnsPerHost: 10,
+ IdleConnTimeout: 90 * time.Second,
+ DisableKeepAlives: false,
+ },
+ },
}
}
@@ -86,6 +108,7 @@ func (s *HealthCheckService) Stop() {
s.logger.Info("Stopping HealthCheckService...")
close(s.stopChan)
s.wg.Wait()
+ s.httpClient.CloseIdleConnections()
s.logger.Info("HealthCheckService stopped gracefully.")
}
@@ -99,10 +122,19 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
return resultsCopy
}
+// ==================== Key Check Loop ====================
+
func (s *HealthCheckService) runKeyCheckLoop() {
defer s.wg.Done()
- s.logger.Info("Key check dynamic scheduler loop started.")
- ticker := time.NewTicker(1 * time.Minute)
+
+ settings := s.settingsManager.GetSettings()
+ schedulerInterval := time.Duration(settings.KeyCheckSchedulerIntervalSeconds) * time.Second
+ if schedulerInterval <= 0 {
+ schedulerInterval = time.Duration(defaultSchedulerIntervalSeconds) * time.Second
+ }
+
+ s.logger.Infof("Key check dynamic scheduler loop started with interval: %v", schedulerInterval)
+ ticker := time.NewTicker(schedulerInterval)
defer ticker.Stop()
for {
@@ -129,9 +161,11 @@ func (s *HealthCheckService) scheduleKeyChecks() {
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
continue
}
+
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
continue
}
+
var intervalMinutes int
if opConfig.KeyCheckIntervalMinutes != nil {
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
@@ -140,64 +174,41 @@ func (s *HealthCheckService) scheduleKeyChecks() {
if interval <= 0 {
continue
}
+
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
- s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
- go s.performKeyChecksForGroup(group, opConfig)
+ s.logger.WithFields(logrus.Fields{
+ "group_id": group.ID,
+ "group_name": group.Name,
+ "interval": interval,
+ }).Info("Scheduling key check for group")
+
+ // 创建带超时的上下文
+ ctx, cancel := context.WithTimeout(context.Background(), interval)
+ go func(g *models.KeyGroup, cfg *models.KeyGroupSettings) {
+ defer cancel()
+ select {
+ case <-s.stopChan:
+ return
+ default:
+ s.performKeyChecksForGroup(ctx, g, cfg)
+ }
+ }(group, opConfig)
+
s.groupNextCheckTime[group.ID] = now.Add(interval)
}
}
}
-func (s *HealthCheckService) runUpstreamCheckLoop() {
- defer s.wg.Done()
- s.logger.Info("Upstream check loop started.")
- if s.SettingsManager.GetSettings().EnableUpstreamCheck {
- s.performUpstreamChecks()
- }
- ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- if s.SettingsManager.GetSettings().EnableUpstreamCheck {
- s.logger.Debug("Upstream check ticker fired.")
- s.performUpstreamChecks()
- }
- case <-s.stopChan:
- s.logger.Info("Upstream check loop stopped.")
- return
- }
- }
-}
-
-func (s *HealthCheckService) runProxyCheckLoop() {
- defer s.wg.Done()
- s.logger.Info("Proxy check loop started.")
- if s.SettingsManager.GetSettings().EnableProxyCheck {
- s.performProxyChecks()
- }
- ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- if s.SettingsManager.GetSettings().EnableProxyCheck {
- s.logger.Debug("Proxy check ticker fired.")
- s.performProxyChecks()
- }
- case <-s.stopChan:
- s.logger.Info("Proxy check loop stopped.")
- return
- }
- }
-}
-
-func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
- ctx := context.Background()
- settings := s.SettingsManager.GetSettings()
+func (s *HealthCheckService) performKeyChecksForGroup(
+ ctx context.Context,
+ group *models.KeyGroup,
+ opConfig *models.KeyGroupSettings,
+) {
+ settings := s.settingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
+ if timeout <= 0 {
+ timeout = time.Duration(defaultKeyCheckTimeoutSeconds) * time.Second
+ }
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(group.ID)
if err != nil {
@@ -205,57 +216,83 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
return
}
- log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
- log.Infof("Starting key health check cycle.")
+ log := s.logger.WithFields(logrus.Fields{
+ "group_id": group.ID,
+ "group_name": group.Name,
+ })
+
+ log.Info("Starting key health check cycle")
+
var mappingsToCheck []models.GroupAPIKeyMapping
err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
Where("group_api_key_mappings.key_group_id = ?", group.ID).
Where("api_keys.master_status = ?", models.MasterStatusActive).
- Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{models.StatusActive, models.StatusDisabled, models.StatusCooldown}).
+ Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{
+ models.StatusActive,
+ models.StatusDisabled,
+ models.StatusCooldown,
+ }).
Preload("APIKey").
Find(&mappingsToCheck).Error
if err != nil {
- log.WithError(err).Error("Failed to fetch key mappings for health check.")
+ log.WithError(err).Error("Failed to fetch key mappings for health check")
return
}
+
if len(mappingsToCheck) == 0 {
- log.Info("No key mappings to check for this group.")
+ log.Info("No key mappings to check for this group")
return
}
- log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
+
+ log.WithField("key_count", len(mappingsToCheck)).Info("Starting health check for key mappings")
+
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
var wg sync.WaitGroup
- var concurrency int
- if opConfig.KeyCheckConcurrency != nil {
- concurrency = *opConfig.KeyCheckConcurrency
- }
- if concurrency <= 0 {
- concurrency = 1
- }
+
+ concurrency := s.getConcurrency(opConfig.KeyCheckConcurrency, defaultKeyCheckConcurrency)
+ log.WithField("concurrency", concurrency).Debug("Using concurrency for key check")
+
for w := 1; w <= concurrency; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for mapping := range jobs {
- s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
+ select {
+ case <-ctx.Done():
+ log.Warn("Context cancelled, stopping worker")
+ return
+ default:
+ s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
+ }
}
}(w)
}
+
for _, m := range mappingsToCheck {
jobs <- m
}
close(jobs)
wg.Wait()
- log.Info("Finished key health check cycle.")
+
+ log.Info("Finished key health check cycle")
}
-func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
+func (s *HealthCheckService) checkAndProcessMapping(
+ ctx context.Context,
+ mapping *models.GroupAPIKeyMapping,
+ timeout time.Duration,
+ endpoint string,
+) {
if mapping.APIKey == nil {
- s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
+ s.logger.WithFields(logrus.Fields{
+ "group_id": mapping.KeyGroupID,
+ "key_id": mapping.APIKeyID,
+ }).Warn("Skipping check for mapping because associated APIKey is nil")
return
}
+
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
if validationErr == nil {
if mapping.Status != models.StatusActive {
@@ -263,17 +300,28 @@ func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping
}
return
}
+
errorString := validationErr.Error()
if CustomErrors.IsPermanentUpstreamError(errorString) {
s.revokeMapping(ctx, mapping, validationErr)
return
}
+
if CustomErrors.IsTemporaryUpstreamError(errorString) {
- s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
+ s.logger.WithFields(logrus.Fields{
+ "key_id": mapping.APIKeyID,
+ "group_id": mapping.KeyGroupID,
+ "error": validationErr.Error(),
+ }).Warn("Health check failed with temporary error, applying penalty")
s.penalizeMapping(ctx, mapping, validationErr)
return
}
- s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
+
+ s.logger.WithFields(logrus.Fields{
+ "key_id": mapping.APIKeyID,
+ "group_id": mapping.KeyGroupID,
+ "error": validationErr.Error(),
+ }).Error("Health check failed with transient or unknown upstream error, mapping will not be penalized")
}
func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
@@ -283,39 +331,64 @@ func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *model
mapping.LastError = ""
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
- s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
+ s.logger.WithError(err).WithFields(logrus.Fields{
+ "group_id": mapping.KeyGroupID,
+ "key_id": mapping.APIKeyID,
+ }).Error("Failed to activate mapping")
return
}
- s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
+
+ s.logger.WithFields(logrus.Fields{
+ "group_id": mapping.KeyGroupID,
+ "key_id": mapping.APIKeyID,
+ "old_status": oldStatus,
+ "new_status": mapping.Status,
+ }).Info("Mapping successfully activated")
+
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
if !ok {
- s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
+ s.logger.WithField("group_id", mapping.KeyGroupID).Error("Could not find group to apply penalty")
return
}
+
opConfig, buildErr := s.groupManager.BuildOperationalConfig(group)
if buildErr != nil {
- s.logger.WithError(buildErr).Errorf("Failed to build operational config for group %d during penalty.", mapping.KeyGroupID)
+ s.logger.WithError(buildErr).WithField("group_id", mapping.KeyGroupID).Error("Failed to build operational config for group during penalty")
return
}
+
oldStatus := mapping.Status
mapping.LastError = err.Error()
mapping.ConsecutiveErrorCount++
+
threshold := *opConfig.KeyBlacklistThreshold
if mapping.ConsecutiveErrorCount >= threshold {
mapping.Status = models.StatusCooldown
cooldownDuration := time.Duration(*opConfig.KeyCooldownMinutes) * time.Minute
cooldownTime := time.Now().Add(cooldownDuration)
mapping.CooldownUntil = &cooldownTime
- s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
+
+ s.logger.WithFields(logrus.Fields{
+ "group_id": mapping.KeyGroupID,
+ "key_id": mapping.APIKeyID,
+ "error_count": mapping.ConsecutiveErrorCount,
+ "threshold": threshold,
+ "cooldown_duration": cooldownDuration,
+ }).Warn("Mapping reached error threshold and is now in COOLDOWN")
}
+
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
- s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
+ s.logger.WithError(errDb).WithFields(logrus.Fields{
+ "group_id": mapping.KeyGroupID,
+ "key_id": mapping.APIKeyID,
+ }).Error("Failed to penalize mapping")
return
}
+
if oldStatus != mapping.Status {
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
@@ -326,144 +399,500 @@ func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.
if oldStatus == models.StatusBanned {
return
}
+
mapping.Status = models.StatusBanned
mapping.LastError = "Definitive error: " + err.Error()
mapping.ConsecutiveErrorCount = 0
+
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
- s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
+ s.logger.WithError(errDb).WithFields(logrus.Fields{
+ "group_id": mapping.KeyGroupID,
+ "key_id": mapping.APIKeyID,
+ }).Error("Failed to revoke mapping")
return
}
- s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
+
+ s.logger.WithFields(logrus.Fields{
+ "group_id": mapping.KeyGroupID,
+ "key_id": mapping.APIKeyID,
+ "error": err.Error(),
+ }).Warn("Mapping has been BANNED due to definitive error")
+
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
- s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
+
+ s.logger.WithField("key_id", mapping.APIKeyID).Info("Triggering MasterStatus update for definitively failed key")
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
- s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
+ s.logger.WithError(err).WithField("key_id", mapping.APIKeyID).Error("Failed to update master status after group-level ban")
+ }
+}
+
+// ==================== Upstream Check Loop ====================
+
+func (s *HealthCheckService) runUpstreamCheckLoop() {
+ defer s.wg.Done()
+ s.logger.Info("Upstream check loop started")
+
+ settings := s.settingsManager.GetSettings()
+ if settings.EnableUpstreamCheck {
+ s.performUpstreamChecks()
+ }
+
+ interval := time.Duration(settings.UpstreamCheckIntervalSeconds) * time.Second
+ if interval <= 0 {
+ interval = 300 * time.Second // 5 分钟兜底
+ }
+
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ settings := s.settingsManager.GetSettings()
+ if settings.EnableUpstreamCheck {
+ s.logger.Debug("Upstream check ticker fired")
+ s.performUpstreamChecks()
+ }
+ case <-s.stopChan:
+ s.logger.Info("Upstream check loop stopped")
+ return
+ }
}
}
func (s *HealthCheckService) performUpstreamChecks() {
ctx := context.Background()
- settings := s.SettingsManager.GetSettings()
+ settings := s.settingsManager.GetSettings()
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
+ if timeout <= 0 {
+ timeout = 10 * time.Second
+ }
+
var upstreams []*models.UpstreamEndpoint
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
- s.logger.WithError(err).Error("Failed to retrieve upstreams.")
+ s.logger.WithError(err).Error("Failed to retrieve upstreams")
return
}
+
if len(upstreams) == 0 {
+ s.logger.Debug("No upstreams configured for health check")
return
}
- s.logger.Infof("Starting validation for %d upstreams.", len(upstreams))
+
+ s.logger.WithFields(logrus.Fields{
+ "count": len(upstreams),
+ "timeout": timeout,
+ }).Info("Starting upstream validation")
+
+ type checkResult struct {
+ upstreamID uint
+ url string
+ oldStatus string
+ newStatus string
+ changed bool
+ err error
+ }
+
+ results := make([]checkResult, 0, len(upstreams))
+ var resultsMutex sync.Mutex
var wg sync.WaitGroup
+
for _, u := range upstreams {
wg.Add(1)
go func(upstream *models.UpstreamEndpoint) {
defer wg.Done()
+
oldStatus := upstream.Status
isAlive := s.checkEndpoint(upstream.URL, timeout)
newStatus := StatusInactive
if isAlive {
newStatus = StatusActive
}
+
s.lastResultsMutex.Lock()
s.lastResults[upstream.URL] = newStatus
s.lastResultsMutex.Unlock()
- if oldStatus != newStatus {
- s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
+
+ result := checkResult{
+ upstreamID: upstream.ID,
+ url: upstream.URL,
+ oldStatus: oldStatus,
+ newStatus: newStatus,
+ changed: oldStatus != newStatus,
+ }
+
+ if result.changed {
+ s.logger.WithFields(logrus.Fields{
+ "upstream_id": upstream.ID,
+ "url": upstream.URL,
+ "old_status": oldStatus,
+ "new_status": newStatus,
+ }).Info("Upstream status changed")
+
if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
- s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
+ s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status")
+ result.err = err
} else {
s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
}
+ } else {
+ s.logger.WithFields(logrus.Fields{
+ "upstream_id": upstream.ID,
+ "url": upstream.URL,
+ "status": newStatus,
+ }).Debug("Upstream status unchanged")
}
+
+ resultsMutex.Lock()
+ results = append(results, result)
+ resultsMutex.Unlock()
}(u)
}
wg.Wait()
+
+ // 汇总统计
+ activeCount := 0
+ inactiveCount := 0
+ changedCount := 0
+ errorCount := 0
+
+ for _, r := range results {
+ if r.changed {
+ changedCount++
+ }
+ if r.err != nil {
+ errorCount++
+ }
+ if r.newStatus == StatusActive {
+ activeCount++
+ } else {
+ inactiveCount++
+ }
+ }
+
+ s.logger.WithFields(logrus.Fields{
+ "total": len(upstreams),
+ "active": activeCount,
+ "inactive": inactiveCount,
+ "changed": changedCount,
+ "errors": errorCount,
+ }).Info("Upstream validation cycle completed")
}
func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) bool {
- client := http.Client{Timeout: timeout}
- resp, err := client.Head(urlStr)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodHead, urlStr, nil)
+ if err != nil {
+ s.logger.WithError(err).WithField("url", urlStr).Debug("Failed to create request for endpoint check")
+ return false
+ }
+
+ resp, err := s.httpClient.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
+
return resp.StatusCode < http.StatusInternalServerError
}
+// ==================== Proxy Check Loop ====================
+
+func (s *HealthCheckService) runProxyCheckLoop() {
+ defer s.wg.Done()
+ s.logger.Info("Proxy check loop started")
+
+ settings := s.settingsManager.GetSettings()
+ if settings.EnableProxyCheck {
+ s.performProxyChecks()
+ }
+
+ interval := time.Duration(settings.ProxyCheckIntervalSeconds) * time.Second
+ if interval <= 0 {
+ interval = 600 * time.Second // 10 分钟兜底
+ }
+
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ settings := s.settingsManager.GetSettings()
+ if settings.EnableProxyCheck {
+ s.logger.Debug("Proxy check ticker fired")
+ s.performProxyChecks()
+ }
+ case <-s.stopChan:
+ s.logger.Info("Proxy check loop stopped")
+ return
+ }
+ }
+}
+
func (s *HealthCheckService) performProxyChecks() {
ctx := context.Background()
- settings := s.SettingsManager.GetSettings()
+ settings := s.settingsManager.GetSettings()
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
+ if timeout <= 0 {
+ timeout = 15 * time.Second
+ }
+
var proxies []*models.ProxyConfig
if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
- s.logger.WithError(err).Error("Failed to retrieve proxies.")
+ s.logger.WithError(err).Error("Failed to retrieve proxies")
return
}
+
if len(proxies) == 0 {
+ s.logger.Debug("No proxies configured for health check")
return
}
- s.logger.Infof("Starting validation for %d proxies.", len(proxies))
+
+ s.logger.WithFields(logrus.Fields{
+ "count": len(proxies),
+ "timeout": timeout,
+ }).Info("Starting proxy validation")
+
+ activeCount := 0
+ inactiveCount := 0
+ changedCount := 0
+ var statsMutex sync.Mutex
var wg sync.WaitGroup
+
for _, p := range proxies {
wg.Add(1)
go func(proxyCfg *models.ProxyConfig) {
defer wg.Done()
+
+ oldStatus := proxyCfg.Status
isAlive := s.checkProxy(proxyCfg, timeout)
newStatus := StatusInactive
if isAlive {
newStatus = StatusActive
}
+
s.lastResultsMutex.Lock()
s.lastResults[proxyCfg.Address] = newStatus
s.lastResultsMutex.Unlock()
- if proxyCfg.Status != newStatus {
- s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
+
+ if oldStatus != newStatus {
+ s.logger.WithFields(logrus.Fields{
+ "proxy_id": proxyCfg.ID,
+ "address": proxyCfg.Address,
+ "old_status": oldStatus,
+ "new_status": newStatus,
+ }).Info("Proxy status changed")
+
if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
- s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
+ s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status")
}
+
+ statsMutex.Lock()
+ changedCount++
+ if newStatus == StatusActive {
+ activeCount++
+ } else {
+ inactiveCount++
+ }
+ statsMutex.Unlock()
+ } else {
+ s.logger.WithFields(logrus.Fields{
+ "proxy_id": proxyCfg.ID,
+ "address": proxyCfg.Address,
+ "status": newStatus,
+ }).Debug("Proxy status unchanged")
+
+ statsMutex.Lock()
+ if newStatus == StatusActive {
+ activeCount++
+ } else {
+ inactiveCount++
+ }
+ statsMutex.Unlock()
}
}(p)
}
wg.Wait()
+
+ s.logger.WithFields(logrus.Fields{
+ "total": len(proxies),
+ "active": activeCount,
+ "inactive": inactiveCount,
+ "changed": changedCount,
+ }).Info("Proxy validation cycle completed")
}
func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
transport := &http.Transport{}
+
switch proxyCfg.Protocol {
case "http", "https":
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
if err != nil {
- s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format.")
+ s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format")
return false
}
transport.Proxy = http.ProxyURL(proxyUrl)
+
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
if err != nil {
- s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer.")
+ s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer")
return false
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
+
default:
- s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol.")
+ s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol")
return false
}
+
client := &http.Client{
Transport: transport,
Timeout: timeout,
}
+
resp, err := client.Get(ProxyCheckTargetURL)
if err != nil {
return false
}
defer resp.Body.Close()
+
return true
}
-func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
+// ==================== Base Key Check Loop ====================
+
+func (s *HealthCheckService) runBaseKeyCheckLoop() {
+ defer s.wg.Done()
+ s.logger.Info("Global base key check loop started")
+
+ settings := s.settingsManager.GetSettings()
+ if !settings.EnableBaseKeyCheck {
+ s.logger.Info("Global base key check is disabled")
+ return
+ }
+
+ // 启动时执行一次
+ s.performBaseKeyChecks()
+
+ interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
+ if interval <= 0 {
+ s.logger.WithField("interval", settings.BaseKeyCheckIntervalMinutes).Warn("Invalid BaseKeyCheckIntervalMinutes, disabling base key check loop")
+ return
+ }
+
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ s.performBaseKeyChecks()
+ case <-s.stopChan:
+ s.logger.Info("Global base key check loop stopped")
+ return
+ }
+ }
+}
+
+func (s *HealthCheckService) performBaseKeyChecks() {
+ ctx := context.Background()
+ s.logger.Info("Starting global base key check cycle")
+
+ settings := s.settingsManager.GetSettings()
+ timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
+ if timeout <= 0 {
+ timeout = time.Duration(defaultKeyCheckTimeoutSeconds) * time.Second
+ }
+
+ endpoint := settings.BaseKeyCheckEndpoint
+ if endpoint == "" {
+ endpoint = defaultBaseKeyCheckEndpoint
+ s.logger.WithField("endpoint", endpoint).Debug("Using default base key check endpoint")
+ }
+
+ concurrency := settings.BaseKeyCheckConcurrency
+ if concurrency <= 0 {
+ concurrency = defaultBaseKeyCheckConcurrency
+ }
+ concurrency = s.ensureConcurrencyBounds(concurrency)
+
+ keys, err := s.keyRepo.GetActiveMasterKeys()
+ if err != nil {
+ s.logger.WithError(err).Error("Failed to fetch active master keys for base check")
+ return
+ }
+
+ if len(keys) == 0 {
+ s.logger.Info("No active master keys to perform base check on")
+ return
+ }
+
+ s.logger.WithFields(logrus.Fields{
+ "key_count": len(keys),
+ "concurrency": concurrency,
+ "endpoint": endpoint,
+ }).Info("Performing base check on active master keys")
+
+ jobs := make(chan *models.APIKey, len(keys))
+ var wg sync.WaitGroup
+
+ for w := 0; w < concurrency; w++ {
+ wg.Add(1)
+ go func(workerID int) {
+ defer wg.Done()
+ for key := range jobs {
+ select {
+ case <-s.stopChan:
+ return
+ default:
+ err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint)
+ if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
+ oldStatus := key.MasterStatus
+ keyPrefix := key.APIKey
+ if len(keyPrefix) > 8 {
+ keyPrefix = keyPrefix[:8]
+ }
+
+ s.logger.WithFields(logrus.Fields{
+ "key_id": key.ID,
+ "key_prefix": keyPrefix + "...",
+ "error": err.Error(),
+ }).Warn("Key failed definitive base check, setting MasterStatus to REVOKED")
+
+ if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
+ s.logger.WithError(updateErr).WithField("key_id", key.ID).Error("Failed to update master status")
+ } else {
+ s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
+ }
+ }
+ }
+ }
+ }(w)
+ }
+
+ for _, key := range keys {
+ jobs <- key
+ }
+ close(jobs)
+ wg.Wait()
+
+ s.logger.Info("Global base key check cycle finished")
+}
+
+// ==================== Event Publishing ====================
+
+func (s *HealthCheckService) publishKeyStatusChangedEvent(
+ ctx context.Context,
+ groupID, keyID uint,
+ oldStatus, newStatus models.APIKeyStatus,
+) {
event := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -472,17 +901,23 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, g
ChangeReason: "health_check",
ChangedAt: time.Now(),
}
+
payload, err := json.Marshal(event)
if err != nil {
- s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
+ s.logger.WithError(err).WithField("group_id", groupID).Error("Failed to marshal KeyStatusChangedEvent")
return
}
+
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
- s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
+ s.logger.WithError(err).WithField("group_id", groupID).Error("Failed to publish KeyStatusChangedEvent")
}
}
-func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
+func (s *HealthCheckService) publishUpstreamHealthChangedEvent(
+ ctx context.Context,
+ upstream *models.UpstreamEndpoint,
+ oldStatus, newStatus string,
+) {
event := models.UpstreamHealthChangedEvent{
UpstreamID: upstream.ID,
UpstreamURL: upstream.URL,
@@ -492,93 +927,23 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Conte
Reason: "health_check",
CheckedAt: time.Now(),
}
+
payload, err := json.Marshal(event)
if err != nil {
- s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
+ s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent")
return
}
+
if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
- s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
+ s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent")
}
}
-func (s *HealthCheckService) runBaseKeyCheckLoop() {
- defer s.wg.Done()
- s.logger.Info("Global base key check loop started.")
- settings := s.SettingsManager.GetSettings()
- if !settings.EnableBaseKeyCheck {
- s.logger.Info("Global base key check is disabled.")
- return
- }
- s.performBaseKeyChecks()
- interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
- if interval <= 0 {
- s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
- return
- }
- ticker := time.NewTicker(interval)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- s.performBaseKeyChecks()
- case <-s.stopChan:
- s.logger.Info("Global base key check loop stopped.")
- return
- }
- }
-}
-
-func (s *HealthCheckService) performBaseKeyChecks() {
- ctx := context.Background()
- s.logger.Info("Starting global base key check cycle.")
- settings := s.SettingsManager.GetSettings()
- timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
- endpoint := settings.BaseKeyCheckEndpoint
- concurrency := settings.BaseKeyCheckConcurrency
- keys, err := s.keyRepo.GetActiveMasterKeys()
- if err != nil {
- s.logger.WithError(err).Error("Failed to fetch active master keys for base check.")
- return
- }
- if len(keys) == 0 {
- s.logger.Info("No active master keys to perform base check on.")
- return
- }
- s.logger.Infof("Performing base check on %d active master keys.", len(keys))
- jobs := make(chan *models.APIKey, len(keys))
- var wg sync.WaitGroup
- if concurrency <= 0 {
- concurrency = 5
- }
- for w := 0; w < concurrency; w++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for key := range jobs {
- err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint)
- if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
- oldStatus := key.MasterStatus
- s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
- if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
- s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
- } else {
- s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
- }
- }
- }
- }()
- }
- for _, key := range keys {
- jobs <- key
- }
- close(jobs)
- wg.Wait()
- s.logger.Info("Global base key check cycle finished.")
-}
-
-func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
+func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(
+ ctx context.Context,
+ keyID uint,
+ oldStatus, newStatus models.MasterAPIKeyStatus,
+) {
event := models.MasterKeyStatusChangedEvent{
KeyID: keyID,
OldMasterStatus: oldStatus,
@@ -586,12 +951,47 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Cont
ChangeReason: "base_health_check",
ChangedAt: time.Now(),
}
+
payload, err := json.Marshal(event)
if err != nil {
- s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
+ s.logger.WithError(err).WithField("key_id", keyID).Error("Failed to marshal MasterKeyStatusChangedEvent")
return
}
+
if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
- s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
+ s.logger.WithError(err).WithField("key_id", keyID).Error("Failed to publish MasterKeyStatusChangedEvent")
}
}
+
+// ==================== Helper Methods ====================
+
+func (s *HealthCheckService) getConcurrency(configValue *int, defaultValue int) int {
+ var concurrency int
+ if configValue != nil && *configValue > 0 {
+ concurrency = *configValue
+ } else {
+ concurrency = defaultValue
+ }
+
+ return s.ensureConcurrencyBounds(concurrency)
+}
+
+func (s *HealthCheckService) ensureConcurrencyBounds(concurrency int) int {
+ if concurrency < minHealthCheckConcurrency {
+ s.logger.WithFields(logrus.Fields{
+ "requested": concurrency,
+ "minimum": minHealthCheckConcurrency,
+ }).Debug("Concurrency below minimum, adjusting")
+ return minHealthCheckConcurrency
+ }
+
+ if concurrency > maxHealthCheckConcurrency {
+ s.logger.WithFields(logrus.Fields{
+ "requested": concurrency,
+ "maximum": maxHealthCheckConcurrency,
+ }).Warn("Concurrency exceeds maximum, capping it")
+ return maxHealthCheckConcurrency
+ }
+
+ return concurrency
+}
diff --git a/internal/service/key_import_service.go b/internal/service/key_import_service.go
index fca6c7c..9f178c8 100644
--- a/internal/service/key_import_service.go
+++ b/internal/service/key_import_service.go
@@ -23,6 +23,10 @@ const (
TaskTypeHardDeleteKeys = "hard_delete_keys"
TaskTypeRestoreKeys = "restore_keys"
chunkSize = 500
+
+ // 任务超时时间常量化
+ defaultTaskTimeout = 15 * time.Minute
+ longTaskTimeout = time.Hour
)
type KeyImportService struct {
@@ -43,17 +47,19 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
}
}
+// runTaskWithRecovery 统一的任务恢复包装器
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
- s.logger.Error(err)
+ s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
taskFunc()
}
+// StartAddKeysTask 启动批量添加密钥任务
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
@@ -61,260 +67,404 @@ func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, k
}
resourceID := fmt.Sprintf("group-%d", groupID)
- taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
+ taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout)
if err != nil {
return nil, err
}
+
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
})
+
return taskStatus, nil
}
+// StartUnlinkKeysTask 启动批量解绑密钥任务
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
+
resourceID := fmt.Sprintf("group-%d", groupID)
- taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
+ taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
+
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
})
+
return taskStatus, nil
}
+// StartHardDeleteKeysTask 启动硬删除密钥任务
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
+
resourceID := "global_hard_delete"
- taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
+ taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
+
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
+
return taskStatus, nil
}
+// StartRestoreKeysTask 启动恢复密钥任务
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
- resourceID := "global_restore_keys"
- taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
+ resourceID := "global_restore_keys"
+ taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
+
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
+
return taskStatus, nil
}
-func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
- uniqueKeysMap := make(map[string]struct{})
- var uniqueKeyStrings []string
- for _, kStr := range keys {
- if _, exists := uniqueKeysMap[kStr]; !exists {
- uniqueKeysMap[kStr] = struct{}{}
- uniqueKeyStrings = append(uniqueKeyStrings, kStr)
- }
- }
- if len(uniqueKeyStrings) == 0 {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
- return
- }
- keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
- for i, keyStr := range uniqueKeyStrings {
- keysToEnsure[i] = models.APIKey{APIKey: keyStr}
- }
- allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
+// StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑
+func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
+ s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
+
+ keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
- return
+ return nil, fmt.Errorf("failed to find keys by filter: %w", err)
}
- alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
- if err != nil {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
- return
- }
- alreadyLinkedIDSet := make(map[uint]struct{})
- for _, key := range alreadyLinkedModels {
- alreadyLinkedIDSet[key.ID] = struct{}{}
- }
- var keysToLink []models.APIKey
- for _, key := range allKeyModels {
- if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
- keysToLink = append(keysToLink, key)
- }
+ if len(keyValues) == 0 {
+ return nil, fmt.Errorf("no keys found matching the provided filter")
}
+ keysAsText := strings.Join(keyValues, "\n")
+ s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
+
+ return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
+}
+
+// ==================== 核心任务执行逻辑 ====================
+
+// runAddKeysTask 执行批量添加密钥
+func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
+ // 1. 去重
+ uniqueKeys := s.deduplicateKeys(keys)
+ if len(uniqueKeys) == 0 {
+ s.endTaskWithResult(ctx, taskID, resourceID, gin.H{
+ "newly_linked_count": 0,
+ "already_linked_count": 0,
+ }, nil)
+ return
+ }
+
+ // 2. 确保所有密钥在数据库中存在(幂等操作)
+ allKeyModels, err := s.ensureKeysExist(uniqueKeys)
+ if err != nil {
+ s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
+ return
+ }
+
+ // 3. 过滤已关联的密钥
+ keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys)
+ if err != nil {
+ s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err))
+ return
+ }
+
+ // 4. 更新任务的实际处理总数
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
+
+ // 5. 批量关联密钥到组
if len(keysToLink) > 0 {
- idsToLink := make([]uint, len(keysToLink))
- for i, key := range keysToLink {
- idsToLink[i] = key.ID
+ if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil {
+ s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
+ return
}
- for i := 0; i < len(idsToLink); i += chunkSize {
- end := i + chunkSize
- if end > len(idsToLink) {
- end = len(idsToLink)
- }
- chunk := idsToLink[i:end]
- if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
- return
- }
- _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
+ }
+
+ // 6. 根据验证标志处理密钥状态
+ if len(keysToLink) > 0 {
+ s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport)
+ }
+
+ // 7. 返回结果
+ result := gin.H{
+ "newly_linked_count": len(keysToLink),
+ "already_linked_count": alreadyLinkedCount,
+ "total_linked_count": len(allKeyModels),
+ }
+ s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
+}
+
+// runUnlinkKeysTask 执行批量解绑密钥
+func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
+ // 1. 去重
+ uniqueKeys := s.deduplicateKeys(keys)
+
+ // 2. 查找需要解绑的密钥
+ keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
+ if err != nil {
+ s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
+ return
+ }
+
+ if len(keysToUnlink) == 0 {
+ result := gin.H{
+ "unlinked_count": 0,
+ "hard_deleted_count": 0,
+ "not_found_count": len(uniqueKeys),
}
+ s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
+ return
+ }
+
+ // 3. 提取密钥 ID
+ idsToUnlink := s.extractKeyIDs(keysToUnlink)
+
+ // 4. 更新任务总数
+ if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
+ s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
+ }
+
+ // 5. 批量解绑
+ totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink)
+ if err != nil {
+ s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
+ return
+ }
+
+ // 6. 清理孤立密钥
+ totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
+ if err != nil {
+ s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
+ }
+
+ // 7. 返回结果
+ result := gin.H{
+ "unlinked_count": totalUnlinked,
+ "hard_deleted_count": totalDeleted,
+ "not_found_count": len(uniqueKeys) - int(totalUnlinked),
+ }
+ s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
+}
+
+// runHardDeleteKeysTask 执行硬删除密钥
+func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
+ totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
+ return s.keyRepo.HardDeleteByValues(chunk)
+ })
+
+ if err != nil {
+ s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
+ return
}
result := gin.H{
- "newly_linked_count": len(keysToLink),
- "already_linked_count": len(alreadyLinkedIDSet),
- "total_linked_count": len(allKeyModels),
+ "hard_deleted_count": totalDeleted,
+ "not_found_count": int64(len(keys)) - totalDeleted,
}
- if len(keysToLink) > 0 {
- idsToLink := make([]uint, len(keysToLink))
- for i, key := range keysToLink {
- idsToLink[i] = key.ID
- }
- if validateOnImport {
- s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
- for _, keyID := range idsToLink {
- s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
- }
- } else {
- for _, keyID := range idsToLink {
- if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
- s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
- }
- }
- }
- }
- s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
+ s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
+ s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
-func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
- uniqueKeysMap := make(map[string]struct{})
- var uniqueKeys []string
+// runRestoreKeysTask 执行恢复密钥
+func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
+ restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
+ return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
+ })
+
+ if err != nil {
+ s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
+ return
+ }
+
+ result := gin.H{
+ "restored_count": restoredCount,
+ "not_found_count": int64(len(keys)) - restoredCount,
+ }
+ s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
+ s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
+}
+
+// ==================== 辅助方法 ====================
+
+// deduplicateKeys 去重密钥列表
+func (s *KeyImportService) deduplicateKeys(keys []string) []string {
+ uniqueKeysMap := make(map[string]struct{}, len(keys))
+ uniqueKeys := make([]string, 0, len(keys))
+
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeys = append(uniqueKeys, kStr)
}
}
- keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
+ return uniqueKeys
+}
+
+// ensureKeysExist 确保所有密钥在数据库中存在
+func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) {
+ keysToEnsure := make([]models.APIKey, len(keys))
+ for i, keyStr := range keys {
+ keysToEnsure[i] = models.APIKey{APIKey: keyStr}
+ }
+ return s.keyRepo.AddKeys(keysToEnsure)
+}
+
+// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥
+func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) {
+ alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
- return
+ return nil, 0, err
}
- if len(keysToUnlink) == 0 {
- result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
- s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
- return
- }
- idsToUnlink := make([]uint, len(keysToUnlink))
- for i, key := range keysToUnlink {
- idsToUnlink[i] = key.ID
+ alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels))
+ for _, key := range alreadyLinkedModels {
+ alreadyLinkedIDSet[key.ID] = struct{}{}
}
- if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
- s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
- }
- var totalUnlinked int64
- for i := 0; i < len(idsToUnlink); i += chunkSize {
- end := i + chunkSize
- if end > len(idsToUnlink) {
- end = len(idsToUnlink)
+ keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet))
+ for _, key := range allKeyModels {
+ if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
+ keysToLink = append(keysToLink, key)
}
+ }
+
+ return keysToLink, len(alreadyLinkedIDSet), nil
+}
+
+// extractKeyIDs 提取密钥 ID 列表
+func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint {
+ ids := make([]uint, len(keys))
+ for i, key := range keys {
+ ids[i] = key.ID
+ }
+ return ids
+}
+
+// linkKeysInChunks 分块关联密钥到组
+func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error {
+ idsToLink := s.extractKeyIDs(keysToLink)
+
+ for i := 0; i < len(idsToLink); i += chunkSize {
+ end := min(i+chunkSize, len(idsToLink))
+ chunk := idsToLink[i:end]
+
+ if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
+ return fmt.Errorf("chunk failed to link keys: %w", err)
+ }
+
+ _ = s.taskService.UpdateProgressByID(ctx, taskID, end)
+ }
+ return nil
+}
+
+// unlinkKeysInChunks 分块解绑密钥
+func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) {
+ var totalUnlinked int64
+
+ for i := 0; i < len(idsToUnlink); i += chunkSize {
+ end := min(i+chunkSize, len(idsToUnlink))
chunk := idsToUnlink[i:end]
+
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
- return
+ return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err)
}
+
totalUnlinked += unlinked
+
+ // 发布解绑事件
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
}
- _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
+
+ _ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
- totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
+ return totalUnlinked, nil
+}
+
+// processKeysInChunks 通用的分块处理密钥逻辑
+func (s *KeyImportService) processKeysInChunks(
+ ctx context.Context,
+ taskID string,
+ keys []string,
+ processFunc func(chunk []string) (int64, error),
+) (int64, error) {
+ var totalProcessed int64
+
+ for i := 0; i < len(keys); i += chunkSize {
+ end := min(i+chunkSize, len(keys))
+ chunk := keys[i:end]
+
+ count, err := processFunc(chunk)
+ if err != nil {
+ return 0, fmt.Errorf("failed to process chunk: %w", err)
+ }
+
+ totalProcessed += count
+ _ = s.taskService.UpdateProgressByID(ctx, taskID, end)
+ }
+
+ return totalProcessed, nil
+}
+
+// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活)
+func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) {
+ idsToLink := s.extractKeyIDs(keysToLink)
+
+ if validateOnImport {
+ // 发布批量导入完成事件,触发验证
+ s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
+
+ // 发布单个密钥状态变更事件
+ for _, keyID := range idsToLink {
+ s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
+ }
+ } else {
+ // 直接激活密钥,不进行验证
+ for _, keyID := range idsToLink {
+ if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "key_id": keyID,
+ }).Errorf("Failed to directly activate key: %v", err)
+ }
+ }
+ }
+}
+
+// endTaskWithResult 统一的任务结束处理
+func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) {
if err != nil {
- s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
+ s.logger.WithFields(logrus.Fields{
+ "task_id": taskID,
+ "resource_id": resourceID,
+ }).WithError(err).Error("Task failed")
}
- result := gin.H{
- "unlinked_count": totalUnlinked,
- "hard_deleted_count": totalDeleted,
- "not_found_count": len(uniqueKeys) - int(totalUnlinked),
- }
- s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
+ s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err)
}
-func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
- var totalDeleted int64
- for i := 0; i < len(keys); i += chunkSize {
- end := i + chunkSize
- if end > len(keys) {
- end = len(keys)
- }
- chunk := keys[i:end]
-
- deleted, err := s.keyRepo.HardDeleteByValues(chunk)
- if err != nil {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
- return
- }
- totalDeleted += deleted
- _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
- }
- result := gin.H{
- "hard_deleted_count": totalDeleted,
- "not_found_count": int64(len(keys)) - totalDeleted,
- }
- s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
- s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
-}
-
-func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
- var restoredCount int64
- for i := 0; i < len(keys); i += chunkSize {
- end := i + chunkSize
- if end > len(keys) {
- end = len(keys)
- }
- chunk := keys[i:end]
-
- count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
- if err != nil {
- s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
- return
- }
- restoredCount += count
- _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
- }
- result := gin.H{
- "restored_count": restoredCount,
- "not_found_count": int64(len(keys)) - restoredCount,
- }
- s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
- s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
-}
+// ==================== 事件发布方法 ====================
+// publishSingleKeyChangeEvent 发布单个密钥状态变更事件
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
@@ -324,56 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, grou
ChangeReason: reason,
ChangedAt: time.Now(),
}
- eventData, _ := json.Marshal(event)
- if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
- s.logger.WithError(err).WithFields(logrus.Fields{
+
+ eventData, err := json.Marshal(event)
+ if err != nil {
+ s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"reason": reason,
- }).Error("Failed to publish single key change event.")
+ }).WithError(err).Error("Failed to marshal key change event")
+ return
+ }
+
+ if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "key_id": keyID,
+ "reason": reason,
+ }).WithError(err).Error("Failed to publish single key change event")
}
}
+// publishChangeEvent 发布通用变更事件
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
ChangeReason: reason,
}
- eventData, _ := json.Marshal(event)
- _ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
+
+ eventData, err := json.Marshal(event)
+ if err != nil {
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "reason": reason,
+ }).WithError(err).Error("Failed to marshal change event")
+ return
+ }
+
+ if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "reason": reason,
+ }).WithError(err).Error("Failed to publish change event")
+ }
}
+// publishImportGroupCompletedEvent 发布批量导入完成事件
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 {
return
}
+
event := models.ImportGroupCompletedEvent{
GroupID: groupID,
KeyIDs: keyIDs,
CompletedAt: time.Now(),
}
+
eventData, err := json.Marshal(event)
if err != nil {
- s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "key_count": len(keyIDs),
+ }).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent")
return
}
+
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
- s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "key_count": len(keyIDs),
+ }).WithError(err).Error("Failed to publish ImportGroupCompletedEvent")
} else {
- s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "key_count": len(keyIDs),
+ }).Info("Published ImportGroupCompletedEvent")
}
}
-func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
- s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
- keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
- if err != nil {
- return nil, fmt.Errorf("failed to find keys by filter: %w", err)
+// min 返回两个整数中的较小值
+func min(a, b int) int {
+ if a < b {
+ return a
}
- if len(keyValues) == 0 {
- return nil, fmt.Errorf("no keys found matching the provided filter")
- }
- keysAsText := strings.Join(keyValues, "\n")
- s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
- return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
+ return b
}
diff --git a/internal/service/key_validation_service.go b/internal/service/key_validation_service.go
index 74e6678..32820e3 100644
--- a/internal/service/key_validation_service.go
+++ b/internal/service/key_validation_service.go
@@ -25,26 +25,38 @@ import (
)
const (
- TaskTypeTestKeys = "test_keys"
+ TaskTypeTestKeys = "test_keys"
+ defaultConcurrency = 10
+ maxValidationConcurrency = 100
+ validationTaskTimeout = time.Hour
)
type KeyValidationService struct {
taskService task.Reporter
channel channel.ChannelProxy
db *gorm.DB
- SettingsManager *settings.SettingsManager
+ settingsManager *settings.SettingsManager
groupManager *GroupManager
store store.Store
keyRepo repository.KeyRepository
logger *logrus.Entry
}
-func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
+func NewKeyValidationService(
+ ts task.Reporter,
+ ch channel.ChannelProxy,
+ db *gorm.DB,
+ ss *settings.SettingsManager,
+ gm *GroupManager,
+ st store.Store,
+ kr repository.KeyRepository,
+ logger *logrus.Logger,
+) *KeyValidationService {
return &KeyValidationService{
taskService: ts,
channel: ch,
db: db,
- SettingsManager: ss,
+ settingsManager: ss,
groupManager: gm,
store: st,
keyRepo: kr,
@@ -52,33 +64,393 @@ func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm
}
}
+// ==================== 公开接口 ====================
+
+// ValidateSingleKey 验证单个密钥
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
+ // 1. 解密密钥
if err := s.keyRepo.Decrypt(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
}
+
+ // 2. 创建 HTTP 客户端和请求
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
- s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
+ s.logger.WithFields(logrus.Fields{
+ "key_id": key.ID,
+ "endpoint": endpoint,
+ }).Error("Failed to create validation request")
return fmt.Errorf("failed to create request: %w", err)
}
+ // 3. 修改请求(添加密钥认证头)
s.channel.ModifyRequest(req, key)
+ // 4. 执行请求
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
+ // 5. 检查响应状态
if resp.StatusCode == http.StatusOK {
return nil
}
+ // 6. 处理错误响应
+ return s.buildValidationError(resp)
+}
+
+// StartTestKeysTask 启动批量密钥测试任务
+func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
+ // 1. 解析和验证输入
+ keyStrings := utils.ParseKeysFromText(keysText)
+ if len(keyStrings) == 0 {
+ return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
+ }
+
+ // 2. 查询密钥模型
+ apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
+ if err != nil {
+ return nil, CustomErrors.ParseDBError(err)
+ }
+ if len(apiKeyModels) == 0 {
+ return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
+ }
+
+ // 3. 批量解密密钥
+ if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
+ s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
+ return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
+ }
+
+ // 4. 获取组配置
+ group, ok := s.groupManager.GetGroupByID(groupID)
+ if !ok {
+ return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
+ }
+
+ opConfig, err := s.groupManager.BuildOperationalConfig(group)
+ if err != nil {
+ return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
+ }
+
+ // 5. 构建验证端点
+ endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
+ if err != nil {
+ return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
+ }
+
+ // 6. 创建任务
+ resourceID := fmt.Sprintf("group-%d", groupID)
+ taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
+ if err != nil {
+ return nil, err
+ }
+
+ // 7. 准备任务参数
+ params := s.buildValidationParams(opConfig)
+
+ // 8. 启动异步验证任务
+ go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
+
+ return taskStatus, nil
+}
+
+// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
+func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
+ s.logger.WithFields(logrus.Fields{
+ "group_id": groupID,
+ "statuses": statuses,
+ }).Info("Starting test task with status filter")
+
+ // 1. 根据过滤条件查询密钥
+ keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
+ if err != nil {
+ return nil, CustomErrors.ParseDBError(err)
+ }
+
+ if len(keyValues) == 0 {
+ return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
+ }
+
+ // 2. 转换为文本格式并启动任务
+ keysAsText := strings.Join(keyValues, "\n")
+ s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
+
+ return s.StartTestKeysTask(ctx, groupID, keysAsText)
+}
+
+// ==================== 核心任务执行逻辑 ====================
+
+// validationParams 验证参数封装
+type validationParams struct {
+ timeout time.Duration
+ concurrency int
+}
+
+// buildValidationParams 构建验证参数
+func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
+ settings := s.settingsManager.GetSettings()
+ // 从配置读取超时时间(而非硬编码)
+ timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
+ if timeout <= 0 {
+ timeout = 30 * time.Second // 仅在配置无效时使用默认值
+ }
+ // 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
+ var concurrency int
+ if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
+ concurrency = *opConfig.KeyCheckConcurrency
+ } else if settings.BaseKeyCheckConcurrency > 0 {
+ concurrency = settings.BaseKeyCheckConcurrency
+ } else {
+ concurrency = defaultConcurrency // 兜底默认值
+ }
+ // 限制最大并发数(防护措施)
+ if concurrency > maxValidationConcurrency {
+ concurrency = maxValidationConcurrency
+ }
+ return validationParams{
+ timeout: timeout,
+ concurrency: concurrency,
+ }
+}
+
+// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
+func (s *KeyValidationService) runTestKeysTaskWithRecovery(
+ ctx context.Context,
+ taskID string,
+ resourceID string,
+ groupID uint,
+ keys []models.APIKey,
+ params validationParams,
+ endpoint string,
+) {
+ defer func() {
+ if r := recover(); r != nil {
+ err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
+ s.logger.WithField("task_id", taskID).Error(err)
+ s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
+ }
+ }()
+
+ s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
+}
+
+// runTestKeysTask 执行批量密钥验证任务
+func (s *KeyValidationService) runTestKeysTask(
+ ctx context.Context,
+ taskID string,
+ resourceID string,
+ groupID uint,
+ keys []models.APIKey,
+ params validationParams,
+ endpoint string,
+) {
+ s.logger.WithFields(logrus.Fields{
+ "task_id": taskID,
+ "group_id": groupID,
+ "key_count": len(keys),
+ "concurrency": params.concurrency,
+ "timeout": params.timeout,
+ }).Info("Starting validation task")
+
+ // 1. 初始化结果收集
+ results := make([]models.KeyTestResult, len(keys))
+
+ // 2. 创建任务分发器
+ dispatcher := newValidationDispatcher(
+ keys,
+ params.concurrency,
+ s,
+ ctx,
+ taskID,
+ groupID,
+ endpoint,
+ params.timeout,
+ )
+
+ // 3. 执行并发验证
+ dispatcher.run(results)
+
+ // 4. 完成任务
+ s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
+
+ s.logger.WithFields(logrus.Fields{
+ "task_id": taskID,
+ "group_id": groupID,
+ "processed": len(results),
+ }).Info("Validation task completed")
+}
+
+// ==================== 验证调度器 ====================
+
+// validationJob 验证作业
+type validationJob struct {
+ index int
+ key models.APIKey
+}
+
+// validationDispatcher 验证任务分发器
+type validationDispatcher struct {
+ keys []models.APIKey
+ concurrency int
+ service *KeyValidationService
+ ctx context.Context
+ taskID string
+ groupID uint
+ endpoint string
+ timeout time.Duration
+
+ mu sync.Mutex
+ processedCount int
+}
+
+// newValidationDispatcher 创建验证分发器
+func newValidationDispatcher(
+ keys []models.APIKey,
+ concurrency int,
+ service *KeyValidationService,
+ ctx context.Context,
+ taskID string,
+ groupID uint,
+ endpoint string,
+ timeout time.Duration,
+) *validationDispatcher {
+ return &validationDispatcher{
+ keys: keys,
+ concurrency: concurrency,
+ service: service,
+ ctx: ctx,
+ taskID: taskID,
+ groupID: groupID,
+ endpoint: endpoint,
+ timeout: timeout,
+ }
+}
+
+// run 执行并发验证
+func (d *validationDispatcher) run(results []models.KeyTestResult) {
+ var wg sync.WaitGroup
+ jobs := make(chan validationJob, len(d.keys))
+
+ // 启动 worker pool
+ for i := 0; i < d.concurrency; i++ {
+ wg.Add(1)
+ go d.worker(&wg, jobs, results)
+ }
+
+ // 分发任务
+ for i, key := range d.keys {
+ jobs <- validationJob{index: i, key: key}
+ }
+ close(jobs)
+
+ // 等待所有 worker 完成
+ wg.Wait()
+}
+
+// worker 验证工作协程
+func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
+ defer wg.Done()
+
+ for job := range jobs {
+ result := d.validateKey(job.key)
+
+ d.mu.Lock()
+ results[job.index] = result
+ d.processedCount++
+ _ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
+ d.mu.Unlock()
+ }
+}
+
+// validateKey 验证单个密钥并返回结果
+func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
+ // 1. 执行验证
+ validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
+
+ // 2. 构建结果和事件
+ result, event := d.buildResultAndEvent(key, validationErr)
+
+ // 3. 发布验证事件
+ d.publishValidationEvent(key.ID, event)
+
+ return result
+}
+
+// buildResultAndEvent 构建验证结果和事件
+func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
+ event := models.RequestFinishedEvent{
+ RequestLog: models.RequestLog{
+ GroupID: &d.groupID,
+ KeyID: &key.ID,
+ },
+ }
+
+ if validationErr == nil {
+ // 验证成功
+ event.RequestLog.IsSuccess = true
+ return models.KeyTestResult{
+ Key: key.APIKey,
+ Status: "valid",
+ Message: "Validation successful",
+ }, event
+ }
+
+ // 验证失败
+ event.RequestLog.IsSuccess = false
+
+ var apiErr *CustomErrors.APIError
+ if CustomErrors.As(validationErr, &apiErr) {
+ event.Error = apiErr
+ return models.KeyTestResult{
+ Key: key.APIKey,
+ Status: "invalid",
+ Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
+ }, event
+ }
+
+ // 其他错误
+ event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
+ return models.KeyTestResult{
+ Key: key.APIKey,
+ Status: "error",
+ Message: "Validation check failed: " + validationErr.Error(),
+ }, event
+}
+
+// publishValidationEvent 发布验证事件
+func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
+ eventData, err := json.Marshal(event)
+ if err != nil {
+ d.service.logger.WithFields(logrus.Fields{
+ "key_id": keyID,
+ "group_id": d.groupID,
+ }).WithError(err).Error("Failed to marshal validation event")
+ return
+ }
+
+ if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
+ d.service.logger.WithFields(logrus.Fields{
+ "key_id": keyID,
+ "group_id": d.groupID,
+ }).WithError(err).Error("Failed to publish validation event")
+ }
+}
+
+// ==================== 辅助方法 ====================
+
+// buildValidationError 构建验证错误
+func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
bodyBytes, readErr := io.ReadAll(resp.Body)
+
var errorMsg string
if readErr != nil {
errorMsg = "Failed to read error response body"
+ s.logger.WithError(readErr).Warn("Failed to read validation error response")
} else {
errorMsg = string(bodyBytes)
}
@@ -89,128 +461,3 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
Code: "VALIDATION_FAILED",
}
}
-
-func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
- keyStrings := utils.ParseKeysFromText(keysText)
- if len(keyStrings) == 0 {
- return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
- }
- apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
- if err != nil {
- return nil, CustomErrors.ParseDBError(err)
- }
- if len(apiKeyModels) == 0 {
- return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
- }
- if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
- s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
- return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
- }
- group, ok := s.groupManager.GetGroupByID(groupID)
- if !ok {
- return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
- }
- opConfig, err := s.groupManager.BuildOperationalConfig(group)
- if err != nil {
- return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
- }
- resourceID := fmt.Sprintf("group-%d", groupID)
- taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
- if err != nil {
- return nil, err
- }
- settings := s.SettingsManager.GetSettings()
- timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
- endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
- if err != nil {
- s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
- return nil, err
- }
- var concurrency int
- if opConfig.KeyCheckConcurrency != nil {
- concurrency = *opConfig.KeyCheckConcurrency
- } else {
- concurrency = settings.BaseKeyCheckConcurrency
- }
- go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
- return taskStatus, nil
-}
-
-func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
- var wg sync.WaitGroup
- var mu sync.Mutex
- finalResults := make([]models.KeyTestResult, len(keys))
- processedCount := 0
- if concurrency <= 0 {
- concurrency = 10
- }
- type job struct {
- Index int
- Value models.APIKey
- }
- jobs := make(chan job, len(keys))
- for i := 0; i < concurrency; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for j := range jobs {
- apiKeyModel := j.Value
- keyToValidate := apiKeyModel
- validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
-
- var currentResult models.KeyTestResult
- event := models.RequestFinishedEvent{
- RequestLog: models.RequestLog{
- GroupID: &groupID,
- KeyID: &apiKeyModel.ID,
- },
- }
- if validationErr == nil {
- currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
- event.RequestLog.IsSuccess = true
- } else {
- var apiErr *CustomErrors.APIError
- if CustomErrors.As(validationErr, &apiErr) {
- currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
- event.Error = apiErr
- } else {
- currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
- event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
- }
- event.RequestLog.IsSuccess = false
- }
- eventData, _ := json.Marshal(event)
-
- if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
- s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
- }
-
- mu.Lock()
- finalResults[j.Index] = currentResult
- processedCount++
- _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
- mu.Unlock()
- }
- }()
- }
- for i, k := range keys {
- jobs <- job{Index: i, Value: k}
- }
- close(jobs)
- wg.Wait()
- s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
-}
-
-func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
- s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
- keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
- if err != nil {
- return nil, CustomErrors.ParseDBError(err)
- }
- if len(keyValues) == 0 {
- return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
- }
- keysAsText := strings.Join(keyValues, "\n")
- s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
- return s.StartTestKeysTask(ctx, groupID, keysAsText)
-}
diff --git a/internal/service/log_service.go b/internal/service/log_service.go
index 38cf66d..8ff5262 100644
--- a/internal/service/log_service.go
+++ b/internal/service/log_service.go
@@ -1,78 +1,152 @@
-// Filename: internal/service/log_service.go
package service
import (
+ "context"
+ "fmt"
"gemini-balancer/internal/models"
"strconv"
- "github.com/gin-gonic/gin"
+ "github.com/sirupsen/logrus"
"gorm.io/gorm"
)
type LogService struct {
- db *gorm.DB
+ db *gorm.DB
+ logger *logrus.Entry
}
-func NewLogService(db *gorm.DB) *LogService {
- return &LogService{db: db}
+func NewLogService(db *gorm.DB, logger *logrus.Logger) *LogService {
+ return &LogService{
+ db: db,
+ logger: logger.WithField("component", "LogService"),
+ }
}
-func (s *LogService) Record(log *models.RequestLog) error {
- return s.db.Create(log).Error
+func (s *LogService) Record(ctx context.Context, log *models.RequestLog) error {
+ return s.db.WithContext(ctx).Create(log).Error
}
-func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, int64, error) {
+// LogQueryParams 解耦 Gin,使用结构体传参
+type LogQueryParams struct {
+ Page int
+ PageSize int
+ ModelName string
+ IsSuccess *bool // 使用指针区分"未设置"和"false"
+ StatusCode *int
+ KeyID *uint64
+ GroupID *uint64
+}
+
+func (s *LogService) GetLogs(ctx context.Context, params LogQueryParams) ([]models.RequestLog, int64, error) {
+ // 参数校验
+ if params.Page < 1 {
+ params.Page = 1
+ }
+ if params.PageSize < 1 || params.PageSize > 100 {
+ params.PageSize = 20
+ }
+
var logs []models.RequestLog
var total int64
- query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c))
+ // 构建基础查询
+ query := s.db.WithContext(ctx).Model(&models.RequestLog{})
+ query = s.applyFilters(query, params)
- // 先计算总数
+ // 计算总数
if err := query.Count(&total).Error; err != nil {
- return nil, 0, err
+ return nil, 0, fmt.Errorf("failed to count logs: %w", err)
}
+
if total == 0 {
return []models.RequestLog{}, 0, nil
}
- // 再执行分页查询
- page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
- pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
- offset := (page - 1) * pageSize
-
- err := query.Order("request_time desc").Limit(pageSize).Offset(offset).Find(&logs).Error
- if err != nil {
- return nil, 0, err
+ // 分页查询
+ offset := (params.Page - 1) * params.PageSize
+ if err := query.Order("request_time DESC").
+ Limit(params.PageSize).
+ Offset(offset).
+ Find(&logs).Error; err != nil {
+ return nil, 0, fmt.Errorf("failed to query logs: %w", err)
}
return logs, total, nil
}
-func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB {
- return func(db *gorm.DB) *gorm.DB {
- if modelName := c.Query("model_name"); modelName != "" {
- db = db.Where("model_name = ?", modelName)
- }
- if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
- if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
- db = db.Where("is_success = ?", isSuccess)
- }
- }
- if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
- if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
- db = db.Where("status_code = ?", statusCode)
- }
- }
- if keyIDStr := c.Query("key_id"); keyIDStr != "" {
- if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
- db = db.Where("key_id = ?", keyID)
- }
- }
- if groupIDStr := c.Query("group_id"); groupIDStr != "" {
- if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
- db = db.Where("group_id = ?", groupID)
- }
- }
- return db
+func (s *LogService) applyFilters(query *gorm.DB, params LogQueryParams) *gorm.DB {
+ if params.ModelName != "" {
+ query = query.Where("model_name = ?", params.ModelName)
}
+ if params.IsSuccess != nil {
+ query = query.Where("is_success = ?", *params.IsSuccess)
+ }
+ if params.StatusCode != nil {
+ query = query.Where("status_code = ?", *params.StatusCode)
+ }
+ if params.KeyID != nil {
+ query = query.Where("key_id = ?", *params.KeyID)
+ }
+ if params.GroupID != nil {
+ query = query.Where("group_id = ?", *params.GroupID)
+ }
+ return query
+}
+
+// ParseLogQueryParams 在 Handler 层调用,解析 Gin 参数
+func ParseLogQueryParams(queryParams map[string]string) (LogQueryParams, error) {
+ params := LogQueryParams{
+ Page: 1,
+ PageSize: 20,
+ }
+
+ if pageStr, ok := queryParams["page"]; ok {
+ if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
+ params.Page = page
+ }
+ }
+
+ if pageSizeStr, ok := queryParams["page_size"]; ok {
+ if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
+ params.PageSize = pageSize
+ }
+ }
+
+ if modelName, ok := queryParams["model_name"]; ok {
+ params.ModelName = modelName
+ }
+
+ if isSuccessStr, ok := queryParams["is_success"]; ok {
+ if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
+ params.IsSuccess = &isSuccess
+ } else {
+ return params, fmt.Errorf("invalid is_success parameter: %s", isSuccessStr)
+ }
+ }
+
+ if statusCodeStr, ok := queryParams["status_code"]; ok {
+ if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
+ params.StatusCode = &statusCode
+ } else {
+ return params, fmt.Errorf("invalid status_code parameter: %s", statusCodeStr)
+ }
+ }
+
+ if keyIDStr, ok := queryParams["key_id"]; ok {
+ if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
+ params.KeyID = &keyID
+ } else {
+ return params, fmt.Errorf("invalid key_id parameter: %s", keyIDStr)
+ }
+ }
+
+ if groupIDStr, ok := queryParams["group_id"]; ok {
+ if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
+ params.GroupID = &groupID
+ } else {
+ return params, fmt.Errorf("invalid group_id parameter: %s", groupIDStr)
+ }
+ }
+
+ return params, nil
}
diff --git a/internal/service/stats_service.go b/internal/service/stats_service.go
index 50f7b9c..6babb30 100644
--- a/internal/service/stats_service.go
+++ b/internal/service/stats_service.go
@@ -35,34 +35,55 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
- sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
- if err != nil {
- s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
- return
- }
- go func() {
- defer sub.Close()
- for {
- select {
- case msg := <-sub.Channel():
- var event models.KeyStatusChangedEvent
- if err := json.Unmarshal(msg.Payload, &event); err != nil {
- s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
- continue
- }
- s.handleKeyStatusChange(&event)
- case <-s.stopChan:
- s.logger.Info("Stopping stats event listener.")
- return
- }
- }
- }()
+ go s.listenForEvents()
}
func (s *StatsService) Stop() {
close(s.stopChan)
}
+func (s *StatsService) listenForEvents() {
+ for {
+ select {
+ case <-s.stopChan:
+ s.logger.Info("Stopping stats event listener.")
+ return
+ default:
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
+ if err != nil {
+ s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err)
+ cancel()
+ time.Sleep(5 * time.Second)
+ continue
+ }
+
+ s.logger.Info("Subscribed to key status changes")
+ s.handleSubscription(sub, cancel)
+ }
+}
+
+func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) {
+ defer sub.Close()
+ defer cancel()
+
+ for {
+ select {
+ case msg := <-sub.Channel():
+ var event models.KeyStatusChangedEvent
+ if err := json.Unmarshal(msg.Payload, &event); err != nil {
+ s.logger.Errorf("Failed to unmarshal event: %v", err)
+ continue
+ }
+ s.handleKeyStatusChange(&event)
+ case <-s.stopChan:
+ return
+ }
+ }
+}
+
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
if event.GroupID == 0 {
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
@@ -75,23 +96,47 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
- s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
- s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
+ if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil {
+ s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID)
+ s.RecalculateGroupKeyStats(ctx, event.GroupID)
+ return
+ }
+ if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
+ s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
+ s.RecalculateGroupKeyStats(ctx, event.GroupID)
+ return
+ }
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
- s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
- s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
+ if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil {
+ s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID)
+ s.RecalculateGroupKeyStats(ctx, event.GroupID)
+ return
+ }
+ if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
+ s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
+ s.RecalculateGroupKeyStats(ctx, event.GroupID)
+ return
+ }
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
- s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
- s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
+ if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
+ s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
+ s.RecalculateGroupKeyStats(ctx, event.GroupID)
+ return
+ }
+ if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
+ s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
+ s.RecalculateGroupKeyStats(ctx, event.GroupID)
+ return
+ }
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
@@ -113,13 +158,16 @@ func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uin
}
statsKey := fmt.Sprintf("stats:group:%d", groupID)
- updates := make(map[string]interface{})
- totalKeys := int64(0)
+ updates := map[string]interface{}{
+ "active_keys": int64(0),
+ "disabled_keys": int64(0),
+ "error_keys": int64(0),
+ "total_keys": int64(0),
+ }
for _, res := range results {
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
- totalKeys += res.Count
+ updates["total_keys"] = updates["total_keys"].(int64) + res.Count
}
- updates["total_keys"] = totalKeys
if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
@@ -180,8 +228,18 @@ func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
})
}
- return s.db.WithContext(ctx).Clauses(clause.OnConflict{
+ if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
- }).Create(&hourlyStats).Error
+ }).Create(&hourlyStats).Error; err != nil {
+ return err
+ }
+
+ if err := s.db.WithContext(ctx).
+ Where("request_time >= ? AND request_time < ?", startTime, endTime).
+ Delete(&models.RequestLog{}).Error; err != nil {
+ s.logger.WithError(err).Warn("Failed to delete aggregated request logs")
+ }
+
+ return nil
}
diff --git a/internal/service/token_manager.go b/internal/service/token_manager.go
index 4f4fc4c..99ca74d 100644
--- a/internal/service/token_manager.go
+++ b/internal/service/token_manager.go
@@ -37,7 +37,7 @@ func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, log
return tokens, nil
}
- s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged)
+ s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged, logger)
if err != nil {
return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
}
diff --git a/internal/settings/settings.go b/internal/settings/settings.go
index d8217f1..092f0b4 100644
--- a/internal/settings/settings.go
+++ b/internal/settings/settings.go
@@ -87,7 +87,7 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
return settings, nil
}
- s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
+ s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel, logger,)
if err != nil {
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
}
diff --git a/internal/syncer/syncer.go b/internal/syncer/syncer.go
index 66bdaf5..cbbf4b5 100644
--- a/internal/syncer/syncer.go
+++ b/internal/syncer/syncer.go
@@ -4,46 +4,54 @@ import (
"context"
"fmt"
"gemini-balancer/internal/store"
- "log"
"sync"
"time"
+
+ "github.com/sirupsen/logrus"
+)
+
+const (
+ ReconnectDelay = 5 * time.Second
+ ReloadTimeout = 30 * time.Second
)
-// LoaderFunc
type LoaderFunc[T any] func() (T, error)
-// CacheSyncer
type CacheSyncer[T any] struct {
mu sync.RWMutex
cache T
loader LoaderFunc[T]
store store.Store
channelName string
+ logger *logrus.Entry
stopChan chan struct{}
wg sync.WaitGroup
}
-// NewCacheSyncer
func NewCacheSyncer[T any](
loader LoaderFunc[T],
store store.Store,
channelName string,
+ logger *logrus.Logger,
) (*CacheSyncer[T], error) {
s := &CacheSyncer[T]{
loader: loader,
store: store,
channelName: channelName,
+ logger: logger.WithField("component", fmt.Sprintf("CacheSyncer[%s]", channelName)),
stopChan: make(chan struct{}),
}
+
if err := s.reload(); err != nil {
- return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err)
+ return nil, fmt.Errorf("initial load failed: %w", err)
}
+
s.wg.Add(1)
go s.listenForUpdates()
+
return s, nil
}
-// Get, Invalidate, Stop, reload 方法 .
func (s *CacheSyncer[T]) Get() T {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -51,33 +59,60 @@ func (s *CacheSyncer[T]) Get() T {
}
func (s *CacheSyncer[T]) Invalidate() error {
- log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
- return s.store.Publish(context.Background(), s.channelName, []byte("reload"))
+ s.logger.Info("Publishing invalidation notification")
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ if err := s.store.Publish(ctx, s.channelName, []byte("reload")); err != nil {
+ s.logger.WithError(err).Error("Failed to publish invalidation")
+ return err
+ }
+ return nil
}
func (s *CacheSyncer[T]) Stop() {
close(s.stopChan)
s.wg.Wait()
- log.Printf("INFO: CacheSyncer for channel '%s' stopped.", s.channelName)
+ s.logger.Info("CacheSyncer stopped")
}
func (s *CacheSyncer[T]) reload() error {
- log.Printf("INFO: Reloading cache for channel '%s'...", s.channelName)
- newData, err := s.loader()
- if err != nil {
- log.Printf("ERROR: Failed to reload cache for '%s': %v", s.channelName, err)
- return err
+ s.logger.Info("Reloading cache...")
+
+ ctx, cancel := context.WithTimeout(context.Background(), ReloadTimeout)
+ defer cancel()
+
+ type result struct {
+ data T
+ err error
+ }
+ resultChan := make(chan result, 1)
+
+ go func() {
+ data, err := s.loader()
+ resultChan <- result{data, err}
+ }()
+
+ select {
+ case res := <-resultChan:
+ if res.err != nil {
+ s.logger.WithError(res.err).Error("Failed to reload cache")
+ return res.err
+ }
+ s.mu.Lock()
+ s.cache = res.data
+ s.mu.Unlock()
+ s.logger.Info("Cache reloaded successfully")
+ return nil
+ case <-ctx.Done():
+ s.logger.Error("Cache reload timeout")
+ return fmt.Errorf("reload timeout after %v", ReloadTimeout)
}
- s.mu.Lock()
- s.cache = newData
- s.mu.Unlock()
- log.Printf("INFO: Cache for channel '%s' reloaded successfully.", s.channelName)
- return nil
}
-// listenForUpdates ...
func (s *CacheSyncer[T]) listenForUpdates() {
defer s.wg.Done()
+
for {
select {
case <-s.stopChan:
@@ -85,31 +120,39 @@ func (s *CacheSyncer[T]) listenForUpdates() {
default:
}
- subscription, err := s.store.Subscribe(context.Background(), s.channelName)
- if err != nil {
- log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
- time.Sleep(5 * time.Second)
- continue
- }
- log.Printf("INFO: Subscribed to channel '%s' for cache invalidation.", s.channelName)
-
- subscriberLoop:
- for {
+ if err := s.subscribeAndListen(); err != nil {
+ s.logger.WithError(err).Warnf("Subscription error, retrying in %v", ReconnectDelay)
select {
- case _, ok := <-subscription.Channel():
- if !ok {
- log.Printf("WARN: Subscription channel '%s' closed, will re-subscribe.", s.channelName)
- break subscriberLoop
- }
- log.Printf("INFO: Received invalidation notification on '%s', reloading cache.", s.channelName)
- if err := s.reload(); err != nil {
- log.Printf("ERROR: Failed to reload cache for '%s' after notification: %v", s.channelName, err)
- }
+ case <-time.After(ReconnectDelay):
case <-s.stopChan:
- subscription.Close()
return
}
}
- subscription.Close()
+ }
+}
+
+func (s *CacheSyncer[T]) subscribeAndListen() error {
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ subscription, err := s.store.Subscribe(ctx, s.channelName)
+ if err != nil {
+ return fmt.Errorf("failed to subscribe: %w", err)
+ }
+ defer subscription.Close()
+ s.logger.Info("Subscribed to channel")
+ for {
+ select {
+ case msg, ok := <-subscription.Channel():
+ if !ok {
+ return fmt.Errorf("subscription channel closed")
+ }
+ s.logger.WithField("message", string(msg.Payload)).Info("Received invalidation notification")
+ if err := s.reload(); err != nil {
+ s.logger.WithError(err).Error("Failed to reload after notification")
+ }
+ case <-s.stopChan:
+ return nil
+ }
}
}
diff --git a/internal/task/task.go b/internal/task/task.go
index aa204c8..2345d1f 100644
--- a/internal/task/task.go
+++ b/internal/task/task.go
@@ -1,4 +1,3 @@
-// Filename: internal/task/task.go
package task
import (
@@ -13,7 +12,9 @@ import (
)
const (
- ResultTTL = 60 * time.Minute
+ ResultTTL = 60 * time.Minute
+ DefaultTimeout = 24 * time.Hour
+ LockTTL = 30 * time.Minute
)
type Reporter interface {
@@ -65,14 +66,21 @@ func (s *Task) getIsRunningFlagKey(taskID string) string {
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID)
+ taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
- if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
+ locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to acquire task lock: %w", err)
+ }
+ if !locked {
+ existingTaskID, _ := s.store.Get(ctx, lockKey)
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
}
- taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
- taskKey := s.getTaskDataKey(taskID)
- runningFlagKey := s.getIsRunningFlagKey(taskID)
+ if timeout == 0 {
+ timeout = DefaultTimeout
+ }
+
status := &Status{
ID: taskID,
TaskType: taskType,
@@ -81,63 +89,55 @@ func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourc
Total: total,
StartedAt: time.Now(),
}
- statusBytes, err := json.Marshal(status)
- if err != nil {
- return nil, fmt.Errorf("failed to serialize new task status: %w", err)
- }
- if timeout == 0 {
- timeout = ResultTTL * 24
- }
-
- if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
- return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
- }
- if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
+ if err := s.saveStatus(ctx, taskID, status, timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
- return nil, fmt.Errorf("failed to set new task data in store: %w", err)
+ return nil, fmt.Errorf("failed to save task status: %w", err)
}
+ runningFlagKey := s.getIsRunningFlagKey(taskID)
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
- _ = s.store.Del(ctx, taskKey)
- return nil, fmt.Errorf("failed to set task running flag: %w", err)
+ _ = s.store.Del(ctx, s.getTaskDataKey(taskID))
+ return nil, fmt.Errorf("failed to set running flag: %w", err)
}
+
return status, nil
}
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID)
- defer func() {
- if err := s.store.Del(ctx, lockKey); err != nil {
- s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
- }
- }()
runningFlagKey := s.getIsRunningFlagKey(taskID)
- _ = s.store.Del(ctx, runningFlagKey)
+
+ defer func() {
+ _ = s.store.Del(ctx, lockKey)
+ _ = s.store.Del(ctx, runningFlagKey)
+ }()
status, err := s.GetStatus(ctx, taskID)
if err != nil {
- s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
+ s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID)
return
}
+
if !status.IsRunning {
- s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
+ s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID)
return
}
+
now := time.Now()
status.IsRunning = false
status.FinishedAt = &now
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
+
if taskErr != nil {
status.Error = taskErr.Error()
} else {
status.Result = resultData
}
- updatedTaskBytes, _ := json.Marshal(status)
- taskKey := s.getTaskDataKey(taskID)
- if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
- s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
+
+ if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil {
+ s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID)
}
}
@@ -148,43 +148,42 @@ func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found")
}
- return nil, fmt.Errorf("failed to get task status from store: %w", err)
+ return nil, fmt.Errorf("failed to get task status: %w", err)
}
var status Status
if err := json.Unmarshal(statusBytes, &status); err != nil {
- return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
+ return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err)
}
+
if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
}
+
return &status, nil
}
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
- return nil
+ if errors.Is(err, store.ErrNotFound) {
+ return errors.New("task is not running")
+ }
+ return fmt.Errorf("failed to check running flag: %w", err)
}
+
status, err := s.GetStatus(ctx, taskID)
if err != nil {
- s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
- return nil
+ return fmt.Errorf("failed to get task status: %w", err)
}
+
if !status.IsRunning {
- return nil
+ return errors.New("task is not running")
}
+
updater(status)
- statusBytes, marshalErr := json.Marshal(status)
- if marshalErr != nil {
- s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
- return nil
- }
- taskKey := s.getTaskDataKey(taskID)
- if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
- s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
- }
- return nil
+
+ return s.saveStatus(ctx, taskID, status, DefaultTimeout)
}
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
@@ -198,3 +197,17 @@ func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) er
status.Total = total
})
}
+
+func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error {
+ statusBytes, err := json.Marshal(status)
+ if err != nil {
+ return fmt.Errorf("failed to serialize status: %w", err)
+ }
+
+ taskKey := s.getTaskDataKey(taskID)
+ if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil {
+ return fmt.Errorf("failed to save status: %w", err)
+ }
+
+ return nil
+}
diff --git a/internal/webhandlers/auth_handler.go b/internal/webhandlers/auth_handler.go
index cc7ca0f..ef7cc5d 100644
--- a/internal/webhandlers/auth_handler.go
+++ b/internal/webhandlers/auth_handler.go
@@ -1,43 +1,118 @@
-// Filename: internal/webhandlers/auth_handler.go (最终现代化改造版)
+// Filename: internal/webhandlers/auth_handler.go
+
package webhandlers
import (
"gemini-balancer/internal/middleware"
- "gemini-balancer/internal/service" // [核心改造] 依赖service层
+ "gemini-balancer/internal/service"
"net/http"
+ "strings"
"github.com/gin-gonic/gin"
+ "github.com/sirupsen/logrus"
)
-// WebAuthHandler [核心改造] 依赖关系净化,注入SecurityService
+// WebAuthHandler Web 认证处理器
type WebAuthHandler struct {
securityService *service.SecurityService
+ logger *logrus.Logger
}
-// NewWebAuthHandler [核心改造] 构造函数更新
+// NewWebAuthHandler 创建 WebAuthHandler
func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler {
+ logger := logrus.New()
+ logger.SetLevel(logrus.InfoLevel)
+
return &WebAuthHandler{
securityService: securityService,
+ logger: logger,
}
}
-// ShowLoginPage 保持不变
+// ShowLoginPage 显示登录页面
func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) {
errMsg := c.Query("error")
- from := c.Query("from") // 可以从登录失败的页面返回
+
+ // 验证重定向路径(防止开放重定向攻击)
+ redirectPath := h.validateRedirectPath(c.Query("redirect"))
+
+ // 如果已登录,直接重定向
+ if cookie := middleware.ExtractTokenFromCookie(c); cookie != "" {
+ if _, err := h.securityService.AuthenticateToken(cookie); err == nil {
+ c.Redirect(http.StatusFound, redirectPath)
+ return
+ }
+ }
+
c.HTML(http.StatusOK, "auth.html", gin.H{
- "error": errMsg,
- "from": from,
+ "error": errMsg,
+ "redirect": redirectPath,
})
}
-// HandleLogin [核心改造] 认证逻辑完全委托给SecurityService
+// HandleLogin 已废弃(项目无用户名系统)
func (h *WebAuthHandler) HandleLogin(c *gin.Context) {
c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD")
}
-// HandleLogout 保持不变
+// HandleLogout 处理登出请求
func (h *WebAuthHandler) HandleLogout(c *gin.Context) {
+ cookie := middleware.ExtractTokenFromCookie(c)
+
+ if cookie != "" {
+ // 尝试获取 Token 信息用于日志
+ authToken, err := h.securityService.AuthenticateToken(cookie)
+ if err == nil {
+ h.logger.WithFields(logrus.Fields{
+ "token_id": authToken.ID,
+ "client_ip": c.ClientIP(),
+ }).Info("User logged out")
+ } else {
+ h.logger.WithField("client_ip", c.ClientIP()).Warn("Logout with invalid token")
+ }
+
+ // 使缓存失效
+ middleware.InvalidateTokenCache(cookie)
+ } else {
+ h.logger.WithField("client_ip", c.ClientIP()).Debug("Logout without session cookie")
+ }
+
+ // 清除 Cookie
middleware.ClearAdminSessionCookie(c)
+
+ // 重定向到登录页
c.Redirect(http.StatusFound, "/login")
}
+
+// validateRedirectPath 验证重定向路径(防止开放重定向攻击)
+func (h *WebAuthHandler) validateRedirectPath(path string) string {
+ defaultPath := "/dashboard"
+
+ if path == "" {
+ return defaultPath
+ }
+
+ // 只允许内部路径
+ if !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
+ h.logger.WithField("path", path).Warn("Invalid redirect path blocked")
+ return defaultPath
+ }
+
+ // 白名单验证
+ allowedPaths := []string{
+ "/dashboard",
+ "/keys",
+ "/settings",
+ "/logs",
+ "/tasks",
+ "/chat",
+ }
+
+ for _, allowed := range allowedPaths {
+ if strings.HasPrefix(path, allowed) {
+ return path
+ }
+ }
+
+ return defaultPath
+}
diff --git a/web/templates/base.html b/web/templates/base.html
index 55df2f2..b213f55 100644
--- a/web/templates/base.html
+++ b/web/templates/base.html
@@ -123,7 +123,7 @@
{% block core_scripts %}
-
+
{% endblock core_scripts %}
diff --git a/web/templates/settings.html b/web/templates/settings.html
index f2bea8c..271b1ab 100644
--- a/web/templates/settings.html
+++ b/web/templates/settings.html
@@ -492,7 +492,7 @@
type="text"
id="TEST_MODEL"
name="TEST_MODEL"
- placeholder="gemini-1.5-flash"
+ placeholder="gemini-2.0-flash-lite"
class="flex-grow px-4 py-3 rounded-lg form-input-themed"
/>