//proxy/rewriter.go package proxy import ( "bytes" "golang.org/x/net/html" "net/url" "strings" ) type ContentRewriter struct { baseURL *url.URL token string } func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) { u, err := url.Parse(baseURL) if err != nil { return nil, err } return &ContentRewriter{ baseURL: u, token: token, }, nil } func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) { doc, err := html.Parse(bytes.NewReader(body)) if err != nil { return r.simpleRewriteHTML(body), nil } r.rewriteNode(doc) var buf bytes.Buffer if err := html.Render(&buf, doc); err != nil { return r.simpleRewriteHTML(body), nil } return buf.Bytes(), nil } func (r *ContentRewriter) rewriteNode(n *html.Node) { if n.Type == html.ElementNode { attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true} for i, attr := range n.Attr { if attrs[attr.Key] { if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val { n.Attr[i].Val = rewritten } } if attr.Key == "srcset" { n.Attr[i].Val = r.rewriteSrcset(attr.Val) } if attr.Key == "style" { n.Attr[i].Val = r.rewriteInlineCSS(attr.Val) } } if n.Data == "form" { hasAction := false for i, attr := range n.Attr { if attr.Key == "action" { hasAction = true if attr.Val == "" { n.Attr[i].Val = r.rewriteURL(r.baseURL.String()) } break } } if !hasAction { n.Attr = append(n.Attr, html.Attribute{ Key: "action", Val: r.rewriteURL(r.baseURL.String()), }) } } if n.Data == "base" { for i, attr := range n.Attr { if attr.Key == "href" { n.Attr[i].Val = r.baseURL.String() } } } if n.Data == "style" && n.FirstChild != nil { if n.FirstChild.Type == html.TextNode { n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data) } } if n.Data == "script" { for _, attr := range n.Attr { if attr.Key == "src" { if r.isTrackingScript(attr.Val) { if n.Parent != nil { n.Parent.RemoveChild(n) return } } } } } } for c := n.FirstChild; c != nil; c = c.NextSibling { r.rewriteNode(c) } } func (r *ContentRewriter) rewriteURL(urlStr string) string { urlStr = strings.TrimSpace(urlStr) if strings.HasPrefix(urlStr, "javascript:") || strings.HasPrefix(urlStr, "data:") || strings.HasPrefix(urlStr, "mailto:") || strings.HasPrefix(urlStr, "tel:") || strings.HasPrefix(urlStr, "#") || urlStr == "" { return urlStr } // 处理协议相对 URL(//domain.com/path) if strings.HasPrefix(urlStr, "//") { urlStr = r.baseURL.Scheme + ":" + urlStr } u, err := url.Parse(urlStr) if err != nil { return urlStr } if !u.IsAbs() { u = r.baseURL.ResolveReference(u) } // 同域名,只保留路径 if u.Host == r.baseURL.Host { proxyPath := u.Path if u.RawQuery != "" { proxyPath += "?" + u.RawQuery } if u.Fragment != "" { proxyPath += "#" + u.Fragment } return "/p/" + r.token + proxyPath } // 跨域资源,完整 URL return "/p/" + r.token + "/" + u.String() } func (r *ContentRewriter) rewriteSrcset(srcset string) string { if srcset == "" { return srcset } parts := strings.Split(srcset, ",") var rewritten []string for _, part := range parts { part = strings.TrimSpace(part) fields := strings.Fields(part) if len(fields) > 0 { fields[0] = r.rewriteURL(fields[0]) rewritten = append(rewritten, strings.Join(fields, " ")) } } return strings.Join(rewritten, ", ") } func (r *ContentRewriter) RewriteCSS(body []byte) []byte { content := string(body) return []byte(r.rewriteInlineCSS(content)) } func (r *ContentRewriter) rewriteInlineCSS(css string) string { result := css patterns := []struct { prefix string suffix string }{ {`url("`, `")`}, {`url('`, `')`}, {`url(`, `)`}, } for _, pattern := range patterns { start := 0 for { idx := strings.Index(result[start:], pattern.prefix) if idx == -1 { break } idx += start urlStart := idx + len(pattern.prefix) urlEnd := strings.Index(result[urlStart:], pattern.suffix) if urlEnd == -1 { break } urlEnd += urlStart originalURL := result[urlStart:urlEnd] rewrittenURL := r.rewriteURL(originalURL) result = result[:urlStart] + rewrittenURL + result[urlEnd:] start = urlStart + len(rewrittenURL) } } result = r.rewriteImports(result) return result } func (r *ContentRewriter) rewriteImports(css string) string { result := css patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`} for _, pattern := range patterns { start := 0 for { idx := strings.Index(result[start:], pattern) if idx == -1 { break } idx += start urlStart := idx + len(pattern) var endChar string if strings.Contains(pattern, `"`) { endChar = `"` } else { endChar = `'` } urlEnd := strings.Index(result[urlStart:], endChar) if urlEnd == -1 { break } urlEnd += urlStart originalURL := result[urlStart:urlEnd] rewrittenURL := r.rewriteURL(originalURL) result = result[:urlStart] + rewrittenURL + result[urlEnd:] start = urlStart + len(rewrittenURL) } } return result } func (r *ContentRewriter) simpleRewriteHTML(body []byte) []byte { content := string(body) baseStr := r.baseURL.Scheme + "://" + r.baseURL.Host replacements := []struct { old string new string }{ {`href="` + baseStr, `href="/p/` + r.token}, {`src="` + baseStr, `src="/p/` + r.token}, {`action="` + baseStr, `action="/p/` + r.token}, {`href='` + baseStr, `href='/p/` + r.token}, {`src='` + baseStr, `src='/p/` + r.token}, } for _, rep := range replacements { content = strings.ReplaceAll(content, rep.old, rep.new) } return []byte(content) } func (r *ContentRewriter) isTrackingScript(src string) bool { trackingDomains := []string{ "google-analytics.com", "googletagmanager.com", "facebook.net", "doubleclick.net", "analytics.js", "ga.js", "gtag.js", } srcLower := strings.ToLower(src) for _, domain := range trackingDomains { if strings.Contains(srcLower, domain) { return true } } return false }