update once

This commit is contained in:
XOF
2026-01-06 02:25:24 +08:00
commit 7bf4f27be3
25 changed files with 4587 additions and 0 deletions

247
internal/cache/badger_cache.go vendored Normal file
View File

@@ -0,0 +1,247 @@
package cache
import (
"fmt"
"path/filepath"
"time"
"github.com/dgraph-io/badger/v4"
"github.com/dgraph-io/badger/v4/options"
"github.com/miekg/dns"
"godns/pkg/logger"
)
// Cache 定义缓存接口
type Cache interface {
Get(key string) (*CachedMsg, bool)
Set(key string, msg *CachedMsg, ttl time.Duration) error
Delete(key string) error
Close() error
Stats() string
}
// CachedMsg represents a cached DNS message with expiration time
type CachedMsg struct {
Msg *dns.Msg `json:"msg"`
Expires time.Time `json:"expires"`
}
// BadgerCache wraps BadgerDB for DNS query caching
type BadgerCache struct {
db *badger.DB
logger logger.Logger
}
// NewBadgerCache creates a new BadgerDB cache instance with optimized settings for embedded devices
func NewBadgerCache(dataPath string, log logger.Logger) (*BadgerCache, error) {
dbPath := filepath.Join(dataPath, "cache")
opts := badger.DefaultOptions(dbPath)
// 针对树莓派等嵌入式设备的优化配置(目标:总内存 ~32MB
// MemTable4MBBadgerDB 默认保持 2 个 MemTable
opts.MemTableSize = 4 << 20 // 4MB (内存占用 ~8MB)
// ValueLog4MB
opts.ValueLogFileSize = 4 << 20 // 4MB
// BlockCache16MB提升读取性能
opts.BlockCacheSize = 16 << 20 // 16MB
// IndexCache8MB加速索引查找
opts.IndexCacheSize = 8 << 20 // 8MB
// Level 0 tables
opts.NumLevelZeroTables = 2
opts.NumLevelZeroTablesStall = 4
// 关闭压缩,节省 CPU
opts.Compression = options.None
// DNS 响应通常较小,内联存储减少磁盘访问
opts.ValueThreshold = 512
// 异步写入,提高性能
opts.SyncWrites = false
// ValueLog 条目数量
opts.ValueLogMaxEntries = 50000
// 压缩线程数
opts.NumCompactors = 2
// 禁用冲突检测,提升写入性能
opts.DetectConflicts = false
// 禁用内部日志
opts.Logger = nil
db, err := badger.Open(opts)
if err != nil {
return nil, fmt.Errorf("failed to open BadgerDB: %w", err)
}
cache := &BadgerCache{db: db, logger: log}
// Start garbage collection routines
go cache.runGC()
go cache.runCompaction()
return cache, nil
}
// Set stores a DNS message in the cache with the given key and TTL
func (bc *BadgerCache) Set(key string, msg *CachedMsg, ttl time.Duration) error {
// Pack DNS message to wire format
dnsData, err := msg.Msg.Pack()
if err != nil {
return fmt.Errorf("failed to pack DNS message: %w", err)
}
// 直接存储二进制数据8字节过期时间 + DNS wire format
// 避免 JSON 序列化开销
expiresBytes := make([]byte, 8)
// 使用 Unix 时间戳(秒)
expiresUnix := msg.Expires.Unix()
for i := 0; i < 8; i++ {
expiresBytes[i] = byte(expiresUnix >> (56 - i*8))
}
// 组合数据:过期时间 + DNS数据
data := append(expiresBytes, dnsData...)
return bc.db.Update(func(txn *badger.Txn) error {
entry := badger.NewEntry([]byte(key), data).WithTTL(ttl)
return txn.SetEntry(entry)
})
}
// Get retrieves a DNS message from the cache
func (bc *BadgerCache) Get(key string) (*CachedMsg, bool) {
var cachedMsg *CachedMsg
err := bc.db.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(key))
if err != nil {
return err
}
return item.Value(func(val []byte) error {
// 数据格式8字节过期时间 + DNS wire format
if len(val) < 8 {
return fmt.Errorf("invalid cache data: too short")
}
// 解析过期时间
var expiresUnix int64
for i := 0; i < 8; i++ {
expiresUnix = (expiresUnix << 8) | int64(val[i])
}
expires := time.Unix(expiresUnix, 0)
// 解析 DNS 消息
msg := new(dns.Msg)
if err := msg.Unpack(val[8:]); err != nil {
return fmt.Errorf("failed to unpack DNS message: %w", err)
}
cachedMsg = &CachedMsg{
Msg: msg,
Expires: expires,
}
return nil
})
})
if err != nil {
if err == badger.ErrKeyNotFound {
return nil, false
}
// 缓存数据损坏或格式不兼容,返回未命中,后续 Set 会覆盖
bc.logger.Printf("Cache get error for key %s: %v", key, err)
return nil, false
}
return cachedMsg, true
}
// Delete removes a key from the cache
func (bc *BadgerCache) Delete(key string) error {
return bc.db.Update(func(txn *badger.Txn) error {
return txn.Delete([]byte(key))
})
}
// Close closes the BadgerDB instance
func (bc *BadgerCache) Close() error {
return bc.db.Close()
}
// runGC runs garbage collection periodically to clean up expired entries in value log
func (bc *BadgerCache) runGC() {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for range ticker.C {
// Run GC multiple times until no more rewrite is needed
gcCount := 0
for {
err := bc.db.RunValueLogGC(0.5)
if err != nil {
if err != badger.ErrNoRewrite {
bc.logger.Printf("BadgerDB GC error: %v", err)
}
break
}
gcCount++
// Limit GC runs and add delay to prevent CPU hogging
if gcCount >= 10 {
bc.logger.Printf("BadgerDB GC: reached max runs limit (10)")
break
}
// Sleep briefly between GC cycles to reduce CPU usage
time.Sleep(500 * time.Millisecond)
}
if gcCount > 0 {
bc.logger.Printf("BadgerDB GC: completed %d runs", gcCount)
}
// Check disk usage and clean if necessary
bc.checkAndCleanDiskUsage()
}
}
// runCompaction runs LSM tree compaction periodically to clean up expired key metadata
func (bc *BadgerCache) runCompaction() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
err := bc.db.Flatten(1)
if err != nil {
bc.logger.Printf("BadgerDB compaction error: %v", err)
}
}
}
// checkAndCleanDiskUsage checks if cache exceeds size limit and triggers cleanup
func (bc *BadgerCache) checkAndCleanDiskUsage() {
lsm, vlog := bc.db.Size()
totalSize := lsm + vlog
maxSize := int64(50 << 20) // 50MB limit (适合家用路由器等嵌入式设备)
if totalSize > maxSize {
bc.logger.Printf("Cache size %d MB exceeds limit %d MB, triggering cleanup", totalSize>>20, maxSize>>20)
// Force compaction to reduce size
if err := bc.db.Flatten(2); err != nil {
bc.logger.Printf("BadgerDB flatten error: %v", err)
}
}
}
// Stats returns cache statistics
func (bc *BadgerCache) Stats() string {
lsm, vlog := bc.db.Size()
return fmt.Sprintf("LSM size: %d bytes, Value log size: %d bytes", lsm, vlog)
}

750
internal/handler/handler.go Normal file
View File

