// Filename: internal/middleware/security.go (简化版) package middleware import ( "gemini-balancer/internal/service" "gemini-balancer/internal/settings" "net/http" "sync" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) // 简单的缓存项 type cacheItem struct { value bool expiration int64 } // 简单的 TTL 缓存实现 type IPBanCache struct { items map[string]*cacheItem mu sync.RWMutex ttl time.Duration } func NewIPBanCache() *IPBanCache { cache := &IPBanCache{ items: make(map[string]*cacheItem), ttl: 1 * time.Minute, } // 启动清理协程 go cache.cleanup() return cache } func (c *IPBanCache) Get(key string) (bool, bool) { c.mu.RLock() defer c.mu.RUnlock() item, exists := c.items[key] if !exists { return false, false } // 检查是否过期 if time.Now().UnixNano() > item.expiration { return false, false } return item.value, true } func (c *IPBanCache) Set(key string, value bool) { c.mu.Lock() defer c.mu.Unlock() c.items[key] = &cacheItem{ value: value, expiration: time.Now().Add(c.ttl).UnixNano(), } } func (c *IPBanCache) Delete(key string) { c.mu.Lock() defer c.mu.Unlock() delete(c.items, key) } func (c *IPBanCache) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { c.mu.Lock() now := time.Now().UnixNano() for key, item := range c.items { if now > item.expiration { delete(c.items, key) } } c.mu.Unlock() } } func IPBanMiddleware( securityService *service.SecurityService, settingsManager *settings.SettingsManager, banCache *IPBanCache, logger *logrus.Logger, ) gin.HandlerFunc { return func(c *gin.Context) { if !settingsManager.IsIPBanEnabled() { c.Next() return } ip := c.ClientIP() // 查缓存 if isBanned, exists := banCache.Get(ip); exists { if isBanned { logger.WithField("ip", ip).Debug("IP blocked (cached)") c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "Access denied", }) return } c.Next() return } // 查数据库 ctx := c.Request.Context() isBanned, err := securityService.IsIPBanned(ctx, ip) if err != nil { logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status") // 降级策略:允许访问 c.Next() return } // 更新缓存 banCache.Set(ip, isBanned) if isBanned { logger.WithField("ip", ip).Info("IP blocked") c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "Access denied", }) return } c.Next() } }