// Filename: internal/service/healthcheck_service.go (最终校准版) package service import ( "context" "encoding/json" "gemini-balancer/internal/channel" CustomErrors "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" "net" "net/http" "net/url" "sync" "time" "github.com/sirupsen/logrus" "golang.org/x/net/proxy" "gorm.io/gorm" ) const ( ProxyCheckTargetURL = "https://www.google.com/generate_204" DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" StatusActive = "active" StatusInactive = "inactive" ) type HealthCheckServiceLogger struct{ *logrus.Entry } type HealthCheckService struct { db *gorm.DB SettingsManager *settings.SettingsManager store store.Store keyRepo repository.KeyRepository groupManager *GroupManager channel channel.ChannelProxy keyValidationService *KeyValidationService logger *logrus.Entry stopChan chan struct{} wg sync.WaitGroup lastResultsMutex sync.RWMutex lastResults map[string]string groupCheckTimeMutex sync.Mutex groupNextCheckTime map[uint]time.Time } func NewHealthCheckService( db *gorm.DB, ss *settings.SettingsManager, s store.Store, repo repository.KeyRepository, gm *GroupManager, ch channel.ChannelProxy, kvs *KeyValidationService, logger *logrus.Logger, ) *HealthCheckService { return &HealthCheckService{ db: db, SettingsManager: ss, store: s, keyRepo: repo, groupManager: gm, channel: ch, keyValidationService: kvs, logger: logger.WithField("component", "HealthCheck🩺"), stopChan: make(chan struct{}), lastResults: make(map[string]string), groupNextCheckTime: make(map[uint]time.Time), } } func (s *HealthCheckService) Start() { s.logger.Info("Starting HealthCheckService with independent check loops...") s.wg.Add(4) go s.runKeyCheckLoop() go s.runUpstreamCheckLoop() go s.runProxyCheckLoop() go s.runBaseKeyCheckLoop() } func (s *HealthCheckService) Stop() { s.logger.Info("Stopping HealthCheckService...") close(s.stopChan) s.wg.Wait() s.logger.Info("HealthCheckService stopped gracefully.") } func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string { s.lastResultsMutex.RLock() defer s.lastResultsMutex.RUnlock() resultsCopy := make(map[string]string, len(s.lastResults)) for k, v := range s.lastResults { resultsCopy[k] = v } return resultsCopy } func (s *HealthCheckService) runKeyCheckLoop() { defer s.wg.Done() s.logger.Info("Key check dynamic scheduler loop started.") ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: s.scheduleKeyChecks() case <-s.stopChan: s.logger.Info("Key check scheduler loop stopped.") return } } } func (s *HealthCheckService) scheduleKeyChecks() { groups := s.groupManager.GetAllGroups() now := time.Now() s.groupCheckTimeMutex.Lock() defer s.groupCheckTimeMutex.Unlock() for _, group := range groups { opConfig, err := s.groupManager.BuildOperationalConfig(group) if err != nil { s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.") continue } if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck { continue } var intervalMinutes int if opConfig.KeyCheckIntervalMinutes != nil { intervalMinutes = *opConfig.KeyCheckIntervalMinutes } interval := time.Duration(intervalMinutes) * time.Minute if interval <= 0 { continue } if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) { s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID) go s.performKeyChecksForGroup(group, opConfig) s.groupNextCheckTime[group.ID] = now.Add(interval) } } } func (s *HealthCheckService) runUpstreamCheckLoop() { defer s.wg.Done() s.logger.Info("Upstream check loop started.") if s.SettingsManager.GetSettings().EnableUpstreamCheck { s.performUpstreamChecks() } ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second) defer ticker.Stop() for { select { case <-ticker.C: if s.SettingsManager.GetSettings().EnableUpstreamCheck { s.logger.Debug("Upstream check ticker fired.") s.performUpstreamChecks() } case <-s.stopChan: s.logger.Info("Upstream check loop stopped.") return } } } func (s *HealthCheckService) runProxyCheckLoop() { defer s.wg.Done() s.logger.Info("Proxy check loop started.") if s.SettingsManager.GetSettings().EnableProxyCheck { s.performProxyChecks() } ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second) defer ticker.Stop() for { select { case <-ticker.C: if s.SettingsManager.GetSettings().EnableProxyCheck { s.logger.Debug("Proxy check ticker fired.") s.performProxyChecks() } case <-s.stopChan: s.logger.Info("Proxy check loop stopped.") return } } } func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) { ctx := context.Background() settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second endpoint, err := s.groupManager.BuildKeyCheckEndpoint(group.ID) if err != nil { s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build key check endpoint for group, skipping check cycle.") return } log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name}) log.Infof("Starting key health check cycle.") var mappingsToCheck []models.GroupAPIKeyMapping err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id"). Where("group_api_key_mappings.key_group_id = ?", group.ID). Where("api_keys.master_status = ?", models.MasterStatusActive). Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{models.StatusActive, models.StatusDisabled, models.StatusCooldown}). Preload("APIKey"). Find(&mappingsToCheck).Error if err != nil { log.WithError(err).Error("Failed to fetch key mappings for health check.") return } if len(mappingsToCheck) == 0 { log.Info("No key mappings to check for this group.") return } log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck)) jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck)) var wg sync.WaitGroup var concurrency int if opConfig.KeyCheckConcurrency != nil { concurrency = *opConfig.KeyCheckConcurrency } if concurrency <= 0 { concurrency = 1 } for w := 1; w <= concurrency; w++ { wg.Add(1) go func(workerID int) { defer wg.Done() for mapping := range jobs { s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint) } }(w) } for _, m := range mappingsToCheck { jobs <- m } close(jobs) wg.Wait() log.Info("Finished key health check cycle.") } func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) { if mapping.APIKey == nil { s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID) return } validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint) if validationErr == nil { if mapping.Status != models.StatusActive { s.activateMapping(ctx, mapping) } return } errorString := validationErr.Error() if CustomErrors.IsPermanentUpstreamError(errorString) { s.revokeMapping(ctx, mapping, validationErr) return } if CustomErrors.IsTemporaryUpstreamError(errorString) { s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr) s.penalizeMapping(ctx, mapping, validationErr) return } s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr) } func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) { oldStatus := mapping.Status mapping.Status = models.StatusActive mapping.ConsecutiveErrorCount = 0 mapping.LastError = "" if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) return } s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus) s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) } func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) { group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID) if !ok { s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID) return } opConfig, buildErr := s.groupManager.BuildOperationalConfig(group) if buildErr != nil { s.logger.WithError(buildErr).Errorf("Failed to build operational config for group %d during penalty.", mapping.KeyGroupID) return } oldStatus := mapping.Status mapping.LastError = err.Error() mapping.ConsecutiveErrorCount++ threshold := *opConfig.KeyBlacklistThreshold if mapping.ConsecutiveErrorCount >= threshold { mapping.Status = models.StatusCooldown cooldownDuration := time.Duration(*opConfig.KeyCooldownMinutes) * time.Minute cooldownTime := time.Now().Add(cooldownDuration) mapping.CooldownUntil = &cooldownTime s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration) } if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil { s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) return } if oldStatus != mapping.Status { s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) } } func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) { oldStatus := mapping.Status if oldStatus == models.StatusBanned { return } mapping.Status = models.StatusBanned mapping.LastError = "Definitive error: " + err.Error() mapping.ConsecutiveErrorCount = 0 if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil { s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) return } s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err) s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID) if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil { s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID) } } func (s *HealthCheckService) performUpstreamChecks() { ctx := context.Background() settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second var upstreams []*models.UpstreamEndpoint if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil { s.logger.WithError(err).Error("Failed to retrieve upstreams.") return } if len(upstreams) == 0 { return } s.logger.Infof("Starting validation for %d upstreams.", len(upstreams)) var wg sync.WaitGroup for _, u := range upstreams { wg.Add(1) go func(upstream *models.UpstreamEndpoint) { defer wg.Done() oldStatus := upstream.Status isAlive := s.checkEndpoint(upstream.URL, timeout) newStatus := StatusInactive if isAlive { newStatus = StatusActive } s.lastResultsMutex.Lock() s.lastResults[upstream.URL] = newStatus s.lastResultsMutex.Unlock() if oldStatus != newStatus { s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus) if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil { s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.") } else { s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus) } } }(u) } wg.Wait() } func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) bool { client := http.Client{Timeout: timeout} resp, err := client.Head(urlStr) if err != nil { return false } defer resp.Body.Close() return resp.StatusCode < http.StatusInternalServerError } func (s *HealthCheckService) performProxyChecks() { ctx := context.Background() settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second var proxies []*models.ProxyConfig if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil { s.logger.WithError(err).Error("Failed to retrieve proxies.") return } if len(proxies) == 0 { return } s.logger.Infof("Starting validation for %d proxies.", len(proxies)) var wg sync.WaitGroup for _, p := range proxies { wg.Add(1) go func(proxyCfg *models.ProxyConfig) { defer wg.Done() isAlive := s.checkProxy(proxyCfg, timeout) newStatus := StatusInactive if isAlive { newStatus = StatusActive } s.lastResultsMutex.Lock() s.lastResults[proxyCfg.Address] = newStatus s.lastResultsMutex.Unlock() if proxyCfg.Status != newStatus { s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus) if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil { s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.") } } }(p) } wg.Wait() } func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout time.Duration) bool { transport := &http.Transport{} switch proxyCfg.Protocol { case "http", "https": proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address) if err != nil { s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format.") return false } transport.Proxy = http.ProxyURL(proxyUrl) case "socks5": dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct) if err != nil { s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer.") return false } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) } default: s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol.") return false } client := &http.Client{ Transport: transport, Timeout: timeout, } resp, err := client.Get(ProxyCheckTargetURL) if err != nil { return false } defer resp.Body.Close() return true } func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) { event := models.KeyStatusChangedEvent{ KeyID: keyID, GroupID: groupID, OldStatus: oldStatus, NewStatus: newStatus, ChangeReason: "health_check", ChangedAt: time.Now(), } payload, err := json.Marshal(event) if err != nil { s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID) return } if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil { s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID) } } func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) { event := models.UpstreamHealthChangedEvent{ UpstreamID: upstream.ID, UpstreamURL: upstream.URL, OldStatus: oldStatus, NewStatus: newStatus, Latency: 0, Reason: "health_check", CheckedAt: time.Now(), } payload, err := json.Marshal(event) if err != nil { s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.") return } if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil { s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.") } } func (s *HealthCheckService) runBaseKeyCheckLoop() { defer s.wg.Done() s.logger.Info("Global base key check loop started.") settings := s.SettingsManager.GetSettings() if !settings.EnableBaseKeyCheck { s.logger.Info("Global base key check is disabled.") return } s.performBaseKeyChecks() interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute if interval <= 0 { s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes) return } ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: s.performBaseKeyChecks() case <-s.stopChan: s.logger.Info("Global base key check loop stopped.") return } } } func (s *HealthCheckService) performBaseKeyChecks() { ctx := context.Background() s.logger.Info("Starting global base key check cycle.") settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second endpoint := settings.BaseKeyCheckEndpoint concurrency := settings.BaseKeyCheckConcurrency keys, err := s.keyRepo.GetActiveMasterKeys() if err != nil { s.logger.WithError(err).Error("Failed to fetch active master keys for base check.") return } if len(keys) == 0 { s.logger.Info("No active master keys to perform base check on.") return } s.logger.Infof("Performing base check on %d active master keys.", len(keys)) jobs := make(chan *models.APIKey, len(keys)) var wg sync.WaitGroup if concurrency <= 0 { concurrency = 5 } for w := 0; w < concurrency; w++ { wg.Add(1) go func() { defer wg.Done() for key := range jobs { err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint) if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) { oldStatus := key.MasterStatus s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err) if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil { s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID) } else { s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked) } } } }() } for _, key := range keys { jobs <- key } close(jobs) wg.Wait() s.logger.Info("Global base key check cycle finished.") } func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) { event := models.MasterKeyStatusChangedEvent{ KeyID: keyID, OldMasterStatus: oldStatus, NewMasterStatus: newStatus, ChangeReason: "base_health_check", ChangedAt: time.Now(), } payload, err := json.Marshal(event) if err != nil { s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID) return } if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil { s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID) } }