495 lines
12 KiB
Go
495 lines
12 KiB
Go
// Filename: internal/service/analytics_service.go
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"gemini-balancer/internal/db/dialect"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/settings"
|
||
"gemini-balancer/internal/store"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const (
|
||
defaultFlushInterval = 1 * time.Minute
|
||
maxRetryAttempts = 3
|
||
retryDelay = 5 * time.Second
|
||
)
|
||
|
||
type AnalyticsService struct {
|
||
db *gorm.DB
|
||
store store.Store
|
||
logger *logrus.Entry
|
||
dialect dialect.DialectAdapter
|
||
settingsManager *settings.SettingsManager
|
||
|
||
stopChan chan struct{}
|
||
wg sync.WaitGroup
|
||
ctx context.Context
|
||
cancel context.CancelFunc
|
||
|
||
// 统计指标
|
||
eventsReceived atomic.Uint64
|
||
eventsProcessed atomic.Uint64
|
||
eventsFailed atomic.Uint64
|
||
flushCount atomic.Uint64
|
||
recordsFlushed atomic.Uint64
|
||
flushErrors atomic.Uint64
|
||
lastFlushTime time.Time
|
||
lastFlushMutex sync.RWMutex
|
||
|
||
// 运行时配置
|
||
flushInterval time.Duration
|
||
configMutex sync.RWMutex
|
||
}
|
||
|
||
func NewAnalyticsService(
|
||
db *gorm.DB,
|
||
s store.Store,
|
||
logger *logrus.Logger,
|
||
d dialect.DialectAdapter,
|
||
settingsManager *settings.SettingsManager,
|
||
) *AnalyticsService {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
|
||
return &AnalyticsService{
|
||
db: db,
|
||
store: s,
|
||
logger: logger.WithField("component", "Analytics📊"),
|
||
dialect: d,
|
||
settingsManager: settingsManager,
|
||
stopChan: make(chan struct{}),
|
||
ctx: ctx,
|
||
cancel: cancel,
|
||
flushInterval: defaultFlushInterval,
|
||
lastFlushTime: time.Now(),
|
||
}
|
||
}
|
||
|
||
func (s *AnalyticsService) Start() {
|
||
s.wg.Add(3)
|
||
go s.eventListener()
|
||
go s.flushLoop()
|
||
go s.metricsReporter()
|
||
|
||
s.logger.WithFields(logrus.Fields{
|
||
"flush_interval": s.flushInterval,
|
||
}).Info("AnalyticsService started")
|
||
}
|
||
|
||
func (s *AnalyticsService) Stop() {
|
||
s.logger.Info("AnalyticsService stopping...")
|
||
close(s.stopChan)
|
||
s.cancel()
|
||
s.wg.Wait()
|
||
|
||
s.logger.Info("Performing final data flush...")
|
||
s.flushToDB()
|
||
|
||
// 输出最终统计
|
||
s.logger.WithFields(logrus.Fields{
|
||
"events_received": s.eventsReceived.Load(),
|
||
"events_processed": s.eventsProcessed.Load(),
|
||
"events_failed": s.eventsFailed.Load(),
|
||
"flush_count": s.flushCount.Load(),
|
||
"records_flushed": s.recordsFlushed.Load(),
|
||
"flush_errors": s.flushErrors.Load(),
|
||
}).Info("AnalyticsService stopped")
|
||
}
|
||
|
||
// 事件监听循环
|
||
func (s *AnalyticsService) eventListener() {
|
||
defer s.wg.Done()
|
||
|
||
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
|
||
if err != nil {
|
||
s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
|
||
return
|
||
}
|
||
defer func() {
|
||
if err := sub.Close(); err != nil {
|
||
s.logger.WithError(err).Warn("Failed to close subscription")
|
||
}
|
||
}()
|
||
|
||
s.logger.Info("Subscribed to request events for analytics")
|
||
|
||
for {
|
||
select {
|
||
case msg := <-sub.Channel():
|
||
s.handleMessage(msg)
|
||
|
||
case <-s.stopChan:
|
||
s.logger.Info("Event listener stopping")
|
||
return
|
||
|
||
case <-s.ctx.Done():
|
||
s.logger.Info("Event listener context cancelled")
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 处理单条消息
|
||
func (s *AnalyticsService) handleMessage(msg *store.Message) {
|
||
var event models.RequestFinishedEvent
|
||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||
s.logger.WithError(err).Error("Failed to unmarshal analytics event")
|
||
s.eventsFailed.Add(1)
|
||
return
|
||
}
|
||
|
||
s.eventsReceived.Add(1)
|
||
|
||
if err := s.handleAnalyticsEvent(&event); err != nil {
|
||
s.eventsFailed.Add(1)
|
||
s.logger.WithFields(logrus.Fields{
|
||
"correlation_id": event.CorrelationID,
|
||
"group_id": event.RequestLog.GroupID,
|
||
}).WithError(err).Warn("Failed to process analytics event")
|
||
} else {
|
||
s.eventsProcessed.Add(1)
|
||
}
|
||
}
|
||
|
||
// 处理分析事件
|
||
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
|
||
if event.RequestLog.GroupID == nil {
|
||
return nil // 跳过无 GroupID 的事件
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
now := time.Now().UTC()
|
||
key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
|
||
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
|
||
|
||
pipe := s.store.Pipeline(ctx)
|
||
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
|
||
|
||
if event.RequestLog.IsSuccess {
|
||
pipe.HIncrBy(key, fieldPrefix+":success", 1)
|
||
}
|
||
if event.RequestLog.PromptTokens > 0 {
|
||
pipe.HIncrBy(key, fieldPrefix+":prompt", int64(event.RequestLog.PromptTokens))
|
||
}
|
||
if event.RequestLog.CompletionTokens > 0 {
|
||
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
|
||
}
|
||
|
||
// 设置过期时间(保留48小时)
|
||
pipe.Expire(key, 48*time.Hour)
|
||
|
||
if err := pipe.Exec(); err != nil {
|
||
return fmt.Errorf("redis pipeline failed: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 刷新循环
|
||
func (s *AnalyticsService) flushLoop() {
|
||
defer s.wg.Done()
|
||
|
||
s.configMutex.RLock()
|
||
interval := s.flushInterval
|
||
s.configMutex.RUnlock()
|
||
|
||
ticker := time.NewTicker(interval)
|
||
defer ticker.Stop()
|
||
|
||
s.logger.WithField("interval", interval).Info("Flush loop started")
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
s.flushToDB()
|
||
|
||
case <-s.stopChan:
|
||
s.logger.Info("Flush loop stopping")
|
||
return
|
||
|
||
case <-s.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 刷写到数据库
|
||
func (s *AnalyticsService) flushToDB() {
|
||
start := time.Now()
|
||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||
defer cancel()
|
||
|
||
now := time.Now().UTC()
|
||
keysToFlush := s.generateFlushKeys(now)
|
||
|
||
totalRecords := 0
|
||
totalErrors := 0
|
||
|
||
for _, key := range keysToFlush {
|
||
records, err := s.flushSingleKey(ctx, key, now)
|
||
if err != nil {
|
||
s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
|
||
totalErrors++
|
||
s.flushErrors.Add(1)
|
||
} else {
|
||
totalRecords += records
|
||
}
|
||
}
|
||
|
||
s.recordsFlushed.Add(uint64(totalRecords))
|
||
s.flushCount.Add(1)
|
||
|
||
s.lastFlushMutex.Lock()
|
||
s.lastFlushTime = time.Now()
|
||
s.lastFlushMutex.Unlock()
|
||
|
||
duration := time.Since(start)
|
||
|
||
if totalRecords > 0 || totalErrors > 0 {
|
||
s.logger.WithFields(logrus.Fields{
|
||
"records_flushed": totalRecords,
|
||
"keys_processed": len(keysToFlush),
|
||
"errors": totalErrors,
|
||
"duration": duration,
|
||
}).Info("Analytics data flush completed")
|
||
} else {
|
||
s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
|
||
}
|
||
}
|
||
|
||
// 生成需要刷新的 Redis 键
|
||
func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
|
||
keys := make([]string, 0, 4)
|
||
|
||
// 当前小时
|
||
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
|
||
|
||
// 前3个小时(处理延迟和时区问题)
|
||
for i := 1; i <= 3; i++ {
|
||
pastHour := now.Add(-time.Duration(i) * time.Hour)
|
||
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
|
||
}
|
||
|
||
return keys
|
||
}
|
||
|
||
// 刷写单个 Redis 键
|
||
func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
|
||
data, err := s.store.HGetAll(ctx, key)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("failed to get hash data: %w", err)
|
||
}
|
||
|
||
if len(data) == 0 {
|
||
return 0, nil // 无数据,跳过
|
||
}
|
||
|
||
// 解析时间戳
|
||
hourStr := strings.TrimPrefix(key, "analytics:hourly:")
|
||
recordTime, err := time.Parse("2006-01-02T15", hourStr)
|
||
if err != nil {
|
||
s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
|
||
recordTime = baseTime.Truncate(time.Hour)
|
||
}
|
||
|
||
statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
|
||
|
||
if len(statsToFlush) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
// 使用事务 + 重试机制
|
||
var dbErr error
|
||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||
dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
|
||
if dbErr == nil {
|
||
break
|
||
}
|
||
|
||
if attempt < maxRetryAttempts {
|
||
s.logger.WithFields(logrus.Fields{
|
||
"attempt": attempt,
|
||
"key": key,
|
||
}).WithError(dbErr).Warn("Database upsert failed, retrying...")
|
||
time.Sleep(retryDelay)
|
||
}
|
||
}
|
||
|
||
if dbErr != nil {
|
||
return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
|
||
}
|
||
|
||
// 删除已处理的字段
|
||
if len(parsedFields) > 0 {
|
||
if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
|
||
s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
|
||
}
|
||
}
|
||
|
||
return len(statsToFlush), nil
|
||
}
|
||
|
||
// 使用事务批量 upsert
|
||
func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
|
||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||
[]string{"time", "group_id", "model_name"},
|
||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
|
||
)
|
||
return tx.Clauses(upsertClause).Create(&stats).Error
|
||
})
|
||
}
|
||
|
||
// 解析 Redis Hash 数据
|
||
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
|
||
tempAggregator := make(map[string]*models.StatsHourly)
|
||
parsedFields := make([]string, 0, len(data))
|
||
|
||
for field, valueStr := range data {
|
||
parts := strings.Split(field, ":")
|
||
if len(parts) != 3 {
|
||
s.logger.WithField("field", field).Warn("Invalid field format")
|
||
continue
|
||
}
|
||
|
||
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
|
||
aggKey := groupIDStr + ":" + modelName
|
||
|
||
if _, ok := tempAggregator[aggKey]; !ok {
|
||
gid, err := strconv.Atoi(groupIDStr)
|
||
if err != nil {
|
||
s.logger.WithFields(logrus.Fields{
|
||
"field": field,
|
||
"group_id": groupIDStr,
|
||
}).Warn("Invalid group ID")
|
||
continue
|
||
}
|
||
|
||
tempAggregator[aggKey] = &models.StatsHourly{
|
||
Time: t,
|
||
GroupID: uint(gid),
|
||
ModelName: modelName,
|
||
}
|
||
}
|
||
|
||
val, err := strconv.ParseInt(valueStr, 10, 64)
|
||
if err != nil {
|
||
s.logger.WithFields(logrus.Fields{
|
||
"field": field,
|
||
"value": valueStr,
|
||
}).Warn("Invalid counter value")
|
||
continue
|
||
}
|
||
|
||
switch counterType {
|
||
case "requests":
|
||
tempAggregator[aggKey].RequestCount = val
|
||
case "success":
|
||
tempAggregator[aggKey].SuccessCount = val
|
||
case "prompt":
|
||
tempAggregator[aggKey].PromptTokens = val
|
||
case "completion":
|
||
tempAggregator[aggKey].CompletionTokens = val
|
||
default:
|
||
s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
|
||
continue
|
||
}
|
||
|
||
parsedFields = append(parsedFields, field)
|
||
}
|
||
|
||
result := make([]models.StatsHourly, 0, len(tempAggregator))
|
||
for _, stats := range tempAggregator {
|
||
if stats.RequestCount > 0 {
|
||
result = append(result, *stats)
|
||
}
|
||
}
|
||
|
||
return result, parsedFields
|
||
}
|
||
|
||
// 定期输出统计信息
|
||
func (s *AnalyticsService) metricsReporter() {
|
||
defer s.wg.Done()
|
||
|
||
ticker := time.NewTicker(5 * time.Minute)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
s.reportMetrics()
|
||
case <-s.stopChan:
|
||
return
|
||
case <-s.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
func (s *AnalyticsService) reportMetrics() {
|
||
s.lastFlushMutex.RLock()
|
||
lastFlush := s.lastFlushTime
|
||
s.lastFlushMutex.RUnlock()
|
||
|
||
received := s.eventsReceived.Load()
|
||
processed := s.eventsProcessed.Load()
|
||
failed := s.eventsFailed.Load()
|
||
|
||
var successRate float64
|
||
if received > 0 {
|
||
successRate = float64(processed) / float64(received) * 100
|
||
}
|
||
|
||
s.logger.WithFields(logrus.Fields{
|
||
"events_received": received,
|
||
"events_processed": processed,
|
||
"events_failed": failed,
|
||
"success_rate": fmt.Sprintf("%.2f%%", successRate),
|
||
"flush_count": s.flushCount.Load(),
|
||
"records_flushed": s.recordsFlushed.Load(),
|
||
"flush_errors": s.flushErrors.Load(),
|
||
"last_flush_ago": time.Since(lastFlush).Round(time.Second),
|
||
}).Info("Analytics metrics")
|
||
}
|
||
|
||
// GetMetrics 返回当前统计指标(供监控使用)
|
||
func (s *AnalyticsService) GetMetrics() map[string]interface{} {
|
||
s.lastFlushMutex.RLock()
|
||
lastFlush := s.lastFlushTime
|
||
s.lastFlushMutex.RUnlock()
|
||
|
||
received := s.eventsReceived.Load()
|
||
processed := s.eventsProcessed.Load()
|
||
|
||
var successRate float64
|
||
if received > 0 {
|
||
successRate = float64(processed) / float64(received) * 100
|
||
}
|
||
|
||
return map[string]interface{}{
|
||
"events_received": received,
|
||
"events_processed": processed,
|
||
"events_failed": s.eventsFailed.Load(),
|
||
"success_rate": successRate,
|
||
"flush_count": s.flushCount.Load(),
|
||
"records_flushed": s.recordsFlushed.Load(),
|
||
"flush_errors": s.flushErrors.Load(),
|
||
"last_flush_ago": time.Since(lastFlush).Seconds(),
|
||
"flush_interval": s.flushInterval.Seconds(),
|
||
}
|
||
}
|