Fix Services & Update the middleware && others

This commit is contained in:
XOF
2025-11-24 04:48:07 +08:00
parent 3a95a07e8a
commit f2706d6fc8
37 changed files with 4458 additions and 1166 deletions

View File

@@ -3,10 +3,12 @@ package main
import ( import (
"gemini-balancer/internal/app" "gemini-balancer/internal/app"
"gemini-balancer/internal/container" "gemini-balancer/internal/container"
"gemini-balancer/internal/logging"
"log" "log"
) )
func main() { func main() {
defer logging.Close()
cont, err := container.BuildContainer() cont, err := container.BuildContainer()
if err != nil { if err != nil {
log.Fatalf("FATAL: Failed to build dependency container: %v", err) log.Fatalf("FATAL: Failed to build dependency container: %v", err)

View File

@@ -14,6 +14,12 @@ server:
log: log:
level: "debug" level: "debug"
# 日志轮转配置
max_size: 100 # MB
max_backups: 7 # 保留文件数
max_age: 30 # 保留天数
compress: true # 压缩旧日志
redis: redis:
dsn: "redis://localhost:6379/0" dsn: "redis://localhost:6379/0"

2
go.mod
View File

@@ -17,6 +17,8 @@ require (
github.com/spf13/viper v1.20.1 github.com/spf13/viper v1.20.1
go.uber.org/dig v1.19.0 go.uber.org/dig v1.19.0
golang.org/x/net v0.42.0 golang.org/x/net v0.42.0
golang.org/x/time v0.14.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gorm.io/datatypes v1.0.5 gorm.io/datatypes v1.0.5
gorm.io/driver/mysql v1.6.0 gorm.io/driver/mysql v1.6.0
gorm.io/driver/postgres v1.6.0 gorm.io/driver/postgres v1.6.0

4
go.sum
View File

@@ -311,6 +311,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
@@ -325,6 +327,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -29,7 +29,9 @@ type DatabaseConfig struct {
// ServerConfig 存储HTTP服务器配置 // ServerConfig 存储HTTP服务器配置
type ServerConfig struct { type ServerConfig struct {
Port string `mapstructure:"port"` Port string `mapstructure:"port"`
Host string `yaml:"host"`
CORSOrigins []string `yaml:"cors_origins"`
} }
// LogConfig 存储日志配置 // LogConfig 存储日志配置
@@ -38,6 +40,12 @@ type LogConfig struct {
Format string `mapstructure:"format" json:"format"` Format string `mapstructure:"format" json:"format"`
EnableFile bool `mapstructure:"enable_file" json:"enable_file"` EnableFile bool `mapstructure:"enable_file" json:"enable_file"`
FilePath string `mapstructure:"file_path" json:"file_path"` FilePath string `mapstructure:"file_path" json:"file_path"`
// 日志轮转配置(可选)
MaxSize int `yaml:"max_size"` // MB默认 100
MaxBackups int `yaml:"max_backups"` // 默认 7
MaxAge int `yaml:"max_age"` // 天,默认 30
Compress bool `yaml:"compress"` // 默认 true
} }
type RedisConfig struct { type RedisConfig struct {

View File

@@ -87,7 +87,7 @@ func BuildContainer() (*dig.Container, error) {
// 为GroupManager配置Syncer // 为GroupManager配置Syncer
container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) { container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) {
const groupUpdateChannel = "groups:cache_invalidation" const groupUpdateChannel = "groups:cache_invalidation"
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel) return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger)
}) })
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) =========== // =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========

View File

@@ -20,7 +20,7 @@ type Module struct {
func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) { func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) {
loader := newManagerLoader(gormDB) loader := newManagerLoader(gormDB)
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation") cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation", logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -71,6 +71,7 @@ var clientNetworkErrorSubstrings = []string{
"broken pipe", "broken pipe",
"use of closed network connection", "use of closed network connection",
"request canceled", "request canceled",
"invalid query parameters", // 参数解析错误,归类为客户端错误
} }
// IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid. // IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid.

View File

@@ -5,7 +5,6 @@ import (
"gemini-balancer/internal/errors" "gemini-balancer/internal/errors"
"gemini-balancer/internal/response" "gemini-balancer/internal/response"
"gemini-balancer/internal/service" "gemini-balancer/internal/service"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -19,22 +18,26 @@ func NewLogHandler(logService *service.LogService) *LogHandler {
} }
func (h *LogHandler) GetLogs(c *gin.Context) { func (h *LogHandler) GetLogs(c *gin.Context) {
// 调用新的服务函数,接收日志列表和总数 queryParams := make(map[string]string)
logs, total, err := h.logService.GetLogs(c) for key, values := range c.Request.URL.Query() {
if len(values) > 0 {
queryParams[key] = values[0]
}
}
params, err := service.ParseLogQueryParams(queryParams)
if err != nil {
response.Error(c, errors.ErrBadRequest)
return
}
logs, total, err := h.logService.GetLogs(c.Request.Context(), params)
if err != nil { if err != nil {
response.Error(c, errors.ErrDatabase) response.Error(c, errors.ErrDatabase)
return return
} }
// 解析分页参数用于响应体
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
// 使用标准的分页响应结构
response.Success(c, gin.H{ response.Success(c, gin.H{
"items": logs, "items": logs,
"total": total, "total": total,
"page": page, "page": params.Page,
"page_size": pageSize, "page_size": params.PageSize,
}) })
} }

View File

@@ -9,20 +9,25 @@ import (
"path/filepath" "path/filepath"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
) )
// 包级变量,用于存储日志轮转器
var logRotator *lumberjack.Logger
// NewLogger 返回标准的 *logrus.Logger兼容 Fx 依赖注入)
func NewLogger(cfg *config.Config) *logrus.Logger { func NewLogger(cfg *config.Config) *logrus.Logger {
logger := logrus.New() logger := logrus.New()
// 1. 设置日志级别 // 设置日志级别
level, err := logrus.ParseLevel(cfg.Log.Level) level, err := logrus.ParseLevel(cfg.Log.Level)
if err != nil { if err != nil {
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level specified, defaulting to 'info'.") logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level, defaulting to 'info'")
level = logrus.InfoLevel level = logrus.InfoLevel
} }
logger.SetLevel(level) logger.SetLevel(level)
// 2. 设置日志格式 // 设置日志格式
if cfg.Log.Format == "json" { if cfg.Log.Format == "json" {
logger.SetFormatter(&logrus.JSONFormatter{ logger.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: "2006-01-02T15:04:05.000Z07:00", TimestampFormat: "2006-01-02T15:04:05.000Z07:00",
@@ -39,36 +44,57 @@ func NewLogger(cfg *config.Config) *logrus.Logger {
}) })
} }
// 3. 设置日志输出 // 添加全局字段
hostname, _ := os.Hostname()
logger = logger.WithFields(logrus.Fields{
"service": "gemini-balancer",
"hostname": hostname,
}).Logger
// 设置日志输出
if cfg.Log.EnableFile { if cfg.Log.EnableFile {
if cfg.Log.FilePath == "" { if cfg.Log.FilePath == "" {
logger.Warn("Log file is enabled but no file path is specified. Logging to console only.") logger.Warn("Log file enabled but no path specified. Logging to console only")
logger.SetOutput(os.Stdout) logger.SetOutput(os.Stdout)
return logger return logger
} }
logDir := filepath.Dir(cfg.Log.FilePath) logDir := filepath.Dir(cfg.Log.FilePath)
if err := os.MkdirAll(logDir, 0755); err != nil { if err := os.MkdirAll(logDir, 0750); err != nil {
logger.WithError(err).Warn("Failed to create log directory. Logging to console only.") logger.WithError(err).Warn("Failed to create log directory. Logging to console only")
logger.SetOutput(os.Stdout) logger.SetOutput(os.Stdout)
return logger return logger
} }
logFile, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) // 配置日志轮转(保存到包级变量)
if err != nil { logRotator = &lumberjack.Logger{
logger.WithError(err).Warn("Failed to open log file. Logging to console only.") Filename: cfg.Log.FilePath,
logger.SetOutput(os.Stdout) MaxSize: getOrDefault(cfg.Log.MaxSize, 100),
return logger MaxBackups: getOrDefault(cfg.Log.MaxBackups, 7),
MaxAge: getOrDefault(cfg.Log.MaxAge, 30),
Compress: cfg.Log.Compress,
} }
// 同时输出到控制台和文件 logger.SetOutput(io.MultiWriter(os.Stdout, logRotator))
logger.SetOutput(io.MultiWriter(os.Stdout, logFile)) logger.WithField("log_file", cfg.Log.FilePath).Info("Logging to both console and file")
logger.WithField("log_file_path", cfg.Log.FilePath).Info("Logging is now configured to output to both console and file.")
} else { } else {
// 仅输出到控制台
logger.SetOutput(os.Stdout) logger.SetOutput(os.Stdout)
} }
logger.Info("Root logger initialized.") logger.Info("Logger initialized successfully")
return logger return logger
} }
// Close 关闭日志轮转器(在 main.go 中调用)
func Close() {
if logRotator != nil {
logRotator.Close()
}
}
func getOrDefault(value, defaultValue int) int {
if value <= 0 {
return defaultValue
}
return value
}

View File

@@ -1,4 +1,5 @@
// Filename: internal/middleware/auth.go // Filename: internal/middleware/auth.go
package middleware package middleware
import ( import (
@@ -7,76 +8,115 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
) )
// === API Admin 认证管道 (/admin/* API路由) === type ErrorResponse struct {
Error string `json:"error"`
Code string `json:"code,omitempty"`
}
func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc { // APIAdminAuthMiddleware 管理后台 API 认证
func APIAdminAuthMiddleware(
securityService *service.SecurityService,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
tokenValue := extractBearerToken(c) tokenValue := extractBearerToken(c)
if tokenValue == "" { if tokenValue == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"}) c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Authentication required",
Code: "AUTH_MISSING",
})
return return
} }
// ✅ 只传 token 参数(移除 context
authToken, err := securityService.AuthenticateToken(tokenValue) authToken, err := securityService.AuthenticateToken(tokenValue)
if err != nil || !authToken.IsAdmin { if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"}) logger.WithError(err).Warn("Authentication failed")
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Invalid authentication",
Code: "AUTH_INVALID",
})
return return
} }
if !authToken.IsAdmin {
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
Error: "Admin access required",
Code: "AUTH_FORBIDDEN",
})
return
}
c.Set("adminUser", authToken) c.Set("adminUser", authToken)
c.Next() c.Next()
} }
} }
// === /v1 Proxy 认证 === // ProxyAuthMiddleware 代理请求认证
func ProxyAuthMiddleware(
func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc { securityService *service.SecurityService,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
tokenValue := extractProxyToken(c) tokenValue := extractProxyToken(c)
if tokenValue == "" { if tokenValue == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"}) c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "API key required",
Code: "KEY_MISSING",
})
return return
} }
// ✅ 只传 token 参数(移除 context
authToken, err := securityService.AuthenticateToken(tokenValue) authToken, err := securityService.AuthenticateToken(tokenValue)
if err != nil { if err != nil {
// 通用信息,避免泄露过多信息 c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"}) Error: "Invalid API key",
Code: "KEY_INVALID",
})
return return
} }
c.Set("authToken", authToken) c.Set("authToken", authToken)
c.Next() c.Next()
} }
} }
// extractProxyToken 按优先级提取 token
func extractProxyToken(c *gin.Context) string { func extractProxyToken(c *gin.Context) string {
if key := c.Query("key"); key != "" { // 优先级 1: Authorization Header
return key if token := extractBearerToken(c); token != "" {
} return token
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
} }
// 优先级 2: X-Api-Key
if key := c.GetHeader("X-Api-Key"); key != "" { if key := c.GetHeader("X-Api-Key"); key != "" {
return key return key
} }
// 优先级 3: X-Goog-Api-Key
if key := c.GetHeader("X-Goog-Api-Key"); key != "" { if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key return key
} }
return ""
// 优先级 4: Query 参数(不推荐)
return c.Query("key")
} }
// === 辅助函数 === // extractBearerToken 提取 Bearer Token
func extractBearerToken(c *gin.Context) string { func extractBearerToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" { if authHeader == "" {
return "" return ""
} }
parts := strings.Split(authHeader, " ")
if len(parts) == 2 && parts[0] == "Bearer" { const prefix = "Bearer "
return parts[1] if !strings.HasPrefix(authHeader, prefix) {
return ""
} }
return ""
return strings.TrimSpace(authHeader[len(prefix):])
} }

View File

@@ -0,0 +1,90 @@
// Filename: internal/middleware/cors.go
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
ExposedHeaders []string
AllowCredentials bool
MaxAge int
}
func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// 检查 origin 是否允许
if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
}
if config.AllowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if len(config.ExposedHeaders) > 0 {
c.Writer.Header().Set("Access-Control-Expose-Headers",
strings.Join(config.ExposedHeaders, ", "))
}
// 处理预检请求
if c.Request.Method == http.MethodOptions {
if len(config.AllowedMethods) > 0 {
c.Writer.Header().Set("Access-Control-Allow-Methods",
strings.Join(config.AllowedMethods, ", "))
}
if len(config.AllowedHeaders) > 0 {
c.Writer.Header().Set("Access-Control-Allow-Headers",
strings.Join(config.AllowedHeaders, ", "))
}
if config.MaxAge > 0 {
c.Writer.Header().Set("Access-Control-Max-Age",
string(rune(config.MaxAge)))
}
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
// 支持通配符子域名
if strings.HasPrefix(allowed, "*.") {
domain := strings.TrimPrefix(allowed, "*.")
if strings.HasSuffix(origin, domain) {
return true
}
}
}
return false
}
// 使用示例
func SetupCORS(r *gin.Engine) {
r.Use(CORSMiddleware(CORSConfig{
AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"},
ExposedHeaders: []string{"X-Request-Id"},
AllowCredentials: true,
MaxAge: 3600,
}))
}

View File

@@ -1,84 +1,213 @@
// Filename: internal/middleware/log_redaction.go // Filename: internal/middleware/logging.go
package middleware package middleware
import ( import (
"bytes" "bytes"
"io" "io"
"regexp" "regexp"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
const RedactedBodyKey = "redactedBody" const (
const RedactedAuthHeaderKey = "redactedAuthHeader" RedactedBodyKey = "redactedBody"
const RedactedValue = `"[REDACTED]"` RedactedAuthHeaderKey = "redactedAuthHeader"
RedactedValue = `"[REDACTED]"`
)
// 预编译正则表达式(全局变量,提升性能)
var (
// JSON 敏感字段脱敏
jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`)
// Bearer Token 脱敏
bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`)
// URL 中的 key 参数脱敏
queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`)
)
// RedactionMiddleware 请求数据脱敏中间件
func RedactionMiddleware() gin.HandlerFunc { func RedactionMiddleware() gin.HandlerFunc {
// Pre-compile regex for efficiency
jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
return func(c *gin.Context) { return func(c *gin.Context) {
// --- 1. Redact Request Body --- // 1. 脱敏请求体
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" { if shouldRedactBody(c) {
if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil { redactRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
bodyString := string(bodyBytes)
redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
c.Set(RedactedBodyKey, redactedBody)
}
}
// --- 2. Redact Authorization Header ---
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if bearerTokenPattern.MatchString(authHeader) {
redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
c.Set(RedactedAuthHeaderKey, redactedHeader)
}
} }
// 2. 脱敏认证头
redactAuthHeader(c)
// 3. 脱敏 URL 查询参数
redactQueryParams(c)
c.Next() c.Next()
} }
} }
// LogrusLogger is a Gin middleware that logs requests using a Logrus logger. // shouldRedactBody 判断是否需要脱敏请求体
// It consumes redacted data prepared by the RedactionMiddleware. func shouldRedactBody(c *gin.Context) bool {
method := c.Request.Method
contentType := c.GetHeader("Content-Type")
// 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求
return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") &&
strings.Contains(contentType, "application/json")
}
// redactRequestBody 脱敏请求体
func redactRequestBody(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return
}
// 恢复请求体供后续使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 脱敏敏感字段
bodyString := string(bodyBytes)
redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue)
c.Set(RedactedBodyKey, redactedBody)
}
// redactAuthHeader 脱敏认证头
func redactAuthHeader(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return
}
if bearerTokenPattern.MatchString(authHeader) {
redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
c.Set(RedactedAuthHeaderKey, redacted)
} else {
// 对于非 Bearer 的 token全部脱敏
c.Set(RedactedAuthHeaderKey, "[REDACTED]")
}
// 同时处理其他敏感 Header
sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"}
for _, header := range sensitiveHeaders {
if value := c.GetHeader(header); value != "" {
c.Set("redacted_"+header, "[REDACTED]")
}
}
}
// redactQueryParams 脱敏 URL 查询参数
func redactQueryParams(c *gin.Context) {
rawQuery := c.Request.URL.RawQuery
if rawQuery == "" {
return
}
redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`)
if redacted != rawQuery {
c.Set("redactedQuery", redacted)
}
}
// LogrusLogger Gin 请求日志中间件(使用 Logrus
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc { func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
start := time.Now() start := time.Now()
path := c.Request.URL.Path path := c.Request.URL.Path
method := c.Request.Method
// Process request // 处理请求
c.Next() c.Next()
// After request, gather data and log // 计算延迟
latency := time.Since(start) latency := time.Since(start)
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
clientIP := c.ClientIP()
entry := logger.WithFields(logrus.Fields{ // 构建日志字段
"status_code": statusCode, fields := logrus.Fields{
"latency_ms": latency.Milliseconds(), "status": statusCode,
"client_ip": c.ClientIP(), "latency_ms": latency.Milliseconds(),
"method": c.Request.Method, "ip": clientIP,
"path": path, "method": method,
}) "path": path,
}
// 添加请求 ID如果存在
if requestID := getRequestID(c); requestID != "" {
fields["request_id"] = requestID
}
// 添加脱敏后的数据
if redactedBody, exists := c.Get(RedactedBodyKey); exists { if redactedBody, exists := c.Get(RedactedBodyKey); exists {
entry = entry.WithField("body", redactedBody) fields["body"] = redactedBody
} }
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists { if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
entry = entry.WithField("authorization", redactedAuth) fields["authorization"] = redactedAuth
} }
if redactedQuery, exists := c.Get("redactedQuery"); exists {
fields["query"] = redactedQuery
}
// 添加用户信息(如果已认证)
if user := getAuthenticatedUser(c); user != "" {
fields["user"] = user
}
// 根据状态码选择日志级别
entry := logger.WithFields(fields)
if len(c.Errors) > 0 { if len(c.Errors) > 0 {
entry.Error(c.Errors.String()) fields["errors"] = c.Errors.String()
entry.Error("Request failed")
} else { } else {
entry.Info("request handled") switch {
case statusCode >= 500:
entry.Error("Server error")
case statusCode >= 400:
entry.Warn("Client error")
case statusCode >= 300:
entry.Info("Redirect")
default:
// 只在 Debug 模式记录成功请求
if logger.Level >= logrus.DebugLevel {
entry.Debug("Request completed")
}
}
} }
} }
} }
// getRequestID 获取请求 ID
func getRequestID(c *gin.Context) string {
if id, exists := c.Get("request_id"); exists {
if requestID, ok := id.(string); ok {
return requestID
}
}
return ""
}
// getAuthenticatedUser 获取已认证用户标识
func getAuthenticatedUser(c *gin.Context) string {
// 尝试从不同来源获取用户信息
if user, exists := c.Get("adminUser"); exists {
if authToken, ok := user.(interface{ GetID() string }); ok {
return authToken.GetID()
}
}
if user, exists := c.Get("authToken"); exists {
if authToken, ok := user.(interface{ GetID() string }); ok {
return authToken.GetID()
}
}
return ""
}

