更新 main.go
This commit is contained in:
62
main.go
62
main.go
@@ -2,9 +2,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"siteproxy/auth"
|
"siteproxy/auth"
|
||||||
"siteproxy/cache"
|
"siteproxy/cache"
|
||||||
@@ -22,9 +24,15 @@ func main() {
|
|||||||
log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow)
|
log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow)
|
||||||
log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024)
|
log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024)
|
||||||
|
|
||||||
|
// 加载模板
|
||||||
|
templates, err := loadTemplates()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load templates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 初始化组件
|
// 初始化组件
|
||||||
sessionMgr := auth.NewSessionManager(cfg.SessionTimeout)
|
sessionMgr := auth.NewSessionManager(cfg.SessionTimeout)
|
||||||
authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr)
|
authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr, templates)
|
||||||
|
|
||||||
validator := security.NewRequestValidator(
|
validator := security.NewRequestValidator(
|
||||||
cfg.BlockedDomains,
|
cfg.BlockedDomains,
|
||||||
@@ -38,7 +46,7 @@ func main() {
|
|||||||
if cfg.CacheEnabled {
|
if cfg.CacheEnabled {
|
||||||
memCache = cache.NewMemoryCache(cfg.CacheMaxSize, cfg.CacheTTL)
|
memCache = cache.NewMemoryCache(cfg.CacheMaxSize, cfg.CacheTTL)
|
||||||
} else {
|
} else {
|
||||||
memCache = cache.NewMemoryCache(0, 0) // 禁用缓存
|
memCache = cache.NewMemoryCache(0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyHandler := proxy.NewHandler(
|
proxyHandler := proxy.NewHandler(
|
||||||
@@ -49,6 +57,7 @@ func main() {
|
|||||||
cfg.MaxResponseSize,
|
cfg.MaxResponseSize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
indexHandler := proxy.NewIndexHandler(templates)
|
||||||
statsHandler := proxy.NewStatsHandler(memCache)
|
statsHandler := proxy.NewStatsHandler(memCache)
|
||||||
|
|
||||||
// 设置路由
|
// 设置路由
|
||||||
@@ -59,31 +68,58 @@ func main() {
|
|||||||
mux.HandleFunc("/health", healthCheck)
|
mux.HandleFunc("/health", healthCheck)
|
||||||
|
|
||||||
// 受保护路由
|
// 受保护路由
|
||||||
mux.Handle("/", authMw.Require(http.HandlerFunc(proxy.ServeIndexPage)))
|
mux.Handle("/", authMw.Require(indexHandler))
|
||||||
mux.Handle("/proxy", authMw.Require(proxyHandler))
|
mux.Handle("/proxy", authMw.Require(proxyHandler))
|
||||||
mux.Handle("/stats", authMw.Require(statsHandler))
|
mux.Handle("/stats", authMw.Require(statsHandler))
|
||||||
mux.HandleFunc("/logout", authMw.Logout)
|
mux.HandleFunc("/logout", authMw.Logout)
|
||||||
|
|
||||||
// 启动服务器
|
// 启动服务器
|
||||||
port := getEnv("PORT", "8080")
|
addr := ":" + cfg.Port
|
||||||
addr := ":" + port
|
|
||||||
|
|
||||||
log.Printf("Server listening on %s", addr)
|
log.Printf("Server listening on %s", addr)
|
||||||
log.Printf("Login with username: %s", cfg.Username)
|
log.Printf("Login with username: %s", cfg.Username)
|
||||||
|
log.Printf("Access at: http://localhost:%s", cfg.Port)
|
||||||
|
|
||||||
if err := http.ListenAndServe(addr, mux); err != nil {
|
if err := http.ListenAndServe(addr, mux); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func healthCheck(w http.ResponseWriter, r *http.Request) {
|
func loadTemplates() (*template.Template, error) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
// 尝试从多个位置加载模板
|
||||||
w.Write([]byte(`{"status":"ok"}`))
|
templateDirs := []string{
|
||||||
|
"templates",
|
||||||
|
"./templates",
|
||||||
|
"/app/templates",
|
||||||
|
}
|
||||||
|
|
||||||
|
var templateDir string
|
||||||
|
for _, dir := range templateDirs {
|
||||||
|
if _, err := os.Stat(dir); err == nil {
|
||||||
|
templateDir = dir
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if templateDir == "" {
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Loading templates from: %s", templateDir)
|
||||||
|
|
||||||
|
// 加载所有 .html 文件
|
||||||
|
pattern := filepath.Join(templateDir, "*.html")
|
||||||
|
tmpl, err := template.ParseGlob(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Loaded templates: %v", tmpl.DefinedTemplates())
|
||||||
|
|
||||||
|
return tmpl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnv(key, defaultValue string) string {
|
func healthCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
if value := os.Getenv(key); value != "" {
|
w.Header().Set("Content-Type", "application/json")
|
||||||
return value
|
w.Write([]byte(`{"status":"ok","version":"1.0.0"}`))
|
||||||
}
|
|
||||||
return defaultValue
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user