Update Context for store
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -34,7 +35,7 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
|
||||
|
||||
func (s *StatsService) Start() {
|
||||
s.logger.Info("Starting event listener for stats maintenance.")
|
||||
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
@@ -67,42 +68,43 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
|
||||
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
|
||||
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
|
||||
|
||||
switch event.ChangeReason {
|
||||
case "key_unlinked", "key_hard_deleted":
|
||||
if event.OldStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
} else {
|
||||
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "key_linked":
|
||||
if event.NewStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
} else {
|
||||
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
default:
|
||||
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
|
||||
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
|
||||
var results []struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ?", groupID).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
@@ -119,37 +121,25 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
}
|
||||
updates["total_keys"] = totalKeys
|
||||
|
||||
if err := s.store.Del(statsKey); err != nil {
|
||||
if err := s.store.Del(ctx, statsKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
||||
}
|
||||
if err := s.store.HSet(statsKey, updates); err != nil {
|
||||
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
|
||||
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
|
||||
}
|
||||
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
|
||||
// TODO 逻辑:
|
||||
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
|
||||
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
|
||||
// 3. 组合成 DashboardStatsResponse
|
||||
// ... 这个方法的具体实现,我们可以在DashboardQueryService中完成,
|
||||
// 这里我们先确保StatsService的核心职责(维护缓存)已经完成。
|
||||
// 为了编译通过,我们先返回一个空对象。
|
||||
|
||||
// 伪代码:
|
||||
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
|
||||
// ...
|
||||
|
||||
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
|
||||
return &models.DashboardStatsResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *StatsService) AggregateHourlyStats() error {
|
||||
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
|
||||
s.logger.Info("Starting aggregation of the last hour's request data...")
|
||||
now := time.Now()
|
||||
endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00
|
||||
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
|
||||
endTime := now.Truncate(time.Hour)
|
||||
startTime := endTime.Add(-1 * time.Hour)
|
||||
|
||||
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
|
||||
type aggregationResult struct {
|
||||
@@ -161,7 +151,8 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
CompletionTokens int64
|
||||
}
|
||||
var results []aggregationResult
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
|
||||
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
||||
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
|
||||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||||
Group("group_id, model_name").
|
||||
@@ -179,7 +170,7 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
var hourlyStats []models.StatsHourly
|
||||
for _, res := range results {
|
||||
hourlyStats = append(hourlyStats, models.StatsHourly{
|
||||
Time: startTime, // 所有记录的时间戳都是该小时的起点
|
||||
Time: startTime,
|
||||
GroupID: res.GroupID,
|
||||
ModelName: res.ModelName,
|
||||
RequestCount: res.RequestCount,
|
||||
@@ -189,7 +180,7 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
})
|
||||
}
|
||||
|
||||
return s.db.Clauses(clause.OnConflict{
|
||||
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
||||
}).Create(&hourlyStats).Error
|
||||
|
||||
Reference in New Issue
Block a user