New
This commit is contained in:
8
internal/domain/proxy/errors.go
Normal file
8
internal/domain/proxy/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package proxy
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoActiveProxies = errors.New("no active proxies available in the pool")
|
||||
ErrTaskConflict = errors.New("a sync task is already in progress for proxies")
|
||||
)
|
||||
269
internal/domain/proxy/handler.go
Normal file
269
internal/domain/proxy/handler.go
Normal file
@@ -0,0 +1,269 @@
|
||||
// Filename: internal/domain/proxy/handler.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
db *gorm.DB
|
||||
manager *manager
|
||||
store store.Store
|
||||
settings *settings.SettingsManager
|
||||
}
|
||||
|
||||
func newHandler(db *gorm.DB, m *manager, s store.Store, sp *settings.SettingsManager) *handler {
|
||||
return &handler{
|
||||
db: db,
|
||||
manager: m,
|
||||
store: s,
|
||||
settings: sp,
|
||||
}
|
||||
}
|
||||
|
||||
// === 领域暴露的公共API ===
|
||||
|
||||
func (h *handler) registerRoutes(rg *gin.RouterGroup) {
|
||||
proxyRoutes := rg.Group("/proxies")
|
||||
{
|
||||
proxyRoutes.PUT("/sync", h.SyncProxies)
|
||||
proxyRoutes.POST("/check", h.CheckSingleProxy)
|
||||
proxyRoutes.POST("/check-all", h.CheckAllProxies)
|
||||
|
||||
proxyRoutes.POST("/", h.CreateProxyConfig)
|
||||
proxyRoutes.GET("/", h.ListProxyConfigs)
|
||||
proxyRoutes.GET("/:id", h.GetProxyConfig)
|
||||
proxyRoutes.PUT("/:id", h.UpdateProxyConfig)
|
||||
proxyRoutes.DELETE("/:id", h.DeleteProxyConfig)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// --- 请求 DTO ---
|
||||
type CreateProxyConfigRequest struct {
|
||||
Address string `json:"address" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type UpdateProxyConfigRequest struct {
|
||||
Address *string `json:"address"`
|
||||
Protocol *string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// 单个检测的请求体 (与前端JS对齐)
|
||||
type CheckSingleProxyRequest struct {
|
||||
Proxy string `json:"proxy" binding:"required"`
|
||||
}
|
||||
|
||||
// 批量检测的请求体
|
||||
type CheckAllProxiesRequest struct {
|
||||
Proxies []string `json:"proxies" binding:"required"`
|
||||
}
|
||||
|
||||
// --- Handler 方法 ---
|
||||
|
||||
func (h *handler) CreateProxyConfig(c *gin.Context) {
|
||||
var req CreateProxyConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Status == "" {
|
||||
req.Status = "active" // 默认状态
|
||||
}
|
||||
|
||||
proxyConfig := models.ProxyConfig{
|
||||
Address: req.Address,
|
||||
Protocol: req.Protocol,
|
||||
Status: req.Status,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
if err := h.db.Create(&proxyConfig).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
// 写操作后,发布事件并使缓存失效
|
||||
h.publishAndInvalidate(proxyConfig.ID, "created")
|
||||
response.Created(c, proxyConfig)
|
||||
}
|
||||
|
||||
func (h *handler) ListProxyConfigs(c *gin.Context) {
|
||||
var proxyConfigs []models.ProxyConfig
|
||||
if err := h.db.Find(&proxyConfigs).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, proxyConfigs)
|
||||
}
|
||||
|
||||
func (h *handler) GetProxyConfig(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
var proxyConfig models.ProxyConfig
|
||||
if err := h.db.First(&proxyConfig, id).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, proxyConfig)
|
||||
}
|
||||
|
||||
func (h *handler) UpdateProxyConfig(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProxyConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var proxyConfig models.ProxyConfig
|
||||
if err := h.db.First(&proxyConfig, id).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Address != nil {
|
||||
proxyConfig.Address = *req.Address
|
||||
}
|
||||
if req.Protocol != nil {
|
||||
proxyConfig.Protocol = *req.Protocol
|
||||
}
|
||||
if req.Status != nil {
|
||||
proxyConfig.Status = *req.Status
|
||||
}
|
||||
if req.Description != nil {
|
||||
proxyConfig.Description = *req.Description
|
||||
}
|
||||
|
||||
if err := h.db.Save(&proxyConfig).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
h.publishAndInvalidate(uint(id), "updated")
|
||||
response.Success(c, proxyConfig)
|
||||
}
|
||||
|
||||
func (h *handler) DeleteProxyConfig(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := h.db.Model(&models.APIKey{}).Where("proxy_id = ?", id).Count(&count).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrDuplicateResource, "Cannot delete proxy config that is still in use by API keys"))
|
||||
return
|
||||
}
|
||||
|
||||
result := h.db.Delete(&models.ProxyConfig{}, id)
|
||||
if result.Error != nil {
|
||||
response.Error(c, errors.ParseDBError(result.Error))
|
||||
return
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
response.Error(c, errors.ErrResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
h.publishAndInvalidate(uint(id), "deleted")
|
||||
response.NoContent(c)
|
||||
}
|
||||
|
||||
// publishAndInvalidate 统一事件发布和缓存失效逻辑
|
||||
func (h *handler) publishAndInvalidate(proxyID uint, action string) {
|
||||
go h.manager.invalidate()
|
||||
go func() {
|
||||
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
// 新的 Handler 方法和 DTO
|
||||
type SyncProxiesRequest struct {
|
||||
Proxies []string `json:"proxies"`
|
||||
}
|
||||
|
||||
func (h *handler) SyncProxies(c *gin.Context) {
|
||||
var req SyncProxiesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, ErrTaskConflict) {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
} else {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Proxy synchronization task started.",
|
||||
"task": taskStatus,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *handler) CheckSingleProxy(c *gin.Context) {
|
||||
var req CheckSingleProxyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
cfg := h.settings.GetSettings()
|
||||
timeout := time.Duration(cfg.ProxyCheckTimeoutSeconds) * time.Second
|
||||
result := h.manager.CheckSingleProxy(req.Proxy, timeout)
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *handler) CheckAllProxies(c *gin.Context) {
|
||||
var req CheckAllProxiesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
cfg := h.settings.GetSettings()
|
||||
timeout := time.Duration(cfg.ProxyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
concurrency := cfg.ProxyCheckConcurrency
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // 如果配置不合法,提供一个安全的默认值
|
||||
}
|
||||
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
|
||||
response.Success(c, results)
|
||||
}
|
||||
315
internal/domain/proxy/manager.go
Normal file
315
internal/domain/proxy/manager.go
Normal file
@@ -0,0 +1,315 @@
|
||||
// Filename: internal/domain/proxy/manager.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeProxySync = "proxy_sync"
|
||||
proxyChunkSize = 200 // 代理同步的批量大小
|
||||
)
|
||||
|
||||
type ProxyCheckResult struct {
|
||||
Proxy string `json:"proxy"`
|
||||
IsAvailable bool `json:"is_available"`
|
||||
ResponseTime float64 `json:"response_time"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
}
|
||||
|
||||
// managerCacheData
|
||||
type managerCacheData struct {
|
||||
ActiveProxies []*models.ProxyConfig
|
||||
ProxiesByID map[uint]*models.ProxyConfig
|
||||
}
|
||||
|
||||
// manager结构体
|
||||
type manager struct {
|
||||
db *gorm.DB
|
||||
syncer *syncer.CacheSyncer[managerCacheData]
|
||||
task task.Reporter
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func newManagerLoader(db *gorm.DB) syncer.LoaderFunc[managerCacheData] {
|
||||
return func() (managerCacheData, error) {
|
||||
var activeProxies []*models.ProxyConfig
|
||||
if err := db.Where("status = ?", "active").Order("assigned_keys_count asc").Find(&activeProxies).Error; err != nil {
|
||||
return managerCacheData{}, fmt.Errorf("failed to load active proxies for cache: %w", err)
|
||||
}
|
||||
|
||||
proxiesByID := make(map[uint]*models.ProxyConfig, len(activeProxies))
|
||||
for _, proxy := range activeProxies {
|
||||
p := *proxy
|
||||
proxiesByID[p.ID] = &p
|
||||
}
|
||||
|
||||
return managerCacheData{
|
||||
ActiveProxies: activeProxies,
|
||||
ProxiesByID: proxiesByID,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskReporter task.Reporter, store store.Store, logger *logrus.Entry) *manager {
|
||||
return &manager{
|
||||
db: db,
|
||||
syncer: syncer,
|
||||
task: taskReporter,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) {
|
||||
resourceID := "global_proxy_sync"
|
||||
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
|
||||
if err != nil {
|
||||
return nil, ErrTaskConflict
|
||||
}
|
||||
go m.runProxySyncTask(taskStatus.ID, proxyStrings)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
|
||||
resourceID := "global_proxy_sync"
|
||||
var allProxies []models.ProxyConfig
|
||||
if err := m.db.Find(&allProxies).Error; err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
|
||||
return
|
||||
}
|
||||
currentProxyMap := make(map[string]uint)
|
||||
for _, p := range allProxies {
|
||||
fullString := fmt.Sprintf("%s://%s", p.Protocol, p.Address)
|
||||
currentProxyMap[fullString] = p.ID
|
||||
}
|
||||
finalProxyMap := make(map[string]bool)
|
||||
for _, ps := range finalProxyStrings {
|
||||
finalProxyMap[strings.TrimSpace(ps)] = true
|
||||
}
|
||||
var idsToDelete []uint
|
||||
var proxiesToAdd []models.ProxyConfig
|
||||
for proxyString, id := range currentProxyMap {
|
||||
if !finalProxyMap[proxyString] {
|
||||
idsToDelete = append(idsToDelete, id)
|
||||
}
|
||||
}
|
||||
for proxyString := range finalProxyMap {
|
||||
if _, exists := currentProxyMap[proxyString]; !exists {
|
||||
parsed := parseProxyString(proxyString)
|
||||
if parsed != nil {
|
||||
proxiesToAdd = append(proxiesToAdd, models.ProxyConfig{
|
||||
Protocol: parsed.Protocol, Address: parsed.Address, Status: "active",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(idsToDelete) > 0 {
|
||||
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(proxiesToAdd) > 0 {
|
||||
if err := m.bulkAdd(proxiesToAdd); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
|
||||
m.task.EndTaskByID(taskID, resourceID, result, nil)
|
||||
m.publishChangeEvent("proxies_synced")
|
||||
go m.invalidate()
|
||||
}
|
||||
|
||||
type parsedProxy struct{ Protocol, Address string }
|
||||
|
||||
func parseProxyString(proxyStr string) *parsedProxy {
|
||||
proxyStr = strings.TrimSpace(proxyStr)
|
||||
u, err := url.Parse(proxyStr)
|
||||
if err != nil || !strings.Contains(proxyStr, "://") {
|
||||
if strings.Contains(proxyStr, "@") {
|
||||
parts := strings.Split(proxyStr, "@")
|
||||
if len(parts) == 2 {
|
||||
proxyStr = "socks5://" + proxyStr
|
||||
u, err = url.Parse(proxyStr)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
protocol := strings.ToLower(u.Scheme)
|
||||
if protocol != "http" && protocol != "https" && protocol != "socks5" {
|
||||
return nil
|
||||
}
|
||||
address := u.Host
|
||||
if u.User != nil {
|
||||
address = u.User.String() + "@" + u.Host
|
||||
}
|
||||
return &parsedProxy{Protocol: protocol, Address: address}
|
||||
}
|
||||
|
||||
func (m *manager) bulkDeleteByIDs(ids []uint) error {
|
||||
for i := 0; i < len(ids); i += proxyChunkSize {
|
||||
end := i + proxyChunkSize
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
chunk := ids[i:end]
|
||||
if err := m.db.Where("id IN ?", chunk).Delete(&models.ProxyConfig{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
|
||||
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
|
||||
}
|
||||
|
||||
func (m *manager) publishChangeEvent(reason string) {
|
||||
event := models.ProxyStatusChangedEvent{Action: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
|
||||
cacheData := m.syncer.Get()
|
||||
if cacheData.ActiveProxies == nil {
|
||||
return nil, ErrNoActiveProxies
|
||||
}
|
||||
if apiKey.ProxyID != nil {
|
||||
if proxy, ok := cacheData.ProxiesByID[*apiKey.ProxyID]; ok {
|
||||
return proxy, nil
|
||||
}
|
||||
}
|
||||
if len(cacheData.ActiveProxies) == 0 {
|
||||
return nil, ErrNoActiveProxies
|
||||
}
|
||||
bestProxy := cacheData.ActiveProxies[0]
|
||||
txErr := m.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(apiKey).Update("proxy_id", bestProxy.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(bestProxy).Update("assigned_keys_count", gorm.Expr("assigned_keys_count + 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return nil, txErr
|
||||
}
|
||||
go m.invalidate()
|
||||
return bestProxy, nil
|
||||
}
|
||||
|
||||
func (m *manager) invalidate() error {
|
||||
m.logger.Info("Proxy cache invalidation triggered.")
|
||||
return m.syncer.Invalidate()
|
||||
}
|
||||
|
||||
func (m *manager) stop() {
|
||||
m.syncer.Stop()
|
||||
}
|
||||
|
||||
func (m *manager) CheckSingleProxy(proxyURL string, timeout time.Duration) *ProxyCheckResult {
|
||||
parsed := parseProxyString(proxyURL)
|
||||
if parsed == nil {
|
||||
return &ProxyCheckResult{Proxy: proxyURL, IsAvailable: false, ErrorMessage: "Invalid URL format"}
|
||||
}
|
||||
|
||||
proxyCfg := &models.ProxyConfig{Protocol: parsed.Protocol, Address: parsed.Address}
|
||||
|
||||
startTime := time.Now()
|
||||
isAlive := m.checkProxyConnectivity(proxyCfg, timeout)
|
||||
latency := time.Since(startTime).Seconds()
|
||||
result := &ProxyCheckResult{
|
||||
Proxy: proxyURL,
|
||||
IsAvailable: isAlive,
|
||||
ResponseTime: latency,
|
||||
}
|
||||
if !isAlive {
|
||||
result.ErrorMessage = "Connection failed or timed out"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *manager) CheckMultipleProxies(proxies []string, timeout time.Duration, concurrency int) []*ProxyCheckResult {
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan string, len(proxies))
|
||||
resultsChan := make(chan *ProxyCheckResult, len(proxies))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for proxyURL := range jobs {
|
||||
resultsChan <- m.CheckSingleProxy(proxyURL, timeout)
|
||||
}
|
||||
}()
|
||||
}
|
||||
for _, p := range proxies {
|
||||
jobs <- p
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
close(resultsChan)
|
||||
finalResults := make([]*ProxyCheckResult, 0, len(proxies))
|
||||
for res := range resultsChan {
|
||||
finalResults = append(finalResults, res)
|
||||
}
|
||||
return finalResults
|
||||
}
|
||||
|
||||
func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
|
||||
const ProxyCheckTargetURL = "https://www.google.com/generate_204"
|
||||
transport := &http.Transport{}
|
||||
switch proxyCfg.Protocol {
|
||||
case "http", "https":
|
||||
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
transport.Proxy = http.ProxyURL(proxyUrl)
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
resp, err := client.Get(ProxyCheckTargetURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return true
|
||||
}
|
||||
45
internal/domain/proxy/module.go
Normal file
45
internal/domain/proxy/module.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Filename: internal/domain/proxy/module.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
manager *manager
|
||||
handler *handler
|
||||
}
|
||||
|
||||
func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) {
|
||||
loader := newManagerLoader(gormDB)
|
||||
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
manager := newManager(gormDB, cacheSyncer, taskReporter, store, logger.WithField("domain", "proxy"))
|
||||
handler := newHandler(gormDB, manager, store, settingsManager)
|
||||
return &Module{
|
||||
manager: manager,
|
||||
handler: handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Module) AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
|
||||
return m.manager.assignProxyIfNeeded(apiKey)
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router *gin.RouterGroup) {
|
||||
m.handler.registerRoutes(router)
|
||||
}
|
||||
|
||||
func (m *Module) Stop() {
|
||||
m.manager.stop()
|
||||
}
|
||||
167
internal/domain/upstream/handler.go
Normal file
167
internal/domain/upstream/handler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Filename: internal/domain/upstream/handler.go
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service) *Handler {
|
||||
return &Handler{service: service}
|
||||
}
|
||||
|
||||
// ------ DTOs and Validation ------
|
||||
type CreateUpstreamRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Weight int `json:"weight" binding:"omitempty,gte=1,lte=1000"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
type UpdateUpstreamRequest struct {
|
||||
URL *string `json:"url"`
|
||||
Weight *int `json:"weight" binding:"omitempty,gte=1,lte=1000"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
func isValidURL(rawURL string) bool {
|
||||
u, err := url.ParseRequestURI(rawURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Scheme == "http" || u.Scheme == "https"
|
||||
}
|
||||
|
||||
// --- Handler ---
|
||||
|
||||
func (h *Handler) CreateUpstream(c *gin.Context) {
|
||||
var req CreateUpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if !isValidURL(req.URL) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid URL format"))
|
||||
return
|
||||
}
|
||||
|
||||
upstream := models.UpstreamEndpoint{
|
||||
URL: req.URL,
|
||||
Weight: req.Weight,
|
||||
Status: req.Status,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
if err := h.service.Create(&upstream); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Created(c, upstream)
|
||||
}
|
||||
|
||||
func (h *Handler) ListUpstreams(c *gin.Context) {
|
||||
upstreams, err := h.service.List()
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, upstreams)
|
||||
}
|
||||
|
||||
func (h *Handler) GetUpstream(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
upstream, err := h.service.GetByID(id)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, upstream)
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateUpstream(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
var req UpdateUpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
upstream, err := h.service.GetByID(id)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL != nil {
|
||||
if !isValidURL(*req.URL) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid URL format"))
|
||||
return
|
||||
}
|
||||
upstream.URL = *req.URL
|
||||
}
|
||||
if req.Weight != nil {
|
||||
upstream.Weight = *req.Weight
|
||||
}
|
||||
if req.Status != nil {
|
||||
upstream.Status = *req.Status
|
||||
}
|
||||
if req.Description != nil {
|
||||
upstream.Description = *req.Description
|
||||
}
|
||||
|
||||
if err := h.service.Update(upstream); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, upstream)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteUpstream(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
rowsAffected, err := h.service.Delete(id)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
response.Error(c, errors.ErrResourceNotFound)
|
||||
return
|
||||
}
|
||||
response.NoContent(c)
|
||||
}
|
||||
|
||||
// RegisterRoutes
|
||||
|
||||
func (h *Handler) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
upstreamRoutes := rg.Group("/upstreams")
|
||||
{
|
||||
upstreamRoutes.POST("/", h.CreateUpstream)
|
||||
upstreamRoutes.GET("/", h.ListUpstreams)
|
||||
upstreamRoutes.GET("/:id", h.GetUpstream)
|
||||
upstreamRoutes.PUT("/:id", h.UpdateUpstream)
|
||||
upstreamRoutes.DELETE("/:id", h.DeleteUpstream)
|
||||
}
|
||||
}
|
||||
36
internal/domain/upstream/module.go
Normal file
36
internal/domain/upstream/module.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Filename: internal/domain/upstream/module.go
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
service *Service
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
func NewModule(db *gorm.DB) *Module {
|
||||
service := NewService(db)
|
||||
handler := NewHandler(service)
|
||||
|
||||
return &Module{
|
||||
service: service,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// === 领域暴露的公共API ===
|
||||
|
||||
// SelectActiveWeighted
|
||||
|
||||
func (m *Module) SelectActiveWeighted(upstreams []*models.UpstreamEndpoint) (*models.UpstreamEndpoint, error) {
|
||||
return m.service.SelectActiveWeighted(upstreams)
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router *gin.RouterGroup) {
|
||||
m.handler.RegisterRoutes(router)
|
||||
}
|
||||
84
internal/domain/upstream/service.go
Normal file
84
internal/domain/upstream/service.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Filename: internal/domain/upstream/service.go
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB) *Service {
|
||||
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
return &Service{db: db}
|
||||
}
|
||||
|
||||
func (s *Service) SelectActiveWeighted(upstreams []*models.UpstreamEndpoint) (*models.UpstreamEndpoint, error) {
|
||||
activeUpstreams := make([]*models.UpstreamEndpoint, 0)
|
||||
totalWeight := 0
|
||||
for _, u := range upstreams {
|
||||
if u.Status == "active" {
|
||||
activeUpstreams = append(activeUpstreams, u)
|
||||
totalWeight += u.Weight
|
||||
}
|
||||
}
|
||||
if len(activeUpstreams) == 0 {
|
||||
return nil, errors.New("no active upstream endpoints available")
|
||||
}
|
||||
if totalWeight <= 0 || len(activeUpstreams) == 1 {
|
||||
return activeUpstreams[0], nil
|
||||
}
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
for _, u := range activeUpstreams {
|
||||
randomWeight -= u.Weight
|
||||
if randomWeight < 0 {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
return activeUpstreams[len(activeUpstreams)-1], nil
|
||||
}
|
||||
|
||||
// CRUD,供Handler调用
|
||||
|
||||
func (s *Service) Create(upstream *models.UpstreamEndpoint) error {
|
||||
if upstream.Weight == 0 {
|
||||
upstream.Weight = 100 // 默认权重
|
||||
}
|
||||
if upstream.Status == "" {
|
||||
upstream.Status = "active" // 默认状态
|
||||
}
|
||||
return s.db.Create(upstream).Error
|
||||
}
|
||||
|
||||
// List Service层只做数据库查询
|
||||
func (s *Service) List() ([]models.UpstreamEndpoint, error) {
|
||||
var upstreams []models.UpstreamEndpoint
|
||||
err := s.db.Find(&upstreams).Error
|
||||
return upstreams, err
|
||||
}
|
||||
|
||||
// GetByID Service层只做数据库查询
|
||||
func (s *Service) GetByID(id int) (*models.UpstreamEndpoint, error) {
|
||||
var upstream models.UpstreamEndpoint
|
||||
if err := s.db.First(&upstream, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &upstream, nil
|
||||
}
|
||||
|
||||
// Update Service层只做数据库更新
|
||||
func (s *Service) Update(upstream *models.UpstreamEndpoint) error {
|
||||
return s.db.Save(upstream).Error
|
||||
}
|
||||
|
||||
// Delete Service层只做数据库删除
|
||||
func (s *Service) Delete(id int) (int64, error) {
|
||||
result := s.db.Delete(&models.UpstreamEndpoint{}, id)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
Reference in New Issue
Block a user