@@ -0,0 +1,750 @@
package handler
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"godns/internal/cache"
"godns/internal/model"
"godns/internal/stats"
"godns/pkg/logger"
)
type Handler struct {
strategy int
commonUpstreams, specialUpstreams []*model.Upstream
builtInCache cache.Cache
logger logger.Logger
stats stats.StatsRecorder
}
func NewHandler(strategy int, builtInCache bool,
upstreams []*model.Upstream,
dataPath string,
log logger.Logger,
statsRecorder stats.StatsRecorder) *Handler {
var c cache.Cache
if builtInCache {
var err error
c, err = cache.NewBadgerCache(dataPath, log)
if err != nil {
log.Printf("Failed to initialize BadgerDB cache: %v", err)
log.Printf("Cache will be disabled")
c = nil
} else {
log.Printf("BadgerDB cache initialized successfully at %s", dataPath)
}
}
var commonUpstreams, specialUpstreams []*model.Upstream
for i := 0; i < len(upstreams); i++ {
if len(upstreams[i].Match) > 0 {
specialUpstreams = append(specialUpstreams, upstreams[i])
} else {
commonUpstreams = append(commonUpstreams, upstreams[i])
}
}
return &Handler{
strategy: strategy,
commonUpstreams: commonUpstreams,
specialUpstreams: specialUpstreams,
builtInCache: c,
logger: log,
stats: statsRecorder,
}
}
func (h *Handler) matchedUpstreams(req *dns.Msg) []*model.Upstream {
if len(req.Question) == 0 {
return h.commonUpstreams
}
q := req.Question[0]
var matchedUpstreams []*model.Upstream
for i := 0; i < len(h.specialUpstreams); i++ {
if h.specialUpstreams[i].IsMatch(q.Name) {
matchedUpstreams = append(matchedUpstreams, h.specialUpstreams[i])
}
}
if len(matchedUpstreams) > 0 {
return matchedUpstreams
}
return h.commonUpstreams
}
func (h *Handler) LookupIP(host string) (ip net.IP, err error) {
if ip = net.ParseIP(host); ip != nil {
return ip, nil
}
if !strings.HasSuffix(host, ".") {
host += "."
}
m := new(dns.Msg)
m.Id = dns.Id()
m.RecursionDesired = true
m.Question = make([]dns.Question, 1)
m.Question[0] = dns.Question{Name: host, Qtype: dns.TypeA, Qclass: dns.ClassINET}
res := h.exchange(m)
// 取一个 IPv4 地址
for i := 0; i < len(res.Answer); i++ {
if aRecord, ok := res.Answer[i].(*dns.A); ok {
ip = aRecord.A
}
}
// 选取最后一个(一般是备用,存活率高一些)
if ip == nil {
err = errors.New("no ipv4 address found")
}
h.logger.Printf("bootstrap LookupIP: %s %v --> %s %v", host, res.Answer, ip, err)
return
}
// removeEDNS 清理请求中的 EDNS 客户端子网信息
func (h *Handler) removeEDNS(req *dns.Msg) {
opt := req.IsEdns0()
if opt == nil {
return
}
// 过滤掉 EDNS Client Subnet 选项
var newOptions []dns.EDNS0
for _, option := range opt.Option {
if _, ok := option.(*dns.EDNS0_SUBNET); !ok {
// 保留非 ECS 的其他选项
newOptions = append(newOptions, option)
} else {
h.logger.Printf("Removed EDNS Client Subnet from request")
}
}
opt.Option = newOptions
}
func (h *Handler) exchange(req *dns.Msg) *dns.Msg {
// 清理 EDNS 客户端子网信息
h.removeEDNS(req)
var msgs []*dns.Msg
switch h.strategy {
case model.StrategyFullest:
msgs = h.getTheFullestResults(req)
case model.StrategyFastest:
msgs = h.getTheFastestResults(req)
case model.StrategyAnyResult:
msgs = h.getAnyResult(req)
}
var res *dns.Msg
for i := 0; i < len(msgs); i++ {
if msgs[i] == nil {
continue
}
if res == nil {
res = msgs[i]
continue
}
res.Answer = append(res.Answer, msgs[i].Answer...)
}
if res == nil {
// 如果全部上游挂了要返回错误
res = new(dns.Msg)
res.Rcode = dns.RcodeServerFailure
} else {
res.Answer = uniqueAnswer(res.Answer)
}
return res
}
func getDnsRequestCacheKey(m *dns.Msg) string {
var dnssec string
if o := m.IsEdns0(); o != nil {
// 区分 DNSSEC 请求,避免将非 DNSSEC 响应返回给需要 DNSSEC 的客户端
if o.Do() {
dnssec = "DO"
}
// 服务多区域的公共dns使用
// for _, s := range o.Option {
// switch e := s.(type) {
// case *dns.EDNS0_SUBNET:
// edns = e.Address.String()
// }
// }
}
return fmt.Sprintf("%s#%d#%s", model.GetDomainNameFromDnsMsg(m), m.Question[0].Qtype, dnssec)
}
func getDnsResponseTtl(m *dns.Msg) time.Duration {
var ttl uint32
if len(m.Answer) > 0 {
ttl = m.Answer[0].Header().Ttl
}
if ttl < 60 {
ttl = 60 // 最小 ttl 1 分钟
} else if ttl > 3600 {
ttl = 3600 // 最大 ttl 1 小时
}
return time.Duration(ttl) * time.Second
}
// shouldCacheResponse 判断响应是否应该被缓存
func shouldCacheResponse(m *dns.Msg) bool {
// 不缓存服务器错误响应
if m.Rcode == dns.RcodeServerFailure {
return false
}
// 不缓存格式错误的响应
if m.Rcode == dns.RcodeFormatError {
return false
}
// NXDOMAIN (域名不存在) 可以缓存,但时间较短(由 getDnsResponseTtl 控制)
// NOERROR 和 NXDOMAIN 都可以缓存
return m.Rcode == dns.RcodeSuccess || m.Rcode == dns.RcodeNameError
}
// validateResponse 验证 DNS 响应,防止缓存投毒
// 返回 true 表示响应有效false 表示可能存在投毒风险
func validateResponse(req *dns.Msg, resp *dns.Msg, debugLogger logger.Logger) bool {
// 1. 检查响应是否为空
if resp == nil {
return false
}
// 2. 检查请求和响应的问题数量
if len(req.Question) == 0 || len(resp.Question) == 0 {
return true // 如果没有问题部分,跳过验证(某些响应可能没有问题部分)
}
// 3. 验证域名匹配(不区分大小写)
if !strings.EqualFold(req.Question[0].Name, resp.Question[0].Name) {
debugLogger.Printf("DNS response validation failed: domain mismatch - request: %s, response: %s",
req.Question[0].Name, resp.Question[0].Name)
return false
}
// 4. 验证查询类型匹配
if req.Question[0].Qtype != resp.Question[0].Qtype {
debugLogger.Printf("DNS response validation failed: qtype mismatch - request: %d, response: %d",
req.Question[0].Qtype, resp.Question[0].Qtype)
return false
}
// 5. 验证查询类别匹配(通常都是 IN - Internet
if req.Question[0].Qclass != resp.Question[0].Qclass {
debugLogger.Printf("DNS response validation failed: qclass mismatch - request: %d, response: %d",
req.Question[0].Qclass, resp.Question[0].Qclass)
return false
}
// 6. 验证 Answer 部分的域名(防止返回无关域名的记录)
requestDomain := strings.ToLower(strings.TrimSuffix(req.Question[0].Name, "."))
validDomains := make(map[string]bool)
validDomains[requestDomain] = true
// 第一遍:收集所有 CNAME 目标域名
for _, answer := range resp.Answer {
if answer.Header().Rrtype == dns.TypeCNAME {
if cname, ok := answer.(*dns.CNAME); ok {
cnameTarget := strings.ToLower(strings.TrimSuffix(cname.Target, "."))
validDomains[cnameTarget] = true
}
}
}
// 第二遍:验证所有应答记录
for _, answer := range resp.Answer {
answerDomain := strings.ToLower(strings.TrimSuffix(answer.Header().Name, "."))
// 检查应答记录的域名是否在有效域名列表中
if !validDomains[answerDomain] {
// 对于 CNAME 记录,域名必须是请求域名
if answer.Header().Rrtype == dns.TypeCNAME {
if answerDomain != requestDomain {
debugLogger.Printf("DNS response validation failed: CNAME domain mismatch - request: %s, CNAME: %s",
requestDomain, answerDomain)
return false
}
} else {
// 对于其他记录类型,记录警告但不拒绝(某些服务器可能返回额外记录)
debugLogger.Printf("DNS response validation warning: answer domain not in valid chain - request: %s, answer: %s (type: %d)",
requestDomain, answerDomain, answer.Header().Rrtype)
}
}
}
// 7. 检查 TTL 值的合理性(防止异常的 TTL 值)
for _, answer := range resp.Answer {
ttl := answer.Header().Ttl
// TTL 不应该超过 7 天604800 秒)
if ttl > 604800 {
debugLogger.Printf("DNS response validation warning: suspiciously high TTL: %d seconds for %s",
ttl, answer.Header().Name)
}
}
return true
}
// HandleDnsMsg 处理 DNS 查询的核心逻辑(支持缓存和统计)
// clientIP 和 domain 用于统计,如果为空则自动从请求中提取 domain
func (h *Handler) HandleDnsMsg(req *dns.Msg, clientIP, domain string) *dns.Msg {
h.logger.Printf("godns::request %+v\n", req)
// 记录查询统计
if h.stats != nil {
h.stats.RecordQuery()
// 提取域名(如果未提供)
if domain == "" && len(req.Question) > 0 {
domain = req.Question[0].Name
}
// 记录客户端查询
if clientIP != "" || domain != "" {
h.stats.RecordClientQuery(clientIP, domain)
}
}
// 检查缓存
var cacheKey string
var respCache *dns.Msg
if h.builtInCache != nil {
cacheKey = getDnsRequestCacheKey(req)
if v, ok := h.builtInCache.Get(cacheKey); ok {
if h.stats != nil {
h.stats.RecordCacheHit()
}
respCache = v.Msg.Copy()
if v.Expires.After(time.Now()) {
msg := replyUpdateTtl(req, respCache, uint32(time.Until(v.Expires).Seconds()))
if len(msg.Answer) > 0 {
return msg
}
}
} else {
if h.stats != nil {
h.stats.RecordCacheMiss()
}
}
}
// 从上游获取响应
resp := h.exchange(req)
if resp.Rcode == dns.RcodeServerFailure {
if h.stats != nil {
h.stats.RecordFailed()
}
// 上游失败时使用任何可用缓存(即使过期)作为降级
if respCache != nil {
msg := replyUpdateTtl(req, respCache, 12)
if len(msg.Answer) > 0 {
return msg
}
}
}
resp.SetReply(req)
h.logger.Printf("godns::resp: %+v\n", resp)
// 验证响应并缓存(防止缓存投毒)
if h.builtInCache != nil && shouldCacheResponse(resp) && validateResponse(req, resp, h.logger) {
ttl := getDnsResponseTtl(resp)
cachedMsg := &cache.CachedMsg{
Msg: resp,
Expires: time.Now().Add(ttl),
}
if err := h.builtInCache.Set(cacheKey, cachedMsg, ttl+time.Hour); err != nil {
h.logger.Printf("Failed to cache response: %v", err)
}
}
return resp
}
// extractClientIPFromDNS 从 DNS 请求中提取客户端 IP
// 优先级EDNS Client Subnet > RemoteAddr
func extractClientIPFromDNS(w dns.ResponseWriter, req *dns.Msg) string {
// 1. 优先检查 EDNS Client Subnet (ECS)
// ECS 是 DNS 协议标准,用于传递真实客户端 IP
if opt := req.IsEdns0(); opt != nil {
for _, option := range opt.Option {
if ecs, ok := option.(*dns.EDNS0_SUBNET); ok {
// ECS 中的 Address 就是客户端真实 IP
return ecs.Address.String()
}
}
}
// 2. 从 RemoteAddr 获取
var clientIP string
if addr := w.RemoteAddr(); addr != nil {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
clientIP = udpAddr.IP.String()
} else if tcpAddr, ok := addr.(*net.TCPAddr); ok {
clientIP = tcpAddr.IP.String()
}
}
return clientIP
}
func (h *Handler) HandleRequest(w dns.ResponseWriter, req *dns.Msg) {
// 提取客户端 IP
clientIP := extractClientIPFromDNS(w, req)
// 提取域名
var domain string
if len(req.Question) > 0 {
domain = req.Question[0].Name
}
// 调用核心处理逻辑
resp := h.HandleDnsMsg(req, clientIP, domain)
// 写入响应
if err := w.WriteMsg(resp); err != nil {
h.logger.Printf("WriteMsg error: %+v", err)
}
}
// uniqueAnswer 去除重复的 DNS 资源记录
// 基于域名、类型和记录数据进行去重,比字符串分割更高效和可靠
func uniqueAnswer(records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
seen := make(map[string]bool, len(records))
result := make([]dns.RR, 0, len(records))
for _, rr := range records {
if rr == nil {
continue
}
header := rr.Header()
if header == nil {
continue
}
// 构造唯一键:域名 + 类型 + 记录数据
// 使用 strings.Builder 优化字符串拼接性能
var builder strings.Builder
builder.Grow(128) // Pre-allocate reasonable capacity
var key string
switch v := rr.(type) {
case *dns.A:
builder.WriteString(header.Name)
builder.WriteString("|A|")
builder.WriteString(v.A.String())
key = builder.String()
case *dns.AAAA:
builder.WriteString(header.Name)
builder.WriteString("|AAAA|")
builder.WriteString(v.AAAA.String())
key = builder.String()
case *dns.CNAME:
builder.WriteString(header.Name)
builder.WriteString("|CNAME|")
builder.WriteString(v.Target)
key = builder.String()
case *dns.MX:
builder.WriteString(header.Name)
builder.WriteString("|MX|")
builder.WriteString(fmt.Sprintf("%d|%s", v.Preference, v.Mx))
key = builder.String()
case *dns.NS:
builder.WriteString(header.Name)
builder.WriteString("|NS|")
builder.WriteString(v.Ns)
key = builder.String()
case *dns.PTR:
builder.WriteString(header.Name)
builder.WriteString("|PTR|")
builder.WriteString(v.Ptr)
key = builder.String()
case *dns.TXT:
builder.WriteString(header.Name)
builder.WriteString("|TXT|")
builder.WriteString(strings.Join(v.Txt, "|"))
key = builder.String()
case *dns.SRV:
builder.WriteString(header.Name)
builder.WriteString("|SRV|")
builder.WriteString(fmt.Sprintf("%d|%d|%d|%s", v.Priority, v.Weight, v.Port, v.Target))
key = builder.String()
case *dns.SOA:
builder.WriteString(header.Name)
builder.WriteString("|SOA|")
builder.WriteString(v.Ns)
builder.WriteString("|")
builder.WriteString(v.Mbox)
key = builder.String()
default:
// 对于其他类型,回退到完整字符串表示
key = rr.String()
}
if !seen[key] {
seen[key] = true
result = append(result, rr)
}
}
return result
}
func (h *Handler) getTheFullestResults(req *dns.Msg) []*dns.Msg {
matchedUpstreams := h.matchedUpstreams(req)
var wg sync.WaitGroup
wg.Add(len(matchedUpstreams))
msgs := make([]*dns.Msg, len(matchedUpstreams))
for i := 0; i < len(matchedUpstreams); i++ {
go func(j int) {
defer wg.Done()
msg, _, err := matchedUpstreams[j].Exchange(req.Copy())
// 记录上游服务器统计
if h.stats != nil {
h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil)
}
if err != nil {
h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err)
return
}
if matchedUpstreams[j].IsValidMsg(msg) {
msgs[j] = msg
}
}(i)
}
wg.Wait()
return msgs
}
func (h *Handler) getTheFastestResults(req *dns.Msg) []*dns.Msg {
preferUpstreams := h.matchedUpstreams(req)
msgs := make([]*dns.Msg, len(preferUpstreams))
var mutex sync.Mutex
var finishedCount int
var finished bool
var freedomIndex, primaryIndex []int
var wg sync.WaitGroup
wg.Add(1)
for i := 0; i < len(preferUpstreams); i++ {
go func(j int) {
msg, _, err := preferUpstreams[j].Exchange(req.Copy())
// 记录上游服务器统计
if h.stats != nil {
h.stats.RecordUpstreamQuery(preferUpstreams[j].Address, err != nil)
}
if err != nil {
h.logger.Printf("upstream error %s: %v %s", preferUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err)
}
mutex.Lock()
defer mutex.Unlock()
finishedCount++
// 已经结束直接退出
if finished {
return
}
if err == nil {
if preferUpstreams[j].IsValidMsg(msg) {
if preferUpstreams[j].IsPrimary {
primaryIndex = append(primaryIndex, j)
} else {
freedomIndex = append(freedomIndex, j)
}
msgs[j] = msg
} else if preferUpstreams[j].IsPrimary {
// 策略:国内 DNS 返回了 国外 服务器,计数但是不记入结果,以 国外 DNS 为准
primaryIndex = append(primaryIndex, j)
}
}
// 全部结束直接退出
if finishedCount == len(preferUpstreams) {
finished = true
wg.Done()
return
}
// 两组 DNS 都有一个返回结果,退出
if len(primaryIndex) > 0 && len(freedomIndex) > 0 {
finished = true
wg.Done()
return
}
// 满足任一条件退出
// - 国内 DNS 返回了 国内 服务器
// - 国内 DNS 返回国外服务器 且 国外 DNS 有可用结果
if len(primaryIndex) > 0 && (msgs[primaryIndex[0]] != nil || len(freedomIndex) > 0) {
finished = true
wg.Done()
}
}(i)
}
wg.Wait()
return msgs
}
func (h *Handler) getAnyResult(req *dns.Msg) []*dns.Msg {
matchedUpstreams := h.matchedUpstreams(req)
var wg sync.WaitGroup
wg.Add(1)
msgs := make([]*dns.Msg, len(matchedUpstreams))
var mutex sync.Mutex
var finishedCount int
var finished bool
for i := 0; i < len(matchedUpstreams); i++ {
go func(j int) {
msg, _, err := matchedUpstreams[j].Exchange(req.Copy())
// 记录上游服务器统计
if h.stats != nil {
h.stats.RecordUpstreamQuery(matchedUpstreams[j].Address, err != nil)
}
if err != nil {
h.logger.Printf("upstream error %s: %v %s", matchedUpstreams[j].Address, model.GetDomainNameFromDnsMsg(req), err)
}
mutex.Lock()
defer mutex.Unlock()
finishedCount++
if finished {
return
}
// 已结束或任意上游返回成功时退出
if err == nil || finishedCount == len(matchedUpstreams) {
finished = true
msgs[j] = msg
wg.Done()
}
}(i)
}
wg.Wait()
return msgs
}
// Close properly shuts down the cache
func (h *Handler) Close() error {
if h.builtInCache != nil {
return h.builtInCache.Close()
}
return nil
}
// GetCacheStats returns cache statistics
func (h *Handler) GetCacheStats() string {
if h.builtInCache != nil {
return h.builtInCache.Stats()
}
return "Cache disabled"
}
// replyUpdateTtl 准备缓存响应以发送给客户端,执行必要的修正:
// 1. 设置正确的 Message ID通过 SetReply
// 2. 更新所有 RR 的 TTL 为剩余时间(最低 0
// 3. 调整 OPT RR 的 UDP size 为客户端请求的值
// 4. 清除 ECS Scope Length标记为缓存答案
// 5. 检查过期的 RRSIG 并移除
func replyUpdateTtl(req *dns.Msg, resp *dns.Msg, ttl uint32) *dns.Msg {
now := time.Now().Unix()
// 辅助函数:更新 RR 列表的 TTL并检测过期 RRSIG
updateRRs := func(rrs []dns.RR) []dns.RR {
var validRRs []dns.RR
for _, rr := range rrs {
header := rr.Header()
if header == nil {
continue
}
// 检查 RRSIG 是否过期
if rrsig, ok := rr.(*dns.RRSIG); ok {
if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration {
// RRSIG 已过期,跳过这条记录
continue
}
}
// 更新 TTL最低为 0
header.Ttl = ttl
validRRs = append(validRRs, rr)
}
return validRRs
}
// 更新所有部分的 TTL 并移除过期 RRSIG
resp.Answer = updateRRs(resp.Answer)
resp.Ns = updateRRs(resp.Ns)
// Extra 部分需要特殊处理 OPT RR
var validExtra []dns.RR
var reqOpt *dns.OPT
if reqOpt = req.IsEdns0(); reqOpt != nil {
// 客户端有 EDNS0获取其 UDP size
}
for _, rr := range resp.Extra {
if opt, ok := rr.(*dns.OPT); ok {
// 处理 OPT RR
if reqOpt != nil {
// 使用客户端请求的 UDP size
opt.SetUDPSize(reqOpt.UDPSize())
}
// 清除 ECS Scope Length
for i, option := range opt.Option {
if ecs, ok := option.(*dns.EDNS0_SUBNET); ok {
// 将 Scope Length 设为 0表示这是缓存答案
ecs.SourceScope = 0
opt.Option[i] = ecs
}
}
validExtra = append(validExtra, opt)
} else {
// 非 OPT RR正常更新 TTL 和检查 RRSIG
header := rr.Header()
if header != nil {
if rrsig, ok := rr.(*dns.RRSIG); ok {
if rrsig.Expiration > 0 && uint32(now) > rrsig.Expiration {
continue // 跳过过期的 RRSIG
}
}
header.Ttl = ttl
}
validExtra = append(validExtra, rr)
}
}
resp.Extra = validExtra
// SetReply 会设置正确的 Message ID 和其他响应标志
return resp.SetReply(req)
}

