Files
godns/internal/handler/handler.go
2026-01-06 19:16:21 +08:00

743 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"strconv"
"github.com/miekg/dns"
"godns/internal/cache"
"godns/internal/model"
"godns/internal/stats"
"godns/pkg/logger"
)
type Handler struct {
strategy int
commonUpstreams, specialUpstreams []*model.Upstream
builtInCache cache.Cache
logger logger.Logger
stats stats.StatsRecorder
}
func NewHandler(strategy int, builtInCache bool,
upstreams []*model.Upstream,
dataPath string,
log logger.Logger,
statsRecorder stats.StatsRecorder) *Handler {
var c cache.Cache
if builtInCache {
var err error
c, err = cache.NewBadgerCache(dataPath, log)
if err != nil {
log.Printf("Failed to initialize BadgerDB cache: %v", err)
log.Printf("Cache will be disabled")
c = nil
} else {
log.Printf("BadgerDB cache initialized successfully at %s", dataPath)
}
}
var commonUpstreams, specialUpstreams []*model.Upstream
for i := 0; i < len(upstreams); i++ {
if len(upstreams[i].Match) > 0 {
specialUpstreams = append(specialUpstreams, upstreams[i])
} else {
commonUpstreams = append(commonUpstreams, upstreams[i])
}
}
return &Handler{
strategy: strategy,
commonUpstreams: commonUpstreams,
specialUpstreams: specialUpstreams,
builtInCache: c,
logger: log,
stats: statsRecorder,
}
}
func (h *Handler) matchedUpstreams(req *dns.Msg) []*model.Upstream {
if len(req.Question) == 0 {
return h.commonUpstreams
}
q := req.Question[0]
var matchedUpstreams []*model.Upstream
for i := 0; i < len(h.specialUpstreams); i++ {
if h.specialUpstreams[i].IsMatch(q.Name) {
matchedUpstreams = append(matchedUpstreams, h.specialUpstreams[i])
}
}
if len(matchedUpstreams) > 0 {
return matchedUpstreams
}
return h.commonUpstreams
}
func (h *Handler) LookupIP(host string) (ip net.IP, err error) {
if ip = net.ParseIP(host); ip != nil {
return ip, nil
}
if !strings.HasSuffix(host, ".") {
host += "."
}
m := new(dns.Msg)
m.Id = dns.Id()
m.RecursionDesired = true
m.Question = make([]dns.Question, 1)
m.Question[0] = dns.Question{Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET}
res := h.exchange(m)
// 取一个 IPv4 地址
for i := 0; i < len(res.Answer); i++ {
if aRecord, ok := res.Answer[i].(*dns.A); ok {
ip = aRecord.A
}
}
// 选取最后一个(一般是备用,存活率高一些)
if ip == nil {
err = errors.New("no ipv4 address found")
}
h.logger.Printf("bootstrap LookupIP: %s %v --> %s %v", host, res.Answer, ip, err)
return
}
// removeEDNS 清理请求中的 EDNS 客户端子网信息
func (h *Handler) removeEDNS(req *dns.Msg) {
opt := req.IsEdns0()
if opt == nil {
return
}
// 过滤掉 EDNS Client Subnet 选项
var newOptions []dns.EDNS0
for _, option := range opt.Option {
if _, ok := option.(*dns.EDNS0_SUBNET); !ok {
// 保留非 ECS 的其他选项
newOptions = append(newOptions, option)
} else {
h.logger.Printf("Removed EDNS Client Subnet from request")
}
}
opt.Option = newOptions
}
func (h *Handler) exchange(req *dns.Msg) *dns.Msg {
// 清理 EDNS 客户端子网信息
h.removeEDNS(req)
var msgs []*dns.Msg
switch h.strategy {
case model.StrategyFullest:
msgs = h.getTheFullestResults(req)
case model.StrategyFastest:
msgs = h.getTheFastestResults(req)
case model.StrategyAnyResult:
msgs = h.getAnyResult(req)
}
var res *dns.Msg
for i := 0; i < len(msgs); i++ {
if msgs[i] == nil {
continue
}
if res == nil {
res = msgs[i]
continue
}
res.Answer = append(res.Answer, msgs[i].Answer...)
}
if res == nil {
// 如果全部上游挂了要返回错误
res = new(dns.Msg)
res.Rcode = dns.RcodeServerFailure
} else {
res.Answer = uniqueAnswer(res.Answer)
}
return res
}
func getDnsRequestCacheKey(m *dns.Msg) string {
if len(m.Question) == 0 {
return ""
}
var b strings.Builder
b.Grow(64)
b.WriteString(model.GetDomainNameFromDnsMsg(m))
b.WriteByte('#')
b.WriteString(strconv.Itoa(int(m.Question[0].Qtype)))
if o := m.IsEdns0(); o != nil && o.Do() {
b.WriteString("#DO")
}
return b.String()
}
func getDnsResponseTtl(m *dns.Msg) time.Duration {
var ttl uint32
if len(m.Answer) > 0 {
ttl = m.Answer[0].Header().Ttl
}
if ttl < 60 {
ttl = 60 // 最小 ttl 1 分钟
} else if ttl > 3600 {
ttl = 3600 // 最大 ttl 1 小时
}
return time.Duration(ttl) * time.Second
}
// shouldCacheResponse 判断响应是否应该被缓存
func shouldCacheResponse(m *dns.Msg) bool {
// 不缓存服务器错误响应
if m.Rcode == dns.RcodeServerFailure {
return false
}
// 不缓存格式错误的响应
if m.Rcode == dns.RcodeFormatError {
return false
}
// NXDOMAIN (域名不存在) 可以缓存,但时间较短(由 getDnsResponseTtl 控制)
// NOERROR 和 NXDOMAIN 都可以缓存
return m.Rcode == dns.RcodeSuccess || m.Rcode == dns.RcodeNameError
}
// validateResponse 验证 DNS 响应,防止缓存投毒
// 返回 true 表示响应有效false 表示可能存在投毒风险
func validateResponse(req *dns.Msg, resp *dns.Msg, debugLogger logger.Logger) bool {
// 1. 检查响应是否为空
if resp == nil {
return false
}
// 2. 检查请求和响应的问题数量
if len(req.Question) == 0 || len(resp.Question) == 0 {
return true // 如果没有问题部分,跳过验证(某些响应可能没有问题部分)
}
// 3. 验证域名匹配(不区分大小写)
if !strings.EqualFold(req.Question[0].Name, resp.Question[0].Name) {
debugLogger.Printf("DNS response validation failed: domain mismatch - request: %s, response: %s",
req.Question[0].Name, resp.Question[0].Name)
return false
}
// 4. 验证查询类型匹配
if req.Question[0].Qtype != resp.Question[0].Qtype {
debugLogger.Printf("DNS response validation failed: qtype mismatch - request: %d, response: %d",
req.Question[0].Qtype, resp.Question[0].Qtype)
return false
}
// 5. 验证查询类别匹配(通常都是 IN - Internet
if req.Question[0].Qclass != resp.Question[0].Qclass {
debugLogger.Printf("DNS response validation failed: qclass mismatch - request: %d, response: %d",
req.Question[0].Qclass, resp.Question[0].Qclass)
return false
}
// 6. 验证 Answer 部分的域名(防止返回无关域名的记录)
requestDomain := strings.ToLower(strings.TrimSuffix(req.Question[0].Name, "."))
validDomains := make(map[string]bool)
validDomains[requestDomain] = true
// 第一遍:收集所有 CNAME 目标域名
for _, answer := range resp.Answer {
if answer.Header().Rrtype == dns.TypeCNAME {
if cname, ok := answer.(*dns.CNAME); ok {
cnameTarget := strings.ToLower(strings.TrimSuffix(cname.Target, "."))
validDomains[cnameTarget] = true
}
}
}
// 第二遍:验证所有应答记录
for _, answer := range resp.Answer {
answerDomain := strings.ToLower(strings.TrimSuffix(answer.Header().Name, "."))
// 检查应答记录的域名是否在有效域名列表中
if !validDomains[answerDomain] {
// 对于 CNAME 记录,域名必须是请求域名
if answer.Header().Rrtype == dns.TypeCNAME {
if answerDomain != requestDomain {
debugLogger.Printf("DNS response validation failed: CNAME domain mismatch - request: %s, CNAME: %s",
requestDomain, answerDomain)
return false
}
} else {
// 对于其他记录类型,记录警告但不拒绝(某些服务器可能返回额外记录)
debugLogger.Printf("DNS response validation warning: answer domain not in valid chain - request: %s, answer: %s (type: %d)",
requestDomain, answerDomain, answer.Header().Rrtype)
}
}
}
// 7. 检查 TTL 值的合理性(防止异常的 TTL 值)
for _, answer := range resp.Answer {
ttl := answer.Header().Ttl
// TTL 不应该超过 7 天604800 秒)
if ttl > 604800 {
debugLogger.Printf("DNS response validation warning: suspiciously high TTL: %d seconds for %s",
ttl, answer.Header().Name)
}
}
return true
}
// HandleDnsMsg 处理 DNS 查询的核心逻辑(支持缓存和统计)
// clientIP 和 domain 用于统计,如果为空则自动从请求中提取 domain
func (h *Handler) HandleDnsMsg(req *dns.Msg, clientIP, domain string) *dns.Msg {
h.logger.Printf("godns::request %+v\n", req)
// 记录查询统计
if h.stats != nil {
h.stats.RecordQuery()
// 提取域名(如果未提供)
if domain == "" && len(req.Question) > 0 {
domain = req.Question[0].Name
}
// 记录客户端查询
if clientIP != "" || domain != "" {
h.stats.RecordClientQuery(clientIP, domain)
}
}
// 检查缓存
var cacheKey string
var respCache *dns.Msg
if h.builtInCache != nil {
cacheKey = getDnsRequestCacheKey(req)
if v, ok := h.builtInCache.Get(cacheKey); ok {
if h.stats != nil {
h.stats.RecordCacheHit()
}
respCache = v.Msg.Copy()
if v.Expires.After(time.Now()) {
msg := replyUpdateTtl(req, respCache, uint32(time.Until(v.Expires).Seconds()))
if len(msg.Answer) > 0 {
return msg
}
}
} else {
if h.stats != nil {
h.stats.RecordCacheMiss()
}
}
}
// 从上游获取响应
resp := h.exchange(req)
if resp.Rcode == dns.RcodeServerFailure {
if h.stats != nil {
h.stats.RecordFailed()
}
// 上游失败时使用任何可用缓存(即使过期)作为降级
if respCache != nil {
msg := replyUpdateTtl(req, respCache, 12)
if len(msg.Answer) > 0 {
return msg
}
}
}
resp.SetReply(req)
h.logger.Printf("godns::resp: %+v\n", resp)
// 验证响应并缓存(防止缓存投毒)
if h.builtInCache != nil && shouldCacheResponse(resp) && validateResponse(req, resp, h.logger) {
ttl := getDnsResponseTtl(resp)
cachedMsg := &cache.CachedMsg{
Msg: resp,
Expires: time.Now().Add(ttl),
}
if err := h.builtInCache.Set(cacheKey, cachedMsg, ttl+time.Hour); err != nil {
h.logger.Printf("Failed to cache response: %v", err)
}
}
return resp
}
// extractClientIPFromDNS 从 DNS 请求中提取客户端 IP
// 优先级EDNS Client Subnet > RemoteAddr
func extractClientIPFromDNS(w dns.ResponseWriter, req *dns.Msg) string {
// 1. 优先检查 EDNS Client Subnet (ECS)
// ECS 是 DNS 协议标准,用于传递真实客户端 IP
if opt := req.IsEdns0(); opt != nil {
for _, option := range opt.Option {
if ecs, ok := option.(*dns.EDNS0_SUBNET); ok {
// ECS 中的 Address 就是客户端真实 IP
return ecs.Address.String()
}
}
}
// 2. 从 RemoteAddr 获取
var clientIP string
if addr := w.RemoteAddr(); addr != nil {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
clientIP = udpAddr.IP.String()
} else if tcpAddr, ok := addr.(*net.TCPAddr); ok {
clientIP = tcpAddr.IP.String()
}
}
return clientIP
}
func (h *Handler) HandleRequest(w dns.ResponseWriter, req *dns.Msg) {
// 提取客户端 IP
clientIP := extractClientIPFromDNS(w, req)
// 提取域名
var domain string
if len(req.Question) > 0 {
domain = req.Question[0].Name
}
// 调用核心处理逻辑
resp := h.HandleDnsMsg(req, clientIP, domain)
// 写入响应
if err := w.WriteMsg(resp); err != nil {
h.logger.Printf("WriteMsg error: %+v", err)
}
}
var builderPool = sync.Pool{
New: func() interface{} {
b := new(strings.Builder)
b.Grow(128)
return b
},
}
func uniqueAnswer(records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
seen := make(map[string]bool, len(records))
result := make([]dns.RR, 0, len(records))
for _, rr := range records {
if rr == nil {
continue
}
header := rr.Header()
if header == nil {
continue
}
builder := builderPool.Get().(*strings.Builder)
var key string
switch v := rr.(type) {
case *dns.A:
builder.WriteString(header.Name)
builder.WriteString("|A|")
builder.WriteString(v.A.String())
key = builder.String()
case *dns.AAAA:
builder.WriteString(header.Name)
builder.WriteString("|AAAA|")
builder.WriteString(v.AAAA.String())
key = builder.String()
case *dns.CNAME:
builder.WriteString(header.Name)
builder.WriteString("|CNAME|")
builder.WriteString(v.Target)
key = builder.String()
case *dns.MX:
builder.WriteString(header.Name)
builder.WriteString("|MX|")
builder.WriteString(fmt.Sprintf("%d|%s", v.Preference, v.Mx))
key = builder.String()
case *dns.NS:
builder.WriteString(header.Name)
builder.WriteString("|NS|")
builder.WriteString(v.Ns)
key = builder.String()
case *dns.PTR:
builder.WriteString(header.Name)
builder.WriteString("|PTR|")
builder.WriteString(v.Ptr)
key = builder.String()
case *dns.TXT:
builder.WriteString(header.Name)
builder.WriteString("|TXT|")
builder.WriteString(strings.Join(v.Txt, "|"))
key = builder.String()
case *dns.SRV:
builder.WriteString(header.Name)
builder.WriteString("|SRV|")
builder.WriteString(fmt.Sprintf("%d|%d|%d|%s", v.Priority, v.Weight, v.Port, v.Target))
key = builder.String()
case *dns.SOA:
builder.WriteString(header.Name)
builder.WriteString("|SOA|")
builder.WriteString(v.Ns)
builder.WriteString("|")
builder.WriteString(v.Mbox)
key = builder.String()
default:
key = rr.String()
}
builder.Reset()
builderPool.Put(builder)
if !seen[key] {
seen[key] = true
result = append(result, rr)
}
}
return result
}
func (h *Handler) getTheFullestResults(req *dns.Msg) []*dns.Msg {
matchedUpstreams := h.matchedUpstreams(req)
var wg sync.WaitGroup
wg.Add(len(matchedUpstreams))
msgs := make([]*dns.Msg, len(matchedUpstreams))
for i := 0; i < len(matchedUpstreams); i++ {
go func(j int) {
defer wg.Done()
msg, _, err := matchedUpstreams[j].Exchange(req.Copy())
// 记录上游服务器统计
if h.stats != nil {
h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil)
}
if err != nil {
h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err)
return
}
if matchedUpstreams[j].IsValidMsg(msg) {
msgs[j] = msg
}
}(i)
}
wg.Wait()
return msgs
}
func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg {
preferUpstreams := h.matchedUpstreams(req)
msgs := make([]*dns.Msg, len(preferUpstreams))
var mutex sync.Mutex
var finishedCount int
var finished bool
var freedomIndex, primaryIndex []int
var wg sync.WaitGroup
wg.Add(1)
for i := 0; i < len(preferUpstreams); i++ {
go func(j int) {
msg, _, err := preferUpstreams[j].Exchange(req.Copy())
if h.stats != nil {
h.stats.RecordUpstreamQuery(preferUpstreams[j].Address, err != nil)
}
if err != nil {
h.logger.Printf("upstream error %s: %v %s", preferUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err)
}
mutex.Lock()
defer mutex.Unlock()
if finished {
return
}
finishedCount++
if err == nil {
if preferUpstreams[j].IsValidMsg(msg) {
if preferUpstreams[j].IsPrimary {
primaryIndex = append(primaryIndex, j)
} else {
freedomIndex = append(freedomIndex, j)
}
msgs[j] = msg
} else if preferUpstreams[j].IsPrimary {
primaryIndex = append(primaryIndex, j)
}
}
shouldFinish := finishedCount == len(preferUpstreams) ||
(len(primaryIndex) > 0 && len(freedomIndex) > 0) ||
(len(primaryIndex) > 0 && (msgs[primaryIndex[0]] != nil || len(freedomIndex) > 0))
if shouldFinish && !finished {
finished = true
wg.Done()
}
}(i)
}
wg.Wait()
return msgs
}
func (h *Handler) getAnyResult(req *dns.Msg) []*dns.Msg {
matchedUpstreams := h.matchedUpstreams(req)
msgs := make([]*dns.Msg, len(matchedUpstreams))
var mutex sync.Mutex
var finishedCount int
var finished bool
var wg sync.WaitGroup
wg.Add(1)
for i := 0; i < len(matchedUpstreams); i++ {
go func(j int) {
msg, _, err := matchedUpstreams[j].Exchange(req.Copy())
if h.stats != nil {
h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil)
}
if err != nil {
h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err)
}
mutex.Lock()
defer mutex.Unlock()
if finished {
return
}
finishedCount++
shouldFinish := err == nil || finishedCount == len(matchedUpstreams)
if shouldFinish && !finished {
finished = true
msgs[j] = msg
wg.Done()
}
}(i)
}
wg.Wait()
return msgs
}
// Close properly shuts down the cache
func (h *Handler) Close() error {
if h.builtInCache != nil {
return h.builtInCache.Close()
}
return nil
}
// GetCacheStats returns cache statistics
func (h *Handler) GetCacheStats() string {
if h.builtInCache != nil {
return h.builtInCache.Stats()
}
return "Cache disabled"
}
// replyUpdateTtl 准备缓存响应以发送给客户端,执行必要的修正:
// 1. 设置正确的 Message ID通过 SetReply
// 2. 更新所有 RR 的 TTL 为剩余时间(最低 0
// 3. 调整 OPT RR 的 UDP size 为客户端请求的值
// 4. 清除 ECS Scope Length标记为缓存答案
// 5. 检查过期的 RRSIG 并移除
func replyUpdateTtl(req *dns.Msg, resp *dns.Msg, ttl uint32) *dns.Msg {
now := time.Now().Unix()
// 辅助函数:更新 RR 列表的 TTL并检测过期 RRSIG
updateRRs := func(rrs []dns.RR) []dns.RR {
var validRRs []dns.RR
for _, rr := range rrs {
header := rr.Header()
if header == nil {
continue
}
// 检查 RRSIG 是否过期
if rrsig, ok := rr.(*dns.RRSIG); ok {
if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration {
// RRSIG 已过期,跳过这条记录
continue
}
}
// 更新 TTL最低为 0
header.Ttl = ttl
validRRs = append(validRRs, rr)
}
return validRRs
}
// 更新所有部分的 TTL 并移除过期 RRSIG
resp.Answer = updateRRs(resp.Answer)
resp.Ns = updateRRs(resp.Ns)
// Extra 部分需要特殊处理 OPT RR
var validExtra []dns.RR
var reqOpt *dns.OPT
if reqOpt = req.IsEdns0(); reqOpt != nil {
// 客户端有 EDNS0获取其 UDP size
}
for _, rr := range resp.Extra {
if opt, ok := rr.(*dns.OPT); ok {
// 处理 OPT RR
if reqOpt != nil {
// 使用客户端请求的 UDP size
opt.SetUDPSize(reqOpt.UDPSize())
}
// 清除 ECS Scope Length
for i, option := range opt.Option {
if ecs, ok := option.(*dns.EDNS0_SUBNET); ok {
// 将 Scope Length 设为 0表示这是缓存答案
ecs.SourceScope = 0
opt.Option[i] = ecs
}
}
validExtra = append(validExtra, opt)
} else {
// 非 OPT RR正常更新 TTL 和检查 RRSIG
header := rr.Header()
if header != nil {
if rrsig, ok := rr.(*dns.RRSIG); ok {
if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration {
continue // 跳过过期的 RRSIG
}
}
header.Ttl = ttl
}
validExtra = append(validExtra, rr)
}
}
resp.Extra = validExtra
// SetReply 会设置正确的 Message ID 和其他响应标志
return resp.SetReply(req)
}