更新 config/config.go

This commit is contained in:
XOF
2025-12-15 02:19:59 +08:00
parent 1b086164ff
commit 3d98525d82

View File

@@ -1,302 +1,202 @@
// config/config.go // config/config.go
package proxy package config
import ( import (
"bytes"
"compress/gzip"
"io"
"log" "log"
"net/http" "os"
"net/url" "strconv"
"strings" "strings"
"time" "time"
"siteproxy/cache"
"siteproxy/security"
) )
type ProxyHandler struct { type Config struct {
validator *security.RequestValidator // 认证配置
rateLimiter *security.RateLimiter Username string
cache *cache.MemoryCache Password string
userAgent string SessionSecret string
maxResponseSize int64 SessionTimeout time.Duration
// 安全配置
RateLimit int
RateLimitWindow time.Duration
MaxResponseSize int64
// 代理配置
AllowedSchemes []string
UserAgent string
// 黑名单配置
BlockedDomains []string
BlockedCIDRs []string
// 缓存配置
CacheEnabled bool
CacheMaxSize int64
CacheTTL time.Duration
// 服务器配置
Port string
} }
func NewHandler( func LoadFromEnv() *Config {
validator *security.RequestValidator, cfg := &Config{
rateLimiter *security.RateLimiter, Username: getEnv("AUTH_USERNAME", "admin"),
cache *cache.MemoryCache, Password: getEnvRequired("AUTH_PASSWORD"),
userAgent string, SessionSecret: getEnvOrGenerate("SESSION_SECRET"),
maxResponseSize int64, SessionTimeout: parseDuration(getEnv("SESSION_TIMEOUT", "30m")),
) *ProxyHandler { RateLimit: parseInt(getEnv("RATE_LIMIT_REQUESTS", "100")),
return &ProxyHandler{ RateLimitWindow: parseDuration(getEnv("RATE_LIMIT_WINDOW", "1m")),
validator: validator, MaxResponseSize: parseInt64(getEnv("MAX_RESPONSE_SIZE", "52428800")),
rateLimiter: rateLimiter, AllowedSchemes: parseList(getEnv("ALLOWED_SCHEMES", "http,https")),
cache: cache, UserAgent: getEnv("USER_AGENT", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"),
userAgent: userAgent, BlockedDomains: parseList(getEnv("BLOCKED_DOMAINS", getDefaultBlockedDomains())),
maxResponseSize: maxResponseSize, BlockedCIDRs: parseList(getEnv("BLOCKED_CIDRS", getDefaultBlockedCIDRs())),
CacheEnabled: parseBool(getEnv("CACHE_ENABLED", "true")),
CacheMaxSize: parseInt64(getEnv("CACHE_MAX_SIZE", "104857600")),
CacheTTL: parseDuration(getEnv("CACHE_TTL", "1h")),
Port: getEnv("PORT", "8080"),
}
// 验证配置
cfg.validate()
return cfg
}
func (c *Config) validate() {
if c.Password == "" {
log.Fatal("AUTH_PASSWORD is required")
}
if c.Password == "your_secure_password_here" ||
c.Password == "change_this" {
log.Fatal("Please change the default AUTH_PASSWORD")
}
if len(c.Password) < 8 {
log.Fatal("AUTH_PASSWORD must be at least 8 characters")
}
if c.SessionSecret == "" {
log.Fatal("SESSION_SECRET is required")
}
if len(c.SessionSecret) < 32 {
log.Fatal("SESSION_SECRET must be at least 32 characters")
}
if c.SessionTimeout < time.Minute {
log.Fatal("SESSION_TIMEOUT must be at least 1 minute")
}
if c.RateLimit < 1 {
log.Fatal("RATE_LIMIT_REQUESTS must be at least 1")
}
if c.MaxResponseSize < 1024 {
log.Fatal("MAX_RESPONSE_SIZE must be at least 1KB")
} }
} }
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func getEnv(key, defaultValue string) string {
// 获取目标 URL if value := os.Getenv(key); value != "" {
targetURL := r.URL.Query().Get("url") return value
if targetURL == "" { }
http.Error(w, "Missing url parameter", http.StatusBadRequest) return defaultValue
return }
func getEnvRequired(key string) string {
value := os.Getenv(key)
if value == "" {
log.Fatalf("%s is required", key)
}
return value
}
func getEnvOrGenerate(key string) string {
if value := os.Getenv(key); value != "" {
return value
} }
// 验证 URL // 生成随机密钥
if err := h.validator.ValidateURL(targetURL); err != nil { secret := GenerateSessionSecret()
log.Printf("URL validation failed: %v", err) log.Printf("WARNING: %s not set, generated random value", key)
http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden) log.Printf("Add this to your .env file: %s=%s", key, secret)
return return secret
} }
// 速率限制 func parseDuration(s string) time.Duration {
clientIP := getClientIP(r) d, err := time.ParseDuration(s)
if !h.rateLimiter.Allow(clientIP) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
// 检查缓存
if entry := h.cache.Get(targetURL); entry != nil {
log.Printf("Cache HIT: %s", targetURL)
h.serveCached(w, entry)
return
}
log.Printf("Cache MISS: %s", targetURL)
// 创建代理请求
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil { if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError) log.Fatalf("Invalid duration: %s", s)
return
} }
return d
// 设置请求头 }
h.setProxyHeaders(proxyReq, r)
func parseInt(s string) int {
// 发送请求 i, err := strconv.Atoi(s)
client := &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// 验证重定向 URL
if err := h.validator.ValidateURL(req.URL.String()); err != nil {
return err
}
return nil
},
}
resp, err := client.Do(proxyReq)
if err != nil { if err != nil {
log.Printf("Request failed: %v", err) log.Fatalf("Invalid integer: %s", s)
http.Error(w, "Failed to fetch URL", http.StatusBadGateway)
return
} }
defer resp.Body.Close() return i
}
// 读取响应体
body, err := h.readResponseBody(resp) func parseInt64(s string) int64 {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
log.Printf("Failed to read response: %v", err) log.Fatalf("Invalid integer: %s", s)
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
} }
return i
// 重写内容
contentType := resp.Header.Get("Content-Type")
body = h.rewriteContent(body, targetURL, contentType)
// 缓存响应
if h.shouldCache(resp) {
h.cache.Set(targetURL, body, resp.Header)
}
// 发送响应
h.sendResponse(w, resp, body)
} }
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte { func parseBool(s string) bool {
rewriter, err := NewContentRewriter(targetURL) b, err := strconv.ParseBool(s)
if err != nil { if err != nil {
log.Printf("Failed to create rewriter: %v", err) log.Fatalf("Invalid boolean: %s", s)
return body
} }
return b
contentType = strings.ToLower(contentType)
// HTML 内容
if strings.Contains(contentType, "text/html") {
rewritten, err := rewriter.RewriteHTML(body)
if err != nil {
log.Printf("HTML rewrite failed: %v", err)
return body
}
return rewritten
}
// CSS 内容
if strings.Contains(contentType, "text/css") {
return rewriter.RewriteCSS(body)
}
// JavaScript 内容 - 暂时不重写,可能会破坏功能
// 未来可以添加更智能的 JS 重写
return body
} }
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) { func parseList(s string) []string {
var reader io.Reader = resp.Body if s == "" {
return []string{}
// 处理 gzip 压缩
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
defer gzReader.Close()
reader = gzReader
} }
// 限制读取大小 parts := strings.Split(s, ",")
limitReader := io.LimitReader(reader, h.maxResponseSize) result := make([]string, 0, len(parts))
return io.ReadAll(limitReader) for _, part := range parts {
} if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
// 复制必要的请求头
headersToForward := []string{
"Accept",
"Accept-Language",
"Accept-Encoding",
"Cache-Control",
"Referer",
}
for _, header := range headersToForward {
if value := originalReq.Header.Get(header); value != "" {
proxyReq.Header.Set(header, value)
} }
} }
// 设置自定义 User-Agent return result
proxyReq.Header.Set("User-Agent", h.userAgent)
// 移除可能暴露代理的头
proxyReq.Header.Del("X-Forwarded-For")
proxyReq.Header.Del("X-Real-IP")
proxyReq.Header.Del("Via")
} }
func (h *ProxyHandler) shouldCache(resp *http.Response) bool { func getDefaultBlockedDomains() string {
// 只缓存成功的 GET 请求 return strings.Join([]string{
if resp.Request.Method != "GET" { "localhost",
return false "127.0.0.1",
} "0.0.0.0",
"*.local",
if resp.StatusCode != http.StatusOK { "internal",
return false "metadata.google.internal",
} "169.254.169.254",
"metadata.azure.com",
// 检查 Cache-Control "metadata.packet.net",
cacheControl := resp.Header.Get("Cache-Control") }, ",")
if strings.Contains(cacheControl, "no-store") ||
strings.Contains(cacheControl, "no-cache") ||
strings.Contains(cacheControl, "private") {
return false
}
// 检查内容类型
contentType := resp.Header.Get("Content-Type")
cacheableTypes := []string{
"text/html",
"text/css",
"application/javascript",
"image/",
"font/",
}
for _, ct := range cacheableTypes {
if strings.Contains(contentType, ct) {
return true
}
}
return false
} }
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) { func getDefaultBlockedCIDRs() string {
// 复制响应头 return strings.Join([]string{
headersToForward := []string{ "10.0.0.0/8",
"Content-Type", "172.16.0.0/12",
"Content-Language", "192.168.0.0/16",
"Last-Modified", "169.254.0.0/16",
"ETag", "::1/128",
"Expires", "fc00::/7",
} "fe80::/10",
"100.64.0.0/10",
for _, header := range headersToForward { }, ",")
if value := resp.Header.Get(header); value != "" {
w.Header().Set(header, value)
}
}
// 添加自定义头
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("X-Cache-Status", "MISS")
// 移除不需要的头
w.Header().Del("Content-Encoding")
w.Header().Del("Content-Length")
// 安全头
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
w.Header().Set("Referrer-Policy", "no-referrer")
w.WriteHeader(resp.StatusCode)
w.Write(body)
}
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
for key, value := range entry.Headers {
w.Header().Set(key, value)
}
w.Header().Set("X-Cache-Status", "HIT")
w.Header().Set("X-Proxied-By", "SiteProxy")
w.Header().Set("Age", entry.Age())
w.WriteHeader(http.StatusOK)
w.Write(entry.Data)
}
func getClientIP(r *http.Request) string {
// 尝试从各种头中获取真实 IP
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
// X-Forwarded-For 可能包含多个 IP
ips := strings.Split(ip, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// 使用远程地址
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
} }