Fix Services & Update the middleware && others

This commit is contained in:
XOF
2025-11-24 04:48:07 +08:00
parent 3a95a07e8a
commit f2706d6fc8
37 changed files with 4458 additions and 1166 deletions

View File

@@ -5,93 +5,179 @@ import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
flushLoopInterval = 1 * time.Minute
defaultFlushInterval = 1 * time.Minute
maxRetryAttempts = 3
retryDelay = 5 * time.Second
)
type AnalyticsServiceLogger struct{ *logrus.Entry }
type AnalyticsService struct {
db *gorm.DB
store store.Store
logger *logrus.Entry
db *gorm.DB
store store.Store
logger *logrus.Entry
dialect dialect.DialectAdapter
settingsManager *settings.SettingsManager
stopChan chan struct{}
wg sync.WaitGroup
dialect dialect.DialectAdapter
ctx context.Context
cancel context.CancelFunc
// 统计指标
eventsReceived atomic.Uint64
eventsProcessed atomic.Uint64
eventsFailed atomic.Uint64
flushCount atomic.Uint64
recordsFlushed atomic.Uint64
flushErrors atomic.Uint64
lastFlushTime time.Time
lastFlushMutex sync.RWMutex
// 运行时配置
flushInterval time.Duration
configMutex sync.RWMutex
}
func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService {
func NewAnalyticsService(
db *gorm.DB,
s store.Store,
logger *logrus.Logger,
d dialect.DialectAdapter,
settingsManager *settings.SettingsManager,
) *AnalyticsService {
ctx, cancel := context.WithCancel(context.Background())
return &AnalyticsService{
db: db,
store: s,
logger: logger.WithField("component", "Analytics📊"),
stopChan: make(chan struct{}),
dialect: d,
db: db,
store: s,
logger: logger.WithField("component", "Analytics📊"),
dialect: d,
settingsManager: settingsManager,
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
flushInterval: defaultFlushInterval,
lastFlushTime: time.Now(),
}
}
func (s *AnalyticsService) Start() {
s.wg.Add(2)
go s.flushLoop()
s.wg.Add(3)
go s.eventListener()
s.logger.Info("AnalyticsService (Command Side) started.")
go s.flushLoop()
go s.metricsReporter()
s.logger.WithFields(logrus.Fields{
"flush_interval": s.flushInterval,
}).Info("AnalyticsService started")
}
func (s *AnalyticsService) Stop() {
s.logger.Info("AnalyticsService stopping...")
close(s.stopChan)
s.cancel()
s.wg.Wait()
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
s.logger.Info("Performing final data flush...")
s.flushToDB()
s.logger.Info("AnalyticsService final data flush completed.")
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"events_received": s.eventsReceived.Load(),
"events_processed": s.eventsProcessed.Load(),
"events_failed": s.eventsFailed.Load(),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
}).Info("AnalyticsService stopped")
}
// 事件监听循环
func (s *AnalyticsService) eventListener() {
defer s.wg.Done()
sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
return
}
defer sub.Close()
s.logger.Info("AnalyticsService subscribed to request events.")
defer func() {
if err := sub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close subscription")
}
}()
s.logger.Info("Subscribed to request events for analytics")
for {
select {
case msg := <-sub.Channel():
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
continue
}
s.handleAnalyticsEvent(&event)
s.handleMessage(msg)
case <-s.stopChan:
s.logger.Info("AnalyticsService stopping event listener.")
s.logger.Info("Event listener stopping")
return
case <-s.ctx.Done():
s.logger.Info("Event listener context cancelled")
return
}
}
}
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) {
if event.RequestLog.GroupID == nil {
// 处理单条消息
func (s *AnalyticsService) handleMessage(msg *store.Message) {
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal analytics event")
s.eventsFailed.Add(1)
return
}
ctx := context.Background()
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
s.eventsReceived.Add(1)
if err := s.handleAnalyticsEvent(&event); err != nil {
s.eventsFailed.Add(1)
s.logger.WithFields(logrus.Fields{
"correlation_id": event.CorrelationID,
"group_id": event.RequestLog.GroupID,
}).WithError(err).Warn("Failed to process analytics event")
} else {
s.eventsProcessed.Add(1)
}
}
// 处理分析事件
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
if event.RequestLog.GroupID == nil {
return nil // 跳过无 GroupID 的事件
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now().UTC()
key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1)
}
@@ -101,80 +187,213 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.CompletionTokens > 0 {
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
}
// 设置过期时间保留48小时
pipe.Expire(key, 48*time.Hour)
if err := pipe.Exec(); err != nil {
s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, *event.RequestLog.GroupID, err)
return fmt.Errorf("redis pipeline failed: %w", err)
}
return nil
}
// 刷新循环
func (s *AnalyticsService) flushLoop() {
defer s.wg.Done()
ticker := time.NewTicker(flushLoopInterval)
s.configMutex.RLock()
interval := s.flushInterval
s.configMutex.RUnlock()
ticker := time.NewTicker(interval)
defer ticker.Stop()
s.logger.WithField("interval", interval).Info("Flush loop started")
for {
select {
case <-ticker.C:
s.flushToDB()
case <-s.stopChan:
s.logger.Info("Flush loop stopping")
return
case <-s.ctx.Done():
return
}
}
}
// 刷写到数据库
func (s *AnalyticsService) flushToDB() {
ctx := context.Background()
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
now := time.Now().UTC()
keysToFlush := []string{
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")),
}
keysToFlush := s.generateFlushKeys(now)
totalRecords := 0
totalErrors := 0
for _, key := range keysToFlush {
data, err := s.store.HGetAll(ctx, key)
if err != nil || len(data) == 0 {
continue
records, err := s.flushSingleKey(ctx, key, now)
if err != nil {
s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
totalErrors++
s.flushErrors.Add(1)
} else {
totalRecords += records
}
}
statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data)
s.recordsFlushed.Add(uint64(totalRecords))
s.flushCount.Add(1)
if len(statsToFlush) > 0 {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
if err != nil {
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
} else {
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
_ = s.store.HDel(ctx, key, parsedFields...)
}
}
s.lastFlushMutex.Lock()
s.lastFlushTime = time.Now()
s.lastFlushMutex.Unlock()
duration := time.Since(start)
if totalRecords > 0 || totalErrors > 0 {
s.logger.WithFields(logrus.Fields{
"records_flushed": totalRecords,
"keys_processed": len(keysToFlush),
"errors": totalErrors,
"duration": duration,
}).Info("Analytics data flush completed")
} else {
s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
}
}
// 生成需要刷新的 Redis 键
func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
keys := make([]string, 0, 4)
// 当前小时
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
// 前3个小时处理延迟和时区问题
for i := 1; i <= 3; i++ {
pastHour := now.Add(-time.Duration(i) * time.Hour)
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
}
return keys
}
// 刷写单个 Redis 键
func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
data, err := s.store.HGetAll(ctx, key)
if err != nil {
return 0, fmt.Errorf("failed to get hash data: %w", err)
}
if len(data) == 0 {
return 0, nil // 无数据,跳过
}
// 解析时间戳
hourStr := strings.TrimPrefix(key, "analytics:hourly:")
recordTime, err := time.Parse("2006-01-02T15", hourStr)
if err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
recordTime = baseTime.Truncate(time.Hour)
}
statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
if len(statsToFlush) == 0 {
return 0, nil
}
// 使用事务 + 重试机制
var dbErr error
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
if dbErr == nil {
break
}
if attempt < maxRetryAttempts {
s.logger.WithFields(logrus.Fields{
"attempt": attempt,
"key": key,
}).WithError(dbErr).Warn("Database upsert failed, retrying...")
time.Sleep(retryDelay)
}
}
if dbErr != nil {
return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
}
// 删除已处理的字段
if len(parsedFields) > 0 {
if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
}
}
return len(statsToFlush), nil
}
// 使用事务批量 upsert
func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
return tx.Clauses(upsertClause).Create(&stats).Error
})
}
// 解析 Redis Hash 数据
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
tempAggregator := make(map[string]*models.StatsHourly)
var parsedFields []string
parsedFields := make([]string, 0, len(data))
for field, valueStr := range data {
parts := strings.Split(field, ":")
if len(parts) != 3 {
s.logger.WithField("field", field).Warn("Invalid field format")
continue
}
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
aggKey := groupIDStr + ":" + modelName
if _, ok := tempAggregator[aggKey]; !ok {
gid, err := strconv.Atoi(groupIDStr)
if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"group_id": groupIDStr,
}).Warn("Invalid group ID")
continue
}
tempAggregator[aggKey] = &models.StatsHourly{
Time: t,
GroupID: uint(gid),
ModelName: modelName,
}
}
val, _ := strconv.ParseInt(valueStr, 10, 64)
val, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"value": valueStr,
}).Warn("Invalid counter value")
continue
}
switch counterType {
case "requests":
tempAggregator[aggKey].RequestCount = val
@@ -184,14 +403,92 @@ func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]strin
tempAggregator[aggKey].PromptTokens = val
case "completion":
tempAggregator[aggKey].CompletionTokens = val
default:
s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
continue
}
parsedFields = append(parsedFields, field)
}
var result []models.StatsHourly
result := make([]models.StatsHourly, 0, len(tempAggregator))
for _, stats := range tempAggregator {
if stats.RequestCount > 0 {
result = append(result, *stats)
}
}
return result, parsedFields
}
// 定期输出统计信息
func (s *AnalyticsService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *AnalyticsService) reportMetrics() {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
failed := s.eventsFailed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
s.logger.WithFields(logrus.Fields{
"events_received": received,
"events_processed": processed,
"events_failed": failed,
"success_rate": fmt.Sprintf("%.2f%%", successRate),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Round(time.Second),
}).Info("Analytics metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *AnalyticsService) GetMetrics() map[string]interface{} {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
return map[string]interface{}{
"events_received": received,
"events_processed": processed,
"events_failed": s.eventsFailed.Load(),
"success_rate": successRate,
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Seconds(),
"flush_interval": s.flushInterval.Seconds(),
}
}

View File

@@ -4,158 +4,297 @@ package service
import (
"context"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const overviewCacheChannel = "syncer:cache:dashboard_overview"
const (
overviewCacheChannel = "syncer:cache:dashboard_overview"
defaultChartDays = 7
cacheLoadTimeout = 30 * time.Second
)
var (
// 图表颜色调色板
chartColorPalette = []string{
"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0",
"#9966FF", "#FF9F40", "#C9CBCF", "#4D5360",
}
)
type DashboardQueryService struct {
db *gorm.DB
store store.Store
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
logger *logrus.Entry
stopChan chan struct{}
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
queryCount atomic.Uint64
cacheHits atomic.Uint64
cacheMisses atomic.Uint64
overviewLoadCount atomic.Uint64
lastQueryTime time.Time
lastQueryMutex sync.RWMutex
}
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
qs := &DashboardQueryService{
db: db,
store: s,
logger: logger.WithField("component", "DashboardQueryService"),
stopChan: make(chan struct{}),
func NewDashboardQueryService(
db *gorm.DB,
s store.Store,
logger *logrus.Logger,
) (*DashboardQueryService, error) {
ctx, cancel := context.WithCancel(context.Background())
service := &DashboardQueryService{
db: db,
store: s,
logger: logger.WithField("component", "DashboardQuery📈"),
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
lastQueryTime: time.Now(),
}
loader := qs.loadOverviewData
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
// 创建 CacheSyncer
overviewSyncer, err := syncer.NewCacheSyncer(
service.loadOverviewData,
s,
overviewCacheChannel,
logger,
)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
}
qs.overviewSyncer = overviewSyncer
return qs, nil
service.overviewSyncer = overviewSyncer
return service, nil
}
func (s *DashboardQueryService) Start() {
s.wg.Add(2)
go s.eventListener()
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
go s.metricsReporter()
s.logger.Info("DashboardQueryService started")
}
func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService stopping...")
close(s.stopChan)
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
s.cancel()
s.wg.Wait()
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"total_queries": s.queryCount.Load(),
"cache_hits": s.cacheHits.Load(),
"cache_misses": s.cacheMisses.Load(),
"overview_loads": s.overviewLoadCount.Load(),
}).Info("DashboardQueryService stopped")
}
// ==================== 核心查询方法 ====================
// GetDashboardOverviewData 获取仪表盘概览数据(带缓存)
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
s.queryCount.Add(1)
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
s.cacheMisses.Add(1)
s.logger.Warn("Overview cache is empty, attempting to load...")
// 触发立即加载
if err := s.overviewSyncer.Invalidate(); err != nil {
return nil, fmt.Errorf("failed to trigger cache reload: %w", err)
}
// 等待加载完成最多30秒
ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout)
defer cancel()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if data := s.overviewSyncer.Get(); data != nil {
s.cacheHits.Add(1)
return data, nil
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout waiting for overview cache to load")
}
}
}
s.cacheHits.Add(1)
return cachedDataPtr, nil
}
// InvalidateOverviewCache 手动失效概览缓存
func (s *DashboardQueryService) InvalidateOverviewCache() error {
s.logger.Info("Manually invalidating overview cache")
return s.overviewSyncer.Invalidate()
}
// GetGroupStats 获取指定分组的统计数据
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
start := time.Now()
// 1. 从 Redis 获取 Key 统计
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
}
keyStats := make(map[string]int64)
for k, v := range keyStatsMap {
val, _ := strconv.ParseInt(v, 10, 64)
keyStats[k] = val
}
now := time.Now()
// 2. 查询请求统计(使用 UTC 时间)
now := time.Now().UTC()
oneHourAgo := now.Add(-1 * time.Hour)
twentyFourHoursAgo := now.Add(-24 * time.Hour)
type requestStatsResult struct {
TotalRequests int64
SuccessRequests int64
}
var last1Hour, last24Hours requestStatsResult
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour)
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours)
failureRate1h := 0.0
if last1Hour.TotalRequests > 0 {
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
}
failureRate24h := 0.0
if last24Hours.TotalRequests > 0 {
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
}
last1HourStats := map[string]any{
"total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests,
"failure_rate": failureRate1h,
}
last24HoursStats := map[string]any{
"total_requests": last24Hours.TotalRequests,
"success_requests": last24Hours.SuccessRequests,
"failure_rate": failureRate24h,
// 并发查询优化
var wg sync.WaitGroup
errChan := make(chan error, 2)
wg.Add(2)
// 查询最近1小时
go func() {
defer wg.Done()
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour).Error; err != nil {
errChan <- fmt.Errorf("failed to query 1h stats: %w", err)
}
}()
// 查询最近24小时
go func() {
defer wg.Done()
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours).Error; err != nil {
errChan <- fmt.Errorf("failed to query 24h stats: %w", err)
}
}()
wg.Wait()
close(errChan)
// 检查错误
for err := range errChan {
if err != nil {
return nil, err
}
}
// 3. 计算失败率
failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests)
failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests)
result := map[string]any{
"key_stats": keyStats,
"last_1_hour": last1HourStats,
"last_24_hours": last24HoursStats,
"key_stats": keyStats,
"last_1_hour": map[string]any{
"total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests,
"failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests,
"failure_rate": failureRate1h,
},
"last_24_hours": map[string]any{
"total_requests": last24Hours.TotalRequests,
"success_requests": last24Hours.SuccessRequests,
"failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests,
"failure_rate": failureRate24h,
},
}
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"duration": duration,
}).Debug("Group stats query completed")
return result, nil
}
func (s *DashboardQueryService) eventListener() {
ctx := context.Background()
keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close()
defer upstreamStatusSub.Close()
for {
select {
case <-keyStatusSub.Channel():
s.logger.Info("Received key status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-upstreamStatusSub.Channel():
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-s.stopChan:
s.logger.Info("Stopping dashboard event listener.")
return
}
}
}
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
}
return cachedDataPtr, nil
}
func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate()
}
// QueryHistoricalChart 查询历史图表数据
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
start := time.Now()
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
TotalRequests int64 `gorm:"column:total_requests"`
}
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
// 查询最近7天数据使用 UTC
sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour)
// 根据数据库类型构建时间格式化子句
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
selectClause := fmt.Sprintf(
"%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests",
sqlFormat,
)
// 构建查询
query := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Select(selectClause).
Where("time >= ?", sevenDaysAgo).
Group("time_label, model_name").
Order("time_label ASC")
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
var points []ChartPoint
if err := query.Find(&points).Error; err != nil {
return nil, err
return nil, fmt.Errorf("failed to query chart data: %w", err)
}
// 构建数据集
datasets := make(map[string]map[string]int64)
for _, p := range points {
if _, ok := datasets[p.ModelName]; !ok {
@@ -163,32 +302,99 @@ func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupI
}
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
}
// 生成时间标签(按小时)
var labels []string
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) {
labels = append(labels, t.Format(goFormat))
}
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
// 构建图表数据
chartData := &models.ChartData{
Labels: labels,
Datasets: make([]models.ChartDataset, 0, len(datasets)),
}
colorIndex := 0
for modelName, dataPoints := range datasets {
dataArray := make([]int64, len(labels))
for i, label := range labels {
dataArray[i] = dataPoints[label]
}
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
Label: modelName,
Data: dataArray,
Color: colorPalette[colorIndex%len(colorPalette)],
Color: chartColorPalette[colorIndex%len(chartColorPalette)],
})
colorIndex++
}
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"points": len(points),
"datasets": len(chartData.Datasets),
"duration": duration,
}).Debug("Historical chart query completed")
return chartData, nil
}
// GetRequestStatsForPeriod 获取指定时间段的请求统计
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
var startTime time.Time
now := time.Now().UTC()
switch period {
case "1m":
startTime = now.Add(-1 * time.Minute)
case "1h":
startTime = now.Add(-1 * time.Hour)
case "1d":
year, month, day := now.Date()
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
default:
return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period)
}
var result struct {
Total int64
Success int64
}
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error
if err != nil {
return nil, fmt.Errorf("failed to query request stats: %w", err)
}
return gin.H{
"period": period,
"total": result.Total,
"success": result.Success,
"failure": result.Total - result.Success,
}, nil
}
// ==================== 内部方法 ====================
// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用)
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx := context.Background()
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
s.overviewLoadCount.Add(1)
startTime := time.Now()
s.logger.Info("Starting to load dashboard overview data...")
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
@@ -200,108 +406,391 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
RequestCounts: make(map[string]int64),
}
var loadErr error
var wg sync.WaitGroup
errChan := make(chan error, 10)
// 1. 并发加载 Key 映射状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMappingStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("mapping stats: %w", err)
}
}()
// 2. 并发加载 Master Key 状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMasterStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("master stats: %w", err)
}
}()
// 3. 并发加载请求统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadRequestCounts(ctx, resp); err != nil {
errChan <- fmt.Errorf("request counts: %w", err)
}
}()
// 4. 并发加载上游健康状态
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadUpstreamHealth(ctx, resp); err != nil {
// 上游健康状态失败不阻塞整体加载
s.logger.WithError(err).Warn("Failed to load upstream health status")
}
}()
// 等待所有加载任务完成
wg.Wait()
close(errChan)
// 收集错误
for err := range errChan {
if err != nil {
loadErr = err
break
}
}
if loadErr != nil {
s.logger.WithError(loadErr).Error("Failed to load overview data")
return nil, loadErr
}
duration := time.Since(startTime)
s.logger.WithFields(logrus.Fields{
"duration": duration,
"total_keys": resp.KeyCount.Value,
"requests_1d": resp.RequestCounts["1d"],
"upstreams": len(resp.UpstreamHealthStatus),
}).Info("Successfully loaded dashboard overview data")
return resp, nil
}
// loadMappingStatusStats 加载 Key 映射状态统计
func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var mappingStatusResults []MappingStatusResult
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
var results []MappingStatusResult
if err := s.db.WithContext(ctx).
Model(&models.GroupAPIKeyMapping{}).
Select("status, COUNT(*) as count").
Group("status").
Find(&results).Error; err != nil {
return err
}
for _, res := range mappingStatusResults {
for _, res := range results {
resp.KeyStatusCount[res.Status] = res.Count
}
return nil
}
// loadMasterStatusStats 加载 Master Key 状态统计
func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var masterStatusResults []MasterStatusResult
if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query master status stats: %w", err)
var results []MasterStatusResult
if err := s.db.WithContext(ctx).
Model(&models.APIKey{}).
Select("master_status as status, COUNT(*) as count").
Group("master_status").
Find(&results).Error; err != nil {
return err
}
var totalKeys, invalidKeys int64
for _, res := range masterStatusResults {
for _, res := range results {
resp.MasterStatusCount[res.Status] = res.Count
totalKeys += res.Count
if res.Status != models.MasterStatusActive {
invalidKeys += res.Count
}
}
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
now := time.Now()
resp.KeyCount = models.StatCard{
Value: float64(totalKeys),
SubValue: invalidKeys,
SubValueTip: "非活跃身份密钥数",
}
var count1m, count1h, count1d int64
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
year, month, day := now.UTC().Date()
return nil
}
// loadRequestCounts 加载请求计数统计
func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error {
now := time.Now().UTC()
// 使用 RequestLog 表查询短期数据
var count1m, count1h int64
// 最近1分钟
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Minute)).
Count(&count1m).Error; err != nil {
return fmt.Errorf("1m count: %w", err)
}
// 最近1小时
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Hour)).
Count(&count1h).Error; err != nil {
return fmt.Errorf("1h count: %w", err)
}
// 今天UTC
year, month, day := now.Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
var count1d int64
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", startOfDay).
Count(&count1d).Error; err != nil {
return fmt.Errorf("1d count: %w", err)
}
// 最近30天使用聚合表
var count30d int64
s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
if err := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Where("time >= ?", now.AddDate(0, 0, -30)).
Select("COALESCE(SUM(request_count), 0)").
Scan(&count30d).Error; err != nil {
return fmt.Errorf("30d count: %w", err)
}
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
resp.RequestCounts["1d"] = count1d
resp.RequestCounts["30d"] = count30d
return nil
}
// loadUpstreamHealth 加载上游健康状态
func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error {
var upstreams []*models.UpstreamEndpoint
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
} else {
for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status
return err
}
for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status
}
return nil
}
// ==================== 事件监听 ====================
// eventListener 监听缓存失效事件
func (s *DashboardQueryService) eventListener() {
defer s.wg.Done()
// 订阅事件
keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged)
// 错误处理
if err1 != nil {
s.logger.WithError(err1).Error("Failed to subscribe to key status events")
keyStatusSub = nil
}
if err2 != nil {
s.logger.WithError(err2).Error("Failed to subscribe to upstream status events")
upstreamStatusSub = nil
}
// 如果全部失败,直接返回
if keyStatusSub == nil && upstreamStatusSub == nil {
s.logger.Error("All event subscriptions failed, listener disabled")
return
}
// 安全关闭订阅
defer func() {
if keyStatusSub != nil {
if err := keyStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close key status subscription")
}
}
if upstreamStatusSub != nil {
if err := upstreamStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close upstream status subscription")
}
}
}()
s.logger.WithFields(logrus.Fields{
"key_status_sub": keyStatusSub != nil,
"upstream_status_sub": upstreamStatusSub != nil,
}).Info("Event listener started")
neverReady := make(chan *store.Message)
close(neverReady) // 立即关闭,确保永远不会阻塞
for {
// 动态选择有效的 channel
var keyStatusChan <-chan *store.Message = neverReady
if keyStatusSub != nil {
keyStatusChan = keyStatusSub.Channel()
}
var upstreamStatusChan <-chan *store.Message = neverReady
if upstreamStatusSub != nil {
upstreamStatusChan = upstreamStatusSub.Channel()
}
select {
case _, ok := <-keyStatusChan:
if !ok {
s.logger.Warn("Key status channel closed")
keyStatusSub = nil
continue
}
s.logger.Debug("Received key status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on key status change")
}
case _, ok := <-upstreamStatusChan:
if !ok {
s.logger.Warn("Upstream status channel closed")
upstreamStatusSub = nil
continue
}
s.logger.Debug("Received upstream status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change")
}
case <-s.stopChan:
s.logger.Info("Event listener stopping (stopChan)")
return
case <-s.ctx.Done():
s.logger.Info("Event listener stopping (context cancelled)")
return
}
}
duration := time.Since(startTime)
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
return resp, nil
}
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
var startTime time.Time
now := time.Now()
switch period {
case "1m":
startTime = now.Add(-1 * time.Minute)
case "1h":
startTime = now.Add(-1 * time.Hour)
case "1d":
year, month, day := now.UTC().Date()
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
default:
return nil, fmt.Errorf("invalid period specified: %s", period)
}
var result struct {
Total int64
Success int64
}
// ==================== 监控指标 ====================
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error
if err != nil {
return nil, err
// metricsReporter 定期输出统计信息
func (s *DashboardQueryService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
return gin.H{
"total": result.Total,
"success": result.Success,
"failure": result.Total - result.Success,
}, nil
}
func (s *DashboardQueryService) reportMetrics() {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
totalQueries := s.queryCount.Load()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
s.logger.WithFields(logrus.Fields{
"total_queries": totalQueries,
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate),
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Round(time.Second),
}).Info("DashboardQuery metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *DashboardQueryService) GetMetrics() map[string]interface{} {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
return map[string]interface{}{
"total_queries": s.queryCount.Load(),
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": cacheHitRate,
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Seconds(),
}
}
// ==================== 辅助方法 ====================
// calculateFailureRate 计算失败率
func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 {
if total == 0 {
return 0.0
}
return float64(total-success) / float64(total) * 100
}
// updateLastQueryTime 更新最后查询时间
func (s *DashboardQueryService) updateLastQueryTime() {
s.lastQueryMutex.Lock()
s.lastQueryTime = time.Now()
s.lastQueryMutex.Unlock()
}
// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
dialect := s.db.Dialector.Name()
switch dialect {
case "mysql":
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
case "postgres":
return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00"
case "sqlite":
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
default:
s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format")
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
}
}

