update once

This commit is contained in:
XOF
2026-01-06 02:25:24 +08:00
commit 7bf4f27be3
25 changed files with 4587 additions and 0 deletions

120
internal/model/config.go Normal file
View File

@@ -0,0 +1,120 @@
package model
import (
"encoding/json"
"net"
"os"
"godns/pkg/logger"
"godns/pkg/utils"
"github.com/pkg/errors"
"github.com/yl2chen/cidranger"
"golang.org/x/net/proxy"
)
const (
_ = iota
StrategyFullest
StrategyFastest
StrategyAnyResult
)
type DohServerConfig struct {
Username string `json:"username,omitempty"` // DoH Basic Auth 用户名(可选)
Password string `json:"password,omitempty"` // DoH Basic Auth 密码(可选)
}
type WebAuth struct {
Username string `json:"username"`
Password string `json:"password"`
}
type Config struct {
ServeAddr string `json:"serve_addr,omitempty"`
WebAddr string `json:"web_addr,omitempty"`
DohServer *DohServerConfig `json:"doh_server,omitempty"`
Strategy int `json:"strategy,omitempty"`
Timeout int `json:"timeout,omitempty"`
SocksProxy string `json:"socks_proxy,omitempty"`
BuiltInCache bool `json:"built_in_cache,omitempty"`
Upstreams []*Upstream `json:"upstreams,omitempty"`
Bootstrap []*Upstream `json:"bootstrap,omitempty"`
Blacklist []string `json:"blacklist,omitempty"`
Debug bool `json:"debug,omitempty"`
Profiling bool `json:"profiling,omitempty"`
// Connection pool settings
MaxActiveConnections int `json:"max_active_connections,omitempty"` // Default: 50
MaxIdleConnections int `json:"max_idle_connections,omitempty"` // Default: 20
// Stats persistence interval in minutes
StatsSaveInterval int `json:"stats_save_interval,omitempty"` // Default: 5 minutes
BlacklistSplited [][]string `json:"-"`
// Web 面板鉴权
WebAuth *WebAuth `json:"web_auth,omitempty"`
}
func (c *Config) ReadInConfig(path string, ipRanger cidranger.Ranger, log logger.Logger) error {
body, err := os.ReadFile(path)
if err != nil {
return err
}
if err := json.Unmarshal([]byte(body), c); err != nil {
return err
}
// Set default connection pool values
if c.MaxActiveConnections == 0 {
c.MaxActiveConnections = 50
}
if c.MaxIdleConnections == 0 {
c.MaxIdleConnections = 20
}
// Set default stats save interval (5 minutes)
if c.StatsSaveInterval == 0 {
c.StatsSaveInterval = 5
}
for i := 0; i < len(c.Bootstrap); i++ {
c.Bootstrap[i].Init(c, ipRanger, log)
if net.ParseIP(c.Bootstrap[i].host) == nil {
return errors.New("Bootstrap 服务器只能使用 IP: " + c.Bootstrap[i].Address)
}
c.Bootstrap[i].InitConnectionPool(nil)
}
for i := 0; i < len(c.Upstreams); i++ {
c.Upstreams[i].Init(c, ipRanger, log)
if err := c.Upstreams[i].Validate(); err != nil {
return err
}
}
c.BlacklistSplited = utils.ParseRules(c.Blacklist)
return nil
}
func (c *Config) GetDialerContext(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error) {
dialSocksProxy, err := proxy.SOCKS5("tcp", c.SocksProxy, nil, d)
if err != nil {
return nil, nil, errors.Wrap(err, "Error creating SOCKS5 proxy")
}
if dialContext, ok := dialSocksProxy.(proxy.ContextDialer); !ok {
return nil, nil, errors.New("Failed type assertion to DialContext")
} else {
return dialSocksProxy, dialContext, err
}
}
func (c *Config) StrategyName() string {
switch c.Strategy {
case StrategyFullest:
return "最全结果"
case StrategyFastest:
return "最快结果"
case StrategyAnyResult:
return "任一结果(建议仅 bootstrap"
}
panic("invalid strategy")
}

