Files
gemini-banlancer/internal/service/stats_service.go
2025-11-20 12:24:05 +08:00

197 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/service/stats_service.go
package service
import (
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/store"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type StatsService struct {
db *gorm.DB
store store.Store
keyRepo repository.KeyRepository
logger *logrus.Entry
stopChan chan struct{}
}
func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService {
return &StatsService{
db: db,
store: s,
keyRepo: repo,
logger: logger.WithField("component", "StatsService"),
stopChan: make(chan struct{}),
}
}
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
go func() {
defer sub.Close()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
}
}
}()
}
func (s *StatsService) Stop() {
close(s.stopChan)
}
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
if event.GroupID == 0 {
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
return
}
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", 1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
}
}
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
var results []struct {
Status models.APIKeyStatus
Count int64
}
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ?", groupID).
Select("status, COUNT(*) as count").
Group("status").
Scan(&results).Error; err != nil {
return err
}
statsKey := fmt.Sprintf("stats:group:%d", groupID)
updates := make(map[string]interface{})
totalKeys := int64(0)
for _, res := range results {
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
totalKeys += res.Count
}
updates["total_keys"] = totalKeys
if err := s.store.Del(statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
}
if err := s.store.HSet(statsKey, updates); err != nil {
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
}
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
return nil
}
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
// TODO 逻辑:
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
// 3. 组合成 DashboardStatsResponse
// ... 这个方法的具体实现我们可以在DashboardQueryService中完成
// 这里我们先确保StatsService的核心职责维护缓存已经完成。
// 为了编译通过,我们先返回一个空对象。
// 伪代码:
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
// ...
return &models.DashboardStatsResponse{}, nil
}
func (s *StatsService) AggregateHourlyStats() error {
s.logger.Info("Starting aggregation of the last hour's request data...")
now := time.Now()
endTime := now.Truncate(time.Hour) // 例如15:23 -> 15:00
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
type aggregationResult struct {
GroupID uint
ModelName string
RequestCount int64
SuccessCount int64
PromptTokens int64
CompletionTokens int64
}
var results []aggregationResult
err := s.db.Model(&models.RequestLog{}).
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Group("group_id, model_name").
Scan(&results).Error
if err != nil {
return fmt.Errorf("failed to query aggregation data from request_logs: %w", err)
}
if len(results) == 0 {
s.logger.Info("No request logs found in the last hour to aggregate. Skipping.")
return nil
}
s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results))
var hourlyStats []models.StatsHourly
for _, res := range results {
hourlyStats = append(hourlyStats, models.StatsHourly{
Time: startTime, // 所有记录的时间戳都是该小时的起点
GroupID: res.GroupID,
ModelName: res.ModelName,
RequestCount: res.RequestCount,
SuccessCount: res.SuccessCount,
PromptTokens: res.PromptTokens,
CompletionTokens: res.CompletionTokens,
})
}
return s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error
}