118 lines
2.7 KiB
Go
118 lines
2.7 KiB
Go
// security/validator.go
|
|
package security
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
type RequestValidator struct {
|
|
blockedDomains map[string]bool
|
|
blockedCIDRs []*net.IPNet
|
|
allowedSchemes map[string]bool
|
|
}
|
|
|
|
func NewRequestValidator(blockedDomains map[string]bool, blockedCIDRs []string, allowedSchemes []string) *RequestValidator {
|
|
v := &RequestValidator{
|
|
blockedDomains: blockedDomains,
|
|
allowedSchemes: make(map[string]bool),
|
|
}
|
|
|
|
// 解析 CIDR
|
|
for _, cidr := range blockedCIDRs {
|
|
_, ipNet, err := net.ParseCIDR(strings.TrimSpace(cidr))
|
|
if err == nil {
|
|
v.blockedCIDRs = append(v.blockedCIDRs, ipNet)
|
|
}
|
|
}
|
|
|
|
// 设置允许的协议
|
|
for _, scheme := range allowedSchemes {
|
|
v.allowedSchemes[strings.TrimSpace(scheme)] = true
|
|
}
|
|
|
|
return v
|
|
}
|
|
|
|
func (v *RequestValidator) ValidateURL(urlStr string) error {
|
|
u, err := url.Parse(urlStr)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL format")
|
|
}
|
|
|
|
// 检查协议
|
|
if !v.allowedSchemes[u.Scheme] {
|
|
return fmt.Errorf("scheme not allowed: %s", u.Scheme)
|
|
}
|
|
|
|
// 提取主机名
|
|
host := u.Hostname()
|
|
if host == "" {
|
|
return fmt.Errorf("invalid host")
|
|
}
|
|
|
|
// 检查域名黑名单(支持通配符)
|
|
if v.isBlockedDomain(host) {
|
|
return fmt.Errorf("domain blocked: %s", host)
|
|
}
|
|
|
|
// 解析 IP
|
|
ips, err := net.LookupIP(host)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot resolve host: %s", host)
|
|
}
|
|
|
|
// 检查所有解析的 IP
|
|
for _, ip := range ips {
|
|
if v.isBlockedIP(ip) {
|
|
return fmt.Errorf("IP blocked: %s resolves to %s", host, ip)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (v *RequestValidator) isBlockedDomain(host string) bool {
|
|
host = strings.ToLower(host)
|
|
|
|
// 精确匹配
|
|
if v.blockedDomains[host] {
|
|
return true
|
|
}
|
|
|
|
// 通配符匹配 (*.example.com)
|
|
for domain := range v.blockedDomains {
|
|
if strings.HasPrefix(domain, "*.") {
|
|
suffix := domain[1:] // 移除 *
|
|
if strings.HasSuffix(host, suffix) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (v *RequestValidator) isBlockedIP(ip net.IP) bool {
|
|
// 检查是否是私有 IP
|
|
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
|
return true
|
|
}
|
|
|
|
// 检查 CIDR 黑名单
|
|
for _, ipNet := range v.blockedCIDRs {
|
|
if ipNet.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// 检查特殊地址
|
|
if ip.IsUnspecified() || ip.IsMulticast() {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|