142 lines
2.5 KiB
Go
142 lines
2.5 KiB
Go
// Filename: internal/middleware/security.go (简化版)
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"gemini-balancer/internal/service"
|
|
"gemini-balancer/internal/settings"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// 简单的缓存项
|
|
type cacheItem struct {
|
|
value bool
|
|
expiration int64
|
|
}
|
|
|
|
// 简单的 TTL 缓存实现
|
|
type IPBanCache struct {
|
|
items map[string]*cacheItem
|
|
mu sync.RWMutex
|
|
ttl time.Duration
|
|
}
|
|
|
|
func NewIPBanCache() *IPBanCache {
|
|
cache := &IPBanCache{
|
|
items: make(map[string]*cacheItem),
|
|
ttl: 1 * time.Minute,
|
|
}
|
|
|
|
// 启动清理协程
|
|
go cache.cleanup()
|
|
|
|
return cache
|
|
}
|
|
|
|
func (c *IPBanCache) Get(key string) (bool, bool) {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
item, exists := c.items[key]
|
|
if !exists {
|
|
return false, false
|
|
}
|
|
|
|
// 检查是否过期
|
|
if time.Now().UnixNano() > item.expiration {
|
|
return false, false
|
|
}
|
|
|
|
return item.value, true
|
|
}
|
|
|
|
func (c *IPBanCache) Set(key string, value bool) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
c.items[key] = &cacheItem{
|
|
value: value,
|
|
expiration: time.Now().Add(c.ttl).UnixNano(),
|
|
}
|
|
}
|
|
|
|
func (c *IPBanCache) Delete(key string) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
delete(c.items, key)
|
|
}
|
|
|
|
func (c *IPBanCache) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
c.mu.Lock()
|
|
now := time.Now().UnixNano()
|
|
for key, item := range c.items {
|
|
if now > item.expiration {
|
|
delete(c.items, key)
|
|
}
|
|
}
|
|
c.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func IPBanMiddleware(
|
|
securityService *service.SecurityService,
|
|
settingsManager *settings.SettingsManager,
|
|
banCache *IPBanCache,
|
|
logger *logrus.Logger,
|
|
) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if !settingsManager.IsIPBanEnabled() {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
ip := c.ClientIP()
|
|
|
|
// 查缓存
|
|
if isBanned, exists := banCache.Get(ip); exists {
|
|
if isBanned {
|
|
logger.WithField("ip", ip).Debug("IP blocked (cached)")
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
|
"error": "Access denied",
|
|
})
|
|
return
|
|
}
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// 查数据库
|
|
ctx := c.Request.Context()
|
|
isBanned, err := securityService.IsIPBanned(ctx, ip)
|
|
if err != nil {
|
|
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
|
|
|
|
// 降级策略:允许访问
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// 更新缓存
|
|
banCache.Set(ip, isBanned)
|
|
|
|
if isBanned {
|
|
logger.WithField("ip", ip).Info("IP blocked")
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
|
"error": "Access denied",
|
|
})
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|