Files
godns/internal/model/upstream.go
2026-01-06 02:25:24 +08:00

283 lines
7.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package model
import (
"crypto/tls"
"fmt"
"net"
"runtime"
"strings"
"time"
"github.com/dropbox/godropbox/net2"
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/yl2chen/cidranger"
"go.uber.org/atomic"
"godns/pkg/doh"
"godns/pkg/logger"
"godns/pkg/utils"
)
type Upstream struct {
IsPrimary bool `json:"is_primary,omitempty"`
UseSocks bool `json:"use_socks,omitempty"`
Address string `json:"address,omitempty"`
Match []string `json:"match,omitempty"`
protocol, hostAndPort, host, port string
config *Config
ipRanger cidranger.Ranger
matchSplited [][]string
pool net2.ConnectionPool
dohClient *doh.Client
bootstrap func(host string) (net.IP, error)
logger logger.Logger
count *atomic.Int64
}
func (up *Upstream) Init(config *Config, ipRanger cidranger.Ranger, log logger.Logger) {
var ok bool
up.protocol, up.hostAndPort, ok = strings.Cut(up.Address, "://")
if ok && up.protocol != "https" {
up.host, up.port, ok = strings.Cut(up.hostAndPort, ":")
}
if !ok {
panic("上游地址格式(protocol://host:port)有误:" + up.Address)
}
if up.count != nil {
panic("Upstream 已经初始化过了:" + up.Address)
}
up.matchSplited = utils.ParseRules(up.Match)
up.count = atomic.NewInt64(0)
up.config = config
up.ipRanger = ipRanger
up.logger = log
}
// SetLogger 更新 upstream 的 logger 实例
func (up *Upstream) SetLogger(log logger.Logger) {
up.logger = log
}
func (up *Upstream) IsMatch(domain string) bool {
return utils.HasMatchedRule(up.matchSplited, domain)
}
func (up *Upstream) Validate() error {
if !up.IsPrimary && up.protocol == "udp" {
return errors.New("非 primary 只能使用 tcp(-tls)/https" + up.Address)
}
if up.IsPrimary && up.UseSocks {
return errors.New("primary 无需接入 socks" + up.Address)
}
if up.UseSocks && up.config.SocksProxy == "" {
return errors.New("socks 未配置,但是上游已启用:" + up.Address)
}
if up.IsPrimary && up.protocol != "udp" {
up.logger.Println("[WARN] Primary 建议使用 udp 加速获取结果:" + up.Address)
}
return nil
}
func (up *Upstream) conntionFactory(network, address string) (net.Conn, error) {
up.logger.Printf("connecting to %s://%s", network, address)
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if up.bootstrap != nil && net.ParseIP(host) == nil {
ip, err := up.bootstrap(host)
if err != nil {
address = fmt.Sprintf("%s:%s", "0.0.0.0", port)
} else {
address = fmt.Sprintf("%s:%s", ip.String(), port)
}
}
if up.UseSocks {
d, _, err := up.config.GetDialerContext(&net.Dialer{
Timeout: time.Second * time.Duration(up.config.Timeout),
})
if err != nil {
return nil, err
}
switch network {
case "tcp":
return d.Dial(network, address)
case "tcp-tls":
conn, err := d.Dial("tcp", address)
if err != nil {
return nil, err
}
return tls.Client(conn, &tls.Config{
ServerName: host,
}), nil
}
} else {
var d net.Dialer
d.Timeout = time.Second * time.Duration(up.config.Timeout)
switch network {
case "tcp":
return d.Dial(network, address)
case "tcp-tls":
return tls.DialWithDialer(&d, "tcp", address, &tls.Config{
ServerName: host,
})
}
}
panic("wrong protocol: " + network)
}
func (up *Upstream) InitConnectionPool(bootstrap func(host string) (net.IP, error)) {
up.bootstrap = bootstrap
if strings.Contains(up.protocol, "http") {
ops := []doh.ClientOption{
doh.WithServer(up.Address),
doh.WithBootstrap(bootstrap),
doh.WithTimeout(time.Second * time.Duration(up.config.Timeout)),
doh.WithLogger(up.logger),
}
if up.UseSocks {
ops = append(ops, doh.WithSocksProxy(up.config.GetDialerContext))
}
up.dohClient = doh.NewClient(ops...)
}
// 只需要启用 tcp/tcp-tls 协议的连接池
if strings.Contains(up.protocol, "tcp") {
maxIdleTime := time.Second * time.Duration(up.config.Timeout*10)
timeout := time.Second * time.Duration(up.config.Timeout)
p := net2.NewSimpleConnectionPool(net2.ConnectionOptions{
MaxActiveConnections: int32(up.config.MaxActiveConnections),
MaxIdleConnections: uint32(up.config.MaxIdleConnections),
MaxIdleTime: &maxIdleTime,
DialMaxConcurrency: 20,
ReadTimeout: timeout,
WriteTimeout: timeout,
Dial: func(network, address string) (net.Conn, error) {
dialer, err := up.conntionFactory(network, address)
if err != nil {
return nil, err
}
dialer.SetDeadline(time.Now().Add(timeout))
return dialer, nil
},
})
p.Register(up.protocol, up.hostAndPort)
up.pool = p
}
}
func (up *Upstream) IsValidMsg(r *dns.Msg) bool {
domain := GetDomainNameFromDnsMsg(r)
inBlacklist := utils.HasMatchedRule(up.config.BlacklistSplited, domain)
for i := 0; i < len(r.Answer); i++ {
var ip net.IP
typeA, ok := r.Answer[i].(*dns.A)
if ok {
ip = typeA.A
} else {
typeAAAA, ok := r.Answer[i].(*dns.AAAA)
if !ok {
continue
}
ip = typeAAAA.AAAA
}
isPrimary, err := up.ipRanger.Contains(ip)
if err != nil {
up.logger.Printf("ipRanger query ip %s failed: %s", ip, err)
continue
}
up.logger.Printf("checkPrimary result %s: %s@%s ->domain.inBlacklist:%v ip.IsPrimary:%v up.IsPrimary:%v", up.Address, domain, ip, inBlacklist, isPrimary, up.IsPrimary)
// 黑名单中的域名,如果是 primary 即不可用
if inBlacklist && isPrimary {
return false
}
// 如果是 server 是 primary但是 ip 不是 primary也不可用
if up.IsPrimary && !isPrimary {
return false
}
}
return !up.IsPrimary || len(r.Answer) > 0
}
func GetDomainNameFromDnsMsg(msg *dns.Msg) string {
if msg == nil || len(msg.Question) == 0 {
return ""
}
return msg.Question[0].Name
}
func (up *Upstream) poolLen() int32 {
if up.pool == nil {
return 0
}
return up.pool.NumActive()
}
func (up *Upstream) Exchange(req *dns.Msg) (*dns.Msg, time.Duration, error) {
up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Inc(), up.poolLen(), runtime.NumGoroutine(), "enter")
defer up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Dec(), up.poolLen(), runtime.NumGoroutine(), "exit")
var resp *dns.Msg
var duration time.Duration
var err error
switch up.protocol {
case "https", "http":
resp, duration, err = up.dohClient.Exchange(req)
case "udp":
client := new(dns.Client)
client.Timeout = time.Second * time.Duration(up.config.Timeout)
resp, duration, err = client.Exchange(req, up.hostAndPort)
case "tcp", "tcp-tls":
conn, errGetConn := up.pool.Get(up.protocol, up.hostAndPort)
if errGetConn != nil {
return nil, 0, errGetConn
}
resp, err = dnsExchangeWithConn(conn, req)
default:
panic(fmt.Sprintf("invalid upstream protocol: %s in address %s", up.protocol, up.Address))
}
// 清理 EDNS 信息
if resp != nil && len(resp.Extra) > 0 {
var newExtra []dns.RR
for i := 0; i < len(resp.Extra); i++ {
if resp.Extra[i].Header().Rrtype == dns.TypeOPT {
continue
}
newExtra = append(newExtra, resp.Extra[i])
}
resp.Extra = newExtra
}
return resp, duration, err
}
func dnsExchangeWithConn(conn net2.ManagedConn, req *dns.Msg) (*dns.Msg, error) {
var resp *dns.Msg
co := dns.Conn{Conn: conn}
err := co.WriteMsg(req)
if err == nil {
resp, err = co.ReadMsg()
}
if err == nil {
conn.ReleaseConnection()
} else {
conn.DiscardConnection()
}
return resp, err
}