308 lines
7.9 KiB
Go
308 lines
7.9 KiB
Go
//proxy/rewriter.go
|
||
package proxy
|
||
|
||
import (
|
||
"bytes"
|
||
"golang.org/x/net/html"
|
||
"net/url"
|
||
"strings"
|
||
)
|
||
|
||
type ContentRewriter struct {
|
||
baseURL *url.URL
|
||
token string
|
||
}
|
||
|
||
func NewContentRewriter(baseURL, token string) (*ContentRewriter, error) {
|
||
u, err := url.Parse(baseURL)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &ContentRewriter{
|
||
baseURL: u,
|
||
token: token,
|
||
}, nil
|
||
}
|
||
|
||
func (r *ContentRewriter) RewriteHTML(body []byte) ([]byte, error) {
|
||
doc, err := html.Parse(bytes.NewReader(body))
|
||
if err != nil {
|
||
return r.simpleRewriteHTML(body), nil
|
||
}
|
||
|
||
r.rewriteNode(doc)
|
||
|
||
var buf bytes.Buffer
|
||
if err := html.Render(&buf, doc); err != nil {
|
||
return r.simpleRewriteHTML(body), nil
|
||
}
|
||
|
||
return buf.Bytes(), nil
|
||
}
|
||
|
||
func (r *ContentRewriter) rewriteNode(n *html.Node) {
|
||
if n.Type == html.ElementNode {
|
||
attrs := map[string]bool{"href": true, "src": true, "action": true, "data": true}
|
||
|
||
for i, attr := range n.Attr {
|
||
if attrs[attr.Key] {
|
||
if rewritten := r.rewriteURL(attr.Val); rewritten != attr.Val {
|
||
n.Attr[i].Val = rewritten
|
||
}
|
||
}
|
||
|
||
if attr.Key == "srcset" {
|
||
n.Attr[i].Val = r.rewriteSrcset(attr.Val)
|
||
}
|
||
|
||
if attr.Key == "style" {
|
||
n.Attr[i].Val = r.rewriteInlineCSS(attr.Val)
|
||
}
|
||
}
|
||
|
||
if n.Data == "form" {
|
||
hasAction := false
|
||
for i, attr := range n.Attr {
|
||
if attr.Key == "action" {
|
||
hasAction = true
|
||
if attr.Val == "" {
|
||
n.Attr[i].Val = r.rewriteURL(r.baseURL.String())
|
||
}
|
||
break
|
||
}
|
||
}
|
||
if !hasAction {
|
||
n.Attr = append(n.Attr, html.Attribute{
|
||
Key: "action",
|
||
Val: r.rewriteURL(r.baseURL.String()),
|
||
})
|
||
}
|
||
}
|
||
|
||
if n.Data == "base" {
|
||
for i, attr := range n.Attr {
|
||
if attr.Key == "href" {
|
||
n.Attr[i].Val = r.baseURL.String()
|
||
}
|
||
}
|
||
}
|
||
|
||
if n.Data == "style" && n.FirstChild != nil {
|
||
if n.FirstChild.Type == html.TextNode {
|
||
n.FirstChild.Data = r.rewriteInlineCSS(n.FirstChild.Data)
|
||
}
|
||
}
|
||
|
||
if n.Data == "script" {
|
||
for _, attr := range n.Attr {
|
||
if attr.Key == "src" {
|
||
if r.isTrackingScript(attr.Val) {
|
||
if n.Parent != nil {
|
||
n.Parent.RemoveChild(n)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||
r.rewriteNode(c)
|
||
}
|
||
}
|
||
|
||
func (r *ContentRewriter) rewriteURL(urlStr string) string {
|
||
urlStr = strings.TrimSpace(urlStr)
|
||
|
||
if strings.HasPrefix(urlStr, "javascript:") ||
|
||
strings.HasPrefix(urlStr, "data:") ||
|
||
strings.HasPrefix(urlStr, "mailto:") ||
|
||
strings.HasPrefix(urlStr, "tel:") ||
|
||
strings.HasPrefix(urlStr, "#") ||
|
||
urlStr == "" {
|
||
return urlStr
|
||
}
|
||
|
||
// 处理协议相对 URL(//domain.com/path)
|
||
if strings.HasPrefix(urlStr, "//") {
|
||
urlStr = r.baseURL.Scheme + ":" + urlStr
|
||
}
|
||
|
||
u, err := url.Parse(urlStr)
|
||
if err != nil {
|
||
return urlStr
|
||
}
|
||
|
||
if !u.IsAbs() {
|
||
u = r.baseURL.ResolveReference(u)
|
||
}
|
||
|
||
// 同域名,只保留路径
|
||
if u.Host == r.baseURL.Host {
|
||
proxyPath := u.Path
|
||
if u.RawQuery != "" {
|
||
proxyPath += "?" + u.RawQuery
|
||
}
|
||
if u.Fragment != "" {
|
||
proxyPath += "#" + u.Fragment
|
||
}
|
||
return "/p/" + r.token + proxyPath
|
||
}
|
||
|
||
// 跨域资源,完整 URL
|
||
return "/p/" + r.token + "/" + u.String()
|
||
}
|
||
|
||
|
||
func (r *ContentRewriter) rewriteSrcset(srcset string) string {
|
||
if srcset == "" {
|
||
return srcset
|
||
}
|
||
|
||
parts := strings.Split(srcset, ",")
|
||
var rewritten []string
|
||
|
||
for _, part := range parts {
|
||
part = strings.TrimSpace(part)
|
||
fields := strings.Fields(part)
|
||
|
||
if len(fields) > 0 {
|
||
fields[0] = r.rewriteURL(fields[0])
|
||
rewritten = append(rewritten, strings.Join(fields, " "))
|
||
}
|
||
}
|
||
|
||
return strings.Join(rewritten, ", ")
|
||
}
|
||
|
||
func (r *ContentRewriter) RewriteCSS(body []byte) []byte {
|
||
content := string(body)
|
||
return []byte(r.rewriteInlineCSS(content))
|
||
}
|
||
|
||
func (r *ContentRewriter) rewriteInlineCSS(css string) string {
|
||
result := css
|
||
|
||
patterns := []struct {
|
||
prefix string
|
||
suffix string
|
||
}{
|
||
{`url("`, `")`},
|
||
{`url('`, `')`},
|
||
{`url(`, `)`},
|
||
}
|
||
|
||
for _, pattern := range patterns {
|
||
start := 0
|
||
for {
|
||
idx := strings.Index(result[start:], pattern.prefix)
|
||
if idx == -1 {
|
||
break
|
||
}
|
||
|
||
idx += start
|
||
urlStart := idx + len(pattern.prefix)
|
||
urlEnd := strings.Index(result[urlStart:], pattern.suffix)
|
||
|
||
if urlEnd == -1 {
|
||
break
|
||
}
|
||
|
||
urlEnd += urlStart
|
||
originalURL := result[urlStart:urlEnd]
|
||
rewrittenURL := r.rewriteURL(originalURL)
|
||
|
||
result = result[:urlStart] + rewrittenURL + result[urlEnd:]
|
||
start = urlStart + len(rewrittenURL)
|
||
}
|
||
}
|
||
|
||
result = r.rewriteImports(result)
|
||
|
||
return result
|
||
}
|
||
|
||
func (r *ContentRewriter) rewriteImports(css string) string {
|
||
result := css
|
||
|
||
patterns := []string{`@import "`, `@import '`, `@import url("`, `@import url('`}
|
||
|
||
for _, pattern := range patterns {
|
||
start := 0
|
||
for {
|
||
idx := strings.Index(result[start:], pattern)
|
||
if idx == -1 {
|
||
break
|
||
}
|
||
|
||
idx += start
|
||
urlStart := idx + len(pattern)
|
||
|
||
var endChar string
|
||
if strings.Contains(pattern, `"`) {
|
||
endChar = `"`
|
||
} else {
|
||
endChar = `'`
|
||
}
|
||
|
||
urlEnd := strings.Index(result[urlStart:], endChar)
|
||
if urlEnd == -1 {
|
||
break
|
||
}
|
||
|
||
urlEnd += urlStart
|
||
originalURL := result[urlStart:urlEnd]
|
||
rewrittenURL := r.rewriteURL(originalURL)
|
||
|
||
result = result[:urlStart] + rewrittenURL + result[urlEnd:]
|
||
start = urlStart + len(rewrittenURL)
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func (r *ContentRewriter) simpleRewriteHTML(body []byte) []byte {
|
||
content := string(body)
|
||
baseStr := r.baseURL.Scheme + "://" + r.baseURL.Host
|
||
|
||
replacements := []struct {
|
||
old string
|
||
new string
|
||
}{
|
||
{`href="` + baseStr, `href="/p/` + r.token},
|
||
{`src="` + baseStr, `src="/p/` + r.token},
|
||
{`action="` + baseStr, `action="/p/` + r.token},
|
||
{`href='` + baseStr, `href='/p/` + r.token},
|
||
{`src='` + baseStr, `src='/p/` + r.token},
|
||
}
|
||
|
||
for _, rep := range replacements {
|
||
content = strings.ReplaceAll(content, rep.old, rep.new)
|
||
}
|
||
|
||
return []byte(content)
|
||
}
|
||
|
||
func (r *ContentRewriter) isTrackingScript(src string) bool {
|
||
trackingDomains := []string{
|
||
"google-analytics.com",
|
||
"googletagmanager.com",
|
||
"facebook.net",
|
||
"doubleclick.net",
|
||
"analytics.js",
|
||
"ga.js",
|
||
"gtag.js",
|
||
}
|
||
|
||
srcLower := strings.ToLower(src)
|
||
for _, domain := range trackingDomains {
|
||
if strings.Contains(srcLower, domain) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
} |