316 lines
11 KiB
Go
316 lines
11 KiB
Go
// Filename: internal/service/dashboard_query_service.go
|
||
|
||
package service
|
||
|
||
import (
|
||
"fmt"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/store"
|
||
"gemini-balancer/internal/syncer"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sirupsen/logrus"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||
|
||
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
|
||
|
||
type DashboardQueryService struct {
|
||
db *gorm.DB
|
||
store store.Store
|
||
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
|
||
logger *logrus.Entry
|
||
stopChan chan struct{}
|
||
}
|
||
|
||
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
|
||
qs := &DashboardQueryService{
|
||
db: db,
|
||
store: s,
|
||
logger: logger.WithField("component", "DashboardQueryService"),
|
||
stopChan: make(chan struct{}),
|
||
}
|
||
|
||
loader := qs.loadOverviewData
|
||
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
|
||
}
|
||
qs.overviewSyncer = overviewSyncer
|
||
return qs, nil
|
||
}
|
||
|
||
func (s *DashboardQueryService) Start() {
|
||
go s.eventListener()
|
||
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
|
||
}
|
||
|
||
func (s *DashboardQueryService) Stop() {
|
||
close(s.stopChan)
|
||
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
|
||
}
|
||
|
||
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
|
||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||
keyStatsMap, err := s.store.HGetAll(statsKey)
|
||
if err != nil {
|
||
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
|
||
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
|
||
}
|
||
keyStats := make(map[string]int64)
|
||
for k, v := range keyStatsMap {
|
||
val, _ := strconv.ParseInt(v, 10, 64)
|
||
keyStats[k] = val
|
||
}
|
||
now := time.Now()
|
||
oneHourAgo := now.Add(-1 * time.Hour)
|
||
twentyFourHoursAgo := now.Add(-24 * time.Hour)
|
||
type requestStatsResult struct {
|
||
TotalRequests int64
|
||
SuccessRequests int64
|
||
}
|
||
var last1Hour, last24Hours requestStatsResult
|
||
s.db.Model(&models.StatsHourly{}).
|
||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||
Scan(&last1Hour)
|
||
s.db.Model(&models.StatsHourly{}).
|
||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||
Scan(&last24Hours)
|
||
failureRate1h := 0.0
|
||
if last1Hour.TotalRequests > 0 {
|
||
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
|
||
}
|
||
failureRate24h := 0.0
|
||
if last24Hours.TotalRequests > 0 {
|
||
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
|
||
}
|
||
last1HourStats := map[string]any{
|
||
"total_requests": last1Hour.TotalRequests,
|
||
"success_requests": last1Hour.SuccessRequests,
|
||
"failure_rate": failureRate1h,
|
||
}
|
||
last24HoursStats := map[string]any{
|
||
"total_requests": last24Hours.TotalRequests,
|
||
"success_requests": last24Hours.SuccessRequests,
|
||
"failure_rate": failureRate24h,
|
||
}
|
||
result := map[string]any{
|
||
"key_stats": keyStats,
|
||
"last_1_hour": last1HourStats,
|
||
"last_24_hours": last24HoursStats,
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func (s *DashboardQueryService) eventListener() {
|
||
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
|
||
defer keyStatusSub.Close()
|
||
defer upstreamStatusSub.Close()
|
||
for {
|
||
select {
|
||
case <-keyStatusSub.Channel():
|
||
s.logger.Info("Received key status changed event, invalidating overview cache...")
|
||
_ = s.InvalidateOverviewCache()
|
||
case <-upstreamStatusSub.Channel():
|
||
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
|
||
_ = s.InvalidateOverviewCache()
|
||
case <-s.stopChan:
|
||
s.logger.Info("Stopping dashboard event listener.")
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
|
||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||
cachedDataPtr := s.overviewSyncer.Get()
|
||
if cachedDataPtr == nil {
|
||
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
|
||
}
|
||
return cachedDataPtr, nil
|
||
}
|
||
|
||
func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||
return s.overviewSyncer.Invalidate()
|
||
}
|
||
|
||
// QueryHistoricalChart 查询历史图表数据。
|
||
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
|
||
type ChartPoint struct {
|
||
TimeLabel string `gorm:"column:time_label"`
|
||
ModelName string `gorm:"column:model_name"`
|
||
TotalRequests int64 `gorm:"column:total_requests"`
|
||
}
|
||
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
|
||
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
|
||
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
|
||
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
|
||
if groupID != nil && *groupID > 0 {
|
||
query = query.Where("group_id = ?", *groupID)
|
||
}
|
||
var points []ChartPoint
|
||
if err := query.Find(&points).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
datasets := make(map[string]map[string]int64)
|
||
for _, p := range points {
|
||
if _, ok := datasets[p.ModelName]; !ok {
|
||
datasets[p.ModelName] = make(map[string]int64)
|
||
}
|
||
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
|
||
}
|
||
var labels []string
|
||
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
|
||
labels = append(labels, t.Format(goFormat))
|
||
}
|
||
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
|
||
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
|
||
colorIndex := 0
|
||
for modelName, dataPoints := range datasets {
|
||
dataArray := make([]int64, len(labels))
|
||
for i, label := range labels {
|
||
dataArray[i] = dataPoints[label]
|
||
}
|
||
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
|
||
Label: modelName,
|
||
Data: dataArray,
|
||
Color: colorPalette[colorIndex%len(colorPalette)],
|
||
})
|
||
colorIndex++
|
||
}
|
||
return chartData, nil
|
||
}
|
||
|
||
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
|
||
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
|
||
startTime := time.Now()
|
||
resp := &models.DashboardStatsResponse{
|
||
KeyStatusCount: make(map[models.APIKeyStatus]int64),
|
||
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
|
||
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体,而不是nil
|
||
RequestCount24h: models.StatCard{}, // 同上
|
||
TokenCount: make(map[string]any),
|
||
UpstreamHealthStatus: make(map[string]string),
|
||
RPM: models.StatCard{},
|
||
RequestCounts: make(map[string]int64),
|
||
}
|
||
// --- 1. Aggregate Operational Status from Mappings ---
|
||
type MappingStatusResult struct {
|
||
Status models.APIKeyStatus
|
||
Count int64
|
||
}
|
||
var mappingStatusResults []MappingStatusResult
|
||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
|
||
}
|
||
for _, res := range mappingStatusResults {
|
||
resp.KeyStatusCount[res.Status] = res.Count
|
||
}
|
||
|
||
// --- 2. Aggregate Master Status from APIKeys ---
|
||
type MasterStatusResult struct {
|
||
Status models.MasterAPIKeyStatus
|
||
Count int64
|
||
}
|
||
var masterStatusResults []MasterStatusResult
|
||
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to query master status stats: %w", err)
|
||
}
|
||
var totalKeys, invalidKeys int64
|
||
for _, res := range masterStatusResults {
|
||
resp.MasterStatusCount[res.Status] = res.Count
|
||
totalKeys += res.Count
|
||
if res.Status != models.MasterStatusActive {
|
||
invalidKeys += res.Count
|
||
}
|
||
}
|
||
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
|
||
|
||
now := time.Now()
|
||
|
||
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
|
||
var count1m, count1h, count1d int64
|
||
// RPM: 从此刻倒推1分钟
|
||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
|
||
// RPH: 从此刻倒推1小时
|
||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
|
||
|
||
// RPD: 从今天零点 (UTC) 到此刻
|
||
year, month, day := now.UTC().Date()
|
||
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
|
||
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
|
||
var count30d int64
|
||
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
|
||
|
||
resp.RequestCounts["1m"] = count1m
|
||
resp.RequestCounts["1h"] = count1h
|
||
resp.RequestCounts["1d"] = count1d
|
||
resp.RequestCounts["30d"] = count30d
|
||
|
||
var upstreams []*models.UpstreamEndpoint
|
||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
|
||
} else {
|
||
for _, u := range upstreams {
|
||
resp.UpstreamHealthStatus[u.URL] = u.Status
|
||
}
|
||
}
|
||
|
||
duration := time.Since(startTime)
|
||
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
|
||
return resp, nil
|
||
}
|
||
|
||
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
|
||
var startTime time.Time
|
||
now := time.Now()
|
||
switch period {
|
||
case "1m":
|
||
startTime = now.Add(-1 * time.Minute)
|
||
case "1h":
|
||
startTime = now.Add(-1 * time.Hour)
|
||
case "1d":
|
||
year, month, day := now.UTC().Date()
|
||
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||
default:
|
||
return nil, fmt.Errorf("invalid period specified: %s", period)
|
||
}
|
||
var result struct {
|
||
Total int64
|
||
Success int64
|
||
}
|
||
|
||
err := s.db.Model(&models.RequestLog{}).
|
||
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
|
||
Where("request_time >= ?", startTime).
|
||
Scan(&result).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return gin.H{
|
||
"total": result.Total,
|
||
"success": result.Success,
|
||
"failure": result.Total - result.Success,
|
||
}, nil
|
||
}
|
||
|
||
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
|
||
dialect := s.db.Dialector.Name()
|
||
switch dialect {
|
||
case "mysql":
|
||
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
|
||
case "sqlite":
|
||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||
default:
|
||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||
}
|
||
}
|