Files
gemini-banlancer/internal/service/analytics_service.go

495 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/service/analytics_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
defaultFlushInterval = 1 * time.Minute
maxRetryAttempts = 3
retryDelay = 5 * time.Second
)
type AnalyticsService struct {
db *gorm.DB
store store.Store
logger *logrus.Entry
dialect dialect.DialectAdapter
settingsManager *settings.SettingsManager
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
eventsReceived atomic.Uint64
eventsProcessed atomic.Uint64
eventsFailed atomic.Uint64
flushCount atomic.Uint64
recordsFlushed atomic.Uint64
flushErrors atomic.Uint64
lastFlushTime time.Time
lastFlushMutex sync.RWMutex
// 运行时配置
flushInterval time.Duration
configMutex sync.RWMutex
}
func NewAnalyticsService(
db *gorm.DB,
s store.Store,
logger *logrus.Logger,
d dialect.DialectAdapter,
settingsManager *settings.SettingsManager,
) *AnalyticsService {
ctx, cancel := context.WithCancel(context.Background())
return &AnalyticsService{
db: db,
store: s,
logger: logger.WithField("component", "Analytics📊"),
dialect: d,
settingsManager: settingsManager,
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
flushInterval: defaultFlushInterval,
lastFlushTime: time.Now(),
}
}
func (s *AnalyticsService) Start() {
s.wg.Add(3)
go s.eventListener()
go s.flushLoop()
go s.metricsReporter()
s.logger.WithFields(logrus.Fields{
"flush_interval": s.flushInterval,
}).Info("AnalyticsService started")
}
func (s *AnalyticsService) Stop() {
s.logger.Info("AnalyticsService stopping...")
close(s.stopChan)
s.cancel()
s.wg.Wait()
s.logger.Info("Performing final data flush...")
s.flushToDB()
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"events_received": s.eventsReceived.Load(),
"events_processed": s.eventsProcessed.Load(),
"events_failed": s.eventsFailed.Load(),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
}).Info("AnalyticsService stopped")
}
// 事件监听循环
func (s *AnalyticsService) eventListener() {
defer s.wg.Done()
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil {
s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
return
}
defer func() {
if err := sub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close subscription")
}
}()
s.logger.Info("Subscribed to request events for analytics")
for {
select {
case msg := <-sub.Channel():
s.handleMessage(msg)
case <-s.stopChan:
s.logger.Info("Event listener stopping")
return
case <-s.ctx.Done():
s.logger.Info("Event listener context cancelled")
return
}
}
}
// 处理单条消息
func (s *AnalyticsService) handleMessage(msg *store.Message) {
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal analytics event")
s.eventsFailed.Add(1)
return
}
s.eventsReceived.Add(1)
if err := s.handleAnalyticsEvent(&event); err != nil {
s.eventsFailed.Add(1)
s.logger.WithFields(logrus.Fields{
"correlation_id": event.CorrelationID,
"group_id": event.RequestLog.GroupID,
}).WithError(err).Warn("Failed to process analytics event")
} else {
s.eventsProcessed.Add(1)
}
}
// 处理分析事件
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
if event.RequestLog.GroupID == nil {
return nil // 跳过无 GroupID 的事件
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now().UTC()
key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1)
}
if event.RequestLog.PromptTokens > 0 {
pipe.HIncrBy(key, fieldPrefix+":prompt", int64(event.RequestLog.PromptTokens))
}
if event.RequestLog.CompletionTokens > 0 {
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
}
// 设置过期时间保留48小时
pipe.Expire(key, 48*time.Hour)
if err := pipe.Exec(); err != nil {
return fmt.Errorf("redis pipeline failed: %w", err)
}
return nil
}
// 刷新循环
func (s *AnalyticsService) flushLoop() {
defer s.wg.Done()
s.configMutex.RLock()
interval := s.flushInterval
s.configMutex.RUnlock()
ticker := time.NewTicker(interval)
defer ticker.Stop()
s.logger.WithField("interval", interval).Info("Flush loop started")
for {
select {
case <-ticker.C:
s.flushToDB()
case <-s.stopChan:
s.logger.Info("Flush loop stopping")
return
case <-s.ctx.Done():
return
}
}
}
// 刷写到数据库
func (s *AnalyticsService) flushToDB() {
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
now := time.Now().UTC()
keysToFlush := s.generateFlushKeys(now)
totalRecords := 0
totalErrors := 0
for _, key := range keysToFlush {
records, err := s.flushSingleKey(ctx, key, now)
if err != nil {
s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
totalErrors++
s.flushErrors.Add(1)
} else {
totalRecords += records
}
}
s.recordsFlushed.Add(uint64(totalRecords))
s.flushCount.Add(1)
s.lastFlushMutex.Lock()
s.lastFlushTime = time.Now()
s.lastFlushMutex.Unlock()
duration := time.Since(start)
if totalRecords > 0 || totalErrors > 0 {
s.logger.WithFields(logrus.Fields{
"records_flushed": totalRecords,
"keys_processed": len(keysToFlush),
"errors": totalErrors,
"duration": duration,
}).Info("Analytics data flush completed")
} else {
s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
}
}
// 生成需要刷新的 Redis 键
func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
keys := make([]string, 0, 4)
// 当前小时
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
// 前3个小时处理延迟和时区问题
for i := 1; i <= 3; i++ {
pastHour := now.Add(-time.Duration(i) * time.Hour)
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
}
return keys
}
// 刷写单个 Redis 键
func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
data, err := s.store.HGetAll(ctx, key)
if err != nil {
return 0, fmt.Errorf("failed to get hash data: %w", err)
}
if len(data) == 0 {
return 0, nil // 无数据,跳过
}
// 解析时间戳
hourStr := strings.TrimPrefix(key, "analytics:hourly:")
recordTime, err := time.Parse("2006-01-02T15", hourStr)
if err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
recordTime = baseTime.Truncate(time.Hour)
}
statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
if len(statsToFlush) == 0 {
return 0, nil
}
// 使用事务 + 重试机制
var dbErr error
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
if dbErr == nil {
break
}
if attempt < maxRetryAttempts {
s.logger.WithFields(logrus.Fields{
"attempt": attempt,
"key": key,
}).WithError(dbErr).Warn("Database upsert failed, retrying...")
time.Sleep(retryDelay)
}
}
if dbErr != nil {
return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
}
// 删除已处理的字段
if len(parsedFields) > 0 {
if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
}
}
return len(statsToFlush), nil
}
// 使用事务批量 upsert
func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
return tx.Clauses(upsertClause).Create(&stats).Error
})
}
// 解析 Redis Hash 数据
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
tempAggregator := make(map[string]*models.StatsHourly)
parsedFields := make([]string, 0, len(data))
for field, valueStr := range data {
parts := strings.Split(field, ":")
if len(parts) != 3 {
s.logger.WithField("field", field).Warn("Invalid field format")
continue
}
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
aggKey := groupIDStr + ":" + modelName
if _, ok := tempAggregator[aggKey]; !ok {
gid, err := strconv.Atoi(groupIDStr)
if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"group_id": groupIDStr,
}).Warn("Invalid group ID")
continue
}
tempAggregator[aggKey] = &models.StatsHourly{
Time: t,
GroupID: uint(gid),
ModelName: modelName,
}
}
val, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"value": valueStr,
}).Warn("Invalid counter value")
continue
}
switch counterType {
case "requests":
tempAggregator[aggKey].RequestCount = val
case "success":
tempAggregator[aggKey].SuccessCount = val
case "prompt":
tempAggregator[aggKey].PromptTokens = val
case "completion":
tempAggregator[aggKey].CompletionTokens = val
default:
s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
continue
}
parsedFields = append(parsedFields, field)
}
result := make([]models.StatsHourly, 0, len(tempAggregator))
for _, stats := range tempAggregator {
if stats.RequestCount > 0 {
result = append(result, *stats)
}
}
return result, parsedFields
}
// 定期输出统计信息
func (s *AnalyticsService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *AnalyticsService) reportMetrics() {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
failed := s.eventsFailed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
s.logger.WithFields(logrus.Fields{
"events_received": received,
"events_processed": processed,
"events_failed": failed,
"success_rate": fmt.Sprintf("%.2f%%", successRate),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Round(time.Second),
}).Info("Analytics metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *AnalyticsService) GetMetrics() map[string]interface{} {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
return map[string]interface{}{
"events_received": received,
"events_processed": processed,
"events_failed": s.eventsFailed.Load(),
"success_rate": successRate,
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Seconds(),
"flush_interval": s.flushInterval.Seconds(),
}
}