更新 config/config.go
This commit is contained in:
428
config/config.go
428
config/config.go
@@ -1,302 +1,202 @@
|
|||||||
// config/config.go
|
// config/config.go
|
||||||
package proxy
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"compress/gzip"
|
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"os"
|
||||||
"net/url"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"siteproxy/cache"
|
|
||||||
"siteproxy/security"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxyHandler struct {
|
type Config struct {
|
||||||
validator *security.RequestValidator
|
// 认证配置
|
||||||
rateLimiter *security.RateLimiter
|
Username string
|
||||||
cache *cache.MemoryCache
|
Password string
|
||||||
userAgent string
|
SessionSecret string
|
||||||
maxResponseSize int64
|
SessionTimeout time.Duration
|
||||||
|
|
||||||
|
// 安全配置
|
||||||
|
RateLimit int
|
||||||
|
RateLimitWindow time.Duration
|
||||||
|
MaxResponseSize int64
|
||||||
|
|
||||||
|
// 代理配置
|
||||||
|
AllowedSchemes []string
|
||||||
|
UserAgent string
|
||||||
|
|
||||||
|
// 黑名单配置
|
||||||
|
BlockedDomains []string
|
||||||
|
BlockedCIDRs []string
|
||||||
|
|
||||||
|
// 缓存配置
|
||||||
|
CacheEnabled bool
|
||||||
|
CacheMaxSize int64
|
||||||
|
CacheTTL time.Duration
|
||||||
|
|
||||||
|
// 服务器配置
|
||||||
|
Port string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(
|
func LoadFromEnv() *Config {
|
||||||
validator *security.RequestValidator,
|
cfg := &Config{
|
||||||
rateLimiter *security.RateLimiter,
|
Username: getEnv("AUTH_USERNAME", "admin"),
|
||||||
cache *cache.MemoryCache,
|
Password: getEnvRequired("AUTH_PASSWORD"),
|
||||||
userAgent string,
|
SessionSecret: getEnvOrGenerate("SESSION_SECRET"),
|
||||||
maxResponseSize int64,
|
SessionTimeout: parseDuration(getEnv("SESSION_TIMEOUT", "30m")),
|
||||||
) *ProxyHandler {
|
RateLimit: parseInt(getEnv("RATE_LIMIT_REQUESTS", "100")),
|
||||||
return &ProxyHandler{
|
RateLimitWindow: parseDuration(getEnv("RATE_LIMIT_WINDOW", "1m")),
|
||||||
validator: validator,
|
MaxResponseSize: parseInt64(getEnv("MAX_RESPONSE_SIZE", "52428800")),
|
||||||
rateLimiter: rateLimiter,
|
AllowedSchemes: parseList(getEnv("ALLOWED_SCHEMES", "http,https")),
|
||||||
cache: cache,
|
UserAgent: getEnv("USER_AGENT", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"),
|
||||||
userAgent: userAgent,
|
BlockedDomains: parseList(getEnv("BLOCKED_DOMAINS", getDefaultBlockedDomains())),
|
||||||
maxResponseSize: maxResponseSize,
|
BlockedCIDRs: parseList(getEnv("BLOCKED_CIDRS", getDefaultBlockedCIDRs())),
|
||||||
|
CacheEnabled: parseBool(getEnv("CACHE_ENABLED", "true")),
|
||||||
|
CacheMaxSize: parseInt64(getEnv("CACHE_MAX_SIZE", "104857600")),
|
||||||
|
CacheTTL: parseDuration(getEnv("CACHE_TTL", "1h")),
|
||||||
|
Port: getEnv("PORT", "8080"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证配置
|
||||||
|
cfg.validate()
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) validate() {
|
||||||
|
if c.Password == "" {
|
||||||
|
log.Fatal("AUTH_PASSWORD is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Password == "your_secure_password_here" ||
|
||||||
|
c.Password == "change_this" {
|
||||||
|
log.Fatal("Please change the default AUTH_PASSWORD")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.Password) < 8 {
|
||||||
|
log.Fatal("AUTH_PASSWORD must be at least 8 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.SessionSecret == "" {
|
||||||
|
log.Fatal("SESSION_SECRET is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.SessionSecret) < 32 {
|
||||||
|
log.Fatal("SESSION_SECRET must be at least 32 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.SessionTimeout < time.Minute {
|
||||||
|
log.Fatal("SESSION_TIMEOUT must be at least 1 minute")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.RateLimit < 1 {
|
||||||
|
log.Fatal("RATE_LIMIT_REQUESTS must be at least 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MaxResponseSize < 1024 {
|
||||||
|
log.Fatal("MAX_RESPONSE_SIZE must be at least 1KB")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func getEnv(key, defaultValue string) string {
|
||||||
// 获取目标 URL
|
if value := os.Getenv(key); value != "" {
|
||||||
targetURL := r.URL.Query().Get("url")
|
return value
|
||||||
if targetURL == "" {
|
}
|
||||||
http.Error(w, "Missing url parameter", http.StatusBadRequest)
|
return defaultValue
|
||||||
return
|
}
|
||||||
|
|
||||||
|
func getEnvRequired(key string) string {
|
||||||
|
value := os.Getenv(key)
|
||||||
|
if value == "" {
|
||||||
|
log.Fatalf("%s is required", key)
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnvOrGenerate(key string) string {
|
||||||
|
if value := os.Getenv(key); value != "" {
|
||||||
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证 URL
|
// 生成随机密钥
|
||||||
if err := h.validator.ValidateURL(targetURL); err != nil {
|
secret := GenerateSessionSecret()
|
||||||
log.Printf("URL validation failed: %v", err)
|
log.Printf("WARNING: %s not set, generated random value", key)
|
||||||
http.Error(w, "Invalid or blocked URL: "+err.Error(), http.StatusForbidden)
|
log.Printf("Add this to your .env file: %s=%s", key, secret)
|
||||||
return
|
return secret
|
||||||
}
|
}
|
||||||
|
|
||||||
// 速率限制
|
func parseDuration(s string) time.Duration {
|
||||||
clientIP := getClientIP(r)
|
d, err := time.ParseDuration(s)
|
||||||
if !h.rateLimiter.Allow(clientIP) {
|
|
||||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查缓存
|
|
||||||
if entry := h.cache.Get(targetURL); entry != nil {
|
|
||||||
log.Printf("Cache HIT: %s", targetURL)
|
|
||||||
h.serveCached(w, entry)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Cache MISS: %s", targetURL)
|
|
||||||
|
|
||||||
// 创建代理请求
|
|
||||||
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
log.Fatalf("Invalid duration: %s", s)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return d
|
||||||
// 设置请求头
|
}
|
||||||
h.setProxyHeaders(proxyReq, r)
|
|
||||||
|
func parseInt(s string) int {
|
||||||
// 发送请求
|
i, err := strconv.Atoi(s)
|
||||||
client := &http.Client{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
||||||
// 验证重定向 URL
|
|
||||||
if err := h.validator.ValidateURL(req.URL.String()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(proxyReq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Request failed: %v", err)
|
log.Fatalf("Invalid integer: %s", s)
|
||||||
http.Error(w, "Failed to fetch URL", http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
return i
|
||||||
|
}
|
||||||
// 读取响应体
|
|
||||||
body, err := h.readResponseBody(resp)
|
func parseInt64(s string) int64 {
|
||||||
|
i, err := strconv.ParseInt(s, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to read response: %v", err)
|
log.Fatalf("Invalid integer: %s", s)
|
||||||
http.Error(w, "Failed to read response", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return i
|
||||||
// 重写内容
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
body = h.rewriteContent(body, targetURL, contentType)
|
|
||||||
|
|
||||||
// 缓存响应
|
|
||||||
if h.shouldCache(resp) {
|
|
||||||
h.cache.Set(targetURL, body, resp.Header)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送响应
|
|
||||||
h.sendResponse(w, resp, body)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) rewriteContent(body []byte, targetURL, contentType string) []byte {
|
func parseBool(s string) bool {
|
||||||
rewriter, err := NewContentRewriter(targetURL)
|
b, err := strconv.ParseBool(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to create rewriter: %v", err)
|
log.Fatalf("Invalid boolean: %s", s)
|
||||||
return body
|
|
||||||
}
|
}
|
||||||
|
return b
|
||||||
contentType = strings.ToLower(contentType)
|
|
||||||
|
|
||||||
// HTML 内容
|
|
||||||
if strings.Contains(contentType, "text/html") {
|
|
||||||
rewritten, err := rewriter.RewriteHTML(body)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("HTML rewrite failed: %v", err)
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return rewritten
|
|
||||||
}
|
|
||||||
|
|
||||||
// CSS 内容
|
|
||||||
if strings.Contains(contentType, "text/css") {
|
|
||||||
return rewriter.RewriteCSS(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// JavaScript 内容 - 暂时不重写,可能会破坏功能
|
|
||||||
// 未来可以添加更智能的 JS 重写
|
|
||||||
|
|
||||||
return body
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) readResponseBody(resp *http.Response) ([]byte, error) {
|
func parseList(s string) []string {
|
||||||
var reader io.Reader = resp.Body
|
if s == "" {
|
||||||
|
return []string{}
|
||||||
// 处理 gzip 压缩
|
|
||||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
|
||||||
gzReader, err := gzip.NewReader(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer gzReader.Close()
|
|
||||||
reader = gzReader
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 限制读取大小
|
parts := strings.Split(s, ",")
|
||||||
limitReader := io.LimitReader(reader, h.maxResponseSize)
|
result := make([]string, 0, len(parts))
|
||||||
|
|
||||||
return io.ReadAll(limitReader)
|
for _, part := range parts {
|
||||||
}
|
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
||||||
|
result = append(result, trimmed)
|
||||||
func (h *ProxyHandler) setProxyHeaders(proxyReq, originalReq *http.Request) {
|
|
||||||
// 复制必要的请求头
|
|
||||||
headersToForward := []string{
|
|
||||||
"Accept",
|
|
||||||
"Accept-Language",
|
|
||||||
"Accept-Encoding",
|
|
||||||
"Cache-Control",
|
|
||||||
"Referer",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, header := range headersToForward {
|
|
||||||
if value := originalReq.Header.Get(header); value != "" {
|
|
||||||
proxyReq.Header.Set(header, value)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置自定义 User-Agent
|
return result
|
||||||
proxyReq.Header.Set("User-Agent", h.userAgent)
|
|
||||||
|
|
||||||
// 移除可能暴露代理的头
|
|
||||||
proxyReq.Header.Del("X-Forwarded-For")
|
|
||||||
proxyReq.Header.Del("X-Real-IP")
|
|
||||||
proxyReq.Header.Del("Via")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) shouldCache(resp *http.Response) bool {
|
func getDefaultBlockedDomains() string {
|
||||||
// 只缓存成功的 GET 请求
|
return strings.Join([]string{
|
||||||
if resp.Request.Method != "GET" {
|
"localhost",
|
||||||
return false
|
"127.0.0.1",
|
||||||
}
|
"0.0.0.0",
|
||||||
|
"*.local",
|
||||||
if resp.StatusCode != http.StatusOK {
|
"internal",
|
||||||
return false
|
"metadata.google.internal",
|
||||||
}
|
"169.254.169.254",
|
||||||
|
"metadata.azure.com",
|
||||||
// 检查 Cache-Control
|
"metadata.packet.net",
|
||||||
cacheControl := resp.Header.Get("Cache-Control")
|
}, ",")
|
||||||
if strings.Contains(cacheControl, "no-store") ||
|
|
||||||
strings.Contains(cacheControl, "no-cache") ||
|
|
||||||
strings.Contains(cacheControl, "private") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查内容类型
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
cacheableTypes := []string{
|
|
||||||
"text/html",
|
|
||||||
"text/css",
|
|
||||||
"application/javascript",
|
|
||||||
"image/",
|
|
||||||
"font/",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ct := range cacheableTypes {
|
|
||||||
if strings.Contains(contentType, ct) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) sendResponse(w http.ResponseWriter, resp *http.Response, body []byte) {
|
func getDefaultBlockedCIDRs() string {
|
||||||
// 复制响应头
|
return strings.Join([]string{
|
||||||
headersToForward := []string{
|
"10.0.0.0/8",
|
||||||
"Content-Type",
|
"172.16.0.0/12",
|
||||||
"Content-Language",
|
"192.168.0.0/16",
|
||||||
"Last-Modified",
|
"169.254.0.0/16",
|
||||||
"ETag",
|
"::1/128",
|
||||||
"Expires",
|
"fc00::/7",
|
||||||
}
|
"fe80::/10",
|
||||||
|
"100.64.0.0/10",
|
||||||
for _, header := range headersToForward {
|
}, ",")
|
||||||
if value := resp.Header.Get(header); value != "" {
|
|
||||||
w.Header().Set(header, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加自定义头
|
|
||||||
w.Header().Set("X-Proxied-By", "SiteProxy")
|
|
||||||
w.Header().Set("X-Cache-Status", "MISS")
|
|
||||||
|
|
||||||
// 移除不需要的头
|
|
||||||
w.Header().Del("Content-Encoding")
|
|
||||||
w.Header().Del("Content-Length")
|
|
||||||
|
|
||||||
// 安全头
|
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
|
||||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
|
||||||
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
w.Write(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *ProxyHandler) serveCached(w http.ResponseWriter, entry *cache.CacheEntry) {
|
|
||||||
for key, value := range entry.Headers {
|
|
||||||
w.Header().Set(key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("X-Cache-Status", "HIT")
|
|
||||||
w.Header().Set("X-Proxied-By", "SiteProxy")
|
|
||||||
w.Header().Set("Age", entry.Age())
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write(entry.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getClientIP(r *http.Request) string {
|
|
||||||
// 尝试从各种头中获取真实 IP
|
|
||||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
|
||||||
// X-Forwarded-For 可能包含多个 IP
|
|
||||||
ips := strings.Split(ip, ",")
|
|
||||||
if len(ips) > 0 {
|
|
||||||
return strings.TrimSpace(ips[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用远程地址
|
|
||||||
ip := r.RemoteAddr
|
|
||||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
||||||
ip = ip[:idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
return ip
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user