// Filename: internal/domain/proxy/manager.go package proxy import ( "encoding/json" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/store" "gemini-balancer/internal/syncer" "gemini-balancer/internal/task" "context" "net" "net/http" "net/url" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "golang.org/x/net/proxy" "gorm.io/gorm" ) const ( TaskTypeProxySync = "proxy_sync" proxyChunkSize = 200 // 代理同步的批量大小 ) type ProxyCheckResult struct { Proxy string `json:"proxy"` IsAvailable bool `json:"is_available"` ResponseTime float64 `json:"response_time"` ErrorMessage string `json:"error_message"` } // managerCacheData type managerCacheData struct { ActiveProxies []*models.ProxyConfig ProxiesByID map[uint]*models.ProxyConfig } // manager结构体 type manager struct { db *gorm.DB syncer *syncer.CacheSyncer[managerCacheData] task task.Reporter store store.Store logger *logrus.Entry } func newManagerLoader(db *gorm.DB) syncer.LoaderFunc[managerCacheData] { return func() (managerCacheData, error) { var activeProxies []*models.ProxyConfig if err := db.Where("status = ?", "active").Order("assigned_keys_count asc").Find(&activeProxies).Error; err != nil { return managerCacheData{}, fmt.Errorf("failed to load active proxies for cache: %w", err) } proxiesByID := make(map[uint]*models.ProxyConfig, len(activeProxies)) for _, proxy := range activeProxies { p := *proxy proxiesByID[p.ID] = &p } return managerCacheData{ ActiveProxies: activeProxies, ProxiesByID: proxiesByID, }, nil } } func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskReporter task.Reporter, store store.Store, logger *logrus.Entry) *manager { return &manager{ db: db, syncer: syncer, task: taskReporter, store: store, logger: logger, } } func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) { resourceID := "global_proxy_sync" taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0) if err != nil { return nil, ErrTaskConflict } go m.runProxySyncTask(taskStatus.ID, proxyStrings) return taskStatus, nil } func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) { resourceID := "global_proxy_sync" var allProxies []models.ProxyConfig if err := m.db.Find(&allProxies).Error; err != nil { m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err)) return } currentProxyMap := make(map[string]uint) for _, p := range allProxies { fullString := fmt.Sprintf("%s://%s", p.Protocol, p.Address) currentProxyMap[fullString] = p.ID } finalProxyMap := make(map[string]bool) for _, ps := range finalProxyStrings { finalProxyMap[strings.TrimSpace(ps)] = true } var idsToDelete []uint var proxiesToAdd []models.ProxyConfig for proxyString, id := range currentProxyMap { if !finalProxyMap[proxyString] { idsToDelete = append(idsToDelete, id) } } for proxyString := range finalProxyMap { if _, exists := currentProxyMap[proxyString]; !exists { parsed := parseProxyString(proxyString) if parsed != nil { proxiesToAdd = append(proxiesToAdd, models.ProxyConfig{ Protocol: parsed.Protocol, Address: parsed.Address, Status: "active", }) } } } if len(idsToDelete) > 0 { if err := m.bulkDeleteByIDs(idsToDelete); err != nil { m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err)) return } } if len(proxiesToAdd) > 0 { if err := m.bulkAdd(proxiesToAdd); err != nil { m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err)) return } } result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)} m.task.EndTaskByID(taskID, resourceID, result, nil) m.publishChangeEvent("proxies_synced") go m.invalidate() } type parsedProxy struct{ Protocol, Address string } func parseProxyString(proxyStr string) *parsedProxy { proxyStr = strings.TrimSpace(proxyStr) u, err := url.Parse(proxyStr) if err != nil || !strings.Contains(proxyStr, "://") { if strings.Contains(proxyStr, "@") { parts := strings.Split(proxyStr, "@") if len(parts) == 2 { proxyStr = "socks5://" + proxyStr u, err = url.Parse(proxyStr) if err != nil { return nil } } } else { return nil } } protocol := strings.ToLower(u.Scheme) if protocol != "http" && protocol != "https" && protocol != "socks5" { return nil } address := u.Host if u.User != nil { address = u.User.String() + "@" + u.Host } return &parsedProxy{Protocol: protocol, Address: address} } func (m *manager) bulkDeleteByIDs(ids []uint) error { for i := 0; i < len(ids); i += proxyChunkSize { end := i + proxyChunkSize if end > len(ids) { end = len(ids) } chunk := ids[i:end] if err := m.db.Where("id IN ?", chunk).Delete(&models.ProxyConfig{}).Error; err != nil { return err } } return nil } func (m *manager) bulkAdd(proxies []models.ProxyConfig) error { return m.db.CreateInBatches(proxies, proxyChunkSize).Error } func (m *manager) publishChangeEvent(reason string) { event := models.ProxyStatusChangedEvent{Action: reason} eventData, _ := json.Marshal(event) _ = m.store.Publish(models.TopicProxyStatusChanged, eventData) } func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) { cacheData := m.syncer.Get() if cacheData.ActiveProxies == nil { return nil, ErrNoActiveProxies } if apiKey.ProxyID != nil { if proxy, ok := cacheData.ProxiesByID[*apiKey.ProxyID]; ok { return proxy, nil } } if len(cacheData.ActiveProxies) == 0 { return nil, ErrNoActiveProxies } bestProxy := cacheData.ActiveProxies[0] txErr := m.db.Transaction(func(tx *gorm.DB) error { if err := tx.Model(apiKey).Update("proxy_id", bestProxy.ID).Error; err != nil { return err } if err := tx.Model(bestProxy).Update("assigned_keys_count", gorm.Expr("assigned_keys_count + 1")).Error; err != nil { return err } return nil }) if txErr != nil { return nil, txErr } go m.invalidate() return bestProxy, nil } func (m *manager) invalidate() error { m.logger.Info("Proxy cache invalidation triggered.") return m.syncer.Invalidate() } func (m *manager) stop() { m.syncer.Stop() } func (m *manager) CheckSingleProxy(proxyURL string, timeout time.Duration) *ProxyCheckResult { parsed := parseProxyString(proxyURL) if parsed == nil { return &ProxyCheckResult{Proxy: proxyURL, IsAvailable: false, ErrorMessage: "Invalid URL format"} } proxyCfg := &models.ProxyConfig{Protocol: parsed.Protocol, Address: parsed.Address} startTime := time.Now() isAlive := m.checkProxyConnectivity(proxyCfg, timeout) latency := time.Since(startTime).Seconds() result := &ProxyCheckResult{ Proxy: proxyURL, IsAvailable: isAlive, ResponseTime: latency, } if !isAlive { result.ErrorMessage = "Connection failed or timed out" } return result } func (m *manager) CheckMultipleProxies(proxies []string, timeout time.Duration, concurrency int) []*ProxyCheckResult { var wg sync.WaitGroup jobs := make(chan string, len(proxies)) resultsChan := make(chan *ProxyCheckResult, len(proxies)) for i := 0; i < concurrency; i++ { wg.Add(1) go func() { defer wg.Done() for proxyURL := range jobs { resultsChan <- m.CheckSingleProxy(proxyURL, timeout) } }() } for _, p := range proxies { jobs <- p } close(jobs) wg.Wait() close(resultsChan) finalResults := make([]*ProxyCheckResult, 0, len(proxies)) for res := range resultsChan { finalResults = append(finalResults, res) } return finalResults } func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout time.Duration) bool { const ProxyCheckTargetURL = "https://www.google.com/generate_204" transport := &http.Transport{} switch proxyCfg.Protocol { case "http", "https": proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address) if err != nil { return false } transport.Proxy = http.ProxyURL(proxyUrl) case "socks5": dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct) if err != nil { return false } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) } default: return false } client := &http.Client{ Transport: transport, Timeout: timeout, } resp, err := client.Get(ProxyCheckTargetURL) if err != nil { return false } defer resp.Body.Close() return true }