update once
This commit is contained in:
282
internal/model/upstream.go
Normal file
282
internal/model/upstream.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dropbox/godropbox/net2"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/yl2chen/cidranger"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"godns/pkg/doh"
|
||||
"godns/pkg/logger"
|
||||
"godns/pkg/utils"
|
||||
)
|
||||
|
||||
type Upstream struct {
|
||||
IsPrimary bool `json:"is_primary,omitempty"`
|
||||
UseSocks bool `json:"use_socks,omitempty"`
|
||||
Address string `json:"address,omitempty"`
|
||||
Match []string `json:"match,omitempty"`
|
||||
|
||||
protocol, hostAndPort, host, port string
|
||||
config *Config
|
||||
ipRanger cidranger.Ranger
|
||||
matchSplited [][]string
|
||||
|
||||
pool net2.ConnectionPool
|
||||
dohClient *doh.Client
|
||||
bootstrap func(host string) (net.IP, error)
|
||||
logger logger.Logger
|
||||
|
||||
count *atomic.Int64
|
||||
}
|
||||
|
||||
func (up *Upstream) Init(config *Config, ipRanger cidranger.Ranger, log logger.Logger) {
|
||||
var ok bool
|
||||
up.protocol, up.hostAndPort, ok = strings.Cut(up.Address, "://")
|
||||
if ok && up.protocol != "https" {
|
||||
up.host, up.port, ok = strings.Cut(up.hostAndPort, ":")
|
||||
}
|
||||
if !ok {
|
||||
panic("上游地址格式(protocol://host:port)有误:" + up.Address)
|
||||
}
|
||||
|
||||
if up.count != nil {
|
||||
panic("Upstream 已经初始化过了:" + up.Address)
|
||||
}
|
||||
|
||||
up.matchSplited = utils.ParseRules(up.Match)
|
||||
up.count = atomic.NewInt64(0)
|
||||
up.config = config
|
||||
up.ipRanger = ipRanger
|
||||
up.logger = log
|
||||
}
|
||||
|
||||
// SetLogger 更新 upstream 的 logger 实例
|
||||
func (up *Upstream) SetLogger(log logger.Logger) {
|
||||
up.logger = log
|
||||
}
|
||||
|
||||
func (up *Upstream) IsMatch(domain string) bool {
|
||||
return utils.HasMatchedRule(up.matchSplited, domain)
|
||||
}
|
||||
|
||||
func (up *Upstream) Validate() error {
|
||||
if !up.IsPrimary && up.protocol == "udp" {
|
||||
return errors.New("非 primary 只能使用 tcp(-tls)/https:" + up.Address)
|
||||
}
|
||||
if up.IsPrimary && up.UseSocks {
|
||||
return errors.New("primary 无需接入 socks:" + up.Address)
|
||||
}
|
||||
if up.UseSocks && up.config.SocksProxy == "" {
|
||||
return errors.New("socks 未配置,但是上游已启用:" + up.Address)
|
||||
}
|
||||
if up.IsPrimary && up.protocol != "udp" {
|
||||
up.logger.Println("[WARN] Primary 建议使用 udp 加速获取结果:" + up.Address)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Upstream) conntionFactory(network, address string) (net.Conn, error) {
|
||||
up.logger.Printf("connecting to %s://%s", network, address)
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if up.bootstrap != nil && net.ParseIP(host) == nil {
|
||||
ip, err := up.bootstrap(host)
|
||||
if err != nil {
|
||||
address = fmt.Sprintf("%s:%s", "0.0.0.0", port)
|
||||
} else {
|
||||
address = fmt.Sprintf("%s:%s", ip.String(), port)
|
||||
}
|
||||
}
|
||||
|
||||
if up.UseSocks {
|
||||
d, _, err := up.config.GetDialerContext(&net.Dialer{
|
||||
Timeout: time.Second * time.Duration(up.config.Timeout),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch network {
|
||||
case "tcp":
|
||||
return d.Dial(network, address)
|
||||
case "tcp-tls":
|
||||
conn, err := d.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tls.Client(conn, &tls.Config{
|
||||
ServerName: host,
|
||||
}), nil
|
||||
}
|
||||
} else {
|
||||
var d net.Dialer
|
||||
d.Timeout = time.Second * time.Duration(up.config.Timeout)
|
||||
switch network {
|
||||
case "tcp":
|
||||
return d.Dial(network, address)
|
||||
case "tcp-tls":
|
||||
return tls.DialWithDialer(&d, "tcp", address, &tls.Config{
|
||||
ServerName: host,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
panic("wrong protocol: " + network)
|
||||
}
|
||||
|
||||
func (up *Upstream) InitConnectionPool(bootstrap func(host string) (net.IP, error)) {
|
||||
up.bootstrap = bootstrap
|
||||
|
||||
if strings.Contains(up.protocol, "http") {
|
||||
ops := []doh.ClientOption{
|
||||
doh.WithServer(up.Address),
|
||||
doh.WithBootstrap(bootstrap),
|
||||
doh.WithTimeout(time.Second * time.Duration(up.config.Timeout)),
|
||||
doh.WithLogger(up.logger),
|
||||
}
|
||||
if up.UseSocks {
|
||||
ops = append(ops, doh.WithSocksProxy(up.config.GetDialerContext))
|
||||
}
|
||||
up.dohClient = doh.NewClient(ops...)
|
||||
}
|
||||
|
||||
// 只需要启用 tcp/tcp-tls 协议的连接池
|
||||
if strings.Contains(up.protocol, "tcp") {
|
||||
maxIdleTime := time.Second * time.Duration(up.config.Timeout*10)
|
||||
timeout := time.Second * time.Duration(up.config.Timeout)
|
||||
p := net2.NewSimpleConnectionPool(net2.ConnectionOptions{
|
||||
MaxActiveConnections: int32(up.config.MaxActiveConnections),
|
||||
MaxIdleConnections: uint32(up.config.MaxIdleConnections),
|
||||
MaxIdleTime: &maxIdleTime,
|
||||
DialMaxConcurrency: 20,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
Dial: func(network, address string) (net.Conn, error) {
|
||||
dialer, err := up.conntionFactory(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dialer.SetDeadline(time.Now().Add(timeout))
|
||||
return dialer, nil
|
||||
},
|
||||
})
|
||||
p.Register(up.protocol, up.hostAndPort)
|
||||
up.pool = p
|
||||
}
|
||||
}
|
||||
|
||||
func (up *Upstream) IsValidMsg(r *dns.Msg) bool {
|
||||
domain := GetDomainNameFromDnsMsg(r)
|
||||
inBlacklist := utils.HasMatchedRule(up.config.BlacklistSplited, domain)
|
||||
for i := 0; i < len(r.Answer); i++ {
|
||||
var ip net.IP
|
||||
typeA, ok := r.Answer[i].(*dns.A)
|
||||
if ok {
|
||||
ip = typeA.A
|
||||
} else {
|
||||
typeAAAA, ok := r.Answer[i].(*dns.AAAA)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ip = typeAAAA.AAAA
|
||||
}
|
||||
isPrimary, err := up.ipRanger.Contains(ip)
|
||||
if err != nil {
|
||||
up.logger.Printf("ipRanger query ip %s failed: %s", ip, err)
|
||||
continue
|
||||
}
|
||||
|
||||
up.logger.Printf("checkPrimary result %s: %s@%s ->domain.inBlacklist:%v ip.IsPrimary:%v up.IsPrimary:%v", up.Address, domain, ip, inBlacklist, isPrimary, up.IsPrimary)
|
||||
|
||||
// 黑名单中的域名,如果是 primary 即不可用
|
||||
if inBlacklist && isPrimary {
|
||||
return false
|
||||
}
|
||||
// 如果是 server 是 primary,但是 ip 不是 primary,也不可用
|
||||
if up.IsPrimary && !isPrimary {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return !up.IsPrimary || len(r.Answer) > 0
|
||||
}
|
||||
|
||||
func GetDomainNameFromDnsMsg(msg *dns.Msg) string {
|
||||
if msg == nil || len(msg.Question) == 0 {
|
||||
return ""
|
||||
}
|
||||
return msg.Question[0].Name
|
||||
}
|
||||
|
||||
func (up *Upstream) poolLen() int32 {
|
||||
if up.pool == nil {
|
||||
return 0
|
||||
}
|
||||
return up.pool.NumActive()
|
||||
}
|
||||
|
||||
func (up *Upstream) Exchange(req *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Inc(), up.poolLen(), runtime.NumGoroutine(), "enter")
|
||||
defer up.logger.Printf("tracing exchange %s worker_count: %d pool_count: %d go_routine: %d --> %s", up.Address, up.count.Dec(), up.poolLen(), runtime.NumGoroutine(), "exit")
|
||||
|
||||
var resp *dns.Msg
|
||||
var duration time.Duration
|
||||
var err error
|
||||
|
||||
switch up.protocol {
|
||||
case "https", "http":
|
||||
resp, duration, err = up.dohClient.Exchange(req)
|
||||
case "udp":
|
||||
client := new(dns.Client)
|
||||
client.Timeout = time.Second * time.Duration(up.config.Timeout)
|
||||
resp, duration, err = client.Exchange(req, up.hostAndPort)
|
||||
case "tcp", "tcp-tls":
|
||||
conn, errGetConn := up.pool.Get(up.protocol, up.hostAndPort)
|
||||
if errGetConn != nil {
|
||||
return nil, 0, errGetConn
|
||||
}
|
||||
resp, err = dnsExchangeWithConn(conn, req)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid upstream protocol: %s in address %s", up.protocol, up.Address))
|
||||
}
|
||||
|
||||
// 清理 EDNS 信息
|
||||
if resp != nil && len(resp.Extra) > 0 {
|
||||
var newExtra []dns.RR
|
||||
for i := 0; i < len(resp.Extra); i++ {
|
||||
if resp.Extra[i].Header().Rrtype == dns.TypeOPT {
|
||||
continue
|
||||
}
|
||||
newExtra = append(newExtra, resp.Extra[i])
|
||||
}
|
||||
resp.Extra = newExtra
|
||||
}
|
||||
|
||||
return resp, duration, err
|
||||
}
|
||||
|
||||
func dnsExchangeWithConn(conn net2.ManagedConn, req *dns.Msg) (*dns.Msg, error) {
|
||||
var resp *dns.Msg
|
||||
co := dns.Conn{Conn: conn}
|
||||
err := co.WriteMsg(req)
|
||||
if err == nil {
|
||||
resp, err = co.ReadMsg()
|
||||
}
|
||||
if err == nil {
|
||||
conn.ReleaseConnection()
|
||||
} else {
|
||||
conn.DiscardConnection()
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
Reference in New Issue
Block a user