View File

@@ -0,0 +1,86 @@
// Filename: internal/middleware/rate_limit.go
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
r rate.Limit // 每秒请求数
b int // 突发容量
}
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
r: r,
b: b,
}
}
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
limiter, exists := rl.limiters[key]
if !exists {
limiter = rate.NewLimiter(rl.r, rl.b)
rl.limiters[key] = limiter
}
return limiter
}
// 定期清理不活跃的限制器
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
// 简单策略:定期清空(生产环境应该用 LRU
rl.limiters = make(map[string]*rate.Limiter)
rl.mu.Unlock()
}
}
func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
// 按 IP 限流
key := c.ClientIP()
// 如果有认证 token按 token 限流(更精确)
if authToken, exists := c.Get("authToken"); exists {
if token, ok := authToken.(interface{ GetID() string }); ok {
key = "token:" + token.GetID()
}
}
l := limiter.getLimiter(key)
if !l.Allow() {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"code": "RATE_LIMIT",
})
return
}
c.Next()
}
}
// 使用示例
func SetupRateLimit(r *gin.Engine) {
limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20
go limiter.cleanup()
r.Use(RateLimitMiddleware(limiter))
}

View File

@@ -0,0 +1,39 @@
// Filename: internal/middleware/request_id.go
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// RequestIDMiddleware 请求 ID 追踪中间件
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 1. 尝试从 Header 获取现有的 Request ID
requestID := c.GetHeader("X-Request-Id")
// 2. 如果没有,生成新的
if requestID == "" {
requestID = uuid.New().String()
}
// 3. 设置到 Context
c.Set("request_id", requestID)
// 4. 返回给客户端(用于追踪)
c.Writer.Header().Set("X-Request-Id", requestID)
c.Next()
}
}
// GetRequestID 获取当前请求的 Request ID
func GetRequestID(c *gin.Context) string {
if id, exists := c.Get("request_id"); exists {
if requestID, ok := id.(string); ok {
return requestID
}
}
return ""
}

View File

@@ -1,4 +1,4 @@
// Filename: internal/middleware/security.go // Filename: internal/middleware/security.go (简化版)
package middleware package middleware
@@ -6,26 +6,136 @@ import (
"gemini-balancer/internal/service" "gemini-balancer/internal/service"
"gemini-balancer/internal/settings" "gemini-balancer/internal/settings"
"net/http" "net/http"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
) )
func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc { // 简单的缓存项
type cacheItem struct {
value bool
expiration int64
}
// 简单的 TTL 缓存实现
type IPBanCache struct {
items map[string]*cacheItem
mu sync.RWMutex
ttl time.Duration
}
func NewIPBanCache() *IPBanCache {
cache := &IPBanCache{
items: make(map[string]*cacheItem),
ttl: 1 * time.Minute,
}
// 启动清理协程
go cache.cleanup()
return cache
}
func (c *IPBanCache) Get(key string) (bool, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, exists := c.items[key]
if !exists {
return false, false
}
// 检查是否过期
if time.Now().UnixNano() > item.expiration {
return false, false
}
return item.value, true
}
func (c *IPBanCache) Set(key string, value bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = &cacheItem{
value: value,
expiration: time.Now().Add(c.ttl).UnixNano(),
}
}
func (c *IPBanCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
func (c *IPBanCache) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now().UnixNano()
for key, item := range c.items {
if now > item.expiration {
delete(c.items, key)
}
}
c.mu.Unlock()
}
}
func IPBanMiddleware(
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
banCache *IPBanCache,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if !settingsManager.IsIPBanEnabled() { if !settingsManager.IsIPBanEnabled() {
c.Next() c.Next()
return return
} }
ip := c.ClientIP() ip := c.ClientIP()
isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
if err != nil { // 查缓存
if isBanned, exists := banCache.Get(ip); exists {
if isBanned {
logger.WithField("ip", ip).Debug("IP blocked (cached)")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next() c.Next()
return return
} }
if isBanned {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁请稍后再试"}) // 查数据库
ctx := c.Request.Context()
isBanned, err := securityService.IsIPBanned(ctx, ip)
if err != nil {
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
// 降级策略:允许访问
c.Next()
return return
} }
// 更新缓存
banCache.Set(ip, isBanned)
if isBanned {
logger.WithField("ip", ip).Info("IP blocked")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next() c.Next()
} }
} }

View File

@@ -0,0 +1,52 @@
// Filename: internal/middleware/timeout.go
package middleware
import (
"context"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
// 创建带超时的 context
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 替换 request context
c.Request = c.Request.WithContext(ctx)
// 使用 channel 等待请求完成
finished := make(chan struct{})
go func() {
c.Next()
close(finished)
}()
select {
case <-finished:
// 请求正常完成
return
case <-ctx.Done():
// 超时
c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
"error": "Request timeout",
"code": "TIMEOUT",
})
}
}
}
// 使用示例
func SetupTimeout(r *gin.Engine) {
// 对 API 路由设置 30 秒超时
api := r.Group("/api")
api.Use(TimeoutMiddleware(30 * time.Second))
{
// ... API routes
}
}

View File

@@ -1,23 +1,151 @@
// Filename: internal/middleware/web.go // Filename: internal/middleware/web.go
package middleware package middleware
import ( import (
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/service" "gemini-balancer/internal/service"
"log"
"net/http" "net/http"
"os"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
) )
const ( const (
AdminSessionCookie = "gemini_admin_session" AdminSessionCookie = "gemini_admin_session"
SessionMaxAge = 3600 * 24 * 7 // 7天
CacheTTL = 5 * time.Minute
CleanupInterval = 10 * time.Minute // 降低清理频率
SessionRefreshTime = 30 * time.Minute
) )
// ==================== 缓存层 ====================
type authCacheEntry struct {
Token interface{}
ExpiresAt time.Time
}
type authCache struct {
mu sync.RWMutex
cache map[string]*authCacheEntry
ttl time.Duration
}
var webAuthCache = newAuthCache(CacheTTL)
func newAuthCache(ttl time.Duration) *authCache {
c := &authCache{
cache: make(map[string]*authCacheEntry),
ttl: ttl,
}
go c.cleanupLoop()
return c
}
func (c *authCache) get(key string) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.cache[key]
if !exists || time.Now().After(entry.ExpiresAt) {
return nil, false
}
return entry.Token, true
}
func (c *authCache) set(key string, token interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = &authCacheEntry{
Token: token,
ExpiresAt: time.Now().Add(c.ttl),
}
}
func (c *authCache) delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
func (c *authCache) cleanupLoop() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for range ticker.C {
c.cleanup()
}
}
func (c *authCache) cleanup() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
count := 0
for key, entry := range c.cache {
if now.After(entry.ExpiresAt) {
delete(c.cache, key)
count++
}
}
if count > 0 {
logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count)
}
}
// ==================== 会话刷新缓存 ====================
var sessionRefreshCache = struct {
sync.RWMutex
timestamps map[string]time.Time
}{
timestamps: make(map[string]time.Time),
}
// 定期清理刷新时间戳
func init() {
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
sessionRefreshCache.Lock()
now := time.Now()
for key, ts := range sessionRefreshCache.timestamps {
if now.Sub(ts) > 2*time.Hour {
delete(sessionRefreshCache.timestamps, key)
}
}
sessionRefreshCache.Unlock()
}
}()
}
// ==================== Cookie 操作 ====================
func SetAdminSessionCookie(c *gin.Context, adminToken string) { func SetAdminSessionCookie(c *gin.Context, adminToken string) {
c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true) secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true)
}
func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) {
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true)
} }
func ClearAdminSessionCookie(c *gin.Context) { func ClearAdminSessionCookie(c *gin.Context) {
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true) c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
} }
@@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string {
return cookie return cookie
} }
// ==================== 认证中间件 ====================
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc { func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
logger := logrus.New()
logger.SetLevel(getLogLevel())
return func(c *gin.Context) { return func(c *gin.Context) {
cookie := ExtractTokenFromCookie(c) cookie := ExtractTokenFromCookie(c)
log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path) if cookie == "" {
log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie) logger.Debug("[WebAuth] No session cookie found")
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
} else if !authToken.IsAdmin {
log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
} else {
log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
}
if err != nil || !authToken.IsAdmin {
ClearAdminSessionCookie(c) ClearAdminSessionCookie(c)
c.Redirect(http.StatusFound, "/login") redirectToLogin(c)
c.Abort()
return return
} }
cacheKey := hashToken(cookie)
if cachedToken, found := webAuthCache.get(cacheKey); found {
logger.Debug("[WebAuth] Using cached token")
c.Set("adminUser", cachedToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
return
}
logger.Debug("[WebAuth] Cache miss, authenticating...")
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
logger.WithError(err).Warn("[WebAuth] Authentication failed")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
if !authToken.IsAdmin {
logger.Warn("[WebAuth] User is not admin")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
webAuthCache.set(cacheKey, authToken)
logger.Debug("[WebAuth] Authentication success, token cached")
c.Set("adminUser", authToken) c.Set("adminUser", authToken)
refreshSessionIfNeeded(c, cookie)
c.Next() c.Next()
} }
} }
func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
cookie := ExtractTokenFromCookie(c)
if cookie == "" {
logger.Debug("No session cookie found")
ClearAdminSessionCookie(c)
redirectToLogin(c)
return
}
cacheKey := hashToken(cookie)
if cachedToken, found := webAuthCache.get(cacheKey); found {
c.Set("adminUser", cachedToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
return
}
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
logger.WithError(err).Warn("Token authentication failed")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
if !authToken.IsAdmin {
logger.Warn("Token valid but user is not admin")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
webAuthCache.set(cacheKey, authToken)
c.Set("adminUser", authToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
}
}
// ==================== 辅助函数 ====================
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return hex.EncodeToString(h[:])
}
func redirectToLogin(c *gin.Context) {
if isAjaxRequest(c) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Session expired",
"code": "AUTH_REQUIRED",
})
c.Abort()
return
}
originalPath := c.Request.URL.Path
if originalPath != "/" && originalPath != "/login" {
c.Redirect(http.StatusFound, "/login?redirect="+originalPath)
} else {
c.Redirect(http.StatusFound, "/login")
}
c.Abort()
}
func isAjaxRequest(c *gin.Context) bool {
// 检查 Content-Type
contentType := c.GetHeader("Content-Type")
if strings.Contains(contentType, "application/json") {
return true
}
// 检查 Accept优先检查 JSON
accept := c.GetHeader("Accept")
if strings.Contains(accept, "application/json") &&
!strings.Contains(accept, "text/html") {
return true
}
// 兼容旧版 XMLHttpRequest
return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
}
func refreshSessionIfNeeded(c *gin.Context, token string) {
tokenHash := hashToken(token)
sessionRefreshCache.RLock()
lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash]
sessionRefreshCache.RUnlock()
if !exists || time.Since(lastRefresh) > SessionRefreshTime {
SetAdminSessionCookie(c, token)
sessionRefreshCache.Lock()
sessionRefreshCache.timestamps[tokenHash] = time.Now()
sessionRefreshCache.Unlock()
}
}
func getLogLevel() logrus.Level {
level := os.Getenv("LOG_LEVEL")
switch strings.ToLower(level) {
case "debug":
return logrus.DebugLevel
case "warn":
return logrus.WarnLevel
case "error":
return logrus.ErrorLevel
default:
return logrus.InfoLevel
}
}
// ==================== 工具函数 ====================
func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) {
return c.Get("adminUser")
}
func InvalidateTokenCache(token string) {
tokenHash := hashToken(token)
webAuthCache.delete(tokenHash)
// 同时清理刷新时间戳
sessionRefreshCache.Lock()
delete(sessionRefreshCache.timestamps, tokenHash)
sessionRefreshCache.Unlock()
}
func ClearAllAuthCache() {
webAuthCache.mu.Lock()
webAuthCache.cache = make(map[string]*authCacheEntry)
webAuthCache.mu.Unlock()
sessionRefreshCache.Lock()
sessionRefreshCache.timestamps = make(map[string]time.Time)
sessionRefreshCache.Unlock()
}
// ==================== 调试工具 ====================
type SessionInfo struct {
HasCookie bool `json:"has_cookie"`
IsValid bool `json:"is_valid"`
IsAdmin bool `json:"is_admin"`
IsCached bool `json:"is_cached"`
LastActivity string `json:"last_activity"`
}
func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo {
info := SessionInfo{
HasCookie: false,
IsValid: false,
IsAdmin: false,
IsCached: false,
LastActivity: time.Now().Format(time.RFC3339),
}
cookie := ExtractTokenFromCookie(c)
if cookie == "" {
return info
}
info.HasCookie = true
cacheKey := hashToken(cookie)
if _, found := webAuthCache.get(cacheKey); found {
info.IsCached = true
}
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
return info
}
info.IsValid = true
info.IsAdmin = authToken.IsAdmin
return info
}
func GetCacheStats() map[string]interface{} {
webAuthCache.mu.RLock()
cacheSize := len(webAuthCache.cache)
webAuthCache.mu.RUnlock()
sessionRefreshCache.RLock()
refreshSize := len(sessionRefreshCache.timestamps)
sessionRefreshCache.RUnlock()
return map[string]interface{}{
"auth_cache_entries": cacheSize,
"refresh_cache_entries": refreshSize,
"ttl_seconds": int(webAuthCache.ttl.Seconds()),
"cleanup_interval": int(CleanupInterval.Seconds()),
"session_refresh_time": int(SessionRefreshTime.Seconds()),
}
}

View File

