206 lines
5.8 KiB
Go
206 lines
5.8 KiB
Go
package web
|
||
|
||
import (
|
||
"crypto/subtle"
|
||
"embed"
|
||
"encoding/json"
|
||
"io/fs"
|
||
"net/http"
|
||
|
||
"godns/internal/stats"
|
||
"godns/pkg/logger"
|
||
)
|
||
|
||
//go:embed static/*
|
||
var staticFiles embed.FS
|
||
|
||
// Handler Web服务处理器
|
||
type Handler struct {
|
||
stats stats.StatsRecorder
|
||
version string
|
||
checkUpdateCh chan<- struct{}
|
||
logger logger.Logger
|
||
username string
|
||
password string
|
||
}
|
||
|
||
// NewHandler 创建Web处理器
|
||
func NewHandler(s stats.StatsRecorder, ver string, checkCh chan<- struct{}, log logger.Logger, username, password string) *Handler {
|
||
return &Handler{
|
||
stats: s,
|
||
version: ver,
|
||
checkUpdateCh: checkCh,
|
||
logger: log,
|
||
username: username,
|
||
password: password,
|
||
}
|
||
}
|
||
|
||
// basicAuth 中间件
|
||
func (h *Handler) basicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
// 如果未配置鉴权,直接放行
|
||
if h.username == "" || h.password == "" {
|
||
next(w, r)
|
||
return
|
||
}
|
||
user, pass, ok := r.BasicAuth()
|
||
if !ok || subtle.ConstantTimeCompare([]byte(user), []byte(h.username)) != 1 ||
|
||
subtle.ConstantTimeCompare([]byte(pass), []byte(h.password)) != 1 {
|
||
w.Header().Set("WWW-Authenticate", `Basic realm="NBDNS Monitor"`)
|
||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
next(w, r)
|
||
}
|
||
}
|
||
|
||
// RegisterRoutes 注册路由
|
||
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||
// API路由
|
||
mux.HandleFunc("/api/stats", h.basicAuth(h.handleStats))
|
||
mux.HandleFunc("/api/version", h.basicAuth(h.handleVersion))
|
||
mux.HandleFunc("/api/check-update", h.basicAuth(h.handleCheckUpdate))
|
||
mux.HandleFunc("/api/stats/reset", h.basicAuth(h.handleStatsReset))
|
||
|
||
// 静态文件服务
|
||
staticFS, err := fs.Sub(staticFiles, "static")
|
||
if err != nil {
|
||
h.logger.Printf("Failed to load static files: %v", err)
|
||
return
|
||
}
|
||
mux.Handle("/", h.basicAuth(func(w http.ResponseWriter, r *http.Request) {
|
||
http.FileServer(http.FS(staticFS)).ServeHTTP(w, r)
|
||
}))
|
||
}
|
||
|
||
// handleStats 处理统计信息请求
|
||
func (h *Handler) handleStats(w http.ResponseWriter, r *http.Request) {
|
||
// 只允许GET请求
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取统计快照
|
||
snapshot := h.stats.GetSnapshot()
|
||
|
||
// 设置响应头
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||
|
||
// 编码JSON并返回
|
||
if err := json.NewEncoder(w).Encode(snapshot); err != nil {
|
||
h.logger.Printf("Error encoding stats JSON: %v", err)
|
||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
}
|
||
|
||
// ResetResponse 重置响应
|
||
type ResetResponse struct {
|
||
Success bool `json:"success"`
|
||
Message string `json:"message"`
|
||
}
|
||
|
||
// handleStatsReset 处理统计数据重置请求
|
||
func (h *Handler) handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||
// 只允许POST请求
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 重置统计数据
|
||
h.stats.Reset()
|
||
h.logger.Printf("Statistics reset by user request")
|
||
|
||
// 设置响应头
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
|
||
// 返回成功响应
|
||
if err := json.NewEncoder(w).Encode(ResetResponse{
|
||
Success: true,
|
||
Message: "统计数据已重置",
|
||
}); err != nil {
|
||
h.logger.Printf("Error encoding reset response JSON: %v", err)
|
||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
}
|
||
|
||
// VersionResponse 版本信息响应
|
||
type VersionResponse struct {
|
||
Version string `json:"version"`
|
||
}
|
||
|
||
// handleVersion 处理版本查询请求
|
||
func (h *Handler) handleVersion(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
ver := h.version
|
||
if ver == "" {
|
||
ver = "0.0.0"
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||
|
||
if err := json.NewEncoder(w).Encode(VersionResponse{Version: ver}); err != nil {
|
||
h.logger.Printf("Error encoding version JSON: %v", err)
|
||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
}
|
||
|
||
// UpdateCheckResponse 更新检查响应
|
||
type UpdateCheckResponse struct {
|
||
HasUpdate bool `json:"has_update"`
|
||
CurrentVersion string `json:"current_version"`
|
||
LatestVersion string `json:"latest_version"`
|
||
Message string `json:"message"`
|
||
}
|
||
|
||
// handleCheckUpdate 处理检查更新请求(生产者2:用户手动触发)
|
||
func (h *Handler) handleCheckUpdate(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
ver := h.version
|
||
if ver == "" {
|
||
ver = "0.0.0"
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||
|
||
// 触发后台检查更新(非阻塞)
|
||
select {
|
||
case h.checkUpdateCh <- struct{}{}:
|
||
h.logger.Printf("Update check triggered by user")
|
||
json.NewEncoder(w).Encode(UpdateCheckResponse{
|
||
HasUpdate: false,
|
||
CurrentVersion: ver,
|
||
LatestVersion: ver,
|
||
Message: "已触发更新检查,请查看服务器日志",
|
||
})
|
||
default:
|
||
// 如果通道已满,说明已经在检查中
|
||
json.NewEncoder(w).Encode(UpdateCheckResponse{
|
||
HasUpdate: false,
|
||
CurrentVersion: ver,
|
||
LatestVersion: ver,
|
||
Message: "更新检查正在进行中",
|
||
})
|
||
}
|
||
}
|