Files
SiteProxy/proxy/handler.go
2025-12-15 04:15:29 +08:00

284 lines
7.7 KiB
Go

// proxy/handler.go
package proxy
import (
"compress/gzip"
"io"
"log"
"net/http"
"strings"
"time"
"siteproxy/cache"
"siteproxy/security"
)
type ProxyHandler struct {
validator *security.RequestValidator
rateLimiter *security.RateLimiter
cache *cache.MemoryCache
sessionManager *ProxySessionManager
userAgent string
maxResponseSize int64
}
func NewHandler(
validator *security.RequestValidator,
rateLimiter *security.RateLimiter,
cache *cache.MemoryCache,
sessionManager *ProxySessionManager,
userAgent string,
maxResponseSize int64,
) *ProxyHandler {
return &ProxyHandler{
validator: validator,
rateLimiter: rateLimiter,
cache: cache,
sessionManager: sessionManager,
userAgent: userAgent,
maxResponseSize: maxResponseSize,
}
}
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/p/")
parts := strings.SplitN(path, "/", 2)
if len(parts) == 0 || parts[0] == "" {
http.Error(w, "Invalid token", http.StatusBadRequest)
return
}
token := parts[0]
subPath := ""
if len(parts) > 1 {
subPath = "/" + parts[1]
}
session := h.sessionManager.Get(token)
if session == nil {
http.Error(w, "Session expired or invalid", http.StatusUnauthorized)
return
}
targetURL := session.TargetURL + subPath
if r.URL.RawQuery != "" {
targetURL += "?" + r.URL.RawQuery
}
if err := h.validator.ValidateURL(targetURL); err != nil {
log.Printf("URL validation failed: %v", err)
http.Error(w, "Invalid URL", http.StatusForbidden)
return
}
clientIP := getClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
if entry := h.cache.Get(targetURL); entry != nil {
log.Printf("Cache HIT: %s", targetURL)
h.serveCached(w, entry)
return
}
log.Printf("Cache MISS: %s", targetURL)
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
h.setProxyHeaders(proxyReq, r)
client := &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := h.validator.ValidateURL(req.URL.String()); err != nil {
return err
}
return nil
},
}
resp, err := client.Do(proxyReq)
if err != nil {
log.Printf("Request failed: %v", err)
http.Error(w, "Failed to fetch URL", http.StatusBadGateway)
return
}
defer resp.Body.Close()
body, err := h.readResponseBody(resp)
if err != nil {
log.Printf("Failed to read response: %v", err)
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
contentType := resp.Header.Get("Content-Type")
body = h.rewriteContent(body, targetURL, contentType, token)
if h.shouldCache(resp) {
headers := make(map[string]string)
for key, values := range resp.Header {
if key == "Content-Encoding" || key == "Content-Length" {
continue
}
if len(values) > 0 {
headers[key] = values[0]
}
}
h.cache.Set(targetURL, body, headers)
}
h.sendResponse(w, resp, body)
}
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte {
rewriter, err := NewContentRewriter(targetURL, token)
if err != nil {
log.Printf("Failed to create rewriter: %v", err)
return body
}
contentType = strings.ToLower(contentType)
if strings.Contains(contentType, "text/html") {
rewritten, err := rewriter.RewriteHTML(body)
if err != nil {
log.Printf("HTML rewrite failed: %v", err)
return body
}
return rewritten
}
if strings.Contains(contentType, "text/css") {
return rewriter.RewriteCSS(body)
}
return body
}
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
var reader io.Reader = resp.Body
encoding := strings.ToLower(resp.Header.Get("Content-Encoding"))
if strings.Contains(encoding, "gzip") {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
defer gzReader.Close()
reader = gzReader
}
limitReader := io.LimitReader(reader, h.maxResponseSize)
return io.ReadAll(limitReader)
}
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
headersToForward := []string{
"Accept",
"Accept-Language",
"Accept-Encoding",
"Cache-Control",
}
for _, header := range headersToForward {
if value := originalReq.Header.Get(header); value != "" {
proxyReq.Header.Set(header, value)
}
}
proxyReq.Header.Set("User-Agent", h.userAgent)
proxyReq.Header.Del("X-Forwarded-For")
proxyReq.Header.Del("X-Real-IP")
proxyReq.Header.Del("Via")
proxyReq.Header.Del("Referer")
}
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
return false
}
cacheControl := resp.Header.Get("Cache-Control")
if strings.Contains(cacheControl, "no-store") ||
strings.Contains(cacheControl, "no-cache") ||
strings.Contains(cacheControl, "private") {
return false
}
contentType := resp.Header.Get("Content-Type")
cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"}
for _, ct := range cacheableTypes {
if strings.Contains(contentType, ct) {
return true
}
}
return false
}
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"}
for _, header := range headersToForward {
if value := resp.Header.Get(header); value != "" {
w.Header().Set(header, value)
}
}
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("X-Cache-Status", "MISS")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
w.Header().Set("Referrer-Policy", "no-referrer")
w.WriteHeader(resp.StatusCode)
w.Write(body)
}
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
for key, value := range entry.Headers {
if key == "Content-Encoding" || key == "Content-Length" {
continue
}
w.Header().Set(key, value)
}
w.Header().Set("X-Cache-Status", "HIT")
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("Age", entry.Age())
w.WriteHeader(http.StatusOK)
w.Write(entry.Data)
}
func getClientIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
ips := strings.Split(ip, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}