diff --git a/config/config.go b/config/config.go index 4394d3b..b353245 100644 --- a/config/config.go +++ b/config/config.go @@ -1,97 +1,302 @@ // config/config.go -package config +package proxy import ( + "bytes" + "compress/gzip" + "io" "log" - "os" - "strconv" + "net/http" + "net/url" "strings" "time" + + "siteproxy/cache" + "siteproxy/security" ) -type Config struct { - // 认证 - Username string - Password string - SessionSecret string - SessionTimeout time.Duration - - // 安全 - RateLimit int - RateLimitWindow time.Duration - MaxResponseSize int64 - - // 代理 - AllowedSchemes []string - BlockedDomains map[string]bool - BlockedCIDRs []string - UserAgent string - - // 缓存 - CacheEnabled bool - CacheMaxSize int64 - CacheTTL time.Duration +type ProxyHandler struct { + validator *security.RequestValidator + rateLimiter *security.RateLimiter + cache *cache.MemoryCache + userAgent string + maxResponseSize int64 } -func LoadFromEnv() *Config { - cfg := &Config{ - Username: getEnv("AUTH_USERNAME", "admin"), - Password: getEnv("AUTH_PASSWORD", "changeme"), - SessionSecret: getEnv("SESSION_SECRET", generateRandomString(64)), - SessionTimeout: parseDuration(getEnv("SESSION_TIMEOUT", "30m")), - - RateLimit: parseInt(getEnv("RATE_LIMIT_REQUESTS", "100")), - RateLimitWindow: parseDuration(getEnv("RATE_LIMIT_WINDOW", "1m")), - MaxResponseSize: parseInt64(getEnv("MAX_RESPONSE_SIZE", "52428800")), // 50MB - - AllowedSchemes: strings.Split(getEnv("ALLOWED_SCHEMES", "http,https"), ","), - BlockedDomains: parseBlockedDomains(getEnv("BLOCKED_DOMAINS", "localhost,127.0.0.1,0.0.0.0,internal,*.local")), - BlockedCIDRs: strings.Split(getEnv("BLOCKED_CIDRS", "10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,169.254.0.0/16,::1/128,fc00::/7,fe80::/10"), ","), - UserAgent: getEnv("USER_AGENT", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"), - - CacheEnabled: getEnv("CACHE_ENABLED", "true") == "true", - CacheMaxSize: parseInt64(getEnv("CACHE_MAX_SIZE", "104857600")), // 100MB - CacheTTL: parseDuration(getEnv("CACHE_TTL", "1h")), +func NewHandler( + validator *security.RequestValidator, + rateLimiter *security.RateLimiter, + cache *cache.MemoryCache, + userAgent string, + maxResponseSize int64, +) *ProxyHandler { + return &ProxyHandler{ + validator: validator, + rateLimiter: rateLimiter, + cache: cache, + userAgent: userAgent, + maxResponseSize: maxResponseSize, + } +} + +func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // 获取目标 URL + targetURL := r.URL.Query().Get("url") + if targetURL == "" { + http.Error(w, "Missing url parameter", http.StatusBadRequest) + return } - if cfg.Password == "changeme" { - log.Fatal("Please set AUTH_PASSWORD in .env file") + // 验证 URL + if err := h.validator.ValidateURL(targetURL); err != nil { + log.Printf("URL validation failed: %v", err) + http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden) + return } - return cfg -} - -func getEnv(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value + // 速率限制 + clientIP := getClientIP(r) + if !h.rateLimiter.Allow(clientIP) { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return } - return defaultValue -} - -func parseInt(s string) int { - v, _ := strconv.Atoi(s) - return v -} - -func parseInt64(s string) int64 { - v, _ := strconv.ParseInt(s, 10, 64) - return v -} - -func parseDuration(s string) time.Duration { - d, _ := time.ParseDuration(s) - return d -} - -func parseBlockedDomains(s string) map[string]bool { - domains := make(map[string]bool) - for _, d := range strings.Split(s, ",") { - domains[strings.TrimSpace(d)] = true + + // 检查缓存 + if entry := h.cache.Get(targetURL); entry != nil { + log.Printf("Cache HIT: %s", targetURL) + h.serveCached(w, entry) + return } - return domains + + log.Printf("Cache MISS: %s", targetURL) + + // 创建代理请求 + proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) + if err != nil { + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + + // 设置请求头 + h.setProxyHeaders(proxyReq, r) + + // 发送请求 + client := &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // 验证重定向 URL + if err := h.validator.ValidateURL(req.URL.String()); err != nil { + return err + } + return nil + }, + } + + resp, err := client.Do(proxyReq) + if err != nil { + log.Printf("Request failed: %v", err) + http.Error(w, "Failed to fetch URL", http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // 读取响应体 + body, err := h.readResponseBody(resp) + if err != nil { + log.Printf("Failed to read response: %v", err) + http.Error(w, "Failed to read response", http.StatusInternalServerError) + return + } + + // 重写内容 + contentType := resp.Header.Get("Content-Type") + body = h.rewriteContent(body, targetURL, contentType) + + // 缓存响应 + if h.shouldCache(resp) { + h.cache.Set(targetURL, body, resp.Header) + } + + // 发送响应 + h.sendResponse(w, resp, body) } -func generateRandomString(n int) string { - // 简单实现,生产环境应使用 crypto/rand - return "change_this_to_random_string_in_production" +func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte { + rewriter, err := NewContentRewriter(targetURL) + if err != nil { + log.Printf("Failed to create rewriter: %v", err) + return body + } + + contentType = strings.ToLower(contentType) + + // HTML 内容 + if strings.Contains(contentType, "text/html") { + rewritten, err := rewriter.RewriteHTML(body) + if err != nil { + log.Printf("HTML rewrite failed: %v", err) + return body + } + return rewritten + } + + // CSS 内容 + if strings.Contains(contentType, "text/css") { + return rewriter.RewriteCSS(body) + } + + // JavaScript 内容 - 暂时不重写,可能会破坏功能 + // 未来可以添加更智能的 JS 重写 + + return body +} + +func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { + var reader io.Reader = resp.Body + + // 处理 gzip 压缩 + if resp.Header.Get("Content-Encoding") == "gzip" { + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + defer gzReader.Close() + reader = gzReader + } + + // 限制读取大小 + limitReader := io.LimitReader(reader, h.maxResponseSize) + + return io.ReadAll(limitReader) +} + +func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { + // 复制必要的请求头 + headersToForward := []string{ + "Accept", + "Accept-Language", + "Accept-Encoding", + "Cache-Control", + "Referer", + } + + for _, header := range headersToForward { + if value := originalReq.Header.Get(header); value != "" { + proxyReq.Header.Set(header, value) + } + } + + // 设置自定义 User-Agent + proxyReq.Header.Set("User-Agent", h.userAgent) + + // 移除可能暴露代理的头 + proxyReq.Header.Del("X-Forwarded-For") + proxyReq.Header.Del("X-Real-IP") + proxyReq.Header.Del("Via") +} + +func (h *ProxyHandler) shouldCache(resp *http.Response) bool { + // 只缓存成功的 GET 请求 + if resp.Request.Method != "GET" { + return false + } + + if resp.StatusCode != http.StatusOK { + return false + } + + // 检查 Cache-Control + cacheControl := resp.Header.Get("Cache-Control") + if strings.Contains(cacheControl, "no-store") || + strings.Contains(cacheControl, "no-cache") || + strings.Contains(cacheControl, "private") { + return false + } + + // 检查内容类型 + contentType := resp.Header.Get("Content-Type") + cacheableTypes := []string{ + "text/html", + "text/css", + "application/javascript", + "image/", + "font/", + } + + for _, ct := range cacheableTypes { + if strings.Contains(contentType, ct) { + return true + } + } + + return false +} + +func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { + // 复制响应头 + headersToForward := []string{ + "Content-Type", + "Content-Language", + "Last-Modified", + "ETag", + "Expires", + } + + for _, header := range headersToForward { + if value := resp.Header.Get(header); value != "" { + w.Header().Set(header, value) + } + } + + // 添加自定义头 + w.Header().Set("X-Proxied-By", "SiteProxy") + w.Header().Set("X-Cache-Status", "MISS") + + // 移除不需要的头 + w.Header().Del("Content-Encoding") + w.Header().Del("Content-Length") + + // 安全头 + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "SAMEORIGIN") + w.Header().Set("Referrer-Policy", "no-referrer") + + w.WriteHeader(resp.StatusCode) + w.Write(body) +} + +func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { + for key, value := range entry.Headers { + w.Header().Set(key, value) + } + + w.Header().Set("X-Cache-Status", "HIT") + w.Header().Set("X-Proxied-By", "SiteProxy") + w.Header().Set("Age", entry.Age()) + + w.WriteHeader(http.StatusOK) + w.Write(entry.Data) +} + +func getClientIP(r *http.Request) string { + // 尝试从各种头中获取真实 IP + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + // X-Forwarded-For 可能包含多个 IP + ips := strings.Split(ip, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // 使用远程地址 + ip := r.RemoteAddr + if idx := strings.LastIndex(ip, ":"); idx != -1 { + ip = ip[:idx] + } + + return ip }