// proxy/handler.go package proxy import ( "compress/gzip" "io" "log" "net/http" "strings" "time" "siteproxy/cache" "siteproxy/security" ) type ProxyHandler struct { validator *security.RequestValidator rateLimiter *security.RateLimiter cache *cache.MemoryCache sessionManager *ProxySessionManager userAgent string maxResponseSize int64 } func NewHandler( validator *security.RequestValidator, rateLimiter *security.RateLimiter, cache *cache.MemoryCache, sessionManager *ProxySessionManager, userAgent string, maxResponseSize int64, ) *ProxyHandler { return &ProxyHandler{ validator: validator, rateLimiter: rateLimiter, cache: cache, sessionManager: sessionManager, userAgent: userAgent, maxResponseSize: maxResponseSize, } } func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/p/") slashIdx := strings.Index(path, "/") var token, targetPath string if slashIdx == -1 { token = path targetPath = "" } else { token = path[:slashIdx] targetPath = path[slashIdx+1:] } if token == "" { http.Error(w, "Invalid token", http.StatusBadRequest) return } session := h.sessionManager.Get(token) if session == nil { http.Error(w, "Session expired or invalid", http.StatusUnauthorized) return } var targetURL string if strings.HasPrefix(targetPath, "http:/") && !strings.HasPrefix(targetPath, "http://") { targetPath = strings.Replace(targetPath, "http:/", "http://", 1) } if strings.HasPrefix(targetPath, "https:/") && !strings.HasPrefix(targetPath, "https://") { targetPath = strings.Replace(targetPath, "https:/", "https://", 1) } if strings.HasPrefix(targetPath, "http://") || strings.HasPrefix(targetPath, "https://") { targetURL = targetPath } else { baseURL := strings.TrimSuffix(session.TargetURL, "/") if targetPath == "" { targetURL = baseURL } else { if !strings.HasPrefix(targetPath, "/") { targetPath = "/" + targetPath } targetURL = baseURL + targetPath } } if r.URL.RawQuery != "" { if strings.Contains(targetURL, "?") { targetURL += "&" + r.URL.RawQuery } else { targetURL += "?" + r.URL.RawQuery } } if err := h.validator.ValidateURL(targetURL); err != nil { log.Printf("URL validation failed: %v", err) http.Error(w, "Invalid URL", http.StatusForbidden) return } clientIP := getClientIP(r) if !h.rateLimiter.Allow(clientIP) { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } cacheKey := h.cache.GenerateKey(targetURL) if entry := h.cache.Get(cacheKey); entry != nil { log.Printf("Cache HIT: %s", targetURL) h.serveCached(w, entry) return } log.Printf("Cache MISS: %s", targetURL) proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { http.Error(w, "Failed to create request", http.StatusInternalServerError) return } h.setProxyHeaders(proxyReq, r) client := &http.Client{ Timeout: 30 * time.Second, CheckRedirect: func(req *http.Request, via []*http.Request) error { if err := h.validator.ValidateURL(req.URL.String()); err != nil { return err } return nil }, } resp, err := client.Do(proxyReq) if err != nil { log.Printf("Request failed: %v", err) http.Error(w, "Failed to fetch URL", http.StatusBadGateway) return } defer resp.Body.Close() body, err := h.readResponseBody(resp) if err != nil { log.Printf("Failed to read response: %v", err) http.Error(w, "Failed to read response", http.StatusInternalServerError) return } contentType := resp.Header.Get("Content-Type") body = h.rewriteContent(body, targetURL, contentType, token) if h.shouldCache(resp) { headers := make(map[string]string) for key, values := range resp.Header { if key == "Content-Encoding" || key == "Content-Length" { continue } if len(values) > 0 { headers[key] = values[0] } } h.cache.Set(cacheKey, body, headers) } h.sendResponse(w, resp, body) } func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte { rewriter, err := NewContentRewriter(targetURL, token) if err != nil { log.Printf("Failed to create rewriter: %v", err) return body } contentType = strings.ToLower(contentType) if strings.Contains(contentType, "text/html") { rewritten, err := rewriter.RewriteHTML(body) if err != nil { log.Printf("HTML rewrite failed: %v", err) return body } return rewritten } if strings.Contains(contentType, "text/css") { return rewriter.RewriteCSS(body) } return body } func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { var reader io.Reader = resp.Body encoding := strings.ToLower(resp.Header.Get("Content-Encoding")) if strings.Contains(encoding, "gzip") { gzReader, err := gzip.NewReader(resp.Body) if err != nil { return nil, err } defer gzReader.Close() reader = gzReader } limitReader := io.LimitReader(reader, h.maxResponseSize) return io.ReadAll(limitReader) } func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { headersToForward := []string{ "Accept", "Accept-Language", "Accept-Encoding", "Cache-Control", } for _, header := range headersToForward { if value := originalReq.Header.Get(header); value != "" { proxyReq.Header.Set(header, value) } } proxyReq.Header.Set("User-Agent", h.userAgent) proxyReq.Header.Del("X-Forwarded-For") proxyReq.Header.Del("X-Real-IP") proxyReq.Header.Del("Via") proxyReq.Header.Del("Referer") } func (h *ProxyHandler) shouldCache(resp *http.Response) bool { if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK { return false } cacheControl := resp.Header.Get("Cache-Control") if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "private") { return false } contentType := resp.Header.Get("Content-Type") cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"} for _, ct := range cacheableTypes { if strings.Contains(contentType, ct) { return true } } return false } func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"} for _, header := range headersToForward { if value := resp.Header.Get(header); value != "" { w.Header().Set(header, value) } } w.Header().Set("X-Proxied-By", "SiteProxy") w.Header().Set("X-Cache-Status", "MISS") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "SAMEORIGIN") w.Header().Set("Referrer-Policy", "no-referrer") w.WriteHeader(resp.StatusCode) w.Write(body) } func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { for key, value := range entry.Headers { if key == "Content-Encoding" || key == "Content-Length" { continue } w.Header().Set(key, value) } w.Header().Set("X-Cache-Status", "HIT") w.Header().Set("X-Proxied-By", "SiteProxy") w.Header().Set("Age", entry.Age()) w.WriteHeader(http.StatusOK) w.Write(entry.Data) } func getClientIP(r *http.Request) string { if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } if ip := r.Header.Get("X-Forwarded-For"); ip != "" { ips := strings.Split(ip, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip }