// proxy/rewriter.go package proxy import ( "bytes" "fmt" "golang.org/x/net/html" "net/url" "strings" ) type ContentRewriter struct { baseURL *url.URL token string } func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) { u, err := url.Parse(baseURL) if err != nil { return nil, err } return &ContentRewriter{ baseURL: u, token: token, }, nil } func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) { doc, err := html.Parse(bytes.NewReader(body)) if err != nil { return body, nil } r.rewriteNode(doc) var buf bytes.Buffer if err := html.Render(&buf, doc); err != nil { return body, nil } return buf.Bytes(), nil } func (r *ContentRewriter) rewriteNode(n *html.Node) { if n.Type == html.ElementNode { if n.Data == "head" && n.FirstChild != nil { script := &html.Node{ Type: html.ElementNode, Data: "script", } jsCode := fmt.Sprintf(`(function(){var t="/p/%s";var b="%s";var base=new URL(b);function r(u){if(!u||typeof u!=="string")return u;if(u.startsWith(t)||u.startsWith("data:")||u.startsWith("blob:")||u.startsWith("javascript:"))return u;if(u.startsWith("/")){return t+u}try{var a=new URL(u,b);if(a.host===base.host){return t+a.pathname+a.search+a.hash}}catch(e){}return u}var o=XMLHttpRequest.prototype.open;XMLHttpRequest.prototype.open=function(m,u){arguments[1]=r(u);return o.apply(this,arguments)};var f=window.fetch;window.fetch=function(u,opt){return f.call(this,r(u),opt)}})();`, r.token, r.baseURL.String()) script.AppendChild(&html.Node{ Type: html.TextNode, Data: jsCode, }) script.NextSibling = n.FirstChild n.FirstChild.PrevSibling = script script.Parent = n n.FirstChild = script } attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true} for i, attr := range n.Attr { if attrs[attr.Key] { if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val { n.Attr[i].Val = rewritten } } if attr.Key == "srcset" { n.Attr[i].Val = r.rewriteSrcset(attr.Val) } if attr.Key == "style" { n.Attr[i].Val = r.rewriteInlineCSS(attr.Val) } } if n.Data == "form" { hasAction := false for i, attr := range n.Attr { if attr.Key == "action" { hasAction = true if attr.Val == "" { n.Attr[i].Val = "/p/" + r.token + r.baseURL.Path } break } } if !hasAction { n.Attr = append(n.Attr, html.Attribute{ Key: "action", Val: "/p/" + r.token + r.baseURL.Path, }) } } if n.Data == "style" && n.FirstChild != nil { if n.FirstChild.Type == html.TextNode { n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data) } } } for c := n.FirstChild; c != nil; c = c.NextSibling { r.rewriteNode(c) } } func (r *ContentRewriter) rewriteURL(urlStr string) string { urlStr = strings.TrimSpace(urlStr) if strings.HasPrefix(urlStr, "javascript:") || strings.HasPrefix(urlStr, "data:") || strings.HasPrefix(urlStr, "mailto:") || strings.HasPrefix(urlStr, "tel:") || strings.HasPrefix(urlStr, "#") || urlStr == "" { return urlStr } if strings.HasPrefix(urlStr, "/p/"+r.token) { return urlStr } if strings.HasPrefix(urlStr, "/") && !strings.HasPrefix(urlStr, "//") { return "/p/" + r.token + urlStr } if strings.HasPrefix(urlStr, "//") { return urlStr } u, err := url.Parse(urlStr) if err != nil { return urlStr } if u.Host == r.baseURL.Host { path := u.Path if u.RawQuery != "" { path += "?" + u.RawQuery } if u.Fragment != "" { path += "#" + u.Fragment } return "/p/" + r.token + path } if !u.IsAbs() { resolved := r.baseURL.ResolveReference(u) if resolved.Host == r.baseURL.Host { path := resolved.Path if resolved.RawQuery != "" { path += "?" + resolved.RawQuery } if resolved.Fragment != "" { path += "#" + resolved.Fragment } return "/p/" + r.token + path } } return urlStr } func (r *ContentRewriter) rewriteSrcset(srcset string) string { if srcset == "" { return srcset } parts := strings.Split(srcset, ",") var rewritten []string for _, part := range parts { part = strings.TrimSpace(part) fields := strings.Fields(part) if len(fields) > 0 { fields[0] = r.rewriteURL(fields[0]) rewritten = append(rewritten, strings.Join(fields, " ")) } } return strings.Join(rewritten, ", ") } func (r *ContentRewriter) RewriteCSS(body []byte) []byte { content := string(body) return []byte(r.rewriteInlineCSS(content)) } func (r *ContentRewriter) rewriteInlineCSS(css string) string { result := css patterns := []struct { prefix string suffix string }{ {`url("`, `")`}, {`url('`, `')`}, {`url(`, `)`}, } for _, pattern := range patterns { start := 0 for { idx := strings.Index(result[start:], pattern.prefix) if idx == -1 { break } idx += start urlStart := idx + len(pattern.prefix) urlEnd := strings.Index(result[urlStart:], pattern.suffix) if urlEnd == -1 { break } urlEnd += urlStart originalURL := result[urlStart:urlEnd] rewrittenURL := r.rewriteURL(originalURL) result = result[:urlStart] + rewrittenURL + result[urlEnd:] start = urlStart + len(rewrittenURL) } } result = r.rewriteImports(result) return result } func (r *ContentRewriter) rewriteImports(css string) string { result := css patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`} for _, pattern := range patterns { start := 0 for { idx := strings.Index(result[start:], pattern) if idx == -1 { break } idx += start urlStart := idx + len(pattern) var endChar string if strings.Contains(pattern, `"`) { endChar = `"` } else { endChar = `'` } urlEnd := strings.Index(result[urlStart:], endChar) if urlEnd == -1 { break } urlEnd += urlStart originalURL := result[urlStart:urlEnd] rewrittenURL := r.rewriteURL(originalURL) result = result[:urlStart] + rewrittenURL + result[urlEnd:] start = urlStart + len(rewrittenURL) } } return result }