更新 main.go

This commit is contained in:
XOF
2025-12-15 01:54:32 +08:00
parent 10f2c1c97b
commit 7d4800efb9

62
main.go
View File

@@ -2,9 +2,11 @@
package main package main
import ( import (
"html/template"
"log" "log"
"net/http" "net/http"
"os" "os"
"path/filepath"
"siteproxy/auth" "siteproxy/auth"
"siteproxy/cache" "siteproxy/cache"
@@ -22,9 +24,15 @@ func main() {
log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow) log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow)
log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024) log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024)
// 加载模板
templates, err := loadTemplates()
if err != nil {
log.Fatalf("Failed to load templates: %v", err)
}
// 初始化组件 // 初始化组件
sessionMgr := auth.NewSessionManager(cfg.SessionTimeout) sessionMgr := auth.NewSessionManager(cfg.SessionTimeout)
authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr) authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr, templates)
validator := security.NewRequestValidator( validator := security.NewRequestValidator(
cfg.BlockedDomains, cfg.BlockedDomains,
@@ -38,7 +46,7 @@ func main() {
if cfg.CacheEnabled { if cfg.CacheEnabled {
memCache = cache.NewMemoryCache(cfg.CacheMaxSize, cfg.CacheTTL) memCache = cache.NewMemoryCache(cfg.CacheMaxSize, cfg.CacheTTL)
} else { } else {
memCache = cache.NewMemoryCache(0, 0) // 禁用缓存 memCache = cache.NewMemoryCache(0, 0)
} }
proxyHandler := proxy.NewHandler( proxyHandler := proxy.NewHandler(
@@ -49,6 +57,7 @@ func main() {
cfg.MaxResponseSize, cfg.MaxResponseSize,
) )
indexHandler := proxy.NewIndexHandler(templates)
statsHandler := proxy.NewStatsHandler(memCache) statsHandler := proxy.NewStatsHandler(memCache)
// 设置路由 // 设置路由
@@ -59,31 +68,58 @@ func main() {
mux.HandleFunc("/health", healthCheck) mux.HandleFunc("/health", healthCheck)
// 受保护路由 // 受保护路由
mux.Handle("/", authMw.Require(http.HandlerFunc(proxy.ServeIndexPage))) mux.Handle("/", authMw.Require(indexHandler))
mux.Handle("/proxy", authMw.Require(proxyHandler)) mux.Handle("/proxy", authMw.Require(proxyHandler))
mux.Handle("/stats", authMw.Require(statsHandler)) mux.Handle("/stats", authMw.Require(statsHandler))
mux.HandleFunc("/logout", authMw.Logout) mux.HandleFunc("/logout", authMw.Logout)
// 启动服务器 // 启动服务器
port := getEnv("PORT", "8080") addr := ":" + cfg.Port
addr := ":" + port
log.Printf("Server listening on %s", addr) log.Printf("Server listening on %s", addr)
log.Printf("Login with username: %s", cfg.Username) log.Printf("Login with username: %s", cfg.Username)
log.Printf("Access at: http://localhost:%s", cfg.Port)
if err := http.ListenAndServe(addr, mux); err != nil { if err := http.ListenAndServe(addr, mux); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
func healthCheck(w http.ResponseWriter, r *http.Request) { func loadTemplates() (*template.Template, error) {
w.Header().Set("Content-Type", "application/json") // 尝试从多个位置加载模板
w.Write([]byte(`{"status":"ok"}`)) templateDirs := []string{
"templates",
"./templates",
"/app/templates",
}
var templateDir string
for _, dir := range templateDirs {
if _, err := os.Stat(dir); err == nil {
templateDir = dir
break
}
}
if templateDir == "" {
return nil, os.ErrNotExist
}
log.Printf("Loading templates from: %s", templateDir)
// 加载所有 .html 文件
pattern := filepath.Join(templateDir, "*.html")
tmpl, err := template.ParseGlob(pattern)
if err != nil {
return nil, err
}
log.Printf("Loaded templates: %v", tmpl.DefinedTemplates())
return tmpl, nil
} }
func getEnv(key, defaultValue string) string { func healthCheck(w http.ResponseWriter, r *http.Request) {
if value := os.Getenv(key); value != "" { w.Header().Set("Content-Type", "application/json")
return value w.Write([]byte(`{"status":"ok","version":"1.0.0"}`))
}
return defaultValue
} }