@@ -27,6 +27,8 @@ type SystemSettings struct {
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"` BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"` BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
KeyCheckSchedulerIntervalSeconds int `json:"key_check_scheduler_interval_seconds" default:"60" name:"Key检查调度器间隔(秒)" category:"健康检查" desc:"动态调度器检查各组是否需要执行健康检查的周期。"`
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"` EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"` UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`

View File

@@ -1,63 +1,96 @@
// Filename: internal/pongo/renderer.go
package pongo package pongo
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"sync"
"github.com/flosch/pongo2/v6" "github.com/flosch/pongo2/v6"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/render" "github.com/gin-gonic/gin/render"
"github.com/sirupsen/logrus"
) )
type Renderer struct { type Renderer struct {
Context pongo2.Context mu sync.RWMutex
tplSet *pongo2.TemplateSet globalContext pongo2.Context
tplSet *pongo2.TemplateSet
logger *logrus.Logger
} }
func New(directory string, isDebug bool) *Renderer { func New(directory string, isDebug bool, logger *logrus.Logger) *Renderer {
loader := pongo2.MustNewLocalFileSystemLoader(directory) loader := pongo2.MustNewLocalFileSystemLoader(directory)
tplSet := pongo2.NewSet("gin-pongo-templates", loader) tplSet := pongo2.NewSet("gin-pongo-templates", loader)
tplSet.Debug = isDebug tplSet.Debug = isDebug
return &Renderer{Context: make(pongo2.Context), tplSet: tplSet} return &Renderer{
globalContext: make(pongo2.Context),
tplSet: tplSet,
logger: logger,
}
} }
// Instance returns a new render.HTML instance for a single request. // SetGlobalContext 线程安全地设置全局上下文
func (p *Renderer) Instance(name string, data interface{}) render.Render { func (p *Renderer) SetGlobalContext(key string, value interface{}) {
var glob pongo2.Context p.mu.Lock()
if p.Context != nil { defer p.mu.Unlock()
glob = p.Context p.globalContext[key] = value
} }
// Warmup 预加载模板
func (p *Renderer) Warmup(templateNames ...string) error {
for _, name := range templateNames {
if _, err := p.tplSet.FromCache(name); err != nil {
return fmt.Errorf("failed to warmup template '%s': %w", name, err)
}
}
p.logger.WithField("count", len(templateNames)).Info("Templates warmed up")
return nil
}
func (p *Renderer) Instance(name string, data interface{}) render.Render {
// 安全读取全局上下文
p.mu.RLock()
glob := make(pongo2.Context, len(p.globalContext))
for k, v := range p.globalContext {
glob[k] = v
}
p.mu.RUnlock()
// 解析请求数据
var context pongo2.Context var context pongo2.Context
if data != nil { if data != nil {
if ginContext, ok := data.(gin.H); ok { switch v := data.(type) {
context = pongo2.Context(ginContext) case gin.H:
} else if pongoContext, ok := data.(pongo2.Context); ok { context = pongo2.Context(v)
context = pongoContext case pongo2.Context:
} else if m, ok := data.(map[string]interface{}); ok { context = v
context = m case map[string]interface{}:
} else { context = v
default:
context = make(pongo2.Context) context = make(pongo2.Context)
} }
} else { } else {
context = make(pongo2.Context) context = make(pongo2.Context)
} }
// 合并上下文(请求数据优先)
for k, v := range glob { for k, v := range glob {
if _, ok := context[k]; !ok { if _, exists := context[k]; !exists {
context[k] = v context[k] = v
} }
} }
// 加载模板
tpl, err := p.tplSet.FromCache(name) tpl, err := p.tplSet.FromCache(name)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to load template '%s': %v", name, err)) p.logger.WithError(err).WithField("template", name).Error("Failed to load template")
return &ErrorHTML{
StatusCode: http.StatusInternalServerError,
Error: fmt.Errorf("template load error: %s", name),
}
} }
return &HTML{ return &HTML{
p: p,
Template: tpl, Template: tpl,
Name: name, Name: name,
Data: context, Data: context,
@@ -65,7 +98,6 @@ func (p *Renderer) Instance(name string, data interface{}) render.Render {
} }
type HTML struct { type HTML struct {
p *Renderer
Template *pongo2.Template Template *pongo2.Template
Name string Name string
Data pongo2.Context Data pongo2.Context
@@ -82,15 +114,31 @@ func (h *HTML) Render(w http.ResponseWriter) error {
} }
func (h *HTML) WriteContentType(w http.ResponseWriter) { func (h *HTML) WriteContentType(w http.ResponseWriter) {
header := w.Header() if w.Header().Get("Content-Type") == "" {
if val := header["Content-Type"]; len(val) == 0 { w.Header().Set("Content-Type", "text/html; charset=utf-8")
header["Content-Type"] = []string{"text/html; charset=utf-8"}
} }
} }
// ErrorHTML 错误渲染器
type ErrorHTML struct {
StatusCode int
Error error
}
func (e *ErrorHTML) Render(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(e.StatusCode)
_, err := w.Write([]byte(e.Error.Error()))
return err
}
func (e *ErrorHTML) WriteContentType(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
}
// C 获取或创建 pongo2 上下文
func C(ctx *gin.Context) pongo2.Context { func C(ctx *gin.Context) pongo2.Context {
p, exists := ctx.Get("pongo2") if p, exists := ctx.Get("pongo2"); exists {
if exists {
if pCtx, ok := p.(pongo2.Context); ok { if pCtx, ok := p.(pongo2.Context); ok {
return pCtx return pCtx
} }

View File

@@ -17,6 +17,7 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
) )
func NewRouter( func NewRouter(
@@ -42,70 +43,214 @@ func NewRouter(
upstreamModule *upstream.Module, upstreamModule *upstream.Module,
proxyModule *proxy.Module, proxyModule *proxy.Module,
) *gin.Engine { ) *gin.Engine {
// === 1. 创建全局 Logger统一管理===
logger := createLogger(cfg)
// === 2. 设置 Gin 运行模式 ===
if cfg.Log.Level != "debug" { if cfg.Log.Level != "debug" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
router := gin.Default()
router.Static("/static", "./web/static") // === 3. 创建 Router使用 gin.New() 以便完全控制中间件)===
// CORS 配置 router := gin.New()
config := cors.Config{
// 允许前端的来源。在生产环境中,需改为实际域名
AllowOrigins: []string{"http://localhost:9000"},
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
}
router.Use(cors.New(config))
isDebug := gin.Mode() != gin.ReleaseMode
router.HTMLRender = pongo.New("web/templates", isDebug)
// --- 基础设施 --- // === 4. 注册全局中间件(按执行顺序)===
router.GET("/", func(c *gin.Context) { c.Redirect(http.StatusMovedPermanently, "/dashboard") }) setupGlobalMiddleware(router, logger)
router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })
// --- 统一的认证管道 --- // === 5. 配置静态文件和模板 ===
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService) setupStaticAndTemplates(router, logger)
// === 6. 配置 CORS ===
setupCORS(router, cfg)
// === 7. 注册基础路由 ===
setupBasicRoutes(router)
// === 8. 创建认证中间件 ===
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger)
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService) webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
router.Use(gin.RecoveryWithWriter(os.Stdout)) // === 9. 注册业务路由(按功能分组)===
// --- 将正确的依赖和中间件管道传递下去 --- registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger)
registerProxyRoutes(router, proxyHandler, securityService)
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager)
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler) registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler,
logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
registerProxyRoutes(router, proxyHandler, securityService, logger)
return router return router
} }
// ==================== 辅助函数 ====================
// createLogger 创建并配置全局 Logger
func createLogger(cfg *config.Config) *logrus.Logger {
logger := logrus.New()
// 设置日志格式
if cfg.Log.Format == "json" {
logger.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: time.RFC3339,
})
} else {
logger.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
})
}
// 设置日志级别
switch cfg.Log.Level {
case "debug":
logger.SetLevel(logrus.DebugLevel)
case "info":
logger.SetLevel(logrus.InfoLevel)
case "warn":
logger.SetLevel(logrus.WarnLevel)
case "error":
logger.SetLevel(logrus.ErrorLevel)
default:
logger.SetLevel(logrus.InfoLevel)
}
// 设置输出(可选:输出到文件)
logger.SetOutput(os.Stdout)
return logger
}
// setupGlobalMiddleware 设置全局中间件
func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) {
// 1. 请求 ID 中间件(用于链路追踪)
router.Use(middleware.RequestIDMiddleware())
// 2. 数据脱敏中间件(在日志前执行)
router.Use(middleware.RedactionMiddleware())
// 3. 日志中间件
router.Use(middleware.LogrusLogger(logger))
// 4. 错误恢复中间件
router.Use(gin.RecoveryWithWriter(os.Stdout))
}
// setupStaticAndTemplates 配置静态文件和模板
func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) {
router.Static("/static", "./web/static")
isDebug := gin.Mode() != gin.ReleaseMode
router.HTMLRender = pongo.New("web/templates", isDebug, logger)
}
// setupCORS 配置 CORS
func setupCORS(router *gin.Engine, cfg *config.Config) {
corsConfig := cors.Config{
AllowOrigins: getCORSOrigins(cfg),
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"},
ExposeHeaders: []string{"Content-Length", "X-Request-Id"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
}
router.Use(cors.New(corsConfig))
}
// getCORSOrigins 获取 CORS 允许的来源
func getCORSOrigins(cfg *config.Config) []string {
// 默认值
origins := []string{"http://localhost:9000"}
// 从配置读取(修复:移除 nil 检查)
if len(cfg.Server.CORSOrigins) > 0 {
origins = cfg.Server.CORSOrigins
}
return origins
}
// setupBasicRoutes 设置基础路由
func setupBasicRoutes(router *gin.Engine) {
// 根路径重定向
router.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/dashboard")
})
// 健康检查
router.GET("/health", handleHealthCheck)
// 版本信息(可选)
router.GET("/version", handleVersion)
}
// handleHealthCheck 健康检查处理器
func handleHealthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"time": time.Now().Unix(),
})
}
// handleVersion 版本信息处理器
func handleVersion(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"version": "1.0.0", // 可以从配置或编译时变量读取
"build": "latest",
})
}
// ==================== 路由注册函数 ====================
// registerProxyRoutes 注册代理路由
func registerProxyRoutes( func registerProxyRoutes(
router *gin.Engine, proxyHandler *handlers.ProxyHandler, securityService *service.SecurityService, router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
securityService *service.SecurityService,
logger *logrus.Logger,
) { ) {
// 通用的代理认证中间件 // 创建代理认证中间件
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService) proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger)
// --- 模式一: 智能聚合模式 (根路径) ---
// /v1 和 /v1beta 路径作为默认入口,服务于 BasePool 聚合逻辑 // 模式一: 智能聚合模式(默认入口)
registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
// 模式二: 精确路由模式(按组名路由)
registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
}
// registerAggregateProxyRoutes 注册聚合代理路由
func registerAggregateProxyRoutes(
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
authMiddleware gin.HandlerFunc,
) {
// /v1 路径组
v1 := router.Group("/v1") v1 := router.Group("/v1")
v1.Use(proxyAuthMiddleware) v1.Use(authMiddleware)
{ {
v1.Any("/*path", proxyHandler.HandleProxy) v1.Any("/*path", proxyHandler.HandleProxy)
} }
// /v1beta 路径组
v1beta := router.Group("/v1beta") v1beta := router.Group("/v1beta")
v1beta.Use(proxyAuthMiddleware) v1beta.Use(authMiddleware)
{ {
v1beta.Any("/*path", proxyHandler.HandleProxy) v1beta.Any("/*path", proxyHandler.HandleProxy)
} }
// --- 模式二: 精确路由模式 (/proxy/:group_name) --- }
// 创建一个新的、物理隔离的路由组,用于按组名精确路由
// registerGroupProxyRoutes 注册分组代理路由
func registerGroupProxyRoutes(
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
authMiddleware gin.HandlerFunc,
) {
proxyGroup := router.Group("/proxy/:group_name") proxyGroup := router.Group("/proxy/:group_name")
proxyGroup.Use(proxyAuthMiddleware) proxyGroup.Use(authMiddleware)
{ {
// 捕获所有子路径 (例如 /v1/chat/completions),并全部交给同一个 ProxyHandler。
proxyGroup.Any("/*path", proxyHandler.HandleProxy) proxyGroup.Any("/*path", proxyHandler.HandleProxy)
} }
} }
// registerAdminRoutes // registerAdminRoutes 注册管理后台 API 路由
func registerAdminRoutes( func registerAdminRoutes(
router *gin.Engine, router *gin.Engine,
authMiddleware gin.HandlerFunc, authMiddleware gin.HandlerFunc,
@@ -121,74 +266,112 @@ func registerAdminRoutes(
) { ) {
admin := router.Group("/admin", authMiddleware) admin := router.Group("/admin", authMiddleware)
{ {
// --- KeyGroup Base Routes --- // KeyGroup 路由
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup) registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler)
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
// --- KeyGroup Specific Routes (by :id) ---
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
// --- APIKey Sub-resource Routes under a KeyGroup ---
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
{
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
}
// Global key operations // APIKey 全局路由
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys) registerAPIKeyRoutes(admin, apiKeyHandler)
// admin.PUT("/apikeys/:id", apiKeyHandler.UpdateAPIKey) // DEPRECATED: Status is now contextual
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) // Test keys globally
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) // Hard delete a single key
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) // Hard delete multiple keys
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) // Restore multiple keys globally
// --- Global Routes --- // 系统管理路由
admin.GET("/tokens", tokensHandler.GetAllTokens) registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler)
admin.PUT("/tokens", tokensHandler.UpdateTokens)
admin.GET("/logs", logHandler.GetLogs)
admin.GET("/settings", settingHandler.GetSettings)
admin.PUT("/settings", settingHandler.UpdateSettings)
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
// 用于查询异步任务的状态 // 仪表盘路由
admin.GET("/tasks/:id", taskHandler.GetTaskStatus) registerDashboardRoutes(admin, dashboardHandler)
// 领域模块 // 领域模块路由
upstreamModule.RegisterRoutes(admin) upstreamModule.RegisterRoutes(admin)
proxyModule.RegisterRoutes(admin) proxyModule.RegisterRoutes(admin)
// --- 全局仪表盘路由 ---
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/overview", dashboardHandler.GetOverview)
dashboard.GET("/chart", dashboardHandler.GetChart)
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) // 点击详情
}
} }
} }
// registerWebRoutes // registerKeyGroupRoutes 注册 KeyGroup 相关路由
func registerKeyGroupRoutes(
admin *gin.RouterGroup,
keyGroupHandler *handlers.KeyGroupHandler,
apiKeyHandler *handlers.APIKeyHandler,
) {
// 基础路由
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
// 特定 KeyGroup 路由
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
// KeyGroup 的 APIKey 子资源
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
{
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
}
}
// registerAPIKeyRoutes 注册 APIKey 全局路由
func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) {
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys)
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey)
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys)
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys)
}
// registerSystemRoutes 注册系统管理路由
func registerSystemRoutes(
admin *gin.RouterGroup,
tokensHandler *handlers.TokensHandler,
logHandler *handlers.LogHandler,
settingHandler *handlers.SettingHandler,
taskHandler *handlers.TaskHandler,
) {
// Token 管理
admin.GET("/tokens", tokensHandler.GetAllTokens)
admin.PUT("/tokens", tokensHandler.UpdateTokens)
// 日志管理
admin.GET("/logs", logHandler.GetLogs)
// 设置管理
admin.GET("/settings", settingHandler.GetSettings)
admin.PUT("/settings", settingHandler.UpdateSettings)
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
// 任务管理
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
}
// registerDashboardRoutes 注册仪表盘路由
func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/overview", dashboardHandler.GetOverview)
dashboard.GET("/chart", dashboardHandler.GetChart)
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats)
}
}
// registerWebRoutes 注册 Web 页面路由
func registerWebRoutes( func registerWebRoutes(
router *gin.Engine, router *gin.Engine,
authMiddleware gin.HandlerFunc, authMiddleware gin.HandlerFunc,
webAuthHandler *webhandlers.WebAuthHandler, webAuthHandler *webhandlers.WebAuthHandler,
pageHandler *webhandlers.PageHandler, pageHandler *webhandlers.PageHandler,
) { ) {
// 公开的认证路由
router.GET("/login", webAuthHandler.ShowLoginPage) router.GET("/login", webAuthHandler.ShowLoginPage)
router.POST("/login", webAuthHandler.HandleLogin) router.POST("/login", webAuthHandler.HandleLogin)
router.GET("/logout", webAuthHandler.HandleLogout) router.GET("/logout", webAuthHandler.HandleLogout)
// For Test only router.Run("127.0.0.1:9000")
// 受保护的Admin Web界面 // 受保护的管理界面
webGroup := router.Group("/", authMiddleware) webGroup := router.Group("/", authMiddleware)
webGroup.Use(authMiddleware)
{ {
webGroup.GET("/keys", pageHandler.ShowKeysPage) webGroup.GET("/keys", pageHandler.ShowKeysPage)
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage) webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
@@ -197,14 +380,31 @@ func registerWebRoutes(
webGroup.GET("/tasks", pageHandler.ShowTasksPage) webGroup.GET("/tasks", pageHandler.ShowTasksPage)
webGroup.GET("/chat", pageHandler.ShowChatPage) webGroup.GET("/chat", pageHandler.ShowChatPage)
} }
} }
// registerPublicAPIRoutes 无需后台登录的公共API路由 // registerPublicAPIRoutes 注册公共 API 路由
func registerPublicAPIRoutes(router *gin.Engine, apiAuthHandler *handlers.APIAuthHandler, securityService *service.SecurityService, settingsManager *settings.SettingsManager) { func registerPublicAPIRoutes(
ipBanMiddleware := middleware.IPBanMiddleware(securityService, settingsManager) router *gin.Engine,
publicAPIGroup := router.Group("/api") apiAuthHandler *handlers.APIAuthHandler,
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
logger *logrus.Logger,
) {
// 创建 IP 封禁中间件
ipBanCache := middleware.NewIPBanCache()
ipBanMiddleware := middleware.IPBanMiddleware(
securityService,
settingsManager,
ipBanCache,
logger,
)
// 公共 API 路由组
publicAPI := router.Group("/api")
{ {
publicAPIGroup.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin) publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
// 可以在这里添加其他公共 API 路由
// publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister)
// publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword)
} }
} }

