197 lines
6.7 KiB
Go
197 lines
6.7 KiB
Go
// Filename: internal/service/stats_service.go
|
||
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/repository"
|
||
"gemini-balancer/internal/store"
|
||
"time"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
)
|
||
|
||
type StatsService struct {
|
||
db *gorm.DB
|
||
store store.Store
|
||
keyRepo repository.KeyRepository
|
||
logger *logrus.Entry
|
||
stopChan chan struct{}
|
||
}
|
||
|
||
func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService {
|
||
return &StatsService{
|
||
db: db,
|
||
store: s,
|
||
keyRepo: repo,
|
||
logger: logger.WithField("component", "StatsService"),
|
||
stopChan: make(chan struct{}),
|
||
}
|
||
}
|
||
|
||
func (s *StatsService) Start() {
|
||
s.logger.Info("Starting event listener for stats maintenance.")
|
||
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||
if err != nil {
|
||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||
return
|
||
}
|
||
go func() {
|
||
defer sub.Close()
|
||
for {
|
||
select {
|
||
case msg := <-sub.Channel():
|
||
var event models.KeyStatusChangedEvent
|
||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||
continue
|
||
}
|
||
s.handleKeyStatusChange(&event)
|
||
case <-s.stopChan:
|
||
s.logger.Info("Stopping stats event listener.")
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (s *StatsService) Stop() {
|
||
close(s.stopChan)
|
||
}
|
||
|
||
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
|
||
if event.GroupID == 0 {
|
||
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
||
return
|
||
}
|
||
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
|
||
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
|
||
|
||
switch event.ChangeReason {
|
||
case "key_unlinked", "key_hard_deleted":
|
||
if event.OldStatus != "" {
|
||
s.store.HIncrBy(statsKey, "total_keys", -1)
|
||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||
} else {
|
||
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||
s.RecalculateGroupKeyStats(event.GroupID)
|
||
}
|
||
case "key_linked":
|
||
if event.NewStatus != "" {
|
||
s.store.HIncrBy(statsKey, "total_keys", 1)
|
||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||
} else {
|
||
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
||
s.RecalculateGroupKeyStats(event.GroupID)
|
||
}
|
||
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||
default:
|
||
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||
s.RecalculateGroupKeyStats(event.GroupID)
|
||
}
|
||
}
|
||
|
||
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
|
||
var results []struct {
|
||
Status models.APIKeyStatus
|
||
Count int64
|
||
}
|
||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||
Where("key_group_id = ?", groupID).
|
||
Select("status, COUNT(*) as count").
|
||
Group("status").
|
||
Scan(&results).Error; err != nil {
|
||
return err
|
||
}
|
||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||
|
||
updates := make(map[string]interface{})
|
||
totalKeys := int64(0)
|
||
for _, res := range results {
|
||
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
|
||
totalKeys += res.Count
|
||
}
|
||
updates["total_keys"] = totalKeys
|
||
|
||
if err := s.store.Del(statsKey); err != nil {
|
||
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
||
}
|
||
if err := s.store.HSet(statsKey, updates); err != nil {
|
||
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
|
||
}
|
||
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
|
||
return nil
|
||
}
|
||
|
||
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
|
||
// TODO 逻辑:
|
||
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
|
||
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
|
||
// 3. 组合成 DashboardStatsResponse
|
||
// ... 这个方法的具体实现,我们可以在DashboardQueryService中完成,
|
||
// 这里我们先确保StatsService的核心职责(维护缓存)已经完成。
|
||
// 为了编译通过,我们先返回一个空对象。
|
||
|
||
// 伪代码:
|
||
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
|
||
// ...
|
||
|
||
return &models.DashboardStatsResponse{}, nil
|
||
}
|
||
|
||
func (s *StatsService) AggregateHourlyStats() error {
|
||
s.logger.Info("Starting aggregation of the last hour's request data...")
|
||
now := time.Now()
|
||
endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00
|
||
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
|
||
|
||
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
|
||
type aggregationResult struct {
|
||
GroupID uint
|
||
ModelName string
|
||
RequestCount int64
|
||
SuccessCount int64
|
||
PromptTokens int64
|
||
CompletionTokens int64
|
||
}
|
||
var results []aggregationResult
|
||
err := s.db.Model(&models.RequestLog{}).
|
||
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
|
||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||
Group("group_id, model_name").
|
||
Scan(&results).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to query aggregation data from request_logs: %w", err)
|
||
}
|
||
if len(results) == 0 {
|
||
s.logger.Info("No request logs found in the last hour to aggregate. Skipping.")
|
||
return nil
|
||
}
|
||
|
||
s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results))
|
||
|
||
var hourlyStats []models.StatsHourly
|
||
for _, res := range results {
|
||
hourlyStats = append(hourlyStats, models.StatsHourly{
|
||
Time: startTime, // 所有记录的时间戳都是该小时的起点
|
||
GroupID: res.GroupID,
|
||
ModelName: res.ModelName,
|
||
RequestCount: res.RequestCount,
|
||
SuccessCount: res.SuccessCount,
|
||
PromptTokens: res.PromptTokens,
|
||
CompletionTokens: res.CompletionTokens,
|
||
})
|
||
}
|
||
|
||
return s.db.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
||
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
||
}).Create(&hourlyStats).Error
|
||
}
|