New
This commit is contained in:
624
internal/service/healthcheck_service.go
Normal file
624
internal/service/healthcheck_service.go
Normal file
@@ -0,0 +1,624 @@
|
||||
// Filename: internal/service/healthcheck_service.go (最终校准版)
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/channel"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProxyCheckTargetURL = "https://www.google.com/generate_204"
|
||||
DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
StatusActive = "active"
|
||||
StatusInactive = "inactive"
|
||||
)
|
||||
|
||||
type HealthCheckServiceLogger struct{ *logrus.Entry }
|
||||
|
||||
type HealthCheckService struct {
|
||||
db *gorm.DB
|
||||
SettingsManager *settings.SettingsManager
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
groupManager *GroupManager
|
||||
channel channel.ChannelProxy
|
||||
keyValidationService *KeyValidationService
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
lastResultsMutex sync.RWMutex
|
||||
lastResults map[string]string
|
||||
groupCheckTimeMutex sync.Mutex
|
||||
groupNextCheckTime map[uint]time.Time
|
||||
}
|
||||
|
||||
func NewHealthCheckService(
|
||||
db *gorm.DB,
|
||||
ss *settings.SettingsManager,
|
||||
s store.Store,
|
||||
repo repository.KeyRepository,
|
||||
gm *GroupManager,
|
||||
ch channel.ChannelProxy,
|
||||
kvs *KeyValidationService,
|
||||
logger *logrus.Logger,
|
||||
) *HealthCheckService {
|
||||
return &HealthCheckService{
|
||||
db: db,
|
||||
SettingsManager: ss,
|
||||
store: s,
|
||||
keyRepo: repo,
|
||||
groupManager: gm,
|
||||
channel: ch,
|
||||
keyValidationService: kvs,
|
||||
logger: logger.WithField("component", "HealthCheck🩺"),
|
||||
stopChan: make(chan struct{}),
|
||||
lastResults: make(map[string]string),
|
||||
groupNextCheckTime: make(map[uint]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) Start() {
|
||||
s.logger.Info("Starting HealthCheckService with independent check loops...")
|
||||
s.wg.Add(4) // Now four loops
|
||||
go s.runKeyCheckLoop()
|
||||
go s.runUpstreamCheckLoop()
|
||||
go s.runProxyCheckLoop()
|
||||
go s.runBaseKeyCheckLoop()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) Stop() {
|
||||
s.logger.Info("Stopping HealthCheckService...")
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
s.logger.Info("HealthCheckService stopped gracefully.")
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
|
||||
s.lastResultsMutex.RLock()
|
||||
defer s.lastResultsMutex.RUnlock()
|
||||
resultsCopy := make(map[string]string, len(s.lastResults))
|
||||
for k, v := range s.lastResults {
|
||||
resultsCopy[k] = v
|
||||
}
|
||||
return resultsCopy
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Key check dynamic scheduler loop started.")
|
||||
|
||||
// 主调度循环,每分钟检查一次任务
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.scheduleKeyChecks()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Key check scheduler loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) scheduleKeyChecks() {
|
||||
groups := s.groupManager.GetAllGroups()
|
||||
now := time.Now()
|
||||
|
||||
s.groupCheckTimeMutex.Lock()
|
||||
defer s.groupCheckTimeMutex.Unlock()
|
||||
|
||||
for _, group := range groups {
|
||||
// 获取特定于组的运营配置
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
|
||||
continue
|
||||
}
|
||||
|
||||
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
|
||||
continue // 跳过禁用了健康检查的组
|
||||
}
|
||||
|
||||
var intervalMinutes int
|
||||
if opConfig.KeyCheckIntervalMinutes != nil {
|
||||
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
|
||||
}
|
||||
interval := time.Duration(intervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
continue // 跳过无效的检查周期
|
||||
}
|
||||
|
||||
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
|
||||
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
|
||||
go s.performKeyChecksForGroup(group, opConfig)
|
||||
s.groupNextCheckTime[group.ID] = now.Add(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runUpstreamCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Upstream check loop started.")
|
||||
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
|
||||
s.performUpstreamChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
|
||||
s.logger.Debug("Upstream check ticker fired.")
|
||||
s.performUpstreamChecks()
|
||||
}
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Upstream check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runProxyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Proxy check loop started.")
|
||||
if s.SettingsManager.GetSettings().EnableProxyCheck {
|
||||
s.performProxyChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if s.SettingsManager.GetSettings().EnableProxyCheck {
|
||||
s.logger.Debug("Proxy check ticker fired.")
|
||||
s.performProxyChecks()
|
||||
}
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Proxy check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(group.ID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build key check endpoint for group, skipping check cycle.")
|
||||
return
|
||||
}
|
||||
|
||||
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
|
||||
|
||||
log.Infof("Starting key health check cycle.")
|
||||
|
||||
var mappingsToCheck []models.GroupAPIKeyMapping
|
||||
err = s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", group.ID).
|
||||
Where("api_keys.master_status = ?", models.MasterStatusActive).
|
||||
Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{models.StatusActive, models.StatusDisabled, models.StatusCooldown}).
|
||||
Preload("APIKey").
|
||||
Find(&mappingsToCheck).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to fetch key mappings for health check.")
|
||||
return
|
||||
}
|
||||
if len(mappingsToCheck) == 0 {
|
||||
log.Info("No key mappings to check for this group.")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
|
||||
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
|
||||
var wg sync.WaitGroup
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 1 // 保证至少有一个 worker
|
||||
}
|
||||
for w := 1; w <= concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
for mapping := range jobs {
|
||||
s.checkAndProcessMapping(&mapping, timeout, endpoint)
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
for _, m := range mappingsToCheck {
|
||||
jobs <- m
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
log.Info("Finished key health check cycle.")
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
|
||||
if mapping.APIKey == nil {
|
||||
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
|
||||
// --- 诊断一:验证成功 (健康) ---
|
||||
if validationErr == nil {
|
||||
if mapping.Status != models.StatusActive {
|
||||
s.activateMapping(mapping)
|
||||
}
|
||||
return
|
||||
}
|
||||
errorString := validationErr.Error()
|
||||
// --- 诊断二:永久性错误 ---
|
||||
if CustomErrors.IsPermanentUpstreamError(errorString) {
|
||||
s.revokeMapping(mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断三:暂时性错误 ---
|
||||
if CustomErrors.IsTemporaryUpstreamError(errorString) {
|
||||
// Log with a higher level (WARN) since this is an actionable, proactive finding.
|
||||
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
|
||||
s.penalizeMapping(mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断四:其他未知或上游服务错误 ---
|
||||
|
||||
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
// Re-fetch group-specific operational config to get the correct thresholds
|
||||
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
|
||||
return
|
||||
}
|
||||
opConfig, buildErr := s.groupManager.BuildOperationalConfig(group)
|
||||
if buildErr != nil {
|
||||
s.logger.WithError(buildErr).Errorf("Failed to build operational config for group %d during penalty.", mapping.KeyGroupID)
|
||||
return
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
mapping.LastError = err.Error()
|
||||
mapping.ConsecutiveErrorCount++
|
||||
// Use the group-specific threshold
|
||||
threshold := *opConfig.KeyBlacklistThreshold
|
||||
if mapping.ConsecutiveErrorCount >= threshold {
|
||||
mapping.Status = models.StatusCooldown
|
||||
cooldownDuration := time.Duration(*opConfig.KeyCooldownMinutes) * time.Minute
|
||||
cooldownTime := time.Now().Add(cooldownDuration)
|
||||
mapping.CooldownUntil = &cooldownTime
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
|
||||
}
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
if oldStatus != mapping.Status {
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus == models.StatusBanned {
|
||||
return // Already banned, do nothing.
|
||||
}
|
||||
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = "Definitive error: " + err.Error()
|
||||
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
|
||||
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
|
||||
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performUpstreamChecks() {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
|
||||
return
|
||||
}
|
||||
if len(upstreams) == 0 {
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Starting validation for %d upstreams.", len(upstreams))
|
||||
var wg sync.WaitGroup
|
||||
for _, u := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(upstream *models.UpstreamEndpoint) {
|
||||
defer wg.Done()
|
||||
oldStatus := upstream.Status
|
||||
isAlive := s.checkEndpoint(upstream.URL, timeout)
|
||||
newStatus := StatusInactive
|
||||
if isAlive {
|
||||
newStatus = StatusActive
|
||||
}
|
||||
s.lastResultsMutex.Lock()
|
||||
s.lastResults[upstream.URL] = newStatus
|
||||
s.lastResultsMutex.Unlock()
|
||||
if oldStatus != newStatus {
|
||||
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
|
||||
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
|
||||
} else {
|
||||
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}(u)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) bool {
|
||||
client := http.Client{Timeout: timeout}
|
||||
resp, err := client.Head(urlStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode < http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performProxyChecks() {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
|
||||
var proxies []*models.ProxyConfig
|
||||
if err := s.db.Find(&proxies).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve proxies.")
|
||||
return
|
||||
}
|
||||
if len(proxies) == 0 {
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Starting validation for %d proxies.", len(proxies))
|
||||
var wg sync.WaitGroup
|
||||
for _, p := range proxies {
|
||||
wg.Add(1)
|
||||
go func(proxyCfg *models.ProxyConfig) {
|
||||
defer wg.Done()
|
||||
isAlive := s.checkProxy(proxyCfg, timeout)
|
||||
newStatus := StatusInactive
|
||||
if isAlive {
|
||||
newStatus = StatusActive
|
||||
}
|
||||
s.lastResultsMutex.Lock()
|
||||
s.lastResults[proxyCfg.Address] = newStatus
|
||||
s.lastResultsMutex.Unlock()
|
||||
if proxyCfg.Status != newStatus {
|
||||
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
|
||||
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
|
||||
}
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
|
||||
transport := &http.Transport{}
|
||||
switch proxyCfg.Protocol {
|
||||
case "http", "https":
|
||||
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format.")
|
||||
return false
|
||||
}
|
||||
transport.Proxy = http.ProxyURL(proxyUrl)
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer.")
|
||||
return false
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol.")
|
||||
return false
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
resp, err := client.Get(ProxyCheckTargetURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
GroupID: groupID,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
ChangeReason: "health_check",
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
|
||||
event := models.UpstreamHealthChangedEvent{
|
||||
UpstreamID: upstream.ID,
|
||||
UpstreamURL: upstream.URL,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
Latency: 0,
|
||||
Reason: "health_check",
|
||||
CheckedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Global Base Key Check (New Logic)
|
||||
// =========================================================================
|
||||
|
||||
func (s *HealthCheckService) runBaseKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Global base key check loop started.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
|
||||
if !settings.EnableBaseKeyCheck {
|
||||
s.logger.Info("Global base key check is disabled.")
|
||||
return
|
||||
}
|
||||
|
||||
// Perform an initial check on startup
|
||||
s.performBaseKeyChecks()
|
||||
|
||||
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.performBaseKeyChecks()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Global base key check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performBaseKeyChecks() {
|
||||
s.logger.Info("Starting global base key check cycle.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint := settings.BaseKeyCheckEndpoint
|
||||
concurrency := settings.BaseKeyCheckConcurrency
|
||||
keys, err := s.keyRepo.GetActiveMasterKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch active master keys for base check.")
|
||||
return
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
s.logger.Info("No active master keys to perform base check on.")
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Performing base check on %d active master keys.", len(keys))
|
||||
jobs := make(chan *models.APIKey, len(keys))
|
||||
var wg sync.WaitGroup
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // Safe default
|
||||
}
|
||||
for w := 0; w < concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint)
|
||||
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
|
||||
oldStatus := key.MasterStatus
|
||||
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
|
||||
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
|
||||
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
|
||||
} else {
|
||||
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for _, key := range keys {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.logger.Info("Global base key check cycle finished.")
|
||||
}
|
||||
|
||||
// 事件发布辅助函数
|
||||
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
|
||||
event := models.MasterKeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
OldMasterStatus: oldStatus,
|
||||
NewMasterStatus: newStatus,
|
||||
ChangeReason: "base_health_check",
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user