Files
SiteProxy/security/validator.go
2025-12-15 01:15:27 +08:00

118 lines
2.7 KiB
Go

// security/validator.go
package security
import (
"fmt"
"net"
"net/url"
"strings"
)
type RequestValidator struct {
blockedDomains map[string]bool
blockedCIDRs []*net.IPNet
allowedSchemes map[string]bool
}
func NewRequestValidator(blockedDomains map[string]bool, blockedCIDRs []string, allowedSchemes []string) *RequestValidator {
v := &RequestValidator{
blockedDomains: blockedDomains,
allowedSchemes: make(map[string]bool),
}
// 解析 CIDR
for _, cidr := range blockedCIDRs {
_, ipNet, err := net.ParseCIDR(strings.TrimSpace(cidr))
if err == nil {
v.blockedCIDRs = append(v.blockedCIDRs, ipNet)
}
}
// 设置允许的协议
for _, scheme := range allowedSchemes {
v.allowedSchemes[strings.TrimSpace(scheme)] = true
}
return v
}
func (v *RequestValidator) ValidateURL(urlStr string) error {
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format")
}
// 检查协议
if !v.allowedSchemes[u.Scheme] {
return fmt.Errorf("scheme not allowed: %s", u.Scheme)
}
// 提取主机名
host := u.Hostname()
if host == "" {
return fmt.Errorf("invalid host")
}
// 检查域名黑名单(支持通配符)
if v.isBlockedDomain(host) {
return fmt.Errorf("domain blocked: %s", host)
}
// 解析 IP
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("cannot resolve host: %s", host)
}
// 检查所有解析的 IP
for _, ip := range ips {
if v.isBlockedIP(ip) {
return fmt.Errorf("IP blocked: %s resolves to %s", host, ip)
}
}
return nil
}
func (v *RequestValidator) isBlockedDomain(host string) bool {
host = strings.ToLower(host)
// 精确匹配
if v.blockedDomains[host] {
return true
}
// 通配符匹配 (*.example.com)
for domain := range v.blockedDomains {
if strings.HasPrefix(domain, "*.") {
suffix := domain[1:] // 移除 *
if strings.HasSuffix(host, suffix) {
return true
}
}
}
return false
}
func (v *RequestValidator) isBlockedIP(ip net.IP) bool {
// 检查是否是私有 IP
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
// 检查 CIDR 黑名单
for _, ipNet := range v.blockedCIDRs {
if ipNet.Contains(ip) {
return true
}
}
// 检查特殊地址
if ip.IsUnspecified() || ip.IsMulticast() {
return true
}
return false
}