diff --git a/security/ratelimit.go b/security/ratelimit.go new file mode 100644 index 0000000..f48c1b6 --- /dev/null +++ b/security/ratelimit.go @@ -0,0 +1,72 @@ +// security/ratelimit.go +package security + +import ( + "sync" + "time" +) + +type RequestCounter struct { + Count int + ResetTime time.Time +} + +type RateLimiter struct { + requests sync.Map + limit int + window time.Duration +} + +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + rl := &RateLimiter{ + limit: limit, + window: window, + } + + // 启动清理协程 + go rl.cleanup() + + return rl +} + +func (rl *RateLimiter) Allow(sessionID string) bool { + now := time.Now() + + val, _ := rl.requests.LoadOrStore(sessionID, &RequestCounter{ + Count: 0, + ResetTime: now.Add(rl.window), + }) + + counter := val.(*RequestCounter) + + // 检查是否需要重置 + if now.After(counter.ResetTime) { + counter.Count = 1 + counter.ResetTime = now.Add(rl.window) + return true + } + + // 检查是否超过限制 + if counter.Count >= rl.limit { + return false + } + + counter.Count++ + return true +} + +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + rl.requests.Range(func(key, value interface{}) bool { + counter := value.(*RequestCounter) + if now.After(counter.ResetTime.Add(5 * time.Minute)) { + rl.requests.Delete(key) + } + return true + }) + } +}