diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 524e75c..2a94e16 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -7,6 +7,7 @@ import ( "strings" "sync" "time" + "strconv" "github.com/miekg/dns" @@ -164,21 +165,19 @@ func (h *Handler) exchange(req *dns.Msg) *dns.Msg { } func getDnsRequestCacheKey(m *dns.Msg) string { - var dnssec string - if o := m.IsEdns0(); o != nil { - // 区分 DNSSEC 请求,避免将非 DNSSEC 响应返回给需要 DNSSEC 的客户端 - if o.Do() { - dnssec = "DO" - } - // 服务多区域的公共dns使用 - // for _, s := range o.Option { - // switch e := s.(type) { - // case *dns.EDNS0_SUBNET: - // edns = e.Address.String() - // } - // } + if len(m.Question) == 0 { + return "" } - return fmt.Sprintf("%s#%d#%s", model.GetDomainNameFromDnsMsg(m), m.Question[0].Qtype, dnssec) + + var b strings.Builder + b.Grow(64) + b.WriteString(model.GetDomainNameFromDnsMsg(m)) + b.WriteByte('#') + b.WriteString(strconv.Itoa(int(m.Question[0].Qtype))) + if o := m.IsEdns0(); o != nil && o.Do() { + b.WriteString("#DO") + } + return b.String() } func getDnsResponseTtl(m *dns.Msg) time.Duration { @@ -417,8 +416,14 @@ func (h *Handler) HandleRequest(w dns.ResponseWriter, req *dns.Msg) { } } -// uniqueAnswer 去除重复的 DNS 资源记录 -// 基于域名、类型和记录数据进行去重,比字符串分割更高效和可靠 +var builderPool = sync.Pool{ + New: func() interface{} { + b := new(strings.Builder) + b.Grow(128) + return b + }, +} + func uniqueAnswer(records []dns.RR) []dns.RR { if len(records) == 0 { return records @@ -437,11 +442,8 @@ func uniqueAnswer(records []dns.RR) []dns.RR { continue } - // 构造唯一键:域名 + 类型 + 记录数据 - // 使用 strings.Builder 优化字符串拼接性能 - var builder strings.Builder - builder.Grow(128) // Pre-allocate reasonable capacity - + builder := builderPool.Get().(*strings.Builder) + var key string switch v := rr.(type) { case *dns.A: @@ -492,9 +494,10 @@ func uniqueAnswer(records []dns.RR) []dns.RR { builder.WriteString(v.Mbox) key = builder.String() default: - // 对于其他类型,回退到完整字符串表示 key = rr.String() } + builder.Reset() + builderPool.Put(builder) if !seen[key] { seen[key] = true @@ -551,7 +554,6 @@ func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg { go func(j int) { msg, _, err := preferUpstreams[j].Exchange(req.Copy()) - // 记录上游服务器统计 if h.stats != nil { h.stats.RecordUpstreamQuery(preferUpstreams[j].Address, err != nil) } @@ -563,12 +565,12 @@ func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg { mutex.Lock() defer mutex.Unlock() - finishedCount++ - // 已经结束直接退出 if finished { return } + finishedCount++ + if err == nil { if preferUpstreams[j].IsValidMsg(msg) { if preferUpstreams[j].IsPrimary { @@ -578,27 +580,15 @@ func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg { } msgs[j] = msg } else if preferUpstreams[j].IsPrimary { - // 策略:国内 DNS 返回了 国外 服务器,计数但是不记入结果,以 国外 DNS 为准 primaryIndex = append(primaryIndex, j) } } - // 全部结束直接退出 - if finishedCount == len(preferUpstreams) { - finished = true - wg.Done() - return - } - // 两组 DNS 都有一个返回结果,退出 - if len(primaryIndex) > 0 && len(freedomIndex) > 0 { - finished = true - wg.Done() - return - } - // 满足任一条件退出 - // - 国内 DNS 返回了 国内 服务器 - // - 国内 DNS 返回国外服务器 且 国外 DNS 有可用结果 - if len(primaryIndex) > 0 && (msgs[primaryIndex[0]] != nil || len(freedomIndex) > 0) { + shouldFinish := finishedCount == len(preferUpstreams) || + (len(primaryIndex) > 0 && len(freedomIndex) > 0) || + (len(primaryIndex) > 0 && (msgs[primaryIndex[0]] != nil || len(freedomIndex) > 0)) + + if shouldFinish && !finished { finished = true wg.Done() } @@ -611,19 +601,19 @@ func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg { func (h *Handler) getAnyResult(req *dns.Msg) []*dns.Msg { matchedUpstreams := h.matchedUpstreams(req) - - var wg sync.WaitGroup - wg.Add(1) msgs := make([]*dns.Msg, len(matchedUpstreams)) + var mutex sync.Mutex var finishedCount int var finished bool + var wg sync.WaitGroup + wg.Add(1) + for i := 0; i < len(matchedUpstreams); i++ { go func(j int) { msg, _, err := matchedUpstreams[j].Exchange(req.Copy()) - // 记录上游服务器统计 if h.stats != nil { h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil) } @@ -631,16 +621,18 @@ func (h *Handler) getAnyResult(req *dns.Msg) []*dns.Msg { if err != nil { h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err) } + mutex.Lock() defer mutex.Unlock() - finishedCount++ if finished { return } - // 已结束或任意上游返回成功时退出 - if err == nil || finishedCount == len(matchedUpstreams) { + finishedCount++ + shouldFinish := err == nil || finishedCount == len(matchedUpstreams) + + if shouldFinish && !finished { finished = true msgs[j] = msg wg.Done()