Files
gemini-banlancer/internal/middleware/security.go

142 lines
2.5 KiB
Go

// Filename: internal/middleware/security.go (简化版)
package middleware
import (
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
// 简单的缓存项
type cacheItem struct {
value bool
expiration int64
}
// 简单的 TTL 缓存实现
type IPBanCache struct {
items map[string]*cacheItem
mu sync.RWMutex
ttl time.Duration
}
func NewIPBanCache() *IPBanCache {
cache := &IPBanCache{
items: make(map[string]*cacheItem),
ttl: 1 * time.Minute,
}
// 启动清理协程
go cache.cleanup()
return cache
}
func (c *IPBanCache) Get(key string) (bool, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, exists := c.items[key]
if !exists {
return false, false
}
// 检查是否过期
if time.Now().UnixNano() > item.expiration {
return false, false
}
return item.value, true
}
func (c *IPBanCache) Set(key string, value bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = &cacheItem{
value: value,
expiration: time.Now().Add(c.ttl).UnixNano(),
}
}
func (c *IPBanCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
func (c *IPBanCache) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now().UnixNano()
for key, item := range c.items {
if now > item.expiration {
delete(c.items, key)
}
}
c.mu.Unlock()
}
}
func IPBanMiddleware(
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
banCache *IPBanCache,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) {
if !settingsManager.IsIPBanEnabled() {
c.Next()
return
}
ip := c.ClientIP()
// 查缓存
if isBanned, exists := banCache.Get(ip); exists {
if isBanned {
logger.WithField("ip", ip).Debug("IP blocked (cached)")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next()
return
}
// 查数据库
ctx := c.Request.Context()
isBanned, err := securityService.IsIPBanned(ctx, ip)
if err != nil {
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
// 降级策略:允许访问
c.Next()
return
}
// 更新缓存
banCache.Set(ip, isBanned)
if isBanned {
logger.WithField("ip", ip).Info("IP blocked")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next()
}
}