342 lines
9.9 KiB
Go
342 lines
9.9 KiB
Go
// proxy/handler.go
|
||
package proxy
|
||
|
||
import (
|
||
"compress/gzip"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
"siteproxy/cache"
|
||
"siteproxy/security"
|
||
)
|
||
|
||
type ProxyHandler struct {
|
||
validator *security.RequestValidator
|
||
rateLimiter *security.RateLimiter
|
||
cache *cache.MemoryCache
|
||
sessionManager *ProxySessionManager
|
||
userAgent string
|
||
maxResponseSize int64
|
||
}
|
||
|
||
func NewHandler(
|
||
validator *security.RequestValidator,
|
||
rateLimiter *security.RateLimiter,
|
||
cache *cache.MemoryCache,
|
||
sessionManager *ProxySessionManager,
|
||
userAgent string,
|
||
maxResponseSize int64,
|
||
) *ProxyHandler {
|
||
return &ProxyHandler{
|
||
validator: validator,
|
||
rateLimiter: rateLimiter,
|
||
cache: cache,
|
||
sessionManager: sessionManager,
|
||
userAgent: userAgent,
|
||
maxResponseSize: maxResponseSize,
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
path := strings.TrimPrefix(r.URL.Path, "/p/")
|
||
|
||
slashIdx := strings.Index(path, "/")
|
||
|
||
var token, targetPath string
|
||
|
||
if slashIdx == -1 {
|
||
token = path
|
||
targetPath = ""
|
||
} else {
|
||
token = path[:slashIdx]
|
||
targetPath = path[slashIdx+1:]
|
||
}
|
||
|
||
if token == "" {
|
||
http.Error(w, "Invalid token", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
session := h.sessionManager.Get(token)
|
||
if session == nil {
|
||
http.Error(w, "Session expired or invalid", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
var targetURL string
|
||
|
||
if strings.HasPrefix(targetPath, "http:/") && !strings.HasPrefix(targetPath, "http://") {
|
||
targetPath = strings.Replace(targetPath, "http:/", "http://", 1)
|
||
}
|
||
if strings.HasPrefix(targetPath, "https:/") && !strings.HasPrefix(targetPath, "https://") {
|
||
targetPath = strings.Replace(targetPath, "https:/", "https://", 1)
|
||
}
|
||
|
||
if strings.HasPrefix(targetPath, "http://") || strings.HasPrefix(targetPath, "https://") {
|
||
targetURL = targetPath
|
||
} else {
|
||
baseURL := strings.TrimSuffix(session.TargetURL, "/")
|
||
if targetPath == "" {
|
||
targetURL = baseURL
|
||
} else {
|
||
if !strings.HasPrefix(targetPath, "/") {
|
||
targetPath = "/" + targetPath
|
||
}
|
||
targetURL = baseURL + targetPath
|
||
}
|
||
}
|
||
|
||
if r.URL.RawQuery != "" {
|
||
if strings.Contains(targetURL, "?") {
|
||
targetURL += "&" + r.URL.RawQuery
|
||
} else {
|
||
targetURL += "?" + r.URL.RawQuery
|
||
}
|
||
}
|
||
|
||
if err := h.validator.ValidateURL(targetURL); err != nil {
|
||
log.Printf("URL validation failed: %v", err)
|
||
http.Error(w, "Invalid URL", http.StatusForbidden)
|
||
return
|
||
}
|
||
|
||
clientIP := getClientIP(r)
|
||
if !h.rateLimiter.Allow(clientIP) {
|
||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||
return
|
||
}
|
||
|
||
cacheKey := h.cache.GenerateKey(targetURL)
|
||
if entry := h.cache.Get(cacheKey); entry != nil {
|
||
log.Printf("Cache HIT: %s", targetURL)
|
||
h.serveCached(w, entry)
|
||
return
|
||
}
|
||
log.Printf("Cache MISS: %s", targetURL)
|
||
|
||
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
|
||
if err != nil {
|
||
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
h.setProxyHeaders(proxyReq, r)
|
||
|
||
client := &http.Client{
|
||
Timeout: 30 * time.Second,
|
||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||
if err := h.validator.ValidateURL(req.URL.String()); err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
},
|
||
}
|
||
|
||
resp, err := client.Do(proxyReq)
|
||
if err != nil {
|
||
log.Printf("Request failed: %v", err)
|
||
http.Error(w, "Failed to fetch URL", http.StatusBadGateway)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 处理 404:尝试从 Referer 提取正确路径
|
||
if resp.StatusCode == 404 {
|
||
if referer := r.Header.Get("Referer"); referer != "" && strings.Contains(referer, "/p/"+token) {
|
||
if u, err := url.Parse(targetURL); err == nil && u.Host != "" {
|
||
// 直接使用原始 targetURL 的 host + path
|
||
baseURL := u.Scheme + "://" + u.Host
|
||
correctPath := u.Path
|
||
if u.RawQuery != "" {
|
||
correctPath += "?" + u.RawQuery
|
||
}
|
||
|
||
newTargetURL := baseURL + correctPath
|
||
|
||
// 避免重复请求相同 URL
|
||
if newTargetURL != targetURL {
|
||
retryReq, _ := http.NewRequest(r.Method, newTargetURL, nil)
|
||
h.setProxyHeaders(retryReq, r)
|
||
|
||
if retryResp, err := client.Do(retryReq); err == nil && retryResp.StatusCode == 200 {
|
||
resp.Body.Close()
|
||
resp = retryResp
|
||
defer resp.Body.Close()
|
||
log.Printf("404 retry success: %s -> %s", targetURL, newTargetURL)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
body, err := h.readResponseBody(resp)
|
||
if err != nil {
|
||
log.Printf("Failed to read response: %v", err)
|
||
http.Error(w, "Failed to read response", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
contentType := resp.Header.Get("Content-Type")
|
||
body = h.rewriteContent(body, targetURL, contentType, token)
|
||
|
||
if h.shouldCache(resp) {
|
||
headers := make(map[string]string)
|
||
for key, values := range resp.Header {
|
||
if key == "Content-Encoding" || key == "Content-Length" {
|
||
continue
|
||
}
|
||
if len(values) > 0 {
|
||
headers[key] = values[0]
|
||
}
|
||
}
|
||
h.cache.Set(cacheKey, body, headers)
|
||
}
|
||
|
||
h.sendResponse(w, resp, body)
|
||
}
|
||
|
||
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte {
|
||
rewriter, err := NewContentRewriter(targetURL, token)
|
||
if err != nil {
|
||
log.Printf("Failed to create rewriter: %v", err)
|
||
return body
|
||
}
|
||
|
||
contentType = strings.ToLower(contentType)
|
||
|
||
if strings.Contains(contentType, "text/html") {
|
||
rewritten, err := rewriter.RewriteHTML(body)
|
||
if err != nil {
|
||
log.Printf("HTML rewrite failed: %v", err)
|
||
return body
|
||
}
|
||
return rewritten
|
||
}
|
||
|
||
if strings.Contains(contentType, "text/css") {
|
||
return rewriter.RewriteCSS(body)
|
||
}
|
||
|
||
return body
|
||
}
|
||
|
||
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
|
||
var reader io.Reader = resp.Body
|
||
|
||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||
gzReader, err := gzip.NewReader(resp.Body)
|
||
if err == nil {
|
||
defer gzReader.Close()
|
||
reader = gzReader
|
||
}
|
||
}
|
||
|
||
limitReader := io.LimitReader(reader, h.maxResponseSize)
|
||
return io.ReadAll(limitReader)
|
||
}
|
||
|
||
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
|
||
headersToForward := []string{
|
||
"Accept",
|
||
"Accept-Language",
|
||
"Cache-Control",
|
||
}
|
||
|
||
for _, header := range headersToForward {
|
||
if value := originalReq.Header.Get(header); value != "" {
|
||
proxyReq.Header.Set(header, value)
|
||
}
|
||
}
|
||
|
||
proxyReq.Header.Set("Accept-Encoding", "gzip")
|
||
proxyReq.Header.Set("User-Agent", h.userAgent)
|
||
proxyReq.Header.Del("X-Forwarded-For")
|
||
proxyReq.Header.Del("X-Real-IP")
|
||
proxyReq.Header.Del("Via")
|
||
proxyReq.Header.Del("Referer")
|
||
}
|
||
|
||
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
|
||
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
||
return false
|
||
}
|
||
|
||
cacheControl := resp.Header.Get("Cache-Control")
|
||
if strings.Contains(cacheControl, "no-store") ||
|
||
strings.Contains(cacheControl, "no-cache") ||
|
||
strings.Contains(cacheControl, "private") {
|
||
return false
|
||
}
|
||
|
||
contentType := resp.Header.Get("Content-Type")
|
||
cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"}
|
||
|
||
for _, ct := range cacheableTypes {
|
||
if strings.Contains(contentType, ct) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
|
||
headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"}
|
||
|
||
for _, header := range headersToForward {
|
||
if value := resp.Header.Get(header); value != "" {
|
||
w.Header().Set(header, value)
|
||
}
|
||
}
|
||
|
||
w.Header().Set("X-Proxied-By", "SiteProxy")
|
||
w.Header().Set("X-Cache-Status", "MISS")
|
||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
||
w.Header().Set("Referrer-Policy", "no-referrer")
|
||
|
||
w.WriteHeader(resp.StatusCode)
|
||
w.Write(body)
|
||
}
|
||
|
||
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
|
||
for key, value := range entry.Headers {
|
||
if key == "Content-Encoding" || key == "Content-Length" {
|
||
continue
|
||
}
|
||
w.Header().Set(key, value)
|
||
}
|
||
|
||
w.Header().Set("X-Cache-Status", "HIT")
|
||
w.Header().Set("X-Proxied-By", "SiteProxy")
|
||
w.Header().Set("Age", entry.Age())
|
||
|
||
w.WriteHeader(http.StatusOK)
|
||
w.Write(entry.Data)
|
||
}
|
||
|
||
func getClientIP(r *http.Request) string {
|
||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||
return ip
|
||
}
|
||
|
||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||
ips := strings.Split(ip, ",")
|
||
if len(ips) > 0 {
|
||
return strings.TrimSpace(ips[0])
|
||
}
|
||
}
|
||
|
||
ip := r.RemoteAddr
|
||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||
ip = ip[:idx]
|
||
}
|
||
|
||
return ip
|
||
}
|
||
|