package handler import ( "errors" "fmt" "net" "strings" "sync" "time" "strconv" "github.com/miekg/dns" "godns/internal/cache" "godns/internal/model" "godns/internal/stats" "godns/pkg/logger" ) type Handler struct { strategy int commonUpstreams, specialUpstreams []*model.Upstream builtInCache cache.Cache logger logger.Logger stats stats.StatsRecorder } func NewHandler(strategy int, builtInCache bool, upstreams []*model.Upstream, dataPath string, log logger.Logger, statsRecorder stats.StatsRecorder) *Handler { var c cache.Cache if builtInCache { var err error c, err = cache.NewBadgerCache(dataPath, log) if err != nil { log.Printf("Failed to initialize BadgerDB cache: %v", err) log.Printf("Cache will be disabled") c = nil } else { log.Printf("BadgerDB cache initialized successfully at %s", dataPath) } } var commonUpstreams, specialUpstreams []*model.Upstream for i := 0; i < len(upstreams); i++ { if len(upstreams[i].Match) > 0 { specialUpstreams = append(specialUpstreams, upstreams[i]) } else { commonUpstreams = append(commonUpstreams, upstreams[i]) } } return &Handler{ strategy: strategy, commonUpstreams: commonUpstreams, specialUpstreams: specialUpstreams, builtInCache: c, logger: log, stats: statsRecorder, } } func (h *Handler) matchedUpstreams(req *dns.Msg) []*model.Upstream { if len(req.Question) == 0 { return h.commonUpstreams } q := req.Question[0] var matchedUpstreams []*model.Upstream for i := 0; i < len(h.specialUpstreams); i++ { if h.specialUpstreams[i].IsMatch(q.Name) { matchedUpstreams = append(matchedUpstreams, h.specialUpstreams[i]) } } if len(matchedUpstreams) > 0 { return matchedUpstreams } return h.commonUpstreams } func (h *Handler) LookupIP(host string) (ip net.IP, err error) { if ip = net.ParseIP(host); ip != nil { return ip, nil } if !strings.HasSuffix(host, ".") { host += "." } m := new(dns.Msg) m.Id = dns.Id() m.RecursionDesired = true m.Question = make([]dns.Question, 1) m.Question[0] = dns.Question{Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET} res := h.exchange(m) // 取一个 IPv4 地址 for i := 0; i < len(res.Answer); i++ { if aRecord, ok := res.Answer[i].(*dns.A); ok { ip = aRecord.A } } // 选取最后一个(一般是备用,存活率高一些) if ip == nil { err = errors.New("no ipv4 address found") } h.logger.Printf("bootstrap LookupIP: %s %v --> %s %v", host, res.Answer, ip, err) return } // removeEDNS 清理请求中的 EDNS 客户端子网信息 func (h *Handler) removeEDNS(req *dns.Msg) { opt := req.IsEdns0() if opt == nil { return } // 过滤掉 EDNS Client Subnet 选项 var newOptions []dns.EDNS0 for _, option := range opt.Option { if _, ok := option.(*dns.EDNS0_SUBNET); !ok { // 保留非 ECS 的其他选项 newOptions = append(newOptions, option) } else { h.logger.Printf("Removed EDNS Client Subnet from request") } } opt.Option = newOptions } func (h *Handler) exchange(req *dns.Msg) *dns.Msg { // 清理 EDNS 客户端子网信息 h.removeEDNS(req) var msgs []*dns.Msg switch h.strategy { case model.StrategyFullest: msgs = h.getTheFullestResults(req) case model.StrategyFastest: msgs = h.getTheFastestResults(req) case model.StrategyAnyResult: msgs = h.getAnyResult(req) } var res *dns.Msg for i := 0; i < len(msgs); i++ { if msgs[i] == nil { continue } if res == nil { res = msgs[i] continue } res.Answer = append(res.Answer, msgs[i].Answer...) } if res == nil { // 如果全部上游挂了要返回错误 res = new(dns.Msg) res.Rcode = dns.RcodeServerFailure } else { res.Answer = uniqueAnswer(res.Answer) } return res } func getDnsRequestCacheKey(m *dns.Msg) string { if len(m.Question) == 0 { return "" } var b strings.Builder b.Grow(64) b.WriteString(model.GetDomainNameFromDnsMsg(m)) b.WriteByte('#') b.WriteString(strconv.Itoa(int(m.Question[0].Qtype))) if o := m.IsEdns0(); o != nil && o.Do() { b.WriteString("#DO") } return b.String() } func getDnsResponseTtl(m *dns.Msg) time.Duration { var ttl uint32 if len(m.Answer) > 0 { ttl = m.Answer[0].Header().Ttl } if ttl < 60 { ttl = 60 // 最小 ttl 1 分钟 } else if ttl > 3600 { ttl = 3600 // 最大 ttl 1 小时 } return time.Duration(ttl) * time.Second } // shouldCacheResponse 判断响应是否应该被缓存 func shouldCacheResponse(m *dns.Msg) bool { // 不缓存服务器错误响应 if m.Rcode == dns.RcodeServerFailure { return false } // 不缓存格式错误的响应 if m.Rcode == dns.RcodeFormatError { return false } // NXDOMAIN (域名不存在) 可以缓存,但时间较短(由 getDnsResponseTtl 控制) // NOERROR 和 NXDOMAIN 都可以缓存 return m.Rcode == dns.RcodeSuccess || m.Rcode == dns.RcodeNameError } // validateResponse 验证 DNS 响应,防止缓存投毒 // 返回 true 表示响应有效,false 表示可能存在投毒风险 func validateResponse(req *dns.Msg, resp *dns.Msg, debugLogger logger.Logger) bool { // 1. 检查响应是否为空 if resp == nil { return false } // 2. 检查请求和响应的问题数量 if len(req.Question) == 0 || len(resp.Question) == 0 { return true // 如果没有问题部分,跳过验证(某些响应可能没有问题部分) } // 3. 验证域名匹配(不区分大小写) if !strings.EqualFold(req.Question[0].Name, resp.Question[0].Name) { debugLogger.Printf("DNS response validation failed: domain mismatch - request: %s, response: %s", req.Question[0].Name, resp.Question[0].Name) return false } // 4. 验证查询类型匹配 if req.Question[0].Qtype != resp.Question[0].Qtype { debugLogger.Printf("DNS response validation failed: qtype mismatch - request: %d, response: %d", req.Question[0].Qtype, resp.Question[0].Qtype) return false } // 5. 验证查询类别匹配(通常都是 IN - Internet) if req.Question[0].Qclass != resp.Question[0].Qclass { debugLogger.Printf("DNS response validation failed: qclass mismatch - request: %d, response: %d", req.Question[0].Qclass, resp.Question[0].Qclass) return false } // 6. 验证 Answer 部分的域名(防止返回无关域名的记录) requestDomain := strings.ToLower(strings.TrimSuffix(req.Question[0].Name, ".")) validDomains := make(map[string]bool) validDomains[requestDomain] = true // 第一遍:收集所有 CNAME 目标域名 for _, answer := range resp.Answer { if answer.Header().Rrtype == dns.TypeCNAME { if cname, ok := answer.(*dns.CNAME); ok { cnameTarget := strings.ToLower(strings.TrimSuffix(cname.Target, ".")) validDomains[cnameTarget] = true } } } // 第二遍:验证所有应答记录 for _, answer := range resp.Answer { answerDomain := strings.ToLower(strings.TrimSuffix(answer.Header().Name, ".")) // 检查应答记录的域名是否在有效域名列表中 if !validDomains[answerDomain] { // 对于 CNAME 记录,域名必须是请求域名 if answer.Header().Rrtype == dns.TypeCNAME { if answerDomain != requestDomain { debugLogger.Printf("DNS response validation failed: CNAME domain mismatch - request: %s, CNAME: %s", requestDomain, answerDomain) return false } } else { // 对于其他记录类型,记录警告但不拒绝(某些服务器可能返回额外记录) debugLogger.Printf("DNS response validation warning: answer domain not in valid chain - request: %s, answer: %s (type: %d)", requestDomain, answerDomain, answer.Header().Rrtype) } } } // 7. 检查 TTL 值的合理性(防止异常的 TTL 值) for _, answer := range resp.Answer { ttl := answer.Header().Ttl // TTL 不应该超过 7 天(604800 秒) if ttl > 604800 { debugLogger.Printf("DNS response validation warning: suspiciously high TTL: %d seconds for %s", ttl, answer.Header().Name) } } return true } // HandleDnsMsg 处理 DNS 查询的核心逻辑(支持缓存和统计) // clientIP 和 domain 用于统计,如果为空则自动从请求中提取 domain func (h *Handler) HandleDnsMsg(req *dns.Msg, clientIP, domain string) *dns.Msg { h.logger.Printf("godns::request %+v\n", req) // 记录查询统计 if h.stats != nil { h.stats.RecordQuery() // 提取域名(如果未提供) if domain == "" && len(req.Question) > 0 { domain = req.Question[0].Name } // 记录客户端查询 if clientIP != "" || domain != "" { h.stats.RecordClientQuery(clientIP, domain) } } // 检查缓存 var cacheKey string var respCache *dns.Msg if h.builtInCache != nil { cacheKey = getDnsRequestCacheKey(req) if v, ok := h.builtInCache.Get(cacheKey); ok { if h.stats != nil { h.stats.RecordCacheHit() } respCache = v.Msg.Copy() if v.Expires.After(time.Now()) { msg := replyUpdateTtl(req, respCache, uint32(time.Until(v.Expires).Seconds())) if len(msg.Answer) > 0 { return msg } } } else { if h.stats != nil { h.stats.RecordCacheMiss() } } } // 从上游获取响应 resp := h.exchange(req) if resp.Rcode == dns.RcodeServerFailure { if h.stats != nil { h.stats.RecordFailed() } // 上游失败时使用任何可用缓存(即使过期)作为降级 if respCache != nil { msg := replyUpdateTtl(req, respCache, 12) if len(msg.Answer) > 0 { return msg } } } resp.SetReply(req) h.logger.Printf("godns::resp: %+v\n", resp) // 验证响应并缓存(防止缓存投毒) if h.builtInCache != nil && shouldCacheResponse(resp) && validateResponse(req, resp, h.logger) { ttl := getDnsResponseTtl(resp) cachedMsg := &cache.CachedMsg{ Msg: resp, Expires: time.Now().Add(ttl), } if err := h.builtInCache.Set(cacheKey, cachedMsg, ttl+time.Hour); err != nil { h.logger.Printf("Failed to cache response: %v", err) } } return resp } // extractClientIPFromDNS 从 DNS 请求中提取客户端 IP // 优先级:EDNS Client Subnet > RemoteAddr func extractClientIPFromDNS(w dns.ResponseWriter, req *dns.Msg) string { // 1. 优先检查 EDNS Client Subnet (ECS) // ECS 是 DNS 协议标准,用于传递真实客户端 IP if opt := req.IsEdns0(); opt != nil { for _, option := range opt.Option { if ecs, ok := option.(*dns.EDNS0_SUBNET); ok { // ECS 中的 Address 就是客户端真实 IP return ecs.Address.String() } } } // 2. 从 RemoteAddr 获取 var clientIP string if addr := w.RemoteAddr(); addr != nil { if udpAddr, ok := addr.(*net.UDPAddr); ok { clientIP = udpAddr.IP.String() } else if tcpAddr, ok := addr.(*net.TCPAddr); ok { clientIP = tcpAddr.IP.String() } } return clientIP } func (h *Handler) HandleRequest(w dns.ResponseWriter, req *dns.Msg) { // 提取客户端 IP clientIP := extractClientIPFromDNS(w, req) // 提取域名 var domain string if len(req.Question) > 0 { domain = req.Question[0].Name } // 调用核心处理逻辑 resp := h.HandleDnsMsg(req, clientIP, domain) // 写入响应 if err := w.WriteMsg(resp); err != nil { h.logger.Printf("WriteMsg error: %+v", err) } } var builderPool = sync.Pool{ New: func() interface{} { b := new(strings.Builder) b.Grow(128) return b }, } func uniqueAnswer(records []dns.RR) []dns.RR { if len(records) == 0 { return records } seen := make(map[string]bool, len(records)) result := make([]dns.RR, 0, len(records)) for _, rr := range records { if rr == nil { continue } header := rr.Header() if header == nil { continue } builder := builderPool.Get().(*strings.Builder) var key string switch v := rr.(type) { case *dns.A: builder.WriteString(header.Name) builder.WriteString("|A|") builder.WriteString(v.A.String()) key = builder.String() case *dns.AAAA: builder.WriteString(header.Name) builder.WriteString("|AAAA|") builder.WriteString(v.AAAA.String()) key = builder.String() case *dns.CNAME: builder.WriteString(header.Name) builder.WriteString("|CNAME|") builder.WriteString(v.Target) key = builder.String() case *dns.MX: builder.WriteString(header.Name) builder.WriteString("|MX|") builder.WriteString(fmt.Sprintf("%d|%s", v.Preference, v.Mx)) key = builder.String() case *dns.NS: builder.WriteString(header.Name) builder.WriteString("|NS|") builder.WriteString(v.Ns) key = builder.String() case *dns.PTR: builder.WriteString(header.Name) builder.WriteString("|PTR|") builder.WriteString(v.Ptr) key = builder.String() case *dns.TXT: builder.WriteString(header.Name) builder.WriteString("|TXT|") builder.WriteString(strings.Join(v.Txt, "|")) key = builder.String() case *dns.SRV: builder.WriteString(header.Name) builder.WriteString("|SRV|") builder.WriteString(fmt.Sprintf("%d|%d|%d|%s", v.Priority, v.Weight, v.Port, v.Target)) key = builder.String() case *dns.SOA: builder.WriteString(header.Name) builder.WriteString("|SOA|") builder.WriteString(v.Ns) builder.WriteString("|") builder.WriteString(v.Mbox) key = builder.String() default: key = rr.String() } builder.Reset() builderPool.Put(builder) if !seen[key] { seen[key] = true result = append(result, rr) } } return result } func (h *Handler) getTheFullestResults(req *dns.Msg) []*dns.Msg { matchedUpstreams := h.matchedUpstreams(req) var wg sync.WaitGroup wg.Add(len(matchedUpstreams)) msgs := make([]*dns.Msg, len(matchedUpstreams)) for i := 0; i < len(matchedUpstreams); i++ { go func(j int) { defer wg.Done() msg, _, err := matchedUpstreams[j].Exchange(req.Copy()) // 记录上游服务器统计 if h.stats != nil { h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil) } if err != nil { h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err) return } if matchedUpstreams[j].IsValidMsg(msg) { msgs[j] = msg } }(i) } wg.Wait() return msgs } func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg { preferUpstreams := h.matchedUpstreams(req) msgs := make([]*dns.Msg, len(preferUpstreams)) var mutex sync.Mutex var finishedCount int var finished bool var freedomIndex, primaryIndex []int var wg sync.WaitGroup wg.Add(1) for i := 0; i < len(preferUpstreams); i++ { go func(j int) { msg, _, err := preferUpstreams[j].Exchange(req.Copy()) if h.stats != nil { h.stats.RecordUpstreamQuery(preferUpstreams[j].Address, err != nil) } if err != nil { h.logger.Printf("upstream error %s: %v %s", preferUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err) } mutex.Lock() defer mutex.Unlock() if finished { return } finishedCount++ if err == nil { if preferUpstreams[j].IsValidMsg(msg) { if preferUpstreams[j].IsPrimary { primaryIndex = append(primaryIndex, j) } else { freedomIndex = append(freedomIndex, j) } msgs[j] = msg } else if preferUpstreams[j].IsPrimary { primaryIndex = append(primaryIndex, j) } } shouldFinish := finishedCount == len(preferUpstreams) || (len(primaryIndex) > 0 && len(freedomIndex) > 0) || (len(primaryIndex) > 0 && (msgs[primaryIndex[0]] != nil || len(freedomIndex) > 0)) if shouldFinish && !finished { finished = true wg.Done() } }(i) } wg.Wait() return msgs } func (h *Handler) getAnyResult(req *dns.Msg) []*dns.Msg { matchedUpstreams := h.matchedUpstreams(req) msgs := make([]*dns.Msg, len(matchedUpstreams)) var mutex sync.Mutex var finishedCount int var finished bool var wg sync.WaitGroup wg.Add(1) for i := 0; i < len(matchedUpstreams); i++ { go func(j int) { msg, _, err := matchedUpstreams[j].Exchange(req.Copy()) if h.stats != nil { h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil) } if err != nil { h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err) } mutex.Lock() defer mutex.Unlock() if finished { return } finishedCount++ shouldFinish := err == nil || finishedCount == len(matchedUpstreams) if shouldFinish && !finished { finished = true msgs[j] = msg wg.Done() } }(i) } wg.Wait() return msgs } // Close properly shuts down the cache func (h *Handler) Close() error { if h.builtInCache != nil { return h.builtInCache.Close() } return nil } // GetCacheStats returns cache statistics func (h *Handler) GetCacheStats() string { if h.builtInCache != nil { return h.builtInCache.Stats() } return "Cache disabled" } // replyUpdateTtl 准备缓存响应以发送给客户端,执行必要的修正: // 1. 设置正确的 Message ID(通过 SetReply) // 2. 更新所有 RR 的 TTL 为剩余时间(最低 0) // 3. 调整 OPT RR 的 UDP size 为客户端请求的值 // 4. 清除 ECS Scope Length(标记为缓存答案) // 5. 检查过期的 RRSIG 并移除 func replyUpdateTtl(req *dns.Msg, resp *dns.Msg, ttl uint32) *dns.Msg { now := time.Now().Unix() // 辅助函数:更新 RR 列表的 TTL,并检测过期 RRSIG updateRRs := func(rrs []dns.RR) []dns.RR { var validRRs []dns.RR for _, rr := range rrs { header := rr.Header() if header == nil { continue } // 检查 RRSIG 是否过期 if rrsig, ok := rr.(*dns.RRSIG); ok { if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration { // RRSIG 已过期,跳过这条记录 continue } } // 更新 TTL(最低为 0) header.Ttl = ttl validRRs = append(validRRs, rr) } return validRRs } // 更新所有部分的 TTL 并移除过期 RRSIG resp.Answer = updateRRs(resp.Answer) resp.Ns = updateRRs(resp.Ns) // Extra 部分需要特殊处理 OPT RR var validExtra []dns.RR var reqOpt *dns.OPT if reqOpt = req.IsEdns0(); reqOpt != nil { // 客户端有 EDNS0,获取其 UDP size } for _, rr := range resp.Extra { if opt, ok := rr.(*dns.OPT); ok { // 处理 OPT RR if reqOpt != nil { // 使用客户端请求的 UDP size opt.SetUDPSize(reqOpt.UDPSize()) } // 清除 ECS Scope Length for i, option := range opt.Option { if ecs, ok := option.(*dns.EDNS0_SUBNET); ok { // 将 Scope Length 设为 0,表示这是缓存答案 ecs.SourceScope = 0 opt.Option[i] = ecs } } validExtra = append(validExtra, opt) } else { // 非 OPT RR,正常更新 TTL 和检查 RRSIG header := rr.Header() if header != nil { if rrsig, ok := rr.(*dns.RRSIG); ok { if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration { continue // 跳过过期的 RRSIG } } header.Ttl = ttl } validExtra = append(validExtra, rr) } } resp.Extra = validExtra // SetReply 会设置正确的 Message ID 和其他响应标志 return resp.SetReply(req) }