314 lines
8.5 KiB
Go
314 lines
8.5 KiB
Go
// Filename: internal/domain/proxy/manager.go
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/store"
|
|
"gemini-balancer/internal/syncer"
|
|
"gemini-balancer/internal/task"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/net/proxy"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
TaskTypeProxySync = "proxy_sync"
|
|
proxyChunkSize = 200
|
|
)
|
|
|
|
type ProxyCheckResult struct {
|
|
Proxy string `json:"proxy"`
|
|
IsAvailable bool `json:"is_available"`
|
|
ResponseTime float64 `json:"response_time"`
|
|
ErrorMessage string `json:"error_message"`
|
|
}
|
|
|
|
type managerCacheData struct {
|
|
ActiveProxies []*models.ProxyConfig
|
|
ProxiesByID map[uint]*models.ProxyConfig
|
|
}
|
|
|
|
type manager struct {
|
|
db *gorm.DB
|
|
syncer *syncer.CacheSyncer[managerCacheData]
|
|
task task.Reporter
|
|
store store.Store
|
|
logger *logrus.Entry
|
|
}
|
|
|
|
func newManagerLoader(db *gorm.DB) syncer.LoaderFunc[managerCacheData] {
|
|
return func() (managerCacheData, error) {
|
|
var activeProxies []*models.ProxyConfig
|
|
if err := db.Where("status = ?", "active").Order("assigned_keys_count asc").Find(&activeProxies).Error; err != nil {
|
|
return managerCacheData{}, fmt.Errorf("failed to load active proxies for cache: %w", err)
|
|
}
|
|
|
|
proxiesByID := make(map[uint]*models.ProxyConfig, len(activeProxies))
|
|
for _, proxy := range activeProxies {
|
|
p := *proxy
|
|
proxiesByID[p.ID] = &p
|
|
}
|
|
|
|
return managerCacheData{
|
|
ActiveProxies: activeProxies,
|
|
ProxiesByID: proxiesByID,
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskReporter task.Reporter, store store.Store, logger *logrus.Entry) *manager {
|
|
return &manager{
|
|
db: db,
|
|
syncer: syncer,
|
|
task: taskReporter,
|
|
store: store,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) {
|
|
resourceID := "global_proxy_sync"
|
|
taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
|
|
if err != nil {
|
|
return nil, ErrTaskConflict
|
|
}
|
|
go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings)
|
|
return taskStatus, nil
|
|
}
|
|
|
|
func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) {
|
|
resourceID := "global_proxy_sync"
|
|
var allProxies []models.ProxyConfig
|
|
if err := m.db.Find(&allProxies).Error; err != nil {
|
|
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
|
|
return
|
|
}
|
|
currentProxyMap := make(map[string]uint)
|
|
for _, p := range allProxies {
|
|
fullString := fmt.Sprintf("%s://%s", p.Protocol, p.Address)
|
|
currentProxyMap[fullString] = p.ID
|
|
}
|
|
finalProxyMap := make(map[string]bool)
|
|
for _, ps := range finalProxyStrings {
|
|
finalProxyMap[strings.TrimSpace(ps)] = true
|
|
}
|
|
var idsToDelete []uint
|
|
var proxiesToAdd []models.ProxyConfig
|
|
for proxyString, id := range currentProxyMap {
|
|
if !finalProxyMap[proxyString] {
|
|
idsToDelete = append(idsToDelete, id)
|
|
}
|
|
}
|
|
for proxyString := range finalProxyMap {
|
|
if _, exists := currentProxyMap[proxyString]; !exists {
|
|
parsed := parseProxyString(proxyString)
|
|
if parsed != nil {
|
|
proxiesToAdd = append(proxiesToAdd, models.ProxyConfig{
|
|
Protocol: parsed.Protocol, Address: parsed.Address, Status: "active",
|
|
})
|
|
}
|
|
}
|
|
}
|
|
if len(idsToDelete) > 0 {
|
|
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
|
|
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
|
|
return
|
|
}
|
|
}
|
|
if len(proxiesToAdd) > 0 {
|
|
if err := m.bulkAdd(proxiesToAdd); err != nil {
|
|
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
|
|
return
|
|
}
|
|
}
|
|
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
|
|
m.task.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
|
m.publishChangeEvent(ctx, "proxies_synced")
|
|
go m.invalidate()
|
|
}
|
|
|
|
type parsedProxy struct{ Protocol, Address string }
|
|
|
|
func parseProxyString(proxyStr string) *parsedProxy {
|
|
proxyStr = strings.TrimSpace(proxyStr)
|
|
u, err := url.Parse(proxyStr)
|
|
if err != nil || !strings.Contains(proxyStr, "://") {
|
|
if strings.Contains(proxyStr, "@") {
|
|
parts := strings.Split(proxyStr, "@")
|
|
if len(parts) == 2 {
|
|
proxyStr = "socks5://" + proxyStr
|
|
u, err = url.Parse(proxyStr)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
protocol := strings.ToLower(u.Scheme)
|
|
if protocol != "http" && protocol != "https" && protocol != "socks5" {
|
|
return nil
|
|
}
|
|
address := u.Host
|
|
if u.User != nil {
|
|
address = u.User.String() + "@" + u.Host
|
|
}
|
|
return &parsedProxy{Protocol: protocol, Address: address}
|
|
}
|
|
|
|
func (m *manager) bulkDeleteByIDs(ids []uint) error {
|
|
for i := 0; i < len(ids); i += proxyChunkSize {
|
|
end := i + proxyChunkSize
|
|
if end > len(ids) {
|
|
end = len(ids)
|
|
}
|
|
chunk := ids[i:end]
|
|
if err := m.db.Where("id IN ?", chunk).Delete(&models.ProxyConfig{}).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
|
|
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
|
|
}
|
|
|
|
func (m *manager) publishChangeEvent(ctx context.Context, reason string) {
|
|
event := models.ProxyStatusChangedEvent{Action: reason}
|
|
eventData, _ := json.Marshal(event)
|
|
_ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
|
|
}
|
|
|
|
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
|
|
cacheData := m.syncer.Get()
|
|
if cacheData.ActiveProxies == nil {
|
|
return nil, ErrNoActiveProxies
|
|
}
|
|
if apiKey.ProxyID != nil {
|
|
if proxy, ok := cacheData.ProxiesByID[*apiKey.ProxyID]; ok {
|
|
return proxy, nil
|
|
}
|
|
}
|
|
if len(cacheData.ActiveProxies) == 0 {
|
|
return nil, ErrNoActiveProxies
|
|
}
|
|
bestProxy := cacheData.ActiveProxies[0]
|
|
txErr := m.db.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Model(apiKey).Update("proxy_id", bestProxy.ID).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Model(bestProxy).Update("assigned_keys_count", gorm.Expr("assigned_keys_count + 1")).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
if txErr != nil {
|
|
return nil, txErr
|
|
}
|
|
go m.invalidate()
|
|
return bestProxy, nil
|
|
}
|
|
|
|
func (m *manager) invalidate() error {
|
|
m.logger.Info("Proxy cache invalidation triggered.")
|
|
return m.syncer.Invalidate()
|
|
}
|
|
|
|
func (m *manager) stop() {
|
|
m.syncer.Stop()
|
|
}
|
|
|
|
func (m *manager) CheckSingleProxy(proxyURL string, timeout time.Duration) *ProxyCheckResult {
|
|
parsed := parseProxyString(proxyURL)
|
|
if parsed == nil {
|
|
return &ProxyCheckResult{Proxy: proxyURL, IsAvailable: false, ErrorMessage: "Invalid URL format"}
|
|
}
|
|
|
|
proxyCfg := &models.ProxyConfig{Protocol: parsed.Protocol, Address: parsed.Address}
|
|
|
|
startTime := time.Now()
|
|
isAlive := m.checkProxyConnectivity(proxyCfg, timeout)
|
|
latency := time.Since(startTime).Seconds()
|
|
result := &ProxyCheckResult{
|
|
Proxy: proxyURL,
|
|
IsAvailable: isAlive,
|
|
ResponseTime: latency,
|
|
}
|
|
if !isAlive {
|
|
result.ErrorMessage = "Connection failed or timed out"
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (m *manager) CheckMultipleProxies(proxies []string, timeout time.Duration, concurrency int) []*ProxyCheckResult {
|
|
var wg sync.WaitGroup
|
|
jobs := make(chan string, len(proxies))
|
|
resultsChan := make(chan *ProxyCheckResult, len(proxies))
|
|
for i := 0; i < concurrency; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for proxyURL := range jobs {
|
|
resultsChan <- m.CheckSingleProxy(proxyURL, timeout)
|
|
}
|
|
}()
|
|
}
|
|
for _, p := range proxies {
|
|
jobs <- p
|
|
}
|
|
close(jobs)
|
|
wg.Wait()
|
|
close(resultsChan)
|
|
finalResults := make([]*ProxyCheckResult, 0, len(proxies))
|
|
for res := range resultsChan {
|
|
finalResults = append(finalResults, res)
|
|
}
|
|
return finalResults
|
|
}
|
|
|
|
func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
|
|
const ProxyCheckTargetURL = "https://www.google.com/generate_204"
|
|
transport := &http.Transport{}
|
|
switch proxyCfg.Protocol {
|
|
case "http", "https":
|
|
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
transport.Proxy = http.ProxyURL(proxyUrl)
|
|
case "socks5":
|
|
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return dialer.Dial(network, addr)
|
|
}
|
|
default:
|
|
return false
|
|
}
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
Timeout: timeout,
|
|
}
|
|
resp, err := client.Get(ProxyCheckTargetURL)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer resp.Body.Close()
|
|
return true
|
|
}
|