// proxy/handler.go package proxy import ( "bytes" "compress/gzip" "io" "log" "net/http" "net/url" "strings" "time" "siteproxy/cache" "siteproxy/security" ) type ProxyHandler struct { validator *security.RequestValidator rateLimiter *security.RateLimiter cache *cache.MemoryCache userAgent string maxResponseSize int64 } func NewHandler( validator *security.RequestValidator, rateLimiter *security.RateLimiter, cache *cache.MemoryCache, userAgent string, maxResponseSize int64, ) *ProxyHandler { return &ProxyHandler{ validator: validator, rateLimiter: rateLimiter, cache: cache, userAgent: userAgent, maxResponseSize: maxResponseSize, } } func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 获取目标 URL targetURL := r.URL.Query().Get("url") if targetURL == "" { http.Error(w, "Missing url parameter", http.StatusBadRequest) return } // 验证 URL if err := h.validator.ValidateURL(targetURL); err != nil { log.Printf("URL validation failed: %v", err) http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden) return } // 速率限制 clientIP := getClientIP(r) if !h.rateLimiter.Allow(clientIP) { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } // 检查缓存 if entry := h.cache.Get(targetURL); entry != nil { log.Printf("Cache HIT: %s", targetURL) h.serveCached(w, entry) return } log.Printf("Cache MISS: %s", targetURL) // 创建代理请求 proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { http.Error(w, "Failed to create request", http.StatusInternalServerError) return } // 设置请求头 h.setProxyHeaders(proxyReq, r) // 发送请求 client := &http.Client{ Timeout: 30 * time.Second, CheckRedirect: func(req *http.Request, via []*http.Request) error { // 验证重定向 URL if err := h.validator.ValidateURL(req.URL.String()); err != nil { return err } return nil }, } resp, err := client.Do(proxyReq) if err != nil { log.Printf("Request failed: %v", err) http.Error(w, "Failed to fetch URL", http.StatusBadGateway) return } defer resp.Body.Close() // 读取响应体 body, err := h.readResponseBody(resp) if err != nil { log.Printf("Failed to read response: %v", err) http.Error(w, "Failed to read response", http.StatusInternalServerError) return } // 重写内容 contentType := resp.Header.Get("Content-Type") body = h.rewriteContent(body, targetURL, contentType) // 缓存响应 if h.shouldCache(resp) { h.cache.Set(targetURL, body, resp.Header) } // 发送响应 h.sendResponse(w, resp, body) } func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte { rewriter, err := NewContentRewriter(targetURL) if err != nil { log.Printf("Failed to create rewriter: %v", err) return body } contentType = strings.ToLower(contentType) // HTML 内容 if strings.Contains(contentType, "text/html") { rewritten, err := rewriter.RewriteHTML(body) if err != nil { log.Printf("HTML rewrite failed: %v", err) return body } return rewritten } // CSS 内容 if strings.Contains(contentType, "text/css") { return rewriter.RewriteCSS(body) } // JavaScript 内容 - 暂时不重写,可能会破坏功能 // 未来可以添加更智能的 JS 重写 return body } func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { var reader io.Reader = resp.Body // 处理 gzip 压缩 if resp.Header.Get("Content-Encoding") == "gzip" { gzReader, err := gzip.NewReader(resp.Body) if err != nil { return nil, err } defer gzReader.Close() reader = gzReader } // 限制读取大小 limitReader := io.LimitReader(reader, h.maxResponseSize) return io.ReadAll(limitReader) } func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { // 复制必要的请求头 headersToForward := []string{ "Accept", "Accept-Language", "Accept-Encoding", "Cache-Control", "Referer", } for _, header := range headersToForward { if value := originalReq.Header.Get(header); value != "" { proxyReq.Header.Set(header, value) } } // 设置自定义 User-Agent proxyReq.Header.Set("User-Agent", h.userAgent) // 移除可能暴露代理的头 proxyReq.Header.Del("X-Forwarded-For") proxyReq.Header.Del("X-Real-IP") proxyReq.Header.Del("Via") } func (h *ProxyHandler) shouldCache(resp *http.Response) bool { // 只缓存成功的 GET 请求 if resp.Request.Method != "GET" { return false } if resp.StatusCode != http.StatusOK { return false } // 检查 Cache-Control cacheControl := resp.Header.Get("Cache-Control") if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "private") { return false } // 检查内容类型 contentType := resp.Header.Get("Content-Type") cacheableTypes := []string{ "text/html", "text/css", "application/javascript", "image/", "font/", } for _, ct := range cacheableTypes { if strings.Contains(contentType, ct) { return true } } return false } func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { // 复制响应头 headersToForward := []string{ "Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires", } for _, header := range headersToForward { if value := resp.Header.Get(header); value != "" { w.Header().Set(header, value) } } // 添加自定义头 w.Header().Set("X-Proxied-By", "SiteProxy") w.Header().Set("X-Cache-Status", "MISS") // 移除不需要的头 w.Header().Del("Content-Encoding") // 我们已经解压了 w.Header().Del("Content-Length") // 长度可能已改变 // 安全头 w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "SAMEORIGIN") w.Header().Set("Referrer-Policy", "no-referrer") w.WriteHeader(resp.StatusCode) w.Write(body) } func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { for key, value := range entry.Headers { w.Header().Set(key, value) } w.Header().Set("X-Cache-Status", "HIT") w.Header().Set("X-Proxied-By", "SiteProxy") w.Header().Set("Age", entry.Age()) w.WriteHeader(http.StatusOK) w.Write(entry.Data) } func getClientIP(r *http.Request) string { // 尝试从各种头中获取真实 IP if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } if ip := r.Header.Get("X-Forwarded-For"); ip != "" { // X-Forwarded-For 可能包含多个 IP ips := strings.Split(ip, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } // 使用远程地址 ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip }