更新 main.go
This commit is contained in:
49
main.go
49
main.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
"siteproxy/auth"
|
"siteproxy/auth"
|
||||||
"siteproxy/cache"
|
"siteproxy/cache"
|
||||||
@@ -16,7 +17,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// 加载配置
|
|
||||||
cfg := config.LoadFromEnv()
|
cfg := config.LoadFromEnv()
|
||||||
|
|
||||||
log.Printf("Starting Secure Site Proxy...")
|
log.Printf("Starting Secure Site Proxy...")
|
||||||
@@ -24,17 +24,14 @@ func main() {
|
|||||||
log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow)
|
log.Printf("Rate limit: %d requests per %v", cfg.RateLimit, cfg.RateLimitWindow)
|
||||||
log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024)
|
log.Printf("Cache enabled: %v (max: %d MB)", cfg.CacheEnabled, cfg.CacheMaxSize/1024/1024)
|
||||||
|
|
||||||
// 加载模板
|
|
||||||
templates, err := loadTemplates()
|
templates, err := loadTemplates()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to load templates: %v", err)
|
log.Fatalf("Failed to load templates: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化组件
|
authSessionMgr := auth.NewSessionManager(cfg.SessionTimeout)
|
||||||
sessionMgr := auth.NewSessionManager(cfg.SessionTimeout)
|
authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, authSessionMgr, templates)
|
||||||
authMw := auth.NewAuthMiddleware(cfg.Username, cfg.Password, sessionMgr, templates)
|
|
||||||
|
|
||||||
// 转换 BlockedDomains 为 map
|
|
||||||
blockedDomainsMap := make(map[string]bool)
|
blockedDomainsMap := make(map[string]bool)
|
||||||
for _, domain := range cfg.BlockedDomains {
|
for _, domain := range cfg.BlockedDomains {
|
||||||
blockedDomainsMap[domain] = true
|
blockedDomainsMap[domain] = true
|
||||||
@@ -55,10 +52,13 @@ func main() {
|
|||||||
memCache = cache.NewMemoryCache(0, 0)
|
memCache = cache.NewMemoryCache(0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
proxySessionMgr := proxy.NewProxySessionManager(30 * time.Minute)
|
||||||
|
|
||||||
proxyHandler := proxy.NewHandler(
|
proxyHandler := proxy.NewHandler(
|
||||||
validator,
|
validator,
|
||||||
rateLimiter,
|
rateLimiter,
|
||||||
memCache,
|
memCache,
|
||||||
|
proxySessionMgr,
|
||||||
cfg.UserAgent,
|
cfg.UserAgent,
|
||||||
cfg.MaxResponseSize,
|
cfg.MaxResponseSize,
|
||||||
)
|
)
|
||||||
@@ -66,20 +66,36 @@ func main() {
|
|||||||
indexHandler := proxy.NewIndexHandler(templates)
|
indexHandler := proxy.NewIndexHandler(templates)
|
||||||
statsHandler := proxy.NewStatsHandler(memCache)
|
statsHandler := proxy.NewStatsHandler(memCache)
|
||||||
|
|
||||||
// 设置路由
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
// 公开路由
|
|
||||||
mux.HandleFunc("/login", authMw.Login)
|
mux.HandleFunc("/login", authMw.Login)
|
||||||
mux.HandleFunc("/health", healthCheck)
|
mux.HandleFunc("/health", healthCheck)
|
||||||
|
|
||||||
// 受保护路由
|
mux.Handle("/", authMw.Require(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
mux.Handle("/", authMw.Require(indexHandler))
|
if r.URL.Path != "/" {
|
||||||
mux.Handle("/proxy", authMw.Require(proxyHandler))
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method == "POST" {
|
||||||
|
targetURL := r.FormValue("url")
|
||||||
|
if targetURL == "" {
|
||||||
|
http.Error(w, "URL required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := proxySessionMgr.Create(targetURL)
|
||||||
|
http.Redirect(w, r, "/p/"+token, http.StatusSeeOther)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
indexHandler.ServeHTTP(w, r)
|
||||||
|
})))
|
||||||
|
|
||||||
|
mux.Handle("/p/", authMw.Require(proxyHandler))
|
||||||
mux.Handle("/stats", authMw.Require(statsHandler))
|
mux.Handle("/stats", authMw.Require(statsHandler))
|
||||||
mux.HandleFunc("/logout", authMw.Logout)
|
mux.HandleFunc("/logout", authMw.Logout)
|
||||||
|
|
||||||
// 启动服务器
|
|
||||||
addr := ":" + cfg.Port
|
addr := ":" + cfg.Port
|
||||||
|
|
||||||
log.Printf("Server listening on %s", addr)
|
log.Printf("Server listening on %s", addr)
|
||||||
@@ -92,12 +108,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func loadTemplates() (*template.Template, error) {
|
func loadTemplates() (*template.Template, error) {
|
||||||
// 尝试从多个位置加载模板
|
templateDirs := []string{"templates", "./templates", "/app/templates"}
|
||||||
templateDirs := []string{
|
|
||||||
"templates",
|
|
||||||
"./templates",
|
|
||||||
"/app/templates",
|
|
||||||
}
|
|
||||||
|
|
||||||
var templateDir string
|
var templateDir string
|
||||||
for _, dir := range templateDirs {
|
for _, dir := range templateDirs {
|
||||||
@@ -113,7 +124,6 @@ func loadTemplates() (*template.Template, error) {
|
|||||||
|
|
||||||
log.Printf("Loading templates from: %s", templateDir)
|
log.Printf("Loading templates from: %s", templateDir)
|
||||||
|
|
||||||
// 加载所有 .html 文件
|
|
||||||
pattern := filepath.Join(templateDir, "*.html")
|
pattern := filepath.Join(templateDir, "*.html")
|
||||||
tmpl, err := template.ParseGlob(pattern)
|
tmpl, err := template.ParseGlob(pattern)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -129,3 +139,4 @@ func healthCheck(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write([]byte(`{"status":"ok","version":"1.0.0"}`))
|
w.Write([]byte(`{"status":"ok","version":"1.0.0"}`))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user