// Filename: internal/service/stats_service.go package service import ( "context" "encoding/json" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/store" "time" "github.com/sirupsen/logrus" "gorm.io/gorm" "gorm.io/gorm/clause" ) type StatsService struct { db *gorm.DB store store.Store keyRepo repository.KeyRepository logger *logrus.Entry stopChan chan struct{} } func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService { return &StatsService{ db: db, store: s, keyRepo: repo, logger: logger.WithField("component", "StatsService"), stopChan: make(chan struct{}), } } func (s *StatsService) Start() { s.logger.Info("Starting event listener for stats maintenance.") sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) return } go func() { defer sub.Close() for { select { case msg := <-sub.Channel(): var event models.KeyStatusChangedEvent if err := json.Unmarshal(msg.Payload, &event); err != nil { s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err) continue } s.handleKeyStatusChange(&event) case <-s.stopChan: s.logger.Info("Stopping stats event listener.") return } } }() } func (s *StatsService) Stop() { close(s.stopChan) } func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) { if event.GroupID == 0 { s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID) return } ctx := context.Background() statsKey := fmt.Sprintf("stats:group:%d", event.GroupID) s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason) switch event.ChangeReason { case "key_unlinked", "key_hard_deleted": if event.OldStatus != "" { s.store.HIncrBy(ctx, statsKey, "total_keys", -1) s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) } else { s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID) } case "key_linked": if event.NewStatus != "" { s.store.HIncrBy(ctx, statsKey, "total_keys", 1) s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) } else { s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID) } case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key": s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) default: s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID) } } func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error { s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID) var results []struct { Status models.APIKeyStatus Count int64 } if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). Where("key_group_id = ?", groupID). Select("status, COUNT(*) as count"). Group("status"). Scan(&results).Error; err != nil { return err } statsKey := fmt.Sprintf("stats:group:%d", groupID) updates := make(map[string]interface{}) totalKeys := int64(0) for _, res := range results { updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count totalKeys += res.Count } updates["total_keys"] = totalKeys if err := s.store.Del(ctx, statsKey); err != nil { s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID) } if err := s.store.HSet(ctx, statsKey, updates); err != nil { return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err) } s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID) return nil } func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) { return &models.DashboardStatsResponse{}, nil } func (s *StatsService) AggregateHourlyStats(ctx context.Context) error { s.logger.Info("Starting aggregation of the last hour's request data...") now := time.Now() endTime := now.Truncate(time.Hour) startTime := endTime.Add(-1 * time.Hour) s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339)) type aggregationResult struct { GroupID uint ModelName string RequestCount int64 SuccessCount int64 PromptTokens int64 CompletionTokens int64 } var results []aggregationResult err := s.db.WithContext(ctx).Model(&models.RequestLog{}). Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens"). Where("request_time >= ? AND request_time < ?", startTime, endTime). Group("group_id, model_name"). Scan(&results).Error if err != nil { return fmt.Errorf("failed to query aggregation data from request_logs: %w", err) } if len(results) == 0 { s.logger.Info("No request logs found in the last hour to aggregate. Skipping.") return nil } s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results)) var hourlyStats []models.StatsHourly for _, res := range results { hourlyStats = append(hourlyStats, models.StatsHourly{ Time: startTime, GroupID: res.GroupID, ModelName: res.ModelName, RequestCount: res.RequestCount, SuccessCount: res.SuccessCount, PromptTokens: res.PromptTokens, CompletionTokens: res.CompletionTokens, }) } return s.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}}, DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}), }).Create(&hourlyStats).Error }