From 25715cfe22353d26bcca636b5a0f06724d6e8298 Mon Sep 17 00:00:00 2001 From: XOF Date: Mon, 15 Dec 2025 04:17:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20main.go?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/main.go b/main.go index 1de853f..69c39f8 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "path/filepath" + "time" "siteproxy/auth" "siteproxy/cache" @@ -16,7 +17,6 @@ import ( ) func main() { - // 加载配置 cfg := config.LoadFromEnv() log.Printf("Starting Secure Site Proxy...") @@ -24,17 +24,14 @@ func main() { log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow) log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024) - // 加载模板 templates, err := loadTemplates() if err != nil { log.Fatalf("Failed to load templates: %v", err) } - // 初始化组件 - sessionMgr := auth.NewSessionManager(cfg.SessionTimeout) - authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr, templates) + authSessionMgr := auth.NewSessionManager(cfg.SessionTimeout) + authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, authSessionMgr, templates) - // 转换 BlockedDomains 为 map blockedDomainsMap := make(map[string]bool) for _, domain := range cfg.BlockedDomains { blockedDomainsMap[domain] = true @@ -55,10 +52,13 @@ func main() { memCache = cache.NewMemoryCache(0, 0) } + proxySessionMgr := proxy.NewProxySessionManager(30 * time.Minute) + proxyHandler := proxy.NewHandler( validator, rateLimiter, memCache, + proxySessionMgr, cfg.UserAgent, cfg.MaxResponseSize, ) @@ -66,20 +66,36 @@ func main() { indexHandler := proxy.NewIndexHandler(templates) statsHandler := proxy.NewStatsHandler(memCache) - // 设置路由 mux := http.NewServeMux() - // 公开路由 mux.HandleFunc("/login", authMw.Login) mux.HandleFunc("/health", healthCheck) - // 受保护路由 - mux.Handle("/", authMw.Require(indexHandler)) - mux.Handle("/proxy", authMw.Require(proxyHandler)) + mux.Handle("/", authMw.Require(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + + if r.Method == "POST" { + targetURL := r.FormValue("url") + if targetURL == "" { + http.Error(w, "URL required", http.StatusBadRequest) + return + } + + token := proxySessionMgr.Create(targetURL) + http.Redirect(w, r, "/p/"+token, http.StatusSeeOther) + return + } + + indexHandler.ServeHTTP(w, r) + }))) + + mux.Handle("/p/", authMw.Require(proxyHandler)) mux.Handle("/stats", authMw.Require(statsHandler)) mux.HandleFunc("/logout", authMw.Logout) - // 启动服务器 addr := ":" + cfg.Port log.Printf("Server listening on %s", addr) @@ -92,12 +108,7 @@ func main() { } func loadTemplates() (*template.Template, error) { - // 尝试从多个位置加载模板 - templateDirs := []string{ - "templates", - "./templates", - "/app/templates", - } + templateDirs := []string{"templates", "./templates", "/app/templates"} var templateDir string for _, dir := range templateDirs { @@ -113,7 +124,6 @@ func loadTemplates() (*template.Template, error) { log.Printf("Loading templates from: %s", templateDir) - // 加载所有 .html 文件 pattern := filepath.Join(templateDir, "*.html") tmpl, err := template.ParseGlob(pattern) if err != nil { @@ -129,3 +139,4 @@ func healthCheck(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"status":"ok","version":"1.0.0"}`)) } +