//proxy/rewriter.go
package proxy
import (
"bytes"
"golang.org/x/net/html"
"net/url"
"strings"
)
type ContentRewriter struct {
baseURL *url.URL
}
func NewContentRewriter(baseURL string) (*ContentRewriter, error) {
u, err := url.Parse(baseURL)
if err != nil {
return nil, err
}
return &ContentRewriter{
baseURL: u,
}, nil
}
// RewriteHTML 重写 HTML 内容中的所有 URL
func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) {
doc, err := html.Parse(bytes.NewReader(body))
if err != nil {
// 如果解析失败,使用简单的字符串替换
return r.simpleRewriteHTML(body), nil
}
r.rewriteNode(doc)
var buf bytes.Buffer
if err := html.Render(&buf, doc); err != nil {
return r.simpleRewriteHTML(body), nil
}
return buf.Bytes(), nil
}
// rewriteNode 递归重写 HTML 节点
func (r *ContentRewriter) rewriteNode(n *html.Node) {
if n.Type == html.ElementNode {
// 重写需要处理的属性
attrs := map[string]bool{
"href": true,
"src": true,
"action": true,
"data": true,
}
for i, attr := range n.Attr {
if attrs[attr.Key] {
if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val {
n.Attr[i].Val = rewritten
}
}
// 处理 srcset 属性
if attr.Key == "srcset" {
n.Attr[i].Val = r.rewriteSrcset(attr.Val)
}
// 处理 style 属性中的 URL
if attr.Key == "style" {
n.Attr[i].Val = r.rewriteInlineCSS(attr.Val)
}
}
// 处理 标签
if n.Data == "base" {
for i, attr := range n.Attr {
if attr.Key == "href" {
n.Attr[i].Val = r.baseURL.String()
}
}
}
// 处理