Files
SiteProxy/proxy/handler-base.go
2025-12-15 21:09:04 +08:00

335 lines
9.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// proxy/handler.go
package proxy
import (
"compress/gzip"
"bytes"
"io"
"log"
"net/http"
"strings"
"time"
"siteproxy/cache"
"siteproxy/security"
)
type ProxyHandler struct {
validator *security.RequestValidator
rateLimiter *security.RateLimiter
cache *cache.MemoryCache
sessionManager *ProxySessionManager
userAgent string
maxResponseSize int64
}
func NewHandler(
validator *security.RequestValidator,
rateLimiter *security.RateLimiter,
cache *cache.MemoryCache,
sessionManager *ProxySessionManager,
userAgent string,
maxResponseSize int64,
) *ProxyHandler {
return &ProxyHandler{
validator: validator,
rateLimiter: rateLimiter,
cache: cache,
sessionManager: sessionManager,
userAgent: userAgent,
maxResponseSize: maxResponseSize,
}
}
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/p/")
slashIdx := strings.Index(path, "/")
var token, targetPath string
if slashIdx == -1 {
token = path
targetPath = ""
} else {
token = path[:slashIdx]
targetPath = path[slashIdx+1:]
}
if token == "" {
http.Error(w, "Invalid token", http.StatusBadRequest)
return
}
session := h.sessionManager.Get(token)
if session == nil {
// 尝试从 cookie 恢复
if cookie, err := r.Cookie("ORIGINALHOST_" + token); err == nil {
session = &ProxySession{TargetURL: cookie.Value}
} else {
http.Error(w, "Session expired or invalid", http.StatusUnauthorized)
return
}
}
// 设置 cookie10分钟有效期
http.SetCookie(w, &http.Cookie{
Name: "ORIGINALHOST_" + token,
Value: session.TargetURL,
Path: "/p/" + token,
MaxAge: 600,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
var targetURL string
if strings.HasPrefix(targetPath, "http:/") && !strings.HasPrefix(targetPath, "http://") {
targetPath = strings.Replace(targetPath, "http:/", "http://", 1)
}
if strings.HasPrefix(targetPath, "https:/") && !strings.HasPrefix(targetPath, "https://") {
targetPath = strings.Replace(targetPath, "https:/", "https://", 1)
}
if strings.HasPrefix(targetPath, "http://") || strings.HasPrefix(targetPath, "https://") {
targetURL = targetPath
} else {
baseURL := strings.TrimSuffix(session.TargetURL, "/")
if targetPath == "" {
targetURL = baseURL
} else {
if !strings.HasPrefix(targetPath, "/") {
targetPath = "/" + targetPath
}
targetURL = baseURL + targetPath
}
}
if r.URL.RawQuery != "" {
if strings.Contains(targetURL, "?") {
targetURL += "&" + r.URL.RawQuery
} else {
targetURL += "?" + r.URL.RawQuery
}
}
if err := h.validator.ValidateURL(targetURL); err != nil {
log.Printf("URL validation failed: %v", err)
http.Error(w, "Invalid URL", http.StatusForbidden)
return
}
clientIP := getClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
cacheKey := h.cache.GenerateKey(targetURL)
if entry := h.cache.Get(cacheKey); entry != nil {
log.Printf("Cache HIT: %s", targetURL)
h.serveCached(w, entry)
return
}
log.Printf("Cache MISS: %s", targetURL)
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
if r.Body != nil {
body, _ := io.ReadAll(r.Body)
proxyReq.Body = io.NopCloser(bytes.NewReader(body))
proxyReq.ContentLength = int64(len(body))
}
}
h.setProxyHeaders(proxyReq, r)
client := &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := h.validator.ValidateURL(req.URL.String()); err != nil {
return err
}
return nil
},
}
resp, err := client.Do(proxyReq)
if err != nil {
log.Printf("Request failed: %v", err)
http.Error(w, "Failed to fetch URL", http.StatusBadGateway)
return
}
defer resp.Body.Close()
body, err := h.readResponseBody(resp)
if err != nil {
log.Printf("Failed to read response: %v", err)
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
contentType := resp.Header.Get("Content-Type")
body = h.rewriteContent(body, targetURL, contentType, token)
if h.shouldCache(resp) {
headers := make(map[string]string)
for key, values := range resp.Header {
if key == "Content-Encoding" || key == "Content-Length" {
continue
}
if len(values) > 0 {
headers[key] = values[0]
}
}
h.cache.Set(cacheKey, body, headers)
}
h.sendResponse(w, resp, body)
}
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte {
rewriter, err := NewContentRewriter(targetURL, token)
if err != nil {
log.Printf("Failed to create rewriter: %v", err)
return body
}
contentType = strings.ToLower(contentType)
if strings.Contains(contentType, "text/html") {
rewritten, err := rewriter.RewriteHTML(body)
if err != nil {
log.Printf("HTML rewrite failed: %v", err)
return body
}
return rewritten
}
if strings.Contains(contentType, "text/css") {
return rewriter.RewriteCSS(body)
}
return body
}
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
var reader io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err == nil {
defer gzReader.Close()
reader = gzReader
}
}
limitReader := io.LimitReader(reader, h.maxResponseSize)
return io.ReadAll(limitReader)
}
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
headersToForward := []string{
"Accept",
"Accept-Language",
"Cache-Control",
}
for _, header := range headersToForward {
if value := originalReq.Header.Get(header); value != "" {
proxyReq.Header.Set(header, value)
}
}
proxyReq.Header.Set("Accept-Encoding", "gzip")
proxyReq.Header.Set("User-Agent", h.userAgent)
proxyReq.Header.Del("X-Forwarded-For")
proxyReq.Header.Del("X-Real-IP")
proxyReq.Header.Del("Via")
proxyReq.Header.Del("Referer")
}
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
return false
}
cacheControl := resp.Header.Get("Cache-Control")
if strings.Contains(cacheControl, "no-store") ||
strings.Contains(cacheControl, "no-cache") ||
strings.Contains(cacheControl, "private") {
return false
}
contentType := resp.Header.Get("Content-Type")
cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"}
for _, ct := range cacheableTypes {
if strings.Contains(contentType, ct) {
return true
}
}
return false
}
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"}
for _, header := range headersToForward {
if value := resp.Header.Get(header); value != "" {
w.Header().Set(header, value)
}
}
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("X-Cache-Status", "MISS")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
w.Header().Set("Referrer-Policy", "no-referrer")
w.WriteHeader(resp.StatusCode)
w.Write(body)
}
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
for key, value := range entry.Headers {
if key == "Content-Encoding" || key == "Content-Length" {
continue
}
w.Header().Set(key, value)
}
w.Header().Set("X-Cache-Status", "HIT")
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("Age", entry.Age())
w.WriteHeader(http.StatusOK)
w.Write(entry.Data)
}
func getClientIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
ips := strings.Split(ip, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}