更新 proxy/handler.go

This commit is contained in:
XOF
2025-12-15 04:15:29 +08:00
parent b318e76623
commit bc47d0152a

View File

@@ -17,6 +17,7 @@ type ProxyHandler struct {
validator *security.RequestValidator validator *security.RequestValidator
rateLimiter *security.RateLimiter rateLimiter *security.RateLimiter
cache *cache.MemoryCache cache *cache.MemoryCache
sessionManager *ProxySessionManager
userAgent string userAgent string
maxResponseSize int64 maxResponseSize int64
} }
@@ -25,6 +26,7 @@ func NewHandler(
validator *security.RequestValidator, validator *security.RequestValidator,
rateLimiter *security.RateLimiter, rateLimiter *security.RateLimiter,
cache *cache.MemoryCache, cache *cache.MemoryCache,
sessionManager *ProxySessionManager,
userAgent string, userAgent string,
maxResponseSize int64, maxResponseSize int64,
) *ProxyHandler { ) *ProxyHandler {
@@ -32,34 +34,50 @@ func NewHandler(
validator: validator, validator: validator,
rateLimiter: rateLimiter, rateLimiter: rateLimiter,
cache: cache, cache: cache,
sessionManager: sessionManager,
userAgent: userAgent, userAgent: userAgent,
maxResponseSize: maxResponseSize, maxResponseSize: maxResponseSize,
} }
} }
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 获取目标 URL path := strings.TrimPrefix(r.URL.Path, "/p/")
targetURL := r.URL.Query().Get("url") parts := strings.SplitN(path, "/", 2)
if targetURL == "" {
http.Error(w, "Missing url parameter", http.StatusBadRequest) if len(parts) == 0 || parts[0] == "" {
http.Error(w, "Invalid token", http.StatusBadRequest)
return return
} }
// 验证 URL token := parts[0]
subPath := ""
if len(parts) > 1 {
subPath = "/" + parts[1]
}
session := h.sessionManager.Get(token)
if session == nil {
http.Error(w, "Session expired or invalid", http.StatusUnauthorized)
return
}
targetURL := session.TargetURL + subPath
if r.URL.RawQuery != "" {
targetURL += "?" + r.URL.RawQuery
}
if err := h.validator.ValidateURL(targetURL); err != nil { if err := h.validator.ValidateURL(targetURL); err != nil {
log.Printf("URL validation failed: %v", err) log.Printf("URL validation failed: %v", err)
http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden) http.Error(w, "Invalid URL", http.StatusForbidden)
return return
} }
// 速率限制
clientIP := getClientIP(r) clientIP := getClientIP(r)
if !h.rateLimiter.Allow(clientIP) { if !h.rateLimiter.Allow(clientIP) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return return
} }
// 检查缓存
if entry := h.cache.Get(targetURL); entry != nil { if entry := h.cache.Get(targetURL); entry != nil {
log.Printf("Cache HIT: %s", targetURL) log.Printf("Cache HIT: %s", targetURL)
h.serveCached(w, entry) h.serveCached(w, entry)
@@ -68,17 +86,14 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("Cache MISS: %s", targetURL) log.Printf("Cache MISS: %s", targetURL)
// 创建代理请求
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil { if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError) http.Error(w, "Failed to create request", http.StatusInternalServerError)
return return
} }
// 设置请求头
h.setProxyHeaders(proxyReq, r) h.setProxyHeaders(proxyReq, r)
// 发送请求
client := &http.Client{ client := &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
@@ -97,7 +112,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
defer resp.Body.Close() defer resp.Body.Close()
// 读取响应体(会自动解压 gzip
body, err := h.readResponseBody(resp) body, err := h.readResponseBody(resp)
if err != nil { if err != nil {
log.Printf("Failed to read response: %v", err) log.Printf("Failed to read response: %v", err)
@@ -105,15 +119,12 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// 重写内容
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
body = h.rewriteContent(body, targetURL, contentType) body = h.rewriteContent(body, targetURL, contentType, token)
// 缓存响应 - 转换 headers 并删除 Content-Encoding
if h.shouldCache(resp) { if h.shouldCache(resp) {
headers := make(map[string]string) headers := make(map[string]string)
for key, values := range resp.Header { for key, values := range resp.Header {
// 跳过 Content-Encoding因为我们已经解压了
if key == "Content-Encoding" || key == "Content-Length" { if key == "Content-Encoding" || key == "Content-Length" {
continue continue
} }
@@ -124,12 +135,11 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.cache.Set(targetURL, body, headers) h.cache.Set(targetURL, body, headers)
} }
// 发送响应
h.sendResponse(w, resp, body) h.sendResponse(w, resp, body)
} }
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte { func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte {
rewriter, err := NewContentRewriter(targetURL) rewriter, err := NewContentRewriter(targetURL, token)
if err != nil { if err != nil {
log.Printf("Failed to create rewriter: %v", err) log.Printf("Failed to create rewriter: %v", err)
return body return body
@@ -137,7 +147,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string
contentType = strings.ToLower(contentType) contentType = strings.ToLower(contentType)
// HTML 内容
if strings.Contains(contentType, "text/html") { if strings.Contains(contentType, "text/html") {
rewritten, err := rewriter.RewriteHTML(body) rewritten, err := rewriter.RewriteHTML(body)
if err != nil { if err != nil {
@@ -147,7 +156,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string
return rewritten return rewritten
} }
// CSS 内容
if strings.Contains(contentType, "text/css") { if strings.Contains(contentType, "text/css") {
return rewriter.RewriteCSS(body) return rewriter.RewriteCSS(body)
} }
@@ -158,7 +166,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
var reader io.Reader = resp.Body var reader io.Reader = resp.Body
// 处理各种压缩格式
encoding := strings.ToLower(resp.Header.Get("Content-Encoding")) encoding := strings.ToLower(resp.Header.Get("Content-Encoding"))
if strings.Contains(encoding, "gzip") { if strings.Contains(encoding, "gzip") {
gzReader, err := gzip.NewReader(resp.Body) gzReader, err := gzip.NewReader(resp.Body)
@@ -169,20 +176,17 @@ func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
reader = gzReader reader = gzReader
} }
// 限制读取大小
limitReader := io.LimitReader(reader, h.maxResponseSize) limitReader := io.LimitReader(reader, h.maxResponseSize)
return io.ReadAll(limitReader) return io.ReadAll(limitReader)
} }
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
// 复制必要的请求头
headersToForward := []string{ headersToForward := []string{
"Accept", "Accept",
"Accept-Language", "Accept-Language",
"Accept-Encoding", "Accept-Encoding",
"Cache-Control", "Cache-Control",
"Referer",
} }
for _, header := range headersToForward { for _, header := range headersToForward {
@@ -191,26 +195,18 @@ func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
} }
} }
// 设置自定义 User-Agent
proxyReq.Header.Set("User-Agent", h.userAgent) proxyReq.Header.Set("User-Agent", h.userAgent)
// 移除可能暴露代理的头
proxyReq.Header.Del("X-Forwarded-For") proxyReq.Header.Del("X-Forwarded-For")
proxyReq.Header.Del("X-Real-IP") proxyReq.Header.Del("X-Real-IP")
proxyReq.Header.Del("Via") proxyReq.Header.Del("Via")
proxyReq.Header.Del("Referer")
} }
func (h *ProxyHandler) shouldCache(resp *http.Response) bool { func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
// 只缓存成功的 GET 请求 if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
if resp.Request.Method != "GET" {
return false return false
} }
if resp.StatusCode != http.StatusOK {
return false
}
// 检查 Cache-Control
cacheControl := resp.Header.Get("Cache-Control") cacheControl := resp.Header.Get("Cache-Control")
if strings.Contains(cacheControl, "no-store") || if strings.Contains(cacheControl, "no-store") ||
strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "no-cache") ||
@@ -218,15 +214,8 @@ func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
return false return false
} }
// 检查内容类型
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
cacheableTypes := []string{ cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"}
"text/html",
"text/css",
"application/javascript",
"image/",
"font/",
}
for _, ct := range cacheableTypes { for _, ct := range cacheableTypes {
if strings.Contains(contentType, ct) { if strings.Contains(contentType, ct) {
@@ -238,14 +227,7 @@ func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
} }
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
// 复制响应头 headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"}
headersToForward := []string{
"Content-Type",
"Content-Language",
"Last-Modified",
"ETag",
"Expires",
}
for _, header := range headersToForward { for _, header := range headersToForward {
if value := resp.Header.Get(header); value != "" { if value := resp.Header.Get(header); value != "" {
@@ -253,25 +235,18 @@ func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response,
} }
} }
// 添加自定义头
w.Header().Set("X-Proxied-By", "SiteProxy") w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("X-Cache-Status", "MISS") w.Header().Set("X-Cache-Status", "MISS")
// 安全头
w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "SAMEORIGIN") w.Header().Set("X-Frame-Options", "SAMEORIGIN")
w.Header().Set("Referrer-Policy", "no-referrer") w.Header().Set("Referrer-Policy", "no-referrer")
// 不设置 Content-Encoding 和 Content-Length让 Go 自动处理
w.WriteHeader(resp.StatusCode) w.WriteHeader(resp.StatusCode)
w.Write(body) w.Write(body)
} }
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
for key, value := range entry.Headers { for key, value := range entry.Headers {
// 跳过这些头,让 Go 自动处理
if key == "Content-Encoding" || key == "Content-Length" { if key == "Content-Encoding" || key == "Content-Length" {
continue continue
} }
@@ -287,20 +262,17 @@ func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntr
} }
func getClientIP(r *http.Request) string { func getClientIP(r *http.Request) string {
// 尝试从各种头中获取真实 IP
if ip := r.Header.Get("X-Real-IP"); ip != "" { if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip return ip
} }
if ip := r.Header.Get("X-Forwarded-For"); ip != "" { if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
// X-Forwarded-For 可能包含多个 IP
ips := strings.Split(ip, ",") ips := strings.Split(ip, ",")
if len(ips) > 0 { if len(ips) > 0 {
return strings.TrimSpace(ips[0]) return strings.TrimSpace(ips[0])
} }
} }
// 使用远程地址
ip := r.RemoteAddr ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 { if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx] ip = ip[:idx]