更新 proxy/handler.go
This commit is contained in:
@@ -17,6 +17,7 @@ type ProxyHandler struct {
|
|||||||
validator *security.RequestValidator
|
validator *security.RequestValidator
|
||||||
rateLimiter *security.RateLimiter
|
rateLimiter *security.RateLimiter
|
||||||
cache *cache.MemoryCache
|
cache *cache.MemoryCache
|
||||||
|
sessionManager *ProxySessionManager
|
||||||
userAgent string
|
userAgent string
|
||||||
maxResponseSize int64
|
maxResponseSize int64
|
||||||
}
|
}
|
||||||
@@ -25,6 +26,7 @@ func NewHandler(
|
|||||||
validator *security.RequestValidator,
|
validator *security.RequestValidator,
|
||||||
rateLimiter *security.RateLimiter,
|
rateLimiter *security.RateLimiter,
|
||||||
cache *cache.MemoryCache,
|
cache *cache.MemoryCache,
|
||||||
|
sessionManager *ProxySessionManager,
|
||||||
userAgent string,
|
userAgent string,
|
||||||
maxResponseSize int64,
|
maxResponseSize int64,
|
||||||
) *ProxyHandler {
|
) *ProxyHandler {
|
||||||
@@ -32,34 +34,50 @@ func NewHandler(
|
|||||||
validator: validator,
|
validator: validator,
|
||||||
rateLimiter: rateLimiter,
|
rateLimiter: rateLimiter,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
sessionManager: sessionManager,
|
||||||
userAgent: userAgent,
|
userAgent: userAgent,
|
||||||
maxResponseSize: maxResponseSize,
|
maxResponseSize: maxResponseSize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// 获取目标 URL
|
path := strings.TrimPrefix(r.URL.Path, "/p/")
|
||||||
targetURL := r.URL.Query().Get("url")
|
parts := strings.SplitN(path, "/", 2)
|
||||||
if targetURL == "" {
|
|
||||||
http.Error(w, "Missing url parameter", http.StatusBadRequest)
|
if len(parts) == 0 || parts[0] == "" {
|
||||||
|
http.Error(w, "Invalid token", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证 URL
|
token := parts[0]
|
||||||
|
subPath := ""
|
||||||
|
if len(parts) > 1 {
|
||||||
|
subPath = "/" + parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.sessionManager.Get(token)
|
||||||
|
if session == nil {
|
||||||
|
http.Error(w, "Session expired or invalid", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL := session.TargetURL + subPath
|
||||||
|
if r.URL.RawQuery != "" {
|
||||||
|
targetURL += "?" + r.URL.RawQuery
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.validator.ValidateURL(targetURL); err != nil {
|
if err := h.validator.ValidateURL(targetURL); err != nil {
|
||||||
log.Printf("URL validation failed: %v", err)
|
log.Printf("URL validation failed: %v", err)
|
||||||
http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden)
|
http.Error(w, "Invalid URL", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 速率限制
|
|
||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
if !h.rateLimiter.Allow(clientIP) {
|
if !h.rateLimiter.Allow(clientIP) {
|
||||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查缓存
|
|
||||||
if entry := h.cache.Get(targetURL); entry != nil {
|
if entry := h.cache.Get(targetURL); entry != nil {
|
||||||
log.Printf("Cache HIT: %s", targetURL)
|
log.Printf("Cache HIT: %s", targetURL)
|
||||||
h.serveCached(w, entry)
|
h.serveCached(w, entry)
|
||||||
@@ -68,17 +86,14 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
log.Printf("Cache MISS: %s", targetURL)
|
log.Printf("Cache MISS: %s", targetURL)
|
||||||
|
|
||||||
// 创建代理请求
|
|
||||||
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
|
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置请求头
|
|
||||||
h.setProxyHeaders(proxyReq, r)
|
h.setProxyHeaders(proxyReq, r)
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
@@ -97,7 +112,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// 读取响应体(会自动解压 gzip)
|
|
||||||
body, err := h.readResponseBody(resp)
|
body, err := h.readResponseBody(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to read response: %v", err)
|
log.Printf("Failed to read response: %v", err)
|
||||||
@@ -105,15 +119,12 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重写内容
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
body = h.rewriteContent(body, targetURL, contentType)
|
body = h.rewriteContent(body, targetURL, contentType, token)
|
||||||
|
|
||||||
// 缓存响应 - 转换 headers 并删除 Content-Encoding
|
|
||||||
if h.shouldCache(resp) {
|
if h.shouldCache(resp) {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
// 跳过 Content-Encoding,因为我们已经解压了
|
|
||||||
if key == "Content-Encoding" || key == "Content-Length" {
|
if key == "Content-Encoding" || key == "Content-Length" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -124,12 +135,11 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
h.cache.Set(targetURL, body, headers)
|
h.cache.Set(targetURL, body, headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送响应
|
|
||||||
h.sendResponse(w, resp, body)
|
h.sendResponse(w, resp, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte {
|
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType, token string) []byte {
|
||||||
rewriter, err := NewContentRewriter(targetURL)
|
rewriter, err := NewContentRewriter(targetURL, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to create rewriter: %v", err)
|
log.Printf("Failed to create rewriter: %v", err)
|
||||||
return body
|
return body
|
||||||
@@ -137,7 +147,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string
|
|||||||
|
|
||||||
contentType = strings.ToLower(contentType)
|
contentType = strings.ToLower(contentType)
|
||||||
|
|
||||||
// HTML 内容
|
|
||||||
if strings.Contains(contentType, "text/html") {
|
if strings.Contains(contentType, "text/html") {
|
||||||
rewritten, err := rewriter.RewriteHTML(body)
|
rewritten, err := rewriter.RewriteHTML(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -147,7 +156,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string
|
|||||||
return rewritten
|
return rewritten
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSS 内容
|
|
||||||
if strings.Contains(contentType, "text/css") {
|
if strings.Contains(contentType, "text/css") {
|
||||||
return rewriter.RewriteCSS(body)
|
return rewriter.RewriteCSS(body)
|
||||||
}
|
}
|
||||||
@@ -158,7 +166,6 @@ func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string
|
|||||||
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
|
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
|
||||||
var reader io.Reader = resp.Body
|
var reader io.Reader = resp.Body
|
||||||
|
|
||||||
// 处理各种压缩格式
|
|
||||||
encoding := strings.ToLower(resp.Header.Get("Content-Encoding"))
|
encoding := strings.ToLower(resp.Header.Get("Content-Encoding"))
|
||||||
if strings.Contains(encoding, "gzip") {
|
if strings.Contains(encoding, "gzip") {
|
||||||
gzReader, err := gzip.NewReader(resp.Body)
|
gzReader, err := gzip.NewReader(resp.Body)
|
||||||
@@ -169,20 +176,17 @@ func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
|
|||||||
reader = gzReader
|
reader = gzReader
|
||||||
}
|
}
|
||||||
|
|
||||||
// 限制读取大小
|
|
||||||
limitReader := io.LimitReader(reader, h.maxResponseSize)
|
limitReader := io.LimitReader(reader, h.maxResponseSize)
|
||||||
|
|
||||||
return io.ReadAll(limitReader)
|
return io.ReadAll(limitReader)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
|
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
|
||||||
// 复制必要的请求头
|
|
||||||
headersToForward := []string{
|
headersToForward := []string{
|
||||||
"Accept",
|
"Accept",
|
||||||
"Accept-Language",
|
"Accept-Language",
|
||||||
"Accept-Encoding",
|
"Accept-Encoding",
|
||||||
"Cache-Control",
|
"Cache-Control",
|
||||||
"Referer",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, header := range headersToForward {
|
for _, header := range headersToForward {
|
||||||
@@ -191,26 +195,18 @@ func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置自定义 User-Agent
|
|
||||||
proxyReq.Header.Set("User-Agent", h.userAgent)
|
proxyReq.Header.Set("User-Agent", h.userAgent)
|
||||||
|
|
||||||
// 移除可能暴露代理的头
|
|
||||||
proxyReq.Header.Del("X-Forwarded-For")
|
proxyReq.Header.Del("X-Forwarded-For")
|
||||||
proxyReq.Header.Del("X-Real-IP")
|
proxyReq.Header.Del("X-Real-IP")
|
||||||
proxyReq.Header.Del("Via")
|
proxyReq.Header.Del("Via")
|
||||||
|
proxyReq.Header.Del("Referer")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
|
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
|
||||||
// 只缓存成功的 GET 请求
|
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
||||||
if resp.Request.Method != "GET" {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查 Cache-Control
|
|
||||||
cacheControl := resp.Header.Get("Cache-Control")
|
cacheControl := resp.Header.Get("Cache-Control")
|
||||||
if strings.Contains(cacheControl, "no-store") ||
|
if strings.Contains(cacheControl, "no-store") ||
|
||||||
strings.Contains(cacheControl, "no-cache") ||
|
strings.Contains(cacheControl, "no-cache") ||
|
||||||
@@ -218,15 +214,8 @@ func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查内容类型
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
cacheableTypes := []string{
|
cacheableTypes := []string{"text/html", "text/css", "application/javascript", "image/", "font/"}
|
||||||
"text/html",
|
|
||||||
"text/css",
|
|
||||||
"application/javascript",
|
|
||||||
"image/",
|
|
||||||
"font/",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ct := range cacheableTypes {
|
for _, ct := range cacheableTypes {
|
||||||
if strings.Contains(contentType, ct) {
|
if strings.Contains(contentType, ct) {
|
||||||
@@ -238,14 +227,7 @@ func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
|
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
|
||||||
// 复制响应头
|
headersToForward := []string{"Content-Type", "Content-Language", "Last-Modified", "ETag", "Expires"}
|
||||||
headersToForward := []string{
|
|
||||||
"Content-Type",
|
|
||||||
"Content-Language",
|
|
||||||
"Last-Modified",
|
|
||||||
"ETag",
|
|
||||||
"Expires",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, header := range headersToForward {
|
for _, header := range headersToForward {
|
||||||
if value := resp.Header.Get(header); value != "" {
|
if value := resp.Header.Get(header); value != "" {
|
||||||
@@ -253,25 +235,18 @@ func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加自定义头
|
|
||||||
w.Header().Set("X-Proxied-By", "SiteProxy")
|
w.Header().Set("X-Proxied-By", "SiteProxy")
|
||||||
w.Header().Set("X-Cache-Status", "MISS")
|
w.Header().Set("X-Cache-Status", "MISS")
|
||||||
|
|
||||||
// 安全头
|
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
||||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
w.Header().Set("Referrer-Policy", "no-referrer")
|
||||||
|
|
||||||
// 不设置 Content-Encoding 和 Content-Length,让 Go 自动处理
|
|
||||||
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
w.WriteHeader(resp.StatusCode)
|
||||||
w.Write(body)
|
w.Write(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
|
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
|
||||||
for key, value := range entry.Headers {
|
for key, value := range entry.Headers {
|
||||||
// 跳过这些头,让 Go 自动处理
|
|
||||||
if key == "Content-Encoding" || key == "Content-Length" {
|
if key == "Content-Encoding" || key == "Content-Length" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -287,20 +262,17 @@ func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getClientIP(r *http.Request) string {
|
func getClientIP(r *http.Request) string {
|
||||||
// 尝试从各种头中获取真实 IP
|
|
||||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||||
// X-Forwarded-For 可能包含多个 IP
|
|
||||||
ips := strings.Split(ip, ",")
|
ips := strings.Split(ip, ",")
|
||||||
if len(ips) > 0 {
|
if len(ips) > 0 {
|
||||||
return strings.TrimSpace(ips[0])
|
return strings.TrimSpace(ips[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用远程地址
|
|
||||||
ip := r.RemoteAddr
|
ip := r.RemoteAddr
|
||||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||||
ip = ip[:idx]
|
ip = ip[:idx]
|
||||||
|
|||||||
Reference in New Issue
Block a user