update once
This commit is contained in:
173
pkg/doh/client.go
Normal file
173
pkg/doh/client.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package doh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
dohMediaType = "application/dns-message"
|
||||
)
|
||||
|
||||
// Logger 定义可选的日志接口
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
timeout time.Duration
|
||||
server string
|
||||
bootstrap func(domain string) (net.IP, error)
|
||||
getDialer func(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error)
|
||||
logger Logger
|
||||
}
|
||||
|
||||
type ClientOption func(*clientOptions) error
|
||||
|
||||
func WithTimeout(t time.Duration) ClientOption {
|
||||
return func(o *clientOptions) error {
|
||||
o.timeout = t
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithSocksProxy(getDialer func(d *net.Dialer) (proxy.Dialer, proxy.ContextDialer, error)) ClientOption {
|
||||
return func(o *clientOptions) error {
|
||||
o.getDialer = getDialer
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithServer(server string) ClientOption {
|
||||
return func(o *clientOptions) error {
|
||||
o.server = server
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithBootstrap(resolver func(domain string) (net.IP, error)) ClientOption {
|
||||
return func(o *clientOptions) error {
|
||||
o.bootstrap = resolver
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(logger Logger) ClientOption {
|
||||
return func(o *clientOptions) error {
|
||||
o.logger = logger
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
opt *clientOptions
|
||||
cli *http.Client
|
||||
traceCtx context.Context
|
||||
}
|
||||
|
||||
func NewClient(opts ...ClientOption) *Client {
|
||||
o := new(clientOptions)
|
||||
for _, f := range opts {
|
||||
f(o)
|
||||
}
|
||||
|
||||
clientTrace := &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
if o.logger != nil {
|
||||
o.logger.Printf("http conn was reused: %t", info.Reused)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var transport *http.Transport
|
||||
|
||||
if o.bootstrap != nil {
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
urls := strings.Split(address, ":")
|
||||
ipv4, err := o.bootstrap(urls[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "bootstrap")
|
||||
}
|
||||
urls[0] = ipv4.String()
|
||||
|
||||
if o.getDialer != nil {
|
||||
dialer, _, err := o.getDialer(&net.Dialer{
|
||||
Timeout: o.timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialer.Dial("tcp", strings.Join(urls, ":"))
|
||||
}
|
||||
|
||||
return (&net.Dialer{
|
||||
Timeout: o.timeout,
|
||||
}).DialContext(ctx, network, strings.Join(urls, ":"))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
opt: o,
|
||||
traceCtx: httptrace.WithClientTrace(context.Background(), clientTrace),
|
||||
cli: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: o.timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Exchange(req *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) {
|
||||
var (
|
||||
buf []byte
|
||||
begin = time.Now()
|
||||
origID = req.Id
|
||||
hreq *http.Request
|
||||
)
|
||||
|
||||
// Set DNS ID as zero accoreding to RFC8484 (cache friendly)
|
||||
req.Id = 0
|
||||
buf, err = req.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
hreq, err = http.NewRequestWithContext(c.traceCtx, http.MethodGet, c.opt.server+"?dns="+base64.RawURLEncoding.EncodeToString(buf), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
hreq.Header.Add("Accept", dohMediaType)
|
||||
hreq.Header.Add("User-Agent", "godns-doh-client/0.1")
|
||||
|
||||
resp, err := c.cli.Do(hreq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err = errors.New("DoH query failed: " + string(content))
|
||||
return
|
||||
}
|
||||
|
||||
r = new(dns.Msg)
|
||||
err = r.Unpack(content)
|
||||
r.Id = origID
|
||||
rtt = time.Since(begin)
|
||||
return
|
||||
}
|
||||
132
pkg/doh/server.go
Normal file
132
pkg/doh/server.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package doh
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"godns/internal/stats"
|
||||
)
|
||||
|
||||
type DoHServer struct {
|
||||
username, password string
|
||||
handler func(req *dns.Msg, clientIP, domain string) *dns.Msg
|
||||
stats stats.StatsRecorder
|
||||
}
|
||||
|
||||
func NewServer(username, password string, handler func(req *dns.Msg, clientIP, domain string) *dns.Msg, statsRecorder stats.StatsRecorder) *DoHServer {
|
||||
return &DoHServer{
|
||||
username: username,
|
||||
password: password,
|
||||
handler: handler,
|
||||
stats: statsRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册 DoH 路由到现有的 HTTP 服务器
|
||||
func (s *DoHServer) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/dns-query", s.handleQuery)
|
||||
}
|
||||
|
||||
func (s *DoHServer) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||||
if s.username != "" && s.password != "" {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok || username != s.username || password != s.password {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="dns"`)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
accept := r.Header.Get("Accept")
|
||||
if accept != dohMediaType {
|
||||
w.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
w.Write([]byte("unsupported media type: " + accept))
|
||||
return
|
||||
}
|
||||
|
||||
query := r.URL.Query().Get("dns")
|
||||
if query == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(query)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(data); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 记录 DoH 查询统计
|
||||
if s.stats != nil {
|
||||
s.stats.RecordDoHQuery()
|
||||
}
|
||||
|
||||
// 提取客户端 IP
|
||||
clientIP := extractClientIP(r)
|
||||
|
||||
// 提取域名
|
||||
var domain string
|
||||
if len(msg.Question) > 0 {
|
||||
domain = msg.Question[0].Name
|
||||
}
|
||||
|
||||
resp := s.handler(msg, clientIP, domain)
|
||||
if resp == nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("nil response"))
|
||||
return
|
||||
}
|
||||
|
||||
data, err = resp.Pack()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", dohMediaType)
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// extractClientIP 从 HTTP 请求中提取真实的客户端 IP
|
||||
func extractClientIP(r *http.Request) string {
|
||||
// 1. 优先检查 X-Forwarded-For(适用于多层代理)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// X-Forwarded-For 格式: client, proxy1, proxy2
|
||||
// 取第一个 IP(最原始的客户端 IP)
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
clientIP := strings.TrimSpace(parts[0])
|
||||
// 验证是否为有效 IP
|
||||
if ip := net.ParseIP(clientIP); ip != nil {
|
||||
return clientIP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 检查 X-Real-IP(单层代理常用)
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
if ip := net.ParseIP(xri); ip != nil {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 使用 RemoteAddr,需要去掉端口号
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
return host
|
||||
}
|
||||
|
||||
// 4. 如果无法解析端口,直接返回(可能已经是纯 IP)
|
||||
return r.RemoteAddr
|
||||
}
|
||||
37
pkg/logger/logger.go
Normal file
37
pkg/logger/logger.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Logger 定义日志接口
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
Println(v ...interface{})
|
||||
}
|
||||
|
||||
// DebugLogger 实现 Logger 接口,支持调试模式
|
||||
type DebugLogger struct {
|
||||
Debug bool
|
||||
}
|
||||
|
||||
// New 创建新的日志实例
|
||||
func New(debug bool) Logger {
|
||||
if !debug {
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
return &DebugLogger{Debug: debug}
|
||||
}
|
||||
|
||||
func (l *DebugLogger) Printf(format string, v ...interface{}) {
|
||||
if l.Debug {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *DebugLogger) Println(v ...interface{}) {
|
||||
if l.Debug {
|
||||
log.Println(v...)
|
||||
}
|
||||
}
|
||||
175
pkg/qqwry/qqwry.go
Normal file
175
pkg/qqwry/qqwry.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package qqwry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
var (
|
||||
data []byte
|
||||
dataLen uint32
|
||||
ipCache *sync.Map
|
||||
)
|
||||
|
||||
const (
|
||||
indexLen = 7
|
||||
redirectMode1 = 0x01
|
||||
redirectMode2 = 0x02
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
City string
|
||||
Isp string
|
||||
}
|
||||
|
||||
func byte3ToUInt32(data []byte) uint32 {
|
||||
i := uint32(data[0]) & 0xff
|
||||
i |= (uint32(data[1]) << 8) & 0xff00
|
||||
i |= (uint32(data[2]) << 16) & 0xff0000
|
||||
return i
|
||||
}
|
||||
|
||||
func gb18030Decode(src []byte) string {
|
||||
in := bytes.NewReader(src)
|
||||
out := transform.NewReader(in, simplifiedchinese.GB18030.NewDecoder())
|
||||
d, _ := ioutil.ReadAll(out)
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// QueryIP 从内存或缓存查询IP
|
||||
func QueryIP(ip net.IP) (city string, isp string, err error) {
|
||||
ip32 := binary.BigEndian.Uint32(ip)
|
||||
|
||||
if ipCache != nil {
|
||||
if v, ok := ipCache.Load(ip32); ok {
|
||||
city = v.(cache).City
|
||||
isp = v.(cache).Isp
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
posA := binary.LittleEndian.Uint32(data[:4])
|
||||
posZ := binary.LittleEndian.Uint32(data[4:8])
|
||||
var offset uint32 = 0
|
||||
for {
|
||||
mid := posA + (((posZ-posA)/indexLen)>>1)*indexLen
|
||||
buf := data[mid : mid+indexLen]
|
||||
_ip := binary.LittleEndian.Uint32(buf[:4])
|
||||
if posZ-posA == indexLen {
|
||||
offset = byte3ToUInt32(buf[4:])
|
||||
buf = data[mid+indexLen : mid+indexLen+indexLen]
|
||||
if ip32 < binary.LittleEndian.Uint32(buf[:4]) {
|
||||
break
|
||||
} else {
|
||||
offset = 0
|
||||
break
|
||||
}
|
||||
}
|
||||
if _ip > ip32 {
|
||||
posZ = mid
|
||||
} else if _ip < ip32 {
|
||||
posA = mid
|
||||
} else if _ip == ip32 {
|
||||
offset = byte3ToUInt32(buf[4:])
|
||||
break
|
||||
}
|
||||
}
|
||||
if offset <= 0 {
|
||||
err = errors.New("ip not found")
|
||||
return
|
||||
}
|
||||
posM := offset + 4
|
||||
mode := data[posM]
|
||||
var ispPos uint32
|
||||
switch mode {
|
||||
case redirectMode1:
|
||||
posC := byte3ToUInt32(data[posM+1 : posM+4])
|
||||
mode = data[posC]
|
||||
posCA := posC
|
||||
if mode == redirectMode2 {
|
||||
posCA = byte3ToUInt32(data[posC+1 : posC+4])
|
||||
posC += 4
|
||||
}
|
||||
for i := posCA; i < dataLen; i++ {
|
||||
if data[i] == 0 {
|
||||
city = string(data[posCA:i])
|
||||
break
|
||||
}
|
||||
}
|
||||
if mode != redirectMode2 {
|
||||
posC += uint32(len(city) + 1)
|
||||
}
|
||||
ispPos = posC
|
||||
case redirectMode2:
|
||||
posCA := byte3ToUInt32(data[posM+1 : posM+4])
|
||||
for i := posCA; i < dataLen; i++ {
|
||||
if data[i] == 0 {
|
||||
city = string(data[posCA:i])
|
||||
break
|
||||
}
|
||||
}
|
||||
ispPos = offset + 8
|
||||
default:
|
||||
posCA := offset + 4
|
||||
for i := posCA; i < dataLen; i++ {
|
||||
if data[i] == 0 {
|
||||
city = string(data[posCA:i])
|
||||
break
|
||||
}
|
||||
}
|
||||
ispPos = offset + uint32(5+len(city))
|
||||
}
|
||||
if city != "" {
|
||||
city = strings.TrimSpace(gb18030Decode([]byte(city)))
|
||||
}
|
||||
ispMode := data[ispPos]
|
||||
if ispMode == redirectMode1 || ispMode == redirectMode2 {
|
||||
ispPos = byte3ToUInt32(data[ispPos+1 : ispPos+4])
|
||||
}
|
||||
if ispPos > 0 {
|
||||
for i := ispPos; i < dataLen; i++ {
|
||||
if data[i] == 0 {
|
||||
isp = string(data[ispPos:i])
|
||||
if isp != "" {
|
||||
if strings.Contains(isp, "CZ88.NET") {
|
||||
isp = ""
|
||||
} else {
|
||||
isp = strings.TrimSpace(gb18030Decode([]byte(isp)))
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipCache != nil {
|
||||
ipCache.Store(ip32, cache{City: city, Isp: isp})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// LoadData 从内存加载IP数据库
|
||||
func LoadData(database []byte) {
|
||||
data = database
|
||||
dataLen = uint32(len(data))
|
||||
}
|
||||
|
||||
// LoadFile 从文件加载IP数据库
|
||||
func LoadFile(filepath string, useCache bool) (err error) {
|
||||
data, err = ioutil.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen = uint32(len(data))
|
||||
if useCache {
|
||||
ipCache = new(sync.Map)
|
||||
}
|
||||
return
|
||||
}
|
||||
45
pkg/utils/utils.go
Normal file
45
pkg/utils/utils.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package utils
|
||||
|
||||
import "strings"
|
||||
|
||||
func ParseRules(rulesRaw []string) [][]string {
|
||||
var rules [][]string
|
||||
for _, r := range rulesRaw {
|
||||
if r == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(r, ".") {
|
||||
r += "."
|
||||
}
|
||||
rules = append(rules, strings.Split(r, "."))
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
func HasMatchedRule(rules [][]string, domain string) bool {
|
||||
var hasMatch bool
|
||||
OUTER:
|
||||
for _, m := range rules {
|
||||
domainSplited := strings.Split(domain, ".")
|
||||
i := len(m) - 1
|
||||
j := len(domainSplited) - 1
|
||||
// 从根域名开始匹配
|
||||
for i >= 0 && j >= 0 {
|
||||
if m[i] != domainSplited[j] && m[i] != "" {
|
||||
continue OUTER
|
||||
}
|
||||
i--
|
||||
j--
|
||||
}
|
||||
// 如果规则中还有剩余,但是域名已经匹配完了,检查规则最后一位是否是任意匹配
|
||||
if j != -1 && i == -1 && m[0] != "" {
|
||||
continue OUTER
|
||||
}
|
||||
hasMatch = i == -1
|
||||
// 如果匹配到了,就不用再匹配了
|
||||
if hasMatch {
|
||||
break
|
||||
}
|
||||
}
|
||||
return hasMatch
|
||||
}
|
||||
Reference in New Issue
Block a user