//proxy/rewriter.go package proxy import ( "bytes" "golang.org/x/net/html" "net/url" "strings" ) type ContentRewriter struct { baseURL *url.URL token string } func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) { u, err := url.Parse(baseURL) if err != nil { return nil, err } return &ContentRewriter{ baseURL: u, token: token, }, nil } func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) { doc, err := html.Parse(bytes.NewReader(body)) if err != nil { return r.simpleRewriteHTML(body), nil } r.rewriteNode(doc) var buf bytes.Buffer if err := html.Render(&buf, doc); err != nil { return r.simpleRewriteHTML(body), nil } return buf.Bytes(), nil } func (r *ContentRewriter) rewriteNode(n *html.Node) { if n.Type == html.ElementNode { // 在 head 标签开始处插入 base 标签 if n.Data == "head" && n.FirstChild != nil { baseNode := &html.Node{ Type: html.ElementNode, Data: "base", Attr: []html.Attribute{ {Key: "href", Val: "/p/" + r.token + "/"}, }, } baseNode.NextSibling = n.FirstChild n.FirstChild.PrevSibling = baseNode baseNode.Parent = n n.FirstChild = baseNode } attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true} for i, attr := range n.Attr { if attrs[attr.Key] { if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val { n.Attr[i].Val = rewritten } } if attr.Key == "srcset" { n.Attr[i].Val = r.rewriteSrcset(attr.Val) } if attr.Key == "style" { n.Attr[i].Val = r.rewriteInlineCSS(attr.Val) } } if n.Data == "form" { hasAction := false for i, attr := range n.Attr { if attr.Key == "action" { hasAction = true if attr.Val == "" { n.Attr[i].Val = r.rewriteURL(r.baseURL.String()) } break } } if !hasAction { n.Attr = append(n.Attr, html.Attribute{ Key: "action", Val: r.rewriteURL(r.baseURL.String()), }) } } if n.Data == "base" { for i, attr := range n.Attr { if attr.Key == "href" { n.Attr[i].Val = r.baseURL.String() } } } if n.Data == "style" && n.FirstChild != nil { if n.FirstChild.Type == html.TextNode { n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data) } } if n.Data == "script" { for _, attr := range n.Attr { if attr.Key == "src" { if r.isTrackingScript(attr.Val) { if n.Parent != nil { n.Parent.RemoveChild(n) return } } } } } } for c := n.FirstChild; c != nil; c = c.NextSibling { r.rewriteNode(c) } } func (r *ContentRewriter) rewriteURL(urlStr string) string { urlStr = strings.TrimSpace(urlStr) if strings.HasPrefix(urlStr, "javascript:") || strings.HasPrefix(urlStr, "data:") || strings.HasPrefix(urlStr, "mailto:") || strings.HasPrefix(urlStr, "tel:") || strings.HasPrefix(urlStr, "#") || urlStr == "" { return urlStr } if strings.HasPrefix(urlStr, "/p/"+r.token+"/") { return urlStr } if strings.HasPrefix(urlStr, "//") { urlStr = r.baseURL.Scheme + ":" + urlStr } if strings.HasPrefix(urlStr, "/") && !strings.HasPrefix(urlStr, "//") { return "/p/" + r.token + urlStr } u, err := url.Parse(urlStr) if err != nil { return urlStr } if !u.IsAbs() { resolved := r.baseURL.ResolveReference(u) proxyPath := resolved.Path if resolved.RawQuery != "" { proxyPath += "?" + resolved.RawQuery } if resolved.Fragment != "" { proxyPath += "#" + resolved.Fragment } return "/p/" + r.token + proxyPath } if u.Host == r.baseURL.Host { proxyPath := u.Path if u.RawQuery != "" { proxyPath += "?" + u.RawQuery } if u.Fragment != "" { proxyPath += "#" + u.Fragment } return "/p/" + r.token + proxyPath } return "/p/" + r.token + "/" + u.String() } func (r *ContentRewriter) rewriteSrcset(srcset string) string { if srcset == "" { return srcset } parts := strings.Split(srcset, ",") var rewritten []string for _, part := range parts { part = strings.TrimSpace(part) fields := strings.Fields(part) if len(fields) > 0 { fields[0] = r.rewriteURL(fields[0]) rewritten = append(rewritten, strings.Join(fields, " ")) } } return strings.Join(rewritten, ", ") } func (r *ContentRewriter) RewriteCSS(body []byte) []byte { content := string(body) return []byte(r.rewriteInlineCSS(content)) } func (r *ContentRewriter) rewriteInlineCSS(css string) string { result := css patterns := []struct { prefix string suffix string }{ {`url("`, `")`}, {`url('`, `')`}, {`url(`, `)`}, } for _, pattern := range patterns { start := 0 for { idx := strings.Index(result[start:], pattern.prefix) if idx == -1 { break } idx += start urlStart := idx + len(pattern.prefix) urlEnd := strings.Index(result[urlStart:], pattern.suffix) if urlEnd == -1 { break } urlEnd += urlStart originalURL := result[urlStart:urlEnd] rewrittenURL := r.rewriteURL(originalURL) result = result[:urlStart] + rewrittenURL + result[urlEnd:] start = urlStart + len(rewrittenURL) } } result = r.rewriteImports(result) return result } func (r *ContentRewriter) rewriteImports(css string) string { result := css patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`} for _, pattern := range patterns { start := 0 for { idx := strings.Index(result[start:], pattern) if idx == -1 { break } idx += start urlStart := idx + len(pattern) var endChar string if strings.Contains(pattern, `"`) { endChar = `"` } else { endChar = `'` } urlEnd := strings.Index(result[urlStart:], endChar) if urlEnd == -1 { break } urlEnd += urlStart originalURL := result[urlStart:urlEnd] rewrittenURL := r.rewriteURL(originalURL) result = result[:urlStart] + rewrittenURL + result[urlEnd:] start = urlStart + len(rewrittenURL) } } return result } func (r *ContentRewriter) simpleRewriteHTML(body []byte) []byte { content := string(body) baseStr := r.baseURL.Scheme + "://" + r.baseURL.Host replacements := []struct { old string new string }{ {`href="` + baseStr, `href="/p/` + r.token}, {`src="` + baseStr, `src="/p/` + r.token}, {`action="` + baseStr, `action="/p/` + r.token}, {`href='` + baseStr, `href='/p/` + r.token}, {`src='` + baseStr, `src='/p/` + r.token}, } for _, rep := range replacements { content = strings.ReplaceAll(content, rep.old, rep.new) } return []byte(content) } func (r *ContentRewriter) isTrackingScript(src string) bool { trackingDomains := []string{ "google-analytics.com", "googletagmanager.com", "facebook.net", "doubleclick.net", "analytics.js", "ga.js", "gtag.js", } srcLower := strings.ToLower(src) for _, domain := range trackingDomains { if strings.Contains(srcLower, domain) { return true } } return false }