Files
SiteProxy/proxy/handler.go
2025-12-15 18:00:49 +08:00

342 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// proxy/handler.go
package proxy
import (
"compress/gzip"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"siteproxy/cache"
"siteproxy/security"
)
type ProxyHandler struct {
validator *security.RequestValidator
rateLimiter *security.RateLimiter
cache *cache.MemoryCache
sessionManager *ProxySessionManager
userAgent string
maxResponseSize int64
}
func NewHandler(
validator *security.RequestValidator,
rateLimiter *security.RateLimiter,
cache *cache.MemoryCache,
sessionManager *ProxySessionManager,
userAgent string,
maxResponseSize int64,
) *ProxyHandler {
return &ProxyHandler{
validator: validator,
rateLimiter: rateLimiter,
cache: cache,
sessionManager: sessionManager,
userAgent: userAgent,
maxResponseSize: maxResponseSize,
}
}
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/p/")
slashIdx := strings.Index(path, "/")
var token, targetPath string
if slashIdx == -1 {
token = path
targetPath = ""
} else {
token = path[:slashIdx]
targetPath = path[slashIdx+1:]
}
if token == "" {
http.Error(w, "Invalid token", http.StatusBadRequest)
return
}
session := h.sessionManager.Get(token)
if session == nil {
http.Error(w, "Session expired or invalid", http.StatusUnauthorized)
return
}
var targetURL string
if strings.HasPrefix(targetPath, "http:/") && !strings.HasPrefix(targetPath, "http://") {
targetPath = strings.Replace(targetPath, "http:/", "http://", 1)
}
if strings.HasPrefix(targetPath, "https:/") && !strings.HasPrefix(targetPath, "https://") {
targetPath = strings.Replace(targetPath, "https:/", "https://", 1)
}
if strings.HasPrefix(targetPath, "http://") || strings.HasPrefix(targetPath, "https://") {
targetURL = targetPath
} else {
baseURL := strings.TrimSuffix(session.TargetURL, "/")
if targetPath == "" {
targetURL = baseURL
} else {
if !strings.HasPrefix(targetPath, "/") {
targetPath = "/" + targetPath
}
targetURL = baseURL + targetPath
}
}
if r.URL.RawQuery != "" {
if strings.Contains(targetURL, "?") {
targetURL += "&" + r.URL.RawQuery
} else {
targetURL += "?" + r.URL.RawQuery
}
}
if err := h.validator.ValidateURL(targetURL); err != nil {
log.Printf("URL validation failed: %v", err)
http.Error(w, "Invalid URL", http.StatusForbidden)
return
}
clientIP := getClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
cacheKey := h.cache.GenerateKey(targetURL)
if entry := h.cache.Get(cacheKey); entry != nil {
log.Printf("Cache HIT: %s", targetURL)
h.serveCached(w, entry)
return
}
log.Printf("Cache MISS: %s", targetURL)
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
h.setProxyHeaders(proxyReq, r)
client := &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := h.validator.ValidateURL(req.URL.String()); err != nil {
return err
}
return nil
},
}
resp, err := client.Do(proxyReq)
if err != nil {
log.Printf("Request failed: %v", err)
http.Error(w, "Failed to fetch URL", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 处理 404尝试从 Referer 提取正确路径
if resp.StatusCode == 404 {
if referer := r.Header.Get("Referer"); referer != "" && strings.Contains(referer, "/p/"+token) {
if u, err := url.Parse(targetURL); err == nil && u.Host != "" {
// 直接使用原始 targetURL 的 host + path
baseURL := u.Scheme + "://" + u.Host
correctPath := u.Path
if u.RawQuery != "" {
correctPath += "?" + u.RawQuery
}
newTargetURL := baseURL + correctPath
// 避免重复请求相同 URL
if newTargetURL != targetURL {
retryReq, _ := http.NewRequest(r.Method, newTargetURL, nil)
h.setProxyHeaders(retryReq, r)
if retryResp, err := client.Do(retryReq); err == nil && retryResp.StatusCode == 200 {
resp.Body.Close()
resp = retryResp
defer resp.Body.Close()
log.Printf("404 retry success: %s -> %s", targetURL, newTargetURL)
}
}
}
}
}
body, err := h.readResponseBody(resp)
if err != nil {
log.Printf("Failed to read response: %v", err)
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
contentType := resp.Header.Get("Content-Type")
body = h.rewriteContent(body, targetURL, contentType, token)
if h.shouldCache(resp) {
headers := make(map[string]string)
for key, values := range resp.Header {
if key == "Content-Encoding" || key == "Content-Length" {
continue
}
if len(values) > 0 {
headers[key] = values[0]
}
}
h.cache.Set(cacheKey, body, headers)
}
h.sendResponse(w, resp, body)
}
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte {
rewriter, err := NewContentRewriter(targetURL, token)
if err != nil {
log.Printf("Failed to create rewriter: %v", err)
return body
}
contentType = strings.ToLower(contentType)
if strings.Contains(contentType, "text/html") {
rewritten, err := rewriter.RewriteHTML(body)
if err != nil {
log.Printf("HTML rewrite failed: %v", err)
return body
}
return rewritten
}
if strings.Contains(contentType, "text/css") {
return rewriter.RewriteCSS(body)
}
return body
}
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
var reader io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err == nil {
defer gzReader.Close()
reader = gzReader
}
}
limitReader := io.LimitReader(reader, h.maxResponseSize)
return io.ReadAll(limitReader)
}
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
headersToForward := []string{
"Accept",
"Accept-Language",
"Cache-Control",
}
for _, header := range headersToForward {
if value := originalReq.Header.Get(header); value != "" {
proxyReq.Header.Set(header, value)
}
}
proxyReq.Header.Set("Accept-Encoding", "gzip")
proxyReq.Header.Set("User-Agent", h.userAgent)
proxyReq.Header.Del("X-Forwarded-For")
proxyReq.Header.Del("X-Real-IP")
proxyReq.Header.Del("Via")
proxyReq.Header.Del("Referer")
}
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
return false
}
cacheControl := resp.Header.Get("Cache-Control")
if strings.Contains(cacheControl, "no-store") ||
strings.Contains(cacheControl, "no-cache") ||
strings.Contains(cacheControl, "private") {
return false
}
contentType := resp.Header.Get("Content-Type")
cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"}
for _, ct := range cacheableTypes {
if strings.Contains(contentType, ct) {
return true
}
}
return false
}
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"}
for _, header := range headersToForward {
if value := resp.Header.Get(header); value != "" {
w.Header().Set(header, value)
}
}
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("X-Cache-Status", "MISS")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
w.Header().Set("Referrer-Policy", "no-referrer")
w.WriteHeader(resp.StatusCode)
w.Write(body)
}
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
for key, value := range entry.Headers {
if key == "Content-Encoding" || key == "Content-Length" {
continue
}
w.Header().Set(key, value)
}
w.Header().Set("X-Cache-Status", "HIT")
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("Age", entry.Age())
w.WriteHeader(http.StatusOK)
w.Write(entry.Data)
}
func getClientIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
ips := strings.Split(ip, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}