120
internal/model/config.go Normal file
View File

@@ -0,0 +1,120 @@
package model
import (
"encoding/json"
"net"
"os"
"godns/pkg/logger"
"godns/pkg/utils"
"github.com/pkg/errors"
"github.com/yl2chen/cidranger"
"golang.org/x/net/proxy"
)
const (
_ = iota
StrategyFullest
StrategyFastest
StrategyAnyResult
)
type DohServerConfig struct {
Username string `json:"username,omitempty"` // DoH Basic Auth 用户名(可选)
Password string `json:"password,omitempty"` // DoH Basic Auth 密码(可选)
}
type WebAuth struct {
Username string `json:"username"`
Password string `json:"password"`
}
type Config struct {
ServeAddr string `json:"serve_addr,omitempty"`
WebAddr string `json:"web_addr,omitempty"`
DohServer *DohServerConfig `json:"doh_server,omitempty"`
Strategy int `json:"strategy,omitempty"`
Timeout int `json:"timeout,omitempty"`
SocksProxy string `json:"socks_proxy,omitempty"`
BuiltInCache bool `json:"built_in_cache,omitempty"`
Upstreams []*Upstream `json:"upstreams,omitempty"`
Bootstrap []*Upstream `json:"bootstrap,omitempty"`
Blacklist []string `json:"blacklist,omitempty"`
Debug bool `json:"debug,omitempty"`
Profiling bool `json:"profiling,omitempty"`
// Connection pool settings
MaxActiveConnections int `json:"max_active_connections,omitempty"` // Default: 50
MaxIdleConnections int `json:"max_idle_connections,omitempty"` // Default: 20
// Stats persistence interval in minutes
StatsSaveInterval int `json:"stats_save_interval,omitempty"` // Default: 5 minutes
BlacklistSplited [][]string `json:"-"`
// Web 面板鉴权
WebAuth *WebAuth `json:"web_auth,omitempty"`
}
func (c *Config) ReadInConfig(path string, ipRanger cidranger.Ranger, log logger.Logger) error {
body, err := os.ReadFile(path)
if err != nil {
return err
}
if err := json.Unmarshal([]byte(body), c); err != nil {
return err
}
// Set default connection pool values
if c.MaxActiveConnections == 0 {
c.MaxActiveConnections = 50
}
if c.MaxIdleConnections == 0 {
c.MaxIdleConnections = 20
}
// Set default stats save interval (5 minutes)
if c.StatsSaveInterval == 0 {
c.StatsSaveInterval = 5
}
for i := 0; i < len(c.Bootstrap); i++ {
c.Bootstrap[i].Init(c, ipRanger, log)
if net.ParseIP(c.Bootstrap[i].host) == nil {
return errors.New("Bootstrap 服务器只能使用 IP: " + c.Bootstrap[i].Address)
}
c.Bootstrap[i].InitConnectionPool(nil)
}
for i := 0; i < len(c.Upstreams); i++ {
c.Upstreams[i].Init(c, ipRanger, log)
if err := c.Upstreams[i].Validate(); err != nil {
return err
}
}
c.BlacklistSplited = utils.ParseRules(c.Blacklist)
return nil
}
func (c *Config) GetDialerContext(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error) {
dialSocksProxy, err := proxy.SOCKS5("tcp", c.SocksProxy, nil, d)
if err != nil {
return nil, nil, errors.Wrap(err, "Error creating SOCKS5 proxy")
}
if dialContext, ok := dialSocksProxy.(proxy.ContextDialer); !ok {
return nil, nil, errors.New("Failed type assertion to DialContext")
} else {
return dialSocksProxy, dialContext, err
}
}
func (c *Config) StrategyName() string {
switch c.Strategy {
case StrategyFullest:
return "最全结果"
case StrategyFastest:
return "最快结果"
case StrategyAnyResult:
return "任一结果(建议仅 bootstrap"
}
panic("invalid strategy")
}