View File

@@ -4,11 +4,13 @@ package service
import (
"context"
"encoding/json"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"sync"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
@@ -18,25 +20,47 @@ type DBLogWriterService struct {
db *gorm.DB
store store.Store
logger *logrus.Entry
logBuffer chan *models.RequestLog
stopChan chan struct{}
wg sync.WaitGroup
SettingsManager *settings.SettingsManager
settingsManager *settings.SettingsManager
logBuffer chan *models.RequestLog
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
totalReceived atomic.Uint64
totalFlushed atomic.Uint64
totalDropped atomic.Uint64
flushCount atomic.Uint64
lastFlushTime time.Time
lastFlushMutex sync.RWMutex
}
func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService {
cfg := settings.GetSettings()
func NewDBLogWriterService(
db *gorm.DB,
s store.Store,
settingsManager *settings.SettingsManager,
logger *logrus.Logger,
) *DBLogWriterService {
cfg := settingsManager.GetSettings()
bufferCapacity := cfg.LogBufferCapacity
if bufferCapacity <= 0 {
bufferCapacity = 1000
}
ctx, cancel := context.WithCancel(context.Background())
return &DBLogWriterService{
db: db,
store: s,
SettingsManager: settings,
settingsManager: settingsManager,
logger: logger.WithField("component", "DBLogWriter📝"),
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
lastFlushTime: time.Now(),
}
}
@@ -44,93 +68,276 @@ func (s *DBLogWriterService) Start() {
s.wg.Add(2)
go s.eventListenerLoop()
go s.dbWriterLoop()
s.logger.Info("DBLogWriterService started.")
// 定期输出统计信息
s.wg.Add(1)
go s.metricsReporter()
s.logger.WithFields(logrus.Fields{
"buffer_capacity": cap(s.logBuffer),
}).Info("DBLogWriterService started")
}
func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan)
s.cancel() // 取消上下文
s.wg.Wait()
s.logger.Info("DBLogWriterService stopped.")
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"total_received": s.totalReceived.Load(),
"total_flushed": s.totalFlushed.Load(),
"total_dropped": s.totalDropped.Load(),
"flush_count": s.flushCount.Load(),
}).Info("DBLogWriterService stopped")
}
// 事件监听循环
func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done()
ctx := context.Background()
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
s.logger.WithError(err).Error("Failed to subscribe to request events, log writing disabled")
return
}
defer sub.Close()
defer func() {
if err := sub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close subscription")
}
}()
s.logger.Info("Subscribed to request events for database logging.")
s.logger.Info("Subscribed to request events for database logging")
for {
select {
case msg := <-sub.Channel():
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
continue
}
select {
case s.logBuffer <- &event.RequestLog:
default:
s.logger.Warn("Log buffer is full. A log message might be dropped.")
}
s.handleMessage(msg)
case <-s.stopChan:
s.logger.Info("Event listener loop stopping.")
s.logger.Info("Event listener loop stopping")
close(s.logBuffer)
return
case <-s.ctx.Done():
s.logger.Info("Event listener context cancelled")
close(s.logBuffer)
return
}
}
}
// 处理单条消息
func (s *DBLogWriterService) handleMessage(msg *store.Message) {
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal request event")
return
}
s.totalReceived.Add(1)
select {
case s.logBuffer <- &event.RequestLog:
// 成功入队
default:
// 缓冲区满,丢弃日志
dropped := s.totalDropped.Add(1)
if dropped%100 == 1 { // 每100条丢失输出一次警告
s.logger.WithFields(logrus.Fields{
"total_dropped": dropped,
"buffer_capacity": cap(s.logBuffer),
"buffer_len": len(s.logBuffer),
}).Warn("Log buffer full, messages being dropped")
}
}
}
// 数据库写入循环
func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done()
cfg := s.SettingsManager.GetSettings()
cfg := s.settingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 {
batchSize = 100
}
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushTimeout <= 0 {
flushTimeout = 5 * time.Second
flushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushInterval <= 0 {
flushInterval = 5 * time.Second
}
s.logger.WithFields(logrus.Fields{
"batch_size": batchSize,
"flush_interval": flushInterval,
}).Info("DB writer loop started")
batch := make([]*models.RequestLog, 0, batchSize)
ticker := time.NewTicker(flushTimeout)
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
// 配置热更新检查(每分钟)
configTicker := time.NewTicker(1 * time.Minute)
defer configTicker.Stop()
for {
select {
case logEntry, ok := <-s.logBuffer:
if !ok {
// 通道关闭,刷新剩余日志
if len(batch) > 0 {
s.flushBatch(batch)
}
s.logger.Info("DB writer loop finished.")
s.logger.Info("DB writer loop finished")
return
}
batch = append(batch, logEntry)
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
case <-ticker.C:
if len(batch) > 0 {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
case <-configTicker.C:
// 热更新配置
cfg := s.settingsManager.GetSettings()
newBatchSize := cfg.LogFlushBatchSize
if newBatchSize <= 0 {
newBatchSize = 100
}
newFlushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if newFlushInterval <= 0 {
newFlushInterval = 5 * time.Second
}
if newBatchSize != batchSize {
s.logger.WithFields(logrus.Fields{
"old": batchSize,
"new": newBatchSize,
}).Info("Batch size updated")
batchSize = newBatchSize
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
}
if newFlushInterval != flushInterval {
s.logger.WithFields(logrus.Fields{
"old": flushInterval,
"new": newFlushInterval,
}).Info("Flush interval updated")
flushInterval = newFlushInterval
ticker.Reset(flushInterval)
}
}
}
}
// 批量刷写到数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
if len(batch) == 0 {
return
}
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err := s.db.WithContext(ctx).CreateInBatches(batch, len(batch)).Error
duration := time.Since(start)
s.lastFlushMutex.Lock()
s.lastFlushTime = time.Now()
s.lastFlushMutex.Unlock()
if err != nil {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
}).WithError(err).Error("Failed to flush log batch to database")
} else {
s.logger.Infof("Successfully flushed %d logs to database.", len(batch))
flushed := s.totalFlushed.Add(uint64(len(batch)))
flushCount := s.flushCount.Add(1)
// 只在慢写入或大批量时输出日志
if duration > 1*time.Second || len(batch) > 500 {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
"total_flushed": flushed,
"flush_count": flushCount,
}).Info("Log batch flushed to database")
} else {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
}).Debug("Log batch flushed to database")
}
}
}
// 定期输出统计信息
func (s *DBLogWriterService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *DBLogWriterService) reportMetrics() {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.totalReceived.Load()
flushed := s.totalFlushed.Load()
dropped := s.totalDropped.Load()
pending := uint64(len(s.logBuffer))
s.logger.WithFields(logrus.Fields{
"received": received,
"flushed": flushed,
"dropped": dropped,
"pending": pending,
"flush_count": s.flushCount.Load(),
"last_flush": time.Since(lastFlush).Round(time.Second),
"buffer_usage": float64(pending) / float64(cap(s.logBuffer)) * 100,
"success_rate": float64(flushed) / float64(received) * 100,
}).Info("DBLogWriter metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *DBLogWriterService) GetMetrics() map[string]interface{} {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
return map[string]interface{}{
"total_received": s.totalReceived.Load(),
"total_flushed": s.totalFlushed.Load(),
"total_dropped": s.totalDropped.Load(),
"flush_count": s.flushCount.Load(),
"buffer_pending": len(s.logBuffer),
"buffer_capacity": cap(s.logBuffer),
"last_flush_ago": time.Since(lastFlush).Seconds(),
}
}

View File

@@ -334,7 +334,6 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
globalSettings := gm.settingsManager.GetSettings()
defaultModel := "gemini-1.5-flash"
opConfig := &models.KeyGroupSettings{
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
@@ -342,7 +341,7 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
KeyCheckModel: &defaultModel,
KeyCheckModel: &globalSettings.BaseKeyCheckModel,
MaxRetries: &globalSettings.MaxRetries,
EnableSmartGateway: &globalSettings.EnableSmartGateway,
}

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,10 @@ const (
TaskTypeHardDeleteKeys = "hard_delete_keys"
TaskTypeRestoreKeys = "restore_keys"
chunkSize = 500
// 任务超时时间常量化
defaultTaskTimeout = 15 * time.Minute
longTaskTimeout = time.Hour
)
type KeyImportService struct {
@@ -43,17 +47,19 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
}
}
// runTaskWithRecovery 统一的任务恢复包装器
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err)
s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
taskFunc()
}
// StartAddKeysTask 启动批量添加密钥任务
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
@@ -61,260 +67,404 @@ func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, k
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
})
return taskStatus, nil
}
// StartUnlinkKeysTask 启动批量解绑密钥任务
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
})
return taskStatus, nil
}
// StartHardDeleteKeysTask 启动硬删除密钥任务
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
// StartRestoreKeysTask 启动恢复密钥任务
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeyStrings []string
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
}
}
if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return
}
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
// StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
}
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
return
}
alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
var keysToLink []models.APIKey
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
}
// ==================== 核心任务执行逻辑 ====================
// runAddKeysTask 执行批量添加密钥
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 1. 去重
uniqueKeys := s.deduplicateKeys(keys)
if len(uniqueKeys) == 0 {
s.endTaskWithResult(ctx, taskID, resourceID, gin.H{
"newly_linked_count": 0,
"already_linked_count": 0,
}, nil)
return
}
// 2. 确保所有密钥在数据库中存在(幂等操作)
allKeyModels, err := s.ensureKeysExist(uniqueKeys)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
}
// 3. 过滤已关联的密钥
keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err))
return
}
// 4. 更新任务的实际处理总数
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
// 5. 批量关联密钥到组
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
for i := 0; i < len(idsToLink); i += chunkSize {
end := i + chunkSize
if end > len(idsToLink) {
end = len(idsToLink)
}
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
return
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
// 6. 根据验证标志处理密钥状态
if len(keysToLink) > 0 {
s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport)
}
// 7. 返回结果
result := gin.H{
"newly_linked_count": len(keysToLink),
"already_linked_count": alreadyLinkedCount,
"total_linked_count": len(allKeyModels),
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
}
// runUnlinkKeysTask 执行批量解绑密钥
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
// 1. 去重
uniqueKeys := s.deduplicateKeys(keys)
// 2. 查找需要解绑的密钥
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
}
if len(keysToUnlink) == 0 {
result := gin.H{
"unlinked_count": 0,
"hard_deleted_count": 0,
"not_found_count": len(uniqueKeys),
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
return
}
// 3. 提取密钥 ID
idsToUnlink := s.extractKeyIDs(keysToUnlink)
// 4. 更新任务总数
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
// 5. 批量解绑
totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
// 6. 清理孤立密钥
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
if err != nil {
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
}
// 7. 返回结果
result := gin.H{
"unlinked_count": totalUnlinked,
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
}
// runHardDeleteKeysTask 执行硬删除密钥
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
return s.keyRepo.HardDeleteByValues(chunk)
})
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
result := gin.H{
"newly_linked_count": len(keysToLink),
"already_linked_count": len(alreadyLinkedIDSet),
"total_linked_count": len(allKeyModels),
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
if validateOnImport {
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
}
}
}
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeys []string
// runRestoreKeysTask 执行恢复密钥
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
})
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
// ==================== 辅助方法 ====================
// deduplicateKeys 去重密钥列表
func (s *KeyImportService) deduplicateKeys(keys []string) []string {
uniqueKeysMap := make(map[string]struct{}, len(keys))
uniqueKeys := make([]string, 0, len(keys))
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeys = append(uniqueKeys, kStr)
}
}
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
return uniqueKeys
}
// ensureKeysExist 确保所有密钥在数据库中存在
func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) {
keysToEnsure := make([]models.APIKey, len(keys))
for i, keyStr := range keys {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
return s.keyRepo.AddKeys(keysToEnsure)
}
// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥
func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) {
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
return nil, 0, err
}
if len(keysToUnlink) == 0 {
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
return
}
idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID
alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels))
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
var totalUnlinked int64
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize
if end > len(idsToUnlink) {
end = len(idsToUnlink)
keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet))
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
}
return keysToLink, len(alreadyLinkedIDSet), nil
}
// extractKeyIDs 提取密钥 ID 列表
func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint {
ids := make([]uint, len(keys))
for i, key := range keys {
ids[i] = key.ID
}
return ids
}
// linkKeysInChunks 分块关联密钥到组
func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error {
idsToLink := s.extractKeyIDs(keysToLink)
for i := 0; i < len(idsToLink); i += chunkSize {
end := min(i+chunkSize, len(idsToLink))
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
return fmt.Errorf("chunk failed to link keys: %w", err)
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
return nil
}
// unlinkKeysInChunks 分块解绑密钥
func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) {
var totalUnlinked int64
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := min(i+chunkSize, len(idsToUnlink))
chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
return
return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err)
}
totalUnlinked += unlinked
// 发布解绑事件
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
return totalUnlinked, nil
}
// processKeysInChunks 通用的分块处理密钥逻辑
func (s *KeyImportService) processKeysInChunks(
ctx context.Context,
taskID string,
keys []string,
processFunc func(chunk []string) (int64, error),
) (int64, error) {
var totalProcessed int64
for i := 0; i < len(keys); i += chunkSize {
end := min(i+chunkSize, len(keys))
chunk := keys[i:end]
count, err := processFunc(chunk)
if err != nil {
return 0, fmt.Errorf("failed to process chunk: %w", err)
}
totalProcessed += count
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
return totalProcessed, nil
}
// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活)
func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) {
idsToLink := s.extractKeyIDs(keysToLink)
if validateOnImport {
// 发布批量导入完成事件,触发验证
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
// 发布单个密钥状态变更事件
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
// 直接激活密钥,不进行验证
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
}).Errorf("Failed to directly activate key: %v", err)
}
}
}
}
// endTaskWithResult 统一的任务结束处理
func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) {
if err != nil {
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"resource_id": resourceID,
}).WithError(err).Error("Task failed")
}
result := gin.H{
"unlinked_count": totalUnlinked,
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err)
}
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return
}
totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var restoredCount int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return
}
restoredCount += count
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
// ==================== 事件发布方法 ====================
// publishSingleKeyChangeEvent 发布单个密钥状态变更事件
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
@@ -324,56 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, grou
ChangeReason: reason,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"reason": reason,
}).Error("Failed to publish single key change event.")
}).WithError(err).Error("Failed to marshal key change event")
return
}
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"reason": reason,
}).WithError(err).Error("Failed to publish single key change event")
}
}
// publishChangeEvent 发布通用变更事件
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
ChangeReason: reason,
}
eventData, _ := json.Marshal(event)
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"reason": reason,
}).WithError(err).Error("Failed to marshal change event")
return
}
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"reason": reason,
}).WithError(err).Error("Failed to publish change event")
}
}
// publishImportGroupCompletedEvent 发布批量导入完成事件
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 {
return
}
event := models.ImportGroupCompletedEvent{
GroupID: groupID,
KeyIDs: keyIDs,
CompletedAt: time.Now(),
}
eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent")
return
}
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).WithError(err).Error("Failed to publish ImportGroupCompletedEvent")
} else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).Info("Published ImportGroupCompletedEvent")
}
}
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
// min 返回两个整数中的较小值
func min(a, b int) int {
if a < b {
return a
}
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
return b
}

View File

@@ -25,26 +25,38 @@ import (
)
const (
TaskTypeTestKeys = "test_keys"
TaskTypeTestKeys = "test_keys"
defaultConcurrency = 10
maxValidationConcurrency = 100
validationTaskTimeout = time.Hour
)
type KeyValidationService struct {
taskService task.Reporter
channel channel.ChannelProxy
db *gorm.DB
SettingsManager *settings.SettingsManager
settingsManager *settings.SettingsManager
groupManager *GroupManager
store store.Store
keyRepo repository.KeyRepository
logger *logrus.Entry
}
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
func NewKeyValidationService(
ts task.Reporter,
ch channel.ChannelProxy,
db *gorm.DB,
ss *settings.SettingsManager,
gm *GroupManager,
st store.Store,
kr repository.KeyRepository,
logger *logrus.Logger,
) *KeyValidationService {
return &KeyValidationService{
taskService: ts,
channel: ch,
db: db,
SettingsManager: ss,
settingsManager: ss,
groupManager: gm,
store: st,
keyRepo: kr,
@@ -52,33 +64,393 @@ func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm
}
}
// ==================== 公开接口 ====================
// ValidateSingleKey 验证单个密钥
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
// 1. 解密密钥
if err := s.keyRepo.Decrypt(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
}
// 2. 创建 HTTP 客户端和请求
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
s.logger.WithFields(logrus.Fields{
"key_id": key.ID,
"endpoint": endpoint,
}).Error("Failed to create validation request")
return fmt.Errorf("failed to create request: %w", err)
}
// 3. 修改请求(添加密钥认证头)
s.channel.ModifyRequest(req, key)
// 4. 执行请求
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
// 5. 检查响应状态
if resp.StatusCode == http.StatusOK {
return nil
}
// 6. 处理错误响应
return s.buildValidationError(resp)
}
// StartTestKeysTask 启动批量密钥测试任务
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
// 1. 解析和验证输入
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
}
// 2. 查询密钥模型
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(apiKeyModels) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
}
// 3. 批量解密密钥
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
}
// 4. 获取组配置
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
// 5. 构建验证端点
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
}
// 6. 创建任务
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
if err != nil {
return nil, err
}
// 7. 准备任务参数
params := s.buildValidationParams(opConfig)
// 8. 启动异步验证任务
go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
return taskStatus, nil
}
// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"statuses": statuses,
}).Info("Starting test task with status filter")
// 1. 根据过滤条件查询密钥
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyValues) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
}
// 2. 转换为文本格式并启动任务
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}
// ==================== 核心任务执行逻辑 ====================
// validationParams 验证参数封装
type validationParams struct {
timeout time.Duration
concurrency int
}
// buildValidationParams 构建验证参数
func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
settings := s.settingsManager.GetSettings()
// 从配置读取超时时间(而非硬编码)
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second // 仅在配置无效时使用默认值
}
// 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
var concurrency int
if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
concurrency = *opConfig.KeyCheckConcurrency
} else if settings.BaseKeyCheckConcurrency > 0 {
concurrency = settings.BaseKeyCheckConcurrency
} else {
concurrency = defaultConcurrency // 兜底默认值
}
// 限制最大并发数(防护措施)
if concurrency > maxValidationConcurrency {
concurrency = maxValidationConcurrency
}
return validationParams{
timeout: timeout,
concurrency: concurrency,
}
}
// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
func (s *KeyValidationService) runTestKeysTaskWithRecovery(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
}
// runTestKeysTask 执行批量密钥验证任务
func (s *KeyValidationService) runTestKeysTask(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"key_count": len(keys),
"concurrency": params.concurrency,
"timeout": params.timeout,
}).Info("Starting validation task")
// 1. 初始化结果收集
results := make([]models.KeyTestResult, len(keys))
// 2. 创建任务分发器
dispatcher := newValidationDispatcher(
keys,
params.concurrency,
s,
ctx,
taskID,
groupID,
endpoint,
params.timeout,
)
// 3. 执行并发验证
dispatcher.run(results)
// 4. 完成任务
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"processed": len(results),
}).Info("Validation task completed")
}
// ==================== 验证调度器 ====================
// validationJob 验证作业
type validationJob struct {
index int
key models.APIKey
}
// validationDispatcher 验证任务分发器
type validationDispatcher struct {
keys []models.APIKey
concurrency int
service *KeyValidationService
ctx context.Context
taskID string
groupID uint
endpoint string
timeout time.Duration
mu sync.Mutex
processedCount int
}
// newValidationDispatcher 创建验证分发器
func newValidationDispatcher(
keys []models.APIKey,
concurrency int,
service *KeyValidationService,
ctx context.Context,
taskID string,
groupID uint,
endpoint string,
timeout time.Duration,
) *validationDispatcher {
return &validationDispatcher{
keys: keys,
concurrency: concurrency,
service: service,
ctx: ctx,
taskID: taskID,
groupID: groupID,
endpoint: endpoint,
timeout: timeout,
}
}
// run 执行并发验证
func (d *validationDispatcher) run(results []models.KeyTestResult) {
var wg sync.WaitGroup
jobs := make(chan validationJob, len(d.keys))
// 启动 worker pool
for i := 0; i < d.concurrency; i++ {
wg.Add(1)
go d.worker(&wg, jobs, results)
}
// 分发任务
for i, key := range d.keys {
jobs <- validationJob{index: i, key: key}
}
close(jobs)
// 等待所有 worker 完成
wg.Wait()
}
// worker 验证工作协程
func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
defer wg.Done()
for job := range jobs {
result := d.validateKey(job.key)
d.mu.Lock()
results[job.index] = result
d.processedCount++
_ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
d.mu.Unlock()
}
}
// validateKey 验证单个密钥并返回结果
func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
// 1. 执行验证
validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
// 2. 构建结果和事件
result, event := d.buildResultAndEvent(key, validationErr)
// 3. 发布验证事件
d.publishValidationEvent(key.ID, event)
return result
}
// buildResultAndEvent 构建验证结果和事件
func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
GroupID: &d.groupID,
KeyID: &key.ID,
},
}
if validationErr == nil {
// 验证成功
event.RequestLog.IsSuccess = true
return models.KeyTestResult{
Key: key.APIKey,
Status: "valid",
Message: "Validation successful",
}, event
}
// 验证失败
event.RequestLog.IsSuccess = false
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
event.Error = apiErr
return models.KeyTestResult{
Key: key.APIKey,
Status: "invalid",
Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
}, event
}
// 其他错误
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
return models.KeyTestResult{
Key: key.APIKey,
Status: "error",
Message: "Validation check failed: " + validationErr.Error(),
}, event
}
// publishValidationEvent 发布验证事件
func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
eventData, err := json.Marshal(event)
if err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to marshal validation event")
return
}
if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to publish validation event")
}
}
// ==================== 辅助方法 ====================
// buildValidationError 构建验证错误
func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
errorMsg = "Failed to read error response body"
s.logger.WithError(readErr).Warn("Failed to read validation error response")
} else {
errorMsg = string(bodyBytes)
}
@@ -89,128 +461,3 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
Code: "VALIDATION_FAILED",
}
}
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
}
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(apiKeyModels) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
}
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
}
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil {
return nil, err
}
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
return nil, err
}
var concurrency int
if opConfig.KeyCheckConcurrency != nil {
concurrency = *opConfig.KeyCheckConcurrency
} else {
concurrency = settings.BaseKeyCheckConcurrency
}
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
return taskStatus, nil
}
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup
var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys))
processedCount := 0
if concurrency <= 0 {
concurrency = 10
}
type job struct {
Index int
Value models.APIKey
}
jobs := make(chan job, len(keys))
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := range jobs {
apiKeyModel := j.Value
keyToValidate := apiKeyModel
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
GroupID: &groupID,
KeyID: &apiKeyModel.ID,
},
}
if validationErr == nil {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
event.RequestLog.IsSuccess = true
} else {
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
event.Error = apiErr
} else {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
}
event.RequestLog.IsSuccess = false
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
}
mu.Lock()
finalResults[j.Index] = currentResult
processedCount++
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
mu.Unlock()
}
}()
}
for i, k := range keys {
jobs <- job{Index: i, Value: k}
}
close(jobs)
wg.Wait()
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
}
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyValues) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -1,78 +1,152 @@
// Filename: internal/service/log_service.go
package service
import (
"context"
"fmt"
"gemini-balancer/internal/models"
"strconv"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
type LogService struct {
db *gorm.DB
db *gorm.DB
logger *logrus.Entry
}
func NewLogService(db *gorm.DB) *LogService {
return &LogService{db: db}
func NewLogService(db *gorm.DB, logger *logrus.Logger) *LogService {
return &LogService{
db: db,
logger: logger.WithField("component", "LogService"),
}
}
func (s *LogService) Record(log *models.RequestLog) error {
return s.db.Create(log).Error
func (s *LogService) Record(ctx context.Context, log *models.RequestLog) error {
return s.db.WithContext(ctx).Create(log).Error
}
func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, int64, error) {
// LogQueryParams 解耦 Gin使用结构体传参
type LogQueryParams struct {
Page int
PageSize int
ModelName string
IsSuccess *bool // 使用指针区分"未设置"和"false"
StatusCode *int
KeyID *uint64
GroupID *uint64
}
func (s *LogService) GetLogs(ctx context.Context, params LogQueryParams) ([]models.RequestLog, int64, error) {
// 参数校验
if params.Page < 1 {
params.Page = 1
}
if params.PageSize < 1 || params.PageSize > 100 {
params.PageSize = 20
}
var logs []models.RequestLog
var total int64
query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c))
// 构建基础查询
query := s.db.WithContext(ctx).Model(&models.RequestLog{})
query = s.applyFilters(query, params)
// 计算总数
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to count logs: %w", err)
}
if total == 0 {
return []models.RequestLog{}, 0, nil
}
// 再执行分页查询
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
offset := (page - 1) * pageSize
err := query.Order("request_time desc").Limit(pageSize).Offset(offset).Find(&logs).Error
if err != nil {
return nil, 0, err
// 分页查询
offset := (params.Page - 1) * params.PageSize
if err := query.Order("request_time DESC").
Limit(params.PageSize).
Offset(offset).
Find(&logs).Error; err != nil {
return nil, 0, fmt.Errorf("failed to query logs: %w", err)
}
return logs, total, nil
}
func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if modelName := c.Query("model_name"); modelName != "" {
db = db.Where("model_name = ?", modelName)
}
if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
db = db.Where("is_success = ?", isSuccess)
}
}
if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
db = db.Where("status_code = ?", statusCode)
}
}
if keyIDStr := c.Query("key_id"); keyIDStr != "" {
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
db = db.Where("key_id = ?", keyID)
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
db = db.Where("group_id = ?", groupID)
}
}
return db
func (s *LogService) applyFilters(query *gorm.DB, params LogQueryParams) *gorm.DB {
if params.ModelName != "" {
query = query.Where("model_name = ?", params.ModelName)
}
if params.IsSuccess != nil {
query = query.Where("is_success = ?", *params.IsSuccess)
}
if params.StatusCode != nil {
query = query.Where("status_code = ?", *params.StatusCode)
}
if params.KeyID != nil {
query = query.Where("key_id = ?", *params.KeyID)
}
if params.GroupID != nil {
query = query.Where("group_id = ?", *params.GroupID)
}
return query
}
// ParseLogQueryParams 在 Handler 层调用,解析 Gin 参数
func ParseLogQueryParams(queryParams map[string]string) (LogQueryParams, error) {
params := LogQueryParams{
Page: 1,
PageSize: 20,
}
if pageStr, ok := queryParams["page"]; ok {
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
params.Page = page
}
}
if pageSizeStr, ok := queryParams["page_size"]; ok {
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
params.PageSize = pageSize
}
}
if modelName, ok := queryParams["model_name"]; ok {
params.ModelName = modelName
}
if isSuccessStr, ok := queryParams["is_success"]; ok {
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
params.IsSuccess = &isSuccess
} else {
return params, fmt.Errorf("invalid is_success parameter: %s", isSuccessStr)
}
}
if statusCodeStr, ok := queryParams["status_code"]; ok {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
params.StatusCode = &statusCode
} else {
return params, fmt.Errorf("invalid status_code parameter: %s", statusCodeStr)
}
}
if keyIDStr, ok := queryParams["key_id"]; ok {
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
params.KeyID = &keyID
} else {
return params, fmt.Errorf("invalid key_id parameter: %s", keyIDStr)
}
}
if groupIDStr, ok := queryParams["group_id"]; ok {
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
params.GroupID = &groupID
} else {
return params, fmt.Errorf("invalid group_id parameter: %s", groupIDStr)
}
}
return params, nil
}

View File

@@ -35,34 +35,55 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
go func() {
defer sub.Close()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
}
}
}()
go s.listenForEvents()
}
func (s *StatsService) Stop() {
close(s.stopChan)
}
func (s *StatsService) listenForEvents() {
for {
select {
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
default:
}
ctx, cancel := context.WithCancel(context.Background())
sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
if err != nil {
s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err)
cancel()
time.Sleep(5 * time.Second)
continue
}
s.logger.Info("Subscribed to key status changes")
s.handleSubscription(sub, cancel)
}
}
func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) {
defer sub.Close()
defer cancel()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
return
}
}
}
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
if event.GroupID == 0 {
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
@@ -75,23 +96,47 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
@@ -113,13 +158,16 @@ func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uin
}
statsKey := fmt.Sprintf("stats:group:%d", groupID)
updates := make(map[string]interface{})
totalKeys := int64(0)
updates := map[string]interface{}{
"active_keys": int64(0),
"disabled_keys": int64(0),
"error_keys": int64(0),
"total_keys": int64(0),
}
for _, res := range results {
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
totalKeys += res.Count
updates["total_keys"] = updates["total_keys"].(int64) + res.Count
}
updates["total_keys"] = totalKeys
if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
@@ -180,8 +228,18 @@ func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
})
}
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error
}).Create(&hourlyStats).Error; err != nil {
return err
}
if err := s.db.WithContext(ctx).
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Delete(&models.RequestLog{}).Error; err != nil {
s.logger.WithError(err).Warn("Failed to delete aggregated request logs")
}
return nil
}

View File

@@ -37,7 +37,7 @@ func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, log
return tokens, nil
}
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged)
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged, logger)
if err != nil {
return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
}