diff --git a/proxy/handler.go b/proxy/handler.go index 9e1d2bf..f616cc7 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -3,7 +3,7 @@ package proxy import ( "bytes" - "fmt" + "compress/gzip" "io" "log" "net/http" @@ -16,49 +16,30 @@ import ( ) type ProxyHandler struct { - validator *security.RequestValidator - rateLimiter *security.RateLimiter - cache *cache.MemoryCache - client *http.Client - userAgent string - maxSize int64 + validator *security.RequestValidator + rateLimiter *security.RateLimiter + cache *cache.MemoryCache + userAgent string + maxResponseSize int64 } -func NewHandler(validator *security.RequestValidator, rateLimiter *security.RateLimiter, memCache *cache.MemoryCache, userAgent string, maxSize int64) *ProxyHandler { +func NewHandler( + validator *security.RequestValidator, + rateLimiter *security.RateLimiter, + cache *cache.MemoryCache, + userAgent string, + maxResponseSize int64, +) *ProxyHandler { return &ProxyHandler{ - validator: validator, - rateLimiter: rateLimiter, - cache: memCache, - userAgent: userAgent, - maxSize: maxSize, - client: &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return fmt.Errorf("too many redirects") - } - return nil - }, - }, + validator: validator, + rateLimiter: rateLimiter, + cache: cache, + userAgent: userAgent, + maxResponseSize: maxResponseSize, } } func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // 获取 session ID - cookie, err := r.Cookie("session_id") - if err != nil { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - - sessionID := cookie.Value - - // 速率限制检查 - if !h.rateLimiter.Allow(sessionID) { - http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) - return - } - // 获取目标 URL targetURL := r.URL.Query().Get("url") if targetURL == "" { @@ -69,133 +50,181 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 验证 URL if err := h.validator.ValidateURL(targetURL); err != nil { log.Printf("URL validation failed: %v", err) - http.Error(w, "Invalid or blocked URL", http.StatusForbidden) + http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden) + return + } + + // 速率限制 + clientIP := getClientIP(r) + if !h.rateLimiter.Allow(clientIP) { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } // 检查缓存 - cacheKey := h.cache.GenerateKey(targetURL) - if cached, ok := h.cache.Get(cacheKey); ok { - log.Printf("Cache hit: %s", targetURL) - h.serveCached(w, cached) + if entry := h.cache.Get(targetURL); entry != nil { + log.Printf("Cache HIT: %s", targetURL) + h.serveCached(w, entry) return } - // 发起代理请求 - log.Printf("Proxying: %s", targetURL) - if err := h.proxyRequest(w, targetURL, cacheKey); err != nil { - log.Printf("Proxy error: %v", err) - http.Error(w, "Proxy request failed", http.StatusBadGateway) - return - } -} - -func (h *ProxyHandler) proxyRequest(w http.ResponseWriter, targetURL, cacheKey string) error { - // 创建请求 - req, err := http.NewRequest("GET", targetURL, nil) + log.Printf("Cache MISS: %s", targetURL) + + // 创建代理请求 + proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { - return err + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return } // 设置请求头 - req.Header.Set("User-Agent", h.userAgent) - req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8") - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - req.Header.Set("Accept-Encoding", "gzip, deflate, br") - req.Header.Set("DNT", "1") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Upgrade-Insecure-Requests", "1") + h.setProxyHeaders(proxyReq, r) // 发送请求 - resp, err := h.client.Do(req) + client := &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // 验证重定向 URL + if err := h.validator.ValidateURL(req.URL.String()); err != nil { + return err + } + return nil + }, + } + + resp, err := client.Do(proxyReq) if err != nil { - return err + log.Printf("Request failed: %v", err) + http.Error(w, "Failed to fetch URL", http.StatusBadGateway) + return } defer resp.Body.Close() - // 限制响应大小 - limitedReader := io.LimitReader(resp.Body, h.maxSize) - body, err := io.ReadAll(limitedReader) + // 读取响应体 + body, err := h.readResponseBody(resp) if err != nil { - return err + log.Printf("Failed to read response: %v", err) + http.Error(w, "Failed to read response", http.StatusInternalServerError) + return + } + + // 重写内容 + contentType := resp.Header.Get("Content-Type") + body = h.rewriteContent(body, targetURL, contentType) + + // 缓存响应 + if h.shouldCache(resp) { + h.cache.Set(targetURL, body, resp.Header) + } + + // 发送响应 + h.sendResponse(w, resp, body) +} + +func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte { + rewriter, err := NewContentRewriter(targetURL) + if err != nil { + log.Printf("Failed to create rewriter: %v", err) + return body + } + + contentType = strings.ToLower(contentType) + + // HTML 内容 + if strings.Contains(contentType, "text/html") { + rewritten, err := rewriter.RewriteHTML(body) + if err != nil { + log.Printf("HTML rewrite failed: %v", err) + return body + } + return rewritten + } + + // CSS 内容 + if strings.Contains(contentType, "text/css") { + return rewriter.RewriteCSS(body) + } + + // JavaScript 内容 - 暂时不重写,可能会破坏功能 + // 未来可以添加更智能的 JS 重写 + + return body +} + +func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { + var reader io.Reader = resp.Body + + // 处理 gzip 压缩 + if resp.Header.Get("Content-Encoding") == "gzip" { + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + defer gzReader.Close() + reader = gzReader + } + + // 限制读取大小 + limitReader := io.LimitReader(reader, h.maxResponseSize) + + return io.ReadAll(limitReader) +} + +func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { + // 复制必要的请求头 + headersToForward := []string{ + "Accept", + "Accept-Language", + "Accept-Encoding", + "Cache-Control", + "Referer", + } + + for _, header := range headersToForward { + if value := originalReq.Header.Get(header); value != "" { + proxyReq.Header.Set(header, value) + } + } + + // 设置自定义 User-Agent + proxyReq.Header.Set("User-Agent", h.userAgent) + + // 移除可能暴露代理的头 + proxyReq.Header.Del("X-Forwarded-For") + proxyReq.Header.Del("X-Real-IP") + proxyReq.Header.Del("Via") +} + +func (h *ProxyHandler) shouldCache(resp *http.Response) bool { + // 只缓存成功的 GET 请求 + if resp.Request.Method != "GET" { + return false + } + + if resp.StatusCode != http.StatusOK { + return false + } + + // 检查 Cache-Control + cacheControl := resp.Header.Get("Cache-Control") + if strings.Contains(cacheControl, "no-store") || + strings.Contains(cacheControl, "no-cache") || + strings.Contains(cacheControl, "private") { + return false } // 检查内容类型 contentType := resp.Header.Get("Content-Type") - - // 如果是 HTML,进行 URL 重写 - if strings.Contains(contentType, "text/html") { - body = h.rewriteHTML(body, targetURL) - } else if strings.Contains(contentType, "text/css") { - body = h.rewriteCSS(body, targetURL) - } - - // 缓存静态资源 - if h.shouldCache(contentType) { - headers := make(map[string]string) - headers["Content-Type"] = contentType - h.cache.Set(cacheKey, body, headers) - } - - // 设置响应头 - w.Header().Set("Content-Type", contentType) - w.Header().Set("X-Proxied-By", "SiteProxy") - - // 移除可能泄露的头 - for _, header := range []string{"Server", "X-Powered-By", "Set-Cookie"} { - w.Header().Del(header) - } - - w.WriteHeader(resp.StatusCode) - w.Write(body) - - return nil -} - -func (h *ProxyHandler) rewriteHTML(body []byte, baseURL string) []byte { - content := string(body) - - // 解析基础 URL - base, err := url.Parse(baseURL) - if err != nil { - return body - } - - // 重写绝对 URL - content = strings.ReplaceAll(content, `href="`+base.Scheme+`://`+base.Host, `href="/proxy?url=`+url.QueryEscape(base.Scheme+`://`+base.Host)) - content = strings.ReplaceAll(content, `src="`+base.Scheme+`://`+base.Host, `src="/proxy?url=`+url.QueryEscape(base.Scheme+`://`+base.Host)) - - // 重写相对 URL(简化版本) - // 注意:完整实现需要 HTML 解析器 - - return []byte(content) -} - -func (h *ProxyHandler) rewriteCSS(body []byte, baseURL string) []byte { - content := string(body) - - base, err := url.Parse(baseURL) - if err != nil { - return body - } - - // 重写 CSS 中的 url() - content = strings.ReplaceAll(content, `url(`+base.Scheme+`://`+base.Host, `url(/proxy?url=`+url.QueryEscape(base.Scheme+`://`+base.Host)) - - return []byte(content) -} - -func (h *ProxyHandler) shouldCache(contentType string) bool { cacheableTypes := []string{ - "image/", + "text/html", "text/css", "application/javascript", - "application/json", + "image/", "font/", } - for _, t := range cacheableTypes { - if strings.Contains(contentType, t) { + for _, ct := range cacheableTypes { + if strings.Contains(contentType, ct) { return true } } @@ -203,12 +232,72 @@ func (h *ProxyHandler) shouldCache(contentType string) bool { return false } +func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { + // 复制响应头 + headersToForward := []string{ + "Content-Type", + "Content-Language", + "Last-Modified", + "ETag", + "Expires", + } + + for _, header := range headersToForward { + if value := resp.Header.Get(header); value != "" { + w.Header().Set(header, value) + } + } + + // 添加自定义头 + w.Header().Set("X-Proxied-By", "SiteProxy") + w.Header().Set("X-Cache-Status", "MISS") + + // 移除不需要的头 + w.Header().Del("Content-Encoding") // 我们已经解压了 + w.Header().Del("Content-Length") // 长度可能已改变 + + // 安全头 + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "SAMEORIGIN") + w.Header().Set("Referrer-Policy", "no-referrer") + + w.WriteHeader(resp.StatusCode) + w.Write(body) +} + func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { for key, value := range entry.Headers { w.Header().Set(key, value) } + w.Header().Set("X-Cache-Status", "HIT") w.Header().Set("X-Proxied-By", "SiteProxy") + w.Header().Set("Age", entry.Age()) + w.WriteHeader(http.StatusOK) w.Write(entry.Data) } + +func getClientIP(r *http.Request) string { + // 尝试从各种头中获取真实 IP + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + // X-Forwarded-For 可能包含多个 IP + ips := strings.Split(ip, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // 使用远程地址 + ip := r.RemoteAddr + if idx := strings.LastIndex(ip, ":"); idx != -1 { + ip = ip[:idx] + } + + return ip +} +