Files
SiteProxy/proxy/rewriter.go
2025-12-15 18:04:17 +08:00

344 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//proxy/rewriter.go
package proxy
import (
"bytes"
"golang.org/x/net/html"
"net/url"
"strings"
)
type ContentRewriter struct {
baseURL *url.URL
token string
}
func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) {
u, err := url.Parse(baseURL)
if err != nil {
return nil, err
}
return &ContentRewriter{
baseURL: u,
token: token,
}, nil
}
func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) {
doc, err := html.Parse(bytes.NewReader(body))
if err != nil {
return r.simpleRewriteHTML(body), nil
}
r.rewriteNode(doc)
var buf bytes.Buffer
if err := html.Render(&buf, doc); err != nil {
return r.simpleRewriteHTML(body), nil
}
return buf.Bytes(), nil
}
func (r *ContentRewriter) rewriteNode(n *html.Node) {
if n.Type == html.ElementNode {
// 在 head 标签中注入请求拦截脚本
if n.Data == "head" && n.FirstChild != nil {
script := &html.Node{
Type: html.ElementNode,
Data: "script",
}
script.AppendChild(&html.Node{
Type: html.TextNode,
Data: `(function(){var t="/p/` + r.token + `";var o=XMLHttpRequest.prototype.open;XMLHttpRequest.prototype.open=function(m,u){if(typeof u==="string"&&u.startsWith("/")){arguments[1]=t+u}return o.apply(this,arguments)};var f=window.fetch;window.fetch=function(u,opt){if(typeof u==="string"&&u.startsWith("/")){u=t+u}return f.call(this,u,opt)};var U=window.URL;window.URL=function(u,base){if(typeof u==="string"&&u.startsWith("/")){u=t+u}return new U(u,base)};Object.defineProperty(window.location,"pathname",{get:function(){var p=window.location.href.split(window.location.host)[1]||"/";return p.startsWith(t)?p.substring(t.length):p}});var oa=Element.prototype.setAttribute;Element.prototype.setAttribute=function(n,v){if((n==="href"||n==="src"||n==="action")&&typeof v==="string"&&v.startsWith("/")){v=t+v}return oa.call(this,n,v)}})();`,
})
script.NextSibling = n.FirstChild
n.FirstChild.PrevSibling = script
script.Parent = n
n.FirstChild = script
}
attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true}
for i, attr := range n.Attr {
if attrs[attr.Key] {
if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val {
n.Attr[i].Val = rewritten
}
}
if attr.Key == "srcset" {
n.Attr[i].Val = r.rewriteSrcset(attr.Val)
}
if attr.Key == "style" {
n.Attr[i].Val = r.rewriteInlineCSS(attr.Val)
}
}
if n.Data == "form" {
hasAction := false
for i, attr := range n.Attr {
if attr.Key == "action" {
hasAction = true
if attr.Val == "" {
n.Attr[i].Val = r.rewriteURL(r.baseURL.String())
}
break
}
}
if !hasAction {
n.Attr = append(n.Attr, html.Attribute{
Key: "action",
Val: r.rewriteURL(r.baseURL.String()),
})
}
}
if n.Data == "base" {
for i, attr := range n.Attr {
if attr.Key == "href" {
n.Attr[i].Val = r.baseURL.String()
}
}
}
if n.Data == "style" && n.FirstChild != nil {
if n.FirstChild.Type == html.TextNode {
n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data)
}
}
if n.Data == "script" {
for _, attr := range n.Attr {
if attr.Key == "src" {
if r.isTrackingScript(attr.Val) {
if n.Parent != nil {
n.Parent.RemoveChild(n)
return
}
}
}
}
}
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
r.rewriteNode(c)
}
}
func (r *ContentRewriter) rewriteURL(urlStr string) string {
urlStr = strings.TrimSpace(urlStr)
if strings.HasPrefix(urlStr, "javascript:") ||
strings.HasPrefix(urlStr, "data:") ||
strings.HasPrefix(urlStr, "mailto:") ||
strings.HasPrefix(urlStr, "tel:") ||
strings.HasPrefix(urlStr, "#") ||
urlStr == "" {
return urlStr
}
if strings.HasPrefix(urlStr, "/p/"+r.token+"/") {
return urlStr
}
if strings.HasPrefix(urlStr, "//") {
urlStr = r.baseURL.Scheme + ":" + urlStr
}
if strings.HasPrefix(urlStr, "/") && !strings.HasPrefix(urlStr, "//") {
return "/p/" + r.token + urlStr
}
u, err := url.Parse(urlStr)
if err != nil {
return urlStr
}
if !u.IsAbs() {
resolved := r.baseURL.ResolveReference(u)
// 同域相对路径,只保留路径部分
if resolved.Host == r.baseURL.Host {
proxyPath := resolved.Path
if resolved.RawQuery != "" {
proxyPath += "?" + resolved.RawQuery
}
if resolved.Fragment != "" {
proxyPath += "#" + resolved.Fragment
}
return "/p/" + r.token + proxyPath
}
// 跨域相对路径(罕见),保留完整 URL
return "/p/" + r.token + "/" + resolved.String()
}
// 同域绝对 URL只保留路径
if u.Host == r.baseURL.Host {
proxyPath := u.Path
if u.RawQuery != "" {
proxyPath += "?" + u.RawQuery
}
if u.Fragment != "" {
proxyPath += "#" + u.Fragment
}
return "/p/" + r.token + proxyPath
}
// 跨域绝对 URL保留完整 URL
return "/p/" + r.token + "/" + u.String()
}
func (r *ContentRewriter) rewriteSrcset(srcset string) string {
if srcset == "" {
return srcset
}
parts := strings.Split(srcset, ",")
var rewritten []string
for _, part := range parts {
part = strings.TrimSpace(part)
fields := strings.Fields(part)
if len(fields) > 0 {
fields[0] = r.rewriteURL(fields[0])
rewritten = append(rewritten, strings.Join(fields, " "))
}
}
return strings.Join(rewritten, ", ")
}
func (r *ContentRewriter) RewriteCSS(body []byte) []byte {
content := string(body)
return []byte(r.rewriteInlineCSS(content))
}
func (r *ContentRewriter) rewriteInlineCSS(css string) string {
result := css
patterns := []struct {
prefix string
suffix string
}{
{`url("`, `")`},
{`url('`, `')`},
{`url(`, `)`},
}
for _, pattern := range patterns {
start := 0
for {
idx := strings.Index(result[start:], pattern.prefix)
if idx == -1 {
break
}
idx += start
urlStart := idx + len(pattern.prefix)
urlEnd := strings.Index(result[urlStart:], pattern.suffix)
if urlEnd == -1 {
break
}
urlEnd += urlStart
originalURL := result[urlStart:urlEnd]
rewrittenURL := r.rewriteURL(originalURL)
result = result[:urlStart] + rewrittenURL + result[urlEnd:]
start = urlStart + len(rewrittenURL)
}
}
result = r.rewriteImports(result)
return result
}
func (r *ContentRewriter) rewriteImports(css string) string {
result := css
patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`}
for _, pattern := range patterns {
start := 0
for {
idx := strings.Index(result[start:], pattern)
if idx == -1 {
break
}
idx += start
urlStart := idx + len(pattern)
var endChar string
if strings.Contains(pattern, `"`) {
endChar = `"`
} else {
endChar = `'`
}
urlEnd := strings.Index(result[urlStart:], endChar)
if urlEnd == -1 {
break
}
urlEnd += urlStart
originalURL := result[urlStart:urlEnd]
rewrittenURL := r.rewriteURL(originalURL)
result = result[:urlStart] + rewrittenURL + result[urlEnd:]
start = urlStart + len(rewrittenURL)
}
}
return result
}
func (r *ContentRewriter) simpleRewriteHTML(body []byte) []byte {
content := string(body)
baseStr := r.baseURL.Scheme + "://" + r.baseURL.Host
replacements := []struct {
old string
new string
}{
{`href="` + baseStr, `href="/p/` + r.token},
{`src="` + baseStr, `src="/p/` + r.token},
{`action="` + baseStr, `action="/p/` + r.token},
{`href='` + baseStr, `href='/p/` + r.token},
{`src='` + baseStr, `src='/p/` + r.token},
}
for _, rep := range replacements {
content = strings.ReplaceAll(content, rep.old, rep.new)
}
return []byte(content)
}
func (r *ContentRewriter) isTrackingScript(src string) bool {
trackingDomains := []string{
"google-analytics.com",
"googletagmanager.com",
"facebook.net",
"doubleclick.net",
"analytics.js",
"ga.js",
"gtag.js",
}
srcLower := strings.ToLower(src)
for _, domain := range trackingDomains {
if strings.Contains(srcLower, domain) {
return true
}
}
return false
}