282
internal/model/upstream.go Normal file
View File

@@ -0,0 +1,282 @@
package model
import (
"crypto/tls"
"fmt"
"net"
"runtime"
"strings"
"time"
"github.com/dropbox/godropbox/net2"
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/yl2chen/cidranger"
"go.uber.org/atomic"
"godns/pkg/doh"
"godns/pkg/logger"
"godns/pkg/utils"
)
type Upstream struct {
IsPrimary bool `json:"is_primary,omitempty"`
UseSocks bool `json:"use_socks,omitempty"`
Address string `json:"address,omitempty"`
Match []string `json:"match,omitempty"`
protocol, hostAndPort, host, port string
config *Config
ipRanger cidranger.Ranger
matchSplited [][]string
pool net2.ConnectionPool
dohClient *doh.Client
bootstrap func(host string) (net.IP, error)
logger logger.Logger
count *atomic.Int64
}
func (up *Upstream) Init(config *Config, ipRanger cidranger.Ranger, log logger.Logger) {
var ok bool
up.protocol, up.hostAndPort, ok = strings.Cut(up.Address, "://")
if ok && up.protocol != "https" {
up.host, up.port, ok = strings.Cut(up.hostAndPort, ":")
}
if !ok {
panic("上游地址格式(protocol://host:port)有误:" + up.Address)
}
if up.count != nil {
panic("Upstream 已经初始化过了:" + up.Address)
}
up.matchSplited = utils.ParseRules(up.Match)
up.count = atomic.NewInt64(0)
up.config = config
up.ipRanger = ipRanger
up.logger = log
}
// SetLogger 更新 upstream 的 logger 实例
func (up *Upstream) SetLogger(log logger.Logger) {
up.logger = log
}
func (up *Upstream) IsMatch(domain string) bool {
return utils.HasMatchedRule(up.matchSplited, domain)
}
func (up *Upstream) Validate() error {
if !up.IsPrimary && up.protocol == "udp" {
return errors.New("非 primary 只能使用 tcp(-tls)/https" + up.Address)
}
if up.IsPrimary && up.UseSocks {
return errors.New("primary 无需接入 socks" + up.Address)
}
if up.UseSocks && up.config.SocksProxy == "" {
return errors.New("socks 未配置,但是上游已启用:" + up.Address)
}
if up.IsPrimary && up.protocol != "udp" {
up.logger.Println("[WARN] Primary 建议使用 udp 加速获取结果:" + up.Address)
}
return nil
}
func (up *Upstream) conntionFactory(network, address string) (net.Conn, error) {
up.logger.Printf("connecting to %s://%s", network, address)
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if up.bootstrap != nil && net.ParseIP(host) == nil {
ip, err := up.bootstrap(host)
if err != nil {
address = fmt.Sprintf("%s:%s", "0.0.0.0", port)
} else {
address = fmt.Sprintf("%s:%s", ip.String(), port)
}
}
if up.UseSocks {
d, _, err := up.config.GetDialerContext(&net.Dialer{
Timeout: time.Second * time.Duration(up.config.Timeout),
})
if err != nil {
return nil, err
}
switch network {
case "tcp":
return d.Dial(network, address)
case "tcp-tls":
conn, err := d.Dial("tcp", address)
if err != nil {
return nil, err
}
return tls.Client(conn, &tls.Config{
ServerName: host,
}), nil
}
} else {
var d net.Dialer
d.Timeout = time.Second * time.Duration(up.config.Timeout)
switch network {
case "tcp":
return d.Dial(network, address)
case "tcp-tls":
return tls.DialWithDialer(&d, "tcp", address, &tls.Config{
ServerName: host,
})
}
}
panic("wrong protocol: " + network)
}
func (up *Upstream) InitConnectionPool(bootstrap func(host string) (net.IP, error)) {
up.bootstrap = bootstrap
if strings.Contains(up.protocol, "http") {
ops := []doh.ClientOption{
doh.WithServer(up.Address),
doh.WithBootstrap(bootstrap),
doh.WithTimeout(time.Second * time.Duration(up.config.Timeout)),
doh.WithLogger(up.logger),
}
if up.UseSocks {
ops = append(ops, doh.WithSocksProxy(up.config.GetDialerContext))
}
up.dohClient = doh.NewClient(ops...)
}
// 只需要启用 tcp/tcp-tls 协议的连接池
if strings.Contains(up.protocol, "tcp") {
maxIdleTime := time.Second * time.Duration(up.config.Timeout*10)
timeout := time.Second * time.Duration(up.config.Timeout)
p := net2.NewSimpleConnectionPool(net2.ConnectionOptions{
MaxActiveConnections: int32(up.config.MaxActiveConnections),
MaxIdleConnections: uint32(up.config.MaxIdleConnections),
MaxIdleTime: &maxIdleTime,
DialMaxConcurrency: 20,
ReadTimeout: timeout,
WriteTimeout: timeout,
Dial: func(network, address string) (net.Conn, error) {
dialer, err := up.conntionFactory(network, address)
if err != nil {
return nil, err
}
dialer.SetDeadline(time.Now().Add(timeout))
return dialer, nil
},
})
p.Register(up.protocol, up.hostAndPort)
up.pool = p
}
}
func (up *Upstream) IsValidMsg(r *dns.Msg) bool {
domain := GetDomainNameFromDnsMsg(r)
inBlacklist := utils.HasMatchedRule(up.config.BlacklistSplited, domain)
for i := 0; i < len(r.Answer); i++ {
var ip net.IP
typeA, ok := r.Answer[i].(*dns.A)
if ok {
ip = typeA.A
} else {
typeAAAA, ok := r.Answer[i].(*dns.AAAA)
if !ok {
continue
}
ip = typeAAAA.AAAA
}
isPrimary, err := up.ipRanger.Contains(ip)
if err != nil {
up.logger.Printf("ipRanger query ip %s failed: %s", ip, err)
continue
}
up.logger.Printf("checkPrimary result %s: %s@%s ->domain.inBlacklist:%v ip.IsPrimary:%v up.IsPrimary:%v", up.Address, domain, ip, inBlacklist, isPrimary, up.IsPrimary)
// 黑名单中的域名,如果是 primary 即不可用
if inBlacklist && isPrimary {
return false
}
// 如果是 server 是 primary但是 ip 不是 primary也不可用
if up.IsPrimary && !isPrimary {
return false
}
}
return !up.IsPrimary || len(r.Answer) > 0
}
func GetDomainNameFromDnsMsg(msg *dns.Msg) string {
if msg == nil || len(msg.Question) == 0 {
return ""
}
return msg.Question[0].Name
}
func (up *Upstream) poolLen() int32 {
if up.pool == nil {
return 0
}
return up.pool.NumActive()
}
func (up *Upstream) Exchange(req *dns.Msg) (*dns.Msg, time.Duration, error) {
up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Inc(), up.poolLen(), runtime.NumGoroutine(), "enter")
defer up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Dec(), up.poolLen(), runtime.NumGoroutine(), "exit")
var resp *dns.Msg
var duration time.Duration
var err error
switch up.protocol {
case "https", "http":
resp, duration, err = up.dohClient.Exchange(req)
case "udp":
client := new(dns.Client)
client.Timeout = time.Second * time.Duration(up.config.Timeout)
resp, duration, err = client.Exchange(req, up.hostAndPort)
case "tcp", "tcp-tls":
conn, errGetConn := up.pool.Get(up.protocol, up.hostAndPort)
if errGetConn != nil {
return nil, 0, errGetConn
}
resp, err = dnsExchangeWithConn(conn, req)
default:
panic(fmt.Sprintf("invalid upstream protocol: %s in address %s", up.protocol, up.Address))
}
// 清理 EDNS 信息
if resp != nil && len(resp.Extra) > 0 {
var newExtra []dns.RR
for i := 0; i < len(resp.Extra); i++ {
if resp.Extra[i].Header().Rrtype == dns.TypeOPT {
continue
}
newExtra = append(newExtra, resp.Extra[i])
}
resp.Extra = newExtra
}
return resp, duration, err
}
func dnsExchangeWithConn(conn net2.ManagedConn, req *dns.Msg) (*dns.Msg, error) {
var resp *dns.Msg
co := dns.Conn{Conn: conn}
err := co.WriteMsg(req)
if err == nil {
resp, err = co.ReadMsg()
}
if err == nil {
conn.ReleaseConnection()
} else {
conn.DiscardConnection()
}
return resp, err
}

View File

