Files
gemini-banlancer/internal/service/dashboard_query_service.go
2025-11-20 12:24:05 +08:00

316 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/service/dashboard_query_service.go
package service
import (
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const overviewCacheChannel = "syncer:cache:dashboard_overview"
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
type DashboardQueryService struct {
db *gorm.DB
store store.Store
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
logger *logrus.Entry
stopChan chan struct{}
}
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
qs := &DashboardQueryService{
db: db,
store: s,
logger: logger.WithField("component", "DashboardQueryService"),
stopChan: make(chan struct{}),
}
loader := qs.loadOverviewData
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
if err != nil {
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
}
qs.overviewSyncer = overviewSyncer
return qs, nil
}
func (s *DashboardQueryService) Start() {
go s.eventListener()
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
}
func (s *DashboardQueryService) Stop() {
close(s.stopChan)
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
}
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(statsKey)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
}
keyStats := make(map[string]int64)
for k, v := range keyStatsMap {
val, _ := strconv.ParseInt(v, 10, 64)
keyStats[k] = val
}
now := time.Now()
oneHourAgo := now.Add(-1 * time.Hour)
twentyFourHoursAgo := now.Add(-24 * time.Hour)
type requestStatsResult struct {
TotalRequests int64
SuccessRequests int64
}
var last1Hour, last24Hours requestStatsResult
s.db.Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour)
s.db.Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours)
failureRate1h := 0.0
if last1Hour.TotalRequests > 0 {
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
}
failureRate24h := 0.0
if last24Hours.TotalRequests > 0 {
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
}
last1HourStats := map[string]any{
"total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests,
"failure_rate": failureRate1h,
}
last24HoursStats := map[string]any{
"total_requests": last24Hours.TotalRequests,
"success_requests": last24Hours.SuccessRequests,
"failure_rate": failureRate24h,
}
result := map[string]any{
"key_stats": keyStats,
"last_1_hour": last1HourStats,
"last_24_hours": last24HoursStats,
}
return result, nil
}
func (s *DashboardQueryService) eventListener() {
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close()
defer upstreamStatusSub.Close()
for {
select {
case <-keyStatusSub.Channel():
s.logger.Info("Received key status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-upstreamStatusSub.Channel():
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-s.stopChan:
s.logger.Info("Stopping dashboard event listener.")
return
}
}
}
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
}
return cachedDataPtr, nil
}
func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate()
}
// QueryHistoricalChart 查询历史图表数据。
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
TotalRequests int64 `gorm:"column:total_requests"`
}
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
var points []ChartPoint
if err := query.Find(&points).Error; err != nil {
return nil, err
}
datasets := make(map[string]map[string]int64)
for _, p := range points {
if _, ok := datasets[p.ModelName]; !ok {
datasets[p.ModelName] = make(map[string]int64)
}
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
}
var labels []string
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
labels = append(labels, t.Format(goFormat))
}
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
colorIndex := 0
for modelName, dataPoints := range datasets {
dataArray := make([]int64, len(labels))
for i, label := range labels {
dataArray[i] = dataPoints[label]
}
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
Label: modelName,
Data: dataArray,
Color: colorPalette[colorIndex%len(colorPalette)],
})
colorIndex++
}
return chartData, nil
}
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
startTime := time.Now()
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体而不是nil
RequestCount24h: models.StatCard{}, // 同上
TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{},
RequestCounts: make(map[string]int64),
}
// --- 1. Aggregate Operational Status from Mappings ---
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var mappingStatusResults []MappingStatusResult
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
}
for _, res := range mappingStatusResults {
resp.KeyStatusCount[res.Status] = res.Count
}
// --- 2. Aggregate Master Status from APIKeys ---
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var masterStatusResults []MasterStatusResult
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query master status stats: %w", err)
}
var totalKeys, invalidKeys int64
for _, res := range masterStatusResults {
resp.MasterStatusCount[res.Status] = res.Count
totalKeys += res.Count
if res.Status != models.MasterStatusActive {
invalidKeys += res.Count
}
}
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
now := time.Now()
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
var count1m, count1h, count1d int64
// RPM: 从此刻倒推1分钟
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
// RPH: 从此刻倒推1小时
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
// RPD: 从今天零点 (UTC) 到此刻
year, month, day := now.UTC().Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
var count30d int64
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
resp.RequestCounts["1d"] = count1d
resp.RequestCounts["30d"] = count30d
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
} else {
for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status
}
}
duration := time.Since(startTime)
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
return resp, nil
}
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
var startTime time.Time
now := time.Now()
switch period {
case "1m":
startTime = now.Add(-1 * time.Minute)
case "1h":
startTime = now.Add(-1 * time.Hour)
case "1d":
year, month, day := now.UTC().Date()
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
default:
return nil, fmt.Errorf("invalid period specified: %s", period)
}
var result struct {
Total int64
Success int64
}
err := s.db.Model(&models.RequestLog{}).
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error
if err != nil {
return nil, err
}
return gin.H{
"total": result.Total,
"success": result.Success,
"failure": result.Total - result.Success,
}, nil
}
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
dialect := s.db.Dialector.Name()
switch dialect {
case "mysql":
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
case "sqlite":
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
default:
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
}
}