Fix Services & Update the middleware && others

This commit is contained in:
XOF
2025-11-24 04:48:07 +08:00
parent 3a95a07e8a
commit f2706d6fc8
37 changed files with 4458 additions and 1166 deletions

View File

@@ -1,4 +1,5 @@
// Filename: internal/middleware/auth.go
package middleware
import (
@@ -7,76 +8,115 @@ import (
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
// === API Admin 认证管道 (/admin/* API路由) ===
type ErrorResponse struct {
Error string `json:"error"`
Code string `json:"code,omitempty"`
}
func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
// APIAdminAuthMiddleware 管理后台 API 认证
func APIAdminAuthMiddleware(
securityService *service.SecurityService,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) {
tokenValue := extractBearerToken(c)
if tokenValue == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"})
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Authentication required",
Code: "AUTH_MISSING",
})
return
}
// ✅ 只传 token 参数(移除 context
authToken, err := securityService.AuthenticateToken(tokenValue)
if err != nil || !authToken.IsAdmin {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"})
if err != nil {
logger.WithError(err).Warn("Authentication failed")
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Invalid authentication",
Code: "AUTH_INVALID",
})
return
}
if !authToken.IsAdmin {
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
Error: "Admin access required",
Code: "AUTH_FORBIDDEN",
})
return
}
c.Set("adminUser", authToken)
c.Next()
}
}
// === /v1 Proxy 认证 ===
func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
// ProxyAuthMiddleware 代理请求认证
func ProxyAuthMiddleware(
securityService *service.SecurityService,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) {
tokenValue := extractProxyToken(c)
if tokenValue == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"})
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "API key required",
Code: "KEY_MISSING",
})
return
}
// ✅ 只传 token 参数(移除 context
authToken, err := securityService.AuthenticateToken(tokenValue)
if err != nil {
// 通用信息,避免泄露过多信息
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"})
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Invalid API key",
Code: "KEY_INVALID",
})
return
}
c.Set("authToken", authToken)
c.Next()
}
}
// extractProxyToken 按优先级提取 token
func extractProxyToken(c *gin.Context) string {
if key := c.Query("key"); key != "" {
return key
}
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
// 优先级 1: Authorization Header
if token := extractBearerToken(c); token != "" {
return token
}
// 优先级 2: X-Api-Key
if key := c.GetHeader("X-Api-Key"); key != "" {
return key
}
// 优先级 3: X-Goog-Api-Key
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
return ""
// 优先级 4: Query 参数(不推荐)
return c.Query("key")
}
// === 辅助函数 ===
// extractBearerToken 提取 Bearer Token
func extractBearerToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
parts := strings.Split(authHeader, " ")
if len(parts) == 2 && parts[0] == "Bearer" {
return parts[1]
const prefix = "Bearer "
if !strings.HasPrefix(authHeader, prefix) {
return ""
}
return ""
return strings.TrimSpace(authHeader[len(prefix):])
}

View File

@@ -0,0 +1,90 @@
// Filename: internal/middleware/cors.go
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
ExposedHeaders []string
AllowCredentials bool
MaxAge int
}
func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// 检查 origin 是否允许
if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
}
if config.AllowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if len(config.ExposedHeaders) > 0 {
c.Writer.Header().Set("Access-Control-Expose-Headers",
strings.Join(config.ExposedHeaders, ", "))
}
// 处理预检请求
if c.Request.Method == http.MethodOptions {
if len(config.AllowedMethods) > 0 {
c.Writer.Header().Set("Access-Control-Allow-Methods",
strings.Join(config.AllowedMethods, ", "))
}
if len(config.AllowedHeaders) > 0 {
c.Writer.Header().Set("Access-Control-Allow-Headers",
strings.Join(config.AllowedHeaders, ", "))
}
if config.MaxAge > 0 {
c.Writer.Header().Set("Access-Control-Max-Age",
string(rune(config.MaxAge)))
}
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
// 支持通配符子域名
if strings.HasPrefix(allowed, "*.") {
domain := strings.TrimPrefix(allowed, "*.")
if strings.HasSuffix(origin, domain) {
return true
}
}
}
return false
}
// 使用示例
func SetupCORS(r *gin.Engine) {
r.Use(CORSMiddleware(CORSConfig{
AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"},
ExposedHeaders: []string{"X-Request-Id"},
AllowCredentials: true,
MaxAge: 3600,
}))
}

View File

