188 lines
6.4 KiB
Go
188 lines
6.4 KiB
Go
// Filename: internal/service/stats_service.go
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/repository"
|
|
"gemini-balancer/internal/store"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
type StatsService struct {
|
|
db *gorm.DB
|
|
store store.Store
|
|
keyRepo repository.KeyRepository
|
|
logger *logrus.Entry
|
|
stopChan chan struct{}
|
|
}
|
|
|
|
func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService {
|
|
return &StatsService{
|
|
db: db,
|
|
store: s,
|
|
keyRepo: repo,
|
|
logger: logger.WithField("component", "StatsService"),
|
|
stopChan: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (s *StatsService) Start() {
|
|
s.logger.Info("Starting event listener for stats maintenance.")
|
|
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
|
if err != nil {
|
|
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
|
return
|
|
}
|
|
go func() {
|
|
defer sub.Close()
|
|
for {
|
|
select {
|
|
case msg := <-sub.Channel():
|
|
var event models.KeyStatusChangedEvent
|
|
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
|
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
|
continue
|
|
}
|
|
s.handleKeyStatusChange(&event)
|
|
case <-s.stopChan:
|
|
s.logger.Info("Stopping stats event listener.")
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *StatsService) Stop() {
|
|
close(s.stopChan)
|
|
}
|
|
|
|
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
|
|
if event.GroupID == 0 {
|
|
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
|
return
|
|
}
|
|
ctx := context.Background()
|
|
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
|
|
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
|
|
|
|
switch event.ChangeReason {
|
|
case "key_unlinked", "key_hard_deleted":
|
|
if event.OldStatus != "" {
|
|
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
|
|
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
|
} else {
|
|
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
|
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
|
}
|
|
case "key_linked":
|
|
if event.NewStatus != "" {
|
|
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
|
|
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
|
} else {
|
|
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
|
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
|
}
|
|
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
|
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
|
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
|
default:
|
|
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
|
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
|
}
|
|
}
|
|
|
|
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
|
|
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
|
|
var results []struct {
|
|
Status models.APIKeyStatus
|
|
Count int64
|
|
}
|
|
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
|
Where("key_group_id = ?", groupID).
|
|
Select("status, COUNT(*) as count").
|
|
Group("status").
|
|
Scan(&results).Error; err != nil {
|
|
return err
|
|
}
|
|
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
|
|
|
updates := make(map[string]interface{})
|
|
totalKeys := int64(0)
|
|
for _, res := range results {
|
|
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
|
|
totalKeys += res.Count
|
|
}
|
|
updates["total_keys"] = totalKeys
|
|
|
|
if err := s.store.Del(ctx, statsKey); err != nil {
|
|
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
|
}
|
|
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
|
|
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
|
|
}
|
|
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
|
|
return nil
|
|
}
|
|
|
|
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
|
|
return &models.DashboardStatsResponse{}, nil
|
|
}
|
|
|
|
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
|
|
s.logger.Info("Starting aggregation of the last hour's request data...")
|
|
now := time.Now()
|
|
endTime := now.Truncate(time.Hour)
|
|
startTime := endTime.Add(-1 * time.Hour)
|
|
|
|
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
|
|
type aggregationResult struct {
|
|
GroupID uint
|
|
ModelName string
|
|
RequestCount int64
|
|
SuccessCount int64
|
|
PromptTokens int64
|
|
CompletionTokens int64
|
|
}
|
|
var results []aggregationResult
|
|
|
|
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
|
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
|
|
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
|
Group("group_id, model_name").
|
|
Scan(&results).Error
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query aggregation data from request_logs: %w", err)
|
|
}
|
|
if len(results) == 0 {
|
|
s.logger.Info("No request logs found in the last hour to aggregate. Skipping.")
|
|
return nil
|
|
}
|
|
|
|
s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results))
|
|
|
|
var hourlyStats []models.StatsHourly
|
|
for _, res := range results {
|
|
hourlyStats = append(hourlyStats, models.StatsHourly{
|
|
Time: startTime,
|
|
GroupID: res.GroupID,
|
|
ModelName: res.ModelName,
|
|
RequestCount: res.RequestCount,
|
|
SuccessCount: res.SuccessCount,
|
|
PromptTokens: res.PromptTokens,
|
|
CompletionTokens: res.CompletionTokens,
|
|
})
|
|
}
|
|
|
|
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
|
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
|
}).Create(&hourlyStats).Error
|
|
}
|