更新 proxy/rewriter.go

This commit is contained in:
XOF
2025-12-15 04:16:25 +08:00
parent bc47d0152a
commit 2a8cd1c427

View File

@@ -10,9 +10,10 @@ import (
type ContentRewriter struct { type ContentRewriter struct {
baseURL *url.URL baseURL *url.URL
token string
} }
func NewContentRewriter(baseURL string) (*ContentRewriter, error) { func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) {
u, err := url.Parse(baseURL) u, err := url.Parse(baseURL)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -20,14 +21,13 @@ func NewContentRewriter(baseURL string) (*ContentRewriter, error) {
return &ContentRewriter{ return &ContentRewriter{
baseURL: u, baseURL: u,
token: token,
}, nil }, nil
} }
// RewriteHTML 重写 HTML 内容中的所有 URL
func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) { func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) {
doc, err := html.Parse(bytes.NewReader(body)) doc, err := html.Parse(bytes.NewReader(body))
if err != nil { if err != nil {
// 如果解析失败,使用简单的字符串替换
return r.simpleRewriteHTML(body), nil return r.simpleRewriteHTML(body), nil
} }
@@ -41,16 +41,9 @@ func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) {
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// rewriteNode 递归重写 HTML 节点
func (r *ContentRewriter) rewriteNode(n *html.Node) { func (r *ContentRewriter) rewriteNode(n *html.Node) {
if n.Type == html.ElementNode { if n.Type == html.ElementNode {
// 重写需要处理的属性 attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true}
attrs := map[string]bool{
"href": true,
"src": true,
"action": true,
"data": true,
}
for i, attr := range n.Attr { for i, attr := range n.Attr {
if attrs[attr.Key] { if attrs[attr.Key] {
@@ -59,31 +52,26 @@ func (r *ContentRewriter) rewriteNode(n *html.Node) {
} }
} }
// 处理 srcset 属性
if attr.Key == "srcset" { if attr.Key == "srcset" {
n.Attr[i].Val = r.rewriteSrcset(attr.Val) n.Attr[i].Val = r.rewriteSrcset(attr.Val)
} }
// 处理 style 属性中的 URL
if attr.Key == "style" { if attr.Key == "style" {
n.Attr[i].Val = r.rewriteInlineCSS(attr.Val) n.Attr[i].Val = r.rewriteInlineCSS(attr.Val)
} }
} }
// 处理 <form> 标签 - 确保 action 属性存在 if n.Data == "form" {
if n.Data == "form" { hasAction := false
hasAction := false for i, attr := range n.Attr {
for i, attr := range n.Attr { if attr.Key == "action" {
if attr.Key == "action" { hasAction = true
hasAction = true if attr.Val == "" {
// 空 action 表示提交到当前页面 n.Attr[i].Val = r.rewriteURL(r.baseURL.String())
if attr.Val == "" {
n.Attr[i].Val = r.rewriteURL(r.baseURL.String())
}
break
} }
break
} }
// 如果没有 action 属性,添加一个 }
if !hasAction { if !hasAction {
n.Attr = append(n.Attr, html.Attribute{ n.Attr = append(n.Attr, html.Attribute{
Key: "action", Key: "action",
@@ -91,7 +79,7 @@ func (r *ContentRewriter) rewriteNode(n *html.Node) {
}) })
} }
} }
// 处理 <base> 标签
if n.Data == "base" { if n.Data == "base" {
for i, attr := range n.Attr { for i, attr := range n.Attr {
if attr.Key == "href" { if attr.Key == "href" {
@@ -100,20 +88,16 @@ func (r *ContentRewriter) rewriteNode(n *html.Node) {
} }
} }
// 处理 <style> 标签内容
if n.Data == "style" && n.FirstChild != nil { if n.Data == "style" && n.FirstChild != nil {
if n.FirstChild.Type == html.TextNode { if n.FirstChild.Type == html.TextNode {
n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data) n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data)
} }
} }
// 处理 <script> 标签,移除可能的跟踪脚本
if n.Data == "script" { if n.Data == "script" {
for _, attr := range n.Attr { for _, attr := range n.Attr {
if attr.Key == "src" { if attr.Key == "src" {
// 可以在这里过滤掉已知的跟踪脚本
if r.isTrackingScript(attr.Val) { if r.isTrackingScript(attr.Val) {
// 移除此节点
if n.Parent != nil { if n.Parent != nil {
n.Parent.RemoveChild(n) n.Parent.RemoveChild(n)
return return
@@ -124,17 +108,14 @@ func (r *ContentRewriter) rewriteNode(n *html.Node) {
} }
} }
// 递归处理子节点
for c := n.FirstChild; c != nil; c = c.NextSibling { for c := n.FirstChild; c != nil; c = c.NextSibling {
r.rewriteNode(c) r.rewriteNode(c)
} }
} }
// rewriteURL 重写单个 URL
func (r *ContentRewriter) rewriteURL(urlStr string) string { func (r *ContentRewriter) rewriteURL(urlStr string) string {
urlStr = strings.TrimSpace(urlStr) urlStr = strings.TrimSpace(urlStr)
// 跳过特殊协议
if strings.HasPrefix(urlStr, "javascript:") || if strings.HasPrefix(urlStr, "javascript:") ||
strings.HasPrefix(urlStr, "data:") || strings.HasPrefix(urlStr, "data:") ||
strings.HasPrefix(urlStr, "mailto:") || strings.HasPrefix(urlStr, "mailto:") ||
@@ -144,22 +125,23 @@ func (r *ContentRewriter) rewriteURL(urlStr string) string {
return urlStr return urlStr
} }
// 解析 URL
u, err := url.Parse(urlStr) u, err := url.Parse(urlStr)
if err != nil { if err != nil {
return urlStr return urlStr
} }
// 如果是相对 URL转换为绝对 URL
if !u.IsAbs() { if !u.IsAbs() {
u = r.baseURL.ResolveReference(u) u = r.baseURL.ResolveReference(u)
} }
// 生成代理 URL proxyPath := u.Path
return "/proxy?url=" + url.QueryEscape(u.String()) if u.RawQuery != "" {
proxyPath += "?" + u.RawQuery
}
return "/p/" + r.token + proxyPath
} }
// rewriteSrcset 重写 srcset 属性
func (r *ContentRewriter) rewriteSrcset(srcset string) string { func (r *ContentRewriter) rewriteSrcset(srcset string) string {
if srcset == "" { if srcset == "" {
return srcset return srcset
@@ -181,18 +163,14 @@ func (r *ContentRewriter) rewriteSrcset(srcset string) string {
return strings.Join(rewritten, ", ") return strings.Join(rewritten, ", ")
} }
// RewriteCSS 重写 CSS 内容
func (r *ContentRewriter) RewriteCSS(body []byte) []byte { func (r *ContentRewriter) RewriteCSS(body []byte) []byte {
content := string(body) content := string(body)
return []byte(r.rewriteInlineCSS(content)) return []byte(r.rewriteInlineCSS(content))
} }
// rewriteInlineCSS 重写内联 CSS 中的 URL
func (r *ContentRewriter) rewriteInlineCSS(css string) string { func (r *ContentRewriter) rewriteInlineCSS(css string) string {
// 匹配 url(...) 模式
result := css result := css
// 处理 url("...") 和 url('...') 和 url(...)
patterns := []struct { patterns := []struct {
prefix string prefix string
suffix string suffix string
@@ -227,22 +205,15 @@ func (r *ContentRewriter) rewriteInlineCSS(css string) string {
} }
} }
// 处理 @import
result = r.rewriteImports(result) result = r.rewriteImports(result)
return result return result
} }
// rewriteImports 重写 CSS @import 语句
func (r *ContentRewriter) rewriteImports(css string) string { func (r *ContentRewriter) rewriteImports(css string) string {
result := css result := css
patterns := []string{ patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`}
`@import "`,
`@import '`,
`@import url("`,
`@import url('`,
}
for _, pattern := range patterns { for _, pattern := range patterns {
start := 0 start := 0
@@ -279,32 +250,28 @@ func (r *ContentRewriter) rewriteImports(css string) string {
return result return result
} }
// simpleRewriteHTML 简单的字符串替换重写(备用方案)
func (r *ContentRewriter) simpleRewriteHTML(body []byte) []byte { func (r *ContentRewriter) simpleRewriteHTML(body []byte) []byte {
content := string(body) content := string(body)
// 重写绝对 URL
baseStr := r.baseURL.Scheme + "://" + r.baseURL.Host baseStr := r.baseURL.Scheme + "://" + r.baseURL.Host
replacements := []struct { replacements := []struct {
old string old string
new string new string
}{ }{
{`href="` + baseStr, `href="/proxy?url=` + url.QueryEscape(baseStr)}, {`href="` + baseStr, `href="/p/` + r.token},
{`src="` + baseStr, `src="/proxy?url=` + url.QueryEscape(baseStr)}, {`src="` + baseStr, `src="/p/` + r.token},
{`action="` + baseStr, `action="/proxy?url=` + url.QueryEscape(baseStr)}, {`action="` + baseStr, `action="/p/` + r.token},
{`href='` + baseStr, `href='/proxy?url=` + url.QueryEscape(baseStr)}, {`href='` + baseStr, `href='/p/` + r.token},
{`src='` + baseStr, `src='/proxy?url=` + url.QueryEscape(baseStr)}, {`src='` + baseStr, `src='/p/` + r.token},
} }
for _, r := range replacements { for _, rep := range replacements {
content = strings.ReplaceAll(content, r.old, r.new) content = strings.ReplaceAll(content, rep.old, rep.new)
} }
return []byte(content) return []byte(content)
} }
// isTrackingScript 检查是否是跟踪脚本
func (r *ContentRewriter) isTrackingScript(src string) bool { func (r *ContentRewriter) isTrackingScript(src string) bool {
trackingDomains := []string{ trackingDomains := []string{
"google-analytics.com", "google-analytics.com",
@@ -324,4 +291,4 @@ func (r *ContentRewriter) isTrackingScript(src string) bool {
} }
return false return false
} }