@@ -1,84 +1,213 @@
// Filename: internal/middleware/log_redaction.go
// Filename: internal/middleware/logging.go
package middleware
import (
"bytes"
"io"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
const RedactedBodyKey = "redactedBody"
const RedactedAuthHeaderKey = "redactedAuthHeader"
const RedactedValue = `"[REDACTED]"`
const (
RedactedBodyKey = "redactedBody"
RedactedAuthHeaderKey = "redactedAuthHeader"
RedactedValue = `"[REDACTED]"`
)
// 预编译正则表达式(全局变量,提升性能)
var (
// JSON 敏感字段脱敏
jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`)
// Bearer Token 脱敏
bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`)
// URL 中的 key 参数脱敏
queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`)
)
// RedactionMiddleware 请求数据脱敏中间件
func RedactionMiddleware() gin.HandlerFunc {
// Pre-compile regex for efficiency
jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
return func(c *gin.Context) {
// --- 1. Redact Request Body ---
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
bodyString := string(bodyBytes)
redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
c.Set(RedactedBodyKey, redactedBody)
}
}
// --- 2. Redact Authorization Header ---
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if bearerTokenPattern.MatchString(authHeader) {
redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
c.Set(RedactedAuthHeaderKey, redactedHeader)
}
// 1. 脱敏请求体
if shouldRedactBody(c) {
redactRequestBody(c)
}
// 2. 脱敏认证头
redactAuthHeader(c)
// 3. 脱敏 URL 查询参数
redactQueryParams(c)
c.Next()
}
}
// LogrusLogger is a Gin middleware that logs requests using a Logrus logger.
// It consumes redacted data prepared by the RedactionMiddleware.
// shouldRedactBody 判断是否需要脱敏请求体
func shouldRedactBody(c *gin.Context) bool {
method := c.Request.Method
contentType := c.GetHeader("Content-Type")
// 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求
return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") &&
strings.Contains(contentType, "application/json")
}
// redactRequestBody 脱敏请求体
func redactRequestBody(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return
}
// 恢复请求体供后续使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 脱敏敏感字段
bodyString := string(bodyBytes)
redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue)
c.Set(RedactedBodyKey, redactedBody)
}
// redactAuthHeader 脱敏认证头
func redactAuthHeader(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return
}
if bearerTokenPattern.MatchString(authHeader) {
redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
c.Set(RedactedAuthHeaderKey, redacted)
} else {
// 对于非 Bearer 的 token全部脱敏
c.Set(RedactedAuthHeaderKey, "[REDACTED]")
}
// 同时处理其他敏感 Header
sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"}
for _, header := range sensitiveHeaders {
if value := c.GetHeader(header); value != "" {
c.Set("redacted_"+header, "[REDACTED]")
}
}
}
// redactQueryParams 脱敏 URL 查询参数
func redactQueryParams(c *gin.Context) {
rawQuery := c.Request.URL.RawQuery
if rawQuery == "" {
return
}
redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`)
if redacted != rawQuery {
c.Set("redactedQuery", redacted)
}
}
// LogrusLogger Gin 请求日志中间件(使用 Logrus
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
// Process request
// 处理请求
c.Next()
// After request, gather data and log
// 计算延迟
latency := time.Since(start)
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
entry := logger.WithFields(logrus.Fields{
"status_code": statusCode,
"latency_ms": latency.Milliseconds(),
"client_ip": c.ClientIP(),
"method": c.Request.Method,
"path": path,
})
// 构建日志字段
fields := logrus.Fields{
"status": statusCode,
"latency_ms": latency.Milliseconds(),
"ip": clientIP,
"method": method,
"path": path,
}
// 添加请求 ID如果存在
if requestID := getRequestID(c); requestID != "" {
fields["request_id"] = requestID
}
// 添加脱敏后的数据
if redactedBody, exists := c.Get(RedactedBodyKey); exists {
entry = entry.WithField("body", redactedBody)
fields["body"] = redactedBody
}
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
entry = entry.WithField("authorization", redactedAuth)
fields["authorization"] = redactedAuth
}
if redactedQuery, exists := c.Get("redactedQuery"); exists {
fields["query"] = redactedQuery
}
// 添加用户信息(如果已认证)
if user := getAuthenticatedUser(c); user != "" {
fields["user"] = user
}
// 根据状态码选择日志级别
entry := logger.WithFields(fields)
if len(c.Errors) > 0 {
entry.Error(c.Errors.String())
fields["errors"] = c.Errors.String()
entry.Error("Request failed")
} else {
entry.Info("request handled")
switch {
case statusCode >= 500:
entry.Error("Server error")
case statusCode >= 400:
entry.Warn("Client error")
case statusCode >= 300:
entry.Info("Redirect")
default:
// 只在 Debug 模式记录成功请求
if logger.Level >= logrus.DebugLevel {
entry.Debug("Request completed")
}
}
}
}
}
// getRequestID 获取请求 ID
func getRequestID(c *gin.Context) string {
if id, exists := c.Get("request_id"); exists {
if requestID, ok := id.(string); ok {
return requestID
}
}
return ""
}
// getAuthenticatedUser 获取已认证用户标识
func getAuthenticatedUser(c *gin.Context) string {
// 尝试从不同来源获取用户信息
if user, exists := c.Get("adminUser"); exists {
if authToken, ok := user.(interface{ GetID() string }); ok {
return authToken.GetID()
}
}
if user, exists := c.Get("authToken"); exists {
if authToken, ok := user.(interface{ GetID() string }); ok {
return authToken.GetID()
}
}
return ""
}

