Files
gemini-banlancer/internal/service/stats_service.go

246 lines
8.2 KiB
Go

// Filename: internal/service/stats_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/store"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type StatsService struct {
db *gorm.DB
store store.Store
keyRepo repository.KeyRepository
logger *logrus.Entry
stopChan chan struct{}
}
func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService {
return &StatsService{
db: db,
store: s,
keyRepo: repo,
logger: logger.WithField("component", "StatsService"),
stopChan: make(chan struct{}),
}
}
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
go s.listenForEvents()
}
func (s *StatsService) Stop() {
close(s.stopChan)
}
func (s *StatsService) listenForEvents() {
for {
select {
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
default:
}
ctx, cancel := context.WithCancel(context.Background())
sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
if err != nil {
s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err)
cancel()
time.Sleep(5 * time.Second)
continue
}
s.logger.Info("Subscribed to key status changes")
s.handleSubscription(sub, cancel)
}
}
func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) {
defer sub.Close()
defer cancel()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
return
}
}
}
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
if event.GroupID == 0 {
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
return
}
ctx := context.Background()
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
}
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
var results []struct {
Status models.APIKeyStatus
Count int64
}
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ?", groupID).
Select("status, COUNT(*) as count").
Group("status").
Scan(&results).Error; err != nil {
return err
}
statsKey := fmt.Sprintf("stats:group:%d", groupID)
updates := map[string]interface{}{
"active_keys": int64(0),
"disabled_keys": int64(0),
"error_keys": int64(0),
"total_keys": int64(0),
}
for _, res := range results {
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
updates["total_keys"] = updates["total_keys"].(int64) + res.Count
}
if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
}
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
}
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
return nil
}
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
return &models.DashboardStatsResponse{}, nil
}
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
s.logger.Info("Starting aggregation of the last hour's request data...")
now := time.Now()
endTime := now.Truncate(time.Hour)
startTime := endTime.Add(-1 * time.Hour)
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
type aggregationResult struct {
GroupID uint
ModelName string
RequestCount int64
SuccessCount int64
PromptTokens int64
CompletionTokens int64
}
var results []aggregationResult
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Group("group_id, model_name").
Scan(&results).Error
if err != nil {
return fmt.Errorf("failed to query aggregation data from request_logs: %w", err)
}
if len(results) == 0 {
s.logger.Info("No request logs found in the last hour to aggregate. Skipping.")
return nil
}
s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results))
var hourlyStats []models.StatsHourly
for _, res := range results {
hourlyStats = append(hourlyStats, models.StatsHourly{
Time: startTime,
GroupID: res.GroupID,
ModelName: res.ModelName,
RequestCount: res.RequestCount,
SuccessCount: res.SuccessCount,
PromptTokens: res.PromptTokens,
CompletionTokens: res.CompletionTokens,
})
}
if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error; err != nil {
return err
}
if err := s.db.WithContext(ctx).
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Delete(&models.RequestLog{}).Error; err != nil {
s.logger.WithError(err).Warn("Failed to delete aggregated request logs")
}
return nil
}