998 lines
27 KiB
Go
998 lines
27 KiB
Go
// Filename: internal/service/healthcheck_service.go
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"gemini-balancer/internal/channel"
|
|
CustomErrors "gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/repository"
|
|
"gemini-balancer/internal/settings"
|
|
"gemini-balancer/internal/store"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/net/proxy"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
// 业务状态常量
|
|
StatusActive = "active"
|
|
StatusInactive = "inactive"
|
|
|
|
// 代理检查目标(固定不变)
|
|
ProxyCheckTargetURL = "https://www.google.com/generate_204"
|
|
|
|
// 并发控制边界
|
|
minHealthCheckConcurrency = 1
|
|
maxHealthCheckConcurrency = 100
|
|
defaultKeyCheckConcurrency = 5
|
|
defaultBaseKeyCheckConcurrency = 5
|
|
|
|
// 兜底默认值
|
|
defaultSchedulerIntervalSeconds = 60
|
|
defaultKeyCheckTimeoutSeconds = 30
|
|
defaultBaseKeyCheckEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
)
|
|
|
|
type HealthCheckServiceLogger struct{ *logrus.Entry }
|
|
|
|
type HealthCheckService struct {
|
|
db *gorm.DB
|
|
settingsManager *settings.SettingsManager
|
|
store store.Store
|
|
keyRepo repository.KeyRepository
|
|
groupManager *GroupManager
|
|
channel channel.ChannelProxy
|
|
keyValidationService *KeyValidationService
|
|
logger *logrus.Entry
|
|
stopChan chan struct{}
|
|
wg sync.WaitGroup
|
|
lastResultsMutex sync.RWMutex
|
|
lastResults map[string]string
|
|
groupCheckTimeMutex sync.Mutex
|
|
groupNextCheckTime map[uint]time.Time
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func NewHealthCheckService(
|
|
db *gorm.DB,
|
|
ss *settings.SettingsManager,
|
|
s store.Store,
|
|
repo repository.KeyRepository,
|
|
gm *GroupManager,
|
|
ch channel.ChannelProxy,
|
|
kvs *KeyValidationService,
|
|
logger *logrus.Logger,
|
|
) *HealthCheckService {
|
|
return &HealthCheckService{
|
|
db: db,
|
|
settingsManager: ss,
|
|
store: s,
|
|
keyRepo: repo,
|
|
groupManager: gm,
|
|
channel: ch,
|
|
keyValidationService: kvs,
|
|
logger: logger.WithField("component", "HealthCheck🩺"),
|
|
stopChan: make(chan struct{}),
|
|
lastResults: make(map[string]string),
|
|
groupNextCheckTime: make(map[uint]time.Time),
|
|
httpClient: &http.Client{
|
|
Transport: &http.Transport{
|
|
MaxIdleConns: 100,
|
|
MaxIdleConnsPerHost: 10,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
DisableKeepAlives: false,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) Start() {
|
|
s.logger.Info("Starting HealthCheckService with independent check loops...")
|
|
s.wg.Add(4)
|
|
go s.runKeyCheckLoop()
|
|
go s.runUpstreamCheckLoop()
|
|
go s.runProxyCheckLoop()
|
|
go s.runBaseKeyCheckLoop()
|
|
}
|
|
|
|
func (s *HealthCheckService) Stop() {
|
|
s.logger.Info("Stopping HealthCheckService...")
|
|
close(s.stopChan)
|
|
s.wg.Wait()
|
|
s.httpClient.CloseIdleConnections()
|
|
s.logger.Info("HealthCheckService stopped gracefully.")
|
|
}
|
|
|
|
func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
|
|
s.lastResultsMutex.RLock()
|
|
defer s.lastResultsMutex.RUnlock()
|
|
resultsCopy := make(map[string]string, len(s.lastResults))
|
|
for k, v := range s.lastResults {
|
|
resultsCopy[k] = v
|
|
}
|
|
return resultsCopy
|
|
}
|
|
|
|
// ==================== Key Check Loop ====================
|
|
|
|
func (s *HealthCheckService) runKeyCheckLoop() {
|
|
defer s.wg.Done()
|
|
|
|
settings := s.settingsManager.GetSettings()
|
|
schedulerInterval := time.Duration(settings.KeyCheckSchedulerIntervalSeconds) * time.Second
|
|
if schedulerInterval <= 0 {
|
|
schedulerInterval = time.Duration(defaultSchedulerIntervalSeconds) * time.Second
|
|
}
|
|
|
|
s.logger.Infof("Key check dynamic scheduler loop started with interval: %v", schedulerInterval)
|
|
ticker := time.NewTicker(schedulerInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
s.scheduleKeyChecks()
|
|
case <-s.stopChan:
|
|
s.logger.Info("Key check scheduler loop stopped.")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) scheduleKeyChecks() {
|
|
groups := s.groupManager.GetAllGroups()
|
|
now := time.Now()
|
|
|
|
s.groupCheckTimeMutex.Lock()
|
|
defer s.groupCheckTimeMutex.Unlock()
|
|
|
|
for _, group := range groups {
|
|
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
|
if err != nil {
|
|
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
|
|
continue
|
|
}
|
|
|
|
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
|
|
continue
|
|
}
|
|
|
|
var intervalMinutes int
|
|
if opConfig.KeyCheckIntervalMinutes != nil {
|
|
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
|
|
}
|
|
interval := time.Duration(intervalMinutes) * time.Minute
|
|
if interval <= 0 {
|
|
continue
|
|
}
|
|
|
|
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"group_id": group.ID,
|
|
"group_name": group.Name,
|
|
"interval": interval,
|
|
}).Info("Scheduling key check for group")
|
|
|
|
// 创建带超时的上下文
|
|
ctx, cancel := context.WithTimeout(context.Background(), interval)
|
|
go func(g *models.KeyGroup, cfg *models.KeyGroupSettings) {
|
|
defer cancel()
|
|
select {
|
|
case <-s.stopChan:
|
|
return
|
|
default:
|
|
s.performKeyChecksForGroup(ctx, g, cfg)
|
|
}
|
|
}(group, opConfig)
|
|
|
|
s.groupNextCheckTime[group.ID] = now.Add(interval)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) performKeyChecksForGroup(
|
|
ctx context.Context,
|
|
group *models.KeyGroup,
|
|
opConfig *models.KeyGroupSettings,
|
|
) {
|
|
settings := s.settingsManager.GetSettings()
|
|
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
|
if timeout <= 0 {
|
|
timeout = time.Duration(defaultKeyCheckTimeoutSeconds) * time.Second
|
|
}
|
|
|
|
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(group.ID)
|
|
if err != nil {
|
|
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build key check endpoint for group, skipping check cycle.")
|
|
return
|
|
}
|
|
|
|
log := s.logger.WithFields(logrus.Fields{
|
|
"group_id": group.ID,
|
|
"group_name": group.Name,
|
|
})
|
|
|
|
log.Info("Starting key health check cycle")
|
|
|
|
var mappingsToCheck []models.GroupAPIKeyMapping
|
|
err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
|
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
|
|
Where("group_api_key_mappings.key_group_id = ?", group.ID).
|
|
Where("api_keys.master_status = ?", models.MasterStatusActive).
|
|
Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{
|
|
models.StatusActive,
|
|
models.StatusDisabled,
|
|
models.StatusCooldown,
|
|
}).
|
|
Preload("APIKey").
|
|
Find(&mappingsToCheck).Error
|
|
|
|
if err != nil {
|
|
log.WithError(err).Error("Failed to fetch key mappings for health check")
|
|
return
|
|
}
|
|
|
|
if len(mappingsToCheck) == 0 {
|
|
log.Info("No key mappings to check for this group")
|
|
return
|
|
}
|
|
|
|
log.WithField("key_count", len(mappingsToCheck)).Info("Starting health check for key mappings")
|
|
|
|
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
|
|
var wg sync.WaitGroup
|
|
|
|
concurrency := s.getConcurrency(opConfig.KeyCheckConcurrency, defaultKeyCheckConcurrency)
|
|
log.WithField("concurrency", concurrency).Debug("Using concurrency for key check")
|
|
|
|
for w := 1; w <= concurrency; w++ {
|
|
wg.Add(1)
|
|
go func(workerID int) {
|
|
defer wg.Done()
|
|
for mapping := range jobs {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Warn("Context cancelled, stopping worker")
|
|
return
|
|
default:
|
|
s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
|
|
}
|
|
}
|
|
}(w)
|
|
}
|
|
|
|
for _, m := range mappingsToCheck {
|
|
jobs <- m
|
|
}
|
|
close(jobs)
|
|
wg.Wait()
|
|
|
|
log.Info("Finished key health check cycle")
|
|
}
|
|
|
|
func (s *HealthCheckService) checkAndProcessMapping(
|
|
ctx context.Context,
|
|
mapping *models.GroupAPIKeyMapping,
|
|
timeout time.Duration,
|
|
endpoint string,
|
|
) {
|
|
if mapping.APIKey == nil {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"group_id": mapping.KeyGroupID,
|
|
"key_id": mapping.APIKeyID,
|
|
}).Warn("Skipping check for mapping because associated APIKey is nil")
|
|
return
|
|
}
|
|
|
|
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
|
|
if validationErr == nil {
|
|
if mapping.Status != models.StatusActive {
|
|
s.activateMapping(ctx, mapping)
|
|
}
|
|
return
|
|
}
|
|
|
|
errorString := validationErr.Error()
|
|
if CustomErrors.IsPermanentUpstreamError(errorString) {
|
|
s.revokeMapping(ctx, mapping, validationErr)
|
|
return
|
|
}
|
|
|
|
if CustomErrors.IsTemporaryUpstreamError(errorString) {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"key_id": mapping.APIKeyID,
|
|
"group_id": mapping.KeyGroupID,
|
|
"error": validationErr.Error(),
|
|
}).Warn("Health check failed with temporary error, applying penalty")
|
|
s.penalizeMapping(ctx, mapping, validationErr)
|
|
return
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"key_id": mapping.APIKeyID,
|
|
"group_id": mapping.KeyGroupID,
|
|
"error": validationErr.Error(),
|
|
}).Error("Health check failed with transient or unknown upstream error, mapping will not be penalized")
|
|
}
|
|
|
|
func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
|
|
oldStatus := mapping.Status
|
|
mapping.Status = models.StatusActive
|
|
mapping.ConsecutiveErrorCount = 0
|
|
mapping.LastError = ""
|
|
|
|
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
|
s.logger.WithError(err).WithFields(logrus.Fields{
|
|
"group_id": mapping.KeyGroupID,
|
|
"key_id": mapping.APIKeyID,
|
|
}).Error("Failed to activate mapping")
|
|
return
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"group_id": mapping.KeyGroupID,
|
|
"key_id": mapping.APIKeyID,
|
|
"old_status": oldStatus,
|
|
"new_status": mapping.Status,
|
|
}).Info("Mapping successfully activated")
|
|
|
|
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
|
}
|
|
|
|
func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
|
|
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
|
|
if !ok {
|
|
s.logger.WithField("group_id", mapping.KeyGroupID).Error("Could not find group to apply penalty")
|
|
return
|
|
}
|
|
|
|
opConfig, buildErr := s.groupManager.BuildOperationalConfig(group)
|
|
if buildErr != nil {
|
|
s.logger.WithError(buildErr).WithField("group_id", mapping.KeyGroupID).Error("Failed to build operational config for group during penalty")
|
|
return
|
|
}
|
|
|
|
oldStatus := mapping.Status
|
|
mapping.LastError = err.Error()
|
|
mapping.ConsecutiveErrorCount++
|
|
|
|
threshold := *opConfig.KeyBlacklistThreshold
|
|
if mapping.ConsecutiveErrorCount >= threshold {
|
|
mapping.Status = models.StatusCooldown
|
|
cooldownDuration := time.Duration(*opConfig.KeyCooldownMinutes) * time.Minute
|
|
cooldownTime := time.Now().Add(cooldownDuration)
|
|
mapping.CooldownUntil = &cooldownTime
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"group_id": mapping.KeyGroupID,
|
|
"key_id": mapping.APIKeyID,
|
|
"error_count": mapping.ConsecutiveErrorCount,
|
|
"threshold": threshold,
|
|
"cooldown_duration": cooldownDuration,
|
|
}).Warn("Mapping reached error threshold and is now in COOLDOWN")
|
|
}
|
|
|
|
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
|
|
s.logger.WithError(errDb).WithFields(logrus.Fields{
|
|
"group_id": mapping.KeyGroupID,
|
|
"key_id": mapping.APIKeyID,
|
|
}).Error("Failed to penalize mapping")
|
|
return
|
|
}
|
|
|
|
if oldStatus != mapping.Status {
|
|
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
|
|
oldStatus := mapping.Status
|
|
if oldStatus == models.StatusBanned {
|
|
return
|
|
}
|
|
|
|
mapping.Status = models.StatusBanned
|
|
mapping.LastError = "Definitive error: " + err.Error()
|
|
mapping.ConsecutiveErrorCount = 0
|
|
|
|
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
|
|
s.logger.WithError(errDb).WithFields(logrus.Fields{
|
|
"group_id": mapping.KeyGroupID,
|
|
"key_id": mapping.APIKeyID,
|
|
}).Error("Failed to revoke mapping")
|
|
return
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"group_id": mapping.KeyGroupID,
|
|
"key_id": mapping.APIKeyID,
|
|
"error": err.Error(),
|
|
}).Warn("Mapping has been BANNED due to definitive error")
|
|
|
|
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
|
|
|
s.logger.WithField("key_id", mapping.APIKeyID).Info("Triggering MasterStatus update for definitively failed key")
|
|
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
|
|
s.logger.WithError(err).WithField("key_id", mapping.APIKeyID).Error("Failed to update master status after group-level ban")
|
|
}
|
|
}
|
|
|
|
// ==================== Upstream Check Loop ====================
|
|
|
|
func (s *HealthCheckService) runUpstreamCheckLoop() {
|
|
defer s.wg.Done()
|
|
s.logger.Info("Upstream check loop started")
|
|
|
|
settings := s.settingsManager.GetSettings()
|
|
if settings.EnableUpstreamCheck {
|
|
s.performUpstreamChecks()
|
|
}
|
|
|
|
interval := time.Duration(settings.UpstreamCheckIntervalSeconds) * time.Second
|
|
if interval <= 0 {
|
|
interval = 300 * time.Second // 5 分钟兜底
|
|
}
|
|
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
settings := s.settingsManager.GetSettings()
|
|
if settings.EnableUpstreamCheck {
|
|
s.logger.Debug("Upstream check ticker fired")
|
|
s.performUpstreamChecks()
|
|
}
|
|
case <-s.stopChan:
|
|
s.logger.Info("Upstream check loop stopped")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) performUpstreamChecks() {
|
|
ctx := context.Background()
|
|
settings := s.settingsManager.GetSettings()
|
|
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
|
|
if timeout <= 0 {
|
|
timeout = 10 * time.Second
|
|
}
|
|
|
|
var upstreams []*models.UpstreamEndpoint
|
|
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
|
|
s.logger.WithError(err).Error("Failed to retrieve upstreams")
|
|
return
|
|
}
|
|
|
|
if len(upstreams) == 0 {
|
|
s.logger.Debug("No upstreams configured for health check")
|
|
return
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"count": len(upstreams),
|
|
"timeout": timeout,
|
|
}).Info("Starting upstream validation")
|
|
|
|
type checkResult struct {
|
|
upstreamID uint
|
|
url string
|
|
oldStatus string
|
|
newStatus string
|
|
changed bool
|
|
err error
|
|
}
|
|
|
|
results := make([]checkResult, 0, len(upstreams))
|
|
var resultsMutex sync.Mutex
|
|
var wg sync.WaitGroup
|
|
|
|
for _, u := range upstreams {
|
|
wg.Add(1)
|
|
go func(upstream *models.UpstreamEndpoint) {
|
|
defer wg.Done()
|
|
|
|
oldStatus := upstream.Status
|
|
isAlive := s.checkEndpoint(upstream.URL, timeout)
|
|
newStatus := StatusInactive
|
|
if isAlive {
|
|
newStatus = StatusActive
|
|
}
|
|
|
|
s.lastResultsMutex.Lock()
|
|
s.lastResults[upstream.URL] = newStatus
|
|
s.lastResultsMutex.Unlock()
|
|
|
|
result := checkResult{
|
|
upstreamID: upstream.ID,
|
|
url: upstream.URL,
|
|
oldStatus: oldStatus,
|
|
newStatus: newStatus,
|
|
changed: oldStatus != newStatus,
|
|
}
|
|
|
|
if result.changed {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"upstream_id": upstream.ID,
|
|
"url": upstream.URL,
|
|
"old_status": oldStatus,
|
|
"new_status": newStatus,
|
|
}).Info("Upstream status changed")
|
|
|
|
if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
|
|
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status")
|
|
result.err = err
|
|
} else {
|
|
s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
|
|
}
|
|
} else {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"upstream_id": upstream.ID,
|
|
"url": upstream.URL,
|
|
"status": newStatus,
|
|
}).Debug("Upstream status unchanged")
|
|
}
|
|
|
|
resultsMutex.Lock()
|
|
results = append(results, result)
|
|
resultsMutex.Unlock()
|
|
}(u)
|
|
}
|
|
wg.Wait()
|
|
|
|
// 汇总统计
|
|
activeCount := 0
|
|
inactiveCount := 0
|
|
changedCount := 0
|
|
errorCount := 0
|
|
|
|
for _, r := range results {
|
|
if r.changed {
|
|
changedCount++
|
|
}
|
|
if r.err != nil {
|
|
errorCount++
|
|
}
|
|
if r.newStatus == StatusActive {
|
|
activeCount++
|
|
} else {
|
|
inactiveCount++
|
|
}
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"total": len(upstreams),
|
|
"active": activeCount,
|
|
"inactive": inactiveCount,
|
|
"changed": changedCount,
|
|
"errors": errorCount,
|
|
}).Info("Upstream validation cycle completed")
|
|
}
|
|
|
|
func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) bool {
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodHead, urlStr, nil)
|
|
if err != nil {
|
|
s.logger.WithError(err).WithField("url", urlStr).Debug("Failed to create request for endpoint check")
|
|
return false
|
|
}
|
|
|
|
resp, err := s.httpClient.Do(req)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return resp.StatusCode < http.StatusInternalServerError
|
|
}
|
|
|
|
// ==================== Proxy Check Loop ====================
|
|
|
|
func (s *HealthCheckService) runProxyCheckLoop() {
|
|
defer s.wg.Done()
|
|
s.logger.Info("Proxy check loop started")
|
|
|
|
settings := s.settingsManager.GetSettings()
|
|
if settings.EnableProxyCheck {
|
|
s.performProxyChecks()
|
|
}
|
|
|
|
interval := time.Duration(settings.ProxyCheckIntervalSeconds) * time.Second
|
|
if interval <= 0 {
|
|
interval = 600 * time.Second // 10 分钟兜底
|
|
}
|
|
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
settings := s.settingsManager.GetSettings()
|
|
if settings.EnableProxyCheck {
|
|
s.logger.Debug("Proxy check ticker fired")
|
|
s.performProxyChecks()
|
|
}
|
|
case <-s.stopChan:
|
|
s.logger.Info("Proxy check loop stopped")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) performProxyChecks() {
|
|
ctx := context.Background()
|
|
settings := s.settingsManager.GetSettings()
|
|
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
|
|
if timeout <= 0 {
|
|
timeout = 15 * time.Second
|
|
}
|
|
|
|
var proxies []*models.ProxyConfig
|
|
if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
|
|
s.logger.WithError(err).Error("Failed to retrieve proxies")
|
|
return
|
|
}
|
|
|
|
if len(proxies) == 0 {
|
|
s.logger.Debug("No proxies configured for health check")
|
|
return
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"count": len(proxies),
|
|
"timeout": timeout,
|
|
}).Info("Starting proxy validation")
|
|
|
|
activeCount := 0
|
|
inactiveCount := 0
|
|
changedCount := 0
|
|
var statsMutex sync.Mutex
|
|
var wg sync.WaitGroup
|
|
|
|
for _, p := range proxies {
|
|
wg.Add(1)
|
|
go func(proxyCfg *models.ProxyConfig) {
|
|
defer wg.Done()
|
|
|
|
oldStatus := proxyCfg.Status
|
|
isAlive := s.checkProxy(proxyCfg, timeout)
|
|
newStatus := StatusInactive
|
|
if isAlive {
|
|
newStatus = StatusActive
|
|
}
|
|
|
|
s.lastResultsMutex.Lock()
|
|
s.lastResults[proxyCfg.Address] = newStatus
|
|
s.lastResultsMutex.Unlock()
|
|
|
|
if oldStatus != newStatus {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"proxy_id": proxyCfg.ID,
|
|
"address": proxyCfg.Address,
|
|
"old_status": oldStatus,
|
|
"new_status": newStatus,
|
|
}).Info("Proxy status changed")
|
|
|
|
if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
|
|
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status")
|
|
}
|
|
|
|
statsMutex.Lock()
|
|
changedCount++
|
|
if newStatus == StatusActive {
|
|
activeCount++
|
|
} else {
|
|
inactiveCount++
|
|
}
|
|
statsMutex.Unlock()
|
|
} else {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"proxy_id": proxyCfg.ID,
|
|
"address": proxyCfg.Address,
|
|
"status": newStatus,
|
|
}).Debug("Proxy status unchanged")
|
|
|
|
statsMutex.Lock()
|
|
if newStatus == StatusActive {
|
|
activeCount++
|
|
} else {
|
|
inactiveCount++
|
|
}
|
|
statsMutex.Unlock()
|
|
}
|
|
}(p)
|
|
}
|
|
wg.Wait()
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"total": len(proxies),
|
|
"active": activeCount,
|
|
"inactive": inactiveCount,
|
|
"changed": changedCount,
|
|
}).Info("Proxy validation cycle completed")
|
|
}
|
|
|
|
func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
|
|
transport := &http.Transport{}
|
|
|
|
switch proxyCfg.Protocol {
|
|
case "http", "https":
|
|
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
|
|
if err != nil {
|
|
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format")
|
|
return false
|
|
}
|
|
transport.Proxy = http.ProxyURL(proxyUrl)
|
|
|
|
case "socks5":
|
|
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
|
|
if err != nil {
|
|
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer")
|
|
return false
|
|
}
|
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return dialer.Dial(network, addr)
|
|
}
|
|
|
|
default:
|
|
s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol")
|
|
return false
|
|
}
|
|
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
Timeout: timeout,
|
|
}
|
|
|
|
resp, err := client.Get(ProxyCheckTargetURL)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return true
|
|
}
|
|
|
|
// ==================== Base Key Check Loop ====================
|
|
|
|
func (s *HealthCheckService) runBaseKeyCheckLoop() {
|
|
defer s.wg.Done()
|
|
s.logger.Info("Global base key check loop started")
|
|
|
|
settings := s.settingsManager.GetSettings()
|
|
if !settings.EnableBaseKeyCheck {
|
|
s.logger.Info("Global base key check is disabled")
|
|
return
|
|
}
|
|
|
|
// 启动时执行一次
|
|
s.performBaseKeyChecks()
|
|
|
|
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
|
|
if interval <= 0 {
|
|
s.logger.WithField("interval", settings.BaseKeyCheckIntervalMinutes).Warn("Invalid BaseKeyCheckIntervalMinutes, disabling base key check loop")
|
|
return
|
|
}
|
|
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
s.performBaseKeyChecks()
|
|
case <-s.stopChan:
|
|
s.logger.Info("Global base key check loop stopped")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) performBaseKeyChecks() {
|
|
ctx := context.Background()
|
|
s.logger.Info("Starting global base key check cycle")
|
|
|
|
settings := s.settingsManager.GetSettings()
|
|
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
|
if timeout <= 0 {
|
|
timeout = time.Duration(defaultKeyCheckTimeoutSeconds) * time.Second
|
|
}
|
|
|
|
endpoint := settings.BaseKeyCheckEndpoint
|
|
if endpoint == "" {
|
|
endpoint = defaultBaseKeyCheckEndpoint
|
|
s.logger.WithField("endpoint", endpoint).Debug("Using default base key check endpoint")
|
|
}
|
|
|
|
concurrency := settings.BaseKeyCheckConcurrency
|
|
if concurrency <= 0 {
|
|
concurrency = defaultBaseKeyCheckConcurrency
|
|
}
|
|
concurrency = s.ensureConcurrencyBounds(concurrency)
|
|
|
|
keys, err := s.keyRepo.GetActiveMasterKeys()
|
|
if err != nil {
|
|
s.logger.WithError(err).Error("Failed to fetch active master keys for base check")
|
|
return
|
|
}
|
|
|
|
if len(keys) == 0 {
|
|
s.logger.Info("No active master keys to perform base check on")
|
|
return
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"key_count": len(keys),
|
|
"concurrency": concurrency,
|
|
"endpoint": endpoint,
|
|
}).Info("Performing base check on active master keys")
|
|
|
|
jobs := make(chan *models.APIKey, len(keys))
|
|
var wg sync.WaitGroup
|
|
|
|
for w := 0; w < concurrency; w++ {
|
|
wg.Add(1)
|
|
go func(workerID int) {
|
|
defer wg.Done()
|
|
for key := range jobs {
|
|
select {
|
|
case <-s.stopChan:
|
|
return
|
|
default:
|
|
err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint)
|
|
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
|
|
oldStatus := key.MasterStatus
|
|
keyPrefix := key.APIKey
|
|
if len(keyPrefix) > 8 {
|
|
keyPrefix = keyPrefix[:8]
|
|
}
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"key_id": key.ID,
|
|
"key_prefix": keyPrefix + "...",
|
|
"error": err.Error(),
|
|
}).Warn("Key failed definitive base check, setting MasterStatus to REVOKED")
|
|
|
|
if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
|
|
s.logger.WithError(updateErr).WithField("key_id", key.ID).Error("Failed to update master status")
|
|
} else {
|
|
s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(w)
|
|
}
|
|
|
|
for _, key := range keys {
|
|
jobs <- key
|
|
}
|
|
close(jobs)
|
|
wg.Wait()
|
|
|
|
s.logger.Info("Global base key check cycle finished")
|
|
}
|
|
|
|
// ==================== Event Publishing ====================
|
|
|
|
func (s *HealthCheckService) publishKeyStatusChangedEvent(
|
|
ctx context.Context,
|
|
groupID, keyID uint,
|
|
oldStatus, newStatus models.APIKeyStatus,
|
|
) {
|
|
event := models.KeyStatusChangedEvent{
|
|
KeyID: keyID,
|
|
GroupID: groupID,
|
|
OldStatus: oldStatus,
|
|
NewStatus: newStatus,
|
|
ChangeReason: "health_check",
|
|
ChangedAt: time.Now(),
|
|
}
|
|
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
s.logger.WithError(err).WithField("group_id", groupID).Error("Failed to marshal KeyStatusChangedEvent")
|
|
return
|
|
}
|
|
|
|
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
|
|
s.logger.WithError(err).WithField("group_id", groupID).Error("Failed to publish KeyStatusChangedEvent")
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(
|
|
ctx context.Context,
|
|
upstream *models.UpstreamEndpoint,
|
|
oldStatus, newStatus string,
|
|
) {
|
|
event := models.UpstreamHealthChangedEvent{
|
|
UpstreamID: upstream.ID,
|
|
UpstreamURL: upstream.URL,
|
|
OldStatus: oldStatus,
|
|
NewStatus: newStatus,
|
|
Latency: 0,
|
|
Reason: "health_check",
|
|
CheckedAt: time.Now(),
|
|
}
|
|
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent")
|
|
return
|
|
}
|
|
|
|
if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
|
|
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent")
|
|
}
|
|
}
|
|
|
|
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(
|
|
ctx context.Context,
|
|
keyID uint,
|
|
oldStatus, newStatus models.MasterAPIKeyStatus,
|
|
) {
|
|
event := models.MasterKeyStatusChangedEvent{
|
|
KeyID: keyID,
|
|
OldMasterStatus: oldStatus,
|
|
NewMasterStatus: newStatus,
|
|
ChangeReason: "base_health_check",
|
|
ChangedAt: time.Now(),
|
|
}
|
|
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
s.logger.WithError(err).WithField("key_id", keyID).Error("Failed to marshal MasterKeyStatusChangedEvent")
|
|
return
|
|
}
|
|
|
|
if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
|
|
s.logger.WithError(err).WithField("key_id", keyID).Error("Failed to publish MasterKeyStatusChangedEvent")
|
|
}
|
|
}
|
|
|
|
// ==================== Helper Methods ====================
|
|
|
|
func (s *HealthCheckService) getConcurrency(configValue *int, defaultValue int) int {
|
|
var concurrency int
|
|
if configValue != nil && *configValue > 0 {
|
|
concurrency = *configValue
|
|
} else {
|
|
concurrency = defaultValue
|
|
}
|
|
|
|
return s.ensureConcurrencyBounds(concurrency)
|
|
}
|
|
|
|
func (s *HealthCheckService) ensureConcurrencyBounds(concurrency int) int {
|
|
if concurrency < minHealthCheckConcurrency {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"requested": concurrency,
|
|
"minimum": minHealthCheckConcurrency,
|
|
}).Debug("Concurrency below minimum, adjusting")
|
|
return minHealthCheckConcurrency
|
|
}
|
|
|
|
if concurrency > maxHealthCheckConcurrency {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"requested": concurrency,
|
|
"maximum": maxHealthCheckConcurrency,
|
|
}).Warn("Concurrency exceeds maximum, capping it")
|
|
return maxHealthCheckConcurrency
|
|
}
|
|
|
|
return concurrency
|
|
}
|