View File

@@ -0,0 +1,86 @@
// Filename: internal/middleware/rate_limit.go
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
r rate.Limit // 每秒请求数
b int // 突发容量
}
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
r: r,
b: b,
}
}
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
limiter, exists := rl.limiters[key]
if !exists {
limiter = rate.NewLimiter(rl.r, rl.b)
rl.limiters[key] = limiter
}
return limiter
}
// 定期清理不活跃的限制器
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
// 简单策略:定期清空(生产环境应该用 LRU
rl.limiters = make(map[string]*rate.Limiter)
rl.mu.Unlock()
}
}
func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
// 按 IP 限流
key := c.ClientIP()
// 如果有认证 token按 token 限流(更精确)
if authToken, exists := c.Get("authToken"); exists {
if token, ok := authToken.(interface{ GetID() string }); ok {
key = "token:" + token.GetID()
}
}
l := limiter.getLimiter(key)
if !l.Allow() {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"code": "RATE_LIMIT",
})
return
}
c.Next()
}
}
// 使用示例
func SetupRateLimit(r *gin.Engine) {
limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20
go limiter.cleanup()
r.Use(RateLimitMiddleware(limiter))
}

View File

@@ -0,0 +1,39 @@
// Filename: internal/middleware/request_id.go
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// RequestIDMiddleware 请求 ID 追踪中间件
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 1. 尝试从 Header 获取现有的 Request ID
requestID := c.GetHeader("X-Request-Id")
// 2. 如果没有,生成新的
if requestID == "" {
requestID = uuid.New().String()
}
// 3. 设置到 Context
c.Set("request_id", requestID)
// 4. 返回给客户端(用于追踪)
c.Writer.Header().Set("X-Request-Id", requestID)
c.Next()
}
}
// GetRequestID 获取当前请求的 Request ID
func GetRequestID(c *gin.Context) string {
if id, exists := c.Get("request_id"); exists {
if requestID, ok := id.(string); ok {
return requestID
}
}
return ""
}

View File

