From 984a8518cd8531d6fc5245edb032154027a6d155 Mon Sep 17 00:00:00 2001 From: XOF Date: Mon, 15 Dec 2025 01:15:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20security/validator.go?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- security/validator.go | 117 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 security/validator.go diff --git a/security/validator.go b/security/validator.go new file mode 100644 index 0000000..e4ee5a2 --- /dev/null +++ b/security/validator.go @@ -0,0 +1,117 @@ +// security/validator.go +package security + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +type RequestValidator struct { + blockedDomains map[string]bool + blockedCIDRs []*net.IPNet + allowedSchemes map[string]bool +} + +func NewRequestValidator(blockedDomains map[string]bool, blockedCIDRs []string, allowedSchemes []string) *RequestValidator { + v := &RequestValidator{ + blockedDomains: blockedDomains, + allowedSchemes: make(map[string]bool), + } + + // 解析 CIDR + for _, cidr := range blockedCIDRs { + _, ipNet, err := net.ParseCIDR(strings.TrimSpace(cidr)) + if err == nil { + v.blockedCIDRs = append(v.blockedCIDRs, ipNet) + } + } + + // 设置允许的协议 + for _, scheme := range allowedSchemes { + v.allowedSchemes[strings.TrimSpace(scheme)] = true + } + + return v +} + +func (v *RequestValidator) ValidateURL(urlStr string) error { + u, err := url.Parse(urlStr) + if err != nil { + return fmt.Errorf("invalid URL format") + } + + // 检查协议 + if !v.allowedSchemes[u.Scheme] { + return fmt.Errorf("scheme not allowed: %s", u.Scheme) + } + + // 提取主机名 + host := u.Hostname() + if host == "" { + return fmt.Errorf("invalid host") + } + + // 检查域名黑名单(支持通配符) + if v.isBlockedDomain(host) { + return fmt.Errorf("domain blocked: %s", host) + } + + // 解析 IP + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("cannot resolve host: %s", host) + } + + // 检查所有解析的 IP + for _, ip := range ips { + if v.isBlockedIP(ip) { + return fmt.Errorf("IP blocked: %s resolves to %s", host, ip) + } + } + + return nil +} + +func (v *RequestValidator) isBlockedDomain(host string) bool { + host = strings.ToLower(host) + + // 精确匹配 + if v.blockedDomains[host] { + return true + } + + // 通配符匹配 (*.example.com) + for domain := range v.blockedDomains { + if strings.HasPrefix(domain, "*.") { + suffix := domain[1:] // 移除 * + if strings.HasSuffix(host, suffix) { + return true + } + } + } + + return false +} + +func (v *RequestValidator) isBlockedIP(ip net.IP) bool { + // 检查是否是私有 IP + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + // 检查 CIDR 黑名单 + for _, ipNet := range v.blockedCIDRs { + if ipNet.Contains(ip) { + return true + } + } + + // 检查特殊地址 + if ip.IsUnspecified() || ip.IsMulticast() { + return true + } + + return false +}