// Filename: internal/service/stats_service.go package service import ( "encoding/json" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/store" "time" "github.com/sirupsen/logrus" "gorm.io/gorm" "gorm.io/gorm/clause" ) type StatsService struct { db *gorm.DB store store.Store keyRepo repository.KeyRepository logger *logrus.Entry stopChan chan struct{} } func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService { return &StatsService{ db: db, store: s, keyRepo: repo, logger: logger.WithField("component", "StatsService"), stopChan: make(chan struct{}), } } func (s *StatsService) Start() { s.logger.Info("Starting event listener for stats maintenance.") sub, err := s.store.Subscribe(models.TopicKeyStatusChanged) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) return } go func() { defer sub.Close() for { select { case msg := <-sub.Channel(): var event models.KeyStatusChangedEvent if err := json.Unmarshal(msg.Payload, &event); err != nil { s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err) continue } s.handleKeyStatusChange(&event) case <-s.stopChan: s.logger.Info("Stopping stats event listener.") return } } }() } func (s *StatsService) Stop() { close(s.stopChan) } func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) { if event.GroupID == 0 { s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID) return } statsKey := fmt.Sprintf("stats:group:%d", event.GroupID) s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason) switch event.ChangeReason { case "key_unlinked", "key_hard_deleted": if event.OldStatus != "" { s.store.HIncrBy(statsKey, "total_keys", -1) s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) } else { s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID) s.RecalculateGroupKeyStats(event.GroupID) } case "key_linked": if event.NewStatus != "" { s.store.HIncrBy(statsKey, "total_keys", 1) s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) } else { s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID) s.RecalculateGroupKeyStats(event.GroupID) } case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key": s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) default: s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID) s.RecalculateGroupKeyStats(event.GroupID) } } func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error { s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID) var results []struct { Status models.APIKeyStatus Count int64 } if err := s.db.Model(&models.GroupAPIKeyMapping{}). Where("key_group_id = ?", groupID). Select("status, COUNT(*) as count"). Group("status"). Scan(&results).Error; err != nil { return err } statsKey := fmt.Sprintf("stats:group:%d", groupID) updates := make(map[string]interface{}) totalKeys := int64(0) for _, res := range results { updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count totalKeys += res.Count } updates["total_keys"] = totalKeys if err := s.store.Del(statsKey); err != nil { s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID) } if err := s.store.HSet(statsKey, updates); err != nil { return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err) } s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID) return nil } func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) { // TODO 逻辑: // 1. 从Redis中获取所有分组的Key统计 (HGetAll) // 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率 // 3. 组合成 DashboardStatsResponse // ... 这个方法的具体实现,我们可以在DashboardQueryService中完成, // 这里我们先确保StatsService的核心职责(维护缓存)已经完成。 // 为了编译通过,我们先返回一个空对象。 // 伪代码: // keyCounts, _ := s.store.HGetAll("stats:global:keys") // ... return &models.DashboardStatsResponse{}, nil } func (s *StatsService) AggregateHourlyStats() error { s.logger.Info("Starting aggregation of the last hour's request data...") now := time.Now() endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00 startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00 s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339)) type aggregationResult struct { GroupID uint ModelName string RequestCount int64 SuccessCount int64 PromptTokens int64 CompletionTokens int64 } var results []aggregationResult err := s.db.Model(&models.RequestLog{}). Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens"). Where("request_time >= ? AND request_time < ?", startTime, endTime). Group("group_id, model_name"). Scan(&results).Error if err != nil { return fmt.Errorf("failed to query aggregation data from request_logs: %w", err) } if len(results) == 0 { s.logger.Info("No request logs found in the last hour to aggregate. Skipping.") return nil } s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results)) var hourlyStats []models.StatsHourly for _, res := range results { hourlyStats = append(hourlyStats, models.StatsHourly{ Time: startTime, // 所有记录的时间戳都是该小时的起点 GroupID: res.GroupID, ModelName: res.ModelName, RequestCount: res.RequestCount, SuccessCount: res.SuccessCount, PromptTokens: res.PromptTokens, CompletionTokens: res.CompletionTokens, }) } return s.db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}}, DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}), }).Create(&hourlyStats).Error }