// Filename: internal/service/analytics_service.go package service import ( "context" "encoding/json" "fmt" "strconv" "strings" "sync" "sync/atomic" "time" "gemini-balancer/internal/db/dialect" "gemini-balancer/internal/models" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" "github.com/sirupsen/logrus" "gorm.io/gorm" ) const ( defaultFlushInterval = 1 * time.Minute maxRetryAttempts = 3 retryDelay = 5 * time.Second ) type AnalyticsService struct { db *gorm.DB store store.Store logger *logrus.Entry dialect dialect.DialectAdapter settingsManager *settings.SettingsManager stopChan chan struct{} wg sync.WaitGroup ctx context.Context cancel context.CancelFunc // 统计指标 eventsReceived atomic.Uint64 eventsProcessed atomic.Uint64 eventsFailed atomic.Uint64 flushCount atomic.Uint64 recordsFlushed atomic.Uint64 flushErrors atomic.Uint64 lastFlushTime time.Time lastFlushMutex sync.RWMutex // 运行时配置 flushInterval time.Duration configMutex sync.RWMutex } func NewAnalyticsService( db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter, settingsManager *settings.SettingsManager, ) *AnalyticsService { ctx, cancel := context.WithCancel(context.Background()) return &AnalyticsService{ db: db, store: s, logger: logger.WithField("component", "Analytics📊"), dialect: d, settingsManager: settingsManager, stopChan: make(chan struct{}), ctx: ctx, cancel: cancel, flushInterval: defaultFlushInterval, lastFlushTime: time.Now(), } } func (s *AnalyticsService) Start() { s.wg.Add(3) go s.eventListener() go s.flushLoop() go s.metricsReporter() s.logger.WithFields(logrus.Fields{ "flush_interval": s.flushInterval, }).Info("AnalyticsService started") } func (s *AnalyticsService) Stop() { s.logger.Info("AnalyticsService stopping...") close(s.stopChan) s.cancel() s.wg.Wait() s.logger.Info("Performing final data flush...") s.flushToDB() // 输出最终统计 s.logger.WithFields(logrus.Fields{ "events_received": s.eventsReceived.Load(), "events_processed": s.eventsProcessed.Load(), "events_failed": s.eventsFailed.Load(), "flush_count": s.flushCount.Load(), "records_flushed": s.recordsFlushed.Load(), "flush_errors": s.flushErrors.Load(), }).Info("AnalyticsService stopped") } // 事件监听循环 func (s *AnalyticsService) eventListener() { defer s.wg.Done() sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished) if err != nil { s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled") return } defer func() { if err := sub.Close(); err != nil { s.logger.WithError(err).Warn("Failed to close subscription") } }() s.logger.Info("Subscribed to request events for analytics") for { select { case msg := <-sub.Channel(): s.handleMessage(msg) case <-s.stopChan: s.logger.Info("Event listener stopping") return case <-s.ctx.Done(): s.logger.Info("Event listener context cancelled") return } } } // 处理单条消息 func (s *AnalyticsService) handleMessage(msg *store.Message) { var event models.RequestFinishedEvent if err := json.Unmarshal(msg.Payload, &event); err != nil { s.logger.WithError(err).Error("Failed to unmarshal analytics event") s.eventsFailed.Add(1) return } s.eventsReceived.Add(1) if err := s.handleAnalyticsEvent(&event); err != nil { s.eventsFailed.Add(1) s.logger.WithFields(logrus.Fields{ "correlation_id": event.CorrelationID, "group_id": event.RequestLog.GroupID, }).WithError(err).Warn("Failed to process analytics event") } else { s.eventsProcessed.Add(1) } } // 处理分析事件 func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error { if event.RequestLog.GroupID == nil { return nil // 跳过无 GroupID 的事件 } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() now := time.Now().UTC() key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")) fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName) pipe := s.store.Pipeline(ctx) pipe.HIncrBy(key, fieldPrefix+":requests", 1) if event.RequestLog.IsSuccess { pipe.HIncrBy(key, fieldPrefix+":success", 1) } if event.RequestLog.PromptTokens > 0 { pipe.HIncrBy(key, fieldPrefix+":prompt", int64(event.RequestLog.PromptTokens)) } if event.RequestLog.CompletionTokens > 0 { pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens)) } // 设置过期时间(保留48小时) pipe.Expire(key, 48*time.Hour) if err := pipe.Exec(); err != nil { return fmt.Errorf("redis pipeline failed: %w", err) } return nil } // 刷新循环 func (s *AnalyticsService) flushLoop() { defer s.wg.Done() s.configMutex.RLock() interval := s.flushInterval s.configMutex.RUnlock() ticker := time.NewTicker(interval) defer ticker.Stop() s.logger.WithField("interval", interval).Info("Flush loop started") for { select { case <-ticker.C: s.flushToDB() case <-s.stopChan: s.logger.Info("Flush loop stopping") return case <-s.ctx.Done(): return } } } // 刷写到数据库 func (s *AnalyticsService) flushToDB() { start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() now := time.Now().UTC() keysToFlush := s.generateFlushKeys(now) totalRecords := 0 totalErrors := 0 for _, key := range keysToFlush { records, err := s.flushSingleKey(ctx, key, now) if err != nil { s.logger.WithError(err).WithField("key", key).Error("Failed to flush key") totalErrors++ s.flushErrors.Add(1) } else { totalRecords += records } } s.recordsFlushed.Add(uint64(totalRecords)) s.flushCount.Add(1) s.lastFlushMutex.Lock() s.lastFlushTime = time.Now() s.lastFlushMutex.Unlock() duration := time.Since(start) if totalRecords > 0 || totalErrors > 0 { s.logger.WithFields(logrus.Fields{ "records_flushed": totalRecords, "keys_processed": len(keysToFlush), "errors": totalErrors, "duration": duration, }).Info("Analytics data flush completed") } else { s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)") } } // 生成需要刷新的 Redis 键 func (s *AnalyticsService) generateFlushKeys(now time.Time) []string { keys := make([]string, 0, 4) // 当前小时 keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))) // 前3个小时(处理延迟和时区问题) for i := 1; i <= 3; i++ { pastHour := now.Add(-time.Duration(i) * time.Hour) keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15"))) } return keys } // 刷写单个 Redis 键 func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) { data, err := s.store.HGetAll(ctx, key) if err != nil { return 0, fmt.Errorf("failed to get hash data: %w", err) } if len(data) == 0 { return 0, nil // 无数据,跳过 } // 解析时间戳 hourStr := strings.TrimPrefix(key, "analytics:hourly:") recordTime, err := time.Parse("2006-01-02T15", hourStr) if err != nil { s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key") recordTime = baseTime.Truncate(time.Hour) } statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data) if len(statsToFlush) == 0 { return 0, nil } // 使用事务 + 重试机制 var dbErr error for attempt := 1; attempt <= maxRetryAttempts; attempt++ { dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush) if dbErr == nil { break } if attempt < maxRetryAttempts { s.logger.WithFields(logrus.Fields{ "attempt": attempt, "key": key, }).WithError(dbErr).Warn("Database upsert failed, retrying...") time.Sleep(retryDelay) } } if dbErr != nil { return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr) } // 删除已处理的字段 if len(parsedFields) > 0 { if err := s.store.HDel(ctx, key, parsedFields...); err != nil { s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis") } } return len(statsToFlush), nil } // 使用事务批量 upsert func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error { return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { upsertClause := s.dialect.OnConflictUpdateAll( []string{"time", "group_id", "model_name"}, []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, ) return tx.Clauses(upsertClause).Create(&stats).Error }) } // 解析 Redis Hash 数据 func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) { tempAggregator := make(map[string]*models.StatsHourly) parsedFields := make([]string, 0, len(data)) for field, valueStr := range data { parts := strings.Split(field, ":") if len(parts) != 3 { s.logger.WithField("field", field).Warn("Invalid field format") continue } groupIDStr, modelName, counterType := parts[0], parts[1], parts[2] aggKey := groupIDStr + ":" + modelName if _, ok := tempAggregator[aggKey]; !ok { gid, err := strconv.Atoi(groupIDStr) if err != nil { s.logger.WithFields(logrus.Fields{ "field": field, "group_id": groupIDStr, }).Warn("Invalid group ID") continue } tempAggregator[aggKey] = &models.StatsHourly{ Time: t, GroupID: uint(gid), ModelName: modelName, } } val, err := strconv.ParseInt(valueStr, 10, 64) if err != nil { s.logger.WithFields(logrus.Fields{ "field": field, "value": valueStr, }).Warn("Invalid counter value") continue } switch counterType { case "requests": tempAggregator[aggKey].RequestCount = val case "success": tempAggregator[aggKey].SuccessCount = val case "prompt": tempAggregator[aggKey].PromptTokens = val case "completion": tempAggregator[aggKey].CompletionTokens = val default: s.logger.WithField("counter_type", counterType).Warn("Unknown counter type") continue } parsedFields = append(parsedFields, field) } result := make([]models.StatsHourly, 0, len(tempAggregator)) for _, stats := range tempAggregator { if stats.RequestCount > 0 { result = append(result, *stats) } } return result, parsedFields } // 定期输出统计信息 func (s *AnalyticsService) metricsReporter() { defer s.wg.Done() ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: s.reportMetrics() case <-s.stopChan: return case <-s.ctx.Done(): return } } } func (s *AnalyticsService) reportMetrics() { s.lastFlushMutex.RLock() lastFlush := s.lastFlushTime s.lastFlushMutex.RUnlock() received := s.eventsReceived.Load() processed := s.eventsProcessed.Load() failed := s.eventsFailed.Load() var successRate float64 if received > 0 { successRate = float64(processed) / float64(received) * 100 } s.logger.WithFields(logrus.Fields{ "events_received": received, "events_processed": processed, "events_failed": failed, "success_rate": fmt.Sprintf("%.2f%%", successRate), "flush_count": s.flushCount.Load(), "records_flushed": s.recordsFlushed.Load(), "flush_errors": s.flushErrors.Load(), "last_flush_ago": time.Since(lastFlush).Round(time.Second), }).Info("Analytics metrics") } // GetMetrics 返回当前统计指标(供监控使用) func (s *AnalyticsService) GetMetrics() map[string]interface{} { s.lastFlushMutex.RLock() lastFlush := s.lastFlushTime s.lastFlushMutex.RUnlock() received := s.eventsReceived.Load() processed := s.eventsProcessed.Load() var successRate float64 if received > 0 { successRate = float64(processed) / float64(received) * 100 } return map[string]interface{}{ "events_received": received, "events_processed": processed, "events_failed": s.eventsFailed.Load(), "success_rate": successRate, "flush_count": s.flushCount.Load(), "records_flushed": s.recordsFlushed.Load(), "flush_errors": s.flushErrors.Load(), "last_flush_ago": time.Since(lastFlush).Seconds(), "flush_interval": s.flushInterval.Seconds(), } }