// security/validator.go package security import ( "fmt" "net" "net/url" "strings" ) type RequestValidator struct { blockedDomains map[string]bool blockedCIDRs []*net.IPNet allowedSchemes map[string]bool } func NewRequestValidator(blockedDomains map[string]bool, blockedCIDRs []string, allowedSchemes []string) *RequestValidator { v := &RequestValidator{ blockedDomains: blockedDomains, allowedSchemes: make(map[string]bool), } // 解析 CIDR for _, cidr := range blockedCIDRs { _, ipNet, err := net.ParseCIDR(strings.TrimSpace(cidr)) if err == nil { v.blockedCIDRs = append(v.blockedCIDRs, ipNet) } } // 设置允许的协议 for _, scheme := range allowedSchemes { v.allowedSchemes[strings.TrimSpace(scheme)] = true } return v } func (v *RequestValidator) ValidateURL(urlStr string) error { u, err := url.Parse(urlStr) if err != nil { return fmt.Errorf("invalid URL format") } // 检查协议 if !v.allowedSchemes[u.Scheme] { return fmt.Errorf("scheme not allowed: %s", u.Scheme) } // 提取主机名 host := u.Hostname() if host == "" { return fmt.Errorf("invalid host") } // 检查域名黑名单(支持通配符) if v.isBlockedDomain(host) { return fmt.Errorf("domain blocked: %s", host) } // 解析 IP ips, err := net.LookupIP(host) if err != nil { return fmt.Errorf("cannot resolve host: %s", host) } // 检查所有解析的 IP for _, ip := range ips { if v.isBlockedIP(ip) { return fmt.Errorf("IP blocked: %s resolves to %s", host, ip) } } return nil } func (v *RequestValidator) isBlockedDomain(host string) bool { host = strings.ToLower(host) // 精确匹配 if v.blockedDomains[host] { return true } // 通配符匹配 (*.example.com) for domain := range v.blockedDomains { if strings.HasPrefix(domain, "*.") { suffix := domain[1:] // 移除 * if strings.HasSuffix(host, suffix) { return true } } } return false } func (v *RequestValidator) isBlockedIP(ip net.IP) bool { // 检查是否是私有 IP if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { return true } // 检查 CIDR 黑名单 for _, ipNet := range v.blockedCIDRs { if ipNet.Contains(ip) { return true } } // 检查特殊地址 if ip.IsUnspecified() || ip.IsMulticast() { return true } return false }