更新 proxy/handler.go

This commit is contained in:
XOF
2025-12-15 02:21:48 +08:00
parent 3d98525d82
commit 3a6b34fe19

View File

@@ -3,7 +3,7 @@ package proxy
import ( import (
"bytes" "bytes"
"fmt" "compress/gzip"
"io" "io"
"log" "log"
"net/http" "net/http"
@@ -16,49 +16,30 @@ import (
) )
type ProxyHandler struct { type ProxyHandler struct {
validator *security.RequestValidator validator *security.RequestValidator
rateLimiter *security.RateLimiter rateLimiter *security.RateLimiter
cache *cache.MemoryCache cache *cache.MemoryCache
client *http.Client userAgent string
userAgent string maxResponseSize int64
maxSize int64
} }
func NewHandler(validator *security.RequestValidator, rateLimiter *security.RateLimiter, memCache *cache.MemoryCache, userAgent string, maxSize int64) *ProxyHandler { func NewHandler(
validator *security.RequestValidator,
rateLimiter *security.RateLimiter,
cache *cache.MemoryCache,
userAgent string,
maxResponseSize int64,
) *ProxyHandler {
return &ProxyHandler{ return &ProxyHandler{
validator: validator, validator: validator,
rateLimiter: rateLimiter, rateLimiter: rateLimiter,
cache: memCache, cache: cache,
userAgent: userAgent, userAgent: userAgent,
maxSize: maxSize, maxResponseSize: maxResponseSize,
client: &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return fmt.Errorf("too many redirects")
}
return nil
},
},
} }
} }
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 获取 session ID
cookie, err := r.Cookie("session_id")
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
sessionID := cookie.Value
// 速率限制检查
if !h.rateLimiter.Allow(sessionID) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
// 获取目标 URL // 获取目标 URL
targetURL := r.URL.Query().Get("url") targetURL := r.URL.Query().Get("url")
if targetURL == "" { if targetURL == "" {
@@ -69,133 +50,181 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 验证 URL // 验证 URL
if err := h.validator.ValidateURL(targetURL); err != nil { if err := h.validator.ValidateURL(targetURL); err != nil {
log.Printf("URL validation failed: %v", err) log.Printf("URL validation failed: %v", err)
http.Error(w, "Invalid or blocked URL", http.StatusForbidden) http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden)
return
}
// 速率限制
clientIP := getClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return return
} }
// 检查缓存 // 检查缓存
cacheKey := h.cache.GenerateKey(targetURL) if entry := h.cache.Get(targetURL); entry != nil {
if cached, ok := h.cache.Get(cacheKey); ok { log.Printf("Cache HIT: %s", targetURL)
log.Printf("Cache hit: %s", targetURL) h.serveCached(w, entry)
h.serveCached(w, cached)
return return
} }
// 发起代理请求 log.Printf("Cache MISS: %s", targetURL)
log.Printf("Proxying: %s", targetURL)
if err := h.proxyRequest(w, targetURL, cacheKey); err != nil { // 创建代理请求
log.Printf("Proxy error: %v", err) proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
http.Error(w, "Proxy request failed", http.StatusBadGateway)
return
}
}
func (h *ProxyHandler) proxyRequest(w http.ResponseWriter, targetURL, cacheKey string) error {
// 创建请求
req, err := http.NewRequest("GET", targetURL, nil)
if err != nil { if err != nil {
return err http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
} }
// 设置请求头 // 设置请求头
req.Header.Set("User-Agent", h.userAgent) h.setProxyHeaders(proxyReq, r)
req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
req.Header.Set("DNT", "1")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Upgrade-Insecure-Requests", "1")
// 发送请求 // 发送请求
resp, err := h.client.Do(req) client := &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// 验证重定向 URL
if err := h.validator.ValidateURL(req.URL.String()); err != nil {
return err
}
return nil
},
}
resp, err := client.Do(proxyReq)
if err != nil { if err != nil {
return err log.Printf("Request failed: %v", err)
http.Error(w, "Failed to fetch URL", http.StatusBadGateway)
return
} }
defer resp.Body.Close() defer resp.Body.Close()
// 限制响应大小 // 读取响应体
limitedReader := io.LimitReader(resp.Body, h.maxSize) body, err := h.readResponseBody(resp)
body, err := io.ReadAll(limitedReader)
if err != nil { if err != nil {
return err log.Printf("Failed to read response: %v", err)
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
// 重写内容
contentType := resp.Header.Get("Content-Type")
body = h.rewriteContent(body, targetURL, contentType)
// 缓存响应
if h.shouldCache(resp) {
h.cache.Set(targetURL, body, resp.Header)
}
// 发送响应
h.sendResponse(w, resp, body)
}
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte {
rewriter, err := NewContentRewriter(targetURL)
if err != nil {
log.Printf("Failed to create rewriter: %v", err)
return body
}
contentType = strings.ToLower(contentType)
// HTML 内容
if strings.Contains(contentType, "text/html") {
rewritten, err := rewriter.RewriteHTML(body)
if err != nil {
log.Printf("HTML rewrite failed: %v", err)
return body
}
return rewritten
}
// CSS 内容
if strings.Contains(contentType, "text/css") {
return rewriter.RewriteCSS(body)
}
// JavaScript 内容 - 暂时不重写,可能会破坏功能
// 未来可以添加更智能的 JS 重写
return body
}
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
var reader io.Reader = resp.Body
// 处理 gzip 压缩
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
defer gzReader.Close()
reader = gzReader
}
// 限制读取大小
limitReader := io.LimitReader(reader, h.maxResponseSize)
return io.ReadAll(limitReader)
}
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
// 复制必要的请求头
headersToForward := []string{
"Accept",
"Accept-Language",
"Accept-Encoding",
"Cache-Control",
"Referer",
}
for _, header := range headersToForward {
if value := originalReq.Header.Get(header); value != "" {
proxyReq.Header.Set(header, value)
}
}
// 设置自定义 User-Agent
proxyReq.Header.Set("User-Agent", h.userAgent)
// 移除可能暴露代理的头
proxyReq.Header.Del("X-Forwarded-For")
proxyReq.Header.Del("X-Real-IP")
proxyReq.Header.Del("Via")
}
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
// 只缓存成功的 GET 请求
if resp.Request.Method != "GET" {
return false
}
if resp.StatusCode != http.StatusOK {
return false
}
// 检查 Cache-Control
cacheControl := resp.Header.Get("Cache-Control")
if strings.Contains(cacheControl, "no-store") ||
strings.Contains(cacheControl, "no-cache") ||
strings.Contains(cacheControl, "private") {
return false
} }
// 检查内容类型 // 检查内容类型
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
// 如果是 HTML进行 URL 重写
if strings.Contains(contentType, "text/html") {
body = h.rewriteHTML(body, targetURL)
} else if strings.Contains(contentType, "text/css") {
body = h.rewriteCSS(body, targetURL)
}
// 缓存静态资源
if h.shouldCache(contentType) {
headers := make(map[string]string)
headers["Content-Type"] = contentType
h.cache.Set(cacheKey, body, headers)
}
// 设置响应头
w.Header().Set("Content-Type", contentType)
w.Header().Set("X-Proxied-By", "SiteProxy")
// 移除可能泄露的头
for _, header := range []string{"Server", "X-Powered-By", "Set-Cookie"} {
w.Header().Del(header)
}
w.WriteHeader(resp.StatusCode)
w.Write(body)
return nil
}
func (h *ProxyHandler) rewriteHTML(body []byte, baseURL string) []byte {
content := string(body)
// 解析基础 URL
base, err := url.Parse(baseURL)
if err != nil {
return body
}
// 重写绝对 URL
content = strings.ReplaceAll(content, `href="`+base.Scheme+`://`+base.Host, `href="/proxy?url=`+url.QueryEscape(base.Scheme+`://`+base.Host))
content = strings.ReplaceAll(content, `src="`+base.Scheme+`://`+base.Host, `src="/proxy?url=`+url.QueryEscape(base.Scheme+`://`+base.Host))
// 重写相对 URL简化版本
// 注意:完整实现需要 HTML 解析器
return []byte(content)
}
func (h *ProxyHandler) rewriteCSS(body []byte, baseURL string) []byte {
content := string(body)
base, err := url.Parse(baseURL)
if err != nil {
return body
}
// 重写 CSS 中的 url()
content = strings.ReplaceAll(content, `url(`+base.Scheme+`://`+base.Host, `url(/proxy?url=`+url.QueryEscape(base.Scheme+`://`+base.Host))
return []byte(content)
}
func (h *ProxyHandler) shouldCache(contentType string) bool {
cacheableTypes := []string{ cacheableTypes := []string{
"image/", "text/html",
"text/css", "text/css",
"application/javascript", "application/javascript",
"application/json", "image/",
"font/", "font/",
} }
for _, t := range cacheableTypes { for _, ct := range cacheableTypes {
if strings.Contains(contentType, t) { if strings.Contains(contentType, ct) {
return true return true
} }
} }
@@ -203,12 +232,72 @@ func (h *ProxyHandler) shouldCache(contentType string) bool {
return false return false
} }
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
// 复制响应头
headersToForward := []string{
"Content-Type",
"Content-Language",
"Last-Modified",
"ETag",
"Expires",
}
for _, header := range headersToForward {
if value := resp.Header.Get(header); value != "" {
w.Header().Set(header, value)
}
}
// 添加自定义头
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("X-Cache-Status", "MISS")
// 移除不需要的头
w.Header().Del("Content-Encoding") // 我们已经解压了
w.Header().Del("Content-Length") // 长度可能已改变
// 安全头
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
w.Header().Set("Referrer-Policy", "no-referrer")
w.WriteHeader(resp.StatusCode)
w.Write(body)
}
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
for key, value := range entry.Headers { for key, value := range entry.Headers {
w.Header().Set(key, value) w.Header().Set(key, value)
} }
w.Header().Set("X-Cache-Status", "HIT") w.Header().Set("X-Cache-Status", "HIT")
w.Header().Set("X-Proxied-By", "SiteProxy") w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("Age", entry.Age())
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(entry.Data) w.Write(entry.Data)
} }
func getClientIP(r *http.Request) string {
// 尝试从各种头中获取真实 IP
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
// X-Forwarded-For 可能包含多个 IP
ips := strings.Split(ip, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// 使用远程地址
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}