@@ -0,0 +1,120 @@
package model
import (
"index/suffixarray"
"strings"
"testing"
"godns/pkg/utils"
)
var primaryLocations = []string{"中国", "省", "市", "自治区"}
var nonPrimaryLocations = []string{"台湾", "香港", "澳门"}
var primaryLocationsBytes = [][]byte{[]byte("中国"), []byte("省"), []byte("市"), []byte("自治区")}
var nonPrimaryLocationsBytes = [][]byte{[]byte("台湾"), []byte("香港"), []byte("澳门")}
func BenchmarkCheckPrimary(b *testing.B) {
for i := 0; i < b.N; i++ {
checkPrimary("哈哈")
}
}
func BenchmarkCheckPrimaryStringsContains(b *testing.B) {
for i := 0; i < b.N; i++ {
checkPrimaryStringsContains("哈哈")
}
}
func TestIsMatch(t *testing.T) {
var up Upstream
up.matchSplited = utils.ParseRules([]string{"."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": true,
"b.a.com.": true,
".b.a.com.cn.": true,
"b.a.com.cn.": true,
"d.b.a.com.": true,
}, t)
up.matchSplited = utils.ParseRules([]string{""})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": false,
"b.a.com.": false,
".b.a.com.cn.": false,
"b.a.com.cn.": false,
"d.b.a.com.": false,
}, t)
up.matchSplited = utils.ParseRules([]string{"a.com."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": true,
"b.a.com.": false,
".b.a.com.cn.": false,
"b.a.com.cn.": false,
"d.b.a.com.": false,
}, t)
up.matchSplited = utils.ParseRules([]string{".a.com."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": false,
"b.a.com.": true,
".b.a.com.cn.": false,
"b.a.com.cn.": false,
"d.b.a.com.": true,
}, t)
up.matchSplited = utils.ParseRules([]string{"b.d.com."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": false,
".a.com.": false,
"b.d.com.": true,
".b.d.com.cn.": false,
"b.d.com.cn.": false,
".c.d.com.": false,
"b.d.a.com.": false,
}, t)
}
func checkUpstreamMatch(up *Upstream, cases map[string]bool, t *testing.T) {
for k, v := range cases {
isMatch := up.IsMatch(k)
if isMatch != v {
t.Errorf("Upstream(%s).IsMatch(%s) = %v, want %v", up.matchSplited, k, isMatch, v)
}
}
}
func checkPrimary(str string) bool {
index := suffixarray.New([]byte(str))
for i := 0; i < len(nonPrimaryLocationsBytes); i++ {
if len(index.Lookup(nonPrimaryLocationsBytes[i], 1)) > 0 {
return false
}
}
for i := 0; i < len(primaryLocationsBytes); i++ {
if len(index.Lookup(primaryLocationsBytes[i], 1)) > 0 {
return true
}
}
return false
}
func checkPrimaryStringsContains(str string) bool {
for i := 0; i < len(nonPrimaryLocations); i++ {
if strings.Contains(str, nonPrimaryLocations[i]) {
return false
}
}
for i := 0; i < len(primaryLocations); i++ {
if strings.Contains(str, primaryLocations[i]) {
return true
}
}
return false
}

649
internal/stats/stats.go Normal file
View File

