Fix Services & Update the middleware && others
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
// Filename: internal/middleware/auth.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@@ -7,76 +8,115 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// === API Admin 认证管道 (/admin/* API路由) ===
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
// APIAdminAuthMiddleware 管理后台 API 认证
|
||||
func APIAdminAuthMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractBearerToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Authentication required",
|
||||
Code: "AUTH_MISSING",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ✅ 只传 token 参数(移除 context)
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"})
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Authentication failed")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Invalid authentication",
|
||||
Code: "AUTH_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
|
||||
Error: "Admin access required",
|
||||
Code: "AUTH_FORBIDDEN",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("adminUser", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// === /v1 Proxy 认证 ===
|
||||
|
||||
func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
// ProxyAuthMiddleware 代理请求认证
|
||||
func ProxyAuthMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractProxyToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "API key required",
|
||||
Code: "KEY_MISSING",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ✅ 只传 token 参数(移除 context)
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil {
|
||||
// 通用信息,避免泄露过多信息
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Invalid API key",
|
||||
Code: "KEY_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("authToken", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// extractProxyToken 按优先级提取 token
|
||||
func extractProxyToken(c *gin.Context) string {
|
||||
if key := c.Query("key"); key != "" {
|
||||
return key
|
||||
}
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
// 优先级 1: Authorization Header
|
||||
if token := extractBearerToken(c); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// 优先级 2: X-Api-Key
|
||||
if key := c.GetHeader("X-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
|
||||
// 优先级 3: X-Goog-Api-Key
|
||||
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
return ""
|
||||
|
||||
// 优先级 4: Query 参数(不推荐)
|
||||
return c.Query("key")
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
|
||||
// extractBearerToken 提取 Bearer Token
|
||||
func extractBearerToken(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
return parts[1]
|
||||
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, prefix) {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
|
||||
return strings.TrimSpace(authHeader[len(prefix):])
|
||||
}
|
||||
|
||||
90
internal/middleware/cors.go
Normal file
90
internal/middleware/cors.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Filename: internal/middleware/cors.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string
|
||||
AllowedMethods []string
|
||||
AllowedHeaders []string
|
||||
ExposedHeaders []string
|
||||
AllowCredentials bool
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// 检查 origin 是否允许
|
||||
if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if config.AllowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if len(config.ExposedHeaders) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Expose-Headers",
|
||||
strings.Join(config.ExposedHeaders, ", "))
|
||||
}
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if len(config.AllowedMethods) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods",
|
||||
strings.Join(config.AllowedMethods, ", "))
|
||||
}
|
||||
|
||||
if len(config.AllowedHeaders) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers",
|
||||
strings.Join(config.AllowedHeaders, ", "))
|
||||
}
|
||||
|
||||
if config.MaxAge > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Max-Age",
|
||||
string(rune(config.MaxAge)))
|
||||
}
|
||||
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if allowed == "*" || allowed == origin {
|
||||
return true
|
||||
}
|
||||
// 支持通配符子域名
|
||||
if strings.HasPrefix(allowed, "*.") {
|
||||
domain := strings.TrimPrefix(allowed, "*.")
|
||||
if strings.HasSuffix(origin, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupCORS(r *gin.Engine) {
|
||||
r.Use(CORSMiddleware(CORSConfig{
|
||||
AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"},
|
||||
ExposedHeaders: []string{"X-Request-Id"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
}))
|
||||
}
|
||||
@@ -1,84 +1,213 @@
|
||||
// Filename: internal/middleware/log_redaction.go
|
||||
// Filename: internal/middleware/logging.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const RedactedBodyKey = "redactedBody"
|
||||
const RedactedAuthHeaderKey = "redactedAuthHeader"
|
||||
const RedactedValue = `"[REDACTED]"`
|
||||
const (
|
||||
RedactedBodyKey = "redactedBody"
|
||||
RedactedAuthHeaderKey = "redactedAuthHeader"
|
||||
RedactedValue = `"[REDACTED]"`
|
||||
)
|
||||
|
||||
// 预编译正则表达式(全局变量,提升性能)
|
||||
var (
|
||||
// JSON 敏感字段脱敏
|
||||
jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`)
|
||||
|
||||
// Bearer Token 脱敏
|
||||
bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`)
|
||||
|
||||
// URL 中的 key 参数脱敏
|
||||
queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`)
|
||||
)
|
||||
|
||||
// RedactionMiddleware 请求数据脱敏中间件
|
||||
func RedactionMiddleware() gin.HandlerFunc {
|
||||
// Pre-compile regex for efficiency
|
||||
jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
|
||||
bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
|
||||
return func(c *gin.Context) {
|
||||
// --- 1. Redact Request Body ---
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
|
||||
if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil {
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
bodyString := string(bodyBytes)
|
||||
|
||||
redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
|
||||
|
||||
c.Set(RedactedBodyKey, redactedBody)
|
||||
}
|
||||
}
|
||||
// --- 2. Redact Authorization Header ---
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if bearerTokenPattern.MatchString(authHeader) {
|
||||
redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
|
||||
c.Set(RedactedAuthHeaderKey, redactedHeader)
|
||||
}
|
||||
// 1. 脱敏请求体
|
||||
if shouldRedactBody(c) {
|
||||
redactRequestBody(c)
|
||||
}
|
||||
|
||||
// 2. 脱敏认证头
|
||||
redactAuthHeader(c)
|
||||
|
||||
// 3. 脱敏 URL 查询参数
|
||||
redactQueryParams(c)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// LogrusLogger is a Gin middleware that logs requests using a Logrus logger.
|
||||
// It consumes redacted data prepared by the RedactionMiddleware.
|
||||
// shouldRedactBody 判断是否需要脱敏请求体
|
||||
func shouldRedactBody(c *gin.Context) bool {
|
||||
method := c.Request.Method
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
// 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求
|
||||
return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") &&
|
||||
strings.Contains(contentType, "application/json")
|
||||
}
|
||||
|
||||
// redactRequestBody 脱敏请求体
|
||||
func redactRequestBody(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 恢复请求体供后续使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 脱敏敏感字段
|
||||
bodyString := string(bodyBytes)
|
||||
redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue)
|
||||
|
||||
c.Set(RedactedBodyKey, redactedBody)
|
||||
}
|
||||
|
||||
// redactAuthHeader 脱敏认证头
|
||||
func redactAuthHeader(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if bearerTokenPattern.MatchString(authHeader) {
|
||||
redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
|
||||
c.Set(RedactedAuthHeaderKey, redacted)
|
||||
} else {
|
||||
// 对于非 Bearer 的 token,全部脱敏
|
||||
c.Set(RedactedAuthHeaderKey, "[REDACTED]")
|
||||
}
|
||||
|
||||
// 同时处理其他敏感 Header
|
||||
sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"}
|
||||
for _, header := range sensitiveHeaders {
|
||||
if value := c.GetHeader(header); value != "" {
|
||||
c.Set("redacted_"+header, "[REDACTED]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// redactQueryParams 脱敏 URL 查询参数
|
||||
func redactQueryParams(c *gin.Context) {
|
||||
rawQuery := c.Request.URL.RawQuery
|
||||
if rawQuery == "" {
|
||||
return
|
||||
}
|
||||
|
||||
redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`)
|
||||
if redacted != rawQuery {
|
||||
c.Set("redactedQuery", redacted)
|
||||
}
|
||||
}
|
||||
|
||||
// LogrusLogger Gin 请求日志中间件(使用 Logrus)
|
||||
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
|
||||
// Process request
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// After request, gather data and log
|
||||
// 计算延迟
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
entry := logger.WithFields(logrus.Fields{
|
||||
"status_code": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"client_ip": c.ClientIP(),
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
})
|
||||
// 构建日志字段
|
||||
fields := logrus.Fields{
|
||||
"status": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"ip": clientIP,
|
||||
"method": method,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
// 添加请求 ID(如果存在)
|
||||
if requestID := getRequestID(c); requestID != "" {
|
||||
fields["request_id"] = requestID
|
||||
}
|
||||
|
||||
// 添加脱敏后的数据
|
||||
if redactedBody, exists := c.Get(RedactedBodyKey); exists {
|
||||
entry = entry.WithField("body", redactedBody)
|
||||
fields["body"] = redactedBody
|
||||
}
|
||||
|
||||
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
|
||||
entry = entry.WithField("authorization", redactedAuth)
|
||||
fields["authorization"] = redactedAuth
|
||||
}
|
||||
|
||||
if redactedQuery, exists := c.Get("redactedQuery"); exists {
|
||||
fields["query"] = redactedQuery
|
||||
}
|
||||
|
||||
// 添加用户信息(如果已认证)
|
||||
if user := getAuthenticatedUser(c); user != "" {
|
||||
fields["user"] = user
|
||||
}
|
||||
|
||||
// 根据状态码选择日志级别
|
||||
entry := logger.WithFields(fields)
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
entry.Error(c.Errors.String())
|
||||
fields["errors"] = c.Errors.String()
|
||||
entry.Error("Request failed")
|
||||
} else {
|
||||
entry.Info("request handled")
|
||||
switch {
|
||||
case statusCode >= 500:
|
||||
entry.Error("Server error")
|
||||
case statusCode >= 400:
|
||||
entry.Warn("Client error")
|
||||
case statusCode >= 300:
|
||||
entry.Info("Redirect")
|
||||
default:
|
||||
// 只在 Debug 模式记录成功请求
|
||||
if logger.Level >= logrus.DebugLevel {
|
||||
entry.Debug("Request completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestID 获取请求 ID
|
||||
func getRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get("request_id"); exists {
|
||||
if requestID, ok := id.(string); ok {
|
||||
return requestID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getAuthenticatedUser 获取已认证用户标识
|
||||
func getAuthenticatedUser(c *gin.Context) string {
|
||||
// 尝试从不同来源获取用户信息
|
||||
if user, exists := c.Get("adminUser"); exists {
|
||||
if authToken, ok := user.(interface{ GetID() string }); ok {
|
||||
return authToken.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
if user, exists := c.Get("authToken"); exists {
|
||||
if authToken, ok := user.(interface{ GetID() string }); ok {
|
||||
return authToken.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
86
internal/middleware/rate_limit.go
Normal file
86
internal/middleware/rate_limit.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Filename: internal/middleware/rate_limit.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
limiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
r rate.Limit // 每秒请求数
|
||||
b int // 突发容量
|
||||
}
|
||||
|
||||
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
r: r,
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
limiter, exists := rl.limiters[key]
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(rl.r, rl.b)
|
||||
rl.limiters[key] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 定期清理不活跃的限制器
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
// 简单策略:定期清空(生产环境应该用 LRU)
|
||||
rl.limiters = make(map[string]*rate.Limiter)
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 按 IP 限流
|
||||
key := c.ClientIP()
|
||||
|
||||
// 如果有认证 token,按 token 限流(更精确)
|
||||
if authToken, exists := c.Get("authToken"); exists {
|
||||
if token, ok := authToken.(interface{ GetID() string }); ok {
|
||||
key = "token:" + token.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
l := limiter.getLimiter(key)
|
||||
if !l.Allow() {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"code": "RATE_LIMIT",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupRateLimit(r *gin.Engine) {
|
||||
limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20
|
||||
go limiter.cleanup()
|
||||
|
||||
r.Use(RateLimitMiddleware(limiter))
|
||||
}
|
||||
39
internal/middleware/request_id.go
Normal file
39
internal/middleware/request_id.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Filename: internal/middleware/request_id.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// RequestIDMiddleware 请求 ID 追踪中间件
|
||||
func RequestIDMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 尝试从 Header 获取现有的 Request ID
|
||||
requestID := c.GetHeader("X-Request-Id")
|
||||
|
||||
// 2. 如果没有,生成新的
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 3. 设置到 Context
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
// 4. 返回给客户端(用于追踪)
|
||||
c.Writer.Header().Set("X-Request-Id", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestID 获取当前请求的 Request ID
|
||||
func GetRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get("request_id"); exists {
|
||||
if requestID, ok := id.(string); ok {
|
||||
return requestID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Filename: internal/middleware/security.go
|
||||
// Filename: internal/middleware/security.go (简化版)
|
||||
|
||||
package middleware
|
||||
|
||||
@@ -6,26 +6,136 @@ import (
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc {
|
||||
// 简单的缓存项
|
||||
type cacheItem struct {
|
||||
value bool
|
||||
expiration int64
|
||||
}
|
||||
|
||||
// 简单的 TTL 缓存实现
|
||||
type IPBanCache struct {
|
||||
items map[string]*cacheItem
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewIPBanCache() *IPBanCache {
|
||||
cache := &IPBanCache{
|
||||
items: make(map[string]*cacheItem),
|
||||
ttl: 1 * time.Minute,
|
||||
}
|
||||
|
||||
// 启动清理协程
|
||||
go cache.cleanup()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Get(key string) (bool, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().UnixNano() > item.expiration {
|
||||
return false, false
|
||||
}
|
||||
|
||||
return item.value, true
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Set(key string, value bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.items[key] = &cacheItem{
|
||||
value: value,
|
||||
expiration: time.Now().Add(c.ttl).UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
func (c *IPBanCache) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.mu.Lock()
|
||||
now := time.Now().UnixNano()
|
||||
for key, item := range c.items {
|
||||
if now > item.expiration {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func IPBanMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
settingsManager *settings.SettingsManager,
|
||||
banCache *IPBanCache,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !settingsManager.IsIPBanEnabled() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
ip := c.ClientIP()
|
||||
isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
|
||||
if err != nil {
|
||||
|
||||
// 查缓存
|
||||
if isBanned, exists := banCache.Get(ip); exists {
|
||||
if isBanned {
|
||||
logger.WithField("ip", ip).Debug("IP blocked (cached)")
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "Access denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if isBanned {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁,请稍后再试"})
|
||||
|
||||
// 查数据库
|
||||
ctx := c.Request.Context()
|
||||
isBanned, err := securityService.IsIPBanned(ctx, ip)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
|
||||
|
||||
// 降级策略:允许访问
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
banCache.Set(ip, isBanned)
|
||||
|
||||
if isBanned {
|
||||
logger.WithField("ip", ip).Info("IP blocked")
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "Access denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
52
internal/middleware/timeout.go
Normal file
52
internal/middleware/timeout.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Filename: internal/middleware/timeout.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 创建带超时的 context
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 替换 request context
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
// 使用 channel 等待请求完成
|
||||
finished := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
c.Next()
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
// 请求正常完成
|
||||
return
|
||||
case <-ctx.Done():
|
||||
// 超时
|
||||
c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
|
||||
"error": "Request timeout",
|
||||
"code": "TIMEOUT",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupTimeout(r *gin.Engine) {
|
||||
// 对 API 路由设置 30 秒超时
|
||||
api := r.Group("/api")
|
||||
api.Use(TimeoutMiddleware(30 * time.Second))
|
||||
{
|
||||
// ... API routes
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,151 @@
|
||||
// Filename: internal/middleware/web.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/service"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
AdminSessionCookie = "gemini_admin_session"
|
||||
SessionMaxAge = 3600 * 24 * 7 // 7天
|
||||
CacheTTL = 5 * time.Minute
|
||||
CleanupInterval = 10 * time.Minute // 降低清理频率
|
||||
SessionRefreshTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
// ==================== 缓存层 ====================
|
||||
|
||||
type authCacheEntry struct {
|
||||
Token interface{}
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type authCache struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*authCacheEntry
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
var webAuthCache = newAuthCache(CacheTTL)
|
||||
|
||||
func newAuthCache(ttl time.Duration) *authCache {
|
||||
c := &authCache{
|
||||
cache: make(map[string]*authCacheEntry),
|
||||
ttl: ttl,
|
||||
}
|
||||
go c.cleanupLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *authCache) get(key string) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.cache[key]
|
||||
if !exists || time.Now().After(entry.ExpiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
return entry.Token, true
|
||||
}
|
||||
|
||||
func (c *authCache) set(key string, token interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache[key] = &authCacheEntry{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(c.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *authCache) delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.cache, key)
|
||||
}
|
||||
|
||||
func (c *authCache) cleanupLoop() {
|
||||
ticker := time.NewTicker(CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *authCache) cleanup() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
count := 0
|
||||
for key, entry := range c.cache {
|
||||
if now.After(entry.ExpiresAt) {
|
||||
delete(c.cache, key)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 会话刷新缓存 ====================
|
||||
|
||||
var sessionRefreshCache = struct {
|
||||
sync.RWMutex
|
||||
timestamps map[string]time.Time
|
||||
}{
|
||||
timestamps: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// 定期清理刷新时间戳
|
||||
func init() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
sessionRefreshCache.Lock()
|
||||
now := time.Now()
|
||||
for key, ts := range sessionRefreshCache.timestamps {
|
||||
if now.Sub(ts) > 2*time.Hour {
|
||||
delete(sessionRefreshCache.timestamps, key)
|
||||
}
|
||||
}
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ==================== Cookie 操作 ====================
|
||||
|
||||
func SetAdminSessionCookie(c *gin.Context, adminToken string) {
|
||||
c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true)
|
||||
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true)
|
||||
}
|
||||
|
||||
func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) {
|
||||
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true)
|
||||
}
|
||||
|
||||
func ClearAdminSessionCookie(c *gin.Context) {
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
|
||||
}
|
||||
|
||||
@@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string {
|
||||
return cookie
|
||||
}
|
||||
|
||||
// ==================== 认证中间件 ====================
|
||||
|
||||
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(getLogLevel())
|
||||
|
||||
return func(c *gin.Context) {
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path)
|
||||
log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie)
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
if err != nil {
|
||||
log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
|
||||
} else if !authToken.IsAdmin {
|
||||
log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
|
||||
} else {
|
||||
log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
|
||||
}
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
if cookie == "" {
|
||||
logger.Debug("[WebAuth] No session cookie found")
|
||||
ClearAdminSessionCookie(c)
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
c.Abort()
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
|
||||
if cachedToken, found := webAuthCache.get(cacheKey); found {
|
||||
logger.Debug("[WebAuth] Using cached token")
|
||||
c.Set("adminUser", cachedToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("[WebAuth] Cache miss, authenticating...")
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("[WebAuth] Authentication failed")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
logger.Warn("[WebAuth] User is not admin")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
webAuthCache.set(cacheKey, authToken)
|
||||
logger.Debug("[WebAuth] Authentication success, token cached")
|
||||
|
||||
c.Set("adminUser", authToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
|
||||
if cookie == "" {
|
||||
logger.Debug("No session cookie found")
|
||||
ClearAdminSessionCookie(c)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
if cachedToken, found := webAuthCache.get(cacheKey); found {
|
||||
c.Set("adminUser", cachedToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Token authentication failed")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
logger.Warn("Token valid but user is not admin")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
webAuthCache.set(cacheKey, authToken)
|
||||
c.Set("adminUser", authToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
func hashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func redirectToLogin(c *gin.Context) {
|
||||
if isAjaxRequest(c) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Session expired",
|
||||
"code": "AUTH_REQUIRED",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
originalPath := c.Request.URL.Path
|
||||
if originalPath != "/" && originalPath != "/login" {
|
||||
c.Redirect(http.StatusFound, "/login?redirect="+originalPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func isAjaxRequest(c *gin.Context) bool {
|
||||
// 检查 Content-Type
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查 Accept(优先检查 JSON)
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "application/json") &&
|
||||
!strings.Contains(accept, "text/html") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 兼容旧版 XMLHttpRequest
|
||||
return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
|
||||
}
|
||||
|
||||
func refreshSessionIfNeeded(c *gin.Context, token string) {
|
||||
tokenHash := hashToken(token)
|
||||
|
||||
sessionRefreshCache.RLock()
|
||||
lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash]
|
||||
sessionRefreshCache.RUnlock()
|
||||
|
||||
if !exists || time.Since(lastRefresh) > SessionRefreshTime {
|
||||
SetAdminSessionCookie(c, token)
|
||||
|
||||
sessionRefreshCache.Lock()
|
||||
sessionRefreshCache.timestamps[tokenHash] = time.Now()
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func getLogLevel() logrus.Level {
|
||||
level := os.Getenv("LOG_LEVEL")
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return logrus.DebugLevel
|
||||
case "warn":
|
||||
return logrus.WarnLevel
|
||||
case "error":
|
||||
return logrus.ErrorLevel
|
||||
default:
|
||||
return logrus.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) {
|
||||
return c.Get("adminUser")
|
||||
}
|
||||
|
||||
func InvalidateTokenCache(token string) {
|
||||
tokenHash := hashToken(token)
|
||||
webAuthCache.delete(tokenHash)
|
||||
|
||||
// 同时清理刷新时间戳
|
||||
sessionRefreshCache.Lock()
|
||||
delete(sessionRefreshCache.timestamps, tokenHash)
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
|
||||
func ClearAllAuthCache() {
|
||||
webAuthCache.mu.Lock()
|
||||
webAuthCache.cache = make(map[string]*authCacheEntry)
|
||||
webAuthCache.mu.Unlock()
|
||||
|
||||
sessionRefreshCache.Lock()
|
||||
sessionRefreshCache.timestamps = make(map[string]time.Time)
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
|
||||
// ==================== 调试工具 ====================
|
||||
|
||||
type SessionInfo struct {
|
||||
HasCookie bool `json:"has_cookie"`
|
||||
IsValid bool `json:"is_valid"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
IsCached bool `json:"is_cached"`
|
||||
LastActivity string `json:"last_activity"`
|
||||
}
|
||||
|
||||
func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo {
|
||||
info := SessionInfo{
|
||||
HasCookie: false,
|
||||
IsValid: false,
|
||||
IsAdmin: false,
|
||||
IsCached: false,
|
||||
LastActivity: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
if cookie == "" {
|
||||
return info
|
||||
}
|
||||
|
||||
info.HasCookie = true
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
if _, found := webAuthCache.get(cacheKey); found {
|
||||
info.IsCached = true
|
||||
}
|
||||
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
info.IsValid = true
|
||||
info.IsAdmin = authToken.IsAdmin
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func GetCacheStats() map[string]interface{} {
|
||||
webAuthCache.mu.RLock()
|
||||
cacheSize := len(webAuthCache.cache)
|
||||
webAuthCache.mu.RUnlock()
|
||||
|
||||
sessionRefreshCache.RLock()
|
||||
refreshSize := len(sessionRefreshCache.timestamps)
|
||||
sessionRefreshCache.RUnlock()
|
||||
return map[string]interface{}{
|
||||
"auth_cache_entries": cacheSize,
|
||||
"refresh_cache_entries": refreshSize,
|
||||
"ttl_seconds": int(webAuthCache.ttl.Seconds()),
|
||||
"cleanup_interval": int(CleanupInterval.Seconds()),
|
||||
"session_refresh_time": int(SessionRefreshTime.Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user