282
internal/model/upstream.go Normal file
View File

@@ -0,0 +1,282 @@
package model
import (
"crypto/tls"
"fmt"
"net"
"runtime"
"strings"
"time"
"github.com/dropbox/godropbox/net2"
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/yl2chen/cidranger"
"go.uber.org/atomic"
"godns/pkg/doh"
"godns/pkg/logger"
"godns/pkg/utils"
)
type Upstream struct {
IsPrimary bool `json:"is_primary,omitempty"`
UseSocks bool `json:"use_socks,omitempty"`
Address string `json:"address,omitempty"`
Match []string `json:"match,omitempty"`
protocol, hostAndPort, host, port string
config *Config
ipRanger cidranger.Ranger
matchSplited [][]string
pool net2.ConnectionPool
dohClient *doh.Client
bootstrap func(host string) (net.IP, error)
logger logger.Logger
count *atomic.Int64
}
func (up *Upstream) Init(config *Config, ipRanger cidranger.Ranger, log logger.Logger) {
var ok bool
up.protocol, up.hostAndPort, ok = strings.Cut(up.Address, "://")
if ok && up.protocol != "https" {
up.host, up.port, ok = strings.Cut(up.hostAndPort, ":")
}
if !ok {
panic("上游地址格式(protocol://host:port)有误:" + up.Address)
}
if up.count != nil {
panic("Upstream 已经初始化过了:" + up.Address)
}
up.matchSplited = utils.ParseRules(up.Match)
up.count = atomic.NewInt64(0)
up.config = config
up.ipRanger = ipRanger
up.logger = log
}
// SetLogger 更新 upstream 的 logger 实例
func (up *Upstream) SetLogger(log logger.Logger) {
up.logger = log
}
func (up *Upstream) IsMatch(domain string) bool {
return utils.HasMatchedRule(up.matchSplited, domain)
}
func (up *Upstream) Validate() error {
if !up.IsPrimary && up.protocol == "udp" {
return errors.New("非 primary 只能使用 tcp(-tls)/https" + up.Address)
}
if up.IsPrimary && up.UseSocks {
return errors.New("primary 无需接入 socks" + up.Address)
}
if up.UseSocks && up.config.SocksProxy == "" {
return errors.New("socks 未配置,但是上游已启用:" + up.Address)
}
if up.IsPrimary && up.protocol != "udp" {
up.logger.Println("[WARN] Primary 建议使用 udp 加速获取结果:" + up.Address)
}
return nil
}
func (up *Upstream) conntionFactory(network, address string) (net.Conn, error) {
up.logger.Printf("connecting to %s://%s", network, address)
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if up.bootstrap != nil && net.ParseIP(host) == nil {
ip, err := up.bootstrap(host)
if err != nil {
address = fmt.Sprintf("%s:%s", "0.0.0.0", port)
} else {
address = fmt.Sprintf("%s:%s", ip.String(), port)
}
}
if up.UseSocks {
d, _, err := up.config.GetDialerContext(&net.Dialer{
Timeout: time.Second * time.Duration(up.config.Timeout),
})
if err != nil {
return nil, err
}
switch network {
case "tcp":
return d.Dial(network, address)
case "tcp-tls":
conn, err := d.Dial("tcp", address)
if err != nil {
return nil, err
}
return tls.Client(conn, &tls.Config{
ServerName: host,
}), nil
}
} else {
var d net.Dialer
d.Timeout = time.Second * time.Duration(up.config.Timeout)
switch network {
case "tcp":
return d.Dial(network, address)
case "tcp-tls":
return tls.DialWithDialer(&d, "tcp", address, &tls.Config{
ServerName: host,
})
}
}
panic("wrong protocol: " + network)
}
func (up *Upstream) InitConnectionPool(bootstrap func(host string) (net.IP, error)) {
up.bootstrap = bootstrap
if strings.Contains(up.protocol, "http") {
ops := []doh.ClientOption{
doh.WithServer(up.Address),
doh.WithBootstrap(bootstrap),
doh.WithTimeout(time.Second * time.Duration(up.config.Timeout)),
doh.WithLogger(up.logger),
}
if up.UseSocks {
ops = append(ops, doh.WithSocksProxy(up.config.GetDialerContext))
}
up.dohClient = doh.NewClient(ops...)
}
// 只需要启用 tcp/tcp-tls 协议的连接池
if strings.Contains(up.protocol, "tcp") {
maxIdleTime := time.Second * time.Duration(up.config.Timeout*10)
timeout := time.Second * time.Duration(up.config.Timeout)
p := net2.NewSimpleConnectionPool(net2.ConnectionOptions{
MaxActiveConnections: int32(up.config.MaxActiveConnections),
MaxIdleConnections: uint32(up.config.MaxIdleConnections),
MaxIdleTime: &maxIdleTime,
DialMaxConcurrency: 20,
ReadTimeout: timeout,
WriteTimeout: timeout,
Dial: func(network, address string) (net.Conn, error) {
dialer, err := up.conntionFactory(network, address)
if err != nil {
return nil, err
}
dialer.SetDeadline(time.Now().Add(timeout))
return dialer, nil
},
})
p.Register(up.protocol, up.hostAndPort)
up.pool = p
}
}
func (up *Upstream) IsValidMsg(r *dns.Msg) bool {
domain := GetDomainNameFromDnsMsg(r)
inBlacklist := utils.HasMatchedRule(up.config.BlacklistSplited, domain)
for i := 0; i < len(r.Answer); i++ {
var ip net.IP
typeA, ok := r.Answer[i].(*dns.A)
if ok {
ip = typeA.A
} else {
typeAAAA, ok := r.Answer[i].(*dns.AAAA)
if !ok {
continue
}
ip = typeAAAA.AAAA
}
isPrimary, err := up.ipRanger.Contains(ip)
if err != nil {
up.logger.Printf("ipRanger query ip %s failed: %s", ip, err)
continue
}
up.logger.Printf("checkPrimary result %s: %s@%s ->domain.inBlacklist:%v ip.IsPrimary:%v up.IsPrimary:%v", up.Address, domain, ip, inBlacklist, isPrimary, up.IsPrimary)
// 黑名单中的域名,如果是 primary 即不可用
if inBlacklist && isPrimary {
return false
}
// 如果是 server 是 primary但是 ip 不是 primary也不可用
if up.IsPrimary && !isPrimary {
return false
}
}
return !up.IsPrimary || len(r.Answer) > 0
}
func GetDomainNameFromDnsMsg(msg *dns.Msg) string {
if msg == nil || len(msg.Question) == 0 {
return ""
}
return msg.Question[0].Name
}
func (up *Upstream) poolLen() int32 {
if up.pool == nil {
return 0
}
return up.pool.NumActive()
}
func (up *Upstream) Exchange(req *dns.Msg) (*dns.Msg, time.Duration, error) {
up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Inc(), up.poolLen(), runtime.NumGoroutine(), "enter")
defer up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Dec(), up.poolLen(), runtime.NumGoroutine(), "exit")
var resp *dns.Msg
var duration time.Duration
var err error
switch up.protocol {
case "https", "http":
resp, duration, err = up.dohClient.Exchange(req)
case "udp":
client := new(dns.Client)
client.Timeout = time.Second * time.Duration(up.config.Timeout)
resp, duration, err = client.Exchange(req, up.hostAndPort)
case "tcp", "tcp-tls":
conn, errGetConn := up.pool.Get(up.protocol, up.hostAndPort)
if errGetConn != nil {
return nil, 0, errGetConn
}
resp, err = dnsExchangeWithConn(conn, req)
default:
panic(fmt.Sprintf("invalid upstream protocol: %s in address %s", up.protocol, up.Address))
}
// 清理 EDNS 信息
if resp != nil && len(resp.Extra) > 0 {
var newExtra []dns.RR
for i := 0; i < len(resp.Extra); i++ {
if resp.Extra[i].Header().Rrtype == dns.TypeOPT {
continue
}
newExtra = append(newExtra, resp.Extra[i])
}
resp.Extra = newExtra
}
return resp, duration, err
}
func dnsExchangeWithConn(conn net2.ManagedConn, req *dns.Msg) (*dns.Msg, error) {
var resp *dns.Msg
co := dns.Conn{Conn: conn}
err := co.WriteMsg(req)
if err == nil {
resp, err = co.ReadMsg()
}
if err == nil {
conn.ReleaseConnection()
} else {
conn.DiscardConnection()
}
return resp, err
}

