// Filename: internal/service/dashboard_query_service.go package service import ( "context" "fmt" "strconv" "sync" "sync/atomic" "time" "gemini-balancer/internal/models" "gemini-balancer/internal/store" "gemini-balancer/internal/syncer" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "gorm.io/gorm" ) const ( overviewCacheChannel = "syncer:cache:dashboard_overview" defaultChartDays = 7 cacheLoadTimeout = 30 * time.Second ) var ( // 图表颜色调色板 chartColorPalette = []string{ "#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40", "#C9CBCF", "#4D5360", } ) type DashboardQueryService struct { db *gorm.DB store store.Store overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse] logger *logrus.Entry stopChan chan struct{} wg sync.WaitGroup ctx context.Context cancel context.CancelFunc // 统计指标 queryCount atomic.Uint64 cacheHits atomic.Uint64 cacheMisses atomic.Uint64 overviewLoadCount atomic.Uint64 lastQueryTime time.Time lastQueryMutex sync.RWMutex } func NewDashboardQueryService( db *gorm.DB, s store.Store, logger *logrus.Logger, ) (*DashboardQueryService, error) { ctx, cancel := context.WithCancel(context.Background()) service := &DashboardQueryService{ db: db, store: s, logger: logger.WithField("component", "DashboardQuery📈"), stopChan: make(chan struct{}), ctx: ctx, cancel: cancel, lastQueryTime: time.Now(), } // 创建 CacheSyncer overviewSyncer, err := syncer.NewCacheSyncer( service.loadOverviewData, s, overviewCacheChannel, logger, ) if err != nil { cancel() return nil, fmt.Errorf("failed to create overview cache syncer: %w", err) } service.overviewSyncer = overviewSyncer return service, nil } func (s *DashboardQueryService) Start() { s.wg.Add(2) go s.eventListener() go s.metricsReporter() s.logger.Info("DashboardQueryService started") } func (s *DashboardQueryService) Stop() { s.logger.Info("DashboardQueryService stopping...") close(s.stopChan) s.cancel() s.wg.Wait() // 输出最终统计 s.logger.WithFields(logrus.Fields{ "total_queries": s.queryCount.Load(), "cache_hits": s.cacheHits.Load(), "cache_misses": s.cacheMisses.Load(), "overview_loads": s.overviewLoadCount.Load(), }).Info("DashboardQueryService stopped") } // ==================== 核心查询方法 ==================== // GetDashboardOverviewData 获取仪表盘概览数据(带缓存) func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) { s.queryCount.Add(1) cachedDataPtr := s.overviewSyncer.Get() if cachedDataPtr == nil { s.cacheMisses.Add(1) s.logger.Warn("Overview cache is empty, attempting to load...") // 触发立即加载 if err := s.overviewSyncer.Invalidate(); err != nil { return nil, fmt.Errorf("failed to trigger cache reload: %w", err) } // 等待加载完成(最多30秒) ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout) defer cancel() ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() for { select { case <-ticker.C: if data := s.overviewSyncer.Get(); data != nil { s.cacheHits.Add(1) return data, nil } case <-ctx.Done(): return nil, fmt.Errorf("timeout waiting for overview cache to load") } } } s.cacheHits.Add(1) return cachedDataPtr, nil } // InvalidateOverviewCache 手动失效概览缓存 func (s *DashboardQueryService) InvalidateOverviewCache() error { s.logger.Info("Manually invalidating overview cache") return s.overviewSyncer.Invalidate() } // GetGroupStats 获取指定分组的统计数据 func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) { s.queryCount.Add(1) s.updateLastQueryTime() start := time.Now() // 1. 从 Redis 获取 Key 统计 statsKey := fmt.Sprintf("stats:group:%d", groupID) keyStatsMap, err := s.store.HGetAll(ctx, statsKey) if err != nil { s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID) return nil, fmt.Errorf("failed to get key stats from cache: %w", err) } keyStats := make(map[string]int64) for k, v := range keyStatsMap { val, _ := strconv.ParseInt(v, 10, 64) keyStats[k] = val } // 2. 查询请求统计(使用 UTC 时间) now := time.Now().UTC() oneHourAgo := now.Add(-1 * time.Hour) twentyFourHoursAgo := now.Add(-24 * time.Hour) type requestStatsResult struct { TotalRequests int64 SuccessRequests int64 } var last1Hour, last24Hours requestStatsResult // 并发查询优化 var wg sync.WaitGroup errChan := make(chan error, 2) wg.Add(2) // 查询最近1小时 go func() { defer wg.Done() if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}). Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests"). Where("group_id = ? AND time >= ?", groupID, oneHourAgo). Scan(&last1Hour).Error; err != nil { errChan <- fmt.Errorf("failed to query 1h stats: %w", err) } }() // 查询最近24小时 go func() { defer wg.Done() if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}). Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests"). Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo). Scan(&last24Hours).Error; err != nil { errChan <- fmt.Errorf("failed to query 24h stats: %w", err) } }() wg.Wait() close(errChan) // 检查错误 for err := range errChan { if err != nil { return nil, err } } // 3. 计算失败率 failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests) failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests) result := map[string]any{ "key_stats": keyStats, "last_1_hour": map[string]any{ "total_requests": last1Hour.TotalRequests, "success_requests": last1Hour.SuccessRequests, "failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests, "failure_rate": failureRate1h, }, "last_24_hours": map[string]any{ "total_requests": last24Hours.TotalRequests, "success_requests": last24Hours.SuccessRequests, "failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests, "failure_rate": failureRate24h, }, } duration := time.Since(start) s.logger.WithFields(logrus.Fields{ "group_id": groupID, "duration": duration, }).Debug("Group stats query completed") return result, nil } // QueryHistoricalChart 查询历史图表数据 func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) { s.queryCount.Add(1) s.updateLastQueryTime() start := time.Now() type ChartPoint struct { TimeLabel string `gorm:"column:time_label"` ModelName string `gorm:"column:model_name"` TotalRequests int64 `gorm:"column:total_requests"` } // 查询最近7天数据(使用 UTC) sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour) // 根据数据库类型构建时间格式化子句 sqlFormat, goFormat := s.buildTimeFormatSelectClause() selectClause := fmt.Sprintf( "%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests", sqlFormat, ) // 构建查询 query := s.db.WithContext(ctx). Model(&models.StatsHourly{}). Select(selectClause). Where("time >= ?", sevenDaysAgo). Group("time_label, model_name"). Order("time_label ASC") if groupID != nil && *groupID > 0 { query = query.Where("group_id = ?", *groupID) } var points []ChartPoint if err := query.Find(&points).Error; err != nil { return nil, fmt.Errorf("failed to query chart data: %w", err) } // 构建数据集 datasets := make(map[string]map[string]int64) for _, p := range points { if _, ok := datasets[p.ModelName]; !ok { datasets[p.ModelName] = make(map[string]int64) } datasets[p.ModelName][p.TimeLabel] = p.TotalRequests } // 生成时间标签(按小时) var labels []string for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) { labels = append(labels, t.Format(goFormat)) } // 构建图表数据 chartData := &models.ChartData{ Labels: labels, Datasets: make([]models.ChartDataset, 0, len(datasets)), } colorIndex := 0 for modelName, dataPoints := range datasets { dataArray := make([]int64, len(labels)) for i, label := range labels { dataArray[i] = dataPoints[label] } chartData.Datasets = append(chartData.Datasets, models.ChartDataset{ Label: modelName, Data: dataArray, Color: chartColorPalette[colorIndex%len(chartColorPalette)], }) colorIndex++ } duration := time.Since(start) s.logger.WithFields(logrus.Fields{ "group_id": groupID, "points": len(points), "datasets": len(chartData.Datasets), "duration": duration, }).Debug("Historical chart query completed") return chartData, nil } // GetRequestStatsForPeriod 获取指定时间段的请求统计 func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) { s.queryCount.Add(1) s.updateLastQueryTime() var startTime time.Time now := time.Now().UTC() switch period { case "1m": startTime = now.Add(-1 * time.Minute) case "1h": startTime = now.Add(-1 * time.Hour) case "1d": year, month, day := now.Date() startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC) default: return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period) } var result struct { Total int64 Success int64 } err := s.db.WithContext(ctx).Model(&models.RequestLog{}). Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success"). Where("request_time >= ?", startTime). Scan(&result).Error if err != nil { return nil, fmt.Errorf("failed to query request stats: %w", err) } return gin.H{ "period": period, "total": result.Total, "success": result.Success, "failure": result.Total - result.Success, }, nil } // ==================== 内部方法 ==================== // loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用) func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() s.overviewLoadCount.Add(1) startTime := time.Now() s.logger.Info("Starting to load dashboard overview data...") resp := &models.DashboardStatsResponse{ KeyStatusCount: make(map[models.APIKeyStatus]int64), MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64), KeyCount: models.StatCard{}, RequestCount24h: models.StatCard{}, TokenCount: make(map[string]any), UpstreamHealthStatus: make(map[string]string), RPM: models.StatCard{}, RequestCounts: make(map[string]int64), } var loadErr error var wg sync.WaitGroup errChan := make(chan error, 10) // 1. 并发加载 Key 映射状态统计 wg.Add(1) go func() { defer wg.Done() if err := s.loadMappingStatusStats(ctx, resp); err != nil { errChan <- fmt.Errorf("mapping stats: %w", err) } }() // 2. 并发加载 Master Key 状态统计 wg.Add(1) go func() { defer wg.Done() if err := s.loadMasterStatusStats(ctx, resp); err != nil { errChan <- fmt.Errorf("master stats: %w", err) } }() // 3. 并发加载请求统计 wg.Add(1) go func() { defer wg.Done() if err := s.loadRequestCounts(ctx, resp); err != nil { errChan <- fmt.Errorf("request counts: %w", err) } }() // 4. 并发加载上游健康状态 wg.Add(1) go func() { defer wg.Done() if err := s.loadUpstreamHealth(ctx, resp); err != nil { // 上游健康状态失败不阻塞整体加载 s.logger.WithError(err).Warn("Failed to load upstream health status") } }() // 等待所有加载任务完成 wg.Wait() close(errChan) // 收集错误 for err := range errChan { if err != nil { loadErr = err break } } if loadErr != nil { s.logger.WithError(loadErr).Error("Failed to load overview data") return nil, loadErr } duration := time.Since(startTime) s.logger.WithFields(logrus.Fields{ "duration": duration, "total_keys": resp.KeyCount.Value, "requests_1d": resp.RequestCounts["1d"], "upstreams": len(resp.UpstreamHealthStatus), }).Info("Successfully loaded dashboard overview data") return resp, nil } // loadMappingStatusStats 加载 Key 映射状态统计 func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error { type MappingStatusResult struct { Status models.APIKeyStatus Count int64 } var results []MappingStatusResult if err := s.db.WithContext(ctx). Model(&models.GroupAPIKeyMapping{}). Select("status, COUNT(*) as count"). Group("status"). Find(&results).Error; err != nil { return err } for _, res := range results { resp.KeyStatusCount[res.Status] = res.Count } return nil } // loadMasterStatusStats 加载 Master Key 状态统计 func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error { type MasterStatusResult struct { Status models.MasterAPIKeyStatus Count int64 } var results []MasterStatusResult if err := s.db.WithContext(ctx). Model(&models.APIKey{}). Select("master_status as status, COUNT(*) as count"). Group("master_status"). Find(&results).Error; err != nil { return err } var totalKeys, invalidKeys int64 for _, res := range results { resp.MasterStatusCount[res.Status] = res.Count totalKeys += res.Count if res.Status != models.MasterStatusActive { invalidKeys += res.Count } } resp.KeyCount = models.StatCard{ Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数", } return nil } // loadRequestCounts 加载请求计数统计 func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error { now := time.Now().UTC() // 使用 RequestLog 表查询短期数据 var count1m, count1h int64 // 最近1分钟 if err := s.db.WithContext(ctx). Model(&models.RequestLog{}). Where("request_time >= ?", now.Add(-1*time.Minute)). Count(&count1m).Error; err != nil { return fmt.Errorf("1m count: %w", err) } // 最近1小时 if err := s.db.WithContext(ctx). Model(&models.RequestLog{}). Where("request_time >= ?", now.Add(-1*time.Hour)). Count(&count1h).Error; err != nil { return fmt.Errorf("1h count: %w", err) } // 今天(UTC) year, month, day := now.Date() startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC) var count1d int64 if err := s.db.WithContext(ctx). Model(&models.RequestLog{}). Where("request_time >= ?", startOfDay). Count(&count1d).Error; err != nil { return fmt.Errorf("1d count: %w", err) } // 最近30天使用聚合表 var count30d int64 if err := s.db.WithContext(ctx). Model(&models.StatsHourly{}). Where("time >= ?", now.AddDate(0, 0, -30)). Select("COALESCE(SUM(request_count), 0)"). Scan(&count30d).Error; err != nil { return fmt.Errorf("30d count: %w", err) } resp.RequestCounts["1m"] = count1m resp.RequestCounts["1h"] = count1h resp.RequestCounts["1d"] = count1d resp.RequestCounts["30d"] = count30d return nil } // loadUpstreamHealth 加载上游健康状态 func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error { var upstreams []*models.UpstreamEndpoint if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil { return err } for _, u := range upstreams { resp.UpstreamHealthStatus[u.URL] = u.Status } return nil } // ==================== 事件监听 ==================== // eventListener 监听缓存失效事件 func (s *DashboardQueryService) eventListener() { defer s.wg.Done() // 订阅事件 keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged) upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged) // 错误处理 if err1 != nil { s.logger.WithError(err1).Error("Failed to subscribe to key status events") keyStatusSub = nil } if err2 != nil { s.logger.WithError(err2).Error("Failed to subscribe to upstream status events") upstreamStatusSub = nil } // 如果全部失败,直接返回 if keyStatusSub == nil && upstreamStatusSub == nil { s.logger.Error("All event subscriptions failed, listener disabled") return } // 安全关闭订阅 defer func() { if keyStatusSub != nil { if err := keyStatusSub.Close(); err != nil { s.logger.WithError(err).Warn("Failed to close key status subscription") } } if upstreamStatusSub != nil { if err := upstreamStatusSub.Close(); err != nil { s.logger.WithError(err).Warn("Failed to close upstream status subscription") } } }() s.logger.WithFields(logrus.Fields{ "key_status_sub": keyStatusSub != nil, "upstream_status_sub": upstreamStatusSub != nil, }).Info("Event listener started") neverReady := make(chan *store.Message) close(neverReady) // 立即关闭,确保永远不会阻塞 for { // 动态选择有效的 channel var keyStatusChan <-chan *store.Message = neverReady if keyStatusSub != nil { keyStatusChan = keyStatusSub.Channel() } var upstreamStatusChan <-chan *store.Message = neverReady if upstreamStatusSub != nil { upstreamStatusChan = upstreamStatusSub.Channel() } select { case _, ok := <-keyStatusChan: if !ok { s.logger.Warn("Key status channel closed") keyStatusSub = nil continue } s.logger.Debug("Received key status changed event") if err := s.InvalidateOverviewCache(); err != nil { s.logger.WithError(err).Warn("Failed to invalidate cache on key status change") } case _, ok := <-upstreamStatusChan: if !ok { s.logger.Warn("Upstream status channel closed") upstreamStatusSub = nil continue } s.logger.Debug("Received upstream status changed event") if err := s.InvalidateOverviewCache(); err != nil { s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change") } case <-s.stopChan: s.logger.Info("Event listener stopping (stopChan)") return case <-s.ctx.Done(): s.logger.Info("Event listener stopping (context cancelled)") return } } } // ==================== 监控指标 ==================== // metricsReporter 定期输出统计信息 func (s *DashboardQueryService) metricsReporter() { defer s.wg.Done() ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: s.reportMetrics() case <-s.stopChan: return case <-s.ctx.Done(): return } } } func (s *DashboardQueryService) reportMetrics() { s.lastQueryMutex.RLock() lastQuery := s.lastQueryTime s.lastQueryMutex.RUnlock() totalQueries := s.queryCount.Load() hits := s.cacheHits.Load() misses := s.cacheMisses.Load() var cacheHitRate float64 if hits+misses > 0 { cacheHitRate = float64(hits) / float64(hits+misses) * 100 } s.logger.WithFields(logrus.Fields{ "total_queries": totalQueries, "cache_hits": hits, "cache_misses": misses, "cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate), "overview_loads": s.overviewLoadCount.Load(), "last_query_ago": time.Since(lastQuery).Round(time.Second), }).Info("DashboardQuery metrics") } // GetMetrics 返回当前统计指标(供监控使用) func (s *DashboardQueryService) GetMetrics() map[string]interface{} { s.lastQueryMutex.RLock() lastQuery := s.lastQueryTime s.lastQueryMutex.RUnlock() hits := s.cacheHits.Load() misses := s.cacheMisses.Load() var cacheHitRate float64 if hits+misses > 0 { cacheHitRate = float64(hits) / float64(hits+misses) * 100 } return map[string]interface{}{ "total_queries": s.queryCount.Load(), "cache_hits": hits, "cache_misses": misses, "cache_hit_rate": cacheHitRate, "overview_loads": s.overviewLoadCount.Load(), "last_query_ago": time.Since(lastQuery).Seconds(), } } // ==================== 辅助方法 ==================== // calculateFailureRate 计算失败率 func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 { if total == 0 { return 0.0 } return float64(total-success) / float64(total) * 100 } // updateLastQueryTime 更新最后查询时间 func (s *DashboardQueryService) updateLastQueryTime() { s.lastQueryMutex.Lock() s.lastQueryTime = time.Now() s.lastQueryMutex.Unlock() } // buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句 func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) { dialect := s.db.Dialector.Name() switch dialect { case "mysql": return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00" case "postgres": return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00" case "sqlite": return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00" default: s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format") return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00" } }