添加 security/validator.go
This commit is contained in:
117
security/validator.go
Normal file
117
security/validator.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
// security/validator.go
|
||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestValidator struct {
|
||||||
|
blockedDomains map[string]bool
|
||||||
|
blockedCIDRs []*net.IPNet
|
||||||
|
allowedSchemes map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequestValidator(blockedDomains map[string]bool, blockedCIDRs []string, allowedSchemes []string) *RequestValidator {
|
||||||
|
v := &RequestValidator{
|
||||||
|
blockedDomains: blockedDomains,
|
||||||
|
allowedSchemes: make(map[string]bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 CIDR
|
||||||
|
for _, cidr := range blockedCIDRs {
|
||||||
|
_, ipNet, err := net.ParseCIDR(strings.TrimSpace(cidr))
|
||||||
|
if err == nil {
|
||||||
|
v.blockedCIDRs = append(v.blockedCIDRs, ipNet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置允许的协议
|
||||||
|
for _, scheme := range allowedSchemes {
|
||||||
|
v.allowedSchemes[strings.TrimSpace(scheme)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *RequestValidator) ValidateURL(urlStr string) error {
|
||||||
|
u, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid URL format")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查协议
|
||||||
|
if !v.allowedSchemes[u.Scheme] {
|
||||||
|
return fmt.Errorf("scheme not allowed: %s", u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取主机名
|
||||||
|
host := u.Hostname()
|
||||||
|
if host == "" {
|
||||||
|
return fmt.Errorf("invalid host")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查域名黑名单(支持通配符)
|
||||||
|
if v.isBlockedDomain(host) {
|
||||||
|
return fmt.Errorf("domain blocked: %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 IP
|
||||||
|
ips, err := net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot resolve host: %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查所有解析的 IP
|
||||||
|
for _, ip := range ips {
|
||||||
|
if v.isBlockedIP(ip) {
|
||||||
|
return fmt.Errorf("IP blocked: %s resolves to %s", host, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *RequestValidator) isBlockedDomain(host string) bool {
|
||||||
|
host = strings.ToLower(host)
|
||||||
|
|
||||||
|
// 精确匹配
|
||||||
|
if v.blockedDomains[host] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通配符匹配 (*.example.com)
|
||||||
|
for domain := range v.blockedDomains {
|
||||||
|
if strings.HasPrefix(domain, "*.") {
|
||||||
|
suffix := domain[1:] // 移除 *
|
||||||
|
if strings.HasSuffix(host, suffix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *RequestValidator) isBlockedIP(ip net.IP) bool {
|
||||||
|
// 检查是否是私有 IP
|
||||||
|
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 CIDR 黑名单
|
||||||
|
for _, ipNet := range v.blockedCIDRs {
|
||||||
|
if ipNet.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查特殊地址
|
||||||
|
if ip.IsUnspecified() || ip.IsMulticast() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user