@@ -1,4 +1,4 @@
// Filename: internal/middleware/security.go
// Filename: internal/middleware/security.go (简化版)
package middleware
@@ -6,26 +6,136 @@ import (
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc {
// 简单的缓存项
type cacheItem struct {
value bool
expiration int64
}
// 简单的 TTL 缓存实现
type IPBanCache struct {
items map[string]*cacheItem
mu sync.RWMutex
ttl time.Duration
}
func NewIPBanCache() *IPBanCache {
cache := &IPBanCache{
items: make(map[string]*cacheItem),
ttl: 1 * time.Minute,
}
// 启动清理协程
go cache.cleanup()
return cache
}
func (c *IPBanCache) Get(key string) (bool, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, exists := c.items[key]
if !exists {
return false, false
}
// 检查是否过期
if time.Now().UnixNano() > item.expiration {
return false, false
}
return item.value, true
}
func (c *IPBanCache) Set(key string, value bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = &cacheItem{
value: value,
expiration: time.Now().Add(c.ttl).UnixNano(),
}
}
func (c *IPBanCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
func (c *IPBanCache) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now().UnixNano()
for key, item := range c.items {
if now > item.expiration {
delete(c.items, key)
}
}
c.mu.Unlock()
}
}
func IPBanMiddleware(
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
banCache *IPBanCache,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) {
if !settingsManager.IsIPBanEnabled() {
c.Next()
return
}
ip := c.ClientIP()
isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
if err != nil {
// 查缓存
if isBanned, exists := banCache.Get(ip); exists {
if isBanned {
logger.WithField("ip", ip).Debug("IP blocked (cached)")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next()
return
}
if isBanned {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁请稍后再试"})
// 查数据库
ctx := c.Request.Context()
isBanned, err := securityService.IsIPBanned(ctx, ip)
if err != nil {
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
// 降级策略:允许访问
c.Next()
return
}
// 更新缓存
banCache.Set(ip, isBanned)
if isBanned {
logger.WithField("ip", ip).Info("IP blocked")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next()
}
}

View File

@@ -0,0 +1,52 @@
// Filename: internal/middleware/timeout.go
package middleware
import (
"context"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
// 创建带超时的 context
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 替换 request context
c.Request = c.Request.WithContext(ctx)
// 使用 channel 等待请求完成
finished := make(chan struct{})
go func() {
c.Next()
close(finished)
}()
select {
case <-finished:
// 请求正常完成
return
case <-ctx.Done():
// 超时
c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
"error": "Request timeout",
"code": "TIMEOUT",
})
}
}
}
// 使用示例
func SetupTimeout(r *gin.Engine) {
// 对 API 路由设置 30 秒超时
api := r.Group("/api")
api.Use(TimeoutMiddleware(30 * time.Second))
{
// ... API routes
}
}

View File

@@ -1,23 +1,151 @@
// Filename: internal/middleware/web.go
package middleware
import (
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/service"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
const (
AdminSessionCookie = "gemini_admin_session"
SessionMaxAge = 3600 * 24 * 7 // 7天
CacheTTL = 5 * time.Minute
CleanupInterval = 10 * time.Minute // 降低清理频率
SessionRefreshTime = 30 * time.Minute
)
// ==================== 缓存层 ====================
type authCacheEntry struct {
Token interface{}
ExpiresAt time.Time
}
type authCache struct {
mu sync.RWMutex
cache map[string]*authCacheEntry
ttl time.Duration
}
var webAuthCache = newAuthCache(CacheTTL)
func newAuthCache(ttl time.Duration) *authCache {
c := &authCache{
cache: make(map[string]*authCacheEntry),
ttl: ttl,
}
go c.cleanupLoop()
return c
}
func (c *authCache) get(key string) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.cache[key]
if !exists || time.Now().After(entry.ExpiresAt) {
return nil, false
}
return entry.Token, true
}
func (c *authCache) set(key string, token interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = &authCacheEntry{
Token: token,
ExpiresAt: time.Now().Add(c.ttl),
}
}
func (c *authCache) delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
func (c *authCache) cleanupLoop() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for range ticker.C {
c.cleanup()
}
}
func (c *authCache) cleanup() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
count := 0
for key, entry := range c.cache {
if now.After(entry.ExpiresAt) {
delete(c.cache, key)
count++
}
}
if count > 0 {
logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count)
}
}
// ==================== 会话刷新缓存 ====================
var sessionRefreshCache = struct {
sync.RWMutex
timestamps map[string]time.Time
}{
timestamps: make(map[string]time.Time),
}
// 定期清理刷新时间戳
func init() {
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
sessionRefreshCache.Lock()
now := time.Now()
for key, ts := range sessionRefreshCache.timestamps {
if now.Sub(ts) > 2*time.Hour {
delete(sessionRefreshCache.timestamps, key)
}
}
sessionRefreshCache.Unlock()
}
}()
}
// ==================== Cookie 操作 ====================
func SetAdminSessionCookie(c *gin.Context, adminToken string) {
c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true)
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true)
}
func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) {
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true)
}
func ClearAdminSessionCookie(c *gin.Context) {
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
}
@@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string {
return cookie
}
// ==================== 认证中间件 ====================
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
logger := logrus.New()
logger.SetLevel(getLogLevel())
return func(c *gin.Context) {
cookie := ExtractTokenFromCookie(c)
log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path)
log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie)
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
} else if !authToken.IsAdmin {
log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
} else {
log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
}
if err != nil || !authToken.IsAdmin {
if cookie == "" {
logger.Debug("[WebAuth] No session cookie found")
ClearAdminSessionCookie(c)
c.Redirect(http.StatusFound, "/login")
c.Abort()
redirectToLogin(c)
return
}
cacheKey := hashToken(cookie)
if cachedToken, found := webAuthCache.get(cacheKey); found {
logger.Debug("[WebAuth] Using cached token")
c.Set("adminUser", cachedToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
return
}
logger.Debug("[WebAuth] Cache miss, authenticating...")
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
logger.WithError(err).Warn("[WebAuth] Authentication failed")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
if !authToken.IsAdmin {
logger.Warn("[WebAuth] User is not admin")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
webAuthCache.set(cacheKey, authToken)
logger.Debug("[WebAuth] Authentication success, token cached")
c.Set("adminUser", authToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
}
}
func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
cookie := ExtractTokenFromCookie(c)
if cookie == "" {
logger.Debug("No session cookie found")
ClearAdminSessionCookie(c)
redirectToLogin(c)
return
}
cacheKey := hashToken(cookie)
if cachedToken, found := webAuthCache.get(cacheKey); found {
c.Set("adminUser", cachedToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
return
}
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
logger.WithError(err).Warn("Token authentication failed")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
if !authToken.IsAdmin {
logger.Warn("Token valid but user is not admin")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
webAuthCache.set(cacheKey, authToken)
c.Set("adminUser", authToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
}
}
// ==================== 辅助函数 ====================
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return hex.EncodeToString(h[:])
}
func redirectToLogin(c *gin.Context) {
if isAjaxRequest(c) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Session expired",
"code": "AUTH_REQUIRED",
})
c.Abort()
return
}
originalPath := c.Request.URL.Path
if originalPath != "/" && originalPath != "/login" {
c.Redirect(http.StatusFound, "/login?redirect="+originalPath)
} else {
c.Redirect(http.StatusFound, "/login")
}
c.Abort()
}
func isAjaxRequest(c *gin.Context) bool {
// 检查 Content-Type
contentType := c.GetHeader("Content-Type")
if strings.Contains(contentType, "application/json") {
return true
}
// 检查 Accept优先检查 JSON
accept := c.GetHeader("Accept")
if strings.Contains(accept, "application/json") &&
!strings.Contains(accept, "text/html") {
return true
}
// 兼容旧版 XMLHttpRequest
return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
}
func refreshSessionIfNeeded(c *gin.Context, token string) {
tokenHash := hashToken(token)
sessionRefreshCache.RLock()
lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash]
sessionRefreshCache.RUnlock()
if !exists || time.Since(lastRefresh) > SessionRefreshTime {
SetAdminSessionCookie(c, token)
sessionRefreshCache.Lock()
sessionRefreshCache.timestamps[tokenHash] = time.Now()
sessionRefreshCache.Unlock()
}
}
func getLogLevel() logrus.Level {
level := os.Getenv("LOG_LEVEL")
switch strings.ToLower(level) {
case "debug":
return logrus.DebugLevel
case "warn":
return logrus.WarnLevel
case "error":
return logrus.ErrorLevel
default:
return logrus.InfoLevel
}
}
// ==================== 工具函数 ====================
func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) {
return c.Get("adminUser")
}
func InvalidateTokenCache(token string) {
tokenHash := hashToken(token)
webAuthCache.delete(tokenHash)
// 同时清理刷新时间戳
sessionRefreshCache.Lock()
delete(sessionRefreshCache.timestamps, tokenHash)
sessionRefreshCache.Unlock()
}
func ClearAllAuthCache() {
webAuthCache.mu.Lock()
webAuthCache.cache = make(map[string]*authCacheEntry)
webAuthCache.mu.Unlock()
sessionRefreshCache.Lock()
sessionRefreshCache.timestamps = make(map[string]time.Time)
sessionRefreshCache.Unlock()
}
// ==================== 调试工具 ====================
type SessionInfo struct {
HasCookie bool `json:"has_cookie"`
IsValid bool `json:"is_valid"`
IsAdmin bool `json:"is_admin"`
IsCached bool `json:"is_cached"`
LastActivity string `json:"last_activity"`
}
func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo {
info := SessionInfo{
HasCookie: false,
IsValid: false,
IsAdmin: false,
IsCached: false,
LastActivity: time.Now().Format(time.RFC3339),
}
cookie := ExtractTokenFromCookie(c)
if cookie == "" {
return info
}
info.HasCookie = true
cacheKey := hashToken(cookie)
if _, found := webAuthCache.get(cacheKey); found {
info.IsCached = true
}
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
return info
}
info.IsValid = true
info.IsAdmin = authToken.IsAdmin
return info
}
func GetCacheStats() map[string]interface{} {
webAuthCache.mu.RLock()
cacheSize := len(webAuthCache.cache)
webAuthCache.mu.RUnlock()
sessionRefreshCache.RLock()
refreshSize := len(sessionRefreshCache.timestamps)
sessionRefreshCache.RUnlock()
return map[string]interface{}{
"auth_cache_entries": cacheSize,
"refresh_cache_entries": refreshSize,
"ttl_seconds": int(webAuthCache.ttl.Seconds()),
"cleanup_interval": int(CleanupInterval.Seconds()),
"session_refresh_time": int(SessionRefreshTime.Seconds()),
}
}