Files
gemini-banlancer/internal/service/dashboard_query_service.go

797 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/service/dashboard_query_service.go
package service
import (
"context"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
overviewCacheChannel = "syncer:cache:dashboard_overview"
defaultChartDays = 7
cacheLoadTimeout = 30 * time.Second
)
var (
// 图表颜色调色板
chartColorPalette = []string{
"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0",
"#9966FF", "#FF9F40", "#C9CBCF", "#4D5360",
}
)
type DashboardQueryService struct {
db *gorm.DB
store store.Store
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
logger *logrus.Entry
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
queryCount atomic.Uint64
cacheHits atomic.Uint64
cacheMisses atomic.Uint64
overviewLoadCount atomic.Uint64
lastQueryTime time.Time
lastQueryMutex sync.RWMutex
}
func NewDashboardQueryService(
db *gorm.DB,
s store.Store,
logger *logrus.Logger,
) (*DashboardQueryService, error) {
ctx, cancel := context.WithCancel(context.Background())
service := &DashboardQueryService{
db: db,
store: s,
logger: logger.WithField("component", "DashboardQuery📈"),
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
lastQueryTime: time.Now(),
}
// 创建 CacheSyncer
overviewSyncer, err := syncer.NewCacheSyncer(
service.loadOverviewData,
s,
overviewCacheChannel,
logger,
)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
}
service.overviewSyncer = overviewSyncer
return service, nil
}
func (s *DashboardQueryService) Start() {
s.wg.Add(2)
go s.eventListener()
go s.metricsReporter()
s.logger.Info("DashboardQueryService started")
}
func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService stopping...")
close(s.stopChan)
s.cancel()
s.wg.Wait()
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"total_queries": s.queryCount.Load(),
"cache_hits": s.cacheHits.Load(),
"cache_misses": s.cacheMisses.Load(),
"overview_loads": s.overviewLoadCount.Load(),
}).Info("DashboardQueryService stopped")
}
// ==================== 核心查询方法 ====================
// GetDashboardOverviewData 获取仪表盘概览数据(带缓存)
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
s.queryCount.Add(1)
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
s.cacheMisses.Add(1)
s.logger.Warn("Overview cache is empty, attempting to load...")
// 触发立即加载
if err := s.overviewSyncer.Invalidate(); err != nil {
return nil, fmt.Errorf("failed to trigger cache reload: %w", err)
}
// 等待加载完成最多30秒
ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout)
defer cancel()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if data := s.overviewSyncer.Get(); data != nil {
s.cacheHits.Add(1)
return data, nil
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout waiting for overview cache to load")
}
}
}
s.cacheHits.Add(1)
return cachedDataPtr, nil
}
// InvalidateOverviewCache 手动失效概览缓存
func (s *DashboardQueryService) InvalidateOverviewCache() error {
s.logger.Info("Manually invalidating overview cache")
return s.overviewSyncer.Invalidate()
}
// GetGroupStats 获取指定分组的统计数据
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
start := time.Now()
// 1. 从 Redis 获取 Key 统计
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
}
keyStats := make(map[string]int64)
for k, v := range keyStatsMap {
val, _ := strconv.ParseInt(v, 10, 64)
keyStats[k] = val
}
// 2. 查询请求统计(使用 UTC 时间)
now := time.Now().UTC()
oneHourAgo := now.Add(-1 * time.Hour)
twentyFourHoursAgo := now.Add(-24 * time.Hour)
type requestStatsResult struct {
TotalRequests int64
SuccessRequests int64
}
var last1Hour, last24Hours requestStatsResult
// 并发查询优化
var wg sync.WaitGroup
errChan := make(chan error, 2)
wg.Add(2)
// 查询最近1小时
go func() {
defer wg.Done()
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour).Error; err != nil {
errChan <- fmt.Errorf("failed to query 1h stats: %w", err)
}
}()
// 查询最近24小时
go func() {
defer wg.Done()
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours).Error; err != nil {
errChan <- fmt.Errorf("failed to query 24h stats: %w", err)
}
}()
wg.Wait()
close(errChan)
// 检查错误
for err := range errChan {
if err != nil {
return nil, err
}
}
// 3. 计算失败率
failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests)
failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests)
result := map[string]any{
"key_stats": keyStats,
"last_1_hour": map[string]any{
"total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests,
"failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests,
"failure_rate": failureRate1h,
},
"last_24_hours": map[string]any{
"total_requests": last24Hours.TotalRequests,
"success_requests": last24Hours.SuccessRequests,
"failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests,
"failure_rate": failureRate24h,
},
}
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"duration": duration,
}).Debug("Group stats query completed")
return result, nil
}
// QueryHistoricalChart 查询历史图表数据
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
start := time.Now()
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
TotalRequests int64 `gorm:"column:total_requests"`
}
// 查询最近7天数据使用 UTC
sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour)
// 根据数据库类型构建时间格式化子句
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf(
"%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests",
sqlFormat,
)
// 构建查询
query := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Select(selectClause).
Where("time >= ?", sevenDaysAgo).
Group("time_label, model_name").
Order("time_label ASC")
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
var points []ChartPoint
if err := query.Find(&points).Error; err != nil {
return nil, fmt.Errorf("failed to query chart data: %w", err)
}
// 构建数据集
datasets := make(map[string]map[string]int64)
for _, p := range points {
if _, ok := datasets[p.ModelName]; !ok {
datasets[p.ModelName] = make(map[string]int64)
}
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
}
// 生成时间标签(按小时)
var labels []string
for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) {
labels = append(labels, t.Format(goFormat))
}
// 构建图表数据
chartData := &models.ChartData{
Labels: labels,
Datasets: make([]models.ChartDataset, 0, len(datasets)),
}
colorIndex := 0
for modelName, dataPoints := range datasets {
dataArray := make([]int64, len(labels))
for i, label := range labels {
dataArray[i] = dataPoints[label]
}
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
Label: modelName,
Data: dataArray,
Color: chartColorPalette[colorIndex%len(chartColorPalette)],
})
colorIndex++
}
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"points": len(points),
"datasets": len(chartData.Datasets),
"duration": duration,
}).Debug("Historical chart query completed")
return chartData, nil
}
// GetRequestStatsForPeriod 获取指定时间段的请求统计
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
var startTime time.Time
now := time.Now().UTC()
switch period {
case "1m":
startTime = now.Add(-1 * time.Minute)
case "1h":
startTime = now.Add(-1 * time.Hour)
case "1d":
year, month, day := now.Date()
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
default:
return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period)
}
var result struct {
Total int64
Success int64
}
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error
if err != nil {
return nil, fmt.Errorf("failed to query request stats: %w", err)
}
return gin.H{
"period": period,
"total": result.Total,
"success": result.Success,
"failure": result.Total - result.Success,
}, nil
}
// ==================== 内部方法 ====================
// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用)
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
s.overviewLoadCount.Add(1)
startTime := time.Now()
s.logger.Info("Starting to load dashboard overview data...")
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{},
RequestCount24h: models.StatCard{},
TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{},
RequestCounts: make(map[string]int64),
}
var loadErr error
var wg sync.WaitGroup
errChan := make(chan error, 10)
// 1. 并发加载 Key 映射状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMappingStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("mapping stats: %w", err)
}
}()
// 2. 并发加载 Master Key 状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMasterStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("master stats: %w", err)
}
}()
// 3. 并发加载请求统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadRequestCounts(ctx, resp); err != nil {
errChan <- fmt.Errorf("request counts: %w", err)
}
}()
// 4. 并发加载上游健康状态
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadUpstreamHealth(ctx, resp); err != nil {
// 上游健康状态失败不阻塞整体加载
s.logger.WithError(err).Warn("Failed to load upstream health status")
}
}()
// 等待所有加载任务完成
wg.Wait()
close(errChan)
// 收集错误
for err := range errChan {
if err != nil {
loadErr = err
break
}
}
if loadErr != nil {
s.logger.WithError(loadErr).Error("Failed to load overview data")
return nil, loadErr
}
duration := time.Since(startTime)
s.logger.WithFields(logrus.Fields{
"duration": duration,
"total_keys": resp.KeyCount.Value,
"requests_1d": resp.RequestCounts["1d"],
"upstreams": len(resp.UpstreamHealthStatus),
}).Info("Successfully loaded dashboard overview data")
return resp, nil
}
// loadMappingStatusStats 加载 Key 映射状态统计
func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var results []MappingStatusResult
if err := s.db.WithContext(ctx).
Model(&models.GroupAPIKeyMapping{}).
Select("status, COUNT(*) as count").
Group("status").
Find(&results).Error; err != nil {
return err
}
for _, res := range results {
resp.KeyStatusCount[res.Status] = res.Count
}
return nil
}
// loadMasterStatusStats 加载 Master Key 状态统计
func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var results []MasterStatusResult
if err := s.db.WithContext(ctx).
Model(&models.APIKey{}).
Select("master_status as status, COUNT(*) as count").
Group("master_status").
Find(&results).Error; err != nil {
return err
}
var totalKeys, invalidKeys int64
for _, res := range results {
resp.MasterStatusCount[res.Status] = res.Count
totalKeys += res.Count
if res.Status != models.MasterStatusActive {
invalidKeys += res.Count
}
}
resp.KeyCount = models.StatCard{
Value: float64(totalKeys),
SubValue: invalidKeys,
SubValueTip: "非活跃身份密钥数",
}
return nil
}
// loadRequestCounts 加载请求计数统计
func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error {
now := time.Now().UTC()
// 使用 RequestLog 表查询短期数据
var count1m, count1h int64
// 最近1分钟
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Minute)).
Count(&count1m).Error; err != nil {
return fmt.Errorf("1m count: %w", err)
}
// 最近1小时
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Hour)).
Count(&count1h).Error; err != nil {
return fmt.Errorf("1h count: %w", err)
}
// 今天UTC
year, month, day := now.Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
var count1d int64
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", startOfDay).
Count(&count1d).Error; err != nil {
return fmt.Errorf("1d count: %w", err)
}
// 最近30天使用聚合表
var count30d int64
if err := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Where("time >= ?", now.AddDate(0, 0, -30)).
Select("COALESCE(SUM(request_count), 0)").
Scan(&count30d).Error; err != nil {
return fmt.Errorf("30d count: %w", err)
}
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
resp.RequestCounts["1d"] = count1d
resp.RequestCounts["30d"] = count30d
return nil
}
// loadUpstreamHealth 加载上游健康状态
func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error {
var upstreams []*models.UpstreamEndpoint
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
return err
}
for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status
}
return nil
}
// ==================== 事件监听 ====================
// eventListener 监听缓存失效事件
func (s *DashboardQueryService) eventListener() {
defer s.wg.Done()
// 订阅事件
keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged)
// 错误处理
if err1 != nil {
s.logger.WithError(err1).Error("Failed to subscribe to key status events")
keyStatusSub = nil
}
if err2 != nil {
s.logger.WithError(err2).Error("Failed to subscribe to upstream status events")
upstreamStatusSub = nil
}
// 如果全部失败,直接返回
if keyStatusSub == nil && upstreamStatusSub == nil {
s.logger.Error("All event subscriptions failed, listener disabled")
return
}
// 安全关闭订阅
defer func() {
if keyStatusSub != nil {
if err := keyStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close key status subscription")
}
}
if upstreamStatusSub != nil {
if err := upstreamStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close upstream status subscription")
}
}
}()
s.logger.WithFields(logrus.Fields{
"key_status_sub": keyStatusSub != nil,
"upstream_status_sub": upstreamStatusSub != nil,
}).Info("Event listener started")
neverReady := make(chan *store.Message)
close(neverReady) // 立即关闭,确保永远不会阻塞
for {
// 动态选择有效的 channel
var keyStatusChan <-chan *store.Message = neverReady
if keyStatusSub != nil {
keyStatusChan = keyStatusSub.Channel()
}
var upstreamStatusChan <-chan *store.Message = neverReady
if upstreamStatusSub != nil {
upstreamStatusChan = upstreamStatusSub.Channel()
}
select {
case _, ok := <-keyStatusChan:
if !ok {
s.logger.Warn("Key status channel closed")
keyStatusSub = nil
continue
}
s.logger.Debug("Received key status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on key status change")
}
case _, ok := <-upstreamStatusChan:
if !ok {
s.logger.Warn("Upstream status channel closed")
upstreamStatusSub = nil
continue
}
s.logger.Debug("Received upstream status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change")
}
case <-s.stopChan:
s.logger.Info("Event listener stopping (stopChan)")
return
case <-s.ctx.Done():
s.logger.Info("Event listener stopping (context cancelled)")
return
}
}
}
// ==================== 监控指标 ====================
// metricsReporter 定期输出统计信息
func (s *DashboardQueryService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *DashboardQueryService) reportMetrics() {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
totalQueries := s.queryCount.Load()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
s.logger.WithFields(logrus.Fields{
"total_queries": totalQueries,
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate),
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Round(time.Second),
}).Info("DashboardQuery metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *DashboardQueryService) GetMetrics() map[string]interface{} {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
return map[string]interface{}{
"total_queries": s.queryCount.Load(),
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": cacheHitRate,
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Seconds(),
}
}
// ==================== 辅助方法 ====================
// calculateFailureRate 计算失败率
func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 {
if total == 0 {
return 0.0
}
return float64(total-success) / float64(total) * 100
}
// updateLastQueryTime 更新最后查询时间
func (s *DashboardQueryService) updateLastQueryTime() {
s.lastQueryMutex.Lock()
s.lastQueryTime = time.Now()
s.lastQueryMutex.Unlock()
}
// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
dialect := s.db.Dialector.Name()
switch dialect {
case "mysql":
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
case "postgres":
return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00"
case "sqlite":
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
default:
s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format")
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
}
}