update once

This commit is contained in:
XOF
2026-01-06 02:25:24 +08:00
commit 7bf4f27be3
25 changed files with 4587 additions and 0 deletions

173
pkg/doh/client.go Normal file
View File

@@ -0,0 +1,173 @@
package doh
import (
"context"
"encoding/base64"
"io"
"net"
"net/http"
"net/http/httptrace"
"strings"
"time"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/proxy"
)
const (
dohMediaType = "application/dns-message"
)
// Logger 定义可选的日志接口
type Logger interface {
Printf(format string, v ...interface{})
}
type clientOptions struct {
timeout time.Duration
server string
bootstrap func(domain string) (net.IP, error)
getDialer func(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error)
logger Logger
}
type ClientOption func(*clientOptions) error
func WithTimeout(t time.Duration) ClientOption {
return func(o *clientOptions) error {
o.timeout = t
return nil
}
}
func WithSocksProxy(getDialer func(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error)) ClientOption {
return func(o *clientOptions) error {
o.getDialer = getDialer
return nil
}
}
func WithServer(server string) ClientOption {
return func(o *clientOptions) error {
o.server = server
return nil
}
}
func WithBootstrap(resolver func(domain string) (net.IP, error)) ClientOption {
return func(o *clientOptions) error {
o.bootstrap = resolver
return nil
}
}
func WithLogger(logger Logger) ClientOption {
return func(o *clientOptions) error {
o.logger = logger
return nil
}
}
type Client struct {
opt *clientOptions
cli *http.Client
traceCtx context.Context
}
func NewClient(opts ...ClientOption) *Client {
o := new(clientOptions)
for _, f := range opts {
f(o)
}
clientTrace := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
if o.logger != nil {
o.logger.Printf("http conn was reused: %t", info.Reused)
}
},
}
var transport *http.Transport
if o.bootstrap != nil {
transport = &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
urls := strings.Split(address, ":")
ipv4, err := o.bootstrap(urls[0])
if err != nil {
return nil, errors.Wrap(err, "bootstrap")
}
urls[0] = ipv4.String()
if o.getDialer != nil {
dialer, _, err := o.getDialer(&net.Dialer{
Timeout: o.timeout,
})
if err != nil {
return nil, err
}
return dialer.Dial("tcp", strings.Join(urls, ":"))
}
return (&net.Dialer{
Timeout: o.timeout,
}).DialContext(ctx, network, strings.Join(urls, ":"))
},
}
}
return &Client{
opt: o,
traceCtx: httptrace.WithClientTrace(context.Background(), clientTrace),
cli: &http.Client{
Transport: transport,
Timeout: o.timeout,
},
}
}
func (c *Client) Exchange(req *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) {
var (
buf []byte
begin = time.Now()
origID = req.Id
hreq *http.Request
)
// Set DNS ID as zero accoreding to RFC8484 (cache friendly)
req.Id = 0
buf, err = req.Pack()
if err != nil {
return
}
hreq, err = http.NewRequestWithContext(c.traceCtx, http.MethodGet, c.opt.server+"?dns="+base64.RawURLEncoding.EncodeToString(buf), nil)
if err != nil {
return
}
hreq.Header.Add("Accept", dohMediaType)
hreq.Header.Add("User-Agent", "godns-doh-client/0.1")
resp, err := c.cli.Do(hreq)
if err != nil {
return
}
defer resp.Body.Close()
content, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if resp.StatusCode != http.StatusOK {
err = errors.New("DoH query failed: " + string(content))
return
}
r = new(dns.Msg)
err = r.Unpack(content)
r.Id = origID
rtt = time.Since(begin)
return
}

132
pkg/doh/server.go Normal file
View File

@@ -0,0 +1,132 @@
package doh
import (
"encoding/base64"
"net"
"net/http"
"strings"
"github.com/miekg/dns"
"godns/internal/stats"
)
type DoHServer struct {
username, password string
handler func(req *dns.Msg, clientIP, domain string) *dns.Msg
stats stats.StatsRecorder
}
func NewServer(username, password string, handler func(req *dns.Msg, clientIP, domain string) *dns.Msg, statsRecorder stats.StatsRecorder) *DoHServer {
return &DoHServer{
username: username,
password: password,
handler: handler,
stats: statsRecorder,
}
}
// RegisterRoutes 注册 DoH 路由到现有的 HTTP 服务器
func (s *DoHServer) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/dns-query", s.handleQuery)
}
func (s *DoHServer) handleQuery(w http.ResponseWriter, r *http.Request) {
if s.username != "" && s.password != "" {
username, password, ok := r.BasicAuth()
if !ok || username != s.username || password != s.password {
w.Header().Set("WWW-Authenticate", `Basic realm="dns"`)
w.WriteHeader(http.StatusUnauthorized)
return
}
}
accept := r.Header.Get("Accept")
if accept != dohMediaType {
w.WriteHeader(http.StatusUnsupportedMediaType)
w.Write([]byte("unsupported media type: " + accept))
return
}
query := r.URL.Query().Get("dns")
if query == "" {
w.WriteHeader(http.StatusBadRequest)
return
}
data, err := base64.RawURLEncoding.DecodeString(query)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
msg := new(dns.Msg)
if err := msg.Unpack(data); err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
// 记录 DoH 查询统计
if s.stats != nil {
s.stats.RecordDoHQuery()
}
// 提取客户端 IP
clientIP := extractClientIP(r)
// 提取域名
var domain string
if len(msg.Question) > 0 {
domain = msg.Question[0].Name
}
resp := s.handler(msg, clientIP, domain)
if resp == nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("nil response"))
return
}
data, err = resp.Pack()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
}
w.Header().Set("Content-Type", dohMediaType)
w.Write(data)
}
// extractClientIP 从 HTTP 请求中提取真实的客户端 IP
func extractClientIP(r *http.Request) string {
// 1. 优先检查 X-Forwarded-For适用于多层代理
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// X-Forwarded-For 格式: client, proxy1, proxy2
// 取第一个 IP最原始的客户端 IP
parts := strings.Split(xff, ",")
if len(parts) > 0 {
clientIP := strings.TrimSpace(parts[0])
// 验证是否为有效 IP
if ip := net.ParseIP(clientIP); ip != nil {
return clientIP
}
}
}
// 2. 检查 X-Real-IP单层代理常用
if xri := r.Header.Get("X-Real-IP"); xri != "" {
if ip := net.ParseIP(xri); ip != nil {
return xri
}
}
// 3. 使用 RemoteAddr需要去掉端口号
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return host
}
// 4. 如果无法解析端口,直接返回(可能已经是纯 IP
return r.RemoteAddr
}