Fix Services & Update the middleware && others
This commit is contained in:
@@ -3,10 +3,12 @@ package main
|
||||
import (
|
||||
"gemini-balancer/internal/app"
|
||||
"gemini-balancer/internal/container"
|
||||
"gemini-balancer/internal/logging"
|
||||
"log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
defer logging.Close()
|
||||
cont, err := container.BuildContainer()
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Failed to build dependency container: %v", err)
|
||||
|
||||
@@ -14,6 +14,12 @@ server:
|
||||
log:
|
||||
level: "debug"
|
||||
|
||||
# 日志轮转配置
|
||||
max_size: 100 # MB
|
||||
max_backups: 7 # 保留文件数
|
||||
max_age: 30 # 保留天数
|
||||
compress: true # 压缩旧日志
|
||||
|
||||
redis:
|
||||
dsn: "redis://localhost:6379/0"
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -17,6 +17,8 @@ require (
|
||||
github.com/spf13/viper v1.20.1
|
||||
go.uber.org/dig v1.19.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/datatypes v1.0.5
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -311,6 +311,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
@@ -325,6 +327,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
||||
@@ -29,7 +29,9 @@ type DatabaseConfig struct {
|
||||
|
||||
// ServerConfig 存储HTTP服务器配置
|
||||
type ServerConfig struct {
|
||||
Port string `mapstructure:"port"`
|
||||
Port string `mapstructure:"port"`
|
||||
Host string `yaml:"host"`
|
||||
CORSOrigins []string `yaml:"cors_origins"`
|
||||
}
|
||||
|
||||
// LogConfig 存储日志配置
|
||||
@@ -38,6 +40,12 @@ type LogConfig struct {
|
||||
Format string `mapstructure:"format" json:"format"`
|
||||
EnableFile bool `mapstructure:"enable_file" json:"enable_file"`
|
||||
FilePath string `mapstructure:"file_path" json:"file_path"`
|
||||
|
||||
// 日志轮转配置(可选)
|
||||
MaxSize int `yaml:"max_size"` // MB,默认 100
|
||||
MaxBackups int `yaml:"max_backups"` // 默认 7
|
||||
MaxAge int `yaml:"max_age"` // 天,默认 30
|
||||
Compress bool `yaml:"compress"` // 默认 true
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
|
||||
@@ -87,7 +87,7 @@ func BuildContainer() (*dig.Container, error) {
|
||||
// 为GroupManager配置Syncer
|
||||
container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) {
|
||||
const groupUpdateChannel = "groups:cache_invalidation"
|
||||
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel)
|
||||
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger)
|
||||
})
|
||||
|
||||
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
|
||||
|
||||
@@ -20,7 +20,7 @@ type Module struct {
|
||||
|
||||
func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) {
|
||||
loader := newManagerLoader(gormDB)
|
||||
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation")
|
||||
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation", logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ var clientNetworkErrorSubstrings = []string{
|
||||
"broken pipe",
|
||||
"use of closed network connection",
|
||||
"request canceled",
|
||||
"invalid query parameters", // 参数解析错误,归类为客户端错误
|
||||
}
|
||||
|
||||
// IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid.
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -19,22 +18,26 @@ func NewLogHandler(logService *service.LogService) *LogHandler {
|
||||
}
|
||||
|
||||
func (h *LogHandler) GetLogs(c *gin.Context) {
|
||||
// 调用新的服务函数,接收日志列表和总数
|
||||
logs, total, err := h.logService.GetLogs(c)
|
||||
queryParams := make(map[string]string)
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
queryParams[key] = values[0]
|
||||
}
|
||||
}
|
||||
params, err := service.ParseLogQueryParams(queryParams)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ErrBadRequest)
|
||||
return
|
||||
}
|
||||
logs, total, err := h.logService.GetLogs(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析分页参数用于响应体
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
// 使用标准的分页响应结构
|
||||
response.Success(c, gin.H{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"page": params.Page,
|
||||
"page_size": params.PageSize,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,20 +9,25 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// 包级变量,用于存储日志轮转器
|
||||
var logRotator *lumberjack.Logger
|
||||
|
||||
// NewLogger 返回标准的 *logrus.Logger(兼容 Fx 依赖注入)
|
||||
func NewLogger(cfg *config.Config) *logrus.Logger {
|
||||
logger := logrus.New()
|
||||
|
||||
// 1. 设置日志级别
|
||||
// 设置日志级别
|
||||
level, err := logrus.ParseLevel(cfg.Log.Level)
|
||||
if err != nil {
|
||||
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level specified, defaulting to 'info'.")
|
||||
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level, defaulting to 'info'")
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
logger.SetLevel(level)
|
||||
|
||||
// 2. 设置日志格式
|
||||
// 设置日志格式
|
||||
if cfg.Log.Format == "json" {
|
||||
logger.SetFormatter(&logrus.JSONFormatter{
|
||||
TimestampFormat: "2006-01-02T15:04:05.000Z07:00",
|
||||
@@ -39,36 +44,57 @@ func NewLogger(cfg *config.Config) *logrus.Logger {
|
||||
})
|
||||
}
|
||||
|
||||
// 3. 设置日志输出
|
||||
// 添加全局字段
|
||||
hostname, _ := os.Hostname()
|
||||
logger = logger.WithFields(logrus.Fields{
|
||||
"service": "gemini-balancer",
|
||||
"hostname": hostname,
|
||||
}).Logger
|
||||
|
||||
// 设置日志输出
|
||||
if cfg.Log.EnableFile {
|
||||
if cfg.Log.FilePath == "" {
|
||||
logger.Warn("Log file is enabled but no file path is specified. Logging to console only.")
|
||||
logger.Warn("Log file enabled but no path specified. Logging to console only")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(cfg.Log.FilePath)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
logger.WithError(err).Warn("Failed to create log directory. Logging to console only.")
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
logger.WithError(err).Warn("Failed to create log directory. Logging to console only")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
}
|
||||
|
||||
logFile, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Failed to open log file. Logging to console only.")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
// 配置日志轮转(保存到包级变量)
|
||||
logRotator = &lumberjack.Logger{
|
||||
Filename: cfg.Log.FilePath,
|
||||
MaxSize: getOrDefault(cfg.Log.MaxSize, 100),
|
||||
MaxBackups: getOrDefault(cfg.Log.MaxBackups, 7),
|
||||
MaxAge: getOrDefault(cfg.Log.MaxAge, 30),
|
||||
Compress: cfg.Log.Compress,
|
||||
}
|
||||
|
||||
// 同时输出到控制台和文件
|
||||
logger.SetOutput(io.MultiWriter(os.Stdout, logFile))
|
||||
logger.WithField("log_file_path", cfg.Log.FilePath).Info("Logging is now configured to output to both console and file.")
|
||||
logger.SetOutput(io.MultiWriter(os.Stdout, logRotator))
|
||||
logger.WithField("log_file", cfg.Log.FilePath).Info("Logging to both console and file")
|
||||
} else {
|
||||
// 仅输出到控制台
|
||||
logger.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
logger.Info("Root logger initialized.")
|
||||
logger.Info("Logger initialized successfully")
|
||||
return logger
|
||||
}
|
||||
|
||||
// Close 关闭日志轮转器(在 main.go 中调用)
|
||||
func Close() {
|
||||
if logRotator != nil {
|
||||
logRotator.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func getOrDefault(value, defaultValue int) int {
|
||||
if value <= 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Filename: internal/middleware/auth.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@@ -7,76 +8,115 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// === API Admin 认证管道 (/admin/* API路由) ===
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
// APIAdminAuthMiddleware 管理后台 API 认证
|
||||
func APIAdminAuthMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractBearerToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Authentication required",
|
||||
Code: "AUTH_MISSING",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ✅ 只传 token 参数(移除 context)
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"})
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Authentication failed")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Invalid authentication",
|
||||
Code: "AUTH_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
|
||||
Error: "Admin access required",
|
||||
Code: "AUTH_FORBIDDEN",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("adminUser", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// === /v1 Proxy 认证 ===
|
||||
|
||||
func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
// ProxyAuthMiddleware 代理请求认证
|
||||
func ProxyAuthMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractProxyToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "API key required",
|
||||
Code: "KEY_MISSING",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ✅ 只传 token 参数(移除 context)
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil {
|
||||
// 通用信息,避免泄露过多信息
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Invalid API key",
|
||||
Code: "KEY_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("authToken", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// extractProxyToken 按优先级提取 token
|
||||
func extractProxyToken(c *gin.Context) string {
|
||||
if key := c.Query("key"); key != "" {
|
||||
return key
|
||||
}
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
// 优先级 1: Authorization Header
|
||||
if token := extractBearerToken(c); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// 优先级 2: X-Api-Key
|
||||
if key := c.GetHeader("X-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
|
||||
// 优先级 3: X-Goog-Api-Key
|
||||
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
return ""
|
||||
|
||||
// 优先级 4: Query 参数(不推荐)
|
||||
return c.Query("key")
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
|
||||
// extractBearerToken 提取 Bearer Token
|
||||
func extractBearerToken(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
return parts[1]
|
||||
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, prefix) {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
|
||||
return strings.TrimSpace(authHeader[len(prefix):])
|
||||
}
|
||||
|
||||
90
internal/middleware/cors.go
Normal file
90
internal/middleware/cors.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Filename: internal/middleware/cors.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string
|
||||
AllowedMethods []string
|
||||
AllowedHeaders []string
|
||||
ExposedHeaders []string
|
||||
AllowCredentials bool
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// 检查 origin 是否允许
|
||||
if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if config.AllowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if len(config.ExposedHeaders) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Expose-Headers",
|
||||
strings.Join(config.ExposedHeaders, ", "))
|
||||
}
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if len(config.AllowedMethods) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods",
|
||||
strings.Join(config.AllowedMethods, ", "))
|
||||
}
|
||||
|
||||
if len(config.AllowedHeaders) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers",
|
||||
strings.Join(config.AllowedHeaders, ", "))
|
||||
}
|
||||
|
||||
if config.MaxAge > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Max-Age",
|
||||
string(rune(config.MaxAge)))
|
||||
}
|
||||
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if allowed == "*" || allowed == origin {
|
||||
return true
|
||||
}
|
||||
// 支持通配符子域名
|
||||
if strings.HasPrefix(allowed, "*.") {
|
||||
domain := strings.TrimPrefix(allowed, "*.")
|
||||
if strings.HasSuffix(origin, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupCORS(r *gin.Engine) {
|
||||
r.Use(CORSMiddleware(CORSConfig{
|
||||
AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"},
|
||||
ExposedHeaders: []string{"X-Request-Id"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
}))
|
||||
}
|
||||
@@ -1,84 +1,213 @@
|
||||
// Filename: internal/middleware/log_redaction.go
|
||||
// Filename: internal/middleware/logging.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const RedactedBodyKey = "redactedBody"
|
||||
const RedactedAuthHeaderKey = "redactedAuthHeader"
|
||||
const RedactedValue = `"[REDACTED]"`
|
||||
const (
|
||||
RedactedBodyKey = "redactedBody"
|
||||
RedactedAuthHeaderKey = "redactedAuthHeader"
|
||||
RedactedValue = `"[REDACTED]"`
|
||||
)
|
||||
|
||||
// 预编译正则表达式(全局变量,提升性能)
|
||||
var (
|
||||
// JSON 敏感字段脱敏
|
||||
jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`)
|
||||
|
||||
// Bearer Token 脱敏
|
||||
bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`)
|
||||
|
||||
// URL 中的 key 参数脱敏
|
||||
queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`)
|
||||
)
|
||||
|
||||
// RedactionMiddleware 请求数据脱敏中间件
|
||||
func RedactionMiddleware() gin.HandlerFunc {
|
||||
// Pre-compile regex for efficiency
|
||||
jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
|
||||
bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
|
||||
return func(c *gin.Context) {
|
||||
// --- 1. Redact Request Body ---
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
|
||||
if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil {
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
bodyString := string(bodyBytes)
|
||||
|
||||
redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
|
||||
|
||||
c.Set(RedactedBodyKey, redactedBody)
|
||||
}
|
||||
}
|
||||
// --- 2. Redact Authorization Header ---
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if bearerTokenPattern.MatchString(authHeader) {
|
||||
redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
|
||||
c.Set(RedactedAuthHeaderKey, redactedHeader)
|
||||
}
|
||||
// 1. 脱敏请求体
|
||||
if shouldRedactBody(c) {
|
||||
redactRequestBody(c)
|
||||
}
|
||||
|
||||
// 2. 脱敏认证头
|
||||
redactAuthHeader(c)
|
||||
|
||||
// 3. 脱敏 URL 查询参数
|
||||
redactQueryParams(c)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// LogrusLogger is a Gin middleware that logs requests using a Logrus logger.
|
||||
// It consumes redacted data prepared by the RedactionMiddleware.
|
||||
// shouldRedactBody 判断是否需要脱敏请求体
|
||||
func shouldRedactBody(c *gin.Context) bool {
|
||||
method := c.Request.Method
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
// 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求
|
||||
return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") &&
|
||||
strings.Contains(contentType, "application/json")
|
||||
}
|
||||
|
||||
// redactRequestBody 脱敏请求体
|
||||
func redactRequestBody(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 恢复请求体供后续使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 脱敏敏感字段
|
||||
bodyString := string(bodyBytes)
|
||||
redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue)
|
||||
|
||||
c.Set(RedactedBodyKey, redactedBody)
|
||||
}
|
||||
|
||||
// redactAuthHeader 脱敏认证头
|
||||
func redactAuthHeader(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if bearerTokenPattern.MatchString(authHeader) {
|
||||
redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
|
||||
c.Set(RedactedAuthHeaderKey, redacted)
|
||||
} else {
|
||||
// 对于非 Bearer 的 token,全部脱敏
|
||||
c.Set(RedactedAuthHeaderKey, "[REDACTED]")
|
||||
}
|
||||
|
||||
// 同时处理其他敏感 Header
|
||||
sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"}
|
||||
for _, header := range sensitiveHeaders {
|
||||
if value := c.GetHeader(header); value != "" {
|
||||
c.Set("redacted_"+header, "[REDACTED]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// redactQueryParams 脱敏 URL 查询参数
|
||||
func redactQueryParams(c *gin.Context) {
|
||||
rawQuery := c.Request.URL.RawQuery
|
||||
if rawQuery == "" {
|
||||
return
|
||||
}
|
||||
|
||||
redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`)
|
||||
if redacted != rawQuery {
|
||||
c.Set("redactedQuery", redacted)
|
||||
}
|
||||
}
|
||||
|
||||
// LogrusLogger Gin 请求日志中间件(使用 Logrus)
|
||||
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
|
||||
// Process request
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// After request, gather data and log
|
||||
// 计算延迟
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
entry := logger.WithFields(logrus.Fields{
|
||||
"status_code": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"client_ip": c.ClientIP(),
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
})
|
||||
// 构建日志字段
|
||||
fields := logrus.Fields{
|
||||
"status": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"ip": clientIP,
|
||||
"method": method,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
// 添加请求 ID(如果存在)
|
||||
if requestID := getRequestID(c); requestID != "" {
|
||||
fields["request_id"] = requestID
|
||||
}
|
||||
|
||||
// 添加脱敏后的数据
|
||||
if redactedBody, exists := c.Get(RedactedBodyKey); exists {
|
||||
entry = entry.WithField("body", redactedBody)
|
||||
fields["body"] = redactedBody
|
||||
}
|
||||
|
||||
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
|
||||
entry = entry.WithField("authorization", redactedAuth)
|
||||
fields["authorization"] = redactedAuth
|
||||
}
|
||||
|
||||
if redactedQuery, exists := c.Get("redactedQuery"); exists {
|
||||
fields["query"] = redactedQuery
|
||||
}
|
||||
|
||||
// 添加用户信息(如果已认证)
|
||||
if user := getAuthenticatedUser(c); user != "" {
|
||||
fields["user"] = user
|
||||
}
|
||||
|
||||
// 根据状态码选择日志级别
|
||||
entry := logger.WithFields(fields)
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
entry.Error(c.Errors.String())
|
||||
fields["errors"] = c.Errors.String()
|
||||
entry.Error("Request failed")
|
||||
} else {
|
||||
entry.Info("request handled")
|
||||
switch {
|
||||
case statusCode >= 500:
|
||||
entry.Error("Server error")
|
||||
case statusCode >= 400:
|
||||
entry.Warn("Client error")
|
||||
case statusCode >= 300:
|
||||
entry.Info("Redirect")
|
||||
default:
|
||||
// 只在 Debug 模式记录成功请求
|
||||
if logger.Level >= logrus.DebugLevel {
|
||||
entry.Debug("Request completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestID 获取请求 ID
|
||||
func getRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get("request_id"); exists {
|
||||
if requestID, ok := id.(string); ok {
|
||||
return requestID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getAuthenticatedUser 获取已认证用户标识
|
||||
func getAuthenticatedUser(c *gin.Context) string {
|
||||
// 尝试从不同来源获取用户信息
|
||||
if user, exists := c.Get("adminUser"); exists {
|
||||
if authToken, ok := user.(interface{ GetID() string }); ok {
|
||||
return authToken.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
if user, exists := c.Get("authToken"); exists {
|
||||
if authToken, ok := user.(interface{ GetID() string }); ok {
|
||||
return authToken.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
86
internal/middleware/rate_limit.go
Normal file
86
internal/middleware/rate_limit.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Filename: internal/middleware/rate_limit.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
limiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
r rate.Limit // 每秒请求数
|
||||
b int // 突发容量
|
||||
}
|
||||
|
||||
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
r: r,
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
limiter, exists := rl.limiters[key]
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(rl.r, rl.b)
|
||||
rl.limiters[key] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 定期清理不活跃的限制器
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
// 简单策略:定期清空(生产环境应该用 LRU)
|
||||
rl.limiters = make(map[string]*rate.Limiter)
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 按 IP 限流
|
||||
key := c.ClientIP()
|
||||
|
||||
// 如果有认证 token,按 token 限流(更精确)
|
||||
if authToken, exists := c.Get("authToken"); exists {
|
||||
if token, ok := authToken.(interface{ GetID() string }); ok {
|
||||
key = "token:" + token.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
l := limiter.getLimiter(key)
|
||||
if !l.Allow() {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"code": "RATE_LIMIT",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupRateLimit(r *gin.Engine) {
|
||||
limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20
|
||||
go limiter.cleanup()
|
||||
|
||||
r.Use(RateLimitMiddleware(limiter))
|
||||
}
|
||||
39
internal/middleware/request_id.go
Normal file
39
internal/middleware/request_id.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Filename: internal/middleware/request_id.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// RequestIDMiddleware 请求 ID 追踪中间件
|
||||
func RequestIDMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 尝试从 Header 获取现有的 Request ID
|
||||
requestID := c.GetHeader("X-Request-Id")
|
||||
|
||||
// 2. 如果没有,生成新的
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 3. 设置到 Context
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
// 4. 返回给客户端(用于追踪)
|
||||
c.Writer.Header().Set("X-Request-Id", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestID 获取当前请求的 Request ID
|
||||
func GetRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get("request_id"); exists {
|
||||
if requestID, ok := id.(string); ok {
|
||||
return requestID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Filename: internal/middleware/security.go
|
||||
// Filename: internal/middleware/security.go (简化版)
|
||||
|
||||
package middleware
|
||||
|
||||
@@ -6,26 +6,136 @@ import (
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc {
|
||||
// 简单的缓存项
|
||||
type cacheItem struct {
|
||||
value bool
|
||||
expiration int64
|
||||
}
|
||||
|
||||
// 简单的 TTL 缓存实现
|
||||
type IPBanCache struct {
|
||||
items map[string]*cacheItem
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewIPBanCache() *IPBanCache {
|
||||
cache := &IPBanCache{
|
||||
items: make(map[string]*cacheItem),
|
||||
ttl: 1 * time.Minute,
|
||||
}
|
||||
|
||||
// 启动清理协程
|
||||
go cache.cleanup()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Get(key string) (bool, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().UnixNano() > item.expiration {
|
||||
return false, false
|
||||
}
|
||||
|
||||
return item.value, true
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Set(key string, value bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.items[key] = &cacheItem{
|
||||
value: value,
|
||||
expiration: time.Now().Add(c.ttl).UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
func (c *IPBanCache) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.mu.Lock()
|
||||
now := time.Now().UnixNano()
|
||||
for key, item := range c.items {
|
||||
if now > item.expiration {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func IPBanMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
settingsManager *settings.SettingsManager,
|
||||
banCache *IPBanCache,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !settingsManager.IsIPBanEnabled() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
ip := c.ClientIP()
|
||||
isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
|
||||
if err != nil {
|
||||
|
||||
// 查缓存
|
||||
if isBanned, exists := banCache.Get(ip); exists {
|
||||
if isBanned {
|
||||
logger.WithField("ip", ip).Debug("IP blocked (cached)")
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "Access denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if isBanned {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁,请稍后再试"})
|
||||
|
||||
// 查数据库
|
||||
ctx := c.Request.Context()
|
||||
isBanned, err := securityService.IsIPBanned(ctx, ip)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
|
||||
|
||||
// 降级策略:允许访问
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
banCache.Set(ip, isBanned)
|
||||
|
||||
if isBanned {
|
||||
logger.WithField("ip", ip).Info("IP blocked")
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "Access denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
52
internal/middleware/timeout.go
Normal file
52
internal/middleware/timeout.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Filename: internal/middleware/timeout.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 创建带超时的 context
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 替换 request context
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
// 使用 channel 等待请求完成
|
||||
finished := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
c.Next()
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
// 请求正常完成
|
||||
return
|
||||
case <-ctx.Done():
|
||||
// 超时
|
||||
c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
|
||||
"error": "Request timeout",
|
||||
"code": "TIMEOUT",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupTimeout(r *gin.Engine) {
|
||||
// 对 API 路由设置 30 秒超时
|
||||
api := r.Group("/api")
|
||||
api.Use(TimeoutMiddleware(30 * time.Second))
|
||||
{
|
||||
// ... API routes
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,151 @@
|
||||
// Filename: internal/middleware/web.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/service"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
AdminSessionCookie = "gemini_admin_session"
|
||||
SessionMaxAge = 3600 * 24 * 7 // 7天
|
||||
CacheTTL = 5 * time.Minute
|
||||
CleanupInterval = 10 * time.Minute // 降低清理频率
|
||||
SessionRefreshTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
// ==================== 缓存层 ====================
|
||||
|
||||
type authCacheEntry struct {
|
||||
Token interface{}
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type authCache struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*authCacheEntry
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
var webAuthCache = newAuthCache(CacheTTL)
|
||||
|
||||
func newAuthCache(ttl time.Duration) *authCache {
|
||||
c := &authCache{
|
||||
cache: make(map[string]*authCacheEntry),
|
||||
ttl: ttl,
|
||||
}
|
||||
go c.cleanupLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *authCache) get(key string) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.cache[key]
|
||||
if !exists || time.Now().After(entry.ExpiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
return entry.Token, true
|
||||
}
|
||||
|
||||
func (c *authCache) set(key string, token interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache[key] = &authCacheEntry{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(c.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *authCache) delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.cache, key)
|
||||
}
|
||||
|
||||
func (c *authCache) cleanupLoop() {
|
||||
ticker := time.NewTicker(CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *authCache) cleanup() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
count := 0
|
||||
for key, entry := range c.cache {
|
||||
if now.After(entry.ExpiresAt) {
|
||||
delete(c.cache, key)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 会话刷新缓存 ====================
|
||||
|
||||
var sessionRefreshCache = struct {
|
||||
sync.RWMutex
|
||||
timestamps map[string]time.Time
|
||||
}{
|
||||
timestamps: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// 定期清理刷新时间戳
|
||||
func init() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
sessionRefreshCache.Lock()
|
||||
now := time.Now()
|
||||
for key, ts := range sessionRefreshCache.timestamps {
|
||||
if now.Sub(ts) > 2*time.Hour {
|
||||
delete(sessionRefreshCache.timestamps, key)
|
||||
}
|
||||
}
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ==================== Cookie 操作 ====================
|
||||
|
||||
func SetAdminSessionCookie(c *gin.Context, adminToken string) {
|
||||
c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true)
|
||||
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true)
|
||||
}
|
||||
|
||||
func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) {
|
||||
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true)
|
||||
}
|
||||
|
||||
func ClearAdminSessionCookie(c *gin.Context) {
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
|
||||
}
|
||||
|
||||
@@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string {
|
||||
return cookie
|
||||
}
|
||||
|
||||
// ==================== 认证中间件 ====================
|
||||
|
||||
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(getLogLevel())
|
||||
|
||||
return func(c *gin.Context) {
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path)
|
||||
log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie)
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
if err != nil {
|
||||
log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
|
||||
} else if !authToken.IsAdmin {
|
||||
log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
|
||||
} else {
|
||||
log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
|
||||
}
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
if cookie == "" {
|
||||
logger.Debug("[WebAuth] No session cookie found")
|
||||
ClearAdminSessionCookie(c)
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
c.Abort()
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
|
||||
if cachedToken, found := webAuthCache.get(cacheKey); found {
|
||||
logger.Debug("[WebAuth] Using cached token")
|
||||
c.Set("adminUser", cachedToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("[WebAuth] Cache miss, authenticating...")
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("[WebAuth] Authentication failed")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
logger.Warn("[WebAuth] User is not admin")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
webAuthCache.set(cacheKey, authToken)
|
||||
logger.Debug("[WebAuth] Authentication success, token cached")
|
||||
|
||||
c.Set("adminUser", authToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
|
||||
if cookie == "" {
|
||||
logger.Debug("No session cookie found")
|
||||
ClearAdminSessionCookie(c)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
if cachedToken, found := webAuthCache.get(cacheKey); found {
|
||||
c.Set("adminUser", cachedToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Token authentication failed")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
logger.Warn("Token valid but user is not admin")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
webAuthCache.set(cacheKey, authToken)
|
||||
c.Set("adminUser", authToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
func hashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func redirectToLogin(c *gin.Context) {
|
||||
if isAjaxRequest(c) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Session expired",
|
||||
"code": "AUTH_REQUIRED",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
originalPath := c.Request.URL.Path
|
||||
if originalPath != "/" && originalPath != "/login" {
|
||||
c.Redirect(http.StatusFound, "/login?redirect="+originalPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func isAjaxRequest(c *gin.Context) bool {
|
||||
// 检查 Content-Type
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查 Accept(优先检查 JSON)
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "application/json") &&
|
||||
!strings.Contains(accept, "text/html") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 兼容旧版 XMLHttpRequest
|
||||
return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
|
||||
}
|
||||
|
||||
func refreshSessionIfNeeded(c *gin.Context, token string) {
|
||||
tokenHash := hashToken(token)
|
||||
|
||||
sessionRefreshCache.RLock()
|
||||
lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash]
|
||||
sessionRefreshCache.RUnlock()
|
||||
|
||||
if !exists || time.Since(lastRefresh) > SessionRefreshTime {
|
||||
SetAdminSessionCookie(c, token)
|
||||
|
||||
sessionRefreshCache.Lock()
|
||||
sessionRefreshCache.timestamps[tokenHash] = time.Now()
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func getLogLevel() logrus.Level {
|
||||
level := os.Getenv("LOG_LEVEL")
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return logrus.DebugLevel
|
||||
case "warn":
|
||||
return logrus.WarnLevel
|
||||
case "error":
|
||||
return logrus.ErrorLevel
|
||||
default:
|
||||
return logrus.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) {
|
||||
return c.Get("adminUser")
|
||||
}
|
||||
|
||||
func InvalidateTokenCache(token string) {
|
||||
tokenHash := hashToken(token)
|
||||
webAuthCache.delete(tokenHash)
|
||||
|
||||
// 同时清理刷新时间戳
|
||||
sessionRefreshCache.Lock()
|
||||
delete(sessionRefreshCache.timestamps, tokenHash)
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
|
||||
func ClearAllAuthCache() {
|
||||
webAuthCache.mu.Lock()
|
||||
webAuthCache.cache = make(map[string]*authCacheEntry)
|
||||
webAuthCache.mu.Unlock()
|
||||
|
||||
sessionRefreshCache.Lock()
|
||||
sessionRefreshCache.timestamps = make(map[string]time.Time)
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
|
||||
// ==================== 调试工具 ====================
|
||||
|
||||
type SessionInfo struct {
|
||||
HasCookie bool `json:"has_cookie"`
|
||||
IsValid bool `json:"is_valid"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
IsCached bool `json:"is_cached"`
|
||||
LastActivity string `json:"last_activity"`
|
||||
}
|
||||
|
||||
func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo {
|
||||
info := SessionInfo{
|
||||
HasCookie: false,
|
||||
IsValid: false,
|
||||
IsAdmin: false,
|
||||
IsCached: false,
|
||||
LastActivity: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
if cookie == "" {
|
||||
return info
|
||||
}
|
||||
|
||||
info.HasCookie = true
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
if _, found := webAuthCache.get(cacheKey); found {
|
||||
info.IsCached = true
|
||||
}
|
||||
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
info.IsValid = true
|
||||
info.IsAdmin = authToken.IsAdmin
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func GetCacheStats() map[string]interface{} {
|
||||
webAuthCache.mu.RLock()
|
||||
cacheSize := len(webAuthCache.cache)
|
||||
webAuthCache.mu.RUnlock()
|
||||
|
||||
sessionRefreshCache.RLock()
|
||||
refreshSize := len(sessionRefreshCache.timestamps)
|
||||
sessionRefreshCache.RUnlock()
|
||||
return map[string]interface{}{
|
||||
"auth_cache_entries": cacheSize,
|
||||
"refresh_cache_entries": refreshSize,
|
||||
"ttl_seconds": int(webAuthCache.ttl.Seconds()),
|
||||
"cleanup_interval": int(CleanupInterval.Seconds()),
|
||||
"session_refresh_time": int(SessionRefreshTime.Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,8 @@ type SystemSettings struct {
|
||||
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
|
||||
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
|
||||
|
||||
KeyCheckSchedulerIntervalSeconds int `json:"key_check_scheduler_interval_seconds" default:"60" name:"Key检查调度器间隔(秒)" category:"健康检查" desc:"动态调度器检查各组是否需要执行健康检查的周期。"`
|
||||
|
||||
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
|
||||
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`
|
||||
|
||||
|
||||
@@ -1,63 +1,96 @@
|
||||
// Filename: internal/pongo/renderer.go
|
||||
|
||||
package pongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/flosch/pongo2/v6"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/render"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Renderer struct {
|
||||
Context pongo2.Context
|
||||
tplSet *pongo2.TemplateSet
|
||||
mu sync.RWMutex
|
||||
globalContext pongo2.Context
|
||||
tplSet *pongo2.TemplateSet
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
func New(directory string, isDebug bool) *Renderer {
|
||||
func New(directory string, isDebug bool, logger *logrus.Logger) *Renderer {
|
||||
loader := pongo2.MustNewLocalFileSystemLoader(directory)
|
||||
tplSet := pongo2.NewSet("gin-pongo-templates", loader)
|
||||
tplSet.Debug = isDebug
|
||||
return &Renderer{Context: make(pongo2.Context), tplSet: tplSet}
|
||||
return &Renderer{
|
||||
globalContext: make(pongo2.Context),
|
||||
tplSet: tplSet,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Instance returns a new render.HTML instance for a single request.
|
||||
func (p *Renderer) Instance(name string, data interface{}) render.Render {
|
||||
var glob pongo2.Context
|
||||
if p.Context != nil {
|
||||
glob = p.Context
|
||||
}
|
||||
// SetGlobalContext 线程安全地设置全局上下文
|
||||
func (p *Renderer) SetGlobalContext(key string, value interface{}) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.globalContext[key] = value
|
||||
}
|
||||
|
||||
// Warmup 预加载模板
|
||||
func (p *Renderer) Warmup(templateNames ...string) error {
|
||||
for _, name := range templateNames {
|
||||
if _, err := p.tplSet.FromCache(name); err != nil {
|
||||
return fmt.Errorf("failed to warmup template '%s': %w", name, err)
|
||||
}
|
||||
}
|
||||
p.logger.WithField("count", len(templateNames)).Info("Templates warmed up")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Renderer) Instance(name string, data interface{}) render.Render {
|
||||
// 安全读取全局上下文
|
||||
p.mu.RLock()
|
||||
glob := make(pongo2.Context, len(p.globalContext))
|
||||
for k, v := range p.globalContext {
|
||||
glob[k] = v
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
// 解析请求数据
|
||||
var context pongo2.Context
|
||||
if data != nil {
|
||||
if ginContext, ok := data.(gin.H); ok {
|
||||
context = pongo2.Context(ginContext)
|
||||
} else if pongoContext, ok := data.(pongo2.Context); ok {
|
||||
context = pongoContext
|
||||
} else if m, ok := data.(map[string]interface{}); ok {
|
||||
context = m
|
||||
} else {
|
||||
switch v := data.(type) {
|
||||
case gin.H:
|
||||
context = pongo2.Context(v)
|
||||
case pongo2.Context:
|
||||
context = v
|
||||
case map[string]interface{}:
|
||||
context = v
|
||||
default:
|
||||
context = make(pongo2.Context)
|
||||
}
|
||||
} else {
|
||||
context = make(pongo2.Context)
|
||||
}
|
||||
|
||||
// 合并上下文(请求数据优先)
|
||||
for k, v := range glob {
|
||||
if _, ok := context[k]; !ok {
|
||||
if _, exists := context[k]; !exists {
|
||||
context[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 加载模板
|
||||
tpl, err := p.tplSet.FromCache(name)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to load template '%s': %v", name, err))
|
||||
p.logger.WithError(err).WithField("template", name).Error("Failed to load template")
|
||||
return &ErrorHTML{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("template load error: %s", name),
|
||||
}
|
||||
}
|
||||
|
||||
return &HTML{
|
||||
p: p,
|
||||
Template: tpl,
|
||||
Name: name,
|
||||
Data: context,
|
||||
@@ -65,7 +98,6 @@ func (p *Renderer) Instance(name string, data interface{}) render.Render {
|
||||
}
|
||||
|
||||
type HTML struct {
|
||||
p *Renderer
|
||||
Template *pongo2.Template
|
||||
Name string
|
||||
Data pongo2.Context
|
||||
@@ -82,15 +114,31 @@ func (h *HTML) Render(w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
func (h *HTML) WriteContentType(w http.ResponseWriter) {
|
||||
header := w.Header()
|
||||
if val := header["Content-Type"]; len(val) == 0 {
|
||||
header["Content-Type"] = []string{"text/html; charset=utf-8"}
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorHTML 错误渲染器
|
||||
type ErrorHTML struct {
|
||||
StatusCode int
|
||||
Error error
|
||||
}
|
||||
|
||||
func (e *ErrorHTML) Render(w http.ResponseWriter) error {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(e.StatusCode)
|
||||
_, err := w.Write([]byte(e.Error.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *ErrorHTML) WriteContentType(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
}
|
||||
|
||||
// C 获取或创建 pongo2 上下文
|
||||
func C(ctx *gin.Context) pongo2.Context {
|
||||
p, exists := ctx.Get("pongo2")
|
||||
if exists {
|
||||
if p, exists := ctx.Get("pongo2"); exists {
|
||||
if pCtx, ok := p.(pongo2.Context); ok {
|
||||
return pCtx
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewRouter(
|
||||
@@ -42,70 +43,214 @@ func NewRouter(
|
||||
upstreamModule *upstream.Module,
|
||||
proxyModule *proxy.Module,
|
||||
) *gin.Engine {
|
||||
// === 1. 创建全局 Logger(统一管理)===
|
||||
logger := createLogger(cfg)
|
||||
|
||||
// === 2. 设置 Gin 运行模式 ===
|
||||
if cfg.Log.Level != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
router := gin.Default()
|
||||
|
||||
router.Static("/static", "./web/static")
|
||||
// CORS 配置
|
||||
config := cors.Config{
|
||||
// 允许前端的来源。在生产环境中,需改为实际域名
|
||||
AllowOrigins: []string{"http://localhost:9000"},
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"},
|
||||
ExposeHeaders: []string{"Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * time.Hour,
|
||||
}
|
||||
router.Use(cors.New(config))
|
||||
isDebug := gin.Mode() != gin.ReleaseMode
|
||||
router.HTMLRender = pongo.New("web/templates", isDebug)
|
||||
// === 3. 创建 Router(使用 gin.New() 以便完全控制中间件)===
|
||||
router := gin.New()
|
||||
|
||||
// --- 基础设施 ---
|
||||
router.GET("/", func(c *gin.Context) { c.Redirect(http.StatusMovedPermanently, "/dashboard") })
|
||||
router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })
|
||||
// --- 统一的认证管道 ---
|
||||
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService)
|
||||
// === 4. 注册全局中间件(按执行顺序)===
|
||||
setupGlobalMiddleware(router, logger)
|
||||
|
||||
// === 5. 配置静态文件和模板 ===
|
||||
setupStaticAndTemplates(router, logger)
|
||||
|
||||
// === 6. 配置 CORS ===
|
||||
setupCORS(router, cfg)
|
||||
|
||||
// === 7. 注册基础路由 ===
|
||||
setupBasicRoutes(router)
|
||||
|
||||
// === 8. 创建认证中间件 ===
|
||||
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger)
|
||||
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
|
||||
|
||||
router.Use(gin.RecoveryWithWriter(os.Stdout))
|
||||
// --- 将正确的依赖和中间件管道传递下去 ---
|
||||
registerProxyRoutes(router, proxyHandler, securityService)
|
||||
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
|
||||
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager)
|
||||
// === 9. 注册业务路由(按功能分组)===
|
||||
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger)
|
||||
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
|
||||
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler,
|
||||
logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
|
||||
registerProxyRoutes(router, proxyHandler, securityService, logger)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
// createLogger 创建并配置全局 Logger
|
||||
func createLogger(cfg *config.Config) *logrus.Logger {
|
||||
logger := logrus.New()
|
||||
|
||||
// 设置日志格式
|
||||
if cfg.Log.Format == "json" {
|
||||
logger.SetFormatter(&logrus.JSONFormatter{
|
||||
TimestampFormat: time.RFC3339,
|
||||
})
|
||||
} else {
|
||||
logger.SetFormatter(&logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
})
|
||||
}
|
||||
|
||||
// 设置日志级别
|
||||
switch cfg.Log.Level {
|
||||
case "debug":
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
case "info":
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
case "warn":
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
case "error":
|
||||
logger.SetLevel(logrus.ErrorLevel)
|
||||
default:
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
}
|
||||
|
||||
// 设置输出(可选:输出到文件)
|
||||
logger.SetOutput(os.Stdout)
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// setupGlobalMiddleware 设置全局中间件
|
||||
func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) {
|
||||
// 1. 请求 ID 中间件(用于链路追踪)
|
||||
router.Use(middleware.RequestIDMiddleware())
|
||||
|
||||
// 2. 数据脱敏中间件(在日志前执行)
|
||||
router.Use(middleware.RedactionMiddleware())
|
||||
|
||||
// 3. 日志中间件
|
||||
router.Use(middleware.LogrusLogger(logger))
|
||||
|
||||
// 4. 错误恢复中间件
|
||||
router.Use(gin.RecoveryWithWriter(os.Stdout))
|
||||
}
|
||||
|
||||
// setupStaticAndTemplates 配置静态文件和模板
|
||||
func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) {
|
||||
router.Static("/static", "./web/static")
|
||||
|
||||
isDebug := gin.Mode() != gin.ReleaseMode
|
||||
router.HTMLRender = pongo.New("web/templates", isDebug, logger)
|
||||
}
|
||||
|
||||
// setupCORS 配置 CORS
|
||||
func setupCORS(router *gin.Engine, cfg *config.Config) {
|
||||
corsConfig := cors.Config{
|
||||
AllowOrigins: getCORSOrigins(cfg),
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"},
|
||||
ExposeHeaders: []string{"Content-Length", "X-Request-Id"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * time.Hour,
|
||||
}
|
||||
router.Use(cors.New(corsConfig))
|
||||
}
|
||||
|
||||
// getCORSOrigins 获取 CORS 允许的来源
|
||||
func getCORSOrigins(cfg *config.Config) []string {
|
||||
// 默认值
|
||||
origins := []string{"http://localhost:9000"}
|
||||
|
||||
// 从配置读取(修复:移除 nil 检查)
|
||||
if len(cfg.Server.CORSOrigins) > 0 {
|
||||
origins = cfg.Server.CORSOrigins
|
||||
}
|
||||
|
||||
return origins
|
||||
}
|
||||
|
||||
// setupBasicRoutes 设置基础路由
|
||||
func setupBasicRoutes(router *gin.Engine) {
|
||||
// 根路径重定向
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, "/dashboard")
|
||||
})
|
||||
|
||||
// 健康检查
|
||||
router.GET("/health", handleHealthCheck)
|
||||
|
||||
// 版本信息(可选)
|
||||
router.GET("/version", handleVersion)
|
||||
}
|
||||
|
||||
// handleHealthCheck 健康检查处理器
|
||||
func handleHealthCheck(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"time": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleVersion 版本信息处理器
|
||||
func handleVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"version": "1.0.0", // 可以从配置或编译时变量读取
|
||||
"build": "latest",
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 路由注册函数 ====================
|
||||
|
||||
// registerProxyRoutes 注册代理路由
|
||||
func registerProxyRoutes(
|
||||
router *gin.Engine, proxyHandler *handlers.ProxyHandler, securityService *service.SecurityService,
|
||||
router *gin.Engine,
|
||||
proxyHandler *handlers.ProxyHandler,
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) {
|
||||
// 通用的代理认证中间件
|
||||
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService)
|
||||
// --- 模式一: 智能聚合模式 (根路径) ---
|
||||
// /v1 和 /v1beta 路径作为默认入口,服务于 BasePool 聚合逻辑
|
||||
// 创建代理认证中间件
|
||||
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger)
|
||||
|
||||
// 模式一: 智能聚合模式(默认入口)
|
||||
registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
|
||||
|
||||
// 模式二: 精确路由模式(按组名路由)
|
||||
registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
|
||||
}
|
||||
|
||||
// registerAggregateProxyRoutes 注册聚合代理路由
|
||||
func registerAggregateProxyRoutes(
|
||||
router *gin.Engine,
|
||||
proxyHandler *handlers.ProxyHandler,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
) {
|
||||
// /v1 路径组
|
||||
v1 := router.Group("/v1")
|
||||
v1.Use(proxyAuthMiddleware)
|
||||
v1.Use(authMiddleware)
|
||||
{
|
||||
v1.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
|
||||
// /v1beta 路径组
|
||||
v1beta := router.Group("/v1beta")
|
||||
v1beta.Use(proxyAuthMiddleware)
|
||||
v1beta.Use(authMiddleware)
|
||||
{
|
||||
v1beta.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
// --- 模式二: 精确路由模式 (/proxy/:group_name) ---
|
||||
// 创建一个新的、物理隔离的路由组,用于按组名精确路由
|
||||
}
|
||||
|
||||
// registerGroupProxyRoutes 注册分组代理路由
|
||||
func registerGroupProxyRoutes(
|
||||
router *gin.Engine,
|
||||
proxyHandler *handlers.ProxyHandler,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
) {
|
||||
proxyGroup := router.Group("/proxy/:group_name")
|
||||
proxyGroup.Use(proxyAuthMiddleware)
|
||||
proxyGroup.Use(authMiddleware)
|
||||
{
|
||||
// 捕获所有子路径 (例如 /v1/chat/completions),并全部交给同一个 ProxyHandler。
|
||||
proxyGroup.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAdminRoutes
|
||||
// registerAdminRoutes 注册管理后台 API 路由
|
||||
func registerAdminRoutes(
|
||||
router *gin.Engine,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
@@ -121,74 +266,112 @@ func registerAdminRoutes(
|
||||
) {
|
||||
admin := router.Group("/admin", authMiddleware)
|
||||
{
|
||||
// --- KeyGroup Base Routes ---
|
||||
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
|
||||
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
|
||||
// --- KeyGroup Specific Routes (by :id) ---
|
||||
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
|
||||
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
|
||||
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
|
||||
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
|
||||
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
|
||||
// --- APIKey Sub-resource Routes under a KeyGroup ---
|
||||
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
|
||||
{
|
||||
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
|
||||
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
|
||||
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
|
||||
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
|
||||
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
|
||||
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
|
||||
}
|
||||
// KeyGroup 路由
|
||||
registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler)
|
||||
|
||||
// Global key operations
|
||||
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
|
||||
// admin.PUT("/apikeys/:id", apiKeyHandler.UpdateAPIKey) // DEPRECATED: Status is now contextual
|
||||
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) // Test keys globally
|
||||
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) // Hard delete a single key
|
||||
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) // Hard delete multiple keys
|
||||
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) // Restore multiple keys globally
|
||||
// APIKey 全局路由
|
||||
registerAPIKeyRoutes(admin, apiKeyHandler)
|
||||
|
||||
// --- Global Routes ---
|
||||
admin.GET("/tokens", tokensHandler.GetAllTokens)
|
||||
admin.PUT("/tokens", tokensHandler.UpdateTokens)
|
||||
admin.GET("/logs", logHandler.GetLogs)
|
||||
admin.GET("/settings", settingHandler.GetSettings)
|
||||
admin.PUT("/settings", settingHandler.UpdateSettings)
|
||||
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
|
||||
// 系统管理路由
|
||||
registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler)
|
||||
|
||||
// 用于查询异步任务的状态
|
||||
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
|
||||
// 仪表盘路由
|
||||
registerDashboardRoutes(admin, dashboardHandler)
|
||||
|
||||
// 领域模块
|
||||
// 领域模块路由
|
||||
upstreamModule.RegisterRoutes(admin)
|
||||
proxyModule.RegisterRoutes(admin)
|
||||
// --- 全局仪表盘路由 ---
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/overview", dashboardHandler.GetOverview)
|
||||
dashboard.GET("/chart", dashboardHandler.GetChart)
|
||||
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) // 点击详情
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebRoutes
|
||||
// registerKeyGroupRoutes 注册 KeyGroup 相关路由
|
||||
func registerKeyGroupRoutes(
|
||||
admin *gin.RouterGroup,
|
||||
keyGroupHandler *handlers.KeyGroupHandler,
|
||||
apiKeyHandler *handlers.APIKeyHandler,
|
||||
) {
|
||||
// 基础路由
|
||||
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
|
||||
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
|
||||
|
||||
// 特定 KeyGroup 路由
|
||||
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
|
||||
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
|
||||
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
|
||||
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
|
||||
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
|
||||
|
||||
// KeyGroup 的 APIKey 子资源
|
||||
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
|
||||
{
|
||||
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
|
||||
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
|
||||
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
|
||||
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
|
||||
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
|
||||
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAPIKeyRoutes 注册 APIKey 全局路由
|
||||
func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) {
|
||||
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
|
||||
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys)
|
||||
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey)
|
||||
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys)
|
||||
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys)
|
||||
}
|
||||
|
||||
// registerSystemRoutes 注册系统管理路由
|
||||
func registerSystemRoutes(
|
||||
admin *gin.RouterGroup,
|
||||
tokensHandler *handlers.TokensHandler,
|
||||
logHandler *handlers.LogHandler,
|
||||
settingHandler *handlers.SettingHandler,
|
||||
taskHandler *handlers.TaskHandler,
|
||||
) {
|
||||
// Token 管理
|
||||
admin.GET("/tokens", tokensHandler.GetAllTokens)
|
||||
admin.PUT("/tokens", tokensHandler.UpdateTokens)
|
||||
|
||||
// 日志管理
|
||||
admin.GET("/logs", logHandler.GetLogs)
|
||||
|
||||
// 设置管理
|
||||
admin.GET("/settings", settingHandler.GetSettings)
|
||||
admin.PUT("/settings", settingHandler.UpdateSettings)
|
||||
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
|
||||
|
||||
// 任务管理
|
||||
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
|
||||
}
|
||||
|
||||
// registerDashboardRoutes 注册仪表盘路由
|
||||
func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) {
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/overview", dashboardHandler.GetOverview)
|
||||
dashboard.GET("/chart", dashboardHandler.GetChart)
|
||||
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats)
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebRoutes 注册 Web 页面路由
|
||||
func registerWebRoutes(
|
||||
router *gin.Engine,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
webAuthHandler *webhandlers.WebAuthHandler,
|
||||
pageHandler *webhandlers.PageHandler,
|
||||
) {
|
||||
// 公开的认证路由
|
||||
router.GET("/login", webAuthHandler.ShowLoginPage)
|
||||
router.POST("/login", webAuthHandler.HandleLogin)
|
||||
router.GET("/logout", webAuthHandler.HandleLogout)
|
||||
// For Test only router.Run("127.0.0.1:9000")
|
||||
// 受保护的Admin Web界面
|
||||
|
||||
// 受保护的管理界面
|
||||
webGroup := router.Group("/", authMiddleware)
|
||||
webGroup.Use(authMiddleware)
|
||||
{
|
||||
webGroup.GET("/keys", pageHandler.ShowKeysPage)
|
||||
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
|
||||
@@ -197,14 +380,31 @@ func registerWebRoutes(
|
||||
webGroup.GET("/tasks", pageHandler.ShowTasksPage)
|
||||
webGroup.GET("/chat", pageHandler.ShowChatPage)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// registerPublicAPIRoutes 无需后台登录的公共API路由
|
||||
func registerPublicAPIRoutes(router *gin.Engine, apiAuthHandler *handlers.APIAuthHandler, securityService *service.SecurityService, settingsManager *settings.SettingsManager) {
|
||||
ipBanMiddleware := middleware.IPBanMiddleware(securityService, settingsManager)
|
||||
publicAPIGroup := router.Group("/api")
|
||||
// registerPublicAPIRoutes 注册公共 API 路由
|
||||
func registerPublicAPIRoutes(
|
||||
router *gin.Engine,
|
||||
apiAuthHandler *handlers.APIAuthHandler,
|
||||
securityService *service.SecurityService,
|
||||
settingsManager *settings.SettingsManager,
|
||||
logger *logrus.Logger,
|
||||
) {
|
||||
// 创建 IP 封禁中间件
|
||||
ipBanCache := middleware.NewIPBanCache()
|
||||
ipBanMiddleware := middleware.IPBanMiddleware(
|
||||
securityService,
|
||||
settingsManager,
|
||||
ipBanCache,
|
||||
logger,
|
||||
)
|
||||
|
||||
// 公共 API 路由组
|
||||
publicAPI := router.Group("/api")
|
||||
{
|
||||
publicAPIGroup.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
|
||||
publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
|
||||
// 可以在这里添加其他公共 API 路由
|
||||
// publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister)
|
||||
// publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,93 +5,179 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
flushLoopInterval = 1 * time.Minute
|
||||
defaultFlushInterval = 1 * time.Minute
|
||||
maxRetryAttempts = 3
|
||||
retryDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
type AnalyticsServiceLogger struct{ *logrus.Entry }
|
||||
|
||||
type AnalyticsService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
dialect dialect.DialectAdapter
|
||||
settingsManager *settings.SettingsManager
|
||||
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
dialect dialect.DialectAdapter
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 统计指标
|
||||
eventsReceived atomic.Uint64
|
||||
eventsProcessed atomic.Uint64
|
||||
eventsFailed atomic.Uint64
|
||||
flushCount atomic.Uint64
|
||||
recordsFlushed atomic.Uint64
|
||||
flushErrors atomic.Uint64
|
||||
lastFlushTime time.Time
|
||||
lastFlushMutex sync.RWMutex
|
||||
|
||||
// 运行时配置
|
||||
flushInterval time.Duration
|
||||
configMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService {
|
||||
func NewAnalyticsService(
|
||||
db *gorm.DB,
|
||||
s store.Store,
|
||||
logger *logrus.Logger,
|
||||
d dialect.DialectAdapter,
|
||||
settingsManager *settings.SettingsManager,
|
||||
) *AnalyticsService {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &AnalyticsService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "Analytics📊"),
|
||||
stopChan: make(chan struct{}),
|
||||
dialect: d,
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "Analytics📊"),
|
||||
dialect: d,
|
||||
settingsManager: settingsManager,
|
||||
stopChan: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
flushInterval: defaultFlushInterval,
|
||||
lastFlushTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Start() {
|
||||
s.wg.Add(2)
|
||||
go s.flushLoop()
|
||||
s.wg.Add(3)
|
||||
go s.eventListener()
|
||||
s.logger.Info("AnalyticsService (Command Side) started.")
|
||||
go s.flushLoop()
|
||||
go s.metricsReporter()
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"flush_interval": s.flushInterval,
|
||||
}).Info("AnalyticsService started")
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Stop() {
|
||||
s.logger.Info("AnalyticsService stopping...")
|
||||
close(s.stopChan)
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
|
||||
|
||||
s.logger.Info("Performing final data flush...")
|
||||
s.flushToDB()
|
||||
s.logger.Info("AnalyticsService final data flush completed.")
|
||||
|
||||
// 输出最终统计
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"events_received": s.eventsReceived.Load(),
|
||||
"events_processed": s.eventsProcessed.Load(),
|
||||
"events_failed": s.eventsFailed.Load(),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"records_flushed": s.recordsFlushed.Load(),
|
||||
"flush_errors": s.flushErrors.Load(),
|
||||
}).Info("AnalyticsService stopped")
|
||||
}
|
||||
|
||||
// 事件监听循环
|
||||
func (s *AnalyticsService) eventListener() {
|
||||
defer s.wg.Done()
|
||||
sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
|
||||
|
||||
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
s.logger.Info("AnalyticsService subscribed to request events.")
|
||||
defer func() {
|
||||
if err := sub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close subscription")
|
||||
}
|
||||
}()
|
||||
|
||||
s.logger.Info("Subscribed to request events for analytics")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleAnalyticsEvent(&event)
|
||||
s.handleMessage(msg)
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("AnalyticsService stopping event listener.")
|
||||
s.logger.Info("Event listener stopping")
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
s.logger.Info("Event listener context cancelled")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) {
|
||||
if event.RequestLog.GroupID == nil {
|
||||
// 处理单条消息
|
||||
func (s *AnalyticsService) handleMessage(msg *store.Message) {
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to unmarshal analytics event")
|
||||
s.eventsFailed.Add(1)
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
|
||||
|
||||
s.eventsReceived.Add(1)
|
||||
|
||||
if err := s.handleAnalyticsEvent(&event); err != nil {
|
||||
s.eventsFailed.Add(1)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"correlation_id": event.CorrelationID,
|
||||
"group_id": event.RequestLog.GroupID,
|
||||
}).WithError(err).Warn("Failed to process analytics event")
|
||||
} else {
|
||||
s.eventsProcessed.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理分析事件
|
||||
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
|
||||
if event.RequestLog.GroupID == nil {
|
||||
return nil // 跳过无 GroupID 的事件
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
|
||||
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
|
||||
|
||||
pipe := s.store.Pipeline(ctx)
|
||||
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
|
||||
|
||||
if event.RequestLog.IsSuccess {
|
||||
pipe.HIncrBy(key, fieldPrefix+":success", 1)
|
||||
}
|
||||
@@ -101,80 +187,213 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
|
||||
if event.RequestLog.CompletionTokens > 0 {
|
||||
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
|
||||
}
|
||||
|
||||
// 设置过期时间(保留48小时)
|
||||
pipe.Expire(key, 48*time.Hour)
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, *event.RequestLog.GroupID, err)
|
||||
return fmt.Errorf("redis pipeline failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新循环
|
||||
func (s *AnalyticsService) flushLoop() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(flushLoopInterval)
|
||||
|
||||
s.configMutex.RLock()
|
||||
interval := s.flushInterval
|
||||
s.configMutex.RUnlock()
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.logger.WithField("interval", interval).Info("Flush loop started")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.flushToDB()
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Flush loop stopping")
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 刷写到数据库
|
||||
func (s *AnalyticsService) flushToDB() {
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
keysToFlush := []string{
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")),
|
||||
}
|
||||
keysToFlush := s.generateFlushKeys(now)
|
||||
|
||||
totalRecords := 0
|
||||
totalErrors := 0
|
||||
|
||||
for _, key := range keysToFlush {
|
||||
data, err := s.store.HGetAll(ctx, key)
|
||||
if err != nil || len(data) == 0 {
|
||||
continue
|
||||
records, err := s.flushSingleKey(ctx, key, now)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
|
||||
totalErrors++
|
||||
s.flushErrors.Add(1)
|
||||
} else {
|
||||
totalRecords += records
|
||||
}
|
||||
}
|
||||
|
||||
statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data)
|
||||
s.recordsFlushed.Add(uint64(totalRecords))
|
||||
s.flushCount.Add(1)
|
||||
|
||||
if len(statsToFlush) > 0 {
|
||||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||||
[]string{"time", "group_id", "model_name"},
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
|
||||
)
|
||||
err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
|
||||
_ = s.store.HDel(ctx, key, parsedFields...)
|
||||
}
|
||||
}
|
||||
s.lastFlushMutex.Lock()
|
||||
s.lastFlushTime = time.Now()
|
||||
s.lastFlushMutex.Unlock()
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
if totalRecords > 0 || totalErrors > 0 {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"records_flushed": totalRecords,
|
||||
"keys_processed": len(keysToFlush),
|
||||
"errors": totalErrors,
|
||||
"duration": duration,
|
||||
}).Info("Analytics data flush completed")
|
||||
} else {
|
||||
s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
|
||||
}
|
||||
}
|
||||
|
||||
// 生成需要刷新的 Redis 键
|
||||
func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
|
||||
keys := make([]string, 0, 4)
|
||||
|
||||
// 当前小时
|
||||
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
|
||||
|
||||
// 前3个小时(处理延迟和时区问题)
|
||||
for i := 1; i <= 3; i++ {
|
||||
pastHour := now.Add(-time.Duration(i) * time.Hour)
|
||||
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// 刷写单个 Redis 键
|
||||
func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
|
||||
data, err := s.store.HGetAll(ctx, key)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get hash data: %w", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return 0, nil // 无数据,跳过
|
||||
}
|
||||
|
||||
// 解析时间戳
|
||||
hourStr := strings.TrimPrefix(key, "analytics:hourly:")
|
||||
recordTime, err := time.Parse("2006-01-02T15", hourStr)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
|
||||
recordTime = baseTime.Truncate(time.Hour)
|
||||
}
|
||||
|
||||
statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
|
||||
|
||||
if len(statsToFlush) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 使用事务 + 重试机制
|
||||
var dbErr error
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
|
||||
if dbErr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if attempt < maxRetryAttempts {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"attempt": attempt,
|
||||
"key": key,
|
||||
}).WithError(dbErr).Warn("Database upsert failed, retrying...")
|
||||
time.Sleep(retryDelay)
|
||||
}
|
||||
}
|
||||
|
||||
if dbErr != nil {
|
||||
return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
|
||||
}
|
||||
|
||||
// 删除已处理的字段
|
||||
if len(parsedFields) > 0 {
|
||||
if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
|
||||
s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
|
||||
}
|
||||
}
|
||||
|
||||
return len(statsToFlush), nil
|
||||
}
|
||||
|
||||
// 使用事务批量 upsert
|
||||
func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||||
[]string{"time", "group_id", "model_name"},
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
|
||||
)
|
||||
return tx.Clauses(upsertClause).Create(&stats).Error
|
||||
})
|
||||
}
|
||||
|
||||
// 解析 Redis Hash 数据
|
||||
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
|
||||
tempAggregator := make(map[string]*models.StatsHourly)
|
||||
var parsedFields []string
|
||||
parsedFields := make([]string, 0, len(data))
|
||||
|
||||
for field, valueStr := range data {
|
||||
parts := strings.Split(field, ":")
|
||||
if len(parts) != 3 {
|
||||
s.logger.WithField("field", field).Warn("Invalid field format")
|
||||
continue
|
||||
}
|
||||
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
|
||||
|
||||
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
|
||||
aggKey := groupIDStr + ":" + modelName
|
||||
|
||||
if _, ok := tempAggregator[aggKey]; !ok {
|
||||
gid, err := strconv.Atoi(groupIDStr)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"field": field,
|
||||
"group_id": groupIDStr,
|
||||
}).Warn("Invalid group ID")
|
||||
continue
|
||||
}
|
||||
|
||||
tempAggregator[aggKey] = &models.StatsHourly{
|
||||
Time: t,
|
||||
GroupID: uint(gid),
|
||||
ModelName: modelName,
|
||||
}
|
||||
}
|
||||
val, _ := strconv.ParseInt(valueStr, 10, 64)
|
||||
|
||||
val, err := strconv.ParseInt(valueStr, 10, 64)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"field": field,
|
||||
"value": valueStr,
|
||||
}).Warn("Invalid counter value")
|
||||
continue
|
||||
}
|
||||
|
||||
switch counterType {
|
||||
case "requests":
|
||||
tempAggregator[aggKey].RequestCount = val
|
||||
@@ -184,14 +403,92 @@ func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]strin
|
||||
tempAggregator[aggKey].PromptTokens = val
|
||||
case "completion":
|
||||
tempAggregator[aggKey].CompletionTokens = val
|
||||
default:
|
||||
s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
|
||||
continue
|
||||
}
|
||||
|
||||
parsedFields = append(parsedFields, field)
|
||||
}
|
||||
var result []models.StatsHourly
|
||||
|
||||
result := make([]models.StatsHourly, 0, len(tempAggregator))
|
||||
for _, stats := range tempAggregator {
|
||||
if stats.RequestCount > 0 {
|
||||
result = append(result, *stats)
|
||||
}
|
||||
}
|
||||
|
||||
return result, parsedFields
|
||||
}
|
||||
|
||||
// 定期输出统计信息
|
||||
func (s *AnalyticsService) metricsReporter() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.reportMetrics()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) reportMetrics() {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
received := s.eventsReceived.Load()
|
||||
processed := s.eventsProcessed.Load()
|
||||
failed := s.eventsFailed.Load()
|
||||
|
||||
var successRate float64
|
||||
if received > 0 {
|
||||
successRate = float64(processed) / float64(received) * 100
|
||||
}
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"events_received": received,
|
||||
"events_processed": processed,
|
||||
"events_failed": failed,
|
||||
"success_rate": fmt.Sprintf("%.2f%%", successRate),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"records_flushed": s.recordsFlushed.Load(),
|
||||
"flush_errors": s.flushErrors.Load(),
|
||||
"last_flush_ago": time.Since(lastFlush).Round(time.Second),
|
||||
}).Info("Analytics metrics")
|
||||
}
|
||||
|
||||
// GetMetrics 返回当前统计指标(供监控使用)
|
||||
func (s *AnalyticsService) GetMetrics() map[string]interface{} {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
received := s.eventsReceived.Load()
|
||||
processed := s.eventsProcessed.Load()
|
||||
|
||||
var successRate float64
|
||||
if received > 0 {
|
||||
successRate = float64(processed) / float64(received) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"events_received": received,
|
||||
"events_processed": processed,
|
||||
"events_failed": s.eventsFailed.Load(),
|
||||
"success_rate": successRate,
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"records_flushed": s.recordsFlushed.Load(),
|
||||
"flush_errors": s.flushErrors.Load(),
|
||||
"last_flush_ago": time.Since(lastFlush).Seconds(),
|
||||
"flush_interval": s.flushInterval.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,158 +4,297 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||||
const (
|
||||
overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||||
defaultChartDays = 7
|
||||
cacheLoadTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// 图表颜色调色板
|
||||
chartColorPalette = []string{
|
||||
"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0",
|
||||
"#9966FF", "#FF9F40", "#C9CBCF", "#4D5360",
|
||||
}
|
||||
)
|
||||
|
||||
type DashboardQueryService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 统计指标
|
||||
queryCount atomic.Uint64
|
||||
cacheHits atomic.Uint64
|
||||
cacheMisses atomic.Uint64
|
||||
overviewLoadCount atomic.Uint64
|
||||
lastQueryTime time.Time
|
||||
lastQueryMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
|
||||
qs := &DashboardQueryService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "DashboardQueryService"),
|
||||
stopChan: make(chan struct{}),
|
||||
func NewDashboardQueryService(
|
||||
db *gorm.DB,
|
||||
s store.Store,
|
||||
logger *logrus.Logger,
|
||||
) (*DashboardQueryService, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
service := &DashboardQueryService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "DashboardQuery📈"),
|
||||
stopChan: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
lastQueryTime: time.Now(),
|
||||
}
|
||||
|
||||
loader := qs.loadOverviewData
|
||||
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
|
||||
// 创建 CacheSyncer
|
||||
overviewSyncer, err := syncer.NewCacheSyncer(
|
||||
service.loadOverviewData,
|
||||
s,
|
||||
overviewCacheChannel,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
|
||||
}
|
||||
qs.overviewSyncer = overviewSyncer
|
||||
return qs, nil
|
||||
service.overviewSyncer = overviewSyncer
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Start() {
|
||||
s.wg.Add(2)
|
||||
go s.eventListener()
|
||||
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
|
||||
go s.metricsReporter()
|
||||
|
||||
s.logger.Info("DashboardQueryService started")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Stop() {
|
||||
s.logger.Info("DashboardQueryService stopping...")
|
||||
close(s.stopChan)
|
||||
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
|
||||
// 输出最终统计
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_queries": s.queryCount.Load(),
|
||||
"cache_hits": s.cacheHits.Load(),
|
||||
"cache_misses": s.cacheMisses.Load(),
|
||||
"overview_loads": s.overviewLoadCount.Load(),
|
||||
}).Info("DashboardQueryService stopped")
|
||||
}
|
||||
|
||||
// ==================== 核心查询方法 ====================
|
||||
|
||||
// GetDashboardOverviewData 获取仪表盘概览数据(带缓存)
|
||||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
s.queryCount.Add(1)
|
||||
|
||||
cachedDataPtr := s.overviewSyncer.Get()
|
||||
if cachedDataPtr == nil {
|
||||
s.cacheMisses.Add(1)
|
||||
s.logger.Warn("Overview cache is empty, attempting to load...")
|
||||
|
||||
// 触发立即加载
|
||||
if err := s.overviewSyncer.Invalidate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to trigger cache reload: %w", err)
|
||||
}
|
||||
|
||||
// 等待加载完成(最多30秒)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if data := s.overviewSyncer.Get(); data != nil {
|
||||
s.cacheHits.Add(1)
|
||||
return data, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout waiting for overview cache to load")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.cacheHits.Add(1)
|
||||
return cachedDataPtr, nil
|
||||
}
|
||||
|
||||
// InvalidateOverviewCache 手动失效概览缓存
|
||||
func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||||
s.logger.Info("Manually invalidating overview cache")
|
||||
return s.overviewSyncer.Invalidate()
|
||||
}
|
||||
|
||||
// GetGroupStats 获取指定分组的统计数据
|
||||
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
|
||||
s.queryCount.Add(1)
|
||||
s.updateLastQueryTime()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// 1. 从 Redis 获取 Key 统计
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
|
||||
s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID)
|
||||
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
|
||||
}
|
||||
|
||||
keyStats := make(map[string]int64)
|
||||
for k, v := range keyStatsMap {
|
||||
val, _ := strconv.ParseInt(v, 10, 64)
|
||||
keyStats[k] = val
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
// 2. 查询请求统计(使用 UTC 时间)
|
||||
now := time.Now().UTC()
|
||||
oneHourAgo := now.Add(-1 * time.Hour)
|
||||
twentyFourHoursAgo := now.Add(-24 * time.Hour)
|
||||
|
||||
type requestStatsResult struct {
|
||||
TotalRequests int64
|
||||
SuccessRequests int64
|
||||
}
|
||||
|
||||
var last1Hour, last24Hours requestStatsResult
|
||||
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||||
Scan(&last1Hour)
|
||||
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||||
Scan(&last24Hours)
|
||||
failureRate1h := 0.0
|
||||
if last1Hour.TotalRequests > 0 {
|
||||
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
|
||||
}
|
||||
failureRate24h := 0.0
|
||||
if last24Hours.TotalRequests > 0 {
|
||||
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
|
||||
}
|
||||
last1HourStats := map[string]any{
|
||||
"total_requests": last1Hour.TotalRequests,
|
||||
"success_requests": last1Hour.SuccessRequests,
|
||||
"failure_rate": failureRate1h,
|
||||
}
|
||||
last24HoursStats := map[string]any{
|
||||
"total_requests": last24Hours.TotalRequests,
|
||||
"success_requests": last24Hours.SuccessRequests,
|
||||
"failure_rate": failureRate24h,
|
||||
|
||||
// 并发查询优化
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
// 查询最近1小时
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||||
Scan(&last1Hour).Error; err != nil {
|
||||
errChan <- fmt.Errorf("failed to query 1h stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 查询最近24小时
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||||
Scan(&last24Hours).Error; err != nil {
|
||||
errChan <- fmt.Errorf("failed to query 24h stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// 检查错误
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 计算失败率
|
||||
failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests)
|
||||
failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests)
|
||||
|
||||
result := map[string]any{
|
||||
"key_stats": keyStats,
|
||||
"last_1_hour": last1HourStats,
|
||||
"last_24_hours": last24HoursStats,
|
||||
"key_stats": keyStats,
|
||||
"last_1_hour": map[string]any{
|
||||
"total_requests": last1Hour.TotalRequests,
|
||||
"success_requests": last1Hour.SuccessRequests,
|
||||
"failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests,
|
||||
"failure_rate": failureRate1h,
|
||||
},
|
||||
"last_24_hours": map[string]any{
|
||||
"total_requests": last24Hours.TotalRequests,
|
||||
"success_requests": last24Hours.SuccessRequests,
|
||||
"failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests,
|
||||
"failure_rate": failureRate24h,
|
||||
},
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"duration": duration,
|
||||
}).Debug("Group stats query completed")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) eventListener() {
|
||||
ctx := context.Background()
|
||||
keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
|
||||
defer keyStatusSub.Close()
|
||||
defer upstreamStatusSub.Close()
|
||||
for {
|
||||
select {
|
||||
case <-keyStatusSub.Channel():
|
||||
s.logger.Info("Received key status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-upstreamStatusSub.Channel():
|
||||
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping dashboard event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
cachedDataPtr := s.overviewSyncer.Get()
|
||||
if cachedDataPtr == nil {
|
||||
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
|
||||
}
|
||||
return cachedDataPtr, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||||
return s.overviewSyncer.Invalidate()
|
||||
}
|
||||
|
||||
// QueryHistoricalChart 查询历史图表数据
|
||||
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
|
||||
s.queryCount.Add(1)
|
||||
s.updateLastQueryTime()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
type ChartPoint struct {
|
||||
TimeLabel string `gorm:"column:time_label"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
}
|
||||
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
|
||||
|
||||
// 查询最近7天数据(使用 UTC)
|
||||
sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour)
|
||||
|
||||
// 根据数据库类型构建时间格式化子句
|
||||
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
|
||||
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
|
||||
query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
|
||||
selectClause := fmt.Sprintf(
|
||||
"%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests",
|
||||
sqlFormat,
|
||||
)
|
||||
|
||||
// 构建查询
|
||||
query := s.db.WithContext(ctx).
|
||||
Model(&models.StatsHourly{}).
|
||||
Select(selectClause).
|
||||
Where("time >= ?", sevenDaysAgo).
|
||||
Group("time_label, model_name").
|
||||
Order("time_label ASC")
|
||||
|
||||
if groupID != nil && *groupID > 0 {
|
||||
query = query.Where("group_id = ?", *groupID)
|
||||
}
|
||||
|
||||
var points []ChartPoint
|
||||
if err := query.Find(&points).Error; err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to query chart data: %w", err)
|
||||
}
|
||||
|
||||
// 构建数据集
|
||||
datasets := make(map[string]map[string]int64)
|
||||
for _, p := range points {
|
||||
if _, ok := datasets[p.ModelName]; !ok {
|
||||
@@ -163,32 +302,99 @@ func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupI
|
||||
}
|
||||
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
|
||||
}
|
||||
|
||||
// 生成时间标签(按小时)
|
||||
var labels []string
|
||||
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
|
||||
for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) {
|
||||
labels = append(labels, t.Format(goFormat))
|
||||
}
|
||||
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
|
||||
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
|
||||
|
||||
// 构建图表数据
|
||||
chartData := &models.ChartData{
|
||||
Labels: labels,
|
||||
Datasets: make([]models.ChartDataset, 0, len(datasets)),
|
||||
}
|
||||
|
||||
colorIndex := 0
|
||||
for modelName, dataPoints := range datasets {
|
||||
dataArray := make([]int64, len(labels))
|
||||
for i, label := range labels {
|
||||
dataArray[i] = dataPoints[label]
|
||||
}
|
||||
|
||||
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
|
||||
Label: modelName,
|
||||
Data: dataArray,
|
||||
Color: colorPalette[colorIndex%len(colorPalette)],
|
||||
Color: chartColorPalette[colorIndex%len(chartColorPalette)],
|
||||
})
|
||||
colorIndex++
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"points": len(points),
|
||||
"datasets": len(chartData.Datasets),
|
||||
"duration": duration,
|
||||
}).Debug("Historical chart query completed")
|
||||
|
||||
return chartData, nil
|
||||
}
|
||||
|
||||
// GetRequestStatsForPeriod 获取指定时间段的请求统计
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
|
||||
s.queryCount.Add(1)
|
||||
s.updateLastQueryTime()
|
||||
|
||||
var startTime time.Time
|
||||
now := time.Now().UTC()
|
||||
|
||||
switch period {
|
||||
case "1m":
|
||||
startTime = now.Add(-1 * time.Minute)
|
||||
case "1h":
|
||||
startTime = now.Add(-1 * time.Hour)
|
||||
case "1d":
|
||||
year, month, day := now.Date()
|
||||
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Total int64
|
||||
Success int64
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
||||
Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success").
|
||||
Where("request_time >= ?", startTime).
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query request stats: %w", err)
|
||||
}
|
||||
|
||||
return gin.H{
|
||||
"period": period,
|
||||
"total": result.Total,
|
||||
"success": result.Success,
|
||||
"failure": result.Total - result.Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ==================== 内部方法 ====================
|
||||
|
||||
// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用)
|
||||
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
ctx := context.Background()
|
||||
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
s.overviewLoadCount.Add(1)
|
||||
startTime := time.Now()
|
||||
|
||||
s.logger.Info("Starting to load dashboard overview data...")
|
||||
|
||||
resp := &models.DashboardStatsResponse{
|
||||
KeyStatusCount: make(map[models.APIKeyStatus]int64),
|
||||
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
|
||||
@@ -200,108 +406,391 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
|
||||
RequestCounts: make(map[string]int64),
|
||||
}
|
||||
|
||||
var loadErr error
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 10)
|
||||
|
||||
// 1. 并发加载 Key 映射状态统计
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadMappingStatusStats(ctx, resp); err != nil {
|
||||
errChan <- fmt.Errorf("mapping stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 2. 并发加载 Master Key 状态统计
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadMasterStatusStats(ctx, resp); err != nil {
|
||||
errChan <- fmt.Errorf("master stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 3. 并发加载请求统计
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadRequestCounts(ctx, resp); err != nil {
|
||||
errChan <- fmt.Errorf("request counts: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 4. 并发加载上游健康状态
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadUpstreamHealth(ctx, resp); err != nil {
|
||||
// 上游健康状态失败不阻塞整体加载
|
||||
s.logger.WithError(err).Warn("Failed to load upstream health status")
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待所有加载任务完成
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// 收集错误
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
loadErr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if loadErr != nil {
|
||||
s.logger.WithError(loadErr).Error("Failed to load overview data")
|
||||
return nil, loadErr
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"duration": duration,
|
||||
"total_keys": resp.KeyCount.Value,
|
||||
"requests_1d": resp.RequestCounts["1d"],
|
||||
"upstreams": len(resp.UpstreamHealthStatus),
|
||||
}).Info("Successfully loaded dashboard overview data")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// loadMappingStatusStats 加载 Key 映射状态统计
|
||||
func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
type MappingStatusResult struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var mappingStatusResults []MappingStatusResult
|
||||
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
|
||||
|
||||
var results []MappingStatusResult
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.GroupAPIKeyMapping{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&results).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
for _, res := range mappingStatusResults {
|
||||
|
||||
for _, res := range results {
|
||||
resp.KeyStatusCount[res.Status] = res.Count
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadMasterStatusStats 加载 Master Key 状态统计
|
||||
func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
type MasterStatusResult struct {
|
||||
Status models.MasterAPIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var masterStatusResults []MasterStatusResult
|
||||
if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query master status stats: %w", err)
|
||||
|
||||
var results []MasterStatusResult
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.APIKey{}).
|
||||
Select("master_status as status, COUNT(*) as count").
|
||||
Group("master_status").
|
||||
Find(&results).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var totalKeys, invalidKeys int64
|
||||
for _, res := range masterStatusResults {
|
||||
for _, res := range results {
|
||||
resp.MasterStatusCount[res.Status] = res.Count
|
||||
totalKeys += res.Count
|
||||
if res.Status != models.MasterStatusActive {
|
||||
invalidKeys += res.Count
|
||||
}
|
||||
}
|
||||
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
|
||||
|
||||
now := time.Now()
|
||||
resp.KeyCount = models.StatCard{
|
||||
Value: float64(totalKeys),
|
||||
SubValue: invalidKeys,
|
||||
SubValueTip: "非活跃身份密钥数",
|
||||
}
|
||||
|
||||
var count1m, count1h, count1d int64
|
||||
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
|
||||
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
|
||||
year, month, day := now.UTC().Date()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRequestCounts 加载请求计数统计
|
||||
func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// 使用 RequestLog 表查询短期数据
|
||||
var count1m, count1h int64
|
||||
|
||||
// 最近1分钟
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.RequestLog{}).
|
||||
Where("request_time >= ?", now.Add(-1*time.Minute)).
|
||||
Count(&count1m).Error; err != nil {
|
||||
return fmt.Errorf("1m count: %w", err)
|
||||
}
|
||||
|
||||
// 最近1小时
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.RequestLog{}).
|
||||
Where("request_time >= ?", now.Add(-1*time.Hour)).
|
||||
Count(&count1h).Error; err != nil {
|
||||
return fmt.Errorf("1h count: %w", err)
|
||||
}
|
||||
|
||||
// 今天(UTC)
|
||||
year, month, day := now.Date()
|
||||
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
|
||||
|
||||
var count1d int64
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.RequestLog{}).
|
||||
Where("request_time >= ?", startOfDay).
|
||||
Count(&count1d).Error; err != nil {
|
||||
return fmt.Errorf("1d count: %w", err)
|
||||
}
|
||||
|
||||
// 最近30天使用聚合表
|
||||
var count30d int64
|
||||
s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.StatsHourly{}).
|
||||
Where("time >= ?", now.AddDate(0, 0, -30)).
|
||||
Select("COALESCE(SUM(request_count), 0)").
|
||||
Scan(&count30d).Error; err != nil {
|
||||
return fmt.Errorf("30d count: %w", err)
|
||||
}
|
||||
|
||||
resp.RequestCounts["1m"] = count1m
|
||||
resp.RequestCounts["1h"] = count1h
|
||||
resp.RequestCounts["1d"] = count1d
|
||||
resp.RequestCounts["30d"] = count30d
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadUpstreamHealth 加载上游健康状态
|
||||
func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
|
||||
} else {
|
||||
for _, u := range upstreams {
|
||||
resp.UpstreamHealthStatus[u.URL] = u.Status
|
||||
return err
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
resp.UpstreamHealthStatus[u.URL] = u.Status
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 事件监听 ====================
|
||||
|
||||
// eventListener 监听缓存失效事件
|
||||
func (s *DashboardQueryService) eventListener() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 订阅事件
|
||||
keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged)
|
||||
|
||||
// 错误处理
|
||||
if err1 != nil {
|
||||
s.logger.WithError(err1).Error("Failed to subscribe to key status events")
|
||||
keyStatusSub = nil
|
||||
}
|
||||
if err2 != nil {
|
||||
s.logger.WithError(err2).Error("Failed to subscribe to upstream status events")
|
||||
upstreamStatusSub = nil
|
||||
}
|
||||
|
||||
// 如果全部失败,直接返回
|
||||
if keyStatusSub == nil && upstreamStatusSub == nil {
|
||||
s.logger.Error("All event subscriptions failed, listener disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 安全关闭订阅
|
||||
defer func() {
|
||||
if keyStatusSub != nil {
|
||||
if err := keyStatusSub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close key status subscription")
|
||||
}
|
||||
}
|
||||
if upstreamStatusSub != nil {
|
||||
if err := upstreamStatusSub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close upstream status subscription")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"key_status_sub": keyStatusSub != nil,
|
||||
"upstream_status_sub": upstreamStatusSub != nil,
|
||||
}).Info("Event listener started")
|
||||
|
||||
neverReady := make(chan *store.Message)
|
||||
close(neverReady) // 立即关闭,确保永远不会阻塞
|
||||
|
||||
for {
|
||||
// 动态选择有效的 channel
|
||||
var keyStatusChan <-chan *store.Message = neverReady
|
||||
if keyStatusSub != nil {
|
||||
keyStatusChan = keyStatusSub.Channel()
|
||||
}
|
||||
|
||||
var upstreamStatusChan <-chan *store.Message = neverReady
|
||||
if upstreamStatusSub != nil {
|
||||
upstreamStatusChan = upstreamStatusSub.Channel()
|
||||
}
|
||||
|
||||
select {
|
||||
case _, ok := <-keyStatusChan:
|
||||
if !ok {
|
||||
s.logger.Warn("Key status channel closed")
|
||||
keyStatusSub = nil
|
||||
continue
|
||||
}
|
||||
s.logger.Debug("Received key status changed event")
|
||||
if err := s.InvalidateOverviewCache(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to invalidate cache on key status change")
|
||||
}
|
||||
|
||||
case _, ok := <-upstreamStatusChan:
|
||||
if !ok {
|
||||
s.logger.Warn("Upstream status channel closed")
|
||||
upstreamStatusSub = nil
|
||||
continue
|
||||
}
|
||||
s.logger.Debug("Received upstream status changed event")
|
||||
if err := s.InvalidateOverviewCache(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change")
|
||||
}
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Event listener stopping (stopChan)")
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
s.logger.Info("Event listener stopping (context cancelled)")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
|
||||
var startTime time.Time
|
||||
now := time.Now()
|
||||
switch period {
|
||||
case "1m":
|
||||
startTime = now.Add(-1 * time.Minute)
|
||||
case "1h":
|
||||
startTime = now.Add(-1 * time.Hour)
|
||||
case "1d":
|
||||
year, month, day := now.UTC().Date()
|
||||
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid period specified: %s", period)
|
||||
}
|
||||
var result struct {
|
||||
Total int64
|
||||
Success int64
|
||||
}
|
||||
// ==================== 监控指标 ====================
|
||||
|
||||
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
||||
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
|
||||
Where("request_time >= ?", startTime).
|
||||
Scan(&result).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// metricsReporter 定期输出统计信息
|
||||
func (s *DashboardQueryService) metricsReporter() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.reportMetrics()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
return gin.H{
|
||||
"total": result.Total,
|
||||
"success": result.Success,
|
||||
"failure": result.Total - result.Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) reportMetrics() {
|
||||
s.lastQueryMutex.RLock()
|
||||
lastQuery := s.lastQueryTime
|
||||
s.lastQueryMutex.RUnlock()
|
||||
|
||||
totalQueries := s.queryCount.Load()
|
||||
hits := s.cacheHits.Load()
|
||||
misses := s.cacheMisses.Load()
|
||||
|
||||
var cacheHitRate float64
|
||||
if hits+misses > 0 {
|
||||
cacheHitRate = float64(hits) / float64(hits+misses) * 100
|
||||
}
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_queries": totalQueries,
|
||||
"cache_hits": hits,
|
||||
"cache_misses": misses,
|
||||
"cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate),
|
||||
"overview_loads": s.overviewLoadCount.Load(),
|
||||
"last_query_ago": time.Since(lastQuery).Round(time.Second),
|
||||
}).Info("DashboardQuery metrics")
|
||||
}
|
||||
|
||||
// GetMetrics 返回当前统计指标(供监控使用)
|
||||
func (s *DashboardQueryService) GetMetrics() map[string]interface{} {
|
||||
s.lastQueryMutex.RLock()
|
||||
lastQuery := s.lastQueryTime
|
||||
s.lastQueryMutex.RUnlock()
|
||||
|
||||
hits := s.cacheHits.Load()
|
||||
misses := s.cacheMisses.Load()
|
||||
|
||||
var cacheHitRate float64
|
||||
if hits+misses > 0 {
|
||||
cacheHitRate = float64(hits) / float64(hits+misses) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_queries": s.queryCount.Load(),
|
||||
"cache_hits": hits,
|
||||
"cache_misses": misses,
|
||||
"cache_hit_rate": cacheHitRate,
|
||||
"overview_loads": s.overviewLoadCount.Load(),
|
||||
"last_query_ago": time.Since(lastQuery).Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// calculateFailureRate 计算失败率
|
||||
func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 {
|
||||
if total == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(total-success) / float64(total) * 100
|
||||
}
|
||||
|
||||
// updateLastQueryTime 更新最后查询时间
|
||||
func (s *DashboardQueryService) updateLastQueryTime() {
|
||||
s.lastQueryMutex.Lock()
|
||||
s.lastQueryTime = time.Now()
|
||||
s.lastQueryMutex.Unlock()
|
||||
}
|
||||
|
||||
// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句
|
||||
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
|
||||
dialect := s.db.Dialector.Name()
|
||||
switch dialect {
|
||||
case "mysql":
|
||||
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
|
||||
case "postgres":
|
||||
return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00"
|
||||
case "sqlite":
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
default:
|
||||
s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format")
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,13 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
@@ -18,25 +20,47 @@ type DBLogWriterService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
logBuffer chan *models.RequestLog
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
SettingsManager *settings.SettingsManager
|
||||
settingsManager *settings.SettingsManager
|
||||
|
||||
logBuffer chan *models.RequestLog
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 统计指标
|
||||
totalReceived atomic.Uint64
|
||||
totalFlushed atomic.Uint64
|
||||
totalDropped atomic.Uint64
|
||||
flushCount atomic.Uint64
|
||||
lastFlushTime time.Time
|
||||
lastFlushMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService {
|
||||
cfg := settings.GetSettings()
|
||||
func NewDBLogWriterService(
|
||||
db *gorm.DB,
|
||||
s store.Store,
|
||||
settingsManager *settings.SettingsManager,
|
||||
logger *logrus.Logger,
|
||||
) *DBLogWriterService {
|
||||
cfg := settingsManager.GetSettings()
|
||||
bufferCapacity := cfg.LogBufferCapacity
|
||||
if bufferCapacity <= 0 {
|
||||
bufferCapacity = 1000
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &DBLogWriterService{
|
||||
db: db,
|
||||
store: s,
|
||||
SettingsManager: settings,
|
||||
settingsManager: settingsManager,
|
||||
logger: logger.WithField("component", "DBLogWriter📝"),
|
||||
logBuffer: make(chan *models.RequestLog, bufferCapacity),
|
||||
stopChan: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
lastFlushTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,93 +68,276 @@ func (s *DBLogWriterService) Start() {
|
||||
s.wg.Add(2)
|
||||
go s.eventListenerLoop()
|
||||
go s.dbWriterLoop()
|
||||
s.logger.Info("DBLogWriterService started.")
|
||||
|
||||
// 定期输出统计信息
|
||||
s.wg.Add(1)
|
||||
go s.metricsReporter()
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"buffer_capacity": cap(s.logBuffer),
|
||||
}).Info("DBLogWriterService started")
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Stop() {
|
||||
s.logger.Info("DBLogWriterService stopping...")
|
||||
close(s.stopChan)
|
||||
s.cancel() // 取消上下文
|
||||
s.wg.Wait()
|
||||
s.logger.Info("DBLogWriterService stopped.")
|
||||
|
||||
// 输出最终统计
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_received": s.totalReceived.Load(),
|
||||
"total_flushed": s.totalFlushed.Load(),
|
||||
"total_dropped": s.totalDropped.Load(),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
}).Info("DBLogWriterService stopped")
|
||||
}
|
||||
|
||||
// 事件监听循环
|
||||
func (s *DBLogWriterService) eventListenerLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ctx := context.Background()
|
||||
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
|
||||
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
s.logger.WithError(err).Error("Failed to subscribe to request events, log writing disabled")
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
defer func() {
|
||||
if err := sub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close subscription")
|
||||
}
|
||||
}()
|
||||
|
||||
s.logger.Info("Subscribed to request events for database logging.")
|
||||
s.logger.Info("Subscribed to request events for database logging")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case s.logBuffer <- &event.RequestLog:
|
||||
default:
|
||||
s.logger.Warn("Log buffer is full. A log message might be dropped.")
|
||||
}
|
||||
s.handleMessage(msg)
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Event listener loop stopping.")
|
||||
s.logger.Info("Event listener loop stopping")
|
||||
close(s.logBuffer)
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
s.logger.Info("Event listener context cancelled")
|
||||
close(s.logBuffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理单条消息
|
||||
func (s *DBLogWriterService) handleMessage(msg *store.Message) {
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to unmarshal request event")
|
||||
return
|
||||
}
|
||||
|
||||
s.totalReceived.Add(1)
|
||||
|
||||
select {
|
||||
case s.logBuffer <- &event.RequestLog:
|
||||
// 成功入队
|
||||
default:
|
||||
// 缓冲区满,丢弃日志
|
||||
dropped := s.totalDropped.Add(1)
|
||||
if dropped%100 == 1 { // 每100条丢失输出一次警告
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_dropped": dropped,
|
||||
"buffer_capacity": cap(s.logBuffer),
|
||||
"buffer_len": len(s.logBuffer),
|
||||
}).Warn("Log buffer full, messages being dropped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 数据库写入循环
|
||||
func (s *DBLogWriterService) dbWriterLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
cfg := s.SettingsManager.GetSettings()
|
||||
cfg := s.settingsManager.GetSettings()
|
||||
batchSize := cfg.LogFlushBatchSize
|
||||
if batchSize <= 0 {
|
||||
batchSize = 100
|
||||
}
|
||||
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if flushTimeout <= 0 {
|
||||
flushTimeout = 5 * time.Second
|
||||
flushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if flushInterval <= 0 {
|
||||
flushInterval = 5 * time.Second
|
||||
}
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": batchSize,
|
||||
"flush_interval": flushInterval,
|
||||
}).Info("DB writer loop started")
|
||||
|
||||
batch := make([]*models.RequestLog, 0, batchSize)
|
||||
ticker := time.NewTicker(flushTimeout)
|
||||
ticker := time.NewTicker(flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 配置热更新检查(每分钟)
|
||||
configTicker := time.NewTicker(1 * time.Minute)
|
||||
defer configTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case logEntry, ok := <-s.logBuffer:
|
||||
if !ok {
|
||||
// 通道关闭,刷新剩余日志
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
}
|
||||
s.logger.Info("DB writer loop finished.")
|
||||
s.logger.Info("DB writer loop finished")
|
||||
return
|
||||
}
|
||||
|
||||
batch = append(batch, logEntry)
|
||||
if len(batch) >= batchSize {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
|
||||
case <-configTicker.C:
|
||||
// 热更新配置
|
||||
cfg := s.settingsManager.GetSettings()
|
||||
newBatchSize := cfg.LogFlushBatchSize
|
||||
if newBatchSize <= 0 {
|
||||
newBatchSize = 100
|
||||
}
|
||||
newFlushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if newFlushInterval <= 0 {
|
||||
newFlushInterval = 5 * time.Second
|
||||
}
|
||||
|
||||
if newBatchSize != batchSize {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"old": batchSize,
|
||||
"new": newBatchSize,
|
||||
}).Info("Batch size updated")
|
||||
batchSize = newBatchSize
|
||||
if len(batch) >= batchSize {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
}
|
||||
|
||||
if newFlushInterval != flushInterval {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"old": flushInterval,
|
||||
"new": newFlushInterval,
|
||||
}).Info("Flush interval updated")
|
||||
flushInterval = newFlushInterval
|
||||
ticker.Reset(flushInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 批量刷写到数据库
|
||||
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
|
||||
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
|
||||
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.db.WithContext(ctx).CreateInBatches(batch, len(batch)).Error
|
||||
duration := time.Since(start)
|
||||
|
||||
s.lastFlushMutex.Lock()
|
||||
s.lastFlushTime = time.Now()
|
||||
s.lastFlushMutex.Unlock()
|
||||
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": len(batch),
|
||||
"duration": duration,
|
||||
}).WithError(err).Error("Failed to flush log batch to database")
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d logs to database.", len(batch))
|
||||
flushed := s.totalFlushed.Add(uint64(len(batch)))
|
||||
flushCount := s.flushCount.Add(1)
|
||||
|
||||
// 只在慢写入或大批量时输出日志
|
||||
if duration > 1*time.Second || len(batch) > 500 {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": len(batch),
|
||||
"duration": duration,
|
||||
"total_flushed": flushed,
|
||||
"flush_count": flushCount,
|
||||
}).Info("Log batch flushed to database")
|
||||
} else {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": len(batch),
|
||||
"duration": duration,
|
||||
}).Debug("Log batch flushed to database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 定期输出统计信息
|
||||
func (s *DBLogWriterService) metricsReporter() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.reportMetrics()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) reportMetrics() {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
received := s.totalReceived.Load()
|
||||
flushed := s.totalFlushed.Load()
|
||||
dropped := s.totalDropped.Load()
|
||||
pending := uint64(len(s.logBuffer))
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"received": received,
|
||||
"flushed": flushed,
|
||||
"dropped": dropped,
|
||||
"pending": pending,
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"last_flush": time.Since(lastFlush).Round(time.Second),
|
||||
"buffer_usage": float64(pending) / float64(cap(s.logBuffer)) * 100,
|
||||
"success_rate": float64(flushed) / float64(received) * 100,
|
||||
}).Info("DBLogWriter metrics")
|
||||
}
|
||||
|
||||
// GetMetrics 返回当前统计指标(供监控使用)
|
||||
func (s *DBLogWriterService) GetMetrics() map[string]interface{} {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_received": s.totalReceived.Load(),
|
||||
"total_flushed": s.totalFlushed.Load(),
|
||||
"total_dropped": s.totalDropped.Load(),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"buffer_pending": len(s.logBuffer),
|
||||
"buffer_capacity": cap(s.logBuffer),
|
||||
"last_flush_ago": time.Since(lastFlush).Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,7 +334,6 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
}
|
||||
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
defaultModel := "gemini-1.5-flash"
|
||||
opConfig := &models.KeyGroupSettings{
|
||||
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
||||
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
||||
@@ -342,7 +341,7 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
|
||||
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
||||
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
||||
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
||||
KeyCheckModel: &defaultModel,
|
||||
KeyCheckModel: &globalSettings.BaseKeyCheckModel,
|
||||
MaxRetries: &globalSettings.MaxRetries,
|
||||
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,10 @@ const (
|
||||
TaskTypeHardDeleteKeys = "hard_delete_keys"
|
||||
TaskTypeRestoreKeys = "restore_keys"
|
||||
chunkSize = 500
|
||||
|
||||
// 任务超时时间常量化
|
||||
defaultTaskTimeout = 15 * time.Minute
|
||||
longTaskTimeout = time.Hour
|
||||
)
|
||||
|
||||
type KeyImportService struct {
|
||||
@@ -43,17 +47,19 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
|
||||
}
|
||||
}
|
||||
|
||||
// runTaskWithRecovery 统一的任务恢复包装器
|
||||
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
||||
s.logger.Error(err)
|
||||
s.logger.WithField("task_id", taskID).Error(err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
taskFunc()
|
||||
}
|
||||
|
||||
// StartAddKeysTask 启动批量添加密钥任务
|
||||
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
@@ -61,260 +67,404 @@ func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, k
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// StartUnlinkKeysTask 启动批量解绑密钥任务
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// StartHardDeleteKeysTask 启动硬删除密钥任务
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
|
||||
resourceID := "global_hard_delete"
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// StartRestoreKeysTask 启动恢复密钥任务
|
||||
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_restore_keys"
|
||||
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||||
resourceID := "global_restore_keys"
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeyStrings []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
|
||||
}
|
||||
}
|
||||
if len(uniqueKeyStrings) == 0 {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||||
return
|
||||
}
|
||||
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
||||
for i, keyStr := range uniqueKeyStrings {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
||||
// StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
}
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
alreadyLinkedIDSet := make(map[uint]struct{})
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
var keysToLink []models.APIKey
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
|
||||
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
// ==================== 核心任务执行逻辑 ====================
|
||||
|
||||
// runAddKeysTask 执行批量添加密钥
|
||||
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
// 1. 去重
|
||||
uniqueKeys := s.deduplicateKeys(keys)
|
||||
if len(uniqueKeys) == 0 {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, gin.H{
|
||||
"newly_linked_count": 0,
|
||||
"already_linked_count": 0,
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 确保所有密钥在数据库中存在(幂等操作)
|
||||
allKeyModels, err := s.ensureKeysExist(uniqueKeys)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 过滤已关联的密钥
|
||||
keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 更新任务的实际处理总数
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
|
||||
// 5. 批量关联密钥到组
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToLink) {
|
||||
end = len(idsToLink)
|
||||
}
|
||||
chunk := idsToLink[i:end]
|
||||
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||||
return
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
// 6. 根据验证标志处理密钥状态
|
||||
if len(keysToLink) > 0 {
|
||||
s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport)
|
||||
}
|
||||
|
||||
// 7. 返回结果
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": alreadyLinkedCount,
|
||||
"total_linked_count": len(allKeyModels),
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runUnlinkKeysTask 执行批量解绑密钥
|
||||
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
|
||||
// 1. 去重
|
||||
uniqueKeys := s.deduplicateKeys(keys)
|
||||
|
||||
// 2. 查找需要解绑的密钥
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{
|
||||
"unlinked_count": 0,
|
||||
"hard_deleted_count": 0,
|
||||
"not_found_count": len(uniqueKeys),
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 提取密钥 ID
|
||||
idsToUnlink := s.extractKeyIDs(keysToUnlink)
|
||||
|
||||
// 4. 更新任务总数
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
|
||||
// 5. 批量解绑
|
||||
totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. 清理孤立密钥
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||||
}
|
||||
|
||||
// 7. 返回结果
|
||||
result := gin.H{
|
||||
"unlinked_count": totalUnlinked,
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runHardDeleteKeysTask 执行硬删除密钥
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
|
||||
return s.keyRepo.HardDeleteByValues(chunk)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": len(alreadyLinkedIDSet),
|
||||
"total_linked_count": len(allKeyModels),
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
if validateOnImport {
|
||||
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeys []string
|
||||
// runRestoreKeysTask 执行恢复密钥
|
||||
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
|
||||
return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// deduplicateKeys 去重密钥列表
|
||||
func (s *KeyImportService) deduplicateKeys(keys []string) []string {
|
||||
uniqueKeysMap := make(map[string]struct{}, len(keys))
|
||||
uniqueKeys := make([]string, 0, len(keys))
|
||||
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, kStr)
|
||||
}
|
||||
}
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
return uniqueKeys
|
||||
}
|
||||
|
||||
// ensureKeysExist 确保所有密钥在数据库中存在
|
||||
func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) {
|
||||
keysToEnsure := make([]models.APIKey, len(keys))
|
||||
for i, keyStr := range keys {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
return s.keyRepo.AddKeys(keysToEnsure)
|
||||
}
|
||||
|
||||
// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥
|
||||
func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) {
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
idsToUnlink := make([]uint, len(keysToUnlink))
|
||||
for i, key := range keysToUnlink {
|
||||
idsToUnlink[i] = key.ID
|
||||
alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels))
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
var totalUnlinked int64
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToUnlink) {
|
||||
end = len(idsToUnlink)
|
||||
keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet))
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keysToLink, len(alreadyLinkedIDSet), nil
|
||||
}
|
||||
|
||||
// extractKeyIDs 提取密钥 ID 列表
|
||||
func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint {
|
||||
ids := make([]uint, len(keys))
|
||||
for i, key := range keys {
|
||||
ids[i] = key.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// linkKeysInChunks 分块关联密钥到组
|
||||
func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error {
|
||||
idsToLink := s.extractKeyIDs(keysToLink)
|
||||
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := min(i+chunkSize, len(idsToLink))
|
||||
chunk := idsToLink[i:end]
|
||||
|
||||
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
|
||||
return fmt.Errorf("chunk failed to link keys: %w", err)
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unlinkKeysInChunks 分块解绑密钥
|
||||
func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) {
|
||||
var totalUnlinked int64
|
||||
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := min(i+chunkSize, len(idsToUnlink))
|
||||
chunk := idsToUnlink[i:end]
|
||||
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||||
return
|
||||
return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err)
|
||||
}
|
||||
|
||||
totalUnlinked += unlinked
|
||||
|
||||
// 发布解绑事件
|
||||
for _, keyID := range chunk {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
return totalUnlinked, nil
|
||||
}
|
||||
|
||||
// processKeysInChunks 通用的分块处理密钥逻辑
|
||||
func (s *KeyImportService) processKeysInChunks(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
keys []string,
|
||||
processFunc func(chunk []string) (int64, error),
|
||||
) (int64, error) {
|
||||
var totalProcessed int64
|
||||
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := min(i+chunkSize, len(keys))
|
||||
chunk := keys[i:end]
|
||||
|
||||
count, err := processFunc(chunk)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to process chunk: %w", err)
|
||||
}
|
||||
|
||||
totalProcessed += count
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
|
||||
return totalProcessed, nil
|
||||
}
|
||||
|
||||
// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活)
|
||||
func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) {
|
||||
idsToLink := s.extractKeyIDs(keysToLink)
|
||||
|
||||
if validateOnImport {
|
||||
// 发布批量导入完成事件,触发验证
|
||||
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
|
||||
|
||||
// 发布单个密钥状态变更事件
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
// 直接激活密钥,不进行验证
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
}).Errorf("Failed to directly activate key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// endTaskWithResult 统一的任务结束处理
|
||||
func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) {
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"task_id": taskID,
|
||||
"resource_id": resourceID,
|
||||
}).WithError(err).Error("Task failed")
|
||||
}
|
||||
result := gin.H{
|
||||
"unlinked_count": totalUnlinked,
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
var totalDeleted int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||||
return
|
||||
}
|
||||
totalDeleted += deleted
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
result := gin.H{
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
var restoredCount int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||||
return
|
||||
}
|
||||
restoredCount += count
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
|
||||
}
|
||||
// ==================== 事件发布方法 ====================
|
||||
|
||||
// publishSingleKeyChangeEvent 发布单个密钥状态变更事件
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
@@ -324,56 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, grou
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithError(err).WithFields(logrus.Fields{
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).Error("Failed to publish single key change event.")
|
||||
}).WithError(err).Error("Failed to marshal key change event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to publish single key change event")
|
||||
}
|
||||
}
|
||||
|
||||
// publishChangeEvent 发布通用变更事件
|
||||
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
ChangeReason: reason,
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to marshal change event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to publish change event")
|
||||
}
|
||||
}
|
||||
|
||||
// publishImportGroupCompletedEvent 发布批量导入完成事件
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
|
||||
if len(keyIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
event := models.ImportGroupCompletedEvent{
|
||||
GroupID: groupID,
|
||||
KeyIDs: keyIDs,
|
||||
CompletedAt: time.Now(),
|
||||
}
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).WithError(err).Error("Failed to publish ImportGroupCompletedEvent")
|
||||
} else {
|
||||
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).Info("Published ImportGroupCompletedEvent")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
// min 返回两个整数中的较小值
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -25,26 +25,38 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeTestKeys = "test_keys"
|
||||
TaskTypeTestKeys = "test_keys"
|
||||
defaultConcurrency = 10
|
||||
maxValidationConcurrency = 100
|
||||
validationTaskTimeout = time.Hour
|
||||
)
|
||||
|
||||
type KeyValidationService struct {
|
||||
taskService task.Reporter
|
||||
channel channel.ChannelProxy
|
||||
db *gorm.DB
|
||||
SettingsManager *settings.SettingsManager
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
|
||||
func NewKeyValidationService(
|
||||
ts task.Reporter,
|
||||
ch channel.ChannelProxy,
|
||||
db *gorm.DB,
|
||||
ss *settings.SettingsManager,
|
||||
gm *GroupManager,
|
||||
st store.Store,
|
||||
kr repository.KeyRepository,
|
||||
logger *logrus.Logger,
|
||||
) *KeyValidationService {
|
||||
return &KeyValidationService{
|
||||
taskService: ts,
|
||||
channel: ch,
|
||||
db: db,
|
||||
SettingsManager: ss,
|
||||
settingsManager: ss,
|
||||
groupManager: gm,
|
||||
store: st,
|
||||
keyRepo: kr,
|
||||
@@ -52,33 +64,393 @@ func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 公开接口 ====================
|
||||
|
||||
// ValidateSingleKey 验证单个密钥
|
||||
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
|
||||
// 1. 解密密钥
|
||||
if err := s.keyRepo.Decrypt(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
|
||||
}
|
||||
|
||||
// 2. 创建 HTTP 客户端和请求
|
||||
client := &http.Client{Timeout: timeout}
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"endpoint": endpoint,
|
||||
}).Error("Failed to create validation request")
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// 3. 修改请求(添加密钥认证头)
|
||||
s.channel.ModifyRequest(req, key)
|
||||
|
||||
// 4. 执行请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 5. 检查响应状态
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 6. 处理错误响应
|
||||
return s.buildValidationError(resp)
|
||||
}
|
||||
|
||||
// StartTestKeysTask 启动批量密钥测试任务
|
||||
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
// 1. 解析和验证输入
|
||||
keyStrings := utils.ParseKeysFromText(keysText)
|
||||
if len(keyStrings) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
||||
}
|
||||
|
||||
// 2. 查询密钥模型
|
||||
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(apiKeyModels) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
|
||||
}
|
||||
|
||||
// 3. 批量解密密钥
|
||||
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
|
||||
}
|
||||
|
||||
// 4. 获取组配置
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
||||
}
|
||||
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
||||
}
|
||||
|
||||
// 5. 构建验证端点
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
|
||||
}
|
||||
|
||||
// 6. 创建任务
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 7. 准备任务参数
|
||||
params := s.buildValidationParams(opConfig)
|
||||
|
||||
// 8. 启动异步验证任务
|
||||
go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"statuses": statuses,
|
||||
}).Info("Starting test task with status filter")
|
||||
|
||||
// 1. 根据过滤条件查询密钥
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
|
||||
if len(keyValues) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
|
||||
}
|
||||
|
||||
// 2. 转换为文本格式并启动任务
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
|
||||
|
||||
return s.StartTestKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
// ==================== 核心任务执行逻辑 ====================
|
||||
|
||||
// validationParams 验证参数封装
|
||||
type validationParams struct {
|
||||
timeout time.Duration
|
||||
concurrency int
|
||||
}
|
||||
|
||||
// buildValidationParams 构建验证参数
|
||||
func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
|
||||
settings := s.settingsManager.GetSettings()
|
||||
// 从配置读取超时时间(而非硬编码)
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second // 仅在配置无效时使用默认值
|
||||
}
|
||||
// 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
} else if settings.BaseKeyCheckConcurrency > 0 {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
} else {
|
||||
concurrency = defaultConcurrency // 兜底默认值
|
||||
}
|
||||
// 限制最大并发数(防护措施)
|
||||
if concurrency > maxValidationConcurrency {
|
||||
concurrency = maxValidationConcurrency
|
||||
}
|
||||
return validationParams{
|
||||
timeout: timeout,
|
||||
concurrency: concurrency,
|
||||
}
|
||||
}
|
||||
|
||||
// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
|
||||
func (s *KeyValidationService) runTestKeysTaskWithRecovery(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
resourceID string,
|
||||
groupID uint,
|
||||
keys []models.APIKey,
|
||||
params validationParams,
|
||||
endpoint string,
|
||||
) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
|
||||
s.logger.WithField("task_id", taskID).Error(err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
|
||||
s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
|
||||
}
|
||||
|
||||
// runTestKeysTask 执行批量密钥验证任务
|
||||
func (s *KeyValidationService) runTestKeysTask(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
resourceID string,
|
||||
groupID uint,
|
||||
keys []models.APIKey,
|
||||
params validationParams,
|
||||
endpoint string,
|
||||
) {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"task_id": taskID,
|
||||
"group_id": groupID,
|
||||
"key_count": len(keys),
|
||||
"concurrency": params.concurrency,
|
||||
"timeout": params.timeout,
|
||||
}).Info("Starting validation task")
|
||||
|
||||
// 1. 初始化结果收集
|
||||
results := make([]models.KeyTestResult, len(keys))
|
||||
|
||||
// 2. 创建任务分发器
|
||||
dispatcher := newValidationDispatcher(
|
||||
keys,
|
||||
params.concurrency,
|
||||
s,
|
||||
ctx,
|
||||
taskID,
|
||||
groupID,
|
||||
endpoint,
|
||||
params.timeout,
|
||||
)
|
||||
|
||||
// 3. 执行并发验证
|
||||
dispatcher.run(results)
|
||||
|
||||
// 4. 完成任务
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"task_id": taskID,
|
||||
"group_id": groupID,
|
||||
"processed": len(results),
|
||||
}).Info("Validation task completed")
|
||||
}
|
||||
|
||||
// ==================== 验证调度器 ====================
|
||||
|
||||
// validationJob 验证作业
|
||||
type validationJob struct {
|
||||
index int
|
||||
key models.APIKey
|
||||
}
|
||||
|
||||
// validationDispatcher 验证任务分发器
|
||||
type validationDispatcher struct {
|
||||
keys []models.APIKey
|
||||
concurrency int
|
||||
service *KeyValidationService
|
||||
ctx context.Context
|
||||
taskID string
|
||||
groupID uint
|
||||
endpoint string
|
||||
timeout time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
processedCount int
|
||||
}
|
||||
|
||||
// newValidationDispatcher 创建验证分发器
|
||||
func newValidationDispatcher(
|
||||
keys []models.APIKey,
|
||||
concurrency int,
|
||||
service *KeyValidationService,
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
groupID uint,
|
||||
endpoint string,
|
||||
timeout time.Duration,
|
||||
) *validationDispatcher {
|
||||
return &validationDispatcher{
|
||||
keys: keys,
|
||||
concurrency: concurrency,
|
||||
service: service,
|
||||
ctx: ctx,
|
||||
taskID: taskID,
|
||||
groupID: groupID,
|
||||
endpoint: endpoint,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// run 执行并发验证
|
||||
func (d *validationDispatcher) run(results []models.KeyTestResult) {
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan validationJob, len(d.keys))
|
||||
|
||||
// 启动 worker pool
|
||||
for i := 0; i < d.concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go d.worker(&wg, jobs, results)
|
||||
}
|
||||
|
||||
// 分发任务
|
||||
for i, key := range d.keys {
|
||||
jobs <- validationJob{index: i, key: key}
|
||||
}
|
||||
close(jobs)
|
||||
|
||||
// 等待所有 worker 完成
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// worker 验证工作协程
|
||||
func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
|
||||
defer wg.Done()
|
||||
|
||||
for job := range jobs {
|
||||
result := d.validateKey(job.key)
|
||||
|
||||
d.mu.Lock()
|
||||
results[job.index] = result
|
||||
d.processedCount++
|
||||
_ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
|
||||
d.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// validateKey 验证单个密钥并返回结果
|
||||
func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
|
||||
// 1. 执行验证
|
||||
validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
|
||||
|
||||
// 2. 构建结果和事件
|
||||
result, event := d.buildResultAndEvent(key, validationErr)
|
||||
|
||||
// 3. 发布验证事件
|
||||
d.publishValidationEvent(key.ID, event)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// buildResultAndEvent 构建验证结果和事件
|
||||
func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
|
||||
event := models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
GroupID: &d.groupID,
|
||||
KeyID: &key.ID,
|
||||
},
|
||||
}
|
||||
|
||||
if validationErr == nil {
|
||||
// 验证成功
|
||||
event.RequestLog.IsSuccess = true
|
||||
return models.KeyTestResult{
|
||||
Key: key.APIKey,
|
||||
Status: "valid",
|
||||
Message: "Validation successful",
|
||||
}, event
|
||||
}
|
||||
|
||||
// 验证失败
|
||||
event.RequestLog.IsSuccess = false
|
||||
|
||||
var apiErr *CustomErrors.APIError
|
||||
if CustomErrors.As(validationErr, &apiErr) {
|
||||
event.Error = apiErr
|
||||
return models.KeyTestResult{
|
||||
Key: key.APIKey,
|
||||
Status: "invalid",
|
||||
Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
|
||||
}, event
|
||||
}
|
||||
|
||||
// 其他错误
|
||||
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
return models.KeyTestResult{
|
||||
Key: key.APIKey,
|
||||
Status: "error",
|
||||
Message: "Validation check failed: " + validationErr.Error(),
|
||||
}, event
|
||||
}
|
||||
|
||||
// publishValidationEvent 发布验证事件
|
||||
func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
d.service.logger.WithFields(logrus.Fields{
|
||||
"key_id": keyID,
|
||||
"group_id": d.groupID,
|
||||
}).WithError(err).Error("Failed to marshal validation event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
|
||||
d.service.logger.WithFields(logrus.Fields{
|
||||
"key_id": keyID,
|
||||
"group_id": d.groupID,
|
||||
}).WithError(err).Error("Failed to publish validation event")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// buildValidationError 构建验证错误
|
||||
func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
|
||||
var errorMsg string
|
||||
if readErr != nil {
|
||||
errorMsg = "Failed to read error response body"
|
||||
s.logger.WithError(readErr).Warn("Failed to read validation error response")
|
||||
} else {
|
||||
errorMsg = string(bodyBytes)
|
||||
}
|
||||
@@ -89,128 +461,3 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
|
||||
Code: "VALIDATION_FAILED",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
keyStrings := utils.ParseKeysFromText(keysText)
|
||||
if len(keyStrings) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
||||
}
|
||||
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(apiKeyModels) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
|
||||
}
|
||||
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
|
||||
}
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
||||
}
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
|
||||
return nil, err
|
||||
}
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
} else {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
}
|
||||
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
finalResults := make([]models.KeyTestResult, len(keys))
|
||||
processedCount := 0
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
type job struct {
|
||||
Index int
|
||||
Value models.APIKey
|
||||
}
|
||||
jobs := make(chan job, len(keys))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := range jobs {
|
||||
apiKeyModel := j.Value
|
||||
keyToValidate := apiKeyModel
|
||||
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
|
||||
|
||||
var currentResult models.KeyTestResult
|
||||
event := models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
GroupID: &groupID,
|
||||
KeyID: &apiKeyModel.ID,
|
||||
},
|
||||
}
|
||||
if validationErr == nil {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
|
||||
event.RequestLog.IsSuccess = true
|
||||
} else {
|
||||
var apiErr *CustomErrors.APIError
|
||||
if CustomErrors.As(validationErr, &apiErr) {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
|
||||
event.Error = apiErr
|
||||
} else {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
|
||||
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
}
|
||||
event.RequestLog.IsSuccess = false
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
finalResults[j.Index] = currentResult
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i, k := range keys {
|
||||
jobs <- job{Index: i, Value: k}
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
|
||||
}
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
|
||||
return s.StartTestKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
@@ -1,78 +1,152 @@
|
||||
// Filename: internal/service/log_service.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LogService struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewLogService(db *gorm.DB) *LogService {
|
||||
return &LogService{db: db}
|
||||
func NewLogService(db *gorm.DB, logger *logrus.Logger) *LogService {
|
||||
return &LogService{
|
||||
db: db,
|
||||
logger: logger.WithField("component", "LogService"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LogService) Record(log *models.RequestLog) error {
|
||||
return s.db.Create(log).Error
|
||||
func (s *LogService) Record(ctx context.Context, log *models.RequestLog) error {
|
||||
return s.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, int64, error) {
|
||||
// LogQueryParams 解耦 Gin,使用结构体传参
|
||||
type LogQueryParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
ModelName string
|
||||
IsSuccess *bool // 使用指针区分"未设置"和"false"
|
||||
StatusCode *int
|
||||
KeyID *uint64
|
||||
GroupID *uint64
|
||||
}
|
||||
|
||||
func (s *LogService) GetLogs(ctx context.Context, params LogQueryParams) ([]models.RequestLog, int64, error) {
|
||||
// 参数校验
|
||||
if params.Page < 1 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize < 1 || params.PageSize > 100 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
|
||||
var logs []models.RequestLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c))
|
||||
// 构建基础查询
|
||||
query := s.db.WithContext(ctx).Model(&models.RequestLog{})
|
||||
query = s.applyFilters(query, params)
|
||||
|
||||
// 先计算总数
|
||||
// 计算总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
return nil, 0, fmt.Errorf("failed to count logs: %w", err)
|
||||
}
|
||||
|
||||
if total == 0 {
|
||||
return []models.RequestLog{}, 0, nil
|
||||
}
|
||||
|
||||
// 再执行分页查询
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
err := query.Order("request_time desc").Limit(pageSize).Offset(offset).Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
// 分页查询
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
if err := query.Order("request_time DESC").
|
||||
Limit(params.PageSize).
|
||||
Offset(offset).
|
||||
Find(&logs).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query logs: %w", err)
|
||||
}
|
||||
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if modelName := c.Query("model_name"); modelName != "" {
|
||||
db = db.Where("model_name = ?", modelName)
|
||||
}
|
||||
if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
|
||||
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
|
||||
db = db.Where("is_success = ?", isSuccess)
|
||||
}
|
||||
}
|
||||
if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
|
||||
db = db.Where("status_code = ?", statusCode)
|
||||
}
|
||||
}
|
||||
if keyIDStr := c.Query("key_id"); keyIDStr != "" {
|
||||
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
|
||||
db = db.Where("key_id = ?", keyID)
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
|
||||
db = db.Where("group_id = ?", groupID)
|
||||
}
|
||||
}
|
||||
return db
|
||||
func (s *LogService) applyFilters(query *gorm.DB, params LogQueryParams) *gorm.DB {
|
||||
if params.ModelName != "" {
|
||||
query = query.Where("model_name = ?", params.ModelName)
|
||||
}
|
||||
if params.IsSuccess != nil {
|
||||
query = query.Where("is_success = ?", *params.IsSuccess)
|
||||
}
|
||||
if params.StatusCode != nil {
|
||||
query = query.Where("status_code = ?", *params.StatusCode)
|
||||
}
|
||||
if params.KeyID != nil {
|
||||
query = query.Where("key_id = ?", *params.KeyID)
|
||||
}
|
||||
if params.GroupID != nil {
|
||||
query = query.Where("group_id = ?", *params.GroupID)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// ParseLogQueryParams 在 Handler 层调用,解析 Gin 参数
|
||||
func ParseLogQueryParams(queryParams map[string]string) (LogQueryParams, error) {
|
||||
params := LogQueryParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
|
||||
if pageStr, ok := queryParams["page"]; ok {
|
||||
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
||||
params.Page = page
|
||||
}
|
||||
}
|
||||
|
||||
if pageSizeStr, ok := queryParams["page_size"]; ok {
|
||||
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
||||
params.PageSize = pageSize
|
||||
}
|
||||
}
|
||||
|
||||
if modelName, ok := queryParams["model_name"]; ok {
|
||||
params.ModelName = modelName
|
||||
}
|
||||
|
||||
if isSuccessStr, ok := queryParams["is_success"]; ok {
|
||||
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
|
||||
params.IsSuccess = &isSuccess
|
||||
} else {
|
||||
return params, fmt.Errorf("invalid is_success parameter: %s", isSuccessStr)
|
||||
}
|
||||
}
|
||||
|
||||
if statusCodeStr, ok := queryParams["status_code"]; ok {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
|
||||
params.StatusCode = &statusCode
|
||||
} else {
|
||||
return params, fmt.Errorf("invalid status_code parameter: %s", statusCodeStr)
|
||||
}
|
||||
}
|
||||
|
||||
if keyIDStr, ok := queryParams["key_id"]; ok {
|
||||
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
|
||||
params.KeyID = &keyID
|
||||
} else {
|
||||
return params, fmt.Errorf("invalid key_id parameter: %s", keyIDStr)
|
||||
}
|
||||
}
|
||||
|
||||
if groupIDStr, ok := queryParams["group_id"]; ok {
|
||||
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
|
||||
params.GroupID = &groupID
|
||||
} else {
|
||||
return params, fmt.Errorf("invalid group_id parameter: %s", groupIDStr)
|
||||
}
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
@@ -35,34 +35,55 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
|
||||
|
||||
func (s *StatsService) Start() {
|
||||
s.logger.Info("Starting event listener for stats maintenance.")
|
||||
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer sub.Close()
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChange(&event)
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping stats event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go s.listenForEvents()
|
||||
}
|
||||
|
||||
func (s *StatsService) Stop() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
func (s *StatsService) listenForEvents() {
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping stats event listener.")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err)
|
||||
cancel()
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info("Subscribed to key status changes")
|
||||
s.handleSubscription(sub, cancel)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) {
|
||||
defer sub.Close()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChange(&event)
|
||||
case <-s.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
|
||||
if event.GroupID == 0 {
|
||||
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
||||
@@ -75,23 +96,47 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
|
||||
switch event.ChangeReason {
|
||||
case "key_unlinked", "key_hard_deleted":
|
||||
if event.OldStatus != "" {
|
||||
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "key_linked":
|
||||
if event.NewStatus != "" {
|
||||
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
default:
|
||||
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
@@ -113,13 +158,16 @@ func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uin
|
||||
}
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
totalKeys := int64(0)
|
||||
updates := map[string]interface{}{
|
||||
"active_keys": int64(0),
|
||||
"disabled_keys": int64(0),
|
||||
"error_keys": int64(0),
|
||||
"total_keys": int64(0),
|
||||
}
|
||||
for _, res := range results {
|
||||
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
|
||||
totalKeys += res.Count
|
||||
updates["total_keys"] = updates["total_keys"].(int64) + res.Count
|
||||
}
|
||||
updates["total_keys"] = totalKeys
|
||||
|
||||
if err := s.store.Del(ctx, statsKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
||||
@@ -180,8 +228,18 @@ func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
||||
}).Create(&hourlyStats).Error
|
||||
}).Create(&hourlyStats).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.WithContext(ctx).
|
||||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||||
Delete(&models.RequestLog{}).Error; err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to delete aggregated request logs")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, log
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged)
|
||||
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
|
||||
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel, logger,)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,46 +4,54 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/store"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
ReconnectDelay = 5 * time.Second
|
||||
ReloadTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// LoaderFunc
|
||||
type LoaderFunc[T any] func() (T, error)
|
||||
|
||||
// CacheSyncer
|
||||
type CacheSyncer[T any] struct {
|
||||
mu sync.RWMutex
|
||||
cache T
|
||||
loader LoaderFunc[T]
|
||||
store store.Store
|
||||
channelName string
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewCacheSyncer
|
||||
func NewCacheSyncer[T any](
|
||||
loader LoaderFunc[T],
|
||||
store store.Store,
|
||||
channelName string,
|
||||
logger *logrus.Logger,
|
||||
) (*CacheSyncer[T], error) {
|
||||
s := &CacheSyncer[T]{
|
||||
loader: loader,
|
||||
store: store,
|
||||
channelName: channelName,
|
||||
logger: logger.WithField("component", fmt.Sprintf("CacheSyncer[%s]", channelName)),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err := s.reload(); err != nil {
|
||||
return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err)
|
||||
return nil, fmt.Errorf("initial load failed: %w", err)
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.listenForUpdates()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Get, Invalidate, Stop, reload 方法 .
|
||||
func (s *CacheSyncer[T]) Get() T {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -51,33 +59,60 @@ func (s *CacheSyncer[T]) Get() T {
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) Invalidate() error {
|
||||
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
|
||||
return s.store.Publish(context.Background(), s.channelName, []byte("reload"))
|
||||
s.logger.Info("Publishing invalidation notification")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.store.Publish(ctx, s.channelName, []byte("reload")); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish invalidation")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
log.Printf("INFO: CacheSyncer for channel '%s' stopped.", s.channelName)
|
||||
s.logger.Info("CacheSyncer stopped")
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) reload() error {
|
||||
log.Printf("INFO: Reloading cache for channel '%s'...", s.channelName)
|
||||
newData, err := s.loader()
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to reload cache for '%s': %v", s.channelName, err)
|
||||
return err
|
||||
s.logger.Info("Reloading cache...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ReloadTimeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
data T
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
data, err := s.loader()
|
||||
resultChan <- result{data, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-resultChan:
|
||||
if res.err != nil {
|
||||
s.logger.WithError(res.err).Error("Failed to reload cache")
|
||||
return res.err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.cache = res.data
|
||||
s.mu.Unlock()
|
||||
s.logger.Info("Cache reloaded successfully")
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.logger.Error("Cache reload timeout")
|
||||
return fmt.Errorf("reload timeout after %v", ReloadTimeout)
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.cache = newData
|
||||
s.mu.Unlock()
|
||||
log.Printf("INFO: Cache for channel '%s' reloaded successfully.", s.channelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenForUpdates ...
|
||||
func (s *CacheSyncer[T]) listenForUpdates() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
@@ -85,31 +120,39 @@ func (s *CacheSyncer[T]) listenForUpdates() {
|
||||
default:
|
||||
}
|
||||
|
||||
subscription, err := s.store.Subscribe(context.Background(), s.channelName)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
log.Printf("INFO: Subscribed to channel '%s' for cache invalidation.", s.channelName)
|
||||
|
||||
subscriberLoop:
|
||||
for {
|
||||
if err := s.subscribeAndListen(); err != nil {
|
||||
s.logger.WithError(err).Warnf("Subscription error, retrying in %v", ReconnectDelay)
|
||||
select {
|
||||
case _, ok := <-subscription.Channel():
|
||||
if !ok {
|
||||
log.Printf("WARN: Subscription channel '%s' closed, will re-subscribe.", s.channelName)
|
||||
break subscriberLoop
|
||||
}
|
||||
log.Printf("INFO: Received invalidation notification on '%s', reloading cache.", s.channelName)
|
||||
if err := s.reload(); err != nil {
|
||||
log.Printf("ERROR: Failed to reload cache for '%s' after notification: %v", s.channelName, err)
|
||||
}
|
||||
case <-time.After(ReconnectDelay):
|
||||
case <-s.stopChan:
|
||||
subscription.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
subscription.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) subscribeAndListen() error {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
subscription, err := s.store.Subscribe(ctx, s.channelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe: %w", err)
|
||||
}
|
||||
defer subscription.Close()
|
||||
s.logger.Info("Subscribed to channel")
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-subscription.Channel():
|
||||
if !ok {
|
||||
return fmt.Errorf("subscription channel closed")
|
||||
}
|
||||
s.logger.WithField("message", string(msg.Payload)).Info("Received invalidation notification")
|
||||
if err := s.reload(); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to reload after notification")
|
||||
}
|
||||
case <-s.stopChan:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Filename: internal/task/task.go
|
||||
package task
|
||||
|
||||
import (
|
||||
@@ -13,7 +12,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ResultTTL = 60 * time.Minute
|
||||
ResultTTL = 60 * time.Minute
|
||||
DefaultTimeout = 24 * time.Hour
|
||||
LockTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type Reporter interface {
|
||||
@@ -65,14 +66,21 @@ func (s *Task) getIsRunningFlagKey(taskID string) string {
|
||||
|
||||
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
||||
|
||||
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
|
||||
locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire task lock: %w", err)
|
||||
}
|
||||
if !locked {
|
||||
existingTaskID, _ := s.store.Get(ctx, lockKey)
|
||||
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
||||
}
|
||||
|
||||
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
|
||||
status := &Status{
|
||||
ID: taskID,
|
||||
TaskType: taskType,
|
||||
@@ -81,63 +89,55 @@ func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourc
|
||||
Total: total,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
statusBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
|
||||
}
|
||||
|
||||
if timeout == 0 {
|
||||
timeout = ResultTTL * 24
|
||||
}
|
||||
|
||||
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
|
||||
}
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
|
||||
if err := s.saveStatus(ctx, taskID, status, timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
|
||||
return nil, fmt.Errorf("failed to save task status: %w", err)
|
||||
}
|
||||
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
_ = s.store.Del(ctx, taskKey)
|
||||
return nil, fmt.Errorf("failed to set task running flag: %w", err)
|
||||
_ = s.store.Del(ctx, s.getTaskDataKey(taskID))
|
||||
return nil, fmt.Errorf("failed to set running flag: %w", err)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
defer func() {
|
||||
if err := s.store.Del(ctx, lockKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
|
||||
}
|
||||
}()
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
_ = s.store.Del(ctx, runningFlagKey)
|
||||
|
||||
defer func() {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
_ = s.store.Del(ctx, runningFlagKey)
|
||||
}()
|
||||
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
|
||||
s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID)
|
||||
return
|
||||
}
|
||||
|
||||
if !status.IsRunning {
|
||||
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
|
||||
s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
status.IsRunning = false
|
||||
status.FinishedAt = &now
|
||||
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
|
||||
|
||||
if taskErr != nil {
|
||||
status.Error = taskErr.Error()
|
||||
} else {
|
||||
status.Result = resultData
|
||||
}
|
||||
updatedTaskBytes, _ := json.Marshal(status)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
|
||||
|
||||
if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,43 +148,42 @@ func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get task status from store: %w", err)
|
||||
return nil, fmt.Errorf("failed to get task status: %w", err)
|
||||
}
|
||||
|
||||
var status Status
|
||||
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
||||
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
|
||||
return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err)
|
||||
}
|
||||
|
||||
if !status.IsRunning && status.FinishedAt != nil {
|
||||
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
|
||||
return nil
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return errors.New("task is not running")
|
||||
}
|
||||
return fmt.Errorf("failed to check running flag: %w", err)
|
||||
}
|
||||
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
|
||||
return nil
|
||||
return fmt.Errorf("failed to get task status: %w", err)
|
||||
}
|
||||
|
||||
if !status.IsRunning {
|
||||
return nil
|
||||
return errors.New("task is not running")
|
||||
}
|
||||
|
||||
updater(status)
|
||||
statusBytes, marshalErr := json.Marshal(status)
|
||||
if marshalErr != nil {
|
||||
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
|
||||
return nil
|
||||
}
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
|
||||
}
|
||||
return nil
|
||||
|
||||
return s.saveStatus(ctx, taskID, status, DefaultTimeout)
|
||||
}
|
||||
|
||||
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
|
||||
@@ -198,3 +197,17 @@ func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) er
|
||||
status.Total = total
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error {
|
||||
statusBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize status: %w", err)
|
||||
}
|
||||
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil {
|
||||
return fmt.Errorf("failed to save status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,43 +1,118 @@
|
||||
// Filename: internal/webhandlers/auth_handler.go (最终现代化改造版)
|
||||
// Filename: internal/webhandlers/auth_handler.go
|
||||
|
||||
package webhandlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/service" // [核心改造] 依赖service层
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WebAuthHandler [核心改造] 依赖关系净化,注入SecurityService
|
||||
// WebAuthHandler Web 认证处理器
|
||||
type WebAuthHandler struct {
|
||||
securityService *service.SecurityService
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewWebAuthHandler [核心改造] 构造函数更新
|
||||
// NewWebAuthHandler 创建 WebAuthHandler
|
||||
func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler {
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
|
||||
return &WebAuthHandler{
|
||||
securityService: securityService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ShowLoginPage 保持不变
|
||||
// ShowLoginPage 显示登录页面
|
||||
func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) {
|
||||
errMsg := c.Query("error")
|
||||
from := c.Query("from") // 可以从登录失败的页面返回
|
||||
|
||||
// 验证重定向路径(防止开放重定向攻击)
|
||||
redirectPath := h.validateRedirectPath(c.Query("redirect"))
|
||||
|
||||
// 如果已登录,直接重定向
|
||||
if cookie := middleware.ExtractTokenFromCookie(c); cookie != "" {
|
||||
if _, err := h.securityService.AuthenticateToken(cookie); err == nil {
|
||||
c.Redirect(http.StatusFound, redirectPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "auth.html", gin.H{
|
||||
"error": errMsg,
|
||||
"from": from,
|
||||
"error": errMsg,
|
||||
"redirect": redirectPath,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleLogin [核心改造] 认证逻辑完全委托给SecurityService
|
||||
// HandleLogin 已废弃(项目无用户名系统)
|
||||
func (h *WebAuthHandler) HandleLogin(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD")
|
||||
}
|
||||
|
||||
// HandleLogout 保持不变
|
||||
// HandleLogout 处理登出请求
|
||||
func (h *WebAuthHandler) HandleLogout(c *gin.Context) {
|
||||
cookie := middleware.ExtractTokenFromCookie(c)
|
||||
|
||||
if cookie != "" {
|
||||
// 尝试获取 Token 信息用于日志
|
||||
authToken, err := h.securityService.AuthenticateToken(cookie)
|
||||
if err == nil {
|
||||
h.logger.WithFields(logrus.Fields{
|
||||
"token_id": authToken.ID,
|
||||
"client_ip": c.ClientIP(),
|
||||
}).Info("User logged out")
|
||||
} else {
|
||||
h.logger.WithField("client_ip", c.ClientIP()).Warn("Logout with invalid token")
|
||||
}
|
||||
|
||||
// 使缓存失效
|
||||
middleware.InvalidateTokenCache(cookie)
|
||||
} else {
|
||||
h.logger.WithField("client_ip", c.ClientIP()).Debug("Logout without session cookie")
|
||||
}
|
||||
|
||||
// 清除 Cookie
|
||||
middleware.ClearAdminSessionCookie(c)
|
||||
|
||||
// 重定向到登录页
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
}
|
||||
|
||||
// validateRedirectPath 验证重定向路径(防止开放重定向攻击)
|
||||
func (h *WebAuthHandler) validateRedirectPath(path string) string {
|
||||
defaultPath := "/dashboard"
|
||||
|
||||
if path == "" {
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
// 只允许内部路径
|
||||
if !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
|
||||
h.logger.WithField("path", path).Warn("Invalid redirect path blocked")
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
// 白名单验证
|
||||
allowedPaths := []string{
|
||||
"/dashboard",
|
||||
"/keys",
|
||||
"/settings",
|
||||
"/logs",
|
||||
"/tasks",
|
||||
"/chat",
|
||||
}
|
||||
|
||||
for _, allowed := range allowedPaths {
|
||||
if strings.HasPrefix(path, allowed) {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@
|
||||
|
||||
{% block core_scripts %}
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/animejs/3.2.1/anime.min.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/sweetalert2@11"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/sweetalert2/11.23.0/sweetalert2.all.min.js"></script>
|
||||
<script src="/static/js/main.js" type="module" defer></script>
|
||||
{% endblock core_scripts %}
|
||||
<!-- [核心] Block 2: 留给子页面的脚本扩展插槽 -->
|
||||
|
||||
@@ -492,7 +492,7 @@
|
||||
type="text"
|
||||
id="TEST_MODEL"
|
||||
name="TEST_MODEL"
|
||||
placeholder="gemini-1.5-flash"
|
||||
placeholder="gemini-2.0-flash-lite"
|
||||
class="flex-grow px-4 py-3 rounded-lg form-input-themed"
|
||||
/>
|
||||
<button
|
||||
|
||||
Reference in New Issue
Block a user