Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -75,7 +75,7 @@ func NewHealthCheckService(
func (s *HealthCheckService) Start() {
s.logger.Info("Starting HealthCheckService with independent check loops...")
s.wg.Add(4) // Now four loops
s.wg.Add(4)
go s.runKeyCheckLoop()
go s.runUpstreamCheckLoop()
go s.runProxyCheckLoop()
@@ -102,8 +102,6 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
func (s *HealthCheckService) runKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Key check dynamic scheduler loop started.")
// 主调度循环,每分钟检查一次任务
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -126,26 +124,22 @@ func (s *HealthCheckService) scheduleKeyChecks() {
defer s.groupCheckTimeMutex.Unlock()
for _, group := range groups {
// 获取特定于组的运营配置
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
continue
}
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
continue // 跳过禁用了健康检查的组
continue
}
var intervalMinutes int
if opConfig.KeyCheckIntervalMinutes != nil {
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
}
interval := time.Duration(intervalMinutes) * time.Minute
if interval <= 0 {
continue // 跳过无效的检查周期
continue
}
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
go s.performKeyChecksForGroup(group, opConfig)
@@ -160,7 +154,6 @@ func (s *HealthCheckService) runUpstreamCheckLoop() {
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
s.performUpstreamChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -184,7 +177,6 @@ func (s *HealthCheckService) runProxyCheckLoop() {
if s.SettingsManager.GetSettings().EnableProxyCheck {
s.performProxyChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -203,6 +195,7 @@ func (s *HealthCheckService) runProxyCheckLoop() {
}
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -213,11 +206,9 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
}
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
log.Infof("Starting key health check cycle.")
var mappingsToCheck []models.GroupAPIKeyMapping
err = s.db.Model(&models.GroupAPIKeyMapping{}).
err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
Where("group_api_key_mappings.key_group_id = ?", group.ID).
Where("api_keys.master_status = ?", models.MasterStatusActive).
@@ -233,7 +224,6 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("No key mappings to check for this group.")
return
}
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
var wg sync.WaitGroup
@@ -242,14 +232,14 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 1 // 保证至少有一个 worker
concurrency = 1
}
for w := 1; w <= concurrency; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for mapping := range jobs {
s.checkAndProcessMapping(&mapping, timeout, endpoint)
s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
}
}(w)
}
@@ -261,52 +251,46 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("Finished key health check cycle.")
}
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
if mapping.APIKey == nil {
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
return
}
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
// --- 诊断一:验证成功 (健康) ---
if validationErr == nil {
if mapping.Status != models.StatusActive {
s.activateMapping(mapping)
s.activateMapping(ctx, mapping)
}
return
}
errorString := validationErr.Error()
// --- 诊断二:永久性错误 ---
if CustomErrors.IsPermanentUpstreamError(errorString) {
s.revokeMapping(mapping, validationErr)
s.revokeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断三:暂时性错误 ---
if CustomErrors.IsTemporaryUpstreamError(errorString) {
// Log with a higher level (WARN) since this is an actionable, proactive finding.
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
s.penalizeMapping(mapping, validationErr)
s.penalizeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断四:其他未知或上游服务错误 ---
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
}
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
oldStatus := mapping.Status
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
// Re-fetch group-specific operational config to get the correct thresholds
func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
if !ok {
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
@@ -320,7 +304,6 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
oldStatus := mapping.Status
mapping.LastError = err.Error()
mapping.ConsecutiveErrorCount++
// Use the group-specific threshold
threshold := *opConfig.KeyBlacklistThreshold
if mapping.ConsecutiveErrorCount >= threshold {
mapping.Status = models.StatusCooldown
@@ -329,44 +312,41 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
mapping.CooldownUntil = &cooldownTime
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
}
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
if oldStatus != mapping.Status {
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
}
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
oldStatus := mapping.Status
if oldStatus == models.StatusBanned {
return // Already banned, do nothing.
return
}
mapping.Status = models.StatusBanned
mapping.LastError = "Definitive error: " + err.Error()
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
mapping.ConsecutiveErrorCount = 0
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
}
}
func (s *HealthCheckService) performUpstreamChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
return
}
@@ -390,10 +370,10 @@ func (s *HealthCheckService) performUpstreamChecks() {
s.lastResultsMutex.Unlock()
if oldStatus != newStatus {
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
} else {
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
}
}
}(u)
@@ -412,10 +392,11 @@ func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration)
}
func (s *HealthCheckService) performProxyChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
var proxies []*models.ProxyConfig
if err := s.db.Find(&proxies).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve proxies.")
return
}
@@ -438,7 +419,7 @@ func (s *HealthCheckService) performProxyChecks() {
s.lastResultsMutex.Unlock()
if proxyCfg.Status != newStatus {
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
}
}
@@ -482,7 +463,7 @@ func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout ti
return true
}
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
event := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -496,12 +477,12 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, o
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
return
}
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
}
}
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
event := models.UpstreamHealthChangedEvent{
UpstreamID: upstream.ID,
UpstreamURL: upstream.URL,
@@ -516,28 +497,20 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
return
}
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
}
}
// =========================================================================
// Global Base Key Check (New Logic)
// =========================================================================
func (s *HealthCheckService) runBaseKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Global base key check loop started.")
settings := s.SettingsManager.GetSettings()
if !settings.EnableBaseKeyCheck {
s.logger.Info("Global base key check is disabled.")
return
}
// Perform an initial check on startup
s.performBaseKeyChecks()
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
if interval <= 0 {
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
@@ -558,6 +531,7 @@ func (s *HealthCheckService) runBaseKeyCheckLoop() {
}
func (s *HealthCheckService) performBaseKeyChecks() {
ctx := context.Background()
s.logger.Info("Starting global base key check cycle.")
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -576,7 +550,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
jobs := make(chan *models.APIKey, len(keys))
var wg sync.WaitGroup
if concurrency <= 0 {
concurrency = 5 // Safe default
concurrency = 5
}
for w := 0; w < concurrency; w++ {
wg.Add(1)
@@ -587,10 +561,10 @@ func (s *HealthCheckService) performBaseKeyChecks() {
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
oldStatus := key.MasterStatus
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
} else {
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
}
}
}
@@ -604,8 +578,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
s.logger.Info("Global base key check cycle finished.")
}
// 事件发布辅助函数
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
event := models.MasterKeyStatusChangedEvent{
KeyID: keyID,
OldMasterStatus: oldStatus,
@@ -618,7 +591,7 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldS
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
return
}
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
}
}