View File

@@ -5,93 +5,179 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
) )
const ( const (
flushLoopInterval = 1 * time.Minute defaultFlushInterval = 1 * time.Minute
maxRetryAttempts = 3
retryDelay = 5 * time.Second
) )
type AnalyticsServiceLogger struct{ *logrus.Entry }
type AnalyticsService struct { type AnalyticsService struct {
db *gorm.DB db *gorm.DB
store store.Store store store.Store
logger *logrus.Entry logger *logrus.Entry
dialect dialect.DialectAdapter
settingsManager *settings.SettingsManager
stopChan chan struct{} stopChan chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
dialect dialect.DialectAdapter ctx context.Context
cancel context.CancelFunc
// 统计指标
eventsReceived atomic.Uint64
eventsProcessed atomic.Uint64
eventsFailed atomic.Uint64
flushCount atomic.Uint64
recordsFlushed atomic.Uint64
flushErrors atomic.Uint64
lastFlushTime time.Time
lastFlushMutex sync.RWMutex
// 运行时配置
flushInterval time.Duration
configMutex sync.RWMutex
} }
func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService { func NewAnalyticsService(
db *gorm.DB,
s store.Store,
logger *logrus.Logger,
d dialect.DialectAdapter,
settingsManager *settings.SettingsManager,
) *AnalyticsService {
ctx, cancel := context.WithCancel(context.Background())
return &AnalyticsService{ return &AnalyticsService{
db: db, db: db,
store: s, store: s,
logger: logger.WithField("component", "Analytics📊"), logger: logger.WithField("component", "Analytics📊"),
stopChan: make(chan struct{}), dialect: d,
dialect: d, settingsManager: settingsManager,
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
flushInterval: defaultFlushInterval,
lastFlushTime: time.Now(),
} }
} }
func (s *AnalyticsService) Start() { func (s *AnalyticsService) Start() {
s.wg.Add(2) s.wg.Add(3)
go s.flushLoop()
go s.eventListener() go s.eventListener()
s.logger.Info("AnalyticsService (Command Side) started.") go s.flushLoop()
go s.metricsReporter()
s.logger.WithFields(logrus.Fields{
"flush_interval": s.flushInterval,
}).Info("AnalyticsService started")
} }
func (s *AnalyticsService) Stop() { func (s *AnalyticsService) Stop() {
s.logger.Info("AnalyticsService stopping...")
close(s.stopChan) close(s.stopChan)
s.cancel()
s.wg.Wait() s.wg.Wait()
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
s.logger.Info("Performing final data flush...")
s.flushToDB() s.flushToDB()
s.logger.Info("AnalyticsService final data flush completed.")
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"events_received": s.eventsReceived.Load(),
"events_processed": s.eventsProcessed.Load(),
"events_failed": s.eventsFailed.Load(),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
}).Info("AnalyticsService stopped")
} }
// 事件监听循环
func (s *AnalyticsService) eventListener() { func (s *AnalyticsService) eventListener() {
defer s.wg.Done() defer s.wg.Done()
sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
return return
} }
defer sub.Close() defer func() {
s.logger.Info("AnalyticsService subscribed to request events.") if err := sub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close subscription")
}
}()
s.logger.Info("Subscribed to request events for analytics")
for { for {
select { select {
case msg := <-sub.Channel(): case msg := <-sub.Channel():
var event models.RequestFinishedEvent s.handleMessage(msg)
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
continue
}
s.handleAnalyticsEvent(&event)
case <-s.stopChan: case <-s.stopChan:
s.logger.Info("AnalyticsService stopping event listener.") s.logger.Info("Event listener stopping")
return
case <-s.ctx.Done():
s.logger.Info("Event listener context cancelled")
return return
} }
} }
} }
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) { // 处理单条消息
if event.RequestLog.GroupID == nil { func (s *AnalyticsService) handleMessage(msg *store.Message) {
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal analytics event")
s.eventsFailed.Add(1)
return return
} }
ctx := context.Background()
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15")) s.eventsReceived.Add(1)
if err := s.handleAnalyticsEvent(&event); err != nil {
s.eventsFailed.Add(1)
s.logger.WithFields(logrus.Fields{
"correlation_id": event.CorrelationID,
"group_id": event.RequestLog.GroupID,
}).WithError(err).Warn("Failed to process analytics event")
} else {
s.eventsProcessed.Add(1)
}
}
// 处理分析事件
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
if event.RequestLog.GroupID == nil {
return nil // 跳过无 GroupID 的事件
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now().UTC()
key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName) fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline(ctx) pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1) pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess { if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1) pipe.HIncrBy(key, fieldPrefix+":success", 1)
} }
@@ -101,80 +187,213 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.CompletionTokens > 0 { if event.RequestLog.CompletionTokens > 0 {
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens)) pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
} }
// 设置过期时间保留48小时
pipe.Expire(key, 48*time.Hour)
if err := pipe.Exec(); err != nil { if err := pipe.Exec(); err != nil {
s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, *event.RequestLog.GroupID, err) return fmt.Errorf("redis pipeline failed: %w", err)
} }
return nil
} }
// 刷新循环
func (s *AnalyticsService) flushLoop() { func (s *AnalyticsService) flushLoop() {
defer s.wg.Done() defer s.wg.Done()
ticker := time.NewTicker(flushLoopInterval)
s.configMutex.RLock()
interval := s.flushInterval
s.configMutex.RUnlock()
ticker := time.NewTicker(interval)
defer ticker.Stop() defer ticker.Stop()
s.logger.WithField("interval", interval).Info("Flush loop started")
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
s.flushToDB() s.flushToDB()
case <-s.stopChan: case <-s.stopChan:
s.logger.Info("Flush loop stopping")
return
case <-s.ctx.Done():
return return
} }
} }
} }
// 刷写到数据库
func (s *AnalyticsService) flushToDB() { func (s *AnalyticsService) flushToDB() {
ctx := context.Background() start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
now := time.Now().UTC() now := time.Now().UTC()
keysToFlush := []string{ keysToFlush := s.generateFlushKeys(now)
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")), totalRecords := 0
} totalErrors := 0
for _, key := range keysToFlush { for _, key := range keysToFlush {
data, err := s.store.HGetAll(ctx, key) records, err := s.flushSingleKey(ctx, key, now)
if err != nil || len(data) == 0 { if err != nil {
continue s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
totalErrors++
s.flushErrors.Add(1)
} else {
totalRecords += records
} }
}
statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data) s.recordsFlushed.Add(uint64(totalRecords))
s.flushCount.Add(1)
if len(statsToFlush) > 0 { s.lastFlushMutex.Lock()
upsertClause := s.dialect.OnConflictUpdateAll( s.lastFlushTime = time.Now()
[]string{"time", "group_id", "model_name"}, s.lastFlushMutex.Unlock()
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
) duration := time.Since(start)
err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
if err != nil { if totalRecords > 0 || totalErrors > 0 {
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err) s.logger.WithFields(logrus.Fields{
} else { "records_flushed": totalRecords,
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key) "keys_processed": len(keysToFlush),
_ = s.store.HDel(ctx, key, parsedFields...) "errors": totalErrors,
} "duration": duration,
} }).Info("Analytics data flush completed")
} else {
s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
} }
} }
// 生成需要刷新的 Redis 键
func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
keys := make([]string, 0, 4)
// 当前小时
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
// 前3个小时处理延迟和时区问题
for i := 1; i <= 3; i++ {
pastHour := now.Add(-time.Duration(i) * time.Hour)
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
}
return keys
}
// 刷写单个 Redis 键
func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
data, err := s.store.HGetAll(ctx, key)
if err != nil {
return 0, fmt.Errorf("failed to get hash data: %w", err)
}
if len(data) == 0 {
return 0, nil // 无数据,跳过
}
// 解析时间戳
hourStr := strings.TrimPrefix(key, "analytics:hourly:")
recordTime, err := time.Parse("2006-01-02T15", hourStr)
if err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
recordTime = baseTime.Truncate(time.Hour)
}
statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
if len(statsToFlush) == 0 {
return 0, nil
}
// 使用事务 + 重试机制
var dbErr error
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
if dbErr == nil {
break
}
if attempt < maxRetryAttempts {
s.logger.WithFields(logrus.Fields{
"attempt": attempt,
"key": key,
}).WithError(dbErr).Warn("Database upsert failed, retrying...")
time.Sleep(retryDelay)
}
}
if dbErr != nil {
return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
}
// 删除已处理的字段
if len(parsedFields) > 0 {
if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
}
}
return len(statsToFlush), nil
}
// 使用事务批量 upsert
func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
return tx.Clauses(upsertClause).Create(&stats).Error
})
}
// 解析 Redis Hash 数据
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) { func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
tempAggregator := make(map[string]*models.StatsHourly) tempAggregator := make(map[string]*models.StatsHourly)
var parsedFields []string parsedFields := make([]string, 0, len(data))
for field, valueStr := range data { for field, valueStr := range data {
parts := strings.Split(field, ":") parts := strings.Split(field, ":")
if len(parts) != 3 { if len(parts) != 3 {
s.logger.WithField("field", field).Warn("Invalid field format")
continue continue
} }
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
aggKey := groupIDStr + ":" + modelName aggKey := groupIDStr + ":" + modelName
if _, ok := tempAggregator[aggKey]; !ok { if _, ok := tempAggregator[aggKey]; !ok {
gid, err := strconv.Atoi(groupIDStr) gid, err := strconv.Atoi(groupIDStr)
if err != nil { if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"group_id": groupIDStr,
}).Warn("Invalid group ID")
continue continue
} }
tempAggregator[aggKey] = &models.StatsHourly{ tempAggregator[aggKey] = &models.StatsHourly{
Time: t, Time: t,
GroupID: uint(gid), GroupID: uint(gid),
ModelName: modelName, ModelName: modelName,
} }
} }
val, _ := strconv.ParseInt(valueStr, 10, 64)
val, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"value": valueStr,
}).Warn("Invalid counter value")
continue
}
switch counterType { switch counterType {
case "requests": case "requests":
tempAggregator[aggKey].RequestCount = val tempAggregator[aggKey].RequestCount = val
@@ -184,14 +403,92 @@ func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]strin
tempAggregator[aggKey].PromptTokens = val tempAggregator[aggKey].PromptTokens = val
case "completion": case "completion":
tempAggregator[aggKey].CompletionTokens = val tempAggregator[aggKey].CompletionTokens = val
default:
s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
continue
} }
parsedFields = append(parsedFields, field) parsedFields = append(parsedFields, field)
} }
var result []models.StatsHourly
result := make([]models.StatsHourly, 0, len(tempAggregator))
for _, stats := range tempAggregator { for _, stats := range tempAggregator {
if stats.RequestCount > 0 { if stats.RequestCount > 0 {
result = append(result, *stats) result = append(result, *stats)
} }
} }
return result, parsedFields return result, parsedFields
} }
// 定期输出统计信息
func (s *AnalyticsService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *AnalyticsService) reportMetrics() {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
failed := s.eventsFailed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
s.logger.WithFields(logrus.Fields{
"events_received": received,
"events_processed": processed,
"events_failed": failed,
"success_rate": fmt.Sprintf("%.2f%%", successRate),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Round(time.Second),
}).Info("Analytics metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *AnalyticsService) GetMetrics() map[string]interface{} {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
return map[string]interface{}{
"events_received": received,
"events_processed": processed,
"events_failed": s.eventsFailed.Load(),
"success_rate": successRate,
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Seconds(),
"flush_interval": s.flushInterval.Seconds(),
}
}

View File

@@ -4,158 +4,297 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"gemini-balancer/internal/store" "gemini-balancer/internal/store"
"gemini-balancer/internal/syncer" "gemini-balancer/internal/syncer"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
) )
const overviewCacheChannel = "syncer:cache:dashboard_overview" const (
overviewCacheChannel = "syncer:cache:dashboard_overview"
defaultChartDays = 7
cacheLoadTimeout = 30 * time.Second
)
var (
// 图表颜色调色板
chartColorPalette = []string{
"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0",
"#9966FF", "#FF9F40", "#C9CBCF", "#4D5360",
}
)
type DashboardQueryService struct { type DashboardQueryService struct {
db *gorm.DB db *gorm.DB
store store.Store store store.Store
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse] overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
logger *logrus.Entry logger *logrus.Entry
stopChan chan struct{}
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
queryCount atomic.Uint64
cacheHits atomic.Uint64
cacheMisses atomic.Uint64
overviewLoadCount atomic.Uint64
lastQueryTime time.Time
lastQueryMutex sync.RWMutex
} }
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) { func NewDashboardQueryService(
qs := &DashboardQueryService{ db *gorm.DB,
db: db, s store.Store,
store: s, logger *logrus.Logger,
logger: logger.WithField("component", "DashboardQueryService"), ) (*DashboardQueryService, error) {
stopChan: make(chan struct{}), ctx, cancel := context.WithCancel(context.Background())
service := &DashboardQueryService{
db: db,
store: s,
logger: logger.WithField("component", "DashboardQuery📈"),
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
lastQueryTime: time.Now(),
} }
loader := qs.loadOverviewData // 创建 CacheSyncer
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel) overviewSyncer, err := syncer.NewCacheSyncer(
service.loadOverviewData,
s,
overviewCacheChannel,
logger,
)
if err != nil { if err != nil {
cancel()
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err) return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
} }
qs.overviewSyncer = overviewSyncer service.overviewSyncer = overviewSyncer
return qs, nil
return service, nil
} }
func (s *DashboardQueryService) Start() { func (s *DashboardQueryService) Start() {
s.wg.Add(2)
go s.eventListener() go s.eventListener()
s.logger.Info("DashboardQueryService started and listening for invalidation events.") go s.metricsReporter()
s.logger.Info("DashboardQueryService started")
} }
func (s *DashboardQueryService) Stop() { func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService stopping...")
close(s.stopChan) close(s.stopChan)
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.") s.cancel()
s.wg.Wait()
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"total_queries": s.queryCount.Load(),
"cache_hits": s.cacheHits.Load(),
"cache_misses": s.cacheMisses.Load(),
"overview_loads": s.overviewLoadCount.Load(),
}).Info("DashboardQueryService stopped")
} }
// ==================== 核心查询方法 ====================
// GetDashboardOverviewData 获取仪表盘概览数据(带缓存)
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
s.queryCount.Add(1)
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
s.cacheMisses.Add(1)
s.logger.Warn("Overview cache is empty, attempting to load...")
// 触发立即加载
if err := s.overviewSyncer.Invalidate(); err != nil {
return nil, fmt.Errorf("failed to trigger cache reload: %w", err)
}
// 等待加载完成最多30秒
ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout)
defer cancel()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if data := s.overviewSyncer.Get(); data != nil {
s.cacheHits.Add(1)
return data, nil
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout waiting for overview cache to load")
}
}
}
s.cacheHits.Add(1)
return cachedDataPtr, nil
}
// InvalidateOverviewCache 手动失效概览缓存
func (s *DashboardQueryService) InvalidateOverviewCache() error {
s.logger.Info("Manually invalidating overview cache")
return s.overviewSyncer.Invalidate()
}
// GetGroupStats 获取指定分组的统计数据
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) { func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
start := time.Now()
// 1. 从 Redis 获取 Key 统计
statsKey := fmt.Sprintf("stats:group:%d", groupID) statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey) keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil { if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID) s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err) return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
} }
keyStats := make(map[string]int64) keyStats := make(map[string]int64)
for k, v := range keyStatsMap { for k, v := range keyStatsMap {
val, _ := strconv.ParseInt(v, 10, 64) val, _ := strconv.ParseInt(v, 10, 64)
keyStats[k] = val keyStats[k] = val
} }
now := time.Now()
// 2. 查询请求统计(使用 UTC 时间)
now := time.Now().UTC()
oneHourAgo := now.Add(-1 * time.Hour) oneHourAgo := now.Add(-1 * time.Hour)
twentyFourHoursAgo := now.Add(-24 * time.Hour) twentyFourHoursAgo := now.Add(-24 * time.Hour)
type requestStatsResult struct { type requestStatsResult struct {
TotalRequests int64 TotalRequests int64
SuccessRequests int64 SuccessRequests int64
} }
var last1Hour, last24Hours requestStatsResult var last1Hour, last24Hours requestStatsResult
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). // 并发查询优化
Where("group_id = ? AND time >= ?", groupID, oneHourAgo). var wg sync.WaitGroup
Scan(&last1Hour) errChan := make(chan error, 2)
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). wg.Add(2)
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours) // 查询最近1小时
failureRate1h := 0.0 go func() {
if last1Hour.TotalRequests > 0 { defer wg.Done()
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100 if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
} Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
failureRate24h := 0.0 Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
if last24Hours.TotalRequests > 0 { Scan(&last1Hour).Error; err != nil {
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100 errChan <- fmt.Errorf("failed to query 1h stats: %w", err)
} }
last1HourStats := map[string]any{ }()
"total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests, // 查询最近24小时
"failure_rate": failureRate1h, go func() {
} defer wg.Done()
last24HoursStats := map[string]any{ if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
"total_requests": last24Hours.TotalRequests, Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
"success_requests": last24Hours.SuccessRequests, Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
"failure_rate": failureRate24h, Scan(&last24Hours).Error; err != nil {
errChan <- fmt.Errorf("failed to query 24h stats: %w", err)
}
}()
wg.Wait()
close(errChan)
// 检查错误
for err := range errChan {
if err != nil {
return nil, err
}
} }
// 3. 计算失败率
failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests)
failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests)
result := map[string]any{ result := map[string]any{
"key_stats": keyStats, "key_stats": keyStats,
"last_1_hour": last1HourStats, "last_1_hour": map[string]any{
"last_24_hours": last24HoursStats, "total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests,
"failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests,
"failure_rate": failureRate1h,
},
"last_24_hours": map[string]any{
"total_requests": last24Hours.TotalRequests,
"success_requests": last24Hours.SuccessRequests,
"failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests,
"failure_rate": failureRate24h,
},
} }
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"duration": duration,
}).Debug("Group stats query completed")
return result, nil return result, nil
} }
func (s *DashboardQueryService) eventListener() { // QueryHistoricalChart 查询历史图表数据
ctx := context.Background()
keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close()
defer upstreamStatusSub.Close()
for {
select {
case <-keyStatusSub.Channel():
s.logger.Info("Received key status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-upstreamStatusSub.Channel():
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-s.stopChan:
s.logger.Info("Stopping dashboard event listener.")
return
}
}
}
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
}
return cachedDataPtr, nil
}
func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate()
}
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) { func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
start := time.Now()
type ChartPoint struct { type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"` TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"` ModelName string `gorm:"column:model_name"`
TotalRequests int64 `gorm:"column:total_requests"` TotalRequests int64 `gorm:"column:total_requests"`
} }
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
// 查询最近7天数据使用 UTC
sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour)
// 根据数据库类型构建时间格式化子句
sqlFormat, goFormat := s.buildTimeFormatSelectClause() sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat) selectClause := fmt.Sprintf(
query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC") "%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests",
sqlFormat,
)
// 构建查询
query := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Select(selectClause).
Where("time >= ?", sevenDaysAgo).
Group("time_label, model_name").
Order("time_label ASC")
if groupID != nil && *groupID > 0 { if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID) query = query.Where("group_id = ?", *groupID)
} }
var points []ChartPoint var points []ChartPoint
if err := query.Find(&points).Error; err != nil { if err := query.Find(&points).Error; err != nil {
return nil, err return nil, fmt.Errorf("failed to query chart data: %w", err)
} }
// 构建数据集
datasets := make(map[string]map[string]int64) datasets := make(map[string]map[string]int64)
for _, p := range points { for _, p := range points {
if _, ok := datasets[p.ModelName]; !ok { if _, ok := datasets[p.ModelName]; !ok {
@@ -163,32 +302,99 @@ func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupI
} }
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
} }
// 生成时间标签(按小时)
var labels []string var labels []string
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) { for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) {
labels = append(labels, t.Format(goFormat)) labels = append(labels, t.Format(goFormat))
} }
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"} // 构建图表数据
chartData := &models.ChartData{
Labels: labels,
Datasets: make([]models.ChartDataset, 0, len(datasets)),
}
colorIndex := 0 colorIndex := 0
for modelName, dataPoints := range datasets { for modelName, dataPoints := range datasets {
dataArray := make([]int64, len(labels)) dataArray := make([]int64, len(labels))
for i, label := range labels { for i, label := range labels {
dataArray[i] = dataPoints[label] dataArray[i] = dataPoints[label]
} }
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{ chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
Label: modelName, Label: modelName,
Data: dataArray, Data: dataArray,
Color: colorPalette[colorIndex%len(colorPalette)], Color: chartColorPalette[colorIndex%len(chartColorPalette)],
}) })
colorIndex++ colorIndex++
} }
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"points": len(points),
"datasets": len(chartData.Datasets),
"duration": duration,
}).Debug("Historical chart query completed")
return chartData, nil return chartData, nil
} }
// GetRequestStatsForPeriod 获取指定时间段的请求统计
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
var startTime time.Time
now := time.Now().UTC()
switch period {
case "1m":
startTime = now.Add(-1 * time.Minute)
case "1h":
startTime = now.Add(-1 * time.Hour)
case "1d":
year, month, day := now.Date()
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
default:
return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period)
}
var result struct {
Total int64
Success int64
}
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error
if err != nil {
return nil, fmt.Errorf("failed to query request stats: %w", err)
}
return gin.H{
"period": period,
"total": result.Total,
"success": result.Success,
"failure": result.Total - result.Success,
}, nil
}
// ==================== 内部方法 ====================
// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用)
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) { func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx := context.Background() ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
s.logger.Info("[CacheSyncer] Starting to load overview data from database...") defer cancel()
s.overviewLoadCount.Add(1)
startTime := time.Now() startTime := time.Now()
s.logger.Info("Starting to load dashboard overview data...")
resp := &models.DashboardStatsResponse{ resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64), KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64), MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
@@ -200,108 +406,391 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
RequestCounts: make(map[string]int64), RequestCounts: make(map[string]int64),
} }
var loadErr error
var wg sync.WaitGroup
errChan := make(chan error, 10)
// 1. 并发加载 Key 映射状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMappingStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("mapping stats: %w", err)
}
}()
// 2. 并发加载 Master Key 状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMasterStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("master stats: %w", err)
}
}()
// 3. 并发加载请求统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadRequestCounts(ctx, resp); err != nil {
errChan <- fmt.Errorf("request counts: %w", err)
}
}()
// 4. 并发加载上游健康状态
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadUpstreamHealth(ctx, resp); err != nil {
// 上游健康状态失败不阻塞整体加载
s.logger.WithError(err).Warn("Failed to load upstream health status")
}
}()
// 等待所有加载任务完成
wg.Wait()
close(errChan)
// 收集错误
for err := range errChan {
if err != nil {
loadErr = err
break
}
}
if loadErr != nil {
s.logger.WithError(loadErr).Error("Failed to load overview data")
return nil, loadErr
}
duration := time.Since(startTime)
s.logger.WithFields(logrus.Fields{
"duration": duration,
"total_keys": resp.KeyCount.Value,
"requests_1d": resp.RequestCounts["1d"],
"upstreams": len(resp.UpstreamHealthStatus),
}).Info("Successfully loaded dashboard overview data")
return resp, nil
}
// loadMappingStatusStats 加载 Key 映射状态统计
func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MappingStatusResult struct { type MappingStatusResult struct {
Status models.APIKeyStatus Status models.APIKeyStatus
Count int64 Count int64
} }
var mappingStatusResults []MappingStatusResult
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil { var results []MappingStatusResult
return nil, fmt.Errorf("failed to query mapping status stats: %w", err) if err := s.db.WithContext(ctx).
Model(&models.GroupAPIKeyMapping{}).
Select("status, COUNT(*) as count").
Group("status").
Find(&results).Error; err != nil {
return err
} }
for _, res := range mappingStatusResults {
for _, res := range results {
resp.KeyStatusCount[res.Status] = res.Count resp.KeyStatusCount[res.Status] = res.Count
} }
return nil
}
// loadMasterStatusStats 加载 Master Key 状态统计
func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MasterStatusResult struct { type MasterStatusResult struct {
Status models.MasterAPIKeyStatus Status models.MasterAPIKeyStatus
Count int64 Count int64
} }
var masterStatusResults []MasterStatusResult
if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil { var results []MasterStatusResult
return nil, fmt.Errorf("failed to query master status stats: %w", err) if err := s.db.WithContext(ctx).
Model(&models.APIKey{}).
Select("master_status as status, COUNT(*) as count").
Group("master_status").
Find(&results).Error; err != nil {
return err
} }
var totalKeys, invalidKeys int64 var totalKeys, invalidKeys int64
for _, res := range masterStatusResults { for _, res := range results {
resp.MasterStatusCount[res.Status] = res.Count resp.MasterStatusCount[res.Status] = res.Count
totalKeys += res.Count totalKeys += res.Count
if res.Status != models.MasterStatusActive { if res.Status != models.MasterStatusActive {
invalidKeys += res.Count invalidKeys += res.Count
} }
} }
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
now := time.Now() resp.KeyCount = models.StatCard{
Value: float64(totalKeys),
SubValue: invalidKeys,
SubValueTip: "非活跃身份密钥数",
}
var count1m, count1h, count1d int64 return nil
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m) }
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
year, month, day := now.UTC().Date() // loadRequestCounts 加载请求计数统计
func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error {
now := time.Now().UTC()
// 使用 RequestLog 表查询短期数据
var count1m, count1h int64
// 最近1分钟
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Minute)).
Count(&count1m).Error; err != nil {
return fmt.Errorf("1m count: %w", err)
}
// 最近1小时
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Hour)).
Count(&count1h).Error; err != nil {
return fmt.Errorf("1h count: %w", err)
}
// 今天UTC
year, month, day := now.Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC) startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
var count1d int64
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", startOfDay).
Count(&count1d).Error; err != nil {
return fmt.Errorf("1d count: %w", err)
}
// 最近30天使用聚合表
var count30d int64 var count30d int64
s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d) if err := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Where("time >= ?", now.AddDate(0, 0, -30)).
Select("COALESCE(SUM(request_count), 0)").
Scan(&count30d).Error; err != nil {
return fmt.Errorf("30d count: %w", err)
}
resp.RequestCounts["1m"] = count1m resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h resp.RequestCounts["1h"] = count1h
resp.RequestCounts["1d"] = count1d resp.RequestCounts["1d"] = count1d
resp.RequestCounts["30d"] = count30d resp.RequestCounts["30d"] = count30d
return nil
}
// loadUpstreamHealth 加载上游健康状态
func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error {
var upstreams []*models.UpstreamEndpoint var upstreams []*models.UpstreamEndpoint
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil { if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.") return err
} else { }
for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status
}
return nil
}
// ==================== 事件监听 ====================
// eventListener 监听缓存失效事件
func (s *DashboardQueryService) eventListener() {
defer s.wg.Done()
// 订阅事件
keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged)
// 错误处理
if err1 != nil {
s.logger.WithError(err1).Error("Failed to subscribe to key status events")
keyStatusSub = nil
}
if err2 != nil {
s.logger.WithError(err2).Error("Failed to subscribe to upstream status events")
upstreamStatusSub = nil
}
// 如果全部失败,直接返回
if keyStatusSub == nil && upstreamStatusSub == nil {
s.logger.Error("All event subscriptions failed, listener disabled")
return
}
// 安全关闭订阅
defer func() {
if keyStatusSub != nil {
if err := keyStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close key status subscription")
}
}
if upstreamStatusSub != nil {
if err := upstreamStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close upstream status subscription")
}
}
}()
s.logger.WithFields(logrus.Fields{
"key_status_sub": keyStatusSub != nil,
"upstream_status_sub": upstreamStatusSub != nil,
}).Info("Event listener started")
neverReady := make(chan *store.Message)
close(neverReady) // 立即关闭,确保永远不会阻塞
for {
// 动态选择有效的 channel
var keyStatusChan <-chan *store.Message = neverReady
if keyStatusSub != nil {
keyStatusChan = keyStatusSub.Channel()
}
var upstreamStatusChan <-chan *store.Message = neverReady
if upstreamStatusSub != nil {
upstreamStatusChan = upstreamStatusSub.Channel()
}
select {
case _, ok := <-keyStatusChan:
if !ok {
s.logger.Warn("Key status channel closed")
keyStatusSub = nil
continue
}
s.logger.Debug("Received key status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on key status change")
}
case _, ok := <-upstreamStatusChan:
if !ok {
s.logger.Warn("Upstream status channel closed")
upstreamStatusSub = nil
continue
}
s.logger.Debug("Received upstream status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change")
}
case <-s.stopChan:
s.logger.Info("Event listener stopping (stopChan)")
return
case <-s.ctx.Done():
s.logger.Info("Event listener stopping (context cancelled)")
return
} }
} }
duration := time.Since(startTime)
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
return resp, nil
} }
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) { // ==================== 监控指标 ====================
var startTime time.Time
now := time.Now()
switch period {
case "1m":
startTime = now.Add(-1 * time.Minute)
case "1h":
startTime = now.Add(-1 * time.Hour)
case "1d":
year, month, day := now.UTC().Date()
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
default:
return nil, fmt.Errorf("invalid period specified: %s", period)
}
var result struct {
Total int64
Success int64
}
err := s.db.WithContext(ctx).Model(&models.RequestLog{}). // metricsReporter 定期输出统计信息
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success"). func (s *DashboardQueryService) metricsReporter() {
Where("request_time >= ?", startTime). defer s.wg.Done()
Scan(&result).Error
if err != nil { ticker := time.NewTicker(5 * time.Minute)
return nil, err defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
} }
return gin.H{
"total": result.Total,
"success": result.Success,
"failure": result.Total - result.Success,
}, nil
} }
func (s *DashboardQueryService) reportMetrics() {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
totalQueries := s.queryCount.Load()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
s.logger.WithFields(logrus.Fields{
"total_queries": totalQueries,
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate),
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Round(time.Second),
}).Info("DashboardQuery metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *DashboardQueryService) GetMetrics() map[string]interface{} {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
return map[string]interface{}{
"total_queries": s.queryCount.Load(),
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": cacheHitRate,
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Seconds(),
}
}
// ==================== 辅助方法 ====================
// calculateFailureRate 计算失败率
func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 {
if total == 0 {
return 0.0
}
return float64(total-success) / float64(total) * 100
}
// updateLastQueryTime 更新最后查询时间
func (s *DashboardQueryService) updateLastQueryTime() {
s.lastQueryMutex.Lock()
s.lastQueryTime = time.Now()
s.lastQueryMutex.Unlock()
}
// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) { func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
dialect := s.db.Dialector.Name() dialect := s.db.Dialector.Name()
switch dialect { switch dialect {
case "mysql": case "mysql":
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00" return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
case "postgres":
return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00"
case "sqlite": case "sqlite":
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00" return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
default: default:
s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format")
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00" return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
} }
} }

View File

@@ -4,11 +4,13 @@ package service
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"gemini-balancer/internal/settings" "gemini-balancer/internal/settings"
"gemini-balancer/internal/store" "gemini-balancer/internal/store"
"sync"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
@@ -18,25 +20,47 @@ type DBLogWriterService struct {
db *gorm.DB db *gorm.DB
store store.Store store store.Store
logger *logrus.Entry logger *logrus.Entry
logBuffer chan *models.RequestLog settingsManager *settings.SettingsManager
stopChan chan struct{}
wg sync.WaitGroup logBuffer chan *models.RequestLog
SettingsManager *settings.SettingsManager stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
totalReceived atomic.Uint64
totalFlushed atomic.Uint64
totalDropped atomic.Uint64
flushCount atomic.Uint64
lastFlushTime time.Time
lastFlushMutex sync.RWMutex
} }
func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService { func NewDBLogWriterService(
cfg := settings.GetSettings() db *gorm.DB,
s store.Store,
settingsManager *settings.SettingsManager,
logger *logrus.Logger,
) *DBLogWriterService {
cfg := settingsManager.GetSettings()
bufferCapacity := cfg.LogBufferCapacity bufferCapacity := cfg.LogBufferCapacity
if bufferCapacity <= 0 { if bufferCapacity <= 0 {
bufferCapacity = 1000 bufferCapacity = 1000
} }
ctx, cancel := context.WithCancel(context.Background())
return &DBLogWriterService{ return &DBLogWriterService{
db: db, db: db,
store: s, store: s,
SettingsManager: settings, settingsManager: settingsManager,
logger: logger.WithField("component", "DBLogWriter📝"), logger: logger.WithField("component", "DBLogWriter📝"),
logBuffer: make(chan *models.RequestLog, bufferCapacity), logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
lastFlushTime: time.Now(),
} }
} }
@@ -44,93 +68,276 @@ func (s *DBLogWriterService) Start() {
s.wg.Add(2) s.wg.Add(2)
go s.eventListenerLoop() go s.eventListenerLoop()
go s.dbWriterLoop() go s.dbWriterLoop()
s.logger.Info("DBLogWriterService started.")
// 定期输出统计信息
s.wg.Add(1)
go s.metricsReporter()
s.logger.WithFields(logrus.Fields{
"buffer_capacity": cap(s.logBuffer),
}).Info("DBLogWriterService started")
} }
func (s *DBLogWriterService) Stop() { func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...") s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan) close(s.stopChan)
s.cancel() // 取消上下文
s.wg.Wait() s.wg.Wait()
s.logger.Info("DBLogWriterService stopped.")
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"total_received": s.totalReceived.Load(),
"total_flushed": s.totalFlushed.Load(),
"total_dropped": s.totalDropped.Load(),
"flush_count": s.flushCount.Load(),
}).Info("DBLogWriterService stopped")
} }
// 事件监听循环
func (s *DBLogWriterService) eventListenerLoop() { func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done() defer s.wg.Done()
ctx := context.Background() sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) s.logger.WithError(err).Error("Failed to subscribe to request events, log writing disabled")
return return
} }
defer sub.Close() defer func() {
if err := sub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close subscription")
}
}()
s.logger.Info("Subscribed to request events for database logging.") s.logger.Info("Subscribed to request events for database logging")
for { for {
select { select {
case msg := <-sub.Channel(): case msg := <-sub.Channel():
var event models.RequestFinishedEvent s.handleMessage(msg)
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
continue
}
select {
case s.logBuffer <- &event.RequestLog:
default:
s.logger.Warn("Log buffer is full. A log message might be dropped.")
}
case <-s.stopChan: case <-s.stopChan:
s.logger.Info("Event listener loop stopping.") s.logger.Info("Event listener loop stopping")
close(s.logBuffer)
return
case <-s.ctx.Done():
s.logger.Info("Event listener context cancelled")
close(s.logBuffer) close(s.logBuffer)
return return
} }
} }
} }
// 处理单条消息
func (s *DBLogWriterService) handleMessage(msg *store.Message) {
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal request event")
return
}
s.totalReceived.Add(1)
select {
case s.logBuffer <- &event.RequestLog:
// 成功入队
default:
// 缓冲区满,丢弃日志
dropped := s.totalDropped.Add(1)
if dropped%100 == 1 { // 每100条丢失输出一次警告
s.logger.WithFields(logrus.Fields{
"total_dropped": dropped,
"buffer_capacity": cap(s.logBuffer),
"buffer_len": len(s.logBuffer),
}).Warn("Log buffer full, messages being dropped")
}
}
}
// 数据库写入循环
func (s *DBLogWriterService) dbWriterLoop() { func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done() defer s.wg.Done()
cfg := s.SettingsManager.GetSettings() cfg := s.settingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 { if batchSize <= 0 {
batchSize = 100 batchSize = 100
} }
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second flushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushTimeout <= 0 { if flushInterval <= 0 {
flushTimeout = 5 * time.Second flushInterval = 5 * time.Second
} }
s.logger.WithFields(logrus.Fields{
"batch_size": batchSize,
"flush_interval": flushInterval,
}).Info("DB writer loop started")
batch := make([]*models.RequestLog, 0, batchSize) batch := make([]*models.RequestLog, 0, batchSize)
ticker := time.NewTicker(flushTimeout) ticker := time.NewTicker(flushInterval)
defer ticker.Stop() defer ticker.Stop()
// 配置热更新检查(每分钟)
configTicker := time.NewTicker(1 * time.Minute)
defer configTicker.Stop()
for { for {
select { select {
case logEntry, ok := <-s.logBuffer: case logEntry, ok := <-s.logBuffer:
if !ok { if !ok {
// 通道关闭,刷新剩余日志
if len(batch) > 0 { if len(batch) > 0 {
s.flushBatch(batch) s.flushBatch(batch)
} }
s.logger.Info("DB writer loop finished.") s.logger.Info("DB writer loop finished")
return return
} }
batch = append(batch, logEntry) batch = append(batch, logEntry)
if len(batch) >= batchSize { if len(batch) >= batchSize {
s.flushBatch(batch) s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize) batch = make([]*models.RequestLog, 0, batchSize)
} }
case <-ticker.C: case <-ticker.C:
if len(batch) > 0 { if len(batch) > 0 {
s.flushBatch(batch) s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize) batch = make([]*models.RequestLog, 0, batchSize)
} }
case <-configTicker.C:
// 热更新配置
cfg := s.settingsManager.GetSettings()
newBatchSize := cfg.LogFlushBatchSize
if newBatchSize <= 0 {
newBatchSize = 100
}
newFlushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if newFlushInterval <= 0 {
newFlushInterval = 5 * time.Second
}
if newBatchSize != batchSize {
s.logger.WithFields(logrus.Fields{
"old": batchSize,
"new": newBatchSize,
}).Info("Batch size updated")
batchSize = newBatchSize
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
}
if newFlushInterval != flushInterval {
s.logger.WithFields(logrus.Fields{
"old": flushInterval,
"new": newFlushInterval,
}).Info("Flush interval updated")
flushInterval = newFlushInterval
ticker.Reset(flushInterval)
}
} }
} }
} }
// 批量刷写到数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) { func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil { if len(batch) == 0 {
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.") return
}
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err := s.db.WithContext(ctx).CreateInBatches(batch, len(batch)).Error
duration := time.Since(start)
s.lastFlushMutex.Lock()
s.lastFlushTime = time.Now()
s.lastFlushMutex.Unlock()
if err != nil {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
}).WithError(err).Error("Failed to flush log batch to database")
} else { } else {
s.logger.Infof("Successfully flushed %d logs to database.", len(batch)) flushed := s.totalFlushed.Add(uint64(len(batch)))
flushCount := s.flushCount.Add(1)
// 只在慢写入或大批量时输出日志
if duration > 1*time.Second || len(batch) > 500 {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
"total_flushed": flushed,
"flush_count": flushCount,
}).Info("Log batch flushed to database")
} else {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
}).Debug("Log batch flushed to database")
}
}
}
// 定期输出统计信息
func (s *DBLogWriterService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *DBLogWriterService) reportMetrics() {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.totalReceived.Load()
flushed := s.totalFlushed.Load()
dropped := s.totalDropped.Load()
pending := uint64(len(s.logBuffer))
s.logger.WithFields(logrus.Fields{
"received": received,
"flushed": flushed,
"dropped": dropped,
"pending": pending,
"flush_count": s.flushCount.Load(),
"last_flush": time.Since(lastFlush).Round(time.Second),
"buffer_usage": float64(pending) / float64(cap(s.logBuffer)) * 100,
"success_rate": float64(flushed) / float64(received) * 100,
}).Info("DBLogWriter metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *DBLogWriterService) GetMetrics() map[string]interface{} {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
return map[string]interface{}{
"total_received": s.totalReceived.Load(),
"total_flushed": s.totalFlushed.Load(),
"total_dropped": s.totalDropped.Load(),
"flush_count": s.flushCount.Load(),
"buffer_pending": len(s.logBuffer),
"buffer_capacity": cap(s.logBuffer),
"last_flush_ago": time.Since(lastFlush).Seconds(),
} }
} }

View File

@@ -334,7 +334,6 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
} }
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) { func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
globalSettings := gm.settingsManager.GetSettings() globalSettings := gm.settingsManager.GetSettings()
defaultModel := "gemini-1.5-flash"
opConfig := &models.KeyGroupSettings{ opConfig := &models.KeyGroupSettings{
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck, EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency, KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
@@ -342,7 +341,7 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL, KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold, KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes, KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
KeyCheckModel: &defaultModel, KeyCheckModel: &globalSettings.BaseKeyCheckModel,
MaxRetries: &globalSettings.MaxRetries, MaxRetries: &globalSettings.MaxRetries,
EnableSmartGateway: &globalSettings.EnableSmartGateway, EnableSmartGateway: &globalSettings.EnableSmartGateway,
} }

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,10 @@ const (
TaskTypeHardDeleteKeys = "hard_delete_keys" TaskTypeHardDeleteKeys = "hard_delete_keys"
TaskTypeRestoreKeys = "restore_keys" TaskTypeRestoreKeys = "restore_keys"
chunkSize = 500 chunkSize = 500
// 任务超时时间常量化
defaultTaskTimeout = 15 * time.Minute
longTaskTimeout = time.Hour
) )
type KeyImportService struct { type KeyImportService struct {
@@ -43,17 +47,19 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
} }
} }
// runTaskWithRecovery 统一的任务恢复包装器
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) { func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r) err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err) s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
} }
}() }()
taskFunc() taskFunc()
} }
// StartAddKeysTask 启动批量添加密钥任务
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) { func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
@@ -61,260 +67,404 @@ func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, k
} }
resourceID := fmt.Sprintf("group-%d", groupID) resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute) taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport) s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
}) })
return taskStatus, nil return taskStatus, nil
} }
// StartUnlinkKeysTask 启动批量解绑密钥任务
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found") return nil, fmt.Errorf("no valid keys found")
} }
resourceID := fmt.Sprintf("group-%d", groupID) resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour) taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys) s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
}) })
return taskStatus, nil return taskStatus, nil
} }
// StartHardDeleteKeysTask 启动硬删除密钥任务
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) { func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found") return nil, fmt.Errorf("no valid keys found")
} }
resourceID := "global_hard_delete" resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour) taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys) s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
}) })
return taskStatus, nil return taskStatus, nil
} }
// StartRestoreKeysTask 启动恢复密钥任务
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) { func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found") return nil, fmt.Errorf("no valid keys found")
} }
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour) resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys) s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
}) })
return taskStatus, nil return taskStatus, nil
} }
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { // StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑
uniqueKeysMap := make(map[string]struct{}) func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
var uniqueKeyStrings []string s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists { keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
uniqueKeysMap[kStr] = struct{}{}
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
}
}
if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return
}
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
if err != nil { if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) return nil, fmt.Errorf("failed to find keys by filter: %w", err)
return
} }
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID) if len(keyValues) == 0 {
if err != nil { return nil, fmt.Errorf("no keys found matching the provided filter")
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
return
}
alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
var keysToLink []models.APIKey
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
} }
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
}
// ==================== 核心任务执行逻辑 ====================
// runAddKeysTask 执行批量添加密钥
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 1. 去重
uniqueKeys := s.deduplicateKeys(keys)
if len(uniqueKeys) == 0 {
s.endTaskWithResult(ctx, taskID, resourceID, gin.H{
"newly_linked_count": 0,
"already_linked_count": 0,
}, nil)
return
}
// 2. 确保所有密钥在数据库中存在(幂等操作)
allKeyModels, err := s.ensureKeysExist(uniqueKeys)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
}
// 3. 过滤已关联的密钥
keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err))
return
}
// 4. 更新任务的实际处理总数
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil { if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
} }
// 5. 批量关联密钥到组
if len(keysToLink) > 0 { if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink)) if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil {
for i, key := range keysToLink { s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
idsToLink[i] = key.ID return
} }
for i := 0; i < len(idsToLink); i += chunkSize { }
end := i + chunkSize
if end > len(idsToLink) { // 6. 根据验证标志处理密钥状态
end = len(idsToLink) if len(keysToLink) > 0 {
} s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport)
chunk := idsToLink[i:end] }
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err)) // 7. 返回结果
return result := gin.H{
} "newly_linked_count": len(keysToLink),
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) "already_linked_count": alreadyLinkedCount,
"total_linked_count": len(allKeyModels),
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
}
// runUnlinkKeysTask 执行批量解绑密钥
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
// 1. 去重
uniqueKeys := s.deduplicateKeys(keys)
// 2. 查找需要解绑的密钥
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
}
if len(keysToUnlink) == 0 {
result := gin.H{
"unlinked_count": 0,
"hard_deleted_count": 0,
"not_found_count": len(uniqueKeys),
} }
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
return
}
// 3. 提取密钥 ID
idsToUnlink := s.extractKeyIDs(keysToUnlink)
// 4. 更新任务总数
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
// 5. 批量解绑
totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
// 6. 清理孤立密钥
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
if err != nil {
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
}
// 7. 返回结果
result := gin.H{
"unlinked_count": totalUnlinked,
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
}
// runHardDeleteKeysTask 执行硬删除密钥
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
return s.keyRepo.HardDeleteByValues(chunk)
})
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
} }
result := gin.H{ result := gin.H{
"newly_linked_count": len(keysToLink), "hard_deleted_count": totalDeleted,
"already_linked_count": len(alreadyLinkedIDSet), "not_found_count": int64(len(keys)) - totalDeleted,
"total_linked_count": len(allKeyModels),
} }
if len(keysToLink) > 0 { s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
idsToLink := make([]uint, len(keysToLink)) s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
if validateOnImport {
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
}
}
}
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
} }
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) { // runRestoreKeysTask 执行恢复密钥
uniqueKeysMap := make(map[string]struct{}) func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var uniqueKeys []string restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
})
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
// ==================== 辅助方法 ====================
// deduplicateKeys 去重密钥列表
func (s *KeyImportService) deduplicateKeys(keys []string) []string {
uniqueKeysMap := make(map[string]struct{}, len(keys))
uniqueKeys := make([]string, 0, len(keys))
for _, kStr := range keys { for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists { if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{} uniqueKeysMap[kStr] = struct{}{}
uniqueKeys = append(uniqueKeys, kStr) uniqueKeys = append(uniqueKeys, kStr)
} }
} }
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) return uniqueKeys
}
// ensureKeysExist 确保所有密钥在数据库中存在
func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) {
keysToEnsure := make([]models.APIKey, len(keys))
for i, keyStr := range keys {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
return s.keyRepo.AddKeys(keysToEnsure)
}
// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥
func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) {
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil { if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) return nil, 0, err
return
} }
if len(keysToUnlink) == 0 { alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels))
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)} for _, key := range alreadyLinkedModels {
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) alreadyLinkedIDSet[key.ID] = struct{}{}
return
}
idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID
} }
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil { keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet))
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) for _, key := range allKeyModels {
} if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
var totalUnlinked int64 keysToLink = append(keysToLink, key)
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize
if end > len(idsToUnlink) {
end = len(idsToUnlink)
} }
}
return keysToLink, len(alreadyLinkedIDSet), nil
}
// extractKeyIDs 提取密钥 ID 列表
func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint {
ids := make([]uint, len(keys))
for i, key := range keys {
ids[i] = key.ID
}
return ids
}
// linkKeysInChunks 分块关联密钥到组
func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error {
idsToLink := s.extractKeyIDs(keysToLink)
for i := 0; i < len(idsToLink); i += chunkSize {
end := min(i+chunkSize, len(idsToLink))
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
return fmt.Errorf("chunk failed to link keys: %w", err)
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
return nil
}
// unlinkKeysInChunks 分块解绑密钥
func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) {
var totalUnlinked int64
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := min(i+chunkSize, len(idsToUnlink))
chunk := idsToUnlink[i:end] chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk) unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil { if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err)) return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err)
return
} }
totalUnlinked += unlinked totalUnlinked += unlinked
// 发布解绑事件
for _, keyID := range chunk { for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked") s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
} }
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
} }
totalDeleted, err := s.keyRepo.DeleteOrphanKeys() return totalUnlinked, nil
}
// processKeysInChunks 通用的分块处理密钥逻辑
func (s *KeyImportService) processKeysInChunks(
ctx context.Context,
taskID string,
keys []string,
processFunc func(chunk []string) (int64, error),
) (int64, error) {
var totalProcessed int64
for i := 0; i < len(keys); i += chunkSize {
end := min(i+chunkSize, len(keys))
chunk := keys[i:end]
count, err := processFunc(chunk)
if err != nil {
return 0, fmt.Errorf("failed to process chunk: %w", err)
}
totalProcessed += count
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
return totalProcessed, nil
}
// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活)
func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) {
idsToLink := s.extractKeyIDs(keysToLink)
if validateOnImport {
// 发布批量导入完成事件,触发验证
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
// 发布单个密钥状态变更事件
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
// 直接激活密钥,不进行验证
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
}).Errorf("Failed to directly activate key: %v", err)
}
}
}
}
// endTaskWithResult 统一的任务结束处理
func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) {
if err != nil { if err != nil {
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.") s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"resource_id": resourceID,
}).WithError(err).Error("Task failed")
} }
result := gin.H{ s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err)
"unlinked_count": totalUnlinked,
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
} }
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { // ==================== 事件发布方法 ====================
var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return
}
totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var restoredCount int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return
}
restoredCount += count
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
// publishSingleKeyChangeEvent 发布单个密钥状态变更事件
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{ event := models.KeyStatusChangedEvent{
GroupID: groupID, GroupID: groupID,
@@ -324,56 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, grou
ChangeReason: reason, ChangeReason: reason,
ChangedAt: time.Now(), ChangedAt: time.Now(),
} }
eventData, _ := json.Marshal(event)
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { eventData, err := json.Marshal(event)
s.logger.WithError(err).WithFields(logrus.Fields{ if err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID, "group_id": groupID,
"key_id": keyID, "key_id": keyID,
"reason": reason, "reason": reason,
}).Error("Failed to publish single key change event.") }).WithError(err).Error("Failed to marshal key change event")
return
}
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"reason": reason,
}).WithError(err).Error("Failed to publish single key change event")
} }
} }
// publishChangeEvent 发布通用变更事件
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) { func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{ event := models.KeyStatusChangedEvent{
GroupID: groupID, GroupID: groupID,
ChangeReason: reason, ChangeReason: reason,
} }
eventData, _ := json.Marshal(event)
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData) eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"reason": reason,
}).WithError(err).Error("Failed to marshal change event")
return
}
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"reason": reason,
}).WithError(err).Error("Failed to publish change event")
}
} }
// publishImportGroupCompletedEvent 发布批量导入完成事件
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) { func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 { if len(keyIDs) == 0 {
return return
} }
event := models.ImportGroupCompletedEvent{ event := models.ImportGroupCompletedEvent{
GroupID: groupID, GroupID: groupID,
KeyIDs: keyIDs, KeyIDs: keyIDs,
CompletedAt: time.Now(), CompletedAt: time.Now(),
} }
eventData, err := json.Marshal(event) eventData, err := json.Marshal(event)
if err != nil { if err != nil {
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.") s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent")
return return
} }
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil { if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.") s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).WithError(err).Error("Failed to publish ImportGroupCompletedEvent")
} else { } else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs)) s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).Info("Published ImportGroupCompletedEvent")
} }
} }
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { // min 返回两个整数中的较小值
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) func min(a, b int) int {
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if a < b {
if err != nil { return a
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
} }
if len(keyValues) == 0 { return b
return nil, fmt.Errorf("no keys found matching the provided filter")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
} }

View File

@@ -25,26 +25,38 @@ import (
) )
const ( const (
TaskTypeTestKeys = "test_keys" TaskTypeTestKeys = "test_keys"
defaultConcurrency = 10
maxValidationConcurrency = 100
validationTaskTimeout = time.Hour
) )
type KeyValidationService struct { type KeyValidationService struct {
taskService task.Reporter taskService task.Reporter
channel channel.ChannelProxy channel channel.ChannelProxy
db *gorm.DB db *gorm.DB
SettingsManager *settings.SettingsManager settingsManager *settings.SettingsManager
groupManager *GroupManager groupManager *GroupManager
store store.Store store store.Store
keyRepo repository.KeyRepository keyRepo repository.KeyRepository
logger *logrus.Entry logger *logrus.Entry
} }
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService { func NewKeyValidationService(
ts task.Reporter,
ch channel.ChannelProxy,
db *gorm.DB,
ss *settings.SettingsManager,
gm *GroupManager,
st store.Store,
kr repository.KeyRepository,
logger *logrus.Logger,
) *KeyValidationService {
return &KeyValidationService{ return &KeyValidationService{
taskService: ts, taskService: ts,
channel: ch, channel: ch,
db: db, db: db,
SettingsManager: ss, settingsManager: ss,
groupManager: gm, groupManager: gm,
store: st, store: st,
keyRepo: kr, keyRepo: kr,
@@ -52,33 +64,393 @@ func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm
} }
} }
// ==================== 公开接口 ====================
// ValidateSingleKey 验证单个密钥
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error { func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
// 1. 解密密钥
if err := s.keyRepo.Decrypt(key); err != nil { if err := s.keyRepo.Decrypt(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err) return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
} }
// 2. 创建 HTTP 客户端和请求
client := &http.Client{Timeout: timeout} client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("GET", endpoint, nil) req, err := http.NewRequest("GET", endpoint, nil)
if err != nil { if err != nil {
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err) s.logger.WithFields(logrus.Fields{
"key_id": key.ID,
"endpoint": endpoint,
}).Error("Failed to create validation request")
return fmt.Errorf("failed to create request: %w", err) return fmt.Errorf("failed to create request: %w", err)
} }
// 3. 修改请求(添加密钥认证头)
s.channel.ModifyRequest(req, key) s.channel.ModifyRequest(req, key)
// 4. 执行请求
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("request failed: %w", err) return fmt.Errorf("request failed: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
// 5. 检查响应状态
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {
return nil return nil
} }
// 6. 处理错误响应
return s.buildValidationError(resp)
}
// StartTestKeysTask 启动批量密钥测试任务
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
// 1. 解析和验证输入
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
}
// 2. 查询密钥模型
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(apiKeyModels) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
}
// 3. 批量解密密钥
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
}
// 4. 获取组配置
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
// 5. 构建验证端点
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
}
// 6. 创建任务
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
if err != nil {
return nil, err
}
// 7. 准备任务参数
params := s.buildValidationParams(opConfig)
// 8. 启动异步验证任务
go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
return taskStatus, nil
}
// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"statuses": statuses,
}).Info("Starting test task with status filter")
// 1. 根据过滤条件查询密钥
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyValues) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
}
// 2. 转换为文本格式并启动任务
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}
// ==================== 核心任务执行逻辑 ====================
// validationParams 验证参数封装
type validationParams struct {
timeout time.Duration
concurrency int
}
// buildValidationParams 构建验证参数
func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
settings := s.settingsManager.GetSettings()
// 从配置读取超时时间(而非硬编码)
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second // 仅在配置无效时使用默认值
}
// 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
var concurrency int
if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
concurrency = *opConfig.KeyCheckConcurrency
} else if settings.BaseKeyCheckConcurrency > 0 {
concurrency = settings.BaseKeyCheckConcurrency
} else {
concurrency = defaultConcurrency // 兜底默认值
}
// 限制最大并发数(防护措施)
if concurrency > maxValidationConcurrency {
concurrency = maxValidationConcurrency
}
return validationParams{
timeout: timeout,
concurrency: concurrency,
}
}
// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
func (s *KeyValidationService) runTestKeysTaskWithRecovery(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
}
// runTestKeysTask 执行批量密钥验证任务
func (s *KeyValidationService) runTestKeysTask(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"key_count": len(keys),
"concurrency": params.concurrency,
"timeout": params.timeout,
}).Info("Starting validation task")
// 1. 初始化结果收集
results := make([]models.KeyTestResult, len(keys))
// 2. 创建任务分发器
dispatcher := newValidationDispatcher(
keys,
params.concurrency,
s,
ctx,
taskID,
groupID,
endpoint,
params.timeout,
)
// 3. 执行并发验证
dispatcher.run(results)
// 4. 完成任务
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"processed": len(results),
}).Info("Validation task completed")
}
// ==================== 验证调度器 ====================
// validationJob 验证作业
type validationJob struct {
index int
key models.APIKey
}
// validationDispatcher 验证任务分发器
type validationDispatcher struct {
keys []models.APIKey
concurrency int
service *KeyValidationService
ctx context.Context
taskID string
groupID uint
endpoint string
timeout time.Duration
mu sync.Mutex
processedCount int
}
// newValidationDispatcher 创建验证分发器
func newValidationDispatcher(
keys []models.APIKey,
concurrency int,
service *KeyValidationService,
ctx context.Context,
taskID string,
groupID uint,
endpoint string,
timeout time.Duration,
) *validationDispatcher {
return &validationDispatcher{
keys: keys,
concurrency: concurrency,
service: service,
ctx: ctx,
taskID: taskID,
groupID: groupID,
endpoint: endpoint,
timeout: timeout,
}
}
// run 执行并发验证
func (d *validationDispatcher) run(results []models.KeyTestResult) {
var wg sync.WaitGroup
jobs := make(chan validationJob, len(d.keys))
// 启动 worker pool
for i := 0; i < d.concurrency; i++ {
wg.Add(1)
go d.worker(&wg, jobs, results)
}
// 分发任务
for i, key := range d.keys {
jobs <- validationJob{index: i, key: key}
}
close(jobs)
// 等待所有 worker 完成
wg.Wait()
}
// worker 验证工作协程
func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
defer wg.Done()
for job := range jobs {
result := d.validateKey(job.key)
d.mu.Lock()
results[job.index] = result
d.processedCount++
_ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
d.mu.Unlock()
}
}
// validateKey 验证单个密钥并返回结果
func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
// 1. 执行验证
validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
// 2. 构建结果和事件
result, event := d.buildResultAndEvent(key, validationErr)
// 3. 发布验证事件
d.publishValidationEvent(key.ID, event)
return result
}
// buildResultAndEvent 构建验证结果和事件
func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
GroupID: &d.groupID,
KeyID: &key.ID,
},
}
if validationErr == nil {
// 验证成功
event.RequestLog.IsSuccess = true
return models.KeyTestResult{
Key: key.APIKey,
Status: "valid",
Message: "Validation successful",
}, event
}
// 验证失败
event.RequestLog.IsSuccess = false
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
event.Error = apiErr
return models.KeyTestResult{
Key: key.APIKey,
Status: "invalid",
Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
}, event
}
// 其他错误
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
return models.KeyTestResult{
Key: key.APIKey,
Status: "error",
Message: "Validation check failed: " + validationErr.Error(),
}, event
}
// publishValidationEvent 发布验证事件
func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
eventData, err := json.Marshal(event)
if err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to marshal validation event")
return
}
if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to publish validation event")
}
}
// ==================== 辅助方法 ====================
// buildValidationError 构建验证错误
func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
bodyBytes, readErr := io.ReadAll(resp.Body) bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string var errorMsg string
if readErr != nil { if readErr != nil {
errorMsg = "Failed to read error response body" errorMsg = "Failed to read error response body"
s.logger.WithError(readErr).Warn("Failed to read validation error response")
} else { } else {
errorMsg = string(bodyBytes) errorMsg = string(bodyBytes)
} }
@@ -89,128 +461,3 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
Code: "VALIDATION_FAILED", Code: "VALIDATION_FAILED",
} }
} }
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
}
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(apiKeyModels) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
}
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
}
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil {
return nil, err
}
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
return nil, err
}
var concurrency int
if opConfig.KeyCheckConcurrency != nil {
concurrency = *opConfig.KeyCheckConcurrency
} else {
concurrency = settings.BaseKeyCheckConcurrency
}
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
return taskStatus, nil
}
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup
var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys))
processedCount := 0
if concurrency <= 0 {
concurrency = 10
}
type job struct {
Index int
Value models.APIKey
}
jobs := make(chan job, len(keys))
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := range jobs {
apiKeyModel := j.Value
keyToValidate := apiKeyModel
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
GroupID: &groupID,
KeyID: &apiKeyModel.ID,
},
}
if validationErr == nil {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
event.RequestLog.IsSuccess = true
} else {
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
event.Error = apiErr
} else {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
}
event.RequestLog.IsSuccess = false
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
}
mu.Lock()
finalResults[j.Index] = currentResult
processedCount++
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
mu.Unlock()
}
}()
}
for i, k := range keys {
jobs <- job{Index: i, Value: k}
}
close(jobs)
wg.Wait()
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
}
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyValues) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -1,78 +1,152 @@
// Filename: internal/service/log_service.go
package service package service
import ( import (
"context"
"fmt"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
) )
type LogService struct { type LogService struct {
db *gorm.DB db *gorm.DB
logger *logrus.Entry
} }
func NewLogService(db *gorm.DB) *LogService { func NewLogService(db *gorm.DB, logger *logrus.Logger) *LogService {
return &LogService{db: db} return &LogService{
db: db,
logger: logger.WithField("component", "LogService"),
}
} }
func (s *LogService) Record(log *models.RequestLog) error { func (s *LogService) Record(ctx context.Context, log *models.RequestLog) error {
return s.db.Create(log).Error return s.db.WithContext(ctx).Create(log).Error
} }
func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, int64, error) { // LogQueryParams 解耦 Gin使用结构体传参
type LogQueryParams struct {
Page int
PageSize int
ModelName string
IsSuccess *bool // 使用指针区分"未设置"和"false"
StatusCode *int
KeyID *uint64
GroupID *uint64
}
func (s *LogService) GetLogs(ctx context.Context, params LogQueryParams) ([]models.RequestLog, int64, error) {
// 参数校验
if params.Page < 1 {
params.Page = 1
}
if params.PageSize < 1 || params.PageSize > 100 {
params.PageSize = 20
}
var logs []models.RequestLog var logs []models.RequestLog
var total int64 var total int64
query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c)) // 构建基础查询
query := s.db.WithContext(ctx).Model(&models.RequestLog{})
query = s.applyFilters(query, params)
// 计算总数 // 计算总数
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {
return nil, 0, err return nil, 0, fmt.Errorf("failed to count logs: %w", err)
} }
if total == 0 { if total == 0 {
return []models.RequestLog{}, 0, nil return []models.RequestLog{}, 0, nil
} }
// 再执行分页查询 // 分页查询
page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) offset := (params.Page - 1) * params.PageSize
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) if err := query.Order("request_time DESC").
offset := (page - 1) * pageSize Limit(params.PageSize).
Offset(offset).
err := query.Order("request_time desc").Limit(pageSize).Offset(offset).Find(&logs).Error Find(&logs).Error; err != nil {
if err != nil { return nil, 0, fmt.Errorf("failed to query logs: %w", err)
return nil, 0, err
} }
return logs, total, nil return logs, total, nil
} }
func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB { func (s *LogService) applyFilters(query *gorm.DB, params LogQueryParams) *gorm.DB {
return func(db *gorm.DB) *gorm.DB { if params.ModelName != "" {
if modelName := c.Query("model_name"); modelName != "" { query = query.Where("model_name = ?", params.ModelName)
db = db.Where("model_name = ?", modelName)
}
if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
db = db.Where("is_success = ?", isSuccess)
}
}
if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
db = db.Where("status_code = ?", statusCode)
}
}
if keyIDStr := c.Query("key_id"); keyIDStr != "" {
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
db = db.Where("key_id = ?", keyID)
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
db = db.Where("group_id = ?", groupID)
}
}
return db
} }
if params.IsSuccess != nil {
query = query.Where("is_success = ?", *params.IsSuccess)
}
if params.StatusCode != nil {
query = query.Where("status_code = ?", *params.StatusCode)
}
if params.KeyID != nil {
query = query.Where("key_id = ?", *params.KeyID)
}
if params.GroupID != nil {
query = query.Where("group_id = ?", *params.GroupID)
}
return query
}
// ParseLogQueryParams 在 Handler 层调用,解析 Gin 参数
func ParseLogQueryParams(queryParams map[string]string) (LogQueryParams, error) {
params := LogQueryParams{
Page: 1,
PageSize: 20,
}
if pageStr, ok := queryParams["page"]; ok {
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
params.Page = page
}
}
if pageSizeStr, ok := queryParams["page_size"]; ok {
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
params.PageSize = pageSize
}
}
if modelName, ok := queryParams["model_name"]; ok {
params.ModelName = modelName
}
if isSuccessStr, ok := queryParams["is_success"]; ok {
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
params.IsSuccess = &isSuccess
} else {
return params, fmt.Errorf("invalid is_success parameter: %s", isSuccessStr)
}
}
if statusCodeStr, ok := queryParams["status_code"]; ok {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
params.StatusCode = &statusCode
} else {
return params, fmt.Errorf("invalid status_code parameter: %s", statusCodeStr)
}
}
if keyIDStr, ok := queryParams["key_id"]; ok {
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
params.KeyID = &keyID
} else {
return params, fmt.Errorf("invalid key_id parameter: %s", keyIDStr)
}
}
if groupIDStr, ok := queryParams["group_id"]; ok {
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
params.GroupID = &groupID
} else {
return params, fmt.Errorf("invalid group_id parameter: %s", groupIDStr)
}
}
return params, nil
} }

View File

@@ -35,34 +35,55 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() { func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.") s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged) go s.listenForEvents()
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
go func() {
defer sub.Close()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
}
}
}()
} }
func (s *StatsService) Stop() { func (s *StatsService) Stop() {
close(s.stopChan) close(s.stopChan)
} }
func (s *StatsService) listenForEvents() {
for {
select {
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
default:
}
ctx, cancel := context.WithCancel(context.Background())
sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
if err != nil {
s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err)
cancel()
time.Sleep(5 * time.Second)
continue
}
s.logger.Info("Subscribed to key status changes")
s.handleSubscription(sub, cancel)
}
}
func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) {
defer sub.Close()
defer cancel()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
return
}
}
}
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) { func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
if event.GroupID == 0 { if event.GroupID == 0 {
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID) s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
@@ -75,23 +96,47 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
switch event.ChangeReason { switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted": case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" { if event.OldStatus != "" {
s.store.HIncrBy(ctx, statsKey, "total_keys", -1) if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil {
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else { } else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID) s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID)
} }
case "key_linked": case "key_linked":
if event.NewStatus != "" { if event.NewStatus != "" {
s.store.HIncrBy(ctx, statsKey, "total_keys", 1) if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil {
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else { } else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID) s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID)
} }
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key": case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
default: default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID) s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID)
@@ -113,13 +158,16 @@ func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uin
} }
statsKey := fmt.Sprintf("stats:group:%d", groupID) statsKey := fmt.Sprintf("stats:group:%d", groupID)
updates := make(map[string]interface{}) updates := map[string]interface{}{
totalKeys := int64(0) "active_keys": int64(0),
"disabled_keys": int64(0),
"error_keys": int64(0),
"total_keys": int64(0),
}
for _, res := range results { for _, res := range results {
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
totalKeys += res.Count updates["total_keys"] = updates["total_keys"].(int64) + res.Count
} }
updates["total_keys"] = totalKeys
if err := s.store.Del(ctx, statsKey); err != nil { if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID) s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
@@ -180,8 +228,18 @@ func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
}) })
} }
return s.db.WithContext(ctx).Clauses(clause.OnConflict{ if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}}, Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}), DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error }).Create(&hourlyStats).Error; err != nil {
return err
}
if err := s.db.WithContext(ctx).
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Delete(&models.RequestLog{}).Error; err != nil {
s.logger.WithError(err).Warn("Failed to delete aggregated request logs")
}
return nil
} }

View File

@@ -37,7 +37,7 @@ func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, log
return tokens, nil return tokens, nil
} }
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged) s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create token manager syncer: %w", err) return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
} }

View File

@@ -87,7 +87,7 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
return settings, nil return settings, nil
} }
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel) s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel, logger,)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create system settings syncer: %w", err) return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
} }

View File

@@ -4,46 +4,54 @@ import (
"context" "context"
"fmt" "fmt"
"gemini-balancer/internal/store" "gemini-balancer/internal/store"
"log"
"sync" "sync"
"time" "time"
"github.com/sirupsen/logrus"
)
const (
ReconnectDelay = 5 * time.Second
ReloadTimeout = 30 * time.Second
) )
// LoaderFunc
type LoaderFunc[T any] func() (T, error) type LoaderFunc[T any] func() (T, error)
// CacheSyncer
type CacheSyncer[T any] struct { type CacheSyncer[T any] struct {
mu sync.RWMutex mu sync.RWMutex
cache T cache T
loader LoaderFunc[T] loader LoaderFunc[T]
store store.Store store store.Store
channelName string channelName string
logger *logrus.Entry
stopChan chan struct{} stopChan chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
} }
// NewCacheSyncer
func NewCacheSyncer[T any]( func NewCacheSyncer[T any](
loader LoaderFunc[T], loader LoaderFunc[T],
store store.Store, store store.Store,
channelName string, channelName string,
logger *logrus.Logger,
) (*CacheSyncer[T], error) { ) (*CacheSyncer[T], error) {
s := &CacheSyncer[T]{ s := &CacheSyncer[T]{
loader: loader, loader: loader,
store: store, store: store,
channelName: channelName, channelName: channelName,
logger: logger.WithField("component", fmt.Sprintf("CacheSyncer[%s]", channelName)),
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
} }
if err := s.reload(); err != nil { if err := s.reload(); err != nil {
return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err) return nil, fmt.Errorf("initial load failed: %w", err)
} }
s.wg.Add(1) s.wg.Add(1)
go s.listenForUpdates() go s.listenForUpdates()
return s, nil return s, nil
} }
// Get, Invalidate, Stop, reload 方法 .
func (s *CacheSyncer[T]) Get() T { func (s *CacheSyncer[T]) Get() T {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
@@ -51,33 +59,60 @@ func (s *CacheSyncer[T]) Get() T {
} }
func (s *CacheSyncer[T]) Invalidate() error { func (s *CacheSyncer[T]) Invalidate() error {
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName) s.logger.Info("Publishing invalidation notification")
return s.store.Publish(context.Background(), s.channelName, []byte("reload")) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.store.Publish(ctx, s.channelName, []byte("reload")); err != nil {
s.logger.WithError(err).Error("Failed to publish invalidation")
return err
}
return nil
} }
func (s *CacheSyncer[T]) Stop() { func (s *CacheSyncer[T]) Stop() {
close(s.stopChan) close(s.stopChan)
s.wg.Wait() s.wg.Wait()
log.Printf("INFO: CacheSyncer for channel '%s' stopped.", s.channelName) s.logger.Info("CacheSyncer stopped")
} }
func (s *CacheSyncer[T]) reload() error { func (s *CacheSyncer[T]) reload() error {
log.Printf("INFO: Reloading cache for channel '%s'...", s.channelName) s.logger.Info("Reloading cache...")
newData, err := s.loader()
if err != nil { ctx, cancel := context.WithTimeout(context.Background(), ReloadTimeout)
log.Printf("ERROR: Failed to reload cache for '%s': %v", s.channelName, err) defer cancel()
return err
type result struct {
data T
err error
}
resultChan := make(chan result, 1)
go func() {
data, err := s.loader()
resultChan <- result{data, err}
}()
select {
case res := <-resultChan:
if res.err != nil {
s.logger.WithError(res.err).Error("Failed to reload cache")
return res.err
}
s.mu.Lock()
s.cache = res.data
s.mu.Unlock()
s.logger.Info("Cache reloaded successfully")
return nil
case <-ctx.Done():
s.logger.Error("Cache reload timeout")
return fmt.Errorf("reload timeout after %v", ReloadTimeout)
} }
s.mu.Lock()
s.cache = newData
s.mu.Unlock()
log.Printf("INFO: Cache for channel '%s' reloaded successfully.", s.channelName)
return nil
} }
// listenForUpdates ...
func (s *CacheSyncer[T]) listenForUpdates() { func (s *CacheSyncer[T]) listenForUpdates() {
defer s.wg.Done() defer s.wg.Done()
for { for {
select { select {
case <-s.stopChan: case <-s.stopChan:
@@ -85,31 +120,39 @@ func (s *CacheSyncer[T]) listenForUpdates() {
default: default:
} }
subscription, err := s.store.Subscribe(context.Background(), s.channelName) if err := s.subscribeAndListen(); err != nil {
if err != nil { s.logger.WithError(err).Warnf("Subscription error, retrying in %v", ReconnectDelay)
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
time.Sleep(5 * time.Second)
continue
}
log.Printf("INFO: Subscribed to channel '%s' for cache invalidation.", s.channelName)
subscriberLoop:
for {
select { select {
case _, ok := <-subscription.Channel(): case <-time.After(ReconnectDelay):
if !ok {
log.Printf("WARN: Subscription channel '%s' closed, will re-subscribe.", s.channelName)
break subscriberLoop
}
log.Printf("INFO: Received invalidation notification on '%s', reloading cache.", s.channelName)
if err := s.reload(); err != nil {
log.Printf("ERROR: Failed to reload cache for '%s' after notification: %v", s.channelName, err)
}
case <-s.stopChan: case <-s.stopChan:
subscription.Close()
return return
} }
} }
subscription.Close() }
}
func (s *CacheSyncer[T]) subscribeAndListen() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
subscription, err := s.store.Subscribe(ctx, s.channelName)
if err != nil {
return fmt.Errorf("failed to subscribe: %w", err)
}
defer subscription.Close()
s.logger.Info("Subscribed to channel")
for {
select {
case msg, ok := <-subscription.Channel():
if !ok {
return fmt.Errorf("subscription channel closed")
}
s.logger.WithField("message", string(msg.Payload)).Info("Received invalidation notification")
if err := s.reload(); err != nil {
s.logger.WithError(err).Error("Failed to reload after notification")
}
case <-s.stopChan:
return nil
}
} }
} }

View File

@@ -1,4 +1,3 @@
// Filename: internal/task/task.go
package task package task
import ( import (
@@ -13,7 +12,9 @@ import (
) )
const ( const (
ResultTTL = 60 * time.Minute ResultTTL = 60 * time.Minute
DefaultTimeout = 24 * time.Hour
LockTTL = 30 * time.Minute
) )
type Reporter interface { type Reporter interface {
@@ -65,14 +66,21 @@ func (s *Task) getIsRunningFlagKey(taskID string) string {
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID) lockKey := s.getResourceLockKey(resourceID)
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 { locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL)
if err != nil {
return nil, fmt.Errorf("failed to acquire task lock: %w", err)
}
if !locked {
existingTaskID, _ := s.store.Get(ctx, lockKey)
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID)) return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
} }
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID) if timeout == 0 {
taskKey := s.getTaskDataKey(taskID) timeout = DefaultTimeout
runningFlagKey := s.getIsRunningFlagKey(taskID) }
status := &Status{ status := &Status{
ID: taskID, ID: taskID,
TaskType: taskType, TaskType: taskType,
@@ -81,63 +89,55 @@ func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourc
Total: total, Total: total,
StartedAt: time.Now(), StartedAt: time.Now(),
} }
statusBytes, err := json.Marshal(status)
if err != nil {
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
}
if timeout == 0 { if err := s.saveStatus(ctx, taskID, status, timeout); err != nil {
timeout = ResultTTL * 24
}
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
}
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(ctx, lockKey) _ = s.store.Del(ctx, lockKey)
return nil, fmt.Errorf("failed to set new task data in store: %w", err) return nil, fmt.Errorf("failed to save task status: %w", err)
} }
runningFlagKey := s.getIsRunningFlagKey(taskID)
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil { if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(ctx, lockKey) _ = s.store.Del(ctx, lockKey)
_ = s.store.Del(ctx, taskKey) _ = s.store.Del(ctx, s.getTaskDataKey(taskID))
return nil, fmt.Errorf("failed to set task running flag: %w", err) return nil, fmt.Errorf("failed to set running flag: %w", err)
} }
return status, nil return status, nil
} }
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) { func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID) lockKey := s.getResourceLockKey(resourceID)
defer func() {
if err := s.store.Del(ctx, lockKey); err != nil {
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
}
}()
runningFlagKey := s.getIsRunningFlagKey(taskID) runningFlagKey := s.getIsRunningFlagKey(taskID)
_ = s.store.Del(ctx, runningFlagKey)
defer func() {
_ = s.store.Del(ctx, lockKey)
_ = s.store.Del(ctx, runningFlagKey)
}()
status, err := s.GetStatus(ctx, taskID) status, err := s.GetStatus(ctx, taskID)
if err != nil { if err != nil {
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID) s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID)
return return
} }
if !status.IsRunning { if !status.IsRunning {
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID) s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID)
return return
} }
now := time.Now() now := time.Now()
status.IsRunning = false status.IsRunning = false
status.FinishedAt = &now status.FinishedAt = &now
status.DurationSeconds = now.Sub(status.StartedAt).Seconds() status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
if taskErr != nil { if taskErr != nil {
status.Error = taskErr.Error() status.Error = taskErr.Error()
} else { } else {
status.Result = resultData status.Result = resultData
} }
updatedTaskBytes, _ := json.Marshal(status)
taskKey := s.getTaskDataKey(taskID) if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil {
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil { s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID)
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
} }
} }
@@ -148,43 +148,42 @@ func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
if errors.Is(err, store.ErrNotFound) { if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found") return nil, errors.New("task not found")
} }
return nil, fmt.Errorf("failed to get task status from store: %w", err) return nil, fmt.Errorf("failed to get task status: %w", err)
} }
var status Status var status Status
if err := json.Unmarshal(statusBytes, &status); err != nil { if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID) return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err)
} }
if !status.IsRunning && status.FinishedAt != nil { if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
} }
return &status, nil return &status, nil
} }
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error { func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID) runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(ctx, runningFlagKey); err != nil { if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
return nil if errors.Is(err, store.ErrNotFound) {
return errors.New("task is not running")
}
return fmt.Errorf("failed to check running flag: %w", err)
} }
status, err := s.GetStatus(ctx, taskID) status, err := s.GetStatus(ctx, taskID)
if err != nil { if err != nil {
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID) return fmt.Errorf("failed to get task status: %w", err)
return nil
} }
if !status.IsRunning { if !status.IsRunning {
return nil return errors.New("task is not running")
} }
updater(status) updater(status)
statusBytes, marshalErr := json.Marshal(status)
if marshalErr != nil { return s.saveStatus(ctx, taskID, status, DefaultTimeout)
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
return nil
}
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
}
return nil
} }
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error { func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
@@ -198,3 +197,17 @@ func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) er
status.Total = total status.Total = total
}) })
} }
func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error {
statusBytes, err := json.Marshal(status)
if err != nil {
return fmt.Errorf("failed to serialize status: %w", err)
}
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil {
return fmt.Errorf("failed to save status: %w", err)
}
return nil
}

View File

@@ -1,43 +1,118 @@
// Filename: internal/webhandlers/auth_handler.go (最终现代化改造版) // Filename: internal/webhandlers/auth_handler.go
package webhandlers package webhandlers
import ( import (
"gemini-balancer/internal/middleware" "gemini-balancer/internal/middleware"
"gemini-balancer/internal/service" // [核心改造] 依赖service层 "gemini-balancer/internal/service"
"net/http" "net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
) )
// WebAuthHandler [核心改造] 依赖关系净化注入SecurityService // WebAuthHandler Web 认证处理器
type WebAuthHandler struct { type WebAuthHandler struct {
securityService *service.SecurityService securityService *service.SecurityService
logger *logrus.Logger
} }
// NewWebAuthHandler [核心改造] 构造函数更新 // NewWebAuthHandler 创建 WebAuthHandler
func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler { func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler {
logger := logrus.New()
logger.SetLevel(logrus.InfoLevel)
return &WebAuthHandler{ return &WebAuthHandler{
securityService: securityService, securityService: securityService,
logger: logger,
} }
} }
// ShowLoginPage 保持不变 // ShowLoginPage 显示登录页面
func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) { func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) {
errMsg := c.Query("error") errMsg := c.Query("error")
from := c.Query("from") // 可以从登录失败的页面返回
// 验证重定向路径(防止开放重定向攻击)
redirectPath := h.validateRedirectPath(c.Query("redirect"))
// 如果已登录,直接重定向
if cookie := middleware.ExtractTokenFromCookie(c); cookie != "" {
if _, err := h.securityService.AuthenticateToken(cookie); err == nil {
c.Redirect(http.StatusFound, redirectPath)
return
}
}
c.HTML(http.StatusOK, "auth.html", gin.H{ c.HTML(http.StatusOK, "auth.html", gin.H{
"error": errMsg, "error": errMsg,
"from": from, "redirect": redirectPath,
}) })
} }
// HandleLogin [核心改造] 认证逻辑完全委托给SecurityService // HandleLogin 已废弃(项目无用户名系统)
func (h *WebAuthHandler) HandleLogin(c *gin.Context) { func (h *WebAuthHandler) HandleLogin(c *gin.Context) {
c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD") c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD")
} }
// HandleLogout 保持不变 // HandleLogout 处理登出请求
func (h *WebAuthHandler) HandleLogout(c *gin.Context) { func (h *WebAuthHandler) HandleLogout(c *gin.Context) {
cookie := middleware.ExtractTokenFromCookie(c)
if cookie != "" {
// 尝试获取 Token 信息用于日志
authToken, err := h.securityService.AuthenticateToken(cookie)
if err == nil {
h.logger.WithFields(logrus.Fields{
"token_id": authToken.ID,
"client_ip": c.ClientIP(),
}).Info("User logged out")
} else {
h.logger.WithField("client_ip", c.ClientIP()).Warn("Logout with invalid token")
}
// 使缓存失效
middleware.InvalidateTokenCache(cookie)
} else {
h.logger.WithField("client_ip", c.ClientIP()).Debug("Logout without session cookie")
}
// 清除 Cookie
middleware.ClearAdminSessionCookie(c) middleware.ClearAdminSessionCookie(c)
// 重定向到登录页
c.Redirect(http.StatusFound, "/login") c.Redirect(http.StatusFound, "/login")
} }
// validateRedirectPath 验证重定向路径(防止开放重定向攻击)
func (h *WebAuthHandler) validateRedirectPath(path string) string {
defaultPath := "/dashboard"
if path == "" {
return defaultPath
}
// 只允许内部路径
if !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
h.logger.WithField("path", path).Warn("Invalid redirect path blocked")
return defaultPath
}
// 白名单验证
allowedPaths := []string{
"/dashboard",
"/keys",
"/settings",
"/logs",
"/tasks",
"/chat",
}
for _, allowed := range allowedPaths {
if strings.HasPrefix(path, allowed) {
return path
}
}
return defaultPath
}

View File

@@ -123,7 +123,7 @@
{% block core_scripts %} {% block core_scripts %}
<script src="https://cdnjs.cloudflare.com/ajax/libs/animejs/3.2.1/anime.min.js"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/animejs/3.2.1/anime.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/sweetalert2@11"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/sweetalert2/11.23.0/sweetalert2.all.min.js"></script>
<script src="/static/js/main.js" type="module" defer></script> <script src="/static/js/main.js" type="module" defer></script>
{% endblock core_scripts %} {% endblock core_scripts %}
<!-- [核心] Block 2: 留给子页面的脚本扩展插槽 --> <!-- [核心] Block 2: 留给子页面的脚本扩展插槽 -->

View File

@@ -492,7 +492,7 @@
type="text" type="text"
id="TEST_MODEL" id="TEST_MODEL"
name="TEST_MODEL" name="TEST_MODEL"
placeholder="gemini-1.5-flash" placeholder="gemini-2.0-flash-lite"
class="flex-grow px-4 py-3 rounded-lg form-input-themed" class="flex-grow px-4 py-3 rounded-lg form-input-themed"
/> />
<button <button