From 3d98525d8291475844abb92c1aa66a6e88446d76 Mon Sep 17 00:00:00 2001 From: XOF Date: Mon, 15 Dec 2025 02:19:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20config/config.go?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.go | 428 ++++++++++++++++++----------------------------- 1 file changed, 164 insertions(+), 264 deletions(-) diff --git a/config/config.go b/config/config.go index b353245..28a3f0e 100644 --- a/config/config.go +++ b/config/config.go @@ -1,302 +1,202 @@ // config/config.go -package proxy +package config import ( - "bytes" - "compress/gzip" - "io" "log" - "net/http" - "net/url" + "os" + "strconv" "strings" "time" - - "siteproxy/cache" - "siteproxy/security" ) -type ProxyHandler struct { - validator *security.RequestValidator - rateLimiter *security.RateLimiter - cache *cache.MemoryCache - userAgent string - maxResponseSize int64 +type Config struct { + // 认证配置 + Username string + Password string + SessionSecret string + SessionTimeout time.Duration + + // 安全配置 + RateLimit int + RateLimitWindow time.Duration + MaxResponseSize int64 + + // 代理配置 + AllowedSchemes []string + UserAgent string + + // 黑名单配置 + BlockedDomains []string + BlockedCIDRs []string + + // 缓存配置 + CacheEnabled bool + CacheMaxSize int64 + CacheTTL time.Duration + + // 服务器配置 + Port string } -func NewHandler( - validator *security.RequestValidator, - rateLimiter *security.RateLimiter, - cache *cache.MemoryCache, - userAgent string, - maxResponseSize int64, -) *ProxyHandler { - return &ProxyHandler{ - validator: validator, - rateLimiter: rateLimiter, - cache: cache, - userAgent: userAgent, - maxResponseSize: maxResponseSize, +func LoadFromEnv() *Config { + cfg := &Config{ + Username: getEnv("AUTH_USERNAME", "admin"), + Password: getEnvRequired("AUTH_PASSWORD"), + SessionSecret: getEnvOrGenerate("SESSION_SECRET"), + SessionTimeout: parseDuration(getEnv("SESSION_TIMEOUT", "30m")), + RateLimit: parseInt(getEnv("RATE_LIMIT_REQUESTS", "100")), + RateLimitWindow: parseDuration(getEnv("RATE_LIMIT_WINDOW", "1m")), + MaxResponseSize: parseInt64(getEnv("MAX_RESPONSE_SIZE", "52428800")), + AllowedSchemes: parseList(getEnv("ALLOWED_SCHEMES", "http,https")), + UserAgent: getEnv("USER_AGENT", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"), + BlockedDomains: parseList(getEnv("BLOCKED_DOMAINS", getDefaultBlockedDomains())), + BlockedCIDRs: parseList(getEnv("BLOCKED_CIDRS", getDefaultBlockedCIDRs())), + CacheEnabled: parseBool(getEnv("CACHE_ENABLED", "true")), + CacheMaxSize: parseInt64(getEnv("CACHE_MAX_SIZE", "104857600")), + CacheTTL: parseDuration(getEnv("CACHE_TTL", "1h")), + Port: getEnv("PORT", "8080"), + } + + // 验证配置 + cfg.validate() + + return cfg +} + +func (c *Config) validate() { + if c.Password == "" { + log.Fatal("AUTH_PASSWORD is required") + } + + if c.Password == "your_secure_password_here" || + c.Password == "change_this" { + log.Fatal("Please change the default AUTH_PASSWORD") + } + + if len(c.Password) < 8 { + log.Fatal("AUTH_PASSWORD must be at least 8 characters") + } + + if c.SessionSecret == "" { + log.Fatal("SESSION_SECRET is required") + } + + if len(c.SessionSecret) < 32 { + log.Fatal("SESSION_SECRET must be at least 32 characters") + } + + if c.SessionTimeout < time.Minute { + log.Fatal("SESSION_TIMEOUT must be at least 1 minute") + } + + if c.RateLimit < 1 { + log.Fatal("RATE_LIMIT_REQUESTS must be at least 1") + } + + if c.MaxResponseSize < 1024 { + log.Fatal("MAX_RESPONSE_SIZE must be at least 1KB") } } -func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // 获取目标 URL - targetURL := r.URL.Query().Get("url") - if targetURL == "" { - http.Error(w, "Missing url parameter", http.StatusBadRequest) - return +func getEnv(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getEnvRequired(key string) string { + value := os.Getenv(key) + if value == "" { + log.Fatalf("%s is required", key) + } + return value +} + +func getEnvOrGenerate(key string) string { + if value := os.Getenv(key); value != "" { + return value } - // 验证 URL - if err := h.validator.ValidateURL(targetURL); err != nil { - log.Printf("URL validation failed: %v", err) - http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden) - return - } - - // 速率限制 - clientIP := getClientIP(r) - if !h.rateLimiter.Allow(clientIP) { - http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) - return - } - - // 检查缓存 - if entry := h.cache.Get(targetURL); entry != nil { - log.Printf("Cache HIT: %s", targetURL) - h.serveCached(w, entry) - return - } - - log.Printf("Cache MISS: %s", targetURL) - - // 创建代理请求 - proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) + // 生成随机密钥 + secret := GenerateSessionSecret() + log.Printf("WARNING: %s not set, generated random value", key) + log.Printf("Add this to your .env file: %s=%s", key, secret) + return secret +} + +func parseDuration(s string) time.Duration { + d, err := time.ParseDuration(s) if err != nil { - http.Error(w, "Failed to create request", http.StatusInternalServerError) - return + log.Fatalf("Invalid duration: %s", s) } - - // 设置请求头 - h.setProxyHeaders(proxyReq, r) - - // 发送请求 - client := &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // 验证重定向 URL - if err := h.validator.ValidateURL(req.URL.String()); err != nil { - return err - } - return nil - }, - } - - resp, err := client.Do(proxyReq) + return d +} + +func parseInt(s string) int { + i, err := strconv.Atoi(s) if err != nil { - log.Printf("Request failed: %v", err) - http.Error(w, "Failed to fetch URL", http.StatusBadGateway) - return + log.Fatalf("Invalid integer: %s", s) } - defer resp.Body.Close() - - // 读取响应体 - body, err := h.readResponseBody(resp) + return i +} + +func parseInt64(s string) int64 { + i, err := strconv.ParseInt(s, 10, 64) if err != nil { - log.Printf("Failed to read response: %v", err) - http.Error(w, "Failed to read response", http.StatusInternalServerError) - return + log.Fatalf("Invalid integer: %s", s) } - - // 重写内容 - contentType := resp.Header.Get("Content-Type") - body = h.rewriteContent(body, targetURL, contentType) - - // 缓存响应 - if h.shouldCache(resp) { - h.cache.Set(targetURL, body, resp.Header) - } - - // 发送响应 - h.sendResponse(w, resp, body) + return i } -func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte { - rewriter, err := NewContentRewriter(targetURL) +func parseBool(s string) bool { + b, err := strconv.ParseBool(s) if err != nil { - log.Printf("Failed to create rewriter: %v", err) - return body + log.Fatalf("Invalid boolean: %s", s) } - - contentType = strings.ToLower(contentType) - - // HTML 内容 - if strings.Contains(contentType, "text/html") { - rewritten, err := rewriter.RewriteHTML(body) - if err != nil { - log.Printf("HTML rewrite failed: %v", err) - return body - } - return rewritten - } - - // CSS 内容 - if strings.Contains(contentType, "text/css") { - return rewriter.RewriteCSS(body) - } - - // JavaScript 内容 - 暂时不重写,可能会破坏功能 - // 未来可以添加更智能的 JS 重写 - - return body + return b } -func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { - var reader io.Reader = resp.Body - - // 处理 gzip 压缩 - if resp.Header.Get("Content-Encoding") == "gzip" { - gzReader, err := gzip.NewReader(resp.Body) - if err != nil { - return nil, err - } - defer gzReader.Close() - reader = gzReader +func parseList(s string) []string { + if s == "" { + return []string{} } - // 限制读取大小 - limitReader := io.LimitReader(reader, h.maxResponseSize) + parts := strings.Split(s, ",") + result := make([]string, 0, len(parts)) - return io.ReadAll(limitReader) -} - -func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) { - // 复制必要的请求头 - headersToForward := []string{ - "Accept", - "Accept-Language", - "Accept-Encoding", - "Cache-Control", - "Referer", - } - - for _, header := range headersToForward { - if value := originalReq.Header.Get(header); value != "" { - proxyReq.Header.Set(header, value) + for _, part := range parts { + if trimmed := strings.TrimSpace(part); trimmed != "" { + result = append(result, trimmed) } } - // 设置自定义 User-Agent - proxyReq.Header.Set("User-Agent", h.userAgent) - - // 移除可能暴露代理的头 - proxyReq.Header.Del("X-Forwarded-For") - proxyReq.Header.Del("X-Real-IP") - proxyReq.Header.Del("Via") + return result } -func (h *ProxyHandler) shouldCache(resp *http.Response) bool { - // 只缓存成功的 GET 请求 - if resp.Request.Method != "GET" { - return false - } - - if resp.StatusCode != http.StatusOK { - return false - } - - // 检查 Cache-Control - cacheControl := resp.Header.Get("Cache-Control") - if strings.Contains(cacheControl, "no-store") || - strings.Contains(cacheControl, "no-cache") || - strings.Contains(cacheControl, "private") { - return false - } - - // 检查内容类型 - contentType := resp.Header.Get("Content-Type") - cacheableTypes := []string{ - "text/html", - "text/css", - "application/javascript", - "image/", - "font/", - } - - for _, ct := range cacheableTypes { - if strings.Contains(contentType, ct) { - return true - } - } - - return false +func getDefaultBlockedDomains() string { + return strings.Join([]string{ + "localhost", + "127.0.0.1", + "0.0.0.0", + "*.local", + "internal", + "metadata.google.internal", + "169.254.169.254", + "metadata.azure.com", + "metadata.packet.net", + }, ",") } -func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { - // 复制响应头 - headersToForward := []string{ - "Content-Type", - "Content-Language", - "Last-Modified", - "ETag", - "Expires", - } - - for _, header := range headersToForward { - if value := resp.Header.Get(header); value != "" { - w.Header().Set(header, value) - } - } - - // 添加自定义头 - w.Header().Set("X-Proxied-By", "SiteProxy") - w.Header().Set("X-Cache-Status", "MISS") - - // 移除不需要的头 - w.Header().Del("Content-Encoding") - w.Header().Del("Content-Length") - - // 安全头 - w.Header().Set("X-Content-Type-Options", "nosniff") - w.Header().Set("X-Frame-Options", "SAMEORIGIN") - w.Header().Set("Referrer-Policy", "no-referrer") - - w.WriteHeader(resp.StatusCode) - w.Write(body) -} - -func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) { - for key, value := range entry.Headers { - w.Header().Set(key, value) - } - - w.Header().Set("X-Cache-Status", "HIT") - w.Header().Set("X-Proxied-By", "SiteProxy") - w.Header().Set("Age", entry.Age()) - - w.WriteHeader(http.StatusOK) - w.Write(entry.Data) -} - -func getClientIP(r *http.Request) string { - // 尝试从各种头中获取真实 IP - if ip := r.Header.Get("X-Real-IP"); ip != "" { - return ip - } - - if ip := r.Header.Get("X-Forwarded-For"); ip != "" { - // X-Forwarded-For 可能包含多个 IP - ips := strings.Split(ip, ",") - if len(ips) > 0 { - return strings.TrimSpace(ips[0]) - } - } - - // 使用远程地址 - ip := r.RemoteAddr - if idx := strings.LastIndex(ip, ":"); idx != -1 { - ip = ip[:idx] - } - - return ip +func getDefaultBlockedCIDRs() string { + return strings.Join([]string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "169.254.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + "100.64.0.0/10", + }, ",") }