View File

@@ -0,0 +1,120 @@
package model
import (
"index/suffixarray"
"strings"
"testing"
"godns/pkg/utils"
)
var primaryLocations = []string{"中国", "省", "市", "自治区"}
var nonPrimaryLocations = []string{"台湾", "香港", "澳门"}
var primaryLocationsBytes = [][]byte{[]byte("中国"), []byte("省"), []byte("市"), []byte("自治区")}
var nonPrimaryLocationsBytes = [][]byte{[]byte("台湾"), []byte("香港"), []byte("澳门")}
func BenchmarkCheckPrimary(b *testing.B) {
for i := 0; i < b.N; i++ {
checkPrimary("哈哈")
}
}
func BenchmarkCheckPrimaryStringsContains(b *testing.B) {
for i := 0; i < b.N; i++ {
checkPrimaryStringsContains("哈哈")
}
}
func TestIsMatch(t *testing.T) {
var up Upstream
up.matchSplited = utils.ParseRules([]string{"."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": true,
"b.a.com.": true,
".b.a.com.cn.": true,
"b.a.com.cn.": true,
"d.b.a.com.": true,
}, t)
up.matchSplited = utils.ParseRules([]string{""})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": false,
"b.a.com.": false,
".b.a.com.cn.": false,
"b.a.com.cn.": false,
"d.b.a.com.": false,
}, t)
up.matchSplited = utils.ParseRules([]string{"a.com."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": true,
"b.a.com.": false,
".b.a.com.cn.": false,
"b.a.com.cn.": false,
"d.b.a.com.": false,
}, t)
up.matchSplited = utils.ParseRules([]string{".a.com."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": false,
"b.a.com.": true,
".b.a.com.cn.": false,
"b.a.com.cn.": false,
"d.b.a.com.": true,
}, t)
up.matchSplited = utils.ParseRules([]string{"b.d.com."})
checkUpstreamMatch(&up, map[string]bool{
"": false,
"a.com.": false,
".a.com.": false,
"b.d.com.": true,
".b.d.com.cn.": false,
"b.d.com.cn.": false,
".c.d.com.": false,
"b.d.a.com.": false,
}, t)
}
func checkUpstreamMatch(up *Upstream, cases map[string]bool, t *testing.T) {
for k, v := range cases {
isMatch := up.IsMatch(k)
if isMatch != v {
t.Errorf("Upstream(%s).IsMatch(%s) = %v, want %v", up.matchSplited, k, isMatch, v)
}
}
}
func checkPrimary(str string) bool {
index := suffixarray.New([]byte(str))
for i := 0; i < len(nonPrimaryLocationsBytes); i++ {
if len(index.Lookup(nonPrimaryLocationsBytes[i], 1)) > 0 {
return false
}
}
for i := 0; i < len(primaryLocationsBytes); i++ {
if len(index.Lookup(primaryLocationsBytes[i], 1)) > 0 {
return true
}
}
return false
}
func checkPrimaryStringsContains(str string) bool {
for i := 0; i < len(nonPrimaryLocations); i++ {
if strings.Contains(str, nonPrimaryLocations[i]) {
return false
}
}
for i := 0; i < len(primaryLocations); i++ {
if strings.Contains(str, primaryLocations[i]) {
return true
}
}
return false
}