package main import ( "errors" "log" "math/rand" "net" "net/http" _ "net/http/pprof" "os" "os/signal" "path/filepath" "strings" "syscall" "time" "github.com/blang/semver" "github.com/miekg/dns" "github.com/rhysd/go-github-selfupdate/selfupdate" "github.com/yl2chen/cidranger" "godns/internal/handler" "godns/internal/model" "godns/internal/stats" "godns/internal/web" "godns/pkg/doh" "godns/pkg/logger" ) var ( version string config *model.Config dataPath string ) func main() { dataPath = detectDataPath() ipRanger := loadIPRanger(dataPath + "china_ip_list.txt") // 先创建一个临时 logger 用于读取配置 tempLogger := logger.New(false) config = &model.Config{} if err := config.ReadInConfig(dataPath+"/config.json", ipRanger, tempLogger); err != nil { panic(err) } // 设置默认 Web 监听地址 if config.WebAddr == "" { config.WebAddr = "0.0.0.0:8854" } // 根据配置创建正式的 logger 和 stats 实例 debugLogger := logger.New(config.Debug) statsRecorder := stats.NewStats() // 加载持久化的统计数据 if err := statsRecorder.Load(dataPath); err != nil { log.Printf("Failed to load stats from disk: %v", err) } else { log.Printf("Stats loaded successfully from disk") } // 更新 upstreams 的 logger 为正式的 logger for i := 0; i < len(config.Bootstrap); i++ { config.Bootstrap[i].SetLogger(debugLogger) } for i := 0; i < len(config.Upstreams); i++ { config.Upstreams[i].SetLogger(debugLogger) } // Bootstrap handler 不需要缓存,只是用于初始化连接 bootstrapHandler := handler.NewHandler(model.StrategyAnyResult, false, config.Bootstrap, dataPath, debugLogger, nil) for i := 0; i < len(config.Upstreams); i++ { config.Upstreams[i].InitConnectionPool(bootstrapHandler.LookupIP) } server := &dns.Server{Addr: config.ServeAddr, Net: "udp"} serverTCP := &dns.Server{Addr: config.ServeAddr, Net: "tcp"} // 只有 upstream handler 需要缓存 upstreamHandler := handler.NewHandler(config.Strategy, config.BuiltInCache, config.Upstreams, dataPath, debugLogger, statsRecorder) dns.HandleFunc(".", upstreamHandler.HandleRequest) // Setup graceful shutdown defer func() { // 保存统计数据 log.Printf("Saving stats before shutdown...") if err := statsRecorder.Save(dataPath); err != nil { log.Printf("Error saving stats: %v", err) } else { log.Printf("Stats saved successfully") } // 关闭缓存 if err := upstreamHandler.Close(); err != nil { log.Printf("Error closing cache: %v", err) } }() log.Println("==== DNS Server ====") log.Println("端口:", config.ServeAddr) log.Println("模式:", config.StrategyName()) log.Println("数据:", dataPath) if config.BuiltInCache { log.Println("启用 BadgerDB 缓存: 最大 40MB") } else { log.Println("禁用缓存") } log.Println("版本:", version) // 创建更新检查通道 checkUpdateCh := make(chan struct{}, 1) // 启动 Web 服务(监控面板 + DoH + pprof) webServerHandler := http.NewServeMux() // 注册监控面板路由 var webUsername, webPassword string if config.WebAuth != nil { webUsername = config.WebAuth.Username webPassword = config.WebAuth.Password } webHandler := web.NewHandler(statsRecorder, version, checkUpdateCh, debugLogger, webUsername, webPassword) webHandler.RegisterRoutes(webServerHandler) // 如果启用 DoH,注册 DoH 路由 if config.DohServer != nil { dohServer := doh.NewServer(config.DohServer.Username, config.DohServer.Password, upstreamHandler.HandleDnsMsg, statsRecorder) dohServer.RegisterRoutes(webServerHandler) log.Printf("DoH 服务: http://%s/dns-query", config.WebAddr) } // 如果启用 profiling,注册 pprof 路由 if config.Profiling { webServerHandler.HandleFunc("/debug/", http.DefaultServeMux.ServeHTTP) log.Printf("性能分析: http://%s/debug/pprof/", config.WebAddr) } go http.ListenAndServe(config.WebAddr, webServerHandler) log.Printf("监控面板: http://%s/", config.WebAddr) // 定时保存统计数据(使用配置的间隔) statsSaveTicker := time.NewTicker(time.Duration(config.StatsSaveInterval) * time.Minute) defer statsSaveTicker.Stop() go func() { for range statsSaveTicker.C { if err := statsRecorder.Save(dataPath); err != nil { debugLogger.Printf("Failed to save stats to disk: %v", err) } else { debugLogger.Printf("Stats saved successfully to disk") } } }() stopCh := make(chan error) // 启动后台更新检查 go checkUpdate(checkUpdateCh, stopCh, debugLogger) // 定时触发更新检查(生产者1:定时器) if version != "" { go func() { // 启动时立即检查一次 select { case checkUpdateCh <- struct{}{}: default: } // 定时检查 ticker := time.NewTicker(time.Duration(40+rand.Intn(20)) * time.Minute) defer ticker.Stop() for range ticker.C { select { case checkUpdateCh <- struct{}{}: default: // 如果通道已满,跳过本次 } } }() } go func() { stopCh <- server.ListenAndServe() }() go func() { stopCh <- serverTCP.ListenAndServe() }() sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigCh log.Println("Shutting down...") stopCh <- errors.New("shutdown signal received") }() log.Printf("server stopped: %+v", <-stopCh) } // checkUpdate 监听 channel 触发更新检查 func checkUpdate(checkCh <-chan struct{}, stopCh chan<- error, debugLogger logger.Logger) { for range checkCh { // 如果 version 为空,使用默认值 ver := version if ver == "" { ver = "0.0.0" } v := semver.MustParse(ver) latest, err := selfupdate.UpdateSelf(v, "xofine/godns") if err != nil { debugLogger.Printf("Error checking for updates: %v", err) continue } if latest.Version.Equals(v) { debugLogger.Printf("No update available, current version: %s", v) } else { log.Printf("Updated to version: %s", latest.Version) stopCh <- errors.New("Server upgraded to " + latest.Version.String()) return } } } func loadIPRanger(path string) cidranger.Ranger { ipRanger := cidranger.NewPCTrieRanger() content, err := os.ReadFile(path) if err != nil { panic(err) } lines := strings.Split(string(content), "\n") for i := 0; i < len(lines); i++ { if strings.TrimSpace(lines[i]) == "" { continue } _, network, err := net.ParseCIDR(lines[i]) if err != nil { panic(err) } if err := ipRanger.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil { panic(err) } } return ipRanger } func detectDataPath() string { ex, err := os.Executable() if err != nil { panic(err) } pwd, err := os.Getwd() if err != nil { panic(err) } pathList := []string{filepath.Dir(ex), pwd} for _, path := range pathList { if f, err := os.Stat(path + "/data/china_ip_list.txt"); err == nil { if f.Size() == 1024*200 { panic("离线IP库 china_ip_list.txt 文件损坏,请重新下载") } return path + "/data/" } } panic("没有检测到IP数据 data/china_ip_list.txt") }