更新 main.go

This commit is contained in:
XOF
2025-12-15 04:17:27 +08:00
parent 2a8cd1c427
commit 25715cfe22

49
main.go
View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"time"
"siteproxy/auth" "siteproxy/auth"
"siteproxy/cache" "siteproxy/cache"
@@ -16,7 +17,6 @@ import (
) )
func main() { func main() {
// 加载配置
cfg := config.LoadFromEnv() cfg := config.LoadFromEnv()
log.Printf("Starting Secure Site Proxy...") log.Printf("Starting Secure Site Proxy...")
@@ -24,17 +24,14 @@ func main() {
log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow) log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow)
log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024) log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024)
// 加载模板
templates, err := loadTemplates() templates, err := loadTemplates()
if err != nil { if err != nil {
log.Fatalf("Failed to load templates: %v", err) log.Fatalf("Failed to load templates: %v", err)
} }
// 初始化组件 authSessionMgr := auth.NewSessionManager(cfg.SessionTimeout)
sessionMgr := auth.NewSessionManager(cfg.SessionTimeout) authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, authSessionMgr, templates)
authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr, templates)
// 转换 BlockedDomains 为 map
blockedDomainsMap := make(map[string]bool) blockedDomainsMap := make(map[string]bool)
for _, domain := range cfg.BlockedDomains { for _, domain := range cfg.BlockedDomains {
blockedDomainsMap[domain] = true blockedDomainsMap[domain] = true
@@ -55,10 +52,13 @@ func main() {
memCache = cache.NewMemoryCache(0, 0) memCache = cache.NewMemoryCache(0, 0)
} }
proxySessionMgr := proxy.NewProxySessionManager(30 * time.Minute)
proxyHandler := proxy.NewHandler( proxyHandler := proxy.NewHandler(
validator, validator,
rateLimiter, rateLimiter,
memCache, memCache,
proxySessionMgr,
cfg.UserAgent, cfg.UserAgent,
cfg.MaxResponseSize, cfg.MaxResponseSize,
) )
@@ -66,20 +66,36 @@ func main() {
indexHandler := proxy.NewIndexHandler(templates) indexHandler := proxy.NewIndexHandler(templates)
statsHandler := proxy.NewStatsHandler(memCache) statsHandler := proxy.NewStatsHandler(memCache)
// 设置路由
mux := http.NewServeMux() mux := http.NewServeMux()
// 公开路由
mux.HandleFunc("/login", authMw.Login) mux.HandleFunc("/login", authMw.Login)
mux.HandleFunc("/health", healthCheck) mux.HandleFunc("/health", healthCheck)
// 受保护路由 mux.Handle("/", authMw.Require(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mux.Handle("/", authMw.Require(indexHandler)) if r.URL.Path != "/" {
mux.Handle("/proxy", authMw.Require(proxyHandler)) http.NotFound(w, r)
return
}
if r.Method == "POST" {
targetURL := r.FormValue("url")
if targetURL == "" {
http.Error(w, "URL required", http.StatusBadRequest)
return
}
token := proxySessionMgr.Create(targetURL)
http.Redirect(w, r, "/p/"+token, http.StatusSeeOther)
return
}
indexHandler.ServeHTTP(w, r)
})))
mux.Handle("/p/", authMw.Require(proxyHandler))
mux.Handle("/stats", authMw.Require(statsHandler)) mux.Handle("/stats", authMw.Require(statsHandler))
mux.HandleFunc("/logout", authMw.Logout) mux.HandleFunc("/logout", authMw.Logout)
// 启动服务器
addr := ":" + cfg.Port addr := ":" + cfg.Port
log.Printf("Server listening on %s", addr) log.Printf("Server listening on %s", addr)
@@ -92,12 +108,7 @@ func main() {
} }
func loadTemplates() (*template.Template, error) { func loadTemplates() (*template.Template, error) {
// 尝试从多个位置加载模板 templateDirs := []string{"templates", "./templates", "/app/templates"}
templateDirs := []string{
"templates",
"./templates",
"/app/templates",
}
var templateDir string var templateDir string
for _, dir := range templateDirs { for _, dir := range templateDirs {
@@ -113,7 +124,6 @@ func loadTemplates() (*template.Template, error) {
log.Printf("Loading templates from: %s", templateDir) log.Printf("Loading templates from: %s", templateDir)
// 加载所有 .html 文件
pattern := filepath.Join(templateDir, "*.html") pattern := filepath.Join(templateDir, "*.html")
tmpl, err := template.ParseGlob(pattern) tmpl, err := template.ParseGlob(pattern)
if err != nil { if err != nil {
@@ -129,3 +139,4 @@ func healthCheck(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok","version":"1.0.0"}`)) w.Write([]byte(`{"status":"ok","version":"1.0.0"}`))
} }