@@ -0,0 +1,649 @@
package stats
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"sort"
"sync"
"sync/atomic"
"time"
)
// StatsRecorder 定义统计接口
type StatsRecorder interface {
RecordQuery()
RecordDoHQuery()
RecordCacheHit()
RecordCacheMiss()
RecordFailed()
RecordUpstreamQuery(address string, isError bool)
RecordClientQuery(clientIP, domain string)
GetSnapshot() StatsSnapshot
Reset()
Save(dataPath string) error
Load(dataPath string) error
}
// Stats DNS服务器统计信息
type Stats struct {
StartTime time.Time // 应用启动时间(不持久化)
StatsStartTime time.Time // 统计数据开始时间(可持久化)
// 查询统计
TotalQueries atomic.Uint64
DoHQueries atomic.Uint64
CacheHits atomic.Uint64
CacheMisses atomic.Uint64
FailedQueries atomic.Uint64
// 上游服务器统计
upstreamStats map[string]*UpstreamStats
mu sync.RWMutex
// Top N 统计
topClients *TopNTracker // 客户端 IP Top N
topDomains *TopNTracker // 查询域名 Top N
}
// UpstreamStats 上游服务器统计
type UpstreamStats struct {
Address string
TotalQueries atomic.Uint64
Errors atomic.Uint64
LastUsed time.Time
mu sync.RWMutex
}
// NewStats 创建统计实例
func NewStats() *Stats {
now := time.Now()
return &Stats{
StartTime: now,
StatsStartTime: now,
upstreamStats: make(map[string]*UpstreamStats),
topClients: NewTopNTracker(100), // 最多保留 100 个客户端 IP
topDomains: NewTopNTracker(200), // 最多保留 200 个域名
}
}
// RecordQuery 记录DNS查询
func (s *Stats) RecordQuery() {
s.TotalQueries.Add(1)
}
// RecordDoHQuery 记录DoH查询
func (s *Stats) RecordDoHQuery() {
s.DoHQueries.Add(1)
}
// RecordCacheHit 记录缓存命中
func (s *Stats) RecordCacheHit() {
s.CacheHits.Add(1)
}
// RecordCacheMiss 记录缓存未命中
func (s *Stats) RecordCacheMiss() {
s.CacheMisses.Add(1)
}
// RecordFailed 记录查询失败
func (s *Stats) RecordFailed() {
s.FailedQueries.Add(1)
}
// RecordUpstreamQuery 记录上游服务器查询
func (s *Stats) RecordUpstreamQuery(address string, isError bool) {
// 先尝试读锁快速查找
s.mu.RLock()
us, ok := s.upstreamStats[address]
s.mu.RUnlock()
// 如果不存在才使用写锁创建
if !ok {
s.mu.Lock()
// 双重检查,防止并发创建
us, ok = s.upstreamStats[address]
if !ok {
us = &UpstreamStats{
Address: address,
}
s.upstreamStats[address] = us
}
s.mu.Unlock()
}
us.TotalQueries.Add(1)
if isError {
us.Errors.Add(1)
}
us.mu.Lock()
us.LastUsed = time.Now()
us.mu.Unlock()
}
// RecordClientQuery 记录客户端查询IP 和域名)
func (s *Stats) RecordClientQuery(clientIP, domain string) {
if clientIP != "" {
s.topClients.Record(clientIP, "")
}
if domain != "" {
s.topDomains.Record(domain, clientIP)
}
}
// Reset 重置统计数据
func (s *Stats) Reset() {
s.mu.Lock()
defer s.mu.Unlock()
// 重置统计开始时间
s.StatsStartTime = time.Now()
// 重置查询统计
s.TotalQueries.Store(0)
s.DoHQueries.Store(0)
s.CacheHits.Store(0)
s.CacheMisses.Store(0)
s.FailedQueries.Store(0)
// 重置上游服务器统计
s.upstreamStats = make(map[string]*UpstreamStats)
// 重置 Top N 统计
s.topClients = NewTopNTracker(100)
s.topDomains = NewTopNTracker(200)
}
// RuntimeStats 运行时统计信息
type RuntimeStats struct {
Uptime int64 `json:"uptime"` // 运行时间(秒)
UptimeStr string `json:"uptime_str"` // 运行时间(可读格式)
StatsDuration int64 `json:"stats_duration"` // 统计时长(秒)
StatsDurationStr string `json:"stats_duration_str"` // 统计时长(可读格式)
Goroutines int `json:"goroutines"` // Goroutine数量
MemAllocMB uint64 `json:"mem_alloc_mb"` // 已分配内存MB
MemTotalMB uint64 `json:"mem_total_mb"` // 总分配内存MB
MemSysMB uint64 `json:"mem_sys_mb"` // 系统内存MB
NumGC uint32 `json:"num_gc"` // GC次数
}
// QueryStats 查询统计信息
type QueryStats struct {
Total uint64 `json:"total"` // 总查询数
DoH uint64 `json:"doh"` // DoH查询数
CacheHits uint64 `json:"cache_hits"` // 缓存命中数
CacheMisses uint64 `json:"cache_misses"` // 缓存未命中数
Failed uint64 `json:"failed"` // 失败查询数
HitRate float64 `json:"hit_rate"` // 缓存命中率
}
// UpstreamStatsJSON 上游服务器统计JSON格式
type UpstreamStatsJSON struct {
Address string `json:"address"` // 服务器地址
TotalQueries uint64 `json:"total_queries"` // 总查询数
Errors uint64 `json:"errors"` // 错误数
ErrorRate float64 `json:"error_rate"` // 错误率
LastUsed string `json:"last_used"` // 最后使用时间
}
// TopNItemJSON Top N 项目JSON格式
type TopNItemJSON struct {
Key string `json:"key"` // IP 地址或域名
Count uint64 `json:"count"` // 查询次数
TopClient string `json:"top_client,omitempty"` // 查询最多的客户端 IP仅域名统计有
}
// StatsSnapshot 完整统计快照
type StatsSnapshot struct {
Runtime RuntimeStats `json:"runtime"` // 运行时信息
Queries QueryStats `json:"queries"` // 查询统计
Upstreams []UpstreamStatsJSON `json:"upstreams"` // 上游服务器统计
TopClients []TopNItemJSON `json:"top_clients"` // Top 客户端 IP
TopDomains []TopNItemJSON `json:"top_domains"` // Top 查询域名
}
// GetSnapshot 获取统计快照
func (s *Stats) GetSnapshot() StatsSnapshot {
// 运行时信息
var m runtime.MemStats
runtime.ReadMemStats(&m)
uptime := time.Since(s.StartTime)
uptimeStr := formatDuration(uptime)
statsDuration := time.Since(s.StatsStartTime)
statsDurationStr := formatDuration(statsDuration)
runtimeStats := RuntimeStats{
Uptime: int64(uptime.Seconds()),
UptimeStr: uptimeStr,
StatsDuration: int64(statsDuration.Seconds()),
StatsDurationStr: statsDurationStr,
Goroutines: runtime.NumGoroutine(),
MemAllocMB: m.Alloc / 1024 / 1024,
MemTotalMB: m.TotalAlloc / 1024 / 1024,
MemSysMB: m.Sys / 1024 / 1024,
NumGC: m.NumGC,
}
// 查询统计
total := s.TotalQueries.Load()
hits := s.CacheHits.Load()
misses := s.CacheMisses.Load()
failed := s.FailedQueries.Load()
var hitRate float64
if total > 0 {
hitRate = float64(hits) / float64(total) * 100
}
queryStats := QueryStats{
Total: total,
DoH: s.DoHQueries.Load(),
CacheHits: hits,
CacheMisses: misses,
Failed: failed,
HitRate: hitRate,
}
// 上游服务器统计
s.mu.RLock()
upstreams := make([]UpstreamStatsJSON, 0, len(s.upstreamStats))
for _, us := range s.upstreamStats {
queries := us.TotalQueries.Load()
errors := us.Errors.Load()
var errorRate float64
if queries > 0 {
errorRate = float64(errors) / float64(queries) * 100
}
us.mu.RLock()
lastUsed := us.LastUsed.Format("2006-01-02 15:04:05")
if us.LastUsed.IsZero() {
lastUsed = "Never"
}
us.mu.RUnlock()
upstreams = append(upstreams, UpstreamStatsJSON{
Address: us.Address,
TotalQueries: queries,
Errors: errors,
ErrorRate: errorRate,
LastUsed: lastUsed,
})
}
s.mu.RUnlock()
// 按服务器地址字符串排序
sort.Slice(upstreams, func(i, j int) bool {
return upstreams[i].Address < upstreams[j].Address
})
// Top N 客户端 IP
topClients := make([]TopNItemJSON, 0)
for _, item := range s.topClients.GetTopN(20) { // 返回 Top 20
topClients = append(topClients, TopNItemJSON{
Key: item.Key,
Count: item.Count,
})
}
// Top N 查询域名
topDomains := make([]TopNItemJSON, 0)
for _, item := range s.topDomains.GetTopN(20) { // 返回 Top 20
topDomains = append(topDomains, TopNItemJSON{
Key: item.Key,
Count: item.Count,
TopClient: item.TopClient,
})
}
return StatsSnapshot{
Runtime: runtimeStats,
Queries: queryStats,
Upstreams: upstreams,
TopClients: topClients,
TopDomains: topDomains,
}
}
// formatDuration 格式化时长为可读格式
func formatDuration(d time.Duration) string {
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
minutes := int(d.Minutes()) % 60
seconds := int(d.Seconds()) % 60
if days > 0 {
return formatString("%d天%d小时%d分钟", days, hours, minutes)
} else if hours > 0 {
return formatString("%d小时%d分钟%d秒", hours, minutes, seconds)
} else if minutes > 0 {
return formatString("%d分钟%d秒", minutes, seconds)
}
return formatString("%d秒", seconds)
}
// formatString 简单的字符串格式化
func formatString(format string, args ...interface{}) string {
result := format
for _, arg := range args {
switch v := arg.(type) {
case int:
result = replaceFirst(result, "%d", itoa(v))
}
}
return result
}
// replaceFirst 替换第一个匹配的字符串
func replaceFirst(s, old, new string) string {
for i := 0; i <= len(s)-len(old); i++ {
if s[i:i+len(old)] == old {
return s[:i] + new + s[i+len(old):]
}
}
return s
}
// itoa 整数转字符串
func itoa(i int) string {
if i == 0 {
return "0"
}
negative := i < 0
if negative {
i = -i
}
var buf [32]byte
pos := len(buf)
for i > 0 {
pos--
buf[pos] = byte('0' + i%10)
i /= 10
}
if negative {
pos--
buf[pos] = '-'
}
return string(buf[pos:])
}
// TopNTracker 追踪 Top N 项目,内存可控
type TopNTracker struct {
mu sync.RWMutex
items map[string]*TopNItem
maxItems int // 最大保留项目数
}
// TopNItem Top N 项目统计
type TopNItem struct {
Key string
Count uint64
TopClient string // 对于域名统计,记录查询最多的客户端 IP
clients map[string]uint64 // 临时记录客户端分布(仅用于找 Top1
}
// PersistentStats 持久化统计数据结构
type PersistentStats struct {
StatsStartTime time.Time `json:"stats_start_time"` // 统计开始时间(可持久化)
TotalQueries uint64 `json:"total_queries"`
DoHQueries uint64 `json:"doh_queries"`
CacheHits uint64 `json:"cache_hits"`
CacheMisses uint64 `json:"cache_misses"`
FailedQueries uint64 `json:"failed_queries"`
Upstreams map[string]*PersistentUpstream `json:"upstreams"`
TopClients []PersistentTopNItem `json:"top_clients"`
TopDomains []PersistentTopNItem `json:"top_domains"`
}
// PersistentUpstream 持久化上游服务器统计
type PersistentUpstream struct {
Address string `json:"address"`
TotalQueries uint64 `json:"total_queries"`
Errors uint64 `json:"errors"`
LastUsed time.Time `json:"last_used"`
}
// PersistentTopNItem 持久化 Top N 项目
type PersistentTopNItem struct {
Key string `json:"key"`
Count uint64 `json:"count"`
TopClient string `json:"top_client,omitempty"`
Clients map[string]uint64 `json:"clients,omitempty"`
}
// NewTopNTracker 创建 Top N 追踪器
func NewTopNTracker(maxItems int) *TopNTracker {
return &TopNTracker{
items: make(map[string]*TopNItem),
maxItems: maxItems,
}
}
// Record 记录一次访问(可选关联的客户端 IP
func (t *TopNTracker) Record(key, associatedClient string) {
t.mu.Lock()
defer t.mu.Unlock()
item, exists := t.items[key]
if !exists {
// 如果超过最大数量,删除计数最少的项
if len(t.items) >= t.maxItems {
t.evictLowest()
}
item = &TopNItem{
Key: key,
clients: make(map[string]uint64),
}
t.items[key] = item
}
item.Count++
// 如果有关联客户端,记录客户端分布
if associatedClient != "" {
item.clients[associatedClient]++
// 更新 Top1 客户端
if item.clients[associatedClient] > item.clients[item.TopClient] {
item.TopClient = associatedClient
}
}
}
// evictLowest 删除计数最少的项(不加锁,由调用者加锁)
func (t *TopNTracker) evictLowest() {
var minKey string
var minCount uint64 = ^uint64(0) // 最大值
for key, item := range t.items {
if item.Count < minCount {
minCount = item.Count
minKey = key
}
}
if minKey != "" {
delete(t.items, minKey)
}
}
// GetTopN 获取 Top N 列表
func (t *TopNTracker) GetTopN(n int) []TopNItem {
t.mu.RLock()
defer t.mu.RUnlock()
// 复制所有项
items := make([]TopNItem, 0, len(t.items))
for _, item := range t.items {
items = append(items, TopNItem{
Key: item.Key,
Count: item.Count,
TopClient: item.TopClient,
})
}
// 按查询次数降序排序
sort.Slice(items, func(i, j int) bool {
return items[i].Count > items[j].Count
})
// 返回前 N 项
if n > len(items) {
n = len(items)
}
return items[:n]
}
// Save 保存统计数据到 JSON 文件
func (s *Stats) Save(dataPath string) error {
s.mu.RLock()
defer s.mu.RUnlock()
// 准备持久化数据
persistent := PersistentStats{
StatsStartTime: s.StatsStartTime,
TotalQueries: s.TotalQueries.Load(),
DoHQueries: s.DoHQueries.Load(),
CacheHits: s.CacheHits.Load(),
CacheMisses: s.CacheMisses.Load(),
FailedQueries: s.FailedQueries.Load(),
Upstreams: make(map[string]*PersistentUpstream),
TopClients: make([]PersistentTopNItem, 0),
TopDomains: make([]PersistentTopNItem, 0),
}
// 保存上游服务器统计
for addr, us := range s.upstreamStats {
us.mu.RLock()
persistent.Upstreams[addr] = &PersistentUpstream{
Address: us.Address,
TotalQueries: us.TotalQueries.Load(),
Errors: us.Errors.Load(),
LastUsed: us.LastUsed,
}
us.mu.RUnlock()
}
// 保存 Top 客户端
s.topClients.mu.RLock()
for _, item := range s.topClients.items {
persistent.TopClients = append(persistent.TopClients, PersistentTopNItem{
Key: item.Key,
Count: item.Count,
TopClient: item.TopClient,
Clients: item.clients,
})
}
s.topClients.mu.RUnlock()
// 保存 Top 域名
s.topDomains.mu.RLock()
for _, item := range s.topDomains.items {
persistent.TopDomains = append(persistent.TopDomains, PersistentTopNItem{
Key: item.Key,
Count: item.Count,
TopClient: item.TopClient,
Clients: item.clients,
})
}
s.topDomains.mu.RUnlock()
// 序列化为 JSON
data, err := json.MarshalIndent(persistent, "", " ")
if err != nil {
return err
}
// 确保目录存在
statsPath := filepath.Join(dataPath, "cache")
if err := os.MkdirAll(statsPath, 0755); err != nil {
return err
}
// 写入文件
statsFile := filepath.Join(statsPath, "stats.json")
return os.WriteFile(statsFile, data, 0644)
}
// Load 从 JSON 文件加载统计数据
func (s *Stats) Load(dataPath string) error {
statsFile := filepath.Join(dataPath, "cache", "stats.json")
// 检查文件是否存在
if _, err := os.Stat(statsFile); os.IsNotExist(err) {
return nil // 文件不存在不是错误,返回 nil
}
// 读取文件
data, err := os.ReadFile(statsFile)
if err != nil {
return err
}
// 解析 JSON
var persistent PersistentStats
if err := json.Unmarshal(data, &persistent); err != nil {
return err
}
// 恢复统计数据
s.mu.Lock()
defer s.mu.Unlock()
// StartTime 保持为应用启动时间,不从磁盘恢复
// 只恢复 StatsStartTime统计数据开始时间
s.StatsStartTime = persistent.StatsStartTime
s.TotalQueries.Store(persistent.TotalQueries)
s.DoHQueries.Store(persistent.DoHQueries)
s.CacheHits.Store(persistent.CacheHits)
s.CacheMisses.Store(persistent.CacheMisses)
s.FailedQueries.Store(persistent.FailedQueries)
// 恢复上游服务器统计
for addr, pus := range persistent.Upstreams {
us := &UpstreamStats{
Address: pus.Address,
LastUsed: pus.LastUsed,
}
us.TotalQueries.Store(pus.TotalQueries)
us.Errors.Store(pus.Errors)
s.upstreamStats[addr] = us
}
// 恢复 Top 客户端
s.topClients.mu.Lock()
for _, pitem := range persistent.TopClients {
item := &TopNItem{
Key: pitem.Key,
Count: pitem.Count,
TopClient: pitem.TopClient,
clients: pitem.Clients,
}
if item.clients == nil {
item.clients = make(map[string]uint64)
}
s.topClients.items[pitem.Key] = item
}
s.topClients.mu.Unlock()
// 恢复 Top 域名
s.topDomains.mu.Lock()
for _, pitem := range persistent.TopDomains {
item := &TopNItem{
Key: pitem.Key,
Count: pitem.Count,
TopClient: pitem.TopClient,
clients: pitem.Clients,
}
if item.clients == nil {
item.clients = make(map[string]uint64)
}
s.topDomains.items[pitem.Key] = item
}
s.topDomains.mu.Unlock()
return nil
}

205
internal/web/handler.go Normal file
View File

@@ -0,0 +1,205 @@
package web
import (
"crypto/subtle"
"embed"
"encoding/json"
"io/fs"
"net/http"
"godns/internal/stats"
"godns/pkg/logger"
)
//go:embed static/*
var staticFiles embed.FS
// Handler Web服务处理器
type Handler struct {
stats stats.StatsRecorder
version string
checkUpdateCh chan<- struct{}
logger logger.Logger
username string
password string
}
// NewHandler 创建Web处理器
func NewHandler(s stats.StatsRecorder, ver string, checkCh chan<- struct{}, log logger.Logger, username, password string) *Handler {
return &Handler{
stats: s,
version: ver,
checkUpdateCh: checkCh,
logger: log,
username: username,
password: password,
}
}
// basicAuth 中间件
func (h *Handler) basicAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 如果未配置鉴权,直接放行
if h.username == "" || h.password == "" {
next(w, r)
return
}
user, pass, ok := r.BasicAuth()
if !ok || subtle.ConstantTimeCompare([]byte(user), []byte(h.username)) != 1 ||
subtle.ConstantTimeCompare([]byte(pass), []byte(h.password)) != 1 {
w.Header().Set("WWW-Authenticate", `Basic realm="NBDNS Monitor"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next(w, r)
}
}
// RegisterRoutes 注册路由
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
// API路由
mux.HandleFunc("/api/stats", h.basicAuth(h.handleStats))
mux.HandleFunc("/api/version", h.basicAuth(h.handleVersion))
mux.HandleFunc("/api/check-update", h.basicAuth(h.handleCheckUpdate))
mux.HandleFunc("/api/stats/reset", h.basicAuth(h.handleStatsReset))
// 静态文件服务
staticFS, err := fs.Sub(staticFiles, "static")
if err != nil {
h.logger.Printf("Failed to load static files: %v", err)
return
}
mux.Handle("/", h.basicAuth(func(w http.ResponseWriter, r *http.Request) {
http.FileServer(http.FS(staticFS)).ServeHTTP(w, r)
}))
}
// handleStats 处理统计信息请求
func (h *Handler) handleStats(w http.ResponseWriter, r *http.Request) {
// 只允许GET请求
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取统计快照
snapshot := h.stats.GetSnapshot()
// 设置响应头
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
// 编码JSON并返回
if err := json.NewEncoder(w).Encode(snapshot); err != nil {
h.logger.Printf("Error encoding stats JSON: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
}
// ResetResponse 重置响应
type ResetResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
}
// handleStatsReset 处理统计数据重置请求
func (h *Handler) handleStatsReset(w http.ResponseWriter, r *http.Request) {
// 只允许POST请求
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 重置统计数据
h.stats.Reset()
h.logger.Printf("Statistics reset by user request")
// 设置响应头
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
// 返回成功响应
if err := json.NewEncoder(w).Encode(ResetResponse{
Success: true,
Message: "统计数据已重置",
}); err != nil {
h.logger.Printf("Error encoding reset response JSON: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
}
// VersionResponse 版本信息响应
type VersionResponse struct {
Version string `json:"version"`
}
// handleVersion 处理版本查询请求
func (h *Handler) handleVersion(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
ver := h.version
if ver == "" {
ver = "0.0.0"
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
if err := json.NewEncoder(w).Encode(VersionResponse{Version: ver}); err != nil {
h.logger.Printf("Error encoding version JSON: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
}
// UpdateCheckResponse 更新检查响应
type UpdateCheckResponse struct {
HasUpdate bool `json:"has_update"`
CurrentVersion string `json:"current_version"`
LatestVersion string `json:"latest_version"`
Message string `json:"message"`
}
// handleCheckUpdate 处理检查更新请求生产者2用户手动触发
func (h *Handler) handleCheckUpdate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
ver := h.version
if ver == "" {
ver = "0.0.0"
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
// 触发后台检查更新(非阻塞)
select {
case h.checkUpdateCh <- struct{}{}:
h.logger.Printf("Update check triggered by user")
json.NewEncoder(w).Encode(UpdateCheckResponse{
HasUpdate: false,
CurrentVersion: ver,
LatestVersion: ver,
Message: "已触发更新检查,请查看服务器日志",
})
default:
// 如果通道已满,说明已经在检查中
json.NewEncoder(w).Encode(UpdateCheckResponse{
HasUpdate: false,
CurrentVersion: ver,
LatestVersion: ver,
Message: "更新检查正在进行中",
})
}
}

307
internal/web/static/app.js Normal file
View File

@@ -0,0 +1,307 @@
// 自动刷新间隔(毫秒)
const REFRESH_INTERVAL = 3000;
let refreshTimer = null;
let countdownTimer = null;
let countdown = 0;
let isCheckingUpdate = false;
let isResettingStats = false;
// 格式化数字,添加千位分隔符
function formatNumber(num) {
return num.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ",");
}
// 格式化百分比
function formatPercent(num) {
return num.toFixed(2) + '%';
}
// 更新运行时信息
function updateRuntimeStats(runtime) {
document.getElementById('uptime').textContent = runtime.uptime_str || '-';
document.getElementById('goroutines').textContent = formatNumber(runtime.goroutines || 0);
document.getElementById('mem-alloc').textContent = formatNumber(runtime.mem_alloc_mb || 0) + ' MB';
document.getElementById('mem-sys').textContent = formatNumber(runtime.mem_sys_mb || 0) + ' MB';
document.getElementById('mem-total').textContent = formatNumber(runtime.mem_total_mb || 0) + ' MB';
document.getElementById('num-gc').textContent = formatNumber(runtime.num_gc || 0);
// 更新统计时长
const statsDuration = runtime.stats_duration_str || '-';
document.getElementById('stats-duration').textContent = '统计时长: ' + statsDuration;
}
// 更新查询统计
function updateQueryStats(queries) {
document.getElementById('total-queries').textContent = formatNumber(queries.total || 0);
document.getElementById('doh-queries').textContent = formatNumber(queries.doh || 0);
document.getElementById('cache-hits').textContent = formatNumber(queries.cache_hits || 0);
document.getElementById('cache-misses').textContent = formatNumber(queries.cache_misses || 0);
document.getElementById('failed-queries').textContent = formatNumber(queries.failed || 0);
document.getElementById('hit-rate').textContent = formatPercent(queries.hit_rate || 0);
}
// 更新上游服务器表格
function updateUpstreamTable(upstreams) {
const tbody = document.getElementById('upstream-tbody');
if (!upstreams || upstreams.length === 0) {
tbody.innerHTML = '<tr><td colspan="5" class="no-data">暂无数据</td></tr>';
return;
}
let html = '';
upstreams.forEach(upstream => {
const errorClass = upstream.error_rate > 10 ? 'error-high' : '';
html += `
<tr>
<td>${upstream.address || '-'}</td>
<td>${formatNumber(upstream.total_queries || 0)}</td>
<td class="${errorClass}">${formatNumber(upstream.errors || 0)}</td>
<td class="${errorClass}">${formatPercent(upstream.error_rate || 0)}</td>
<td>${upstream.last_used || 'Never'}</td>
</tr>
`;
});
tbody.innerHTML = html;
}
// 更新 Top 客户端 IP 表格
function updateTopClientsTable(topClients) {
const tbody = document.getElementById('top-clients-tbody');
if (!topClients || topClients.length === 0) {
tbody.innerHTML = '<tr><td colspan="3" class="no-data">暂无数据</td></tr>';
return;
}
let html = '';
topClients.forEach((client, index) => {
const rankClass = index < 3 ? `rank-${index + 1}` : '';
html += `
<tr class="${rankClass}">
<td class="rank-cell">${index + 1}</td>
<td>${client.key || '-'}</td>
<td>${formatNumber(client.count || 0)}</td>
</tr>
`;
});
tbody.innerHTML = html;
}
// 更新 Top 查询域名表格
function updateTopDomainsTable(topDomains) {
const tbody = document.getElementById('top-domains-tbody');
if (!topDomains || topDomains.length === 0) {
tbody.innerHTML = '<tr><td colspan="4" class="no-data">暂无数据</td></tr>';
return;
}
let html = '';
topDomains.forEach((domain, index) => {
const rankClass = index < 3 ? `rank-${index + 1}` : '';
const topClient = domain.top_client || '-';
html += `
<tr class="${rankClass}">
<td class="rank-cell">${index + 1}</td>
<td class="domain-cell" title="${domain.key}">${domain.key || '-'}</td>
<td>${formatNumber(domain.count || 0)}</td>
<td>${topClient}</td>
</tr>
`;
});
tbody.innerHTML = html;
}
// 更新倒计时显示
function updateCountdown() {
countdown--;
if (countdown <= 0) {
countdown = 0;
}
document.getElementById('last-update').textContent = `下次刷新: ${countdown}`;
}
// 重置倒计时
function resetCountdown() {
countdown = REFRESH_INTERVAL / 1000;
if (countdownTimer) {
clearInterval(countdownTimer);
}
countdownTimer = setInterval(updateCountdown, 1000);
updateCountdown();
}
// 加载统计数据
async function loadStats() {
try {
const response = await fetch('/api/stats');
if (!response.ok) {
throw new Error('获取统计数据失败');
}
const data = await response.json();
// 更新各部分数据
updateRuntimeStats(data.runtime);
updateQueryStats(data.queries);
updateUpstreamTable(data.upstreams);
updateTopClientsTable(data.top_clients);
updateTopDomainsTable(data.top_domains);
// 重置倒计时
resetCountdown();
} catch (error) {
console.error('加载统计数据出错:', error);
document.getElementById('last-update').textContent = '加载失败';
}
}
// 启动自动刷新
function startAutoRefresh() {
if (refreshTimer) {
clearInterval(refreshTimer);
}
refreshTimer = setInterval(loadStats, REFRESH_INTERVAL);
}
// 停止自动刷新
function stopAutoRefresh() {
if (refreshTimer) {
clearInterval(refreshTimer);
refreshTimer = null;
}
if (countdownTimer) {
clearInterval(countdownTimer);
countdownTimer = null;
}
}
// 加载版本号
async function loadVersion() {
try {
const response = await fetch('/api/version');
if (!response.ok) {
throw new Error('获取版本号失败');
}
const data = await response.json();
document.getElementById('version-display').textContent = 'v' + data.version;
} catch (error) {
console.error('加载版本号出错:', error);
document.getElementById('version-display').textContent = 'v0.0.0';
}
}
// 检查更新
async function checkUpdate() {
if (isCheckingUpdate) {
return;
}
const btn = document.getElementById('check-update-btn');
const originalText = btn.textContent;
try {
isCheckingUpdate = true;
btn.textContent = '⏳';
btn.disabled = true;
const response = await fetch('/api/check-update');
if (!response.ok) {
throw new Error('检查更新失败');
}
const data = await response.json();
if (data.has_update) {
alert(`${data.message}\n当前版本: v${data.current_version}\n最新版本: v${data.latest_version}\n\n请访问 GitHub 下载最新版本`);
} else {
alert(`${data.message}\n当前版本: v${data.current_version}`);
}
} catch (error) {
console.error('检查更新出错:', error);
alert('检查更新失败,请稍后再试');
} finally {
isCheckingUpdate = false;
btn.textContent = originalText;
btn.disabled = false;
}
}
// 重置统计数据
async function resetStats() {
if (isResettingStats) {
return;
}
// 确认对话框
if (!confirm('确定要重置所有统计数据吗?此操作无法撤销。')) {
return;
}
const btn = document.getElementById('reset-stats-btn');
const originalText = btn.textContent;
try {
isResettingStats = true;
btn.textContent = '⏳ 重置中...';
btn.disabled = true;
const response = await fetch('/api/stats/reset', {
method: 'POST'
});
if (!response.ok) {
throw new Error('重置统计数据失败');
}
const data = await response.json();
if (data.success) {
alert(data.message || '统计数据已重置');
// 立即刷新数据
await loadStats();
} else {
alert('重置失败: ' + (data.message || '未知错误'));
}
} catch (error) {
console.error('重置统计数据出错:', error);
alert('重置统计数据失败,请稍后再试');
} finally {
isResettingStats = false;
btn.textContent = originalText;
btn.disabled = false;
}
}
// 页面加载完成后初始化
document.addEventListener('DOMContentLoaded', function() {
// 立即加载一次数据
loadStats();
loadVersion();
// 启动自动刷新
startAutoRefresh();
// 绑定检查更新按钮
document.getElementById('check-update-btn').addEventListener('click', checkUpdate);
// 绑定重置统计按钮
document.getElementById('reset-stats-btn').addEventListener('click', resetStats);
// 页面可见性变化时控制刷新
document.addEventListener('visibilitychange', function() {
if (document.hidden) {
stopAutoRefresh();
} else {
loadStats();
startAutoRefresh();
}
});
});
// 页面卸载时停止刷新
window.addEventListener('beforeunload', function() {
stopAutoRefresh();
});

View File

@@ -0,0 +1,172 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>GoDNS 监控面板</title>
<link rel="icon"
href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🌐</text></svg>">
<link rel="stylesheet" href="style.css?v=1.2.4">
</head>
<body>
<div class="container">
<header>
<h1>GoDNS 监控面板</h1>
<div class="update-info">
<span id="last-update">正在加载...</span>
</div>
</header>
<div class="dashboard">
<!-- 运行时信息 -->
<section class="card">
<h2>运行时信息</h2>
<div class="stats-grid">
<div class="stat-item">
<span class="stat-label">运行时长</span>
<span class="stat-value" id="uptime">-</span>
</div>
<div class="stat-item">
<span class="stat-label">Goroutines</span>
<span class="stat-value" id="goroutines">-</span>
</div>
<div class="stat-item">
<span class="stat-label">已分配内存</span>
<span class="stat-value" id="mem-alloc">-</span>
</div>
<div class="stat-item">
<span class="stat-label">系统内存</span>
<span class="stat-value" id="mem-sys">-</span>
</div>
<div class="stat-item">
<span class="stat-label">总分配内存</span>
<span class="stat-value" id="mem-total">-</span>
</div>
<div class="stat-item">
<span class="stat-label">GC 次数</span>
<span class="stat-value" id="num-gc">-</span>
</div>
</div>
</section>
<!-- DNS 查询统计 -->
<section class="card">
<div class="card-header">
<h2>DNS 查询统计</h2>
<div class="stats-controls">
<span class="stats-duration" id="stats-duration">统计时长: -</span>
<button id="reset-stats-btn" class="reset-btn" title="重置统计数据">🔄 重置</button>
</div>
</div>
<div class="stats-grid">
<div class="stat-item highlight">
<span class="stat-label">总查询数</span>
<span class="stat-value" id="total-queries">-</span>
</div>
<div class="stat-item info">
<span class="stat-label">DoH 请求</span>
<span class="stat-value" id="doh-queries">-</span>
</div>
<div class="stat-item success">
<span class="stat-label">缓存命中</span>
<span class="stat-value" id="cache-hits">-</span>
</div>
<div class="stat-item">
<span class="stat-label">缓存未命中</span>
<span class="stat-value" id="cache-misses">-</span>
</div>
<div class="stat-item warning">
<span class="stat-label">失败查询</span>
<span class="stat-value" id="failed-queries">-</span>
</div>
<div class="stat-item highlight">
<span class="stat-label">缓存命中率</span>
<span class="stat-value" id="hit-rate">-</span>
</div>
</div>
</section>
<!-- 上游服务器统计 -->
<section class="card full-width">
<h2>上游服务器统计</h2>
<div class="table-container">
<table id="upstream-table">
<thead>
<tr>
<th>服务器地址</th>
<th>总查询数</th>
<th>错误数</th>
<th>错误率</th>
<th>最后使用</th>
</tr>
</thead>
<tbody id="upstream-tbody">
<tr>
<td colspan="5" class="no-data">暂无数据</td>
</tr>
</tbody>
</table>
</div>
</section>
<!-- Top 客户端 IP -->
<section class="card">
<h2>Top 客户端 IP</h2>
<div class="table-container">
<table id="top-clients-table">
<thead>
<tr>
<th>排名</th>
<th>IP 地址</th>
<th>查询次数</th>
</tr>
</thead>
<tbody id="top-clients-tbody">
<tr>
<td colspan="3" class="no-data">暂无数据</td>
</tr>
</tbody>
</table>
</div>
</section>
<!-- Top 查询域名 -->
<section class="card">
<h2>Top 查询域名</h2>
<div class="table-container">
<table id="top-domains-table">
<thead>
<tr>
<th>排名</th>
<th>域名</th>
<th>查询次数</th>
<th>Top 客户端</th>
</tr>
</thead>
<tbody id="top-domains-tbody">
<tr>
<td colspan="4" class="no-data">暂无数据</td>
</tr>
</tbody>
</table>
</div>
</section>
</div>
<footer>
<div class="footer-content">
<span>GoDNS - 智能 DNS 服务器</span>
<div class="version-info">
<span id="version-display">v0.0.0</span>
<button id="check-update-btn" class="update-btn" title="检查更新">🔄</button>
</div>
</div>
</footer>
</div>
<script src="app.js?v=1.2.0"></script>
</body>
</html>

View File

@@ -0,0 +1,358 @@
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", sans-serif;
background: #0f172a;
min-height: 100vh;
padding: 20px;
color: #f1f5f9;
}
.container {
max-width: 1600px;
margin: 0 auto;
}
header {
background: #1e293b;
padding: 24px 32px;
border-radius: 12px;
border: 1px solid #334155;
margin-bottom: 24px;
display: flex;
justify-content: space-between;
align-items: center;
}
h1 {
font-size: 1.875rem;
font-weight: 600;
color: #f1f5f9;
}
.update-info {
display: flex;
align-items: center;
gap: 12px;
}
#last-update {
color: #94a3b8;
font-size: 0.875rem;
}
.dashboard {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(min(100%, 600px), 1fr));
gap: 16px;
}
.card {
background: #1e293b;
padding: 24px;
border-radius: 12px;
border: 1px solid #334155;
}
.card.full-width {
grid-column: 1 / -1;
}
h2 {
color: #f1f5f9;
margin-bottom: 20px;
font-size: 1.125rem;
font-weight: 600;
border-bottom: 1px solid #334155;
padding-bottom: 12px;
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
gap: 12px;
flex-wrap: wrap;
margin-bottom: 20px;
}
.card-header h2 {
margin-bottom: 0;
padding-bottom: 0;
border-bottom: none;
flex: 1;
min-width: 200px;
}
.stats-controls {
display: flex;
align-items: center;
gap: 12px;
}
.stats-duration {
color: #94a3b8;
font-size: 0.875rem;
}
.reset-btn {
background: #f1f5f9;
color: #0f172a;
border: none;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
font-size: 0.875rem;
font-weight: 500;
transition: all 0.15s;
}
.reset-btn:hover:not(:disabled) {
background: #cbd5e1;
}
.reset-btn:disabled {
cursor: not-allowed;
opacity: 0.5;
}
.stats-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(min(100%, 200px), 1fr));
gap: 12px;
}
.stat-item {
background: #0f172a;
padding: 16px;
border-radius: 8px;
border: 1px solid #334155;
display: flex;
flex-direction: column;
gap: 8px;
transition: border-color 0.15s;
}
.stat-item:hover {
border-color: #475569;
}
.stat-item.highlight {
background: #1e293b;
border-color: #3b82f6;
}
.stat-item.success {
background: #1e293b;
border-color: #22c55e;
}
.stat-item.info {
background: #1e293b;
border-color: #06b6d4;
}
.stat-item.warning {
background: #1e293b;
border-color: #f59e0b;
}
.stat-label {
font-size: 0.875rem;
color: #94a3b8;
}
.stat-value {
font-size: 1.5rem;
font-weight: 600;
color: #f1f5f9;
}
.table-container {
overflow-x: auto;
}
table {
width: 100%;
border-collapse: collapse;
margin-top: 8px;
}
thead {
background: #0f172a;
border-bottom: 1px solid #334155;
}
th {
padding: 12px 16px;
text-align: left;
font-weight: 500;
font-size: 0.875rem;
color: #94a3b8;
}
tbody tr {
border-bottom: 1px solid #334155;
transition: background 0.15s;
}
tbody tr:hover {
background: #0f172a;
}
td {
padding: 12px 16px;
font-size: 0.875rem;
}
.no-data {
text-align: center;
color: #64748b;
padding: 24px;
}
.error-high {
color: #ef4444;
font-weight: 500;
}
.rank-cell {
font-weight: 600;
text-align: center;
width: 60px;
}
#top-clients-table th:nth-child(1),
#top-clients-table td:nth-child(1),
#top-domains-table th:nth-child(1),
#top-domains-table td:nth-child(1) {
width: 60px;
}
.rank-1 {
background: rgba(234, 179, 8, 0.1);
}
.rank-1 .rank-cell {
color: #eab308;
}
.rank-2 {
background: rgba(148, 163, 184, 0.1);
}
.rank-2 .rank-cell {
color: #94a3b8;
}
.rank-3 {
background: rgba(251, 146, 60, 0.1);
}
.rank-3 .rank-cell {
color: #fb923c;
}
.domain-cell {
max-width: 300px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
footer {
text-align: center;
color: #94a3b8;
margin-top: 24px;
padding: 20px;
font-size: 0.875rem;
}
.footer-content {
display: flex;
justify-content: center;
align-items: center;
gap: 16px;
flex-wrap: wrap;
}
.version-info {
display: flex;
align-items: center;
gap: 8px;
background: #1e293b;
padding: 6px 12px;
border-radius: 6px;
border: 1px solid #334155;
}
#version-display {
font-weight: 500;
}
.update-btn {
background: transparent;
color: #94a3b8;
border: 1px solid #334155;
width: 28px;
height: 28px;
border-radius: 6px;
cursor: pointer;
font-size: 14px;
display: flex;
align-items: center;
justify-content: center;
transition: all 0.15s;
}
.update-btn:hover:not(:disabled) {
background: #334155;
color: #f1f5f9;
}
.update-btn:disabled {
cursor: not-allowed;
opacity: 0.5;
}
@media (max-width: 768px) {
body {
padding: 12px;
}
header {
flex-direction: column;
gap: 12px;
padding: 16px;
}
h1 {
font-size: 1.5rem;
}
.card {
padding: 16px;
}
.stats-grid {
grid-template-columns: 1fr;
}
table {
font-size: 0.8rem;
min-width: 500px;
}
th, td {
padding: 8px;
}
#upstream-table th:nth-child(5),
#upstream-table td:nth-child(5) {
display: none;
}
}