Files
SiteProxy/proxy/rewriter.go
2025-12-15 05:04:01 +08:00

308 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//proxy/rewriter.go
package proxy
import (
"bytes"
"golang.org/x/net/html"
"net/url"
"strings"
)
type ContentRewriter struct {
baseURL *url.URL
token string
}
func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) {
u, err := url.Parse(baseURL)
if err != nil {
return nil, err
}
return &ContentRewriter{
baseURL: u,
token: token,
}, nil
}
func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) {
doc, err := html.Parse(bytes.NewReader(body))
if err != nil {
return r.simpleRewriteHTML(body), nil
}
r.rewriteNode(doc)
var buf bytes.Buffer
if err := html.Render(&buf, doc); err != nil {
return r.simpleRewriteHTML(body), nil
}
return buf.Bytes(), nil
}
func (r *ContentRewriter) rewriteNode(n *html.Node) {
if n.Type == html.ElementNode {
attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true}
for i, attr := range n.Attr {
if attrs[attr.Key] {
if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val {
n.Attr[i].Val = rewritten
}
}
if attr.Key == "srcset" {
n.Attr[i].Val = r.rewriteSrcset(attr.Val)
}
if attr.Key == "style" {
n.Attr[i].Val = r.rewriteInlineCSS(attr.Val)
}
}
if n.Data == "form" {
hasAction := false
for i, attr := range n.Attr {
if attr.Key == "action" {
hasAction = true
if attr.Val == "" {
n.Attr[i].Val = r.rewriteURL(r.baseURL.String())
}
break
}
}
if !hasAction {
n.Attr = append(n.Attr, html.Attribute{
Key: "action",
Val: r.rewriteURL(r.baseURL.String()),
})
}
}
if n.Data == "base" {
for i, attr := range n.Attr {
if attr.Key == "href" {
n.Attr[i].Val = r.baseURL.String()
}
}
}
if n.Data == "style" && n.FirstChild != nil {
if n.FirstChild.Type == html.TextNode {
n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data)
}
}
if n.Data == "script" {
for _, attr := range n.Attr {
if attr.Key == "src" {
if r.isTrackingScript(attr.Val) {
if n.Parent != nil {
n.Parent.RemoveChild(n)
return
}
}
}
}
}
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
r.rewriteNode(c)
}
}
func (r *ContentRewriter) rewriteURL(urlStr string) string {
urlStr = strings.TrimSpace(urlStr)
if strings.HasPrefix(urlStr, "javascript:") ||
strings.HasPrefix(urlStr, "data:") ||
strings.HasPrefix(urlStr, "mailto:") ||
strings.HasPrefix(urlStr, "tel:") ||
strings.HasPrefix(urlStr, "#") ||
urlStr == "" {
return urlStr
}
// 处理协议相对 URL//domain.com/path
if strings.HasPrefix(urlStr, "//") {
urlStr = r.baseURL.Scheme + ":" + urlStr
}
u, err := url.Parse(urlStr)
if err != nil {
return urlStr
}
if !u.IsAbs() {
u = r.baseURL.ResolveReference(u)
}
// 同域名,只保留路径
if u.Host == r.baseURL.Host {
proxyPath := u.Path
if u.RawQuery != "" {
proxyPath += "?" + u.RawQuery
}
if u.Fragment != "" {
proxyPath += "#" + u.Fragment
}
return "/p/" + r.token + proxyPath
}
// 跨域资源,完整 URL
return "/p/" + r.token + "/" + u.String()
}
func (r *ContentRewriter) rewriteSrcset(srcset string) string {
if srcset == "" {
return srcset
}
parts := strings.Split(srcset, ",")
var rewritten []string
for _, part := range parts {
part = strings.TrimSpace(part)
fields := strings.Fields(part)
if len(fields) > 0 {
fields[0] = r.rewriteURL(fields[0])
rewritten = append(rewritten, strings.Join(fields, " "))
}
}
return strings.Join(rewritten, ", ")
}
func (r *ContentRewriter) RewriteCSS(body []byte) []byte {
content := string(body)
return []byte(r.rewriteInlineCSS(content))
}
func (r *ContentRewriter) rewriteInlineCSS(css string) string {
result := css
patterns := []struct {
prefix string
suffix string
}{
{`url("`, `")`},
{`url('`, `')`},
{`url(`, `)`},
}
for _, pattern := range patterns {
start := 0
for {
idx := strings.Index(result[start:], pattern.prefix)
if idx == -1 {
break
}
idx += start
urlStart := idx + len(pattern.prefix)
urlEnd := strings.Index(result[urlStart:], pattern.suffix)
if urlEnd == -1 {
break
}
urlEnd += urlStart
originalURL := result[urlStart:urlEnd]
rewrittenURL := r.rewriteURL(originalURL)
result = result[:urlStart] + rewrittenURL + result[urlEnd:]
start = urlStart + len(rewrittenURL)
}
}
result = r.rewriteImports(result)
return result
}
func (r *ContentRewriter) rewriteImports(css string) string {
result := css
patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`}
for _, pattern := range patterns {
start := 0
for {
idx := strings.Index(result[start:], pattern)
if idx == -1 {
break
}
idx += start
urlStart := idx + len(pattern)
var endChar string
if strings.Contains(pattern, `"`) {
endChar = `"`
} else {
endChar = `'`
}
urlEnd := strings.Index(result[urlStart:], endChar)
if urlEnd == -1 {
break
}
urlEnd += urlStart
originalURL := result[urlStart:urlEnd]
rewrittenURL := r.rewriteURL(originalURL)
result = result[:urlStart] + rewrittenURL + result[urlEnd:]
start = urlStart + len(rewrittenURL)
}
}
return result
}
func (r *ContentRewriter) simpleRewriteHTML(body []byte) []byte {
content := string(body)
baseStr := r.baseURL.Scheme + "://" + r.baseURL.Host
replacements := []struct {
old string
new string
}{
{`href="` + baseStr, `href="/p/` + r.token},
{`src="` + baseStr, `src="/p/` + r.token},
{`action="` + baseStr, `action="/p/` + r.token},
{`href='` + baseStr, `href='/p/` + r.token},
{`src='` + baseStr, `src='/p/` + r.token},
}
for _, rep := range replacements {
content = strings.ReplaceAll(content, rep.old, rep.new)
}
return []byte(content)
}
func (r *ContentRewriter) isTrackingScript(src string) bool {
trackingDomains := []string{
"google-analytics.com",
"googletagmanager.com",
"facebook.net",
"doubleclick.net",
"analytics.js",
"ga.js",
"gtag.js",
}
srcLower := strings.ToLower(src)
for _, domain := range trackingDomains {
if strings.Contains(srcLower, domain) {
return true
}
}
return false
}