diff --git a/proxy/rewriter-simple.go b/proxy/rewriter-simple.go new file mode 100644 index 0000000..b263ed3 --- /dev/null +++ b/proxy/rewriter-simple.go @@ -0,0 +1,275 @@ +// proxy/rewriter.go +package proxy + +import ( + "bytes" + "fmt" + "golang.org/x/net/html" + "net/url" + "strings" +) + +type ContentRewriter struct { + baseURL *url.URL + token string +} + +func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) { + u, err := url.Parse(baseURL) + if err != nil { + return nil, err + } + + return &ContentRewriter{ + baseURL: u, + token: token, + }, nil +} + +func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) { + doc, err := html.Parse(bytes.NewReader(body)) + if err != nil { + return body, nil + } + + r.rewriteNode(doc) + + var buf bytes.Buffer + if err := html.Render(&buf, doc); err != nil { + return body, nil + } + + return buf.Bytes(), nil +} + +func (r *ContentRewriter) rewriteNode(n *html.Node) { + if n.Type == html.ElementNode { + if n.Data == "head" && n.FirstChild != nil { + script := &html.Node{ + Type: html.ElementNode, + Data: "script", + } + jsCode := fmt.Sprintf(`(function(){var t="/p/%s";var b="%s";var base=new URL(b);function r(u){if(!u||typeof u!=="string")return u;if(u.startsWith(t)||u.startsWith("data:")||u.startsWith("blob:")||u.startsWith("javascript:"))return u;if(u.startsWith("/")){return t+u}try{var a=new URL(u,b);if(a.host===base.host){return t+a.pathname+a.search+a.hash}}catch(e){}return u}var o=XMLHttpRequest.prototype.open;XMLHttpRequest.prototype.open=function(m,u){arguments[1]=r(u);return o.apply(this,arguments)};var f=window.fetch;window.fetch=function(u,opt){return f.call(this,r(u),opt)}})();`, r.token, r.baseURL.String()) + script.AppendChild(&html.Node{ + Type: html.TextNode, + Data: jsCode, + }) + script.NextSibling = n.FirstChild + n.FirstChild.PrevSibling = script + script.Parent = n + n.FirstChild = script + } + + attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true} + + for i, attr := range n.Attr { + if attrs[attr.Key] { + if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val { + n.Attr[i].Val = rewritten + } + } + + if attr.Key == "srcset" { + n.Attr[i].Val = r.rewriteSrcset(attr.Val) + } + + if attr.Key == "style" { + n.Attr[i].Val = r.rewriteInlineCSS(attr.Val) + } + } + + if n.Data == "form" { + hasAction := false + for i, attr := range n.Attr { + if attr.Key == "action" { + hasAction = true + if attr.Val == "" { + n.Attr[i].Val = "/p/" + r.token + r.baseURL.Path + } + break + } + } + if !hasAction { + n.Attr = append(n.Attr, html.Attribute{ + Key: "action", + Val: "/p/" + r.token + r.baseURL.Path, + }) + } + } + + if n.Data == "style" && n.FirstChild != nil { + if n.FirstChild.Type == html.TextNode { + n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data) + } + } + } + + for c := n.FirstChild; c != nil; c = c.NextSibling { + r.rewriteNode(c) + } +} + +func (r *ContentRewriter) rewriteURL(urlStr string) string { + urlStr = strings.TrimSpace(urlStr) + + if strings.HasPrefix(urlStr, "javascript:") || + strings.HasPrefix(urlStr, "data:") || + strings.HasPrefix(urlStr, "mailto:") || + strings.HasPrefix(urlStr, "tel:") || + strings.HasPrefix(urlStr, "#") || + urlStr == "" { + return urlStr + } + + if strings.HasPrefix(urlStr, "/p/"+r.token) { + return urlStr + } + + if strings.HasPrefix(urlStr, "/") && !strings.HasPrefix(urlStr, "//") { + return "/p/" + r.token + urlStr + } + + if strings.HasPrefix(urlStr, "//") { + return urlStr + } + + u, err := url.Parse(urlStr) + if err != nil { + return urlStr + } + + if u.Host == r.baseURL.Host { + path := u.Path + if u.RawQuery != "" { + path += "?" + u.RawQuery + } + if u.Fragment != "" { + path += "#" + u.Fragment + } + return "/p/" + r.token + path + } + + if !u.IsAbs() { + resolved := r.baseURL.ResolveReference(u) + if resolved.Host == r.baseURL.Host { + path := resolved.Path + if resolved.RawQuery != "" { + path += "?" + resolved.RawQuery + } + if resolved.Fragment != "" { + path += "#" + resolved.Fragment + } + return "/p/" + r.token + path + } + } + + return urlStr +} + +func (r *ContentRewriter) rewriteSrcset(srcset string) string { + if srcset == "" { + return srcset + } + + parts := strings.Split(srcset, ",") + var rewritten []string + + for _, part := range parts { + part = strings.TrimSpace(part) + fields := strings.Fields(part) + + if len(fields) > 0 { + fields[0] = r.rewriteURL(fields[0]) + rewritten = append(rewritten, strings.Join(fields, " ")) + } + } + + return strings.Join(rewritten, ", ") +} + +func (r *ContentRewriter) RewriteCSS(body []byte) []byte { + content := string(body) + return []byte(r.rewriteInlineCSS(content)) +} + +func (r *ContentRewriter) rewriteInlineCSS(css string) string { + result := css + + patterns := []struct { + prefix string + suffix string + }{ + {`url("`, `")`}, + {`url('`, `')`}, + {`url(`, `)`}, + } + + for _, pattern := range patterns { + start := 0 + for { + idx := strings.Index(result[start:], pattern.prefix) + if idx == -1 { + break + } + + idx += start + urlStart := idx + len(pattern.prefix) + urlEnd := strings.Index(result[urlStart:], pattern.suffix) + + if urlEnd == -1 { + break + } + + urlEnd += urlStart + originalURL := result[urlStart:urlEnd] + rewrittenURL := r.rewriteURL(originalURL) + + result = result[:urlStart] + rewrittenURL + result[urlEnd:] + start = urlStart + len(rewrittenURL) + } + } + + result = r.rewriteImports(result) + + return result +} + +func (r *ContentRewriter) rewriteImports(css string) string { + result := css + + patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`} + + for _, pattern := range patterns { + start := 0 + for { + idx := strings.Index(result[start:], pattern) + if idx == -1 { + break + } + + idx += start + urlStart := idx + len(pattern) + + var endChar string + if strings.Contains(pattern, `"`) { + endChar = `"` + } else { + endChar = `'` + } + + urlEnd := strings.Index(result[urlStart:], endChar) + if urlEnd == -1 { + break + } + + urlEnd += urlStart + originalURL := result[urlStart:urlEnd] + rewrittenURL := r.rewriteURL(originalURL) + + result = result[:urlStart] + rewrittenURL + result[urlEnd:] + start = urlStart + len(rewrittenURL) + } + } + + return result +}