diff --git a/main.go b/main.go index b0b06b2..1d411a0 100644 --- a/main.go +++ b/main.go @@ -2,9 +2,11 @@ package main import ( + "html/template" "log" "net/http" "os" + "path/filepath" "siteproxy/auth" "siteproxy/cache" @@ -22,9 +24,15 @@ func main() { log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow) log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024) + // 加载模板 + templates, err := loadTemplates() + if err != nil { + log.Fatalf("Failed to load templates: %v", err) + } + // 初始化组件 sessionMgr := auth.NewSessionManager(cfg.SessionTimeout) - authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr) + authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr, templates) validator := security.NewRequestValidator( cfg.BlockedDomains, @@ -38,7 +46,7 @@ func main() { if cfg.CacheEnabled { memCache = cache.NewMemoryCache(cfg.CacheMaxSize, cfg.CacheTTL) } else { - memCache = cache.NewMemoryCache(0, 0) // 禁用缓存 + memCache = cache.NewMemoryCache(0, 0) } proxyHandler := proxy.NewHandler( @@ -49,6 +57,7 @@ func main() { cfg.MaxResponseSize, ) + indexHandler := proxy.NewIndexHandler(templates) statsHandler := proxy.NewStatsHandler(memCache) // 设置路由 @@ -59,31 +68,58 @@ func main() { mux.HandleFunc("/health", healthCheck) // 受保护路由 - mux.Handle("/", authMw.Require(http.HandlerFunc(proxy.ServeIndexPage))) + mux.Handle("/", authMw.Require(indexHandler)) mux.Handle("/proxy", authMw.Require(proxyHandler)) mux.Handle("/stats", authMw.Require(statsHandler)) mux.HandleFunc("/logout", authMw.Logout) // 启动服务器 - port := getEnv("PORT", "8080") - addr := ":" + port + addr := ":" + cfg.Port log.Printf("Server listening on %s", addr) log.Printf("Login with username: %s", cfg.Username) + log.Printf("Access at: http://localhost:%s", cfg.Port) if err := http.ListenAndServe(addr, mux); err != nil { log.Fatal(err) } } -func healthCheck(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"status":"ok"}`)) +func loadTemplates() (*template.Template, error) { + // 尝试从多个位置加载模板 + templateDirs := []string{ + "templates", + "./templates", + "/app/templates", + } + + var templateDir string + for _, dir := range templateDirs { + if _, err := os.Stat(dir); err == nil { + templateDir = dir + break + } + } + + if templateDir == "" { + return nil, os.ErrNotExist + } + + log.Printf("Loading templates from: %s", templateDir) + + // 加载所有 .html 文件 + pattern := filepath.Join(templateDir, "*.html") + tmpl, err := template.ParseGlob(pattern) + if err != nil { + return nil, err + } + + log.Printf("Loaded templates: %v", tmpl.DefinedTemplates()) + + return tmpl, nil } -func getEnv(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue +func healthCheck(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"ok","version":"1.0.0"}`)) }