From bc47d0152a8808cf2073acc2380e1139fec494a4 Mon Sep 17 00:00:00 2001 From: XOF Date: Mon, 15 Dec 2025 04:15:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20proxy/handler.go?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- proxy/handler.go | 94 +++++++++++++++++------------------------------- 1 file changed, 33 insertions(+), 61 deletions(-) diff --git a/proxy/handler.go b/proxy/handler.go index 6598ac9..7032e8f 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -17,6 +17,7 @@ type ProxyHandler struct { validator *security.RequestValidator rateLimiter *security.RateLimiter cache *cache.MemoryCache + sessionManager *ProxySessionManager userAgent string maxResponseSize int64 } @@ -25,6 +26,7 @@ func NewHandler( validator *security.RequestValidator, rateLimiter *security.RateLimiter, cache *cache.MemoryCache, + sessionManager *ProxySessionManager, userAgent string, maxResponseSize int64, ) *ProxyHandler { @@ -32,34 +34,50 @@ func NewHandler( validator: validator, rateLimiter: rateLimiter, cache: cache, + sessionManager: sessionManager, userAgent: userAgent, maxResponseSize: maxResponseSize, } } func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // 获取目标 URL - targetURL := r.URL.Query().Get("url") - if targetURL == "" { - http.Error(w, "Missing url parameter", http.StatusBadRequest) + path := strings.TrimPrefix(r.URL.Path, "/p/") + parts := strings.SplitN(path, "/", 2) + + if len(parts) == 0 || parts[0] == "" { + http.Error(w, "Invalid token", http.StatusBadRequest) return } - // 验证 URL + token := parts[0] + subPath := "" + if len(parts) > 1 { + subPath = "/" + parts[1] + } + + session := h.sessionManager.Get(token) + if session == nil { + http.Error(w, "Session expired or invalid", http.StatusUnauthorized) + return + } + + targetURL := session.TargetURL + subPath + if r.URL.RawQuery != "" { + targetURL += "?" + r.URL.RawQuery + } + if err := h.validator.ValidateURL(targetURL); err != nil { log.Printf("URL validation failed: %v", err) - http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden) + http.Error(w, "Invalid URL", http.StatusForbidden) return } - // 速率限制 clientIP := getClientIP(r) if !h.rateLimiter.Allow(clientIP) { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } - // 检查缓存 if entry := h.cache.Get(targetURL); entry != nil { log.Printf("Cache HIT: %s", targetURL) h.serveCached(w, entry) @@ -68,17 +86,14 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("Cache MISS: %s", targetURL) - // 创建代理请求 proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { http.Error(w, "Failed to create request", http.StatusInternalServerError) return } - // 设置请求头 h.setProxyHeaders(proxyReq, r) - // 发送请求 client := &http.Client{ Timeout: 30 * time.Second, CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -97,7 +112,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer resp.Body.Close() - // 读取响应体(会自动解压 gzip) body, err := h.readResponseBody(resp) if err != nil { log.Printf("Failed to read response: %v", err) @@ -105,15 +119,12 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // 重写内容 contentType := resp.Header.Get("Content-Type") - body = h.rewriteContent(body, targetURL, contentType) + body = h.rewriteContent(body, targetURL, contentType, token) - // 缓存响应 - 转换 headers 并删除 Content-Encoding if h.shouldCache(resp) { headers := make(map[string]string) for key, values := range resp.Header { - // 跳过 Content-Encoding,因为我们已经解压了 if key == "Content-Encoding" || key == "Content-Length" { continue } @@ -124,12 +135,11 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.cache.Set(targetURL, body, headers) } - // 发送响应 h.sendResponse(w, resp, body) } -func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte { - rewriter, err := NewContentRewriter(targetURL) +func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte { + rewriter, err := NewContentRewriter(targetURL, token) if err != nil { log.Printf("Failed to create rewriter: %v", err) return body @@ -137,7 +147,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string contentType = strings.ToLower(contentType) - // HTML 内容 if strings.Contains(contentType, "text/html") { rewritten, err := rewriter.RewriteHTML(body) if err != nil { @@ -147,7 +156,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string return rewritten } - // CSS 内容 if strings.Contains(contentType, "text/css") { return rewriter.RewriteCSS(body) } @@ -158,7 +166,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { var reader io.Reader = resp.Body - // 处理各种压缩格式 encoding := strings.ToLower(resp.Header.Get("Content-Encoding")) if strings.Contains(encoding, "gzip") { gzReader, err := gzip.NewReader(resp.Body) @@ -169,20 +176,17 @@ func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { reader = gzReader } - // 限制读取大小 limitReader := io.LimitReader(reader, h.maxResponseSize) return io.ReadAll(limitReader) } func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { - // 复制必要的请求头 headersToForward := []string{ "Accept", "Accept-Language", "Accept-Encoding", "Cache-Control", - "Referer", } for _, header := range headersToForward { @@ -191,26 +195,18 @@ func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { } } - // 设置自定义 User-Agent proxyReq.Header.Set("User-Agent", h.userAgent) - - // 移除可能暴露代理的头 proxyReq.Header.Del("X-Forwarded-For") proxyReq.Header.Del("X-Real-IP") proxyReq.Header.Del("Via") + proxyReq.Header.Del("Referer") } func (h *ProxyHandler) shouldCache(resp *http.Response) bool { - // 只缓存成功的 GET 请求 - if resp.Request.Method != "GET" { + if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK { return false } - if resp.StatusCode != http.StatusOK { - return false - } - - // 检查 Cache-Control cacheControl := resp.Header.Get("Cache-Control") if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "no-cache") || @@ -218,15 +214,8 @@ func (h *ProxyHandler) shouldCache(resp *http.Response) bool { return false } - // 检查内容类型 contentType := resp.Header.Get("Content-Type") - cacheableTypes := []string{ - "text/html", - "text/css", - "application/javascript", - "image/", - "font/", - } + cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"} for _, ct := range cacheableTypes { if strings.Contains(contentType, ct) { @@ -238,14 +227,7 @@ func (h *ProxyHandler) shouldCache(resp *http.Response) bool { } func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { - // 复制响应头 - headersToForward := []string{ - "Content-Type", - "Content-Language", - "Last-Modified", - "ETag", - "Expires", - } + headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"} for _, header := range headersToForward { if value := resp.Header.Get(header); value != "" { @@ -253,25 +235,18 @@ func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, } } - // 添加自定义头 w.Header().Set("X-Proxied-By", "SiteProxy") w.Header().Set("X-Cache-Status", "MISS") - - // 安全头 w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "SAMEORIGIN") w.Header().Set("Referrer-Policy", "no-referrer") - // 不设置 Content-Encoding 和 Content-Length,让 Go 自动处理 - w.WriteHeader(resp.StatusCode) w.Write(body) } - func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { for key, value := range entry.Headers { - // 跳过这些头,让 Go 自动处理 if key == "Content-Encoding" || key == "Content-Length" { continue } @@ -287,20 +262,17 @@ func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntr } func getClientIP(r *http.Request) string { - // 尝试从各种头中获取真实 IP if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } if ip := r.Header.Get("X-Forwarded-For"); ip != "" { - // X-Forwarded-For 可能包含多个 IP ips := strings.Split(ip, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } - // 使用远程地址 ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx]