commit 7bf4f27be3c0417d002bf2242895c00c09fda361 Author: xofine Date: Tue Jan 6 02:25:24 2026 +0800 update once diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..151a849 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ +/data/* +!/data/config.json.example +/main +/godns +/debug +/build +.DS_Store diff --git a/.goreleaser.yml b/.goreleaser.yml new file mode 100644 index 0000000..72799da --- /dev/null +++ b/.goreleaser.yml @@ -0,0 +1,72 @@ +version: 2 + +before: + hooks: + - go mod tidy -v +builds: + - env: + - CGO_ENABLED=0 + ldflags: + - -s -w -X main.version={{.Version}} + goos: + - linux + - windows + - darwin + goarch: + - arm + - arm64 + - 386 + - amd64 + - mips + - mipsle + - s390x + - riscv64 + gomips: + - softfloat + ignore: + - goos: windows + goarch: arm + - goos: windows + goarch: arm64 + main: . + binary: godns +universal_binaries: + - name_template: "godns" + replace: false +checksum: + name_template: "checksums.txt" +snapshot: + version_template: "{{ .Version }}-SNAPSHOT-{{ .ShortCommit }}" +archives: + - name_template: "godns_{{ .Os }}_{{ .Arch }}" + formats: ["zip"] + files: + - LICENSE + - README.md + - data +dockers_v2: + - images: + - "ghcr.io/xofine/{{ .ProjectName }}" + tags: + - "{{ .Version }}" + - latest + platforms: + - linux/amd64 + - linux/arm64 + extra_files: + - README.md + labels: + "org.opencontainers.image.created": "{{.Date}}" + "org.opencontainers.image.title": "{{.ProjectName}}" + "org.opencontainers.image.revision": "{{.FullCommit}}" + "org.opencontainers.image.version": "{{.Version}}" +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" + - "^chore" + - Merge pull request + - Merge branch + - go mod tidy diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..22840e7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +# 构建阶段 +FROM golang:alpine AS builder +WORKDIR /build +COPY . . +RUN go build -ldflags="-s -w" -o godns . + +# 运行阶段 +FROM alpine:latest +RUN apk --no-cache add ca-certificates +WORKDIR /godns +COPY --from=builder /build/godns . +COPY data ./data + +VOLUME ["/godns/data"] +ENTRYPOINT ["/godns/godns"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2b318c8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 naiba (xofine after) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8234634 --- /dev/null +++ b/README.md @@ -0,0 +1,168 @@ +# GoDNS + +基于[NbDNS](https://github.com/naiba/nbdns)的个人修改版,并因为原名太过霸道而被迫改名。 + +以下内容来自原项目介绍。 + +:seal: 一个聪明的 DNS 中继器,可提升 DNS 解析准确性,自带管理面板,可替代 AdguardHome。 + +![截图](./doc/screenshot.png) + +## 快速开始 + +1. 从 [releases](https://github.com/naiba/nbdns/releases) 下载最新版本 +2. 下载 [china_ip_list.txt](https://github.com/17mon/china_ip_list/raw/master/china_ip_list.txt) 到 `data` 文件夹 +3. 创建配置文件 `data/config.json`(参考下方配置示例) +4. 启动 `./godns` +5. 访问 `http://localhost:8854` 查看监控面板 +6. DNS TCP/UDP `127.0.0.1:8853`, DoH `http://localhost:8854/dns-query` + +**文件结构:** +``` +|- godns +|- data + |- config.json + |- china_ip_list.txt +``` + +**测试命令:** +```bash +dig @127.0.0.1 -p 8853 www.baidu.com +dig @127.0.0.1 -p 8853 www.google.com +``` +Windows 上的 [dig](https://help.dyn.com/how-to-use-binds-dig-tool/) 工具 + +## 配置示例 + +```json +{ + "serve_addr": "127.0.0.1:8853", + "web_addr": "0.0.0.0:8854", + "strategy": 2, + "timeout": 4, + "built_in_cache": true, + "socks_proxy": "192.168.1.254:3838", + "bootstrap": [ + {"address": "tcp://8.8.4.4:53"}, + {"address": "tcp://1.0.0.1:53"} + ], + "upstreams": [ + {"address": "udp://223.5.5.5:53", "is_primary": true}, + {"address": "udp://223.6.6.6:53", "is_primary": true}, + {"address": "tcp-tls://dns.google:853", "use_socks": true}, + {"address": "tcp-tls://one.one.one.one:853", "use_socks": true}, + {"address": "https://user:pass@doh.example.com/dns-query", "match": [".onion"]} + ], + "doh_server": { + "username": "admin", + "password": "secret" + }, + "blacklist": [".bing.com"] +} +``` + +### 配置说明 + +| 字段 | 说明 | 默认值 | +| ---------------- | ---------------------------------------------------- | -------------- | +| `serve_addr` | DNS 服务监听地址 | 必填 | +| `web_addr` | Web 面板和 DoH 服务端口 | `0.0.0.0:8854` | +| `strategy` | 查询策略:1-最全结果,2-最快结果(推荐),3-任一结果 | `2` | +| `timeout` | 上游超时时间(秒) | `4` | +| `built_in_cache` | 启用内建缓存 | `false` | +| `socks_proxy` | SOCKS5 代理地址 | 可选 | +| `bootstrap` | Bootstrap DNS 服务器(仅支持 IP) | 必填 | +| `upstreams` | 上游 DNS 列表 | 必填 | +| `doh_server` | DoH 服务配置 | 可选 | +| `blacklist` | 域名黑名单(强制使用非 primary DNS) | 可选 | + +**上游 DNS 配置:** +- `is_primary`: 标记国内 DNS +- `use_socks`: 通过 SOCKS5 代理连接 +- `match`: 仅匹配特定域名后缀 + +**域名匹配规则:** +- `.` 匹配所有 +- `a.com` 仅匹配 a.com +- `.a.com` 匹配 a.a.com, c.a.com, e.d.a.com 等 + +## 功能特性 + +### :chart_with_upwards_trend: Web 监控面板 +访问 `http://localhost:8854` 查看: +- 运行时状态(运行时长、内存、Goroutines、GC) +- DNS 查询统计(总查询数、缓存命中率、失败数) +- 上游服务器状态(查询数、错误率、最后使用时间) +- Top 客户端 IP 和查询域名排行 +- 统计数据重置功能 + +### :lock: DoH (DNS over HTTPS) +DoH 服务与 Web 面板共用端口,访问路径:`/dns-query` + +**配置示例:** +```json +{ + "doh_server": { + "username": "admin", + "password": "secret" + } +} +``` + +**测试:** +```bash +curl -v -H "Accept: application/dns-message" \ + -u "user:password" \ + "http://localhost:8854/dns-query?dns=AAABAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB" +``` + +**浏览器配置(Firefox):** +设置 → 网络设置 → 启用基于 HTTPS 的 DNS → 自定义 → `http://your-server:8854/dns-query` + +## 部署 + +### :whale: Docker +```bash +docker run --name godns --restart always -d \ + -v /path/to/data:/godns/data \ + -p 8853:8853/udp \ + -p 8854:8854 \ + ghcr.io/xofine/godns +``` + +### :package: OpenWRT 自启动 +首先在 release 下载对应的二进制解压 zip 包后放置到 `/root`,然后 `chmod -R 777 /root/godns` 赋予执行权限,然后创建 `/etc/init.d/godns`: + +```shell +#!/bin/sh /etc/rc.common +USE_PROCD=1 +# After network starts +START=21 +# Before network stops +STOP=89 + +cmd=/root/godns/godns +name=godns +pid_file="/var/run/${name}.pid" + +start_service() { + echo "Starting ${name}" + procd_open_instance + procd_set_param command ${cmd} + procd_set_param respawn + + # respawn automatically if something died, be careful if you have an alternative process supervisor + # if process exits sooner than respawn_threshold, it is considered crashed and after 5 retries the service is stopped + # if process finishes later than respawn_threshold, it is restarted unconditionally, regardless of error code + # notice that this is literal respawning of the process, no in a respawn-on-failure sense + procd_set_param respawn ${respawn_threshold:-3600} ${respawn_timeout:-5} ${respawn_retry:-5} + + procd_set_param stdout 1 # forward stdout of the command to logd + procd_set_param stderr 1 # same for stderr + procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop + procd_close_instance + echo "${name} has been started" +} +``` + +赋予执行权限 `chmod +x /etc/init.d/godns` 然后启动服务 `/etc/init.d/godns enable && /etc/init.d/godns start` diff --git a/data/config.json.example b/data/config.json.example new file mode 100644 index 0000000..be4f194 --- /dev/null +++ b/data/config.json.example @@ -0,0 +1,61 @@ +{ + "debug": false, + "profiling": false, + "strategy": 2, + "timeout": 2, + "serve_addr": "127.0.0.1:8853", + "web_addr": "0.0.0.0:8854", + "socks_proxy": "", + "built_in_cache": false, + "max_active_connections": 50, + "max_idle_connections": 20, + "stats_save_interval": 5, + "doh_server": { + "username": "user", + "password": "password" + }, + "web_auth": { + "username": "admin", + "password": "your_secure_password" + } + "bootstrap": [ + { + "address": "udp://223.5.5.5:53" + }, + { + "address": "udp://223.6.6.6:53" + } + ], + "upstreams": [ + { + "address": "udp://223.5.5.5:53", + "is_primary": true + }, + { + "address": "udp://223.6.6.6:53", + "is_primary": true + }, + { + "address": "udp://114.114.114.114:53", + "is_primary": true + }, + { + "address": "udp://119.28.28.28:53", + "is_primary": true + }, + { + "address": "tcp-tls://one.one.one.one:853", + "use_socks": false + }, + { + "address": "https://dns.google/dns-query", + "use_socks": false, + "match": [ + ".*\\.onion" + ] + } + ], + "blacklist": [ + "^.*\\.?bing.com*" + ] +} diff --git a/doc/screenshot.png b/doc/screenshot.png new file mode 100644 index 0000000..cfcf6a2 Binary files /dev/null and b/doc/screenshot.png differ diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a06d188 --- /dev/null +++ b/go.mod @@ -0,0 +1,44 @@ +module godns + +go 1.23.0 + +toolchain go1.24.4 + +require ( + github.com/blang/semver v3.5.1+incompatible + github.com/dgraph-io/badger/v4 v4.8.0 + github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd + github.com/miekg/dns v1.1.62 + github.com/pkg/errors v0.9.1 + github.com/rhysd/go-github-selfupdate v1.2.3 + github.com/yl2chen/cidranger v1.0.2 + go.uber.org/atomic v1.11.0 + golang.org/x/net v0.41.0 + golang.org/x/text v0.26.0 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/flatbuffers v25.2.10+incompatible // indirect + github.com/google/go-github/v30 v30.1.0 // indirect + github.com/google/go-querystring v1.0.0 // indirect + github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/tcnksm/go-gitconfig v0.1.2 // indirect + github.com/ulikunitz/xz v0.5.9 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/mod v0.25.0 // indirect + golang.org/x/oauth2 v0.27.0 // indirect + golang.org/x/sync v0.15.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/tools v0.33.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..400f8ec --- /dev/null +++ b/go.sum @@ -0,0 +1,136 @@ +github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= +github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs= +github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w= +github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM= +github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd h1:s2vYw+2c+7GR1ccOaDuDcKsmNB/4RIxyu5liBm1VRbs= +github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd/go.mod h1:Vr/Q4p40Kce7JAHDITjDhiy/zk07W4tqD5YVi5FD0PA= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= +github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-github/v30 v30.1.0 h1:VLDx+UolQICEOKu2m4uAoMti1SxuEBAl7RSEG16L+Oo= +github.com/google/go-github/v30 v30.1.0/go.mod h1:n8jBpHl45a/rlBUtRJMOG4GhNADUQFEufcolZ95JfU8= +github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf h1:WfD7VjIE6z8dIvMsI4/s+1qr5EL+zoIGev1BQj1eoJ8= +github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf/go.mod h1:hyb9oH7vZsitZCiBt0ZvifOrB+qc8PS5IiilCIb87rg= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I= +github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rhysd/go-github-selfupdate v1.2.3 h1:iaa+J202f+Nc+A8zi75uccC8Wg3omaM7HDeimXA22Ag= +github.com/rhysd/go-github-selfupdate v1.2.3/go.mod h1:mp/N8zj6jFfBQy/XMYoWsmfzxazpPAODuqarmPDe2Rg= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tcnksm/go-gitconfig v0.1.2 h1:iiDhRitByXAEyjgBqsKi9QU4o2TNtv9kPP3RgPgXBPw= +github.com/tcnksm/go-gitconfig v0.1.2/go.mod h1:/8EhP4H7oJZdIPyT+/UIsG87kTzrzM4UsLGSItWYCpE= +github.com/ulikunitz/xz v0.5.9 h1:RsKRIA2MO8x56wkkcd3LbtcE/uMszhb6DpRf+3uwa3I= +github.com/ulikunitz/xz v0.5.9/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/yl2chen/cidranger v1.0.2 h1:lbOWZVCG1tCRX4u24kuM1Tb4nHqWkDxwLdoS+SevawU= +github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= +golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cache/badger_cache.go b/internal/cache/badger_cache.go new file mode 100644 index 0000000..333591a --- /dev/null +++ b/internal/cache/badger_cache.go @@ -0,0 +1,247 @@ +package cache + +import ( + "fmt" + "path/filepath" + "time" + + "github.com/dgraph-io/badger/v4" + "github.com/dgraph-io/badger/v4/options" + "github.com/miekg/dns" + "godns/pkg/logger" +) + +// Cache 定义缓存接口 +type Cache interface { + Get(key string) (*CachedMsg, bool) + Set(key string, msg *CachedMsg, ttl time.Duration) error + Delete(key string) error + Close() error + Stats() string +} + +// CachedMsg represents a cached DNS message with expiration time +type CachedMsg struct { + Msg *dns.Msg `json:"msg"` + Expires time.Time `json:"expires"` +} + +// BadgerCache wraps BadgerDB for DNS query caching +type BadgerCache struct { + db *badger.DB + logger logger.Logger +} + +// NewBadgerCache creates a new BadgerDB cache instance with optimized settings for embedded devices +func NewBadgerCache(dataPath string, log logger.Logger) (*BadgerCache, error) { + dbPath := filepath.Join(dataPath, "cache") + + opts := badger.DefaultOptions(dbPath) + + // 针对树莓派等嵌入式设备的优化配置(目标:总内存 ~32MB) + // MemTable:4MB,BadgerDB 默认保持 2 个 MemTable + opts.MemTableSize = 4 << 20 // 4MB (内存占用 ~8MB) + + // ValueLog:4MB + opts.ValueLogFileSize = 4 << 20 // 4MB + + // BlockCache:16MB,提升读取性能 + opts.BlockCacheSize = 16 << 20 // 16MB + + // IndexCache:8MB,加速索引查找 + opts.IndexCacheSize = 8 << 20 // 8MB + + // Level 0 tables + opts.NumLevelZeroTables = 2 + opts.NumLevelZeroTablesStall = 4 + + // 关闭压缩,节省 CPU + opts.Compression = options.None + + // DNS 响应通常较小,内联存储减少磁盘访问 + opts.ValueThreshold = 512 + + // 异步写入,提高性能 + opts.SyncWrites = false + + // ValueLog 条目数量 + opts.ValueLogMaxEntries = 50000 + + // 压缩线程数 + opts.NumCompactors = 2 + + // 禁用冲突检测,提升写入性能 + opts.DetectConflicts = false + + // 禁用内部日志 + opts.Logger = nil + + db, err := badger.Open(opts) + if err != nil { + return nil, fmt.Errorf("failed to open BadgerDB: %w", err) + } + + cache := &BadgerCache{db: db, logger: log} + + // Start garbage collection routines + go cache.runGC() + go cache.runCompaction() + + return cache, nil +} + +// Set stores a DNS message in the cache with the given key and TTL +func (bc *BadgerCache) Set(key string, msg *CachedMsg, ttl time.Duration) error { + // Pack DNS message to wire format + dnsData, err := msg.Msg.Pack() + if err != nil { + return fmt.Errorf("failed to pack DNS message: %w", err) + } + + // 直接存储二进制数据:8字节过期时间 + DNS wire format + // 避免 JSON 序列化开销 + expiresBytes := make([]byte, 8) + // 使用 Unix 时间戳(秒) + expiresUnix := msg.Expires.Unix() + for i := 0; i < 8; i++ { + expiresBytes[i] = byte(expiresUnix >> (56 - i*8)) + } + + // 组合数据:过期时间 + DNS数据 + data := append(expiresBytes, dnsData...) + + return bc.db.Update(func(txn *badger.Txn) error { + entry := badger.NewEntry([]byte(key), data).WithTTL(ttl) + return txn.SetEntry(entry) + }) +} + +// Get retrieves a DNS message from the cache +func (bc *BadgerCache) Get(key string) (*CachedMsg, bool) { + var cachedMsg *CachedMsg + + err := bc.db.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(key)) + if err != nil { + return err + } + + return item.Value(func(val []byte) error { + // 数据格式:8字节过期时间 + DNS wire format + if len(val) < 8 { + return fmt.Errorf("invalid cache data: too short") + } + + // 解析过期时间 + var expiresUnix int64 + for i := 0; i < 8; i++ { + expiresUnix = (expiresUnix << 8) | int64(val[i]) + } + expires := time.Unix(expiresUnix, 0) + + // 解析 DNS 消息 + msg := new(dns.Msg) + if err := msg.Unpack(val[8:]); err != nil { + return fmt.Errorf("failed to unpack DNS message: %w", err) + } + + cachedMsg = &CachedMsg{ + Msg: msg, + Expires: expires, + } + return nil + }) + }) + + if err != nil { + if err == badger.ErrKeyNotFound { + return nil, false + } + // 缓存数据损坏或格式不兼容,返回未命中,后续 Set 会覆盖 + bc.logger.Printf("Cache get error for key %s: %v", key, err) + return nil, false + } + + return cachedMsg, true +} + +// Delete removes a key from the cache +func (bc *BadgerCache) Delete(key string) error { + return bc.db.Update(func(txn *badger.Txn) error { + return txn.Delete([]byte(key)) + }) +} + +// Close closes the BadgerDB instance +func (bc *BadgerCache) Close() error { + return bc.db.Close() +} + +// runGC runs garbage collection periodically to clean up expired entries in value log +func (bc *BadgerCache) runGC() { + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + // Run GC multiple times until no more rewrite is needed + gcCount := 0 + for { + err := bc.db.RunValueLogGC(0.5) + if err != nil { + if err != badger.ErrNoRewrite { + bc.logger.Printf("BadgerDB GC error: %v", err) + } + break + } + gcCount++ + // Limit GC runs and add delay to prevent CPU hogging + if gcCount >= 10 { + bc.logger.Printf("BadgerDB GC: reached max runs limit (10)") + break + } + // Sleep briefly between GC cycles to reduce CPU usage + time.Sleep(500 * time.Millisecond) + } + + if gcCount > 0 { + bc.logger.Printf("BadgerDB GC: completed %d runs", gcCount) + } + + // Check disk usage and clean if necessary + bc.checkAndCleanDiskUsage() + } +} + +// runCompaction runs LSM tree compaction periodically to clean up expired key metadata +func (bc *BadgerCache) runCompaction() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for range ticker.C { + err := bc.db.Flatten(1) + if err != nil { + bc.logger.Printf("BadgerDB compaction error: %v", err) + } + } +} + +// checkAndCleanDiskUsage checks if cache exceeds size limit and triggers cleanup +func (bc *BadgerCache) checkAndCleanDiskUsage() { + lsm, vlog := bc.db.Size() + totalSize := lsm + vlog + maxSize := int64(50 << 20) // 50MB limit (适合家用路由器等嵌入式设备) + + if totalSize > maxSize { + bc.logger.Printf("Cache size %d MB exceeds limit %d MB, triggering cleanup", totalSize>>20, maxSize>>20) + // Force compaction to reduce size + if err := bc.db.Flatten(2); err != nil { + bc.logger.Printf("BadgerDB flatten error: %v", err) + } + } +} + +// Stats returns cache statistics +func (bc *BadgerCache) Stats() string { + lsm, vlog := bc.db.Size() + return fmt.Sprintf("LSM size: %d bytes, Value log size: %d bytes", lsm, vlog) +} diff --git a/internal/handler/handler.go b/internal/handler/handler.go new file mode 100644 index 0000000..524e75c --- /dev/null +++ b/internal/handler/handler.go @@ -0,0 +1,750 @@ +package handler + +import ( + "errors" + "fmt" + "net" + "strings" + "sync" + "time" + + "github.com/miekg/dns" + + "godns/internal/cache" + "godns/internal/model" + "godns/internal/stats" + "godns/pkg/logger" +) + +type Handler struct { + strategy int + commonUpstreams, specialUpstreams []*model.Upstream + builtInCache cache.Cache + logger logger.Logger + stats stats.StatsRecorder +} + +func NewHandler(strategy int, builtInCache bool, + upstreams []*model.Upstream, + dataPath string, + log logger.Logger, + statsRecorder stats.StatsRecorder) *Handler { + var c cache.Cache + if builtInCache { + var err error + c, err = cache.NewBadgerCache(dataPath, log) + if err != nil { + log.Printf("Failed to initialize BadgerDB cache: %v", err) + log.Printf("Cache will be disabled") + c = nil + } else { + log.Printf("BadgerDB cache initialized successfully at %s", dataPath) + } + } + var commonUpstreams, specialUpstreams []*model.Upstream + for i := 0; i < len(upstreams); i++ { + if len(upstreams[i].Match) > 0 { + specialUpstreams = append(specialUpstreams, upstreams[i]) + } else { + commonUpstreams = append(commonUpstreams, upstreams[i]) + } + } + return &Handler{ + strategy: strategy, + commonUpstreams: commonUpstreams, + specialUpstreams: specialUpstreams, + builtInCache: c, + logger: log, + stats: statsRecorder, + } +} + +func (h *Handler) matchedUpstreams(req *dns.Msg) []*model.Upstream { + if len(req.Question) == 0 { + return h.commonUpstreams + } + q := req.Question[0] + var matchedUpstreams []*model.Upstream + for i := 0; i < len(h.specialUpstreams); i++ { + if h.specialUpstreams[i].IsMatch(q.Name) { + matchedUpstreams = append(matchedUpstreams, h.specialUpstreams[i]) + } + } + if len(matchedUpstreams) > 0 { + return matchedUpstreams + } + return h.commonUpstreams +} + +func (h *Handler) LookupIP(host string) (ip net.IP, err error) { + if ip = net.ParseIP(host); ip != nil { + return ip, nil + } + if !strings.HasSuffix(host, ".") { + host += "." + } + m := new(dns.Msg) + m.Id = dns.Id() + m.RecursionDesired = true + m.Question = make([]dns.Question, 1) + m.Question[0] = dns.Question{Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET} + res := h.exchange(m) + // 取一个 IPv4 地址 + for i := 0; i < len(res.Answer); i++ { + if aRecord, ok := res.Answer[i].(*dns.A); ok { + ip = aRecord.A + } + } + // 选取最后一个(一般是备用,存活率高一些) + if ip == nil { + err = errors.New("no ipv4 address found") + } + + h.logger.Printf("bootstrap LookupIP: %s %v --> %s %v", host, res.Answer, ip, err) + return +} + +// removeEDNS 清理请求中的 EDNS 客户端子网信息 +func (h *Handler) removeEDNS(req *dns.Msg) { + opt := req.IsEdns0() + if opt == nil { + return + } + + // 过滤掉 EDNS Client Subnet 选项 + var newOptions []dns.EDNS0 + for _, option := range opt.Option { + if _, ok := option.(*dns.EDNS0_SUBNET); !ok { + // 保留非 ECS 的其他选项 + newOptions = append(newOptions, option) + } else { + h.logger.Printf("Removed EDNS Client Subnet from request") + } + } + opt.Option = newOptions +} + +func (h *Handler) exchange(req *dns.Msg) *dns.Msg { + // 清理 EDNS 客户端子网信息 + h.removeEDNS(req) + + var msgs []*dns.Msg + + switch h.strategy { + case model.StrategyFullest: + msgs = h.getTheFullestResults(req) + case model.StrategyFastest: + msgs = h.getTheFastestResults(req) + case model.StrategyAnyResult: + msgs = h.getAnyResult(req) + } + + var res *dns.Msg + + for i := 0; i < len(msgs); i++ { + if msgs[i] == nil { + continue + } + if res == nil { + res = msgs[i] + continue + } + res.Answer = append(res.Answer, msgs[i].Answer...) + } + + if res == nil { + // 如果全部上游挂了要返回错误 + res = new(dns.Msg) + res.Rcode = dns.RcodeServerFailure + } else { + res.Answer = uniqueAnswer(res.Answer) + } + + return res +} + +func getDnsRequestCacheKey(m *dns.Msg) string { + var dnssec string + if o := m.IsEdns0(); o != nil { + // 区分 DNSSEC 请求,避免将非 DNSSEC 响应返回给需要 DNSSEC 的客户端 + if o.Do() { + dnssec = "DO" + } + // 服务多区域的公共dns使用 + // for _, s := range o.Option { + // switch e := s.(type) { + // case *dns.EDNS0_SUBNET: + // edns = e.Address.String() + // } + // } + } + return fmt.Sprintf("%s#%d#%s", model.GetDomainNameFromDnsMsg(m), m.Question[0].Qtype, dnssec) +} + +func getDnsResponseTtl(m *dns.Msg) time.Duration { + var ttl uint32 + if len(m.Answer) > 0 { + ttl = m.Answer[0].Header().Ttl + } + if ttl < 60 { + ttl = 60 // 最小 ttl 1 分钟 + } else if ttl > 3600 { + ttl = 3600 // 最大 ttl 1 小时 + } + return time.Duration(ttl) * time.Second +} + +// shouldCacheResponse 判断响应是否应该被缓存 +func shouldCacheResponse(m *dns.Msg) bool { + // 不缓存服务器错误响应 + if m.Rcode == dns.RcodeServerFailure { + return false + } + + // 不缓存格式错误的响应 + if m.Rcode == dns.RcodeFormatError { + return false + } + + // NXDOMAIN (域名不存在) 可以缓存,但时间较短(由 getDnsResponseTtl 控制) + // NOERROR 和 NXDOMAIN 都可以缓存 + return m.Rcode == dns.RcodeSuccess || m.Rcode == dns.RcodeNameError +} + +// validateResponse 验证 DNS 响应,防止缓存投毒 +// 返回 true 表示响应有效,false 表示可能存在投毒风险 +func validateResponse(req *dns.Msg, resp *dns.Msg, debugLogger logger.Logger) bool { + // 1. 检查响应是否为空 + if resp == nil { + return false + } + + // 2. 检查请求和响应的问题数量 + if len(req.Question) == 0 || len(resp.Question) == 0 { + return true // 如果没有问题部分,跳过验证(某些响应可能没有问题部分) + } + + // 3. 验证域名匹配(不区分大小写) + if !strings.EqualFold(req.Question[0].Name, resp.Question[0].Name) { + debugLogger.Printf("DNS response validation failed: domain mismatch - request: %s, response: %s", + req.Question[0].Name, resp.Question[0].Name) + return false + } + + // 4. 验证查询类型匹配 + if req.Question[0].Qtype != resp.Question[0].Qtype { + debugLogger.Printf("DNS response validation failed: qtype mismatch - request: %d, response: %d", + req.Question[0].Qtype, resp.Question[0].Qtype) + return false + } + + // 5. 验证查询类别匹配(通常都是 IN - Internet) + if req.Question[0].Qclass != resp.Question[0].Qclass { + debugLogger.Printf("DNS response validation failed: qclass mismatch - request: %d, response: %d", + req.Question[0].Qclass, resp.Question[0].Qclass) + return false + } + + // 6. 验证 Answer 部分的域名(防止返回无关域名的记录) + requestDomain := strings.ToLower(strings.TrimSuffix(req.Question[0].Name, ".")) + validDomains := make(map[string]bool) + validDomains[requestDomain] = true + + // 第一遍:收集所有 CNAME 目标域名 + for _, answer := range resp.Answer { + if answer.Header().Rrtype == dns.TypeCNAME { + if cname, ok := answer.(*dns.CNAME); ok { + cnameTarget := strings.ToLower(strings.TrimSuffix(cname.Target, ".")) + validDomains[cnameTarget] = true + } + } + } + + // 第二遍:验证所有应答记录 + for _, answer := range resp.Answer { + answerDomain := strings.ToLower(strings.TrimSuffix(answer.Header().Name, ".")) + + // 检查应答记录的域名是否在有效域名列表中 + if !validDomains[answerDomain] { + // 对于 CNAME 记录,域名必须是请求域名 + if answer.Header().Rrtype == dns.TypeCNAME { + if answerDomain != requestDomain { + debugLogger.Printf("DNS response validation failed: CNAME domain mismatch - request: %s, CNAME: %s", + requestDomain, answerDomain) + return false + } + } else { + // 对于其他记录类型,记录警告但不拒绝(某些服务器可能返回额外记录) + debugLogger.Printf("DNS response validation warning: answer domain not in valid chain - request: %s, answer: %s (type: %d)", + requestDomain, answerDomain, answer.Header().Rrtype) + } + } + } + + // 7. 检查 TTL 值的合理性(防止异常的 TTL 值) + for _, answer := range resp.Answer { + ttl := answer.Header().Ttl + // TTL 不应该超过 7 天(604800 秒) + if ttl > 604800 { + debugLogger.Printf("DNS response validation warning: suspiciously high TTL: %d seconds for %s", + ttl, answer.Header().Name) + } + } + + return true +} + +// HandleDnsMsg 处理 DNS 查询的核心逻辑(支持缓存和统计) +// clientIP 和 domain 用于统计,如果为空则自动从请求中提取 domain +func (h *Handler) HandleDnsMsg(req *dns.Msg, clientIP, domain string) *dns.Msg { + h.logger.Printf("godns::request %+v\n", req) + + // 记录查询统计 + if h.stats != nil { + h.stats.RecordQuery() + + // 提取域名(如果未提供) + if domain == "" && len(req.Question) > 0 { + domain = req.Question[0].Name + } + + // 记录客户端查询 + if clientIP != "" || domain != "" { + h.stats.RecordClientQuery(clientIP, domain) + } + } + + // 检查缓存 + var cacheKey string + var respCache *dns.Msg + if h.builtInCache != nil { + cacheKey = getDnsRequestCacheKey(req) + if v, ok := h.builtInCache.Get(cacheKey); ok { + if h.stats != nil { + h.stats.RecordCacheHit() + } + respCache = v.Msg.Copy() + if v.Expires.After(time.Now()) { + msg := replyUpdateTtl(req, respCache, uint32(time.Until(v.Expires).Seconds())) + if len(msg.Answer) > 0 { + return msg + } + } + } else { + if h.stats != nil { + h.stats.RecordCacheMiss() + } + } + } + + // 从上游获取响应 + resp := h.exchange(req) + + if resp.Rcode == dns.RcodeServerFailure { + if h.stats != nil { + h.stats.RecordFailed() + } + // 上游失败时使用任何可用缓存(即使过期)作为降级 + if respCache != nil { + msg := replyUpdateTtl(req, respCache, 12) + if len(msg.Answer) > 0 { + return msg + } + } + } + + resp.SetReply(req) + h.logger.Printf("godns::resp: %+v\n", resp) + + // 验证响应并缓存(防止缓存投毒) + if h.builtInCache != nil && shouldCacheResponse(resp) && validateResponse(req, resp, h.logger) { + ttl := getDnsResponseTtl(resp) + cachedMsg := &cache.CachedMsg{ + Msg: resp, + Expires: time.Now().Add(ttl), + } + if err := h.builtInCache.Set(cacheKey, cachedMsg, ttl+time.Hour); err != nil { + h.logger.Printf("Failed to cache response: %v", err) + } + } + + return resp +} + +// extractClientIPFromDNS 从 DNS 请求中提取客户端 IP +// 优先级:EDNS Client Subnet > RemoteAddr +func extractClientIPFromDNS(w dns.ResponseWriter, req *dns.Msg) string { + // 1. 优先检查 EDNS Client Subnet (ECS) + // ECS 是 DNS 协议标准,用于传递真实客户端 IP + if opt := req.IsEdns0(); opt != nil { + for _, option := range opt.Option { + if ecs, ok := option.(*dns.EDNS0_SUBNET); ok { + // ECS 中的 Address 就是客户端真实 IP + return ecs.Address.String() + } + } + } + + // 2. 从 RemoteAddr 获取 + var clientIP string + if addr := w.RemoteAddr(); addr != nil { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + clientIP = udpAddr.IP.String() + } else if tcpAddr, ok := addr.(*net.TCPAddr); ok { + clientIP = tcpAddr.IP.String() + } + } + + return clientIP +} + +func (h *Handler) HandleRequest(w dns.ResponseWriter, req *dns.Msg) { + // 提取客户端 IP + clientIP := extractClientIPFromDNS(w, req) + + // 提取域名 + var domain string + if len(req.Question) > 0 { + domain = req.Question[0].Name + } + + // 调用核心处理逻辑 + resp := h.HandleDnsMsg(req, clientIP, domain) + + // 写入响应 + if err := w.WriteMsg(resp); err != nil { + h.logger.Printf("WriteMsg error: %+v", err) + } +} + +// uniqueAnswer 去除重复的 DNS 资源记录 +// 基于域名、类型和记录数据进行去重,比字符串分割更高效和可靠 +func uniqueAnswer(records []dns.RR) []dns.RR { + if len(records) == 0 { + return records + } + + seen := make(map[string]bool, len(records)) + result := make([]dns.RR, 0, len(records)) + + for _, rr := range records { + if rr == nil { + continue + } + + header := rr.Header() + if header == nil { + continue + } + + // 构造唯一键:域名 + 类型 + 记录数据 + // 使用 strings.Builder 优化字符串拼接性能 + var builder strings.Builder + builder.Grow(128) // Pre-allocate reasonable capacity + + var key string + switch v := rr.(type) { + case *dns.A: + builder.WriteString(header.Name) + builder.WriteString("|A|") + builder.WriteString(v.A.String()) + key = builder.String() + case *dns.AAAA: + builder.WriteString(header.Name) + builder.WriteString("|AAAA|") + builder.WriteString(v.AAAA.String()) + key = builder.String() + case *dns.CNAME: + builder.WriteString(header.Name) + builder.WriteString("|CNAME|") + builder.WriteString(v.Target) + key = builder.String() + case *dns.MX: + builder.WriteString(header.Name) + builder.WriteString("|MX|") + builder.WriteString(fmt.Sprintf("%d|%s", v.Preference, v.Mx)) + key = builder.String() + case *dns.NS: + builder.WriteString(header.Name) + builder.WriteString("|NS|") + builder.WriteString(v.Ns) + key = builder.String() + case *dns.PTR: + builder.WriteString(header.Name) + builder.WriteString("|PTR|") + builder.WriteString(v.Ptr) + key = builder.String() + case *dns.TXT: + builder.WriteString(header.Name) + builder.WriteString("|TXT|") + builder.WriteString(strings.Join(v.Txt, "|")) + key = builder.String() + case *dns.SRV: + builder.WriteString(header.Name) + builder.WriteString("|SRV|") + builder.WriteString(fmt.Sprintf("%d|%d|%d|%s", v.Priority, v.Weight, v.Port, v.Target)) + key = builder.String() + case *dns.SOA: + builder.WriteString(header.Name) + builder.WriteString("|SOA|") + builder.WriteString(v.Ns) + builder.WriteString("|") + builder.WriteString(v.Mbox) + key = builder.String() + default: + // 对于其他类型,回退到完整字符串表示 + key = rr.String() + } + + if !seen[key] { + seen[key] = true + result = append(result, rr) + } + } + + return result +} + +func (h *Handler) getTheFullestResults(req *dns.Msg) []*dns.Msg { + matchedUpstreams := h.matchedUpstreams(req) + var wg sync.WaitGroup + wg.Add(len(matchedUpstreams)) + msgs := make([]*dns.Msg, len(matchedUpstreams)) + + for i := 0; i < len(matchedUpstreams); i++ { + go func(j int) { + defer wg.Done() + msg, _, err := matchedUpstreams[j].Exchange(req.Copy()) + + // 记录上游服务器统计 + if h.stats != nil { + h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil) + } + + if err != nil { + h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err) + return + } + if matchedUpstreams[j].IsValidMsg(msg) { + msgs[j] = msg + } + }(i) + } + + wg.Wait() + return msgs +} + +func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg { + preferUpstreams := h.matchedUpstreams(req) + msgs := make([]*dns.Msg, len(preferUpstreams)) + + var mutex sync.Mutex + var finishedCount int + var finished bool + var freedomIndex, primaryIndex []int + + var wg sync.WaitGroup + wg.Add(1) + + for i := 0; i < len(preferUpstreams); i++ { + go func(j int) { + msg, _, err := preferUpstreams[j].Exchange(req.Copy()) + + // 记录上游服务器统计 + if h.stats != nil { + h.stats.RecordUpstreamQuery(preferUpstreams[j].Address, err != nil) + } + + if err != nil { + h.logger.Printf("upstream error %s: %v %s", preferUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err) + } + + mutex.Lock() + defer mutex.Unlock() + + finishedCount++ + // 已经结束直接退出 + if finished { + return + } + + if err == nil { + if preferUpstreams[j].IsValidMsg(msg) { + if preferUpstreams[j].IsPrimary { + primaryIndex = append(primaryIndex, j) + } else { + freedomIndex = append(freedomIndex, j) + } + msgs[j] = msg + } else if preferUpstreams[j].IsPrimary { + // 策略:国内 DNS 返回了 国外 服务器,计数但是不记入结果,以 国外 DNS 为准 + primaryIndex = append(primaryIndex, j) + } + } + + // 全部结束直接退出 + if finishedCount == len(preferUpstreams) { + finished = true + wg.Done() + return + } + // 两组 DNS 都有一个返回结果,退出 + if len(primaryIndex) > 0 && len(freedomIndex) > 0 { + finished = true + wg.Done() + return + } + // 满足任一条件退出 + // - 国内 DNS 返回了 国内 服务器 + // - 国内 DNS 返回国外服务器 且 国外 DNS 有可用结果 + if len(primaryIndex) > 0 && (msgs[primaryIndex[0]] != nil || len(freedomIndex) > 0) { + finished = true + wg.Done() + } + }(i) + } + + wg.Wait() + return msgs +} + +func (h *Handler) getAnyResult(req *dns.Msg) []*dns.Msg { + matchedUpstreams := h.matchedUpstreams(req) + + var wg sync.WaitGroup + wg.Add(1) + msgs := make([]*dns.Msg, len(matchedUpstreams)) + var mutex sync.Mutex + var finishedCount int + var finished bool + + for i := 0; i < len(matchedUpstreams); i++ { + go func(j int) { + msg, _, err := matchedUpstreams[j].Exchange(req.Copy()) + + // 记录上游服务器统计 + if h.stats != nil { + h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil) + } + + if err != nil { + h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err) + } + mutex.Lock() + defer mutex.Unlock() + + finishedCount++ + if finished { + return + } + + // 已结束或任意上游返回成功时退出 + if err == nil || finishedCount == len(matchedUpstreams) { + finished = true + msgs[j] = msg + wg.Done() + } + }(i) + } + + wg.Wait() + return msgs +} + +// Close properly shuts down the cache +func (h *Handler) Close() error { + if h.builtInCache != nil { + return h.builtInCache.Close() + } + return nil +} + +// GetCacheStats returns cache statistics +func (h *Handler) GetCacheStats() string { + if h.builtInCache != nil { + return h.builtInCache.Stats() + } + return "Cache disabled" +} + +// replyUpdateTtl 准备缓存响应以发送给客户端,执行必要的修正: +// 1. 设置正确的 Message ID(通过 SetReply) +// 2. 更新所有 RR 的 TTL 为剩余时间(最低 0) +// 3. 调整 OPT RR 的 UDP size 为客户端请求的值 +// 4. 清除 ECS Scope Length(标记为缓存答案) +// 5. 检查过期的 RRSIG 并移除 +func replyUpdateTtl(req *dns.Msg, resp *dns.Msg, ttl uint32) *dns.Msg { + now := time.Now().Unix() + + // 辅助函数:更新 RR 列表的 TTL,并检测过期 RRSIG + updateRRs := func(rrs []dns.RR) []dns.RR { + var validRRs []dns.RR + for _, rr := range rrs { + header := rr.Header() + if header == nil { + continue + } + + // 检查 RRSIG 是否过期 + if rrsig, ok := rr.(*dns.RRSIG); ok { + if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration { + // RRSIG 已过期,跳过这条记录 + continue + } + } + + // 更新 TTL(最低为 0) + header.Ttl = ttl + validRRs = append(validRRs, rr) + } + return validRRs + } + + // 更新所有部分的 TTL 并移除过期 RRSIG + resp.Answer = updateRRs(resp.Answer) + resp.Ns = updateRRs(resp.Ns) + + // Extra 部分需要特殊处理 OPT RR + var validExtra []dns.RR + var reqOpt *dns.OPT + if reqOpt = req.IsEdns0(); reqOpt != nil { + // 客户端有 EDNS0,获取其 UDP size + } + + for _, rr := range resp.Extra { + if opt, ok := rr.(*dns.OPT); ok { + // 处理 OPT RR + if reqOpt != nil { + // 使用客户端请求的 UDP size + opt.SetUDPSize(reqOpt.UDPSize()) + } + + // 清除 ECS Scope Length + for i, option := range opt.Option { + if ecs, ok := option.(*dns.EDNS0_SUBNET); ok { + // 将 Scope Length 设为 0,表示这是缓存答案 + ecs.SourceScope = 0 + opt.Option[i] = ecs + } + } + validExtra = append(validExtra, opt) + } else { + // 非 OPT RR,正常更新 TTL 和检查 RRSIG + header := rr.Header() + if header != nil { + if rrsig, ok := rr.(*dns.RRSIG); ok { + if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration { + continue // 跳过过期的 RRSIG + } + } + header.Ttl = ttl + } + validExtra = append(validExtra, rr) + } + } + resp.Extra = validExtra + + // SetReply 会设置正确的 Message ID 和其他响应标志 + return resp.SetReply(req) +} diff --git a/internal/model/config.go b/internal/model/config.go new file mode 100644 index 0000000..bfdd0de --- /dev/null +++ b/internal/model/config.go @@ -0,0 +1,120 @@ +package model + +import ( + "encoding/json" + "net" + "os" + + "godns/pkg/logger" + "godns/pkg/utils" + "github.com/pkg/errors" + "github.com/yl2chen/cidranger" + "golang.org/x/net/proxy" +) + +const ( + _ = iota + StrategyFullest + StrategyFastest + StrategyAnyResult +) + +type DohServerConfig struct { + Username string `json:"username,omitempty"` // DoH Basic Auth 用户名(可选) + Password string `json:"password,omitempty"` // DoH Basic Auth 密码(可选) +} + +type WebAuth struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type Config struct { + ServeAddr string `json:"serve_addr,omitempty"` + WebAddr string `json:"web_addr,omitempty"` + DohServer *DohServerConfig `json:"doh_server,omitempty"` + Strategy int `json:"strategy,omitempty"` + Timeout int `json:"timeout,omitempty"` + SocksProxy string `json:"socks_proxy,omitempty"` + BuiltInCache bool `json:"built_in_cache,omitempty"` + Upstreams []*Upstream `json:"upstreams,omitempty"` + Bootstrap []*Upstream `json:"bootstrap,omitempty"` + Blacklist []string `json:"blacklist,omitempty"` + + Debug bool `json:"debug,omitempty"` + Profiling bool `json:"profiling,omitempty"` + + // Connection pool settings + MaxActiveConnections int `json:"max_active_connections,omitempty"` // Default: 50 + MaxIdleConnections int `json:"max_idle_connections,omitempty"` // Default: 20 + + // Stats persistence interval in minutes + StatsSaveInterval int `json:"stats_save_interval,omitempty"` // Default: 5 minutes + + BlacklistSplited [][]string `json:"-"` + // Web 面板鉴权 + WebAuth *WebAuth `json:"web_auth,omitempty"` +} + +func (c *Config) ReadInConfig(path string, ipRanger cidranger.Ranger, log logger.Logger) error { + body, err := os.ReadFile(path) + if err != nil { + return err + } + if err := json.Unmarshal([]byte(body), c); err != nil { + return err + } + + // Set default connection pool values + if c.MaxActiveConnections == 0 { + c.MaxActiveConnections = 50 + } + if c.MaxIdleConnections == 0 { + c.MaxIdleConnections = 20 + } + + // Set default stats save interval (5 minutes) + if c.StatsSaveInterval == 0 { + c.StatsSaveInterval = 5 + } + + for i := 0; i < len(c.Bootstrap); i++ { + c.Bootstrap[i].Init(c, ipRanger, log) + if net.ParseIP(c.Bootstrap[i].host) == nil { + return errors.New("Bootstrap 服务器只能使用 IP: " + c.Bootstrap[i].Address) + } + c.Bootstrap[i].InitConnectionPool(nil) + } + for i := 0; i < len(c.Upstreams); i++ { + c.Upstreams[i].Init(c, ipRanger, log) + if err := c.Upstreams[i].Validate(); err != nil { + return err + } + } + c.BlacklistSplited = utils.ParseRules(c.Blacklist) + return nil +} + +func (c *Config) GetDialerContext(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error) { + dialSocksProxy, err := proxy.SOCKS5("tcp", c.SocksProxy, nil, d) + if err != nil { + return nil, nil, errors.Wrap(err, "Error creating SOCKS5 proxy") + } + if dialContext, ok := dialSocksProxy.(proxy.ContextDialer); !ok { + return nil, nil, errors.New("Failed type assertion to DialContext") + } else { + return dialSocksProxy, dialContext, err + } +} + +func (c *Config) StrategyName() string { + switch c.Strategy { + case StrategyFullest: + return "最全结果" + case StrategyFastest: + return "最快结果" + case StrategyAnyResult: + return "任一结果(建议仅 bootstrap)" + } + panic("invalid strategy") +} diff --git a/internal/model/upstream.go b/internal/model/upstream.go new file mode 100644 index 0000000..e2896ba --- /dev/null +++ b/internal/model/upstream.go @@ -0,0 +1,282 @@ +package model + +import ( + "crypto/tls" + "fmt" + "net" + "runtime" + "strings" + "time" + + "github.com/dropbox/godropbox/net2" + "github.com/miekg/dns" + "github.com/pkg/errors" + "github.com/yl2chen/cidranger" + "go.uber.org/atomic" + + "godns/pkg/doh" + "godns/pkg/logger" + "godns/pkg/utils" +) + +type Upstream struct { + IsPrimary bool `json:"is_primary,omitempty"` + UseSocks bool `json:"use_socks,omitempty"` + Address string `json:"address,omitempty"` + Match []string `json:"match,omitempty"` + + protocol, hostAndPort, host, port string + config *Config + ipRanger cidranger.Ranger + matchSplited [][]string + + pool net2.ConnectionPool + dohClient *doh.Client + bootstrap func(host string) (net.IP, error) + logger logger.Logger + + count *atomic.Int64 +} + +func (up *Upstream) Init(config *Config, ipRanger cidranger.Ranger, log logger.Logger) { + var ok bool + up.protocol, up.hostAndPort, ok = strings.Cut(up.Address, "://") + if ok && up.protocol != "https" { + up.host, up.port, ok = strings.Cut(up.hostAndPort, ":") + } + if !ok { + panic("上游地址格式(protocol://host:port)有误:" + up.Address) + } + + if up.count != nil { + panic("Upstream 已经初始化过了:" + up.Address) + } + + up.matchSplited = utils.ParseRules(up.Match) + up.count = atomic.NewInt64(0) + up.config = config + up.ipRanger = ipRanger + up.logger = log +} + +// SetLogger 更新 upstream 的 logger 实例 +func (up *Upstream) SetLogger(log logger.Logger) { + up.logger = log +} + +func (up *Upstream) IsMatch(domain string) bool { + return utils.HasMatchedRule(up.matchSplited, domain) +} + +func (up *Upstream) Validate() error { + if !up.IsPrimary && up.protocol == "udp" { + return errors.New("非 primary 只能使用 tcp(-tls)/https:" + up.Address) + } + if up.IsPrimary && up.UseSocks { + return errors.New("primary 无需接入 socks:" + up.Address) + } + if up.UseSocks && up.config.SocksProxy == "" { + return errors.New("socks 未配置,但是上游已启用:" + up.Address) + } + if up.IsPrimary && up.protocol != "udp" { + up.logger.Println("[WARN] Primary 建议使用 udp 加速获取结果:" + up.Address) + } + return nil +} + +func (up *Upstream) conntionFactory(network, address string) (net.Conn, error) { + up.logger.Printf("connecting to %s://%s", network, address) + + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + if up.bootstrap != nil && net.ParseIP(host) == nil { + ip, err := up.bootstrap(host) + if err != nil { + address = fmt.Sprintf("%s:%s", "0.0.0.0", port) + } else { + address = fmt.Sprintf("%s:%s", ip.String(), port) + } + } + + if up.UseSocks { + d, _, err := up.config.GetDialerContext(&net.Dialer{ + Timeout: time.Second * time.Duration(up.config.Timeout), + }) + if err != nil { + return nil, err + } + switch network { + case "tcp": + return d.Dial(network, address) + case "tcp-tls": + conn, err := d.Dial("tcp", address) + if err != nil { + return nil, err + } + return tls.Client(conn, &tls.Config{ + ServerName: host, + }), nil + } + } else { + var d net.Dialer + d.Timeout = time.Second * time.Duration(up.config.Timeout) + switch network { + case "tcp": + return d.Dial(network, address) + case "tcp-tls": + return tls.DialWithDialer(&d, "tcp", address, &tls.Config{ + ServerName: host, + }) + } + } + + panic("wrong protocol: " + network) +} + +func (up *Upstream) InitConnectionPool(bootstrap func(host string) (net.IP, error)) { + up.bootstrap = bootstrap + + if strings.Contains(up.protocol, "http") { + ops := []doh.ClientOption{ + doh.WithServer(up.Address), + doh.WithBootstrap(bootstrap), + doh.WithTimeout(time.Second * time.Duration(up.config.Timeout)), + doh.WithLogger(up.logger), + } + if up.UseSocks { + ops = append(ops, doh.WithSocksProxy(up.config.GetDialerContext)) + } + up.dohClient = doh.NewClient(ops...) + } + + // 只需要启用 tcp/tcp-tls 协议的连接池 + if strings.Contains(up.protocol, "tcp") { + maxIdleTime := time.Second * time.Duration(up.config.Timeout*10) + timeout := time.Second * time.Duration(up.config.Timeout) + p := net2.NewSimpleConnectionPool(net2.ConnectionOptions{ + MaxActiveConnections: int32(up.config.MaxActiveConnections), + MaxIdleConnections: uint32(up.config.MaxIdleConnections), + MaxIdleTime: &maxIdleTime, + DialMaxConcurrency: 20, + ReadTimeout: timeout, + WriteTimeout: timeout, + Dial: func(network, address string) (net.Conn, error) { + dialer, err := up.conntionFactory(network, address) + if err != nil { + return nil, err + } + dialer.SetDeadline(time.Now().Add(timeout)) + return dialer, nil + }, + }) + p.Register(up.protocol, up.hostAndPort) + up.pool = p + } +} + +func (up *Upstream) IsValidMsg(r *dns.Msg) bool { + domain := GetDomainNameFromDnsMsg(r) + inBlacklist := utils.HasMatchedRule(up.config.BlacklistSplited, domain) + for i := 0; i < len(r.Answer); i++ { + var ip net.IP + typeA, ok := r.Answer[i].(*dns.A) + if ok { + ip = typeA.A + } else { + typeAAAA, ok := r.Answer[i].(*dns.AAAA) + if !ok { + continue + } + ip = typeAAAA.AAAA + } + isPrimary, err := up.ipRanger.Contains(ip) + if err != nil { + up.logger.Printf("ipRanger query ip %s failed: %s", ip, err) + continue + } + + up.logger.Printf("checkPrimary result %s: %s@%s ->domain.inBlacklist:%v ip.IsPrimary:%v up.IsPrimary:%v", up.Address, domain, ip, inBlacklist, isPrimary, up.IsPrimary) + + // 黑名单中的域名,如果是 primary 即不可用 + if inBlacklist && isPrimary { + return false + } + // 如果是 server 是 primary,但是 ip 不是 primary,也不可用 + if up.IsPrimary && !isPrimary { + return false + } + } + return !up.IsPrimary || len(r.Answer) > 0 +} + +func GetDomainNameFromDnsMsg(msg *dns.Msg) string { + if msg == nil || len(msg.Question) == 0 { + return "" + } + return msg.Question[0].Name +} + +func (up *Upstream) poolLen() int32 { + if up.pool == nil { + return 0 + } + return up.pool.NumActive() +} + +func (up *Upstream) Exchange(req *dns.Msg) (*dns.Msg, time.Duration, error) { + up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Inc(), up.poolLen(), runtime.NumGoroutine(), "enter") + defer up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Dec(), up.poolLen(), runtime.NumGoroutine(), "exit") + + var resp *dns.Msg + var duration time.Duration + var err error + + switch up.protocol { + case "https", "http": + resp, duration, err = up.dohClient.Exchange(req) + case "udp": + client := new(dns.Client) + client.Timeout = time.Second * time.Duration(up.config.Timeout) + resp, duration, err = client.Exchange(req, up.hostAndPort) + case "tcp", "tcp-tls": + conn, errGetConn := up.pool.Get(up.protocol, up.hostAndPort) + if errGetConn != nil { + return nil, 0, errGetConn + } + resp, err = dnsExchangeWithConn(conn, req) + default: + panic(fmt.Sprintf("invalid upstream protocol: %s in address %s", up.protocol, up.Address)) + } + + // 清理 EDNS 信息 + if resp != nil && len(resp.Extra) > 0 { + var newExtra []dns.RR + for i := 0; i < len(resp.Extra); i++ { + if resp.Extra[i].Header().Rrtype == dns.TypeOPT { + continue + } + newExtra = append(newExtra, resp.Extra[i]) + } + resp.Extra = newExtra + } + + return resp, duration, err +} + +func dnsExchangeWithConn(conn net2.ManagedConn, req *dns.Msg) (*dns.Msg, error) { + var resp *dns.Msg + co := dns.Conn{Conn: conn} + err := co.WriteMsg(req) + if err == nil { + resp, err = co.ReadMsg() + } + if err == nil { + conn.ReleaseConnection() + } else { + conn.DiscardConnection() + } + return resp, err +} diff --git a/internal/model/upstream_test.go b/internal/model/upstream_test.go new file mode 100644 index 0000000..87f48e8 --- /dev/null +++ b/internal/model/upstream_test.go @@ -0,0 +1,120 @@ +package model + +import ( + "index/suffixarray" + "strings" + "testing" + + "godns/pkg/utils" +) + +var primaryLocations = []string{"中国", "省", "市", "自治区"} +var nonPrimaryLocations = []string{"台湾", "香港", "澳门"} + +var primaryLocationsBytes = [][]byte{[]byte("中国"), []byte("省"), []byte("市"), []byte("自治区")} +var nonPrimaryLocationsBytes = [][]byte{[]byte("台湾"), []byte("香港"), []byte("澳门")} + +func BenchmarkCheckPrimary(b *testing.B) { + for i := 0; i < b.N; i++ { + checkPrimary("哈哈") + } +} + +func BenchmarkCheckPrimaryStringsContains(b *testing.B) { + for i := 0; i < b.N; i++ { + checkPrimaryStringsContains("哈哈") + } +} + +func TestIsMatch(t *testing.T) { + var up Upstream + up.matchSplited = utils.ParseRules([]string{"."}) + checkUpstreamMatch(&up, map[string]bool{ + "": false, + "a.com.": true, + "b.a.com.": true, + ".b.a.com.cn.": true, + "b.a.com.cn.": true, + "d.b.a.com.": true, + }, t) + + up.matchSplited = utils.ParseRules([]string{""}) + checkUpstreamMatch(&up, map[string]bool{ + "": false, + "a.com.": false, + "b.a.com.": false, + ".b.a.com.cn.": false, + "b.a.com.cn.": false, + "d.b.a.com.": false, + }, t) + + up.matchSplited = utils.ParseRules([]string{"a.com."}) + checkUpstreamMatch(&up, map[string]bool{ + "": false, + "a.com.": true, + "b.a.com.": false, + ".b.a.com.cn.": false, + "b.a.com.cn.": false, + "d.b.a.com.": false, + }, t) + + up.matchSplited = utils.ParseRules([]string{".a.com."}) + checkUpstreamMatch(&up, map[string]bool{ + "": false, + "a.com.": false, + "b.a.com.": true, + ".b.a.com.cn.": false, + "b.a.com.cn.": false, + "d.b.a.com.": true, + }, t) + + up.matchSplited = utils.ParseRules([]string{"b.d.com."}) + checkUpstreamMatch(&up, map[string]bool{ + "": false, + "a.com.": false, + ".a.com.": false, + "b.d.com.": true, + ".b.d.com.cn.": false, + "b.d.com.cn.": false, + ".c.d.com.": false, + "b.d.a.com.": false, + }, t) +} + +func checkUpstreamMatch(up *Upstream, cases map[string]bool, t *testing.T) { + for k, v := range cases { + isMatch := up.IsMatch(k) + if isMatch != v { + t.Errorf("Upstream(%s).IsMatch(%s) = %v, want %v", up.matchSplited, k, isMatch, v) + } + } +} + +func checkPrimary(str string) bool { + index := suffixarray.New([]byte(str)) + for i := 0; i < len(nonPrimaryLocationsBytes); i++ { + if len(index.Lookup(nonPrimaryLocationsBytes[i], 1)) > 0 { + return false + } + } + for i := 0; i < len(primaryLocationsBytes); i++ { + if len(index.Lookup(primaryLocationsBytes[i], 1)) > 0 { + return true + } + } + return false +} + +func checkPrimaryStringsContains(str string) bool { + for i := 0; i < len(nonPrimaryLocations); i++ { + if strings.Contains(str, nonPrimaryLocations[i]) { + return false + } + } + for i := 0; i < len(primaryLocations); i++ { + if strings.Contains(str, primaryLocations[i]) { + return true + } + } + return false +} diff --git a/internal/stats/stats.go b/internal/stats/stats.go new file mode 100644 index 0000000..4e7701a --- /dev/null +++ b/internal/stats/stats.go @@ -0,0 +1,649 @@ +package stats + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "sort" + "sync" + "sync/atomic" + "time" +) + +// StatsRecorder 定义统计接口 +type StatsRecorder interface { + RecordQuery() + RecordDoHQuery() + RecordCacheHit() + RecordCacheMiss() + RecordFailed() + RecordUpstreamQuery(address string, isError bool) + RecordClientQuery(clientIP, domain string) + GetSnapshot() StatsSnapshot + Reset() + Save(dataPath string) error + Load(dataPath string) error +} + +// Stats DNS服务器统计信息 +type Stats struct { + StartTime time.Time // 应用启动时间(不持久化) + StatsStartTime time.Time // 统计数据开始时间(可持久化) + + // 查询统计 + TotalQueries atomic.Uint64 + DoHQueries atomic.Uint64 + CacheHits atomic.Uint64 + CacheMisses atomic.Uint64 + FailedQueries atomic.Uint64 + + // 上游服务器统计 + upstreamStats map[string]*UpstreamStats + mu sync.RWMutex + + // Top N 统计 + topClients *TopNTracker // 客户端 IP Top N + topDomains *TopNTracker // 查询域名 Top N +} + +// UpstreamStats 上游服务器统计 +type UpstreamStats struct { + Address string + TotalQueries atomic.Uint64 + Errors atomic.Uint64 + LastUsed time.Time + mu sync.RWMutex +} + +// NewStats 创建统计实例 +func NewStats() *Stats { + now := time.Now() + return &Stats{ + StartTime: now, + StatsStartTime: now, + upstreamStats: make(map[string]*UpstreamStats), + topClients: NewTopNTracker(100), // 最多保留 100 个客户端 IP + topDomains: NewTopNTracker(200), // 最多保留 200 个域名 + } +} + +// RecordQuery 记录DNS查询 +func (s *Stats) RecordQuery() { + s.TotalQueries.Add(1) +} + +// RecordDoHQuery 记录DoH查询 +func (s *Stats) RecordDoHQuery() { + s.DoHQueries.Add(1) +} + +// RecordCacheHit 记录缓存命中 +func (s *Stats) RecordCacheHit() { + s.CacheHits.Add(1) +} + +// RecordCacheMiss 记录缓存未命中 +func (s *Stats) RecordCacheMiss() { + s.CacheMisses.Add(1) +} + +// RecordFailed 记录查询失败 +func (s *Stats) RecordFailed() { + s.FailedQueries.Add(1) +} + +// RecordUpstreamQuery 记录上游服务器查询 +func (s *Stats) RecordUpstreamQuery(address string, isError bool) { + // 先尝试读锁快速查找 + s.mu.RLock() + us, ok := s.upstreamStats[address] + s.mu.RUnlock() + + // 如果不存在才使用写锁创建 + if !ok { + s.mu.Lock() + // 双重检查,防止并发创建 + us, ok = s.upstreamStats[address] + if !ok { + us = &UpstreamStats{ + Address: address, + } + s.upstreamStats[address] = us + } + s.mu.Unlock() + } + + us.TotalQueries.Add(1) + if isError { + us.Errors.Add(1) + } + us.mu.Lock() + us.LastUsed = time.Now() + us.mu.Unlock() +} + +// RecordClientQuery 记录客户端查询(IP 和域名) +func (s *Stats) RecordClientQuery(clientIP, domain string) { + if clientIP != "" { + s.topClients.Record(clientIP, "") + } + if domain != "" { + s.topDomains.Record(domain, clientIP) + } +} + +// Reset 重置统计数据 +func (s *Stats) Reset() { + s.mu.Lock() + defer s.mu.Unlock() + + // 重置统计开始时间 + s.StatsStartTime = time.Now() + + // 重置查询统计 + s.TotalQueries.Store(0) + s.DoHQueries.Store(0) + s.CacheHits.Store(0) + s.CacheMisses.Store(0) + s.FailedQueries.Store(0) + + // 重置上游服务器统计 + s.upstreamStats = make(map[string]*UpstreamStats) + + // 重置 Top N 统计 + s.topClients = NewTopNTracker(100) + s.topDomains = NewTopNTracker(200) +} + +// RuntimeStats 运行时统计信息 +type RuntimeStats struct { + Uptime int64 `json:"uptime"` // 运行时间(秒) + UptimeStr string `json:"uptime_str"` // 运行时间(可读格式) + StatsDuration int64 `json:"stats_duration"` // 统计时长(秒) + StatsDurationStr string `json:"stats_duration_str"` // 统计时长(可读格式) + Goroutines int `json:"goroutines"` // Goroutine数量 + MemAllocMB uint64 `json:"mem_alloc_mb"` // 已分配内存(MB) + MemTotalMB uint64 `json:"mem_total_mb"` // 总分配内存(MB) + MemSysMB uint64 `json:"mem_sys_mb"` // 系统内存(MB) + NumGC uint32 `json:"num_gc"` // GC次数 +} + +// QueryStats 查询统计信息 +type QueryStats struct { + Total uint64 `json:"total"` // 总查询数 + DoH uint64 `json:"doh"` // DoH查询数 + CacheHits uint64 `json:"cache_hits"` // 缓存命中数 + CacheMisses uint64 `json:"cache_misses"` // 缓存未命中数 + Failed uint64 `json:"failed"` // 失败查询数 + HitRate float64 `json:"hit_rate"` // 缓存命中率 +} + +// UpstreamStatsJSON 上游服务器统计(JSON格式) +type UpstreamStatsJSON struct { + Address string `json:"address"` // 服务器地址 + TotalQueries uint64 `json:"total_queries"` // 总查询数 + Errors uint64 `json:"errors"` // 错误数 + ErrorRate float64 `json:"error_rate"` // 错误率 + LastUsed string `json:"last_used"` // 最后使用时间 +} + +// TopNItemJSON Top N 项目(JSON格式) +type TopNItemJSON struct { + Key string `json:"key"` // IP 地址或域名 + Count uint64 `json:"count"` // 查询次数 + TopClient string `json:"top_client,omitempty"` // 查询最多的客户端 IP(仅域名统计有) +} + +// StatsSnapshot 完整统计快照 +type StatsSnapshot struct { + Runtime RuntimeStats `json:"runtime"` // 运行时信息 + Queries QueryStats `json:"queries"` // 查询统计 + Upstreams []UpstreamStatsJSON `json:"upstreams"` // 上游服务器统计 + TopClients []TopNItemJSON `json:"top_clients"` // Top 客户端 IP + TopDomains []TopNItemJSON `json:"top_domains"` // Top 查询域名 +} + +// GetSnapshot 获取统计快照 +func (s *Stats) GetSnapshot() StatsSnapshot { + // 运行时信息 + var m runtime.MemStats + runtime.ReadMemStats(&m) + + uptime := time.Since(s.StartTime) + uptimeStr := formatDuration(uptime) + + statsDuration := time.Since(s.StatsStartTime) + statsDurationStr := formatDuration(statsDuration) + + runtimeStats := RuntimeStats{ + Uptime: int64(uptime.Seconds()), + UptimeStr: uptimeStr, + StatsDuration: int64(statsDuration.Seconds()), + StatsDurationStr: statsDurationStr, + Goroutines: runtime.NumGoroutine(), + MemAllocMB: m.Alloc / 1024 / 1024, + MemTotalMB: m.TotalAlloc / 1024 / 1024, + MemSysMB: m.Sys / 1024 / 1024, + NumGC: m.NumGC, + } + + // 查询统计 + total := s.TotalQueries.Load() + hits := s.CacheHits.Load() + misses := s.CacheMisses.Load() + failed := s.FailedQueries.Load() + + var hitRate float64 + if total > 0 { + hitRate = float64(hits) / float64(total) * 100 + } + + queryStats := QueryStats{ + Total: total, + DoH: s.DoHQueries.Load(), + CacheHits: hits, + CacheMisses: misses, + Failed: failed, + HitRate: hitRate, + } + + // 上游服务器统计 + s.mu.RLock() + upstreams := make([]UpstreamStatsJSON, 0, len(s.upstreamStats)) + for _, us := range s.upstreamStats { + queries := us.TotalQueries.Load() + errors := us.Errors.Load() + var errorRate float64 + if queries > 0 { + errorRate = float64(errors) / float64(queries) * 100 + } + + us.mu.RLock() + lastUsed := us.LastUsed.Format("2006-01-02 15:04:05") + if us.LastUsed.IsZero() { + lastUsed = "Never" + } + us.mu.RUnlock() + + upstreams = append(upstreams, UpstreamStatsJSON{ + Address: us.Address, + TotalQueries: queries, + Errors: errors, + ErrorRate: errorRate, + LastUsed: lastUsed, + }) + } + s.mu.RUnlock() + + // 按服务器地址字符串排序 + sort.Slice(upstreams, func(i, j int) bool { + return upstreams[i].Address < upstreams[j].Address + }) + + // Top N 客户端 IP + topClients := make([]TopNItemJSON, 0) + for _, item := range s.topClients.GetTopN(20) { // 返回 Top 20 + topClients = append(topClients, TopNItemJSON{ + Key: item.Key, + Count: item.Count, + }) + } + + // Top N 查询域名 + topDomains := make([]TopNItemJSON, 0) + for _, item := range s.topDomains.GetTopN(20) { // 返回 Top 20 + topDomains = append(topDomains, TopNItemJSON{ + Key: item.Key, + Count: item.Count, + TopClient: item.TopClient, + }) + } + + return StatsSnapshot{ + Runtime: runtimeStats, + Queries: queryStats, + Upstreams: upstreams, + TopClients: topClients, + TopDomains: topDomains, + } +} + +// formatDuration 格式化时长为可读格式 +func formatDuration(d time.Duration) string { + days := int(d.Hours()) / 24 + hours := int(d.Hours()) % 24 + minutes := int(d.Minutes()) % 60 + seconds := int(d.Seconds()) % 60 + + if days > 0 { + return formatString("%d天%d小时%d分钟", days, hours, minutes) + } else if hours > 0 { + return formatString("%d小时%d分钟%d秒", hours, minutes, seconds) + } else if minutes > 0 { + return formatString("%d分钟%d秒", minutes, seconds) + } + return formatString("%d秒", seconds) +} + +// formatString 简单的字符串格式化 +func formatString(format string, args ...interface{}) string { + result := format + for _, arg := range args { + switch v := arg.(type) { + case int: + result = replaceFirst(result, "%d", itoa(v)) + } + } + return result +} + +// replaceFirst 替换第一个匹配的字符串 +func replaceFirst(s, old, new string) string { + for i := 0; i <= len(s)-len(old); i++ { + if s[i:i+len(old)] == old { + return s[:i] + new + s[i+len(old):] + } + } + return s +} + +// itoa 整数转字符串 +func itoa(i int) string { + if i == 0 { + return "0" + } + negative := i < 0 + if negative { + i = -i + } + var buf [32]byte + pos := len(buf) + for i > 0 { + pos-- + buf[pos] = byte('0' + i%10) + i /= 10 + } + if negative { + pos-- + buf[pos] = '-' + } + return string(buf[pos:]) +} + +// TopNTracker 追踪 Top N 项目,内存可控 +type TopNTracker struct { + mu sync.RWMutex + items map[string]*TopNItem + maxItems int // 最大保留项目数 +} + +// TopNItem Top N 项目统计 +type TopNItem struct { + Key string + Count uint64 + TopClient string // 对于域名统计,记录查询最多的客户端 IP + clients map[string]uint64 // 临时记录客户端分布(仅用于找 Top1) +} + +// PersistentStats 持久化统计数据结构 +type PersistentStats struct { + StatsStartTime time.Time `json:"stats_start_time"` // 统计开始时间(可持久化) + TotalQueries uint64 `json:"total_queries"` + DoHQueries uint64 `json:"doh_queries"` + CacheHits uint64 `json:"cache_hits"` + CacheMisses uint64 `json:"cache_misses"` + FailedQueries uint64 `json:"failed_queries"` + Upstreams map[string]*PersistentUpstream `json:"upstreams"` + TopClients []PersistentTopNItem `json:"top_clients"` + TopDomains []PersistentTopNItem `json:"top_domains"` +} + +// PersistentUpstream 持久化上游服务器统计 +type PersistentUpstream struct { + Address string `json:"address"` + TotalQueries uint64 `json:"total_queries"` + Errors uint64 `json:"errors"` + LastUsed time.Time `json:"last_used"` +} + +// PersistentTopNItem 持久化 Top N 项目 +type PersistentTopNItem struct { + Key string `json:"key"` + Count uint64 `json:"count"` + TopClient string `json:"top_client,omitempty"` + Clients map[string]uint64 `json:"clients,omitempty"` +} + +// NewTopNTracker 创建 Top N 追踪器 +func NewTopNTracker(maxItems int) *TopNTracker { + return &TopNTracker{ + items: make(map[string]*TopNItem), + maxItems: maxItems, + } +} + +// Record 记录一次访问(可选关联的客户端 IP) +func (t *TopNTracker) Record(key, associatedClient string) { + t.mu.Lock() + defer t.mu.Unlock() + + item, exists := t.items[key] + if !exists { + // 如果超过最大数量,删除计数最少的项 + if len(t.items) >= t.maxItems { + t.evictLowest() + } + item = &TopNItem{ + Key: key, + clients: make(map[string]uint64), + } + t.items[key] = item + } + + item.Count++ + + // 如果有关联客户端,记录客户端分布 + if associatedClient != "" { + item.clients[associatedClient]++ + // 更新 Top1 客户端 + if item.clients[associatedClient] > item.clients[item.TopClient] { + item.TopClient = associatedClient + } + } +} + +// evictLowest 删除计数最少的项(不加锁,由调用者加锁) +func (t *TopNTracker) evictLowest() { + var minKey string + var minCount uint64 = ^uint64(0) // 最大值 + + for key, item := range t.items { + if item.Count < minCount { + minCount = item.Count + minKey = key + } + } + + if minKey != "" { + delete(t.items, minKey) + } +} + +// GetTopN 获取 Top N 列表 +func (t *TopNTracker) GetTopN(n int) []TopNItem { + t.mu.RLock() + defer t.mu.RUnlock() + + // 复制所有项 + items := make([]TopNItem, 0, len(t.items)) + for _, item := range t.items { + items = append(items, TopNItem{ + Key: item.Key, + Count: item.Count, + TopClient: item.TopClient, + }) + } + + // 按查询次数降序排序 + sort.Slice(items, func(i, j int) bool { + return items[i].Count > items[j].Count + }) + + // 返回前 N 项 + if n > len(items) { + n = len(items) + } + return items[:n] +} + +// Save 保存统计数据到 JSON 文件 +func (s *Stats) Save(dataPath string) error { + s.mu.RLock() + defer s.mu.RUnlock() + + // 准备持久化数据 + persistent := PersistentStats{ + StatsStartTime: s.StatsStartTime, + TotalQueries: s.TotalQueries.Load(), + DoHQueries: s.DoHQueries.Load(), + CacheHits: s.CacheHits.Load(), + CacheMisses: s.CacheMisses.Load(), + FailedQueries: s.FailedQueries.Load(), + Upstreams: make(map[string]*PersistentUpstream), + TopClients: make([]PersistentTopNItem, 0), + TopDomains: make([]PersistentTopNItem, 0), + } + + // 保存上游服务器统计 + for addr, us := range s.upstreamStats { + us.mu.RLock() + persistent.Upstreams[addr] = &PersistentUpstream{ + Address: us.Address, + TotalQueries: us.TotalQueries.Load(), + Errors: us.Errors.Load(), + LastUsed: us.LastUsed, + } + us.mu.RUnlock() + } + + // 保存 Top 客户端 + s.topClients.mu.RLock() + for _, item := range s.topClients.items { + persistent.TopClients = append(persistent.TopClients, PersistentTopNItem{ + Key: item.Key, + Count: item.Count, + TopClient: item.TopClient, + Clients: item.clients, + }) + } + s.topClients.mu.RUnlock() + + // 保存 Top 域名 + s.topDomains.mu.RLock() + for _, item := range s.topDomains.items { + persistent.TopDomains = append(persistent.TopDomains, PersistentTopNItem{ + Key: item.Key, + Count: item.Count, + TopClient: item.TopClient, + Clients: item.clients, + }) + } + s.topDomains.mu.RUnlock() + + // 序列化为 JSON + data, err := json.MarshalIndent(persistent, "", " ") + if err != nil { + return err + } + + // 确保目录存在 + statsPath := filepath.Join(dataPath, "cache") + if err := os.MkdirAll(statsPath, 0755); err != nil { + return err + } + + // 写入文件 + statsFile := filepath.Join(statsPath, "stats.json") + return os.WriteFile(statsFile, data, 0644) +} + +// Load 从 JSON 文件加载统计数据 +func (s *Stats) Load(dataPath string) error { + statsFile := filepath.Join(dataPath, "cache", "stats.json") + + // 检查文件是否存在 + if _, err := os.Stat(statsFile); os.IsNotExist(err) { + return nil // 文件不存在不是错误,返回 nil + } + + // 读取文件 + data, err := os.ReadFile(statsFile) + if err != nil { + return err + } + + // 解析 JSON + var persistent PersistentStats + if err := json.Unmarshal(data, &persistent); err != nil { + return err + } + + // 恢复统计数据 + s.mu.Lock() + defer s.mu.Unlock() + + // StartTime 保持为应用启动时间,不从磁盘恢复 + // 只恢复 StatsStartTime(统计数据开始时间) + s.StatsStartTime = persistent.StatsStartTime + s.TotalQueries.Store(persistent.TotalQueries) + s.DoHQueries.Store(persistent.DoHQueries) + s.CacheHits.Store(persistent.CacheHits) + s.CacheMisses.Store(persistent.CacheMisses) + s.FailedQueries.Store(persistent.FailedQueries) + + // 恢复上游服务器统计 + for addr, pus := range persistent.Upstreams { + us := &UpstreamStats{ + Address: pus.Address, + LastUsed: pus.LastUsed, + } + us.TotalQueries.Store(pus.TotalQueries) + us.Errors.Store(pus.Errors) + s.upstreamStats[addr] = us + } + + // 恢复 Top 客户端 + s.topClients.mu.Lock() + for _, pitem := range persistent.TopClients { + item := &TopNItem{ + Key: pitem.Key, + Count: pitem.Count, + TopClient: pitem.TopClient, + clients: pitem.Clients, + } + if item.clients == nil { + item.clients = make(map[string]uint64) + } + s.topClients.items[pitem.Key] = item + } + s.topClients.mu.Unlock() + + // 恢复 Top 域名 + s.topDomains.mu.Lock() + for _, pitem := range persistent.TopDomains { + item := &TopNItem{ + Key: pitem.Key, + Count: pitem.Count, + TopClient: pitem.TopClient, + clients: pitem.Clients, + } + if item.clients == nil { + item.clients = make(map[string]uint64) + } + s.topDomains.items[pitem.Key] = item + } + s.topDomains.mu.Unlock() + + return nil +} diff --git a/internal/web/handler.go b/internal/web/handler.go new file mode 100644 index 0000000..23d6cd1 --- /dev/null +++ b/internal/web/handler.go @@ -0,0 +1,205 @@ +package web + +import ( + "crypto/subtle" + "embed" + "encoding/json" + "io/fs" + "net/http" + + "godns/internal/stats" + "godns/pkg/logger" +) + +//go:embed static/* +var staticFiles embed.FS + +// Handler Web服务处理器 +type Handler struct { + stats stats.StatsRecorder + version string + checkUpdateCh chan<- struct{} + logger logger.Logger + username string + password string +} + +// NewHandler 创建Web处理器 +func NewHandler(s stats.StatsRecorder, ver string, checkCh chan<- struct{}, log logger.Logger, username, password string) *Handler { + return &Handler{ + stats: s, + version: ver, + checkUpdateCh: checkCh, + logger: log, + username: username, + password: password, + } +} + +// basicAuth 中间件 +func (h *Handler) basicAuth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // 如果未配置鉴权,直接放行 + if h.username == "" || h.password == "" { + next(w, r) + return + } + user, pass, ok := r.BasicAuth() + if !ok || subtle.ConstantTimeCompare([]byte(user), []byte(h.username)) != 1 || + subtle.ConstantTimeCompare([]byte(pass), []byte(h.password)) != 1 { + w.Header().Set("WWW-Authenticate", `Basic realm="NBDNS Monitor"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + next(w, r) + } +} + +// RegisterRoutes 注册路由 +func (h *Handler) RegisterRoutes(mux *http.ServeMux) { + // API路由 + mux.HandleFunc("/api/stats", h.basicAuth(h.handleStats)) + mux.HandleFunc("/api/version", h.basicAuth(h.handleVersion)) + mux.HandleFunc("/api/check-update", h.basicAuth(h.handleCheckUpdate)) + mux.HandleFunc("/api/stats/reset", h.basicAuth(h.handleStatsReset)) + + // 静态文件服务 + staticFS, err := fs.Sub(staticFiles, "static") + if err != nil { + h.logger.Printf("Failed to load static files: %v", err) + return + } + mux.Handle("/", h.basicAuth(func(w http.ResponseWriter, r *http.Request) { + http.FileServer(http.FS(staticFS)).ServeHTTP(w, r) + })) +} + +// handleStats 处理统计信息请求 +func (h *Handler) handleStats(w http.ResponseWriter, r *http.Request) { + // 只允许GET请求 + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 获取统计快照 + snapshot := h.stats.GetSnapshot() + + // 设置响应头 + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + + // 编码JSON并返回 + if err := json.NewEncoder(w).Encode(snapshot); err != nil { + h.logger.Printf("Error encoding stats JSON: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } +} + +// ResetResponse 重置响应 +type ResetResponse struct { + Success bool `json:"success"` + Message string `json:"message"` +} + +// handleStatsReset 处理统计数据重置请求 +func (h *Handler) handleStatsReset(w http.ResponseWriter, r *http.Request) { + // 只允许POST请求 + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 重置统计数据 + h.stats.Reset() + h.logger.Printf("Statistics reset by user request") + + // 设置响应头 + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + + // 返回成功响应 + if err := json.NewEncoder(w).Encode(ResetResponse{ + Success: true, + Message: "统计数据已重置", + }); err != nil { + h.logger.Printf("Error encoding reset response JSON: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } +} + +// VersionResponse 版本信息响应 +type VersionResponse struct { + Version string `json:"version"` +} + +// handleVersion 处理版本查询请求 +func (h *Handler) handleVersion(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ver := h.version + if ver == "" { + ver = "0.0.0" + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + + if err := json.NewEncoder(w).Encode(VersionResponse{Version: ver}); err != nil { + h.logger.Printf("Error encoding version JSON: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } +} + +// UpdateCheckResponse 更新检查响应 +type UpdateCheckResponse struct { + HasUpdate bool `json:"has_update"` + CurrentVersion string `json:"current_version"` + LatestVersion string `json:"latest_version"` + Message string `json:"message"` +} + +// handleCheckUpdate 处理检查更新请求(生产者2:用户手动触发) +func (h *Handler) handleCheckUpdate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ver := h.version + if ver == "" { + ver = "0.0.0" + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + + // 触发后台检查更新(非阻塞) + select { + case h.checkUpdateCh <- struct{}{}: + h.logger.Printf("Update check triggered by user") + json.NewEncoder(w).Encode(UpdateCheckResponse{ + HasUpdate: false, + CurrentVersion: ver, + LatestVersion: ver, + Message: "已触发更新检查,请查看服务器日志", + }) + default: + // 如果通道已满,说明已经在检查中 + json.NewEncoder(w).Encode(UpdateCheckResponse{ + HasUpdate: false, + CurrentVersion: ver, + LatestVersion: ver, + Message: "更新检查正在进行中", + }) + } +} diff --git a/internal/web/static/app.js b/internal/web/static/app.js new file mode 100644 index 0000000..5c44475 --- /dev/null +++ b/internal/web/static/app.js @@ -0,0 +1,307 @@ +// 自动刷新间隔(毫秒) +const REFRESH_INTERVAL = 3000; +let refreshTimer = null; +let countdownTimer = null; +let countdown = 0; +let isCheckingUpdate = false; +let isResettingStats = false; + +// 格式化数字,添加千位分隔符 +function formatNumber(num) { + return num.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ","); +} + +// 格式化百分比 +function formatPercent(num) { + return num.toFixed(2) + '%'; +} + +// 更新运行时信息 +function updateRuntimeStats(runtime) { + document.getElementById('uptime').textContent = runtime.uptime_str || '-'; + document.getElementById('goroutines').textContent = formatNumber(runtime.goroutines || 0); + document.getElementById('mem-alloc').textContent = formatNumber(runtime.mem_alloc_mb || 0) + ' MB'; + document.getElementById('mem-sys').textContent = formatNumber(runtime.mem_sys_mb || 0) + ' MB'; + document.getElementById('mem-total').textContent = formatNumber(runtime.mem_total_mb || 0) + ' MB'; + document.getElementById('num-gc').textContent = formatNumber(runtime.num_gc || 0); + + // 更新统计时长 + const statsDuration = runtime.stats_duration_str || '-'; + document.getElementById('stats-duration').textContent = '统计时长: ' + statsDuration; +} + +// 更新查询统计 +function updateQueryStats(queries) { + document.getElementById('total-queries').textContent = formatNumber(queries.total || 0); + document.getElementById('doh-queries').textContent = formatNumber(queries.doh || 0); + document.getElementById('cache-hits').textContent = formatNumber(queries.cache_hits || 0); + document.getElementById('cache-misses').textContent = formatNumber(queries.cache_misses || 0); + document.getElementById('failed-queries').textContent = formatNumber(queries.failed || 0); + document.getElementById('hit-rate').textContent = formatPercent(queries.hit_rate || 0); +} + +// 更新上游服务器表格 +function updateUpstreamTable(upstreams) { + const tbody = document.getElementById('upstream-tbody'); + + if (!upstreams || upstreams.length === 0) { + tbody.innerHTML = '暂无数据'; + return; + } + + let html = ''; + upstreams.forEach(upstream => { + const errorClass = upstream.error_rate > 10 ? 'error-high' : ''; + html += ` + + ${upstream.address || '-'} + ${formatNumber(upstream.total_queries || 0)} + ${formatNumber(upstream.errors || 0)} + ${formatPercent(upstream.error_rate || 0)} + ${upstream.last_used || 'Never'} + + `; + }); + tbody.innerHTML = html; +} + +// 更新 Top 客户端 IP 表格 +function updateTopClientsTable(topClients) { + const tbody = document.getElementById('top-clients-tbody'); + + if (!topClients || topClients.length === 0) { + tbody.innerHTML = '暂无数据'; + return; + } + + let html = ''; + topClients.forEach((client, index) => { + const rankClass = index < 3 ? `rank-${index + 1}` : ''; + html += ` + + ${index + 1} + ${client.key || '-'} + ${formatNumber(client.count || 0)} + + `; + }); + tbody.innerHTML = html; +} + +// 更新 Top 查询域名表格 +function updateTopDomainsTable(topDomains) { + const tbody = document.getElementById('top-domains-tbody'); + + if (!topDomains || topDomains.length === 0) { + tbody.innerHTML = '暂无数据'; + return; + } + + let html = ''; + topDomains.forEach((domain, index) => { + const rankClass = index < 3 ? `rank-${index + 1}` : ''; + const topClient = domain.top_client || '-'; + html += ` + + ${index + 1} + ${domain.key || '-'} + ${formatNumber(domain.count || 0)} + ${topClient} + + `; + }); + tbody.innerHTML = html; +} + +// 更新倒计时显示 +function updateCountdown() { + countdown--; + if (countdown <= 0) { + countdown = 0; + } + document.getElementById('last-update').textContent = `下次刷新: ${countdown}秒`; +} + +// 重置倒计时 +function resetCountdown() { + countdown = REFRESH_INTERVAL / 1000; + if (countdownTimer) { + clearInterval(countdownTimer); + } + countdownTimer = setInterval(updateCountdown, 1000); + updateCountdown(); +} + +// 加载统计数据 +async function loadStats() { + try { + const response = await fetch('/api/stats'); + if (!response.ok) { + throw new Error('获取统计数据失败'); + } + + const data = await response.json(); + + // 更新各部分数据 + updateRuntimeStats(data.runtime); + updateQueryStats(data.queries); + updateUpstreamTable(data.upstreams); + updateTopClientsTable(data.top_clients); + updateTopDomainsTable(data.top_domains); + + // 重置倒计时 + resetCountdown(); + + } catch (error) { + console.error('加载统计数据出错:', error); + document.getElementById('last-update').textContent = '加载失败'; + } +} + +// 启动自动刷新 +function startAutoRefresh() { + if (refreshTimer) { + clearInterval(refreshTimer); + } + refreshTimer = setInterval(loadStats, REFRESH_INTERVAL); +} + +// 停止自动刷新 +function stopAutoRefresh() { + if (refreshTimer) { + clearInterval(refreshTimer); + refreshTimer = null; + } + if (countdownTimer) { + clearInterval(countdownTimer); + countdownTimer = null; + } +} + +// 加载版本号 +async function loadVersion() { + try { + const response = await fetch('/api/version'); + if (!response.ok) { + throw new Error('获取版本号失败'); + } + const data = await response.json(); + document.getElementById('version-display').textContent = 'v' + data.version; + } catch (error) { + console.error('加载版本号出错:', error); + document.getElementById('version-display').textContent = 'v0.0.0'; + } +} + +// 检查更新 +async function checkUpdate() { + if (isCheckingUpdate) { + return; + } + + const btn = document.getElementById('check-update-btn'); + const originalText = btn.textContent; + + try { + isCheckingUpdate = true; + btn.textContent = '⏳'; + btn.disabled = true; + + const response = await fetch('/api/check-update'); + if (!response.ok) { + throw new Error('检查更新失败'); + } + + const data = await response.json(); + + if (data.has_update) { + alert(`${data.message}\n当前版本: v${data.current_version}\n最新版本: v${data.latest_version}\n\n请访问 GitHub 下载最新版本`); + } else { + alert(`${data.message}\n当前版本: v${data.current_version}`); + } + } catch (error) { + console.error('检查更新出错:', error); + alert('检查更新失败,请稍后再试'); + } finally { + isCheckingUpdate = false; + btn.textContent = originalText; + btn.disabled = false; + } +} + +// 重置统计数据 +async function resetStats() { + if (isResettingStats) { + return; + } + + // 确认对话框 + if (!confirm('确定要重置所有统计数据吗?此操作无法撤销。')) { + return; + } + + const btn = document.getElementById('reset-stats-btn'); + const originalText = btn.textContent; + + try { + isResettingStats = true; + btn.textContent = '⏳ 重置中...'; + btn.disabled = true; + + const response = await fetch('/api/stats/reset', { + method: 'POST' + }); + + if (!response.ok) { + throw new Error('重置统计数据失败'); + } + + const data = await response.json(); + + if (data.success) { + alert(data.message || '统计数据已重置'); + // 立即刷新数据 + await loadStats(); + } else { + alert('重置失败: ' + (data.message || '未知错误')); + } + } catch (error) { + console.error('重置统计数据出错:', error); + alert('重置统计数据失败,请稍后再试'); + } finally { + isResettingStats = false; + btn.textContent = originalText; + btn.disabled = false; + } +} + +// 页面加载完成后初始化 +document.addEventListener('DOMContentLoaded', function() { + // 立即加载一次数据 + loadStats(); + loadVersion(); + + // 启动自动刷新 + startAutoRefresh(); + + // 绑定检查更新按钮 + document.getElementById('check-update-btn').addEventListener('click', checkUpdate); + + // 绑定重置统计按钮 + document.getElementById('reset-stats-btn').addEventListener('click', resetStats); + + // 页面可见性变化时控制刷新 + document.addEventListener('visibilitychange', function() { + if (document.hidden) { + stopAutoRefresh(); + } else { + loadStats(); + startAutoRefresh(); + } + }); +}); + +// 页面卸载时停止刷新 +window.addEventListener('beforeunload', function() { + stopAutoRefresh(); +}); diff --git a/internal/web/static/index.html b/internal/web/static/index.html new file mode 100644 index 0000000..9228bab --- /dev/null +++ b/internal/web/static/index.html @@ -0,0 +1,172 @@ + + + + + + + GoDNS 监控面板 + + + + + +
+
+

GoDNS 监控面板

+
+ 正在加载... +
+
+ +
+ +
+

运行时信息

+
+
+ 运行时长 + - +
+
+ Goroutines + - +
+
+ 已分配内存 + - +
+
+ 系统内存 + - +
+
+ 总分配内存 + - +
+
+ GC 次数 + - +
+
+
+ + +
+
+

DNS 查询统计

+
+ 统计时长: - + +
+
+
+
+ 总查询数 + - +
+
+ DoH 请求 + - +
+
+ 缓存命中 + - +
+
+ 缓存未命中 + - +
+
+ 失败查询 + - +
+
+ 缓存命中率 + - +
+
+
+ + +
+

上游服务器统计

+
+ + + + + + + + + + + + + + + +
服务器地址总查询数错误数错误率最后使用
暂无数据
+
+
+ + +
+

Top 客户端 IP

+
+ + + + + + + + + + + + + +
排名IP 地址查询次数
暂无数据
+
+
+ + +
+

Top 查询域名

+
+ + + + + + + + + + + + + + +
排名域名查询次数Top 客户端
暂无数据
+
+
+
+ +
+ +
+
+ + + + + \ No newline at end of file diff --git a/internal/web/static/style.css b/internal/web/static/style.css new file mode 100644 index 0000000..9c93d79 --- /dev/null +++ b/internal/web/static/style.css @@ -0,0 +1,358 @@ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", sans-serif; + background: #0f172a; + min-height: 100vh; + padding: 20px; + color: #f1f5f9; +} + +.container { + max-width: 1600px; + margin: 0 auto; +} + +header { + background: #1e293b; + padding: 24px 32px; + border-radius: 12px; + border: 1px solid #334155; + margin-bottom: 24px; + display: flex; + justify-content: space-between; + align-items: center; +} + +h1 { + font-size: 1.875rem; + font-weight: 600; + color: #f1f5f9; +} + +.update-info { + display: flex; + align-items: center; + gap: 12px; +} + +#last-update { + color: #94a3b8; + font-size: 0.875rem; +} + +.dashboard { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(min(100%, 600px), 1fr)); + gap: 16px; +} + +.card { + background: #1e293b; + padding: 24px; + border-radius: 12px; + border: 1px solid #334155; +} + +.card.full-width { + grid-column: 1 / -1; +} + +h2 { + color: #f1f5f9; + margin-bottom: 20px; + font-size: 1.125rem; + font-weight: 600; + border-bottom: 1px solid #334155; + padding-bottom: 12px; +} + +.card-header { + display: flex; + justify-content: space-between; + align-items: center; + gap: 12px; + flex-wrap: wrap; + margin-bottom: 20px; +} + +.card-header h2 { + margin-bottom: 0; + padding-bottom: 0; + border-bottom: none; + flex: 1; + min-width: 200px; +} + +.stats-controls { + display: flex; + align-items: center; + gap: 12px; +} + +.stats-duration { + color: #94a3b8; + font-size: 0.875rem; +} + +.reset-btn { + background: #f1f5f9; + color: #0f172a; + border: none; + padding: 8px 16px; + border-radius: 6px; + cursor: pointer; + font-size: 0.875rem; + font-weight: 500; + transition: all 0.15s; +} + +.reset-btn:hover:not(:disabled) { + background: #cbd5e1; +} + +.reset-btn:disabled { + cursor: not-allowed; + opacity: 0.5; +} + +.stats-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(min(100%, 200px), 1fr)); + gap: 12px; +} + +.stat-item { + background: #0f172a; + padding: 16px; + border-radius: 8px; + border: 1px solid #334155; + display: flex; + flex-direction: column; + gap: 8px; + transition: border-color 0.15s; +} + +.stat-item:hover { + border-color: #475569; +} + +.stat-item.highlight { + background: #1e293b; + border-color: #3b82f6; +} + +.stat-item.success { + background: #1e293b; + border-color: #22c55e; +} + +.stat-item.info { + background: #1e293b; + border-color: #06b6d4; +} + +.stat-item.warning { + background: #1e293b; + border-color: #f59e0b; +} + +.stat-label { + font-size: 0.875rem; + color: #94a3b8; +} + +.stat-value { + font-size: 1.5rem; + font-weight: 600; + color: #f1f5f9; +} + +.table-container { + overflow-x: auto; +} + +table { + width: 100%; + border-collapse: collapse; + margin-top: 8px; +} + +thead { + background: #0f172a; + border-bottom: 1px solid #334155; +} + +th { + padding: 12px 16px; + text-align: left; + font-weight: 500; + font-size: 0.875rem; + color: #94a3b8; +} + +tbody tr { + border-bottom: 1px solid #334155; + transition: background 0.15s; +} + +tbody tr:hover { + background: #0f172a; +} + +td { + padding: 12px 16px; + font-size: 0.875rem; +} + +.no-data { + text-align: center; + color: #64748b; + padding: 24px; +} + +.error-high { + color: #ef4444; + font-weight: 500; +} + +.rank-cell { + font-weight: 600; + text-align: center; + width: 60px; +} + +#top-clients-table th:nth-child(1), +#top-clients-table td:nth-child(1), +#top-domains-table th:nth-child(1), +#top-domains-table td:nth-child(1) { + width: 60px; +} + +.rank-1 { + background: rgba(234, 179, 8, 0.1); +} + +.rank-1 .rank-cell { + color: #eab308; +} + +.rank-2 { + background: rgba(148, 163, 184, 0.1); +} + +.rank-2 .rank-cell { + color: #94a3b8; +} + +.rank-3 { + background: rgba(251, 146, 60, 0.1); +} + +.rank-3 .rank-cell { + color: #fb923c; +} + +.domain-cell { + max-width: 300px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +footer { + text-align: center; + color: #94a3b8; + margin-top: 24px; + padding: 20px; + font-size: 0.875rem; +} + +.footer-content { + display: flex; + justify-content: center; + align-items: center; + gap: 16px; + flex-wrap: wrap; +} + +.version-info { + display: flex; + align-items: center; + gap: 8px; + background: #1e293b; + padding: 6px 12px; + border-radius: 6px; + border: 1px solid #334155; +} + +#version-display { + font-weight: 500; +} + +.update-btn { + background: transparent; + color: #94a3b8; + border: 1px solid #334155; + width: 28px; + height: 28px; + border-radius: 6px; + cursor: pointer; + font-size: 14px; + display: flex; + align-items: center; + justify-content: center; + transition: all 0.15s; +} + +.update-btn:hover:not(:disabled) { + background: #334155; + color: #f1f5f9; +} + +.update-btn:disabled { + cursor: not-allowed; + opacity: 0.5; +} + +@media (max-width: 768px) { + body { + padding: 12px; + } + + header { + flex-direction: column; + gap: 12px; + padding: 16px; + } + + h1 { + font-size: 1.5rem; + } + + .card { + padding: 16px; + } + + .stats-grid { + grid-template-columns: 1fr; + } + + table { + font-size: 0.8rem; + min-width: 500px; + } + + th, td { + padding: 8px; + } + + #upstream-table th:nth-child(5), + #upstream-table td:nth-child(5) { + display: none; + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..154dbdf --- /dev/null +++ b/main.go @@ -0,0 +1,276 @@ +package main + +import ( + "errors" + "log" + "math/rand" + "net" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/blang/semver" + "github.com/miekg/dns" + "github.com/rhysd/go-github-selfupdate/selfupdate" + "github.com/yl2chen/cidranger" + + "godns/internal/handler" + "godns/internal/model" + "godns/internal/stats" + "godns/internal/web" + "godns/pkg/doh" + "godns/pkg/logger" +) + +var ( + version string + + config *model.Config + dataPath string +) + +func main() { + dataPath = detectDataPath() + + ipRanger := loadIPRanger(dataPath + "china_ip_list.txt") + + // 先创建一个临时 logger 用于读取配置 + tempLogger := logger.New(false) + + config = &model.Config{} + if err := config.ReadInConfig(dataPath+"/config.json", ipRanger, tempLogger); err != nil { + panic(err) + } + + // 设置默认 Web 监听地址 + if config.WebAddr == "" { + config.WebAddr = "0.0.0.0:8854" + } + + // 根据配置创建正式的 logger 和 stats 实例 + debugLogger := logger.New(config.Debug) + statsRecorder := stats.NewStats() + + // 加载持久化的统计数据 + if err := statsRecorder.Load(dataPath); err != nil { + log.Printf("Failed to load stats from disk: %v", err) + } else { + log.Printf("Stats loaded successfully from disk") + } + + // 更新 upstreams 的 logger 为正式的 logger + for i := 0; i < len(config.Bootstrap); i++ { + config.Bootstrap[i].SetLogger(debugLogger) + } + for i := 0; i < len(config.Upstreams); i++ { + config.Upstreams[i].SetLogger(debugLogger) + } + + // Bootstrap handler 不需要缓存,只是用于初始化连接 + bootstrapHandler := handler.NewHandler(model.StrategyAnyResult, false, config.Bootstrap, dataPath, debugLogger, nil) + + for i := 0; i < len(config.Upstreams); i++ { + config.Upstreams[i].InitConnectionPool(bootstrapHandler.LookupIP) + } + + server := &dns.Server{Addr: config.ServeAddr, Net: "udp"} + serverTCP := &dns.Server{Addr: config.ServeAddr, Net: "tcp"} + + // 只有 upstream handler 需要缓存 + upstreamHandler := handler.NewHandler(config.Strategy, config.BuiltInCache, config.Upstreams, dataPath, debugLogger, statsRecorder) + dns.HandleFunc(".", upstreamHandler.HandleRequest) + + // Setup graceful shutdown + defer func() { + // 保存统计数据 + log.Printf("Saving stats before shutdown...") + if err := statsRecorder.Save(dataPath); err != nil { + log.Printf("Error saving stats: %v", err) + } else { + log.Printf("Stats saved successfully") + } + + // 关闭缓存 + if err := upstreamHandler.Close(); err != nil { + log.Printf("Error closing cache: %v", err) + } + }() + + log.Println("==== DNS Server ====") + log.Println("端口:", config.ServeAddr) + log.Println("模式:", config.StrategyName()) + log.Println("数据:", dataPath) + if config.BuiltInCache { + log.Println("启用 BadgerDB 缓存: 最大 40MB") + } else { + log.Println("禁用缓存") + } + + log.Println("版本:", version) + + // 创建更新检查通道 + checkUpdateCh := make(chan struct{}, 1) + + // 启动 Web 服务(监控面板 + DoH + pprof) + webServerHandler := http.NewServeMux() + + // 注册监控面板路由 + var webUsername, webPassword string + if config.WebAuth != nil { + webUsername = config.WebAuth.Username + webPassword = config.WebAuth.Password + } + webHandler := web.NewHandler(statsRecorder, version, checkUpdateCh, debugLogger, webUsername, webPassword) + webHandler.RegisterRoutes(webServerHandler) + + // 如果启用 DoH,注册 DoH 路由 + if config.DohServer != nil { + dohServer := doh.NewServer(config.DohServer.Username, config.DohServer.Password, upstreamHandler.HandleDnsMsg, statsRecorder) + dohServer.RegisterRoutes(webServerHandler) + log.Printf("DoH 服务: http://%s/dns-query", config.WebAddr) + } + + // 如果启用 profiling,注册 pprof 路由 + if config.Profiling { + webServerHandler.HandleFunc("/debug/", http.DefaultServeMux.ServeHTTP) + log.Printf("性能分析: http://%s/debug/pprof/", config.WebAddr) + } + + go http.ListenAndServe(config.WebAddr, webServerHandler) + log.Printf("监控面板: http://%s/", config.WebAddr) + + // 定时保存统计数据(使用配置的间隔) + statsSaveTicker := time.NewTicker(time.Duration(config.StatsSaveInterval) * time.Minute) + defer statsSaveTicker.Stop() + + go func() { + for range statsSaveTicker.C { + if err := statsRecorder.Save(dataPath); err != nil { + debugLogger.Printf("Failed to save stats to disk: %v", err) + } else { + debugLogger.Printf("Stats saved successfully to disk") + } + } + }() + + stopCh := make(chan error) + + // 启动后台更新检查 + go checkUpdate(checkUpdateCh, stopCh, debugLogger) + + // 定时触发更新检查(生产者1:定时器) + if version != "" { + go func() { + // 启动时立即检查一次 + select { + case checkUpdateCh <- struct{}{}: + default: + } + + // 定时检查 + ticker := time.NewTicker(time.Duration(40+rand.Intn(20)) * time.Minute) + defer ticker.Stop() + for range ticker.C { + select { + case checkUpdateCh <- struct{}{}: + default: + // 如果通道已满,跳过本次 + } + } + }() + } + + go func() { + stopCh <- server.ListenAndServe() + }() + go func() { + stopCh <- serverTCP.ListenAndServe() + }() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + log.Println("Shutting down...") + stopCh <- errors.New("shutdown signal received") + }() + + log.Printf("server stopped: %+v", <-stopCh) +} + +// checkUpdate 监听 channel 触发更新检查 +func checkUpdate(checkCh <-chan struct{}, stopCh chan<- error, debugLogger logger.Logger) { + for range checkCh { + // 如果 version 为空,使用默认值 + ver := version + if ver == "" { + ver = "0.0.0" + } + v := semver.MustParse(ver) + latest, err := selfupdate.UpdateSelf(v, "xofine/godns") + if err != nil { + debugLogger.Printf("Error checking for updates: %v", err) + continue + } + if latest.Version.Equals(v) { + debugLogger.Printf("No update available, current version: %s", v) + } else { + log.Printf("Updated to version: %s", latest.Version) + stopCh <- errors.New("Server upgraded to " + latest.Version.String()) + return + } + } +} + +func loadIPRanger(path string) cidranger.Ranger { + ipRanger := cidranger.NewPCTrieRanger() + + content, err := os.ReadFile(path) + if err != nil { + panic(err) + } + lines := strings.Split(string(content), "\n") + + for i := 0; i < len(lines); i++ { + if strings.TrimSpace(lines[i]) == "" { + continue + } + _, network, err := net.ParseCIDR(lines[i]) + if err != nil { + panic(err) + } + if err := ipRanger.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil { + panic(err) + } + } + + return ipRanger +} + +func detectDataPath() string { + ex, err := os.Executable() + if err != nil { + panic(err) + } + pwd, err := os.Getwd() + if err != nil { + panic(err) + } + pathList := []string{filepath.Dir(ex), pwd} + + for _, path := range pathList { + if f, err := os.Stat(path + "/data/china_ip_list.txt"); err == nil { + if f.Size() == 1024*200 { + panic("离线IP库 china_ip_list.txt 文件损坏,请重新下载") + } + return path + "/data/" + } + } + + panic("没有检测到IP数据 data/china_ip_list.txt") +} diff --git a/pkg/doh/client.go b/pkg/doh/client.go new file mode 100644 index 0000000..ec3413c --- /dev/null +++ b/pkg/doh/client.go @@ -0,0 +1,173 @@ +package doh + +import ( + "context" + "encoding/base64" + "io" + "net" + "net/http" + "net/http/httptrace" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/proxy" +) + +const ( + dohMediaType = "application/dns-message" +) + +// Logger 定义可选的日志接口 +type Logger interface { + Printf(format string, v ...interface{}) +} + +type clientOptions struct { + timeout time.Duration + server string + bootstrap func(domain string) (net.IP, error) + getDialer func(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error) + logger Logger +} + +type ClientOption func(*clientOptions) error + +func WithTimeout(t time.Duration) ClientOption { + return func(o *clientOptions) error { + o.timeout = t + return nil + } +} + +func WithSocksProxy(getDialer func(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error)) ClientOption { + return func(o *clientOptions) error { + o.getDialer = getDialer + return nil + } +} + +func WithServer(server string) ClientOption { + return func(o *clientOptions) error { + o.server = server + return nil + } +} + +func WithBootstrap(resolver func(domain string) (net.IP, error)) ClientOption { + return func(o *clientOptions) error { + o.bootstrap = resolver + return nil + } +} + +func WithLogger(logger Logger) ClientOption { + return func(o *clientOptions) error { + o.logger = logger + return nil + } +} + +type Client struct { + opt *clientOptions + cli *http.Client + traceCtx context.Context +} + +func NewClient(opts ...ClientOption) *Client { + o := new(clientOptions) + for _, f := range opts { + f(o) + } + + clientTrace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + if o.logger != nil { + o.logger.Printf("http conn was reused: %t", info.Reused) + } + }, + } + + var transport *http.Transport + + if o.bootstrap != nil { + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + urls := strings.Split(address, ":") + ipv4, err := o.bootstrap(urls[0]) + if err != nil { + return nil, errors.Wrap(err, "bootstrap") + } + urls[0] = ipv4.String() + + if o.getDialer != nil { + dialer, _, err := o.getDialer(&net.Dialer{ + Timeout: o.timeout, + }) + if err != nil { + return nil, err + } + return dialer.Dial("tcp", strings.Join(urls, ":")) + } + + return (&net.Dialer{ + Timeout: o.timeout, + }).DialContext(ctx, network, strings.Join(urls, ":")) + }, + } + } + + return &Client{ + opt: o, + traceCtx: httptrace.WithClientTrace(context.Background(), clientTrace), + cli: &http.Client{ + Transport: transport, + Timeout: o.timeout, + }, + } +} + +func (c *Client) Exchange(req *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) { + var ( + buf []byte + begin = time.Now() + origID = req.Id + hreq *http.Request + ) + + // Set DNS ID as zero accoreding to RFC8484 (cache friendly) + req.Id = 0 + buf, err = req.Pack() + if err != nil { + return + } + + hreq, err = http.NewRequestWithContext(c.traceCtx, http.MethodGet, c.opt.server+"?dns="+base64.RawURLEncoding.EncodeToString(buf), nil) + if err != nil { + return + } + hreq.Header.Add("Accept", dohMediaType) + hreq.Header.Add("User-Agent", "godns-doh-client/0.1") + + resp, err := c.cli.Do(hreq) + if err != nil { + return + } + defer resp.Body.Close() + + content, err := io.ReadAll(resp.Body) + if err != nil { + return + } + if resp.StatusCode != http.StatusOK { + err = errors.New("DoH query failed: " + string(content)) + return + } + + r = new(dns.Msg) + err = r.Unpack(content) + r.Id = origID + rtt = time.Since(begin) + return +} diff --git a/pkg/doh/server.go b/pkg/doh/server.go new file mode 100644 index 0000000..36009e3 --- /dev/null +++ b/pkg/doh/server.go @@ -0,0 +1,132 @@ +package doh + +import ( + "encoding/base64" + "net" + "net/http" + "strings" + + "github.com/miekg/dns" + "godns/internal/stats" +) + +type DoHServer struct { + username, password string + handler func(req *dns.Msg, clientIP, domain string) *dns.Msg + stats stats.StatsRecorder +} + +func NewServer(username, password string, handler func(req *dns.Msg, clientIP, domain string) *dns.Msg, statsRecorder stats.StatsRecorder) *DoHServer { + return &DoHServer{ + username: username, + password: password, + handler: handler, + stats: statsRecorder, + } +} + +// RegisterRoutes 注册 DoH 路由到现有的 HTTP 服务器 +func (s *DoHServer) RegisterRoutes(mux *http.ServeMux) { + mux.HandleFunc("/dns-query", s.handleQuery) +} + +func (s *DoHServer) handleQuery(w http.ResponseWriter, r *http.Request) { + if s.username != "" && s.password != "" { + username, password, ok := r.BasicAuth() + if !ok || username != s.username || password != s.password { + w.Header().Set("WWW-Authenticate", `Basic realm="dns"`) + w.WriteHeader(http.StatusUnauthorized) + return + } + } + + accept := r.Header.Get("Accept") + if accept != dohMediaType { + w.WriteHeader(http.StatusUnsupportedMediaType) + w.Write([]byte("unsupported media type: " + accept)) + return + } + + query := r.URL.Query().Get("dns") + if query == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + data, err := base64.RawURLEncoding.DecodeString(query) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + msg := new(dns.Msg) + if err := msg.Unpack(data); err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + // 记录 DoH 查询统计 + if s.stats != nil { + s.stats.RecordDoHQuery() + } + + // 提取客户端 IP + clientIP := extractClientIP(r) + + // 提取域名 + var domain string + if len(msg.Question) > 0 { + domain = msg.Question[0].Name + } + + resp := s.handler(msg, clientIP, domain) + if resp == nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("nil response")) + return + } + + data, err = resp.Pack() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + + w.Header().Set("Content-Type", dohMediaType) + w.Write(data) +} + +// extractClientIP 从 HTTP 请求中提取真实的客户端 IP +func extractClientIP(r *http.Request) string { + // 1. 优先检查 X-Forwarded-For(适用于多层代理) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // X-Forwarded-For 格式: client, proxy1, proxy2 + // 取第一个 IP(最原始的客户端 IP) + parts := strings.Split(xff, ",") + if len(parts) > 0 { + clientIP := strings.TrimSpace(parts[0]) + // 验证是否为有效 IP + if ip := net.ParseIP(clientIP); ip != nil { + return clientIP + } + } + } + + // 2. 检查 X-Real-IP(单层代理常用) + if xri := r.Header.Get("X-Real-IP"); xri != "" { + if ip := net.ParseIP(xri); ip != nil { + return xri + } + } + + // 3. 使用 RemoteAddr,需要去掉端口号 + if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return host + } + + // 4. 如果无法解析端口,直接返回(可能已经是纯 IP) + return r.RemoteAddr +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go new file mode 100644 index 0000000..1c82d96 --- /dev/null +++ b/pkg/logger/logger.go @@ -0,0 +1,37 @@ +package logger + +import ( + "log" + "os" +) + +// Logger 定义日志接口 +type Logger interface { + Printf(format string, v ...interface{}) + Println(v ...interface{}) +} + +// DebugLogger 实现 Logger 接口,支持调试模式 +type DebugLogger struct { + Debug bool +} + +// New 创建新的日志实例 +func New(debug bool) Logger { + if !debug { + log.SetOutput(os.Stdout) + } + return &DebugLogger{Debug: debug} +} + +func (l *DebugLogger) Printf(format string, v ...interface{}) { + if l.Debug { + log.Printf(format, v...) + } +} + +func (l *DebugLogger) Println(v ...interface{}) { + if l.Debug { + log.Println(v...) + } +} diff --git a/pkg/qqwry/qqwry.go b/pkg/qqwry/qqwry.go new file mode 100644 index 0000000..c526aed --- /dev/null +++ b/pkg/qqwry/qqwry.go @@ -0,0 +1,175 @@ +package qqwry + +import ( + "bytes" + "encoding/binary" + "errors" + "io/ioutil" + "net" + "strings" + "sync" + + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" +) + +var ( + data []byte + dataLen uint32 + ipCache *sync.Map +) + +const ( + indexLen = 7 + redirectMode1 = 0x01 + redirectMode2 = 0x02 +) + +type cache struct { + City string + Isp string +} + +func byte3ToUInt32(data []byte) uint32 { + i := uint32(data[0]) & 0xff + i |= (uint32(data[1]) << 8) & 0xff00 + i |= (uint32(data[2]) << 16) & 0xff0000 + return i +} + +func gb18030Decode(src []byte) string { + in := bytes.NewReader(src) + out := transform.NewReader(in, simplifiedchinese.GB18030.NewDecoder()) + d, _ := ioutil.ReadAll(out) + return string(d) +} + +// QueryIP 从内存或缓存查询IP +func QueryIP(ip net.IP) (city string, isp string, err error) { + ip32 := binary.BigEndian.Uint32(ip) + + if ipCache != nil { + if v, ok := ipCache.Load(ip32); ok { + city = v.(cache).City + isp = v.(cache).Isp + return + } + } + + posA := binary.LittleEndian.Uint32(data[:4]) + posZ := binary.LittleEndian.Uint32(data[4:8]) + var offset uint32 = 0 + for { + mid := posA + (((posZ-posA)/indexLen)>>1)*indexLen + buf := data[mid : mid+indexLen] + _ip := binary.LittleEndian.Uint32(buf[:4]) + if posZ-posA == indexLen { + offset = byte3ToUInt32(buf[4:]) + buf = data[mid+indexLen : mid+indexLen+indexLen] + if ip32 < binary.LittleEndian.Uint32(buf[:4]) { + break + } else { + offset = 0 + break + } + } + if _ip > ip32 { + posZ = mid + } else if _ip < ip32 { + posA = mid + } else if _ip == ip32 { + offset = byte3ToUInt32(buf[4:]) + break + } + } + if offset <= 0 { + err = errors.New("ip not found") + return + } + posM := offset + 4 + mode := data[posM] + var ispPos uint32 + switch mode { + case redirectMode1: + posC := byte3ToUInt32(data[posM+1 : posM+4]) + mode = data[posC] + posCA := posC + if mode == redirectMode2 { + posCA = byte3ToUInt32(data[posC+1 : posC+4]) + posC += 4 + } + for i := posCA; i < dataLen; i++ { + if data[i] == 0 { + city = string(data[posCA:i]) + break + } + } + if mode != redirectMode2 { + posC += uint32(len(city) + 1) + } + ispPos = posC + case redirectMode2: + posCA := byte3ToUInt32(data[posM+1 : posM+4]) + for i := posCA; i < dataLen; i++ { + if data[i] == 0 { + city = string(data[posCA:i]) + break + } + } + ispPos = offset + 8 + default: + posCA := offset + 4 + for i := posCA; i < dataLen; i++ { + if data[i] == 0 { + city = string(data[posCA:i]) + break + } + } + ispPos = offset + uint32(5+len(city)) + } + if city != "" { + city = strings.TrimSpace(gb18030Decode([]byte(city))) + } + ispMode := data[ispPos] + if ispMode == redirectMode1 || ispMode == redirectMode2 { + ispPos = byte3ToUInt32(data[ispPos+1 : ispPos+4]) + } + if ispPos > 0 { + for i := ispPos; i < dataLen; i++ { + if data[i] == 0 { + isp = string(data[ispPos:i]) + if isp != "" { + if strings.Contains(isp, "CZ88.NET") { + isp = "" + } else { + isp = strings.TrimSpace(gb18030Decode([]byte(isp))) + } + } + break + } + } + } + if ipCache != nil { + ipCache.Store(ip32, cache{City: city, Isp: isp}) + } + return +} + +// LoadData 从内存加载IP数据库 +func LoadData(database []byte) { + data = database + dataLen = uint32(len(data)) +} + +// LoadFile 从文件加载IP数据库 +func LoadFile(filepath string, useCache bool) (err error) { + data, err = ioutil.ReadFile(filepath) + if err != nil { + return + } + dataLen = uint32(len(data)) + if useCache { + ipCache = new(sync.Map) + } + return +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 0000000..a34d990 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,45 @@ +package utils + +import "strings" + +func ParseRules(rulesRaw []string) [][]string { + var rules [][]string + for _, r := range rulesRaw { + if r == "" { + continue + } + if !strings.HasSuffix(r, ".") { + r += "." + } + rules = append(rules, strings.Split(r, ".")) + } + return rules +} + +func HasMatchedRule(rules [][]string, domain string) bool { + var hasMatch bool +OUTER: + for _, m := range rules { + domainSplited := strings.Split(domain, ".") + i := len(m) - 1 + j := len(domainSplited) - 1 + // 从根域名开始匹配 + for i >= 0 && j >= 0 { + if m[i] != domainSplited[j] && m[i] != "" { + continue OUTER + } + i-- + j-- + } + // 如果规则中还有剩余,但是域名已经匹配完了,检查规则最后一位是否是任意匹配 + if j != -1 && i == -1 && m[0] != "" { + continue OUTER + } + hasMatch = i == -1 + // 如果匹配到了,就不用再匹配了 + if hasMatch { + break + } + } + return hasMatch +}