// Filename: internal/middleware/rate_limit.go package middleware import ( "net/http" "sync" "time" "github.com/gin-gonic/gin" "golang.org/x/time/rate" ) type RateLimiter struct { limiters map[string]*rate.Limiter mu sync.RWMutex r rate.Limit // 每秒请求数 b int // 突发容量 } func NewRateLimiter(r rate.Limit, b int) *RateLimiter { return &RateLimiter{ limiters: make(map[string]*rate.Limiter), r: r, b: b, } } func (rl *RateLimiter) getLimiter(key string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() limiter, exists := rl.limiters[key] if !exists { limiter = rate.NewLimiter(rl.r, rl.b) rl.limiters[key] = limiter } return limiter } // 定期清理不活跃的限制器 func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(10 * time.Minute) defer ticker.Stop() for range ticker.C { rl.mu.Lock() // 简单策略:定期清空(生产环境应该用 LRU) rl.limiters = make(map[string]*rate.Limiter) rl.mu.Unlock() } } func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc { return func(c *gin.Context) { // 按 IP 限流 key := c.ClientIP() // 如果有认证 token,按 token 限流(更精确) if authToken, exists := c.Get("authToken"); exists { if token, ok := authToken.(interface{ GetID() string }); ok { key = "token:" + token.GetID() } } l := limiter.getLimiter(key) if !l.Allow() { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded", "code": "RATE_LIMIT", }) return } c.Next() } } // 使用示例 func SetupRateLimit(r *gin.Engine) { limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20 go limiter.cleanup() r.Use(RateLimitMiddleware(limiter)) }