New
This commit is contained in:
142
internal/app/app.go
Normal file
142
internal/app/app.go
Normal file
@@ -0,0 +1,142 @@
|
||||
// Filename: internal/app/app.go
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/db/migrations"
|
||||
"gemini-balancer/internal/db/seeder"
|
||||
"gemini-balancer/internal/scheduler"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// App
|
||||
type App struct {
|
||||
Config *config.Config
|
||||
Router *gin.Engine
|
||||
DB *gorm.DB
|
||||
Logger *logrus.Logger
|
||||
CryptoService *crypto.Service
|
||||
|
||||
// 拥有独立生命周期的后台服务
|
||||
ResourceService *service.ResourceService
|
||||
APIKeyService *service.APIKeyService
|
||||
DBLogWriter *service.DBLogWriterService
|
||||
AnalyticsService *service.AnalyticsService
|
||||
HealthCheckService *service.HealthCheckService
|
||||
SettingsManager *settings.SettingsManager
|
||||
GroupManager *service.GroupManager
|
||||
TokenManager *service.TokenManager
|
||||
Scheduler *scheduler.Scheduler
|
||||
}
|
||||
|
||||
// NewApp
|
||||
func NewApp(
|
||||
cfg *config.Config,
|
||||
router *gin.Engine,
|
||||
db *gorm.DB,
|
||||
logger *logrus.Logger,
|
||||
cryptoService *crypto.Service,
|
||||
resourceService *service.ResourceService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
dbLogWriter *service.DBLogWriterService,
|
||||
analyticsService *service.AnalyticsService,
|
||||
healthCheckService *service.HealthCheckService,
|
||||
settingsManager *settings.SettingsManager,
|
||||
groupManager *service.GroupManager,
|
||||
tokenManager *service.TokenManager,
|
||||
scheduler *scheduler.Scheduler,
|
||||
) *App {
|
||||
return &App{
|
||||
Config: cfg,
|
||||
Router: router,
|
||||
DB: db,
|
||||
Logger: logger,
|
||||
CryptoService: cryptoService,
|
||||
ResourceService: resourceService,
|
||||
APIKeyService: apiKeyService,
|
||||
DBLogWriter: dbLogWriter,
|
||||
AnalyticsService: analyticsService,
|
||||
HealthCheckService: healthCheckService,
|
||||
SettingsManager: settingsManager,
|
||||
GroupManager: groupManager,
|
||||
TokenManager: tokenManager,
|
||||
Scheduler: scheduler,
|
||||
}
|
||||
}
|
||||
|
||||
// Run 启动流程现在由App主动编排,而非被动接受
|
||||
func (a *App) Run() error {
|
||||
// --- 阶段一: (数据库设置) ---
|
||||
a.Logger.Info("* [SYSTEM] * Preparing: Database Setup ...")
|
||||
|
||||
// 步骤 1: (运行基础迁移,确保表存在)
|
||||
if err := migrations.RunMigrations(a.DB, a.Logger); err != nil {
|
||||
return fmt.Errorf("initial database migration failed: %w", err)
|
||||
}
|
||||
|
||||
// 步骤 2: (运行所有版本化的数据迁移)
|
||||
if err := migrations.RunVersionedMigrations(a.DB, a.Config, a.Logger); err != nil {
|
||||
return fmt.Errorf("failed to run versioned migrations: %w", err)
|
||||
}
|
||||
|
||||
// 步骤 3: (数据播种)
|
||||
seeder.RunSeeder(a.DB, a.CryptoService, a.Logger)
|
||||
a.Logger.Info("* [SYSTEM] * All Uitls READY. ---")
|
||||
|
||||
// --- 阶段二: (启动后台服务) ---
|
||||
a.Logger.Info("* [SYSTEM] * Starting main: Background Services ")
|
||||
a.APIKeyService.Start()
|
||||
a.DBLogWriter.Start()
|
||||
a.AnalyticsService.Start()
|
||||
a.HealthCheckService.Start()
|
||||
a.Scheduler.Start()
|
||||
a.Logger.Info("* [SYSTEM] * All Background Services are RUNNING. ")
|
||||
|
||||
// --- 阶段三: (优雅关机) ---
|
||||
defer a.Scheduler.Stop()
|
||||
defer a.HealthCheckService.Stop()
|
||||
defer a.AnalyticsService.Stop()
|
||||
defer a.APIKeyService.Stop()
|
||||
defer a.DBLogWriter.Stop()
|
||||
defer a.SettingsManager.Stop()
|
||||
defer a.TokenManager.Stop()
|
||||
|
||||
// --- 阶段四: (HTTP服务器) ---
|
||||
serverAddr := fmt.Sprintf("0.0.0.0:%s", a.Config.Server.Port)
|
||||
srv := &http.Server{
|
||||
Addr: serverAddr,
|
||||
Handler: a.Router,
|
||||
}
|
||||
go func() {
|
||||
a.Logger.Infof("* [SYSTEM] * HTTP Server Now Listening on %s ", serverAddr)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
a.Logger.Fatalf("HTTP server listen error: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// --- 阶段五: (处理关机信号) ---
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
a.Logger.Info("* [SYSTEM] * Shutdown signal received. Executing strategic retreat...")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
a.Logger.Fatal("Server forced to shutdown:", err)
|
||||
}
|
||||
a.Logger.Info("* [SYSTEM] * All units have ceased operations. Mission complete. ")
|
||||
return nil
|
||||
}
|
||||
47
internal/channel/channel.go
Normal file
47
internal/channel/channel.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Filename: internal/channel/channel.go
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ChannelProxy 是所有协议通道必须实现的统一接口
|
||||
type ChannelProxy interface {
|
||||
// RewritePath 路径重写的核心接口,负责将客户端路径转换为上游期望的路径
|
||||
RewritePath(basePath, originalPath string) string
|
||||
|
||||
TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error)
|
||||
// IsStreamRequest 检查请求是否为流式
|
||||
IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
|
||||
// IsOpenAICompatibleRequest 检查请求是否使用了OpenAI兼容路径
|
||||
IsOpenAICompatibleRequest(c *gin.Context) bool
|
||||
// ExtractModel 从请求中提取模型名称
|
||||
ExtractModel(c *gin.Context, bodyBytes []byte) string
|
||||
// ValidateKey 验证API Key的有效性
|
||||
ValidateKey(ctx context.Context, apiKey *models.APIKey, targetURL string, timeout time.Duration) *errors.APIError
|
||||
// ModifyRequest 在将请求发往上游前对其进行修改(如添加认证)
|
||||
ModifyRequest(req *http.Request, apiKey *models.APIKey)
|
||||
// ProcessSmartStreamRequest 处理核心的流式代理请求
|
||||
ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams)
|
||||
// 其他非流式处理方法等...
|
||||
}
|
||||
|
||||
// SmartRequestParams 是一个参数容器,用于将所有高层依赖,一次性、干净地传递到底层。
|
||||
type SmartRequestParams struct {
|
||||
CorrelationID string
|
||||
APIKey *models.APIKey
|
||||
UpstreamURL string
|
||||
RequestBody []byte
|
||||
OriginalRequest models.GeminiRequest
|
||||
EventLogger *models.RequestFinishedEvent
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
LogTruncationLimit int
|
||||
StreamingRetryPrompt string // <--- 传递续传指令
|
||||
}
|
||||
434
internal/channel/gemini_channel.go
Normal file
434
internal/channel/gemini_channel.go
Normal file
@@ -0,0 +1,434 @@
|
||||
// Filename: internal/channel/gemini_channel.go
|
||||
package channel
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const SmartRetryPrompt = "Continue exactly where you left off..."
|
||||
|
||||
var _ ChannelProxy = (*GeminiChannel)(nil)
|
||||
|
||||
type GeminiChannel struct {
|
||||
logger *logrus.Logger
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// 用于安全提取信息的本地结构体
|
||||
type requestMetadata struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
MaxIdleConns: cfg.TransportMaxIdleConns,
|
||||
MaxIdleConnsPerHost: cfg.TransportMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(cfg.TransportIdleConnTimeoutSecs) * time.Second,
|
||||
TLSHandshakeTimeout: time.Duration(cfg.TransportTLSHandshakeTimeout) * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
return &GeminiChannel{
|
||||
logger: logger,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TransformRequest
|
||||
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
|
||||
var p struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
_ = json.Unmarshal(requestBody, &p)
|
||||
modelName = strings.TrimPrefix(p.Model, "models/")
|
||||
|
||||
if modelName == "" {
|
||||
modelName = ch.extractModelFromPath(c.Request.URL.Path)
|
||||
}
|
||||
return requestBody, modelName, nil
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) extractModelFromPath(path string) string {
|
||||
parts := strings.Split(path, "/")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
|
||||
modelPart := strings.Split(part, ":")[0]
|
||||
return modelPart
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
|
||||
path := c.Request.URL.Path
|
||||
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ValidateKey(
|
||||
ctx context.Context,
|
||||
apiKey *models.APIKey,
|
||||
targetURL string,
|
||||
timeout time.Duration,
|
||||
) *CustomErrors.APIError {
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
|
||||
}
|
||||
|
||||
ch.ModifyRequest(req, apiKey)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
errorBody, _ := io.ReadAll(resp.Body)
|
||||
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
|
||||
Message: parsedMessage,
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
|
||||
// TODO: [Future Refactoring] Decouple auth logic from URL path.
|
||||
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
|
||||
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
|
||||
// This would make the channel more generic and adaptable to new upstream provider types.
|
||||
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
|
||||
} else {
|
||||
req.Header.Del("Authorization")
|
||||
q := req.URL.Query()
|
||||
q.Set("key", apiKey.APIKey)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
|
||||
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
|
||||
return true
|
||||
}
|
||||
var meta requestMetadata
|
||||
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
|
||||
return meta.Stream
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
||||
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
|
||||
return modelName
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
||||
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
|
||||
var rewrittenSegment string
|
||||
if ch.IsOpenAICompatibleRequest(tempCtx) {
|
||||
var apiEndpoint string
|
||||
v1Index := strings.LastIndex(originalPath, "/v1/")
|
||||
if v1Index != -1 {
|
||||
apiEndpoint = originalPath[v1Index+len("/v1/"):]
|
||||
} else {
|
||||
apiEndpoint = strings.TrimPrefix(originalPath, "/")
|
||||
}
|
||||
rewrittenSegment = "v1beta/openai/" + apiEndpoint
|
||||
} else {
|
||||
tempPath := originalPath
|
||||
if strings.HasPrefix(tempPath, "/v1/") {
|
||||
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
|
||||
}
|
||||
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
|
||||
}
|
||||
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
||||
pathToJoin := rewrittenSegment
|
||||
|
||||
versionPrefixes := []string{"v1beta", "v1"}
|
||||
for _, prefix := range versionPrefixes {
|
||||
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
|
||||
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
|
||||
break
|
||||
}
|
||||
}
|
||||
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
|
||||
if err != nil {
|
||||
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
|
||||
}
|
||||
return finalPath
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
}
|
||||
|
||||
// ==========================================================
|
||||
// ================== “智能路由”的核心引擎 ===================
|
||||
// ==========================================================
|
||||
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
|
||||
log := ch.logger.WithField("correlation_id", params.CorrelationID)
|
||||
|
||||
targetURL, err := url.Parse(params.UpstreamURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse upstream URL")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
|
||||
return
|
||||
}
|
||||
targetURL.Path = c.Request.URL.Path
|
||||
targetURL.RawQuery = c.Request.URL.RawQuery
|
||||
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create initial smart request")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
|
||||
return
|
||||
}
|
||||
ch.ModifyRequest(initialReq, params.APIKey)
|
||||
initialReq.Header.Del("Authorization")
|
||||
resp, err := ch.httpClient.Do(initialReq)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Initial smart request failed")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
|
||||
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
|
||||
c.Writer.WriteHeader(standardizedResp.StatusCode)
|
||||
for key, values := range standardizedResp.Header {
|
||||
for _, value := range values {
|
||||
c.Writer.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
io.Copy(c.Writer, standardizedResp.Body)
|
||||
params.EventLogger.IsSuccess = false
|
||||
params.EventLogger.StatusCode = resp.StatusCode
|
||||
return
|
||||
}
|
||||
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) processStreamAndRetry(
|
||||
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
|
||||
) {
|
||||
defer initialReader.Close()
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
var accumulatedText strings.Builder
|
||||
consecutiveRetryCount := 0
|
||||
currentReader := initialReader
|
||||
maxRetries := params.MaxRetries
|
||||
retryDelay := params.RetryDelay
|
||||
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
|
||||
for {
|
||||
var interruptionReason string
|
||||
scanner := bufio.NewScanner(currentReader)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(c.Writer, "%s\n\n", line)
|
||||
flusher.Flush()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
var payload models.GeminiSSEPayload
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
if len(payload.Candidates) > 0 {
|
||||
candidate := payload.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
accumulatedText.WriteString(candidate.Content.Parts[0].Text)
|
||||
}
|
||||
if candidate.FinishReason == "STOP" {
|
||||
log.Info("Stream finished successfully with STOP reason.")
|
||||
params.EventLogger.IsSuccess = true
|
||||
return
|
||||
}
|
||||
if candidate.FinishReason != "" {
|
||||
log.Warnf("Stream interrupted with abnormal finish reason: %s", candidate.FinishReason)
|
||||
interruptionReason = candidate.FinishReason
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
currentReader.Close()
|
||||
if interruptionReason == "" {
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.WithError(err).Warn("Stream scanner encountered an error.")
|
||||
interruptionReason = "SCANNER_ERROR"
|
||||
} else {
|
||||
log.Warn("Stream dropped unexpectedly without a finish reason.")
|
||||
interruptionReason = "CONNECTION_DROP"
|
||||
}
|
||||
}
|
||||
if consecutiveRetryCount >= maxRetries {
|
||||
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
|
||||
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
consecutiveRetryCount++
|
||||
params.EventLogger.Retries = consecutiveRetryCount
|
||||
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
|
||||
time.Sleep(retryDelay)
|
||||
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
||||
retryBodyBytes, _ := json.Marshal(retryBody)
|
||||
|
||||
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
||||
retryReq.Header = initialRequestHeaders
|
||||
ch.ModifyRequest(retryReq, params.APIKey)
|
||||
retryReq.Header.Del("Authorization")
|
||||
|
||||
retryResp, err := ch.httpClient.Do(retryReq)
|
||||
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Retry request failed.")
|
||||
} else {
|
||||
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
||||
if retryResp.Body != nil {
|
||||
retryResp.Body.Close()
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
currentReader = retryResp.Body
|
||||
}
|
||||
}
|
||||
|
||||
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
|
||||
retryBody := originalBody
|
||||
lastUserIndex := -1
|
||||
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
||||
if retryBody.Contents[i].Role == "user" {
|
||||
lastUserIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
history := []models.GeminiContent{
|
||||
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
|
||||
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
|
||||
}
|
||||
if lastUserIndex != -1 {
|
||||
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
|
||||
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
||||
newContents = append(newContents, history...)
|
||||
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
||||
retryBody.Contents = newContents
|
||||
} else {
|
||||
retryBody.Contents = append(retryBody.Contents, history...)
|
||||
}
|
||||
return retryBody, nil
|
||||
}
|
||||
|
||||
// ===============================================
|
||||
// ========= 辅助函数区 (继承并强化) =========
|
||||
// ===============================================
|
||||
|
||||
type googleAPIError struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
Details []interface{} `json:"details,omitempty"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func statusToGoogleStatus(code int) string {
|
||||
switch code {
|
||||
case 400:
|
||||
return "INVALID_ARGUMENT"
|
||||
case 401:
|
||||
return "UNAUTHENTICATED"
|
||||
case 403:
|
||||
return "PERMISSION_DENIED"
|
||||
case 404:
|
||||
return "NOT_FOUND"
|
||||
case 429:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
case 500:
|
||||
return "INTERNAL"
|
||||
case 503:
|
||||
return "UNAVAILABLE"
|
||||
case 504:
|
||||
return "DEADLINE_EXCEEDED"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if n > 0 && len(s) > n {
|
||||
return fmt.Sprintf("%s... [truncated %d chars]", s[:n], len(s)-n)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// standardizeError
|
||||
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to read upstream error body")
|
||||
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
|
||||
}
|
||||
resp.Body.Close()
|
||||
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
|
||||
var standardizedPayload googleAPIError
|
||||
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
|
||||
standardizedPayload.Error.Code = resp.StatusCode
|
||||
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
|
||||
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
|
||||
standardizedPayload.Error.Details = []interface{}{map[string]string{
|
||||
"@type": "proxy.upstream.error",
|
||||
"body": truncate(string(bodyBytes), truncateLimit),
|
||||
}}
|
||||
}
|
||||
newBodyBytes, _ := json.Marshal(standardizedPayload)
|
||||
newResp := &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(newBodyBytes)),
|
||||
}
|
||||
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
newResp.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
return newResp
|
||||
}
|
||||
|
||||
// errToJSON
|
||||
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
|
||||
}
|
||||
81
internal/config/config.go
Normal file
81
internal/config/config.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Config 应用所有配置的集合
|
||||
type Config struct {
|
||||
Database DatabaseConfig
|
||||
Server ServerConfig
|
||||
Log LogConfig
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
SessionSecret string `mapstructure:"session_secret"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 存储数据库连接信息
|
||||
type DatabaseConfig struct {
|
||||
DSN string `mapstructure:"dsn"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
// ServerConfig 存储HTTP服务器配置
|
||||
type ServerConfig struct {
|
||||
Port string `mapstructure:"port"`
|
||||
}
|
||||
|
||||
// LogConfig 存储日志配置
|
||||
type LogConfig struct {
|
||||
Level string `mapstructure:"level" json:"level"`
|
||||
Format string `mapstructure:"format" json:"format"`
|
||||
EnableFile bool `mapstructure:"enable_file" json:"enable_file"`
|
||||
FilePath string `mapstructure:"file_path" json:"file_path"`
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
DSN string `mapstructure:"dsn"`
|
||||
}
|
||||
|
||||
// LoadConfig 从文件和环境变量加载配置
|
||||
func LoadConfig() (*Config, error) {
|
||||
// 设置配置文件名和路径
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
// 允许从环境变量读取
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// 设置默认值
|
||||
viper.SetDefault("server.port", "8080")
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("log.format", "text")
|
||||
viper.SetDefault("log.enable_file", false)
|
||||
viper.SetDefault("log.file_path", "logs/gemini-balancer.log")
|
||||
viper.SetDefault("database.type", "sqlite")
|
||||
viper.SetDefault("database.dsn", "gemini-balancer.db")
|
||||
viper.SetDefault("database.max_idle_conns", 10)
|
||||
viper.SetDefault("database.max_open_conns", 100)
|
||||
viper.SetDefault("database.conn_max_lifetime", "1h")
|
||||
viper.SetDefault("encryption_key", "")
|
||||
|
||||
// 读取配置文件
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("error reading config file: %w", err)
|
||||
}
|
||||
}
|
||||
var cfg Config
|
||||
if err := viper.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("unable to decode config into struct: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
123
internal/container/container.go
Normal file
123
internal/container/container.go
Normal file
@@ -0,0 +1,123 @@
|
||||
// Filename: internal/container/container.go
|
||||
package container
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/app"
|
||||
"gemini-balancer/internal/channel"
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/db"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/db/migrations"
|
||||
"gemini-balancer/internal/domain/proxy"
|
||||
"gemini-balancer/internal/domain/upstream"
|
||||
"gemini-balancer/internal/handlers"
|
||||
"gemini-balancer/internal/logging"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/router"
|
||||
"gemini-balancer/internal/scheduler"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
"gemini-balancer/internal/webhandlers"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.uber.org/dig"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func BuildContainer() (*dig.Container, error) {
|
||||
container := dig.New()
|
||||
|
||||
// =========== 阶段一: 基础设施层 (Infrastructure) ===========
|
||||
container.Provide(config.LoadConfig)
|
||||
|
||||
container.Provide(func(cfg *config.Config, logger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
|
||||
gormDB, adapter, err := db.NewDB(cfg, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// 迁移运行逻辑
|
||||
if err := migrations.RunVersionedMigrations(gormDB, cfg, logger); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to run versioned migrations: %w", err)
|
||||
}
|
||||
return gormDB, adapter, nil
|
||||
})
|
||||
container.Provide(store.NewStore)
|
||||
container.Provide(logging.NewLogger)
|
||||
container.Provide(crypto.NewService)
|
||||
container.Provide(repository.NewAuthTokenRepository)
|
||||
container.Provide(repository.NewGroupRepository)
|
||||
container.Provide(repository.NewKeyRepository)
|
||||
// Repository 接口绑定
|
||||
//container.Provide(func(r *repository.gormKeyRepository) repository.KeyRepository { return r })
|
||||
//container.Provide(func(r *repository.GormGroupRepository) repository.GroupRepository { return r })
|
||||
|
||||
// SettingsManager.
|
||||
container.Provide(settings.NewSettingsManager)
|
||||
// 基于SettingsManager, 提供一个标准的、安全的“数据插座” 让模块只依赖所需的数据,而非整个管理器。
|
||||
container.Provide(func(sm *settings.SettingsManager) *models.SystemSettings { return sm.GetSettings() })
|
||||
|
||||
// =========== 阶段二: 核心服务层 (Services) ===========
|
||||
container.Provide(service.NewDBLogWriterService)
|
||||
container.Provide(service.NewSecurityService)
|
||||
container.Provide(crypto.NewService)
|
||||
container.Provide(service.NewKeyImportService)
|
||||
container.Provide(service.NewKeyValidationService)
|
||||
container.Provide(service.NewTokenManager)
|
||||
container.Provide(service.NewAPIKeyService)
|
||||
container.Provide(service.NewGroupManager)
|
||||
container.Provide(service.NewResourceService)
|
||||
container.Provide(service.NewAnalyticsService)
|
||||
container.Provide(service.NewLogService)
|
||||
container.Provide(service.NewHealthCheckService)
|
||||
container.Provide(service.NewStatsService)
|
||||
container.Provide(service.NewDashboardQueryService)
|
||||
container.Provide(scheduler.NewScheduler)
|
||||
container.Provide(task.NewTask)
|
||||
|
||||
// --- Task Reporter ---
|
||||
container.Provide(func(t *task.Task) task.Reporter { return t })
|
||||
// --- Syncer & Loader for GroupManager ---
|
||||
container.Provide(service.NewGroupManagerLoader)
|
||||
// 为GroupManager配置Syncer
|
||||
container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) {
|
||||
const groupUpdateChannel = "groups:cache_invalidation"
|
||||
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel)
|
||||
})
|
||||
|
||||
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
|
||||
|
||||
// 为Channel提供依赖 (Logger 和 *models.SystemSettings 数据插座)
|
||||
container.Provide(channel.NewGeminiChannel)
|
||||
container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch })
|
||||
|
||||
// --- API Handlers ---
|
||||
container.Provide(handlers.NewAPIKeyHandler)
|
||||
container.Provide(handlers.NewKeyGroupHandler)
|
||||
container.Provide(handlers.NewTokensHandler)
|
||||
container.Provide(handlers.NewLogHandler)
|
||||
container.Provide(handlers.NewSettingHandler)
|
||||
container.Provide(handlers.NewDashboardHandler)
|
||||
container.Provide(handlers.NewAPIAuthHandler)
|
||||
container.Provide(handlers.NewProxyHandler)
|
||||
container.Provide(handlers.NewTaskHandler)
|
||||
|
||||
// --- Domain Modules ---
|
||||
container.Provide(upstream.NewModule)
|
||||
container.Provide(proxy.NewModule)
|
||||
|
||||
// --- Web Page Handlers ---
|
||||
container.Provide(webhandlers.NewWebAuthHandler)
|
||||
container.Provide(webhandlers.NewPageHandler)
|
||||
|
||||
// =========== 顶层应用层 (Application) ===========
|
||||
container.Provide(router.NewRouter)
|
||||
container.Provide(app.NewApp)
|
||||
|
||||
return container, nil
|
||||
}
|
||||
76
internal/crypto/crypto.go
Normal file
76
internal/crypto/crypto.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Filename: internal/crypto/crypto.go
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/config" // [NEW] Crypto service now depends on Config
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
gcm cipher.AEAD
|
||||
}
|
||||
|
||||
func NewService(cfg *config.Config) (*Service, error) {
|
||||
keyHex := cfg.EncryptionKey
|
||||
if keyHex == "" {
|
||||
// Fallback to environment variable if not in config file
|
||||
keyHex = os.Getenv("ENCRYPTION_KEY")
|
||||
if keyHex == "" {
|
||||
return nil, fmt.Errorf("encryption key is not configured: please set 'encryption_key' in config.yaml or the ENCRYPTION_KEY environment variable")
|
||||
}
|
||||
}
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode encryption key from hex: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("invalid encryption key length: must be 32 bytes (64 hex characters), got %d bytes", len(key))
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil { // ... (rest is the same)
|
||||
return nil, fmt.Errorf("failed to create AES cipher block: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM block cipher: %w", err)
|
||||
}
|
||||
return &Service{gcm: gcm}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext and returns hex-encoded ciphertext.
|
||||
func (s *Service) Encrypt(plaintext string) (string, error) {
|
||||
nonce := make([]byte, s.gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := s.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return hex.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a hex-encoded ciphertext and returns the plaintext.
|
||||
func (s *Service) Decrypt(hexCiphertext string) (string, error) {
|
||||
ciphertext, err := hex.DecodeString(hexCiphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ciphertext hex decode error: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := s.gcm.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return "", fmt.Errorf("invalid ciphertext: too short")
|
||||
}
|
||||
|
||||
nonce, encryptedMessage := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
plaintext, err := s.gcm.Open(nil, nonce, encryptedMessage, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
88
internal/db/db.go
Normal file
88
internal/db/db.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Filename: internal/db/db.go
|
||||
package db
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
stdlog "log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func NewDB(cfg *config.Config, appLogger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
|
||||
Logger := appLogger.WithField("component", "db")
|
||||
Logger.Info("Initializing database connection and dialect adapter...")
|
||||
dbConfig := cfg.Database
|
||||
dsn := dbConfig.DSN
|
||||
var gormLogger logger.Interface
|
||||
if cfg.Log.Level == "debug" {
|
||||
gormLogger = logger.New(
|
||||
stdlog.New(os.Stdout, "\r\n", stdlog.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: 1 * time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: true,
|
||||
},
|
||||
)
|
||||
Logger.Info("Debug mode enabled, GORM SQL logging is active.")
|
||||
}
|
||||
|
||||
var dialector gorm.Dialector
|
||||
var adapter dialect.DialectAdapter
|
||||
switch {
|
||||
case strings.HasPrefix(dsn, "postgres://"), strings.HasPrefix(dsn, "postgresql://"):
|
||||
Logger.Info("Detected PostgreSQL database.")
|
||||
dialector = postgres.Open(dsn)
|
||||
adapter = dialect.NewPostgresAdapter()
|
||||
case strings.Contains(dsn, "@tcp"):
|
||||
Logger.Info("Detected MySQL database.")
|
||||
if !strings.Contains(dsn, "parseTime=true") {
|
||||
if strings.Contains(dsn, "?") {
|
||||
dsn += "&parseTime=true"
|
||||
} else {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
dialector = mysql.Open(dsn)
|
||||
adapter = dialect.NewPostgresAdapter()
|
||||
default:
|
||||
Logger.Info("Using SQLite database.")
|
||||
if err := os.MkdirAll(filepath.Dir(dsn), 0755); err != nil {
|
||||
Logger.Errorf("Failed to create SQLite directory: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
dialector = sqlite.Open(dsn + "?_busy_timeout=5000")
|
||||
adapter = dialect.NewSQLiteAdapter()
|
||||
}
|
||||
db, err := gorm.Open(dialector, &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
Logger.Errorf("Failed to open database connection: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
Logger.Errorf("Failed to get underlying sql.DB: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(dbConfig.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(dbConfig.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(dbConfig.ConnMaxLifetime)
|
||||
Logger.Infof("Connection pool configured: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
dbConfig.MaxIdleConns, dbConfig.MaxOpenConns, dbConfig.ConnMaxLifetime)
|
||||
Logger.Info("Database connection established successfully.")
|
||||
return db, adapter, nil
|
||||
}
|
||||
14
internal/db/dialect/dialect.go
Normal file
14
internal/db/dialect/dialect.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Filename: internal/db/dialect/dialect.go
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// “通用语言”接口。
|
||||
type DialectAdapter interface {
|
||||
// OnConflictUpdateAll 生成一个完整的、适用于当前数据库的 "ON CONFLICT DO UPDATE" 子句。
|
||||
// conflictColumns: 唯一的约束列,例如 ["time", "group_id", "model_name"]
|
||||
// updateColumns: 需要累加更新的列,例如 ["request_count", "success_count", ...]
|
||||
OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression
|
||||
}
|
||||
30
internal/db/dialect/mysql_adapter.go
Normal file
30
internal/db/dialect/mysql_adapter.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Filename: internal/db/mysql_adapter.go
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type mysqlAdapter struct{}
|
||||
|
||||
func NewMySQLAdapter() DialectAdapter {
|
||||
return &mysqlAdapter{}
|
||||
}
|
||||
|
||||
func (a *mysqlAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
|
||||
conflictCols := make([]clause.Column, len(conflictColumns))
|
||||
for i, col := range conflictColumns {
|
||||
conflictCols[i] = clause.Column{Name: col}
|
||||
}
|
||||
|
||||
assignments := make(map[string]interface{})
|
||||
for _, col := range updateColumns {
|
||||
assignments[col] = gorm.Expr(col + " + VALUES(" + col + ")")
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: conflictCols,
|
||||
DoUpdates: clause.Assignments(assignments),
|
||||
}
|
||||
}
|
||||
30
internal/db/dialect/postgres_adapter.go
Normal file
30
internal/db/dialect/postgres_adapter.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Filename: internal/db/dialect/postgres_adapter.go (全新文件)
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type postgresAdapter struct{}
|
||||
|
||||
func NewPostgresAdapter() DialectAdapter {
|
||||
return &postgresAdapter{}
|
||||
}
|
||||
|
||||
func (a *postgresAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
|
||||
conflictCols := make([]clause.Column, len(conflictColumns))
|
||||
for i, col := range conflictColumns {
|
||||
conflictCols[i] = clause.Column{Name: col}
|
||||
}
|
||||
|
||||
assignments := make(map[string]interface{})
|
||||
for _, col := range updateColumns {
|
||||
assignments[col] = gorm.Expr(col + " + excluded." + col)
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: conflictCols,
|
||||
DoUpdates: clause.Assignments(assignments),
|
||||
}
|
||||
}
|
||||
30
internal/db/dialect/sqlite_adapter.go
Normal file
30
internal/db/dialect/sqlite_adapter.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Filename: internal/db/sqlite_adapter.go (全新文件 - 最终版)
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type sqliteAdapter struct{}
|
||||
|
||||
func NewSQLiteAdapter() DialectAdapter {
|
||||
return &sqliteAdapter{}
|
||||
}
|
||||
|
||||
func (a *sqliteAdapter) OnConflictUpdateAll(conflictColumns []string, updateColumns []string) clause.Expression {
|
||||
conflictCols := make([]clause.Column, len(conflictColumns))
|
||||
for i, col := range conflictColumns {
|
||||
conflictCols[i] = clause.Column{Name: col}
|
||||
}
|
||||
|
||||
assignments := make(map[string]interface{})
|
||||
for _, col := range updateColumns {
|
||||
assignments[col] = gorm.Expr(col + " + excluded." + col)
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: conflictCols,
|
||||
DoUpdates: clause.Assignments(assignments),
|
||||
}
|
||||
}
|
||||
36
internal/db/migrations/migrations.go
Normal file
36
internal/db/migrations/migrations.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Filename: internal/db/migrations/migrations.go (全新)
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RunMigrations 负责执行所有的数据库模式迁移。
|
||||
func RunMigrations(db *gorm.DB, logger *logrus.Logger) error {
|
||||
log := logger.WithField("component", "migrations")
|
||||
log.Info("Running database schema migrations...")
|
||||
// 集中管理所有需要被创建或更新的表。
|
||||
err := db.AutoMigrate(
|
||||
&models.UpstreamEndpoint{},
|
||||
&models.ProxyConfig{},
|
||||
&models.APIKey{},
|
||||
&models.KeyGroup{},
|
||||
&models.GroupModelMapping{},
|
||||
&models.AuthToken{},
|
||||
&models.RequestLog{},
|
||||
&models.StatsHourly{},
|
||||
&models.FileRecord{},
|
||||
&models.Setting{},
|
||||
&models.GroupSettings{},
|
||||
&models.GroupAPIKeyMapping{},
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Database schema migration failed: %v", err)
|
||||
return err
|
||||
}
|
||||
log.Info("Database schema migrations completed successfully.")
|
||||
return nil
|
||||
}
|
||||
62
internal/db/migrations/versioned_migrations.go
Normal file
62
internal/db/migrations/versioned_migrations.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Filename: internal/db/migrations/versioned_migrations.go
|
||||
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/config"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RunVersionedMigrations 负责运行所有已注册的版本化迁移。
|
||||
func RunVersionedMigrations(db *gorm.DB, cfg *config.Config, logger *logrus.Logger) error {
|
||||
log := logger.WithField("component", "versioned_migrations")
|
||||
log.Info("Checking for versioned database migrations...")
|
||||
|
||||
if err := db.AutoMigrate(&MigrationHistory{}); err != nil {
|
||||
log.Errorf("Failed to create migration history table: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var executedMigrations []MigrationHistory
|
||||
db.Find(&executedMigrations)
|
||||
executedVersions := make(map[string]bool)
|
||||
for _, m := range executedMigrations {
|
||||
executedVersions[m.Version] = true
|
||||
}
|
||||
|
||||
for _, migration := range migrationRegistry {
|
||||
if !executedVersions[migration.Version] {
|
||||
log.Infof("Running migration %s: %s", migration.Version, migration.Description)
|
||||
if err := migration.Migrate(db, cfg, log); err != nil {
|
||||
log.Errorf("Migration %s failed: %v", migration.Version, err)
|
||||
return fmt.Errorf("migration %s failed: %w", migration.Version, err)
|
||||
}
|
||||
db.Create(&MigrationHistory{Version: migration.Version})
|
||||
log.Infof("Migration %s completed successfully.", migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("All versioned migrations are up to date.")
|
||||
return nil
|
||||
}
|
||||
|
||||
type MigrationFunc func(db *gorm.DB, cfg *config.Config, logger *logrus.Entry) error
|
||||
type VersionedMigration struct {
|
||||
Version string
|
||||
Description string
|
||||
Migrate MigrationFunc
|
||||
}
|
||||
type MigrationHistory struct {
|
||||
Version string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
var migrationRegistry = []VersionedMigration{
|
||||
/*{
|
||||
Version: "20250828_encrypt_existing_auth_tokens",
|
||||
Description: "Encrypt plaintext tokens and populate new crypto columns in auth_tokens table.",
|
||||
Migrate: MigrateAuthTokenEncryption,
|
||||
},*/
|
||||
}
|
||||
87
internal/db/seeder/seeder.go
Normal file
87
internal/db/seeder/seeder.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Filename: internal/db/seeder/seeder.go
|
||||
package seeder
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RunSeeder now requires the crypto service to create the initial admin token securely.
|
||||
func RunSeeder(db *gorm.DB, cryptoService *crypto.Service, logger *logrus.Logger) {
|
||||
log := logger.WithField("component", "seeder")
|
||||
log.Info("Running database seeder...")
|
||||
// [REFACTORED] Admin token seeding is now crypto-aware.
|
||||
var count int64
|
||||
db.Model(&models.AuthToken{}).Where("is_admin = ?", true).Count(&count)
|
||||
if count == 0 {
|
||||
log.Info("No admin token found, attempting to seed one...")
|
||||
const adminTokenPlaintext = "admin-secret-token" // The default token
|
||||
// 1. Encrypt and Hash the token
|
||||
encryptedToken, err := cryptoService.Encrypt(adminTokenPlaintext)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Failed to encrypt default admin token during seeding: %v. Server cannot start.", err)
|
||||
return
|
||||
}
|
||||
hash := sha256.Sum256([]byte(adminTokenPlaintext))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
// 2. Use the repository to seed the token
|
||||
// Note: We create a temporary repository instance here just for the seeder.
|
||||
repo := repository.NewAuthTokenRepository(db, cryptoService, logger)
|
||||
if err := repo.SeedAdminToken(encryptedToken, tokenHash); err != nil {
|
||||
log.Warnf("Failed to seed admin token using repository: %v", err)
|
||||
} else {
|
||||
log.Infof("Default admin token has been seeded successfully. Please use '%s' for your first login.", adminTokenPlaintext)
|
||||
}
|
||||
} else {
|
||||
log.Info("Admin token already exists, seeder skipped.")
|
||||
}
|
||||
|
||||
// This functionality should be replaced by a proper user/token management UI in the future.
|
||||
linkAllKeysToDefaultGroup(db, log)
|
||||
}
|
||||
|
||||
// linkAllKeysToDefaultGroup ensures every key belongs to at least one group.
|
||||
func linkAllKeysToDefaultGroup(db *gorm.DB, logger *logrus.Entry) {
|
||||
logger.Info("Linking existing API keys to the default group as a fallback...")
|
||||
// 1. Find a default group (the first one for simplicity)
|
||||
var defaultGroup models.KeyGroup
|
||||
if err := db.Order("id asc").First(&defaultGroup).Error; err != nil {
|
||||
logger.Warnf("Seeder: Could not find a default key group to link keys to: %v", err)
|
||||
return
|
||||
}
|
||||
// 2. Find all "orphan keys" that don't belong to any group
|
||||
var orphanKeys []*models.APIKey
|
||||
err := db.Raw(`
|
||||
SELECT * FROM api_keys
|
||||
WHERE id NOT IN (SELECT DISTINCT api_key_id FROM group_api_key_mappings)
|
||||
AND deleted_at IS NULL
|
||||
`).Scan(&orphanKeys).Error
|
||||
if err != nil {
|
||||
logger.Errorf("Seeder: Failed to query for orphan keys: %v", err)
|
||||
return
|
||||
}
|
||||
if len(orphanKeys) == 0 {
|
||||
logger.Info("Seeder: No orphan API keys found to link.")
|
||||
return
|
||||
}
|
||||
// 3. Create GroupAPIKeyMapping records manually
|
||||
logger.Infof("Seeder: Found %d orphan keys. Creating mappings for them in group '%s' (ID: %d)...", len(orphanKeys), defaultGroup.Name, defaultGroup.ID)
|
||||
var newMappings []models.GroupAPIKeyMapping
|
||||
for _, key := range orphanKeys {
|
||||
newMappings = append(newMappings, models.GroupAPIKeyMapping{
|
||||
KeyGroupID: defaultGroup.ID,
|
||||
APIKeyID: key.ID,
|
||||
})
|
||||
}
|
||||
if err := db.Create(&newMappings).Error; err != nil {
|
||||
logger.Errorf("Seeder: Failed to create key mappings for orphan keys: %v", err)
|
||||
} else {
|
||||
logger.Info("Successfully created mappings for orphan API keys.")
|
||||
}
|
||||
}
|
||||
8
internal/domain/proxy/errors.go
Normal file
8
internal/domain/proxy/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package proxy
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoActiveProxies = errors.New("no active proxies available in the pool")
|
||||
ErrTaskConflict = errors.New("a sync task is already in progress for proxies")
|
||||
)
|
||||
269
internal/domain/proxy/handler.go
Normal file
269
internal/domain/proxy/handler.go
Normal file
@@ -0,0 +1,269 @@
|
||||
// Filename: internal/domain/proxy/handler.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
db *gorm.DB
|
||||
manager *manager
|
||||
store store.Store
|
||||
settings *settings.SettingsManager
|
||||
}
|
||||
|
||||
func newHandler(db *gorm.DB, m *manager, s store.Store, sp *settings.SettingsManager) *handler {
|
||||
return &handler{
|
||||
db: db,
|
||||
manager: m,
|
||||
store: s,
|
||||
settings: sp,
|
||||
}
|
||||
}
|
||||
|
||||
// === 领域暴露的公共API ===
|
||||
|
||||
func (h *handler) registerRoutes(rg *gin.RouterGroup) {
|
||||
proxyRoutes := rg.Group("/proxies")
|
||||
{
|
||||
proxyRoutes.PUT("/sync", h.SyncProxies)
|
||||
proxyRoutes.POST("/check", h.CheckSingleProxy)
|
||||
proxyRoutes.POST("/check-all", h.CheckAllProxies)
|
||||
|
||||
proxyRoutes.POST("/", h.CreateProxyConfig)
|
||||
proxyRoutes.GET("/", h.ListProxyConfigs)
|
||||
proxyRoutes.GET("/:id", h.GetProxyConfig)
|
||||
proxyRoutes.PUT("/:id", h.UpdateProxyConfig)
|
||||
proxyRoutes.DELETE("/:id", h.DeleteProxyConfig)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// --- 请求 DTO ---
|
||||
type CreateProxyConfigRequest struct {
|
||||
Address string `json:"address" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type UpdateProxyConfigRequest struct {
|
||||
Address *string `json:"address"`
|
||||
Protocol *string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// 单个检测的请求体 (与前端JS对齐)
|
||||
type CheckSingleProxyRequest struct {
|
||||
Proxy string `json:"proxy" binding:"required"`
|
||||
}
|
||||
|
||||
// 批量检测的请求体
|
||||
type CheckAllProxiesRequest struct {
|
||||
Proxies []string `json:"proxies" binding:"required"`
|
||||
}
|
||||
|
||||
// --- Handler 方法 ---
|
||||
|
||||
func (h *handler) CreateProxyConfig(c *gin.Context) {
|
||||
var req CreateProxyConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Status == "" {
|
||||
req.Status = "active" // 默认状态
|
||||
}
|
||||
|
||||
proxyConfig := models.ProxyConfig{
|
||||
Address: req.Address,
|
||||
Protocol: req.Protocol,
|
||||
Status: req.Status,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
if err := h.db.Create(&proxyConfig).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
// 写操作后,发布事件并使缓存失效
|
||||
h.publishAndInvalidate(proxyConfig.ID, "created")
|
||||
response.Created(c, proxyConfig)
|
||||
}
|
||||
|
||||
func (h *handler) ListProxyConfigs(c *gin.Context) {
|
||||
var proxyConfigs []models.ProxyConfig
|
||||
if err := h.db.Find(&proxyConfigs).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, proxyConfigs)
|
||||
}
|
||||
|
||||
func (h *handler) GetProxyConfig(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
var proxyConfig models.ProxyConfig
|
||||
if err := h.db.First(&proxyConfig, id).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, proxyConfig)
|
||||
}
|
||||
|
||||
func (h *handler) UpdateProxyConfig(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProxyConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var proxyConfig models.ProxyConfig
|
||||
if err := h.db.First(&proxyConfig, id).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Address != nil {
|
||||
proxyConfig.Address = *req.Address
|
||||
}
|
||||
if req.Protocol != nil {
|
||||
proxyConfig.Protocol = *req.Protocol
|
||||
}
|
||||
if req.Status != nil {
|
||||
proxyConfig.Status = *req.Status
|
||||
}
|
||||
if req.Description != nil {
|
||||
proxyConfig.Description = *req.Description
|
||||
}
|
||||
|
||||
if err := h.db.Save(&proxyConfig).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
h.publishAndInvalidate(uint(id), "updated")
|
||||
response.Success(c, proxyConfig)
|
||||
}
|
||||
|
||||
func (h *handler) DeleteProxyConfig(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := h.db.Model(&models.APIKey{}).Where("proxy_id = ?", id).Count(&count).Error; err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrDuplicateResource, "Cannot delete proxy config that is still in use by API keys"))
|
||||
return
|
||||
}
|
||||
|
||||
result := h.db.Delete(&models.ProxyConfig{}, id)
|
||||
if result.Error != nil {
|
||||
response.Error(c, errors.ParseDBError(result.Error))
|
||||
return
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
response.Error(c, errors.ErrResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
h.publishAndInvalidate(uint(id), "deleted")
|
||||
response.NoContent(c)
|
||||
}
|
||||
|
||||
// publishAndInvalidate 统一事件发布和缓存失效逻辑
|
||||
func (h *handler) publishAndInvalidate(proxyID uint, action string) {
|
||||
go h.manager.invalidate()
|
||||
go func() {
|
||||
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
// 新的 Handler 方法和 DTO
|
||||
type SyncProxiesRequest struct {
|
||||
Proxies []string `json:"proxies"`
|
||||
}
|
||||
|
||||
func (h *handler) SyncProxies(c *gin.Context) {
|
||||
var req SyncProxiesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, ErrTaskConflict) {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
} else {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Proxy synchronization task started.",
|
||||
"task": taskStatus,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *handler) CheckSingleProxy(c *gin.Context) {
|
||||
var req CheckSingleProxyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
cfg := h.settings.GetSettings()
|
||||
timeout := time.Duration(cfg.ProxyCheckTimeoutSeconds) * time.Second
|
||||
result := h.manager.CheckSingleProxy(req.Proxy, timeout)
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *handler) CheckAllProxies(c *gin.Context) {
|
||||
var req CheckAllProxiesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
cfg := h.settings.GetSettings()
|
||||
timeout := time.Duration(cfg.ProxyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
concurrency := cfg.ProxyCheckConcurrency
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // 如果配置不合法,提供一个安全的默认值
|
||||
}
|
||||
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
|
||||
response.Success(c, results)
|
||||
}
|
||||
315
internal/domain/proxy/manager.go
Normal file
315
internal/domain/proxy/manager.go
Normal file
@@ -0,0 +1,315 @@
|
||||
// Filename: internal/domain/proxy/manager.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeProxySync = "proxy_sync"
|
||||
proxyChunkSize = 200 // 代理同步的批量大小
|
||||
)
|
||||
|
||||
type ProxyCheckResult struct {
|
||||
Proxy string `json:"proxy"`
|
||||
IsAvailable bool `json:"is_available"`
|
||||
ResponseTime float64 `json:"response_time"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
}
|
||||
|
||||
// managerCacheData
|
||||
type managerCacheData struct {
|
||||
ActiveProxies []*models.ProxyConfig
|
||||
ProxiesByID map[uint]*models.ProxyConfig
|
||||
}
|
||||
|
||||
// manager结构体
|
||||
type manager struct {
|
||||
db *gorm.DB
|
||||
syncer *syncer.CacheSyncer[managerCacheData]
|
||||
task task.Reporter
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func newManagerLoader(db *gorm.DB) syncer.LoaderFunc[managerCacheData] {
|
||||
return func() (managerCacheData, error) {
|
||||
var activeProxies []*models.ProxyConfig
|
||||
if err := db.Where("status = ?", "active").Order("assigned_keys_count asc").Find(&activeProxies).Error; err != nil {
|
||||
return managerCacheData{}, fmt.Errorf("failed to load active proxies for cache: %w", err)
|
||||
}
|
||||
|
||||
proxiesByID := make(map[uint]*models.ProxyConfig, len(activeProxies))
|
||||
for _, proxy := range activeProxies {
|
||||
p := *proxy
|
||||
proxiesByID[p.ID] = &p
|
||||
}
|
||||
|
||||
return managerCacheData{
|
||||
ActiveProxies: activeProxies,
|
||||
ProxiesByID: proxiesByID,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskReporter task.Reporter, store store.Store, logger *logrus.Entry) *manager {
|
||||
return &manager{
|
||||
db: db,
|
||||
syncer: syncer,
|
||||
task: taskReporter,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) {
|
||||
resourceID := "global_proxy_sync"
|
||||
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
|
||||
if err != nil {
|
||||
return nil, ErrTaskConflict
|
||||
}
|
||||
go m.runProxySyncTask(taskStatus.ID, proxyStrings)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
|
||||
resourceID := "global_proxy_sync"
|
||||
var allProxies []models.ProxyConfig
|
||||
if err := m.db.Find(&allProxies).Error; err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
|
||||
return
|
||||
}
|
||||
currentProxyMap := make(map[string]uint)
|
||||
for _, p := range allProxies {
|
||||
fullString := fmt.Sprintf("%s://%s", p.Protocol, p.Address)
|
||||
currentProxyMap[fullString] = p.ID
|
||||
}
|
||||
finalProxyMap := make(map[string]bool)
|
||||
for _, ps := range finalProxyStrings {
|
||||
finalProxyMap[strings.TrimSpace(ps)] = true
|
||||
}
|
||||
var idsToDelete []uint
|
||||
var proxiesToAdd []models.ProxyConfig
|
||||
for proxyString, id := range currentProxyMap {
|
||||
if !finalProxyMap[proxyString] {
|
||||
idsToDelete = append(idsToDelete, id)
|
||||
}
|
||||
}
|
||||
for proxyString := range finalProxyMap {
|
||||
if _, exists := currentProxyMap[proxyString]; !exists {
|
||||
parsed := parseProxyString(proxyString)
|
||||
if parsed != nil {
|
||||
proxiesToAdd = append(proxiesToAdd, models.ProxyConfig{
|
||||
Protocol: parsed.Protocol, Address: parsed.Address, Status: "active",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(idsToDelete) > 0 {
|
||||
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(proxiesToAdd) > 0 {
|
||||
if err := m.bulkAdd(proxiesToAdd); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
|
||||
m.task.EndTaskByID(taskID, resourceID, result, nil)
|
||||
m.publishChangeEvent("proxies_synced")
|
||||
go m.invalidate()
|
||||
}
|
||||
|
||||
type parsedProxy struct{ Protocol, Address string }
|
||||
|
||||
func parseProxyString(proxyStr string) *parsedProxy {
|
||||
proxyStr = strings.TrimSpace(proxyStr)
|
||||
u, err := url.Parse(proxyStr)
|
||||
if err != nil || !strings.Contains(proxyStr, "://") {
|
||||
if strings.Contains(proxyStr, "@") {
|
||||
parts := strings.Split(proxyStr, "@")
|
||||
if len(parts) == 2 {
|
||||
proxyStr = "socks5://" + proxyStr
|
||||
u, err = url.Parse(proxyStr)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
protocol := strings.ToLower(u.Scheme)
|
||||
if protocol != "http" && protocol != "https" && protocol != "socks5" {
|
||||
return nil
|
||||
}
|
||||
address := u.Host
|
||||
if u.User != nil {
|
||||
address = u.User.String() + "@" + u.Host
|
||||
}
|
||||
return &parsedProxy{Protocol: protocol, Address: address}
|
||||
}
|
||||
|
||||
func (m *manager) bulkDeleteByIDs(ids []uint) error {
|
||||
for i := 0; i < len(ids); i += proxyChunkSize {
|
||||
end := i + proxyChunkSize
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
chunk := ids[i:end]
|
||||
if err := m.db.Where("id IN ?", chunk).Delete(&models.ProxyConfig{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
|
||||
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
|
||||
}
|
||||
|
||||
func (m *manager) publishChangeEvent(reason string) {
|
||||
event := models.ProxyStatusChangedEvent{Action: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
|
||||
cacheData := m.syncer.Get()
|
||||
if cacheData.ActiveProxies == nil {
|
||||
return nil, ErrNoActiveProxies
|
||||
}
|
||||
if apiKey.ProxyID != nil {
|
||||
if proxy, ok := cacheData.ProxiesByID[*apiKey.ProxyID]; ok {
|
||||
return proxy, nil
|
||||
}
|
||||
}
|
||||
if len(cacheData.ActiveProxies) == 0 {
|
||||
return nil, ErrNoActiveProxies
|
||||
}
|
||||
bestProxy := cacheData.ActiveProxies[0]
|
||||
txErr := m.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(apiKey).Update("proxy_id", bestProxy.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(bestProxy).Update("assigned_keys_count", gorm.Expr("assigned_keys_count + 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return nil, txErr
|
||||
}
|
||||
go m.invalidate()
|
||||
return bestProxy, nil
|
||||
}
|
||||
|
||||
func (m *manager) invalidate() error {
|
||||
m.logger.Info("Proxy cache invalidation triggered.")
|
||||
return m.syncer.Invalidate()
|
||||
}
|
||||
|
||||
func (m *manager) stop() {
|
||||
m.syncer.Stop()
|
||||
}
|
||||
|
||||
func (m *manager) CheckSingleProxy(proxyURL string, timeout time.Duration) *ProxyCheckResult {
|
||||
parsed := parseProxyString(proxyURL)
|
||||
if parsed == nil {
|
||||
return &ProxyCheckResult{Proxy: proxyURL, IsAvailable: false, ErrorMessage: "Invalid URL format"}
|
||||
}
|
||||
|
||||
proxyCfg := &models.ProxyConfig{Protocol: parsed.Protocol, Address: parsed.Address}
|
||||
|
||||
startTime := time.Now()
|
||||
isAlive := m.checkProxyConnectivity(proxyCfg, timeout)
|
||||
latency := time.Since(startTime).Seconds()
|
||||
result := &ProxyCheckResult{
|
||||
Proxy: proxyURL,
|
||||
IsAvailable: isAlive,
|
||||
ResponseTime: latency,
|
||||
}
|
||||
if !isAlive {
|
||||
result.ErrorMessage = "Connection failed or timed out"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *manager) CheckMultipleProxies(proxies []string, timeout time.Duration, concurrency int) []*ProxyCheckResult {
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan string, len(proxies))
|
||||
resultsChan := make(chan *ProxyCheckResult, len(proxies))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for proxyURL := range jobs {
|
||||
resultsChan <- m.CheckSingleProxy(proxyURL, timeout)
|
||||
}
|
||||
}()
|
||||
}
|
||||
for _, p := range proxies {
|
||||
jobs <- p
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
close(resultsChan)
|
||||
finalResults := make([]*ProxyCheckResult, 0, len(proxies))
|
||||
for res := range resultsChan {
|
||||
finalResults = append(finalResults, res)
|
||||
}
|
||||
return finalResults
|
||||
}
|
||||
|
||||
func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
|
||||
const ProxyCheckTargetURL = "https://www.google.com/generate_204"
|
||||
transport := &http.Transport{}
|
||||
switch proxyCfg.Protocol {
|
||||
case "http", "https":
|
||||
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
transport.Proxy = http.ProxyURL(proxyUrl)
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
resp, err := client.Get(ProxyCheckTargetURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return true
|
||||
}
|
||||
45
internal/domain/proxy/module.go
Normal file
45
internal/domain/proxy/module.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Filename: internal/domain/proxy/module.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
manager *manager
|
||||
handler *handler
|
||||
}
|
||||
|
||||
func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) {
|
||||
loader := newManagerLoader(gormDB)
|
||||
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
manager := newManager(gormDB, cacheSyncer, taskReporter, store, logger.WithField("domain", "proxy"))
|
||||
handler := newHandler(gormDB, manager, store, settingsManager)
|
||||
return &Module{
|
||||
manager: manager,
|
||||
handler: handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Module) AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
|
||||
return m.manager.assignProxyIfNeeded(apiKey)
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router *gin.RouterGroup) {
|
||||
m.handler.registerRoutes(router)
|
||||
}
|
||||
|
||||
func (m *Module) Stop() {
|
||||
m.manager.stop()
|
||||
}
|
||||
167
internal/domain/upstream/handler.go
Normal file
167
internal/domain/upstream/handler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Filename: internal/domain/upstream/handler.go
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service) *Handler {
|
||||
return &Handler{service: service}
|
||||
}
|
||||
|
||||
// ------ DTOs and Validation ------
|
||||
type CreateUpstreamRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Weight int `json:"weight" binding:"omitempty,gte=1,lte=1000"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
type UpdateUpstreamRequest struct {
|
||||
URL *string `json:"url"`
|
||||
Weight *int `json:"weight" binding:"omitempty,gte=1,lte=1000"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
func isValidURL(rawURL string) bool {
|
||||
u, err := url.ParseRequestURI(rawURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Scheme == "http" || u.Scheme == "https"
|
||||
}
|
||||
|
||||
// --- Handler ---
|
||||
|
||||
func (h *Handler) CreateUpstream(c *gin.Context) {
|
||||
var req CreateUpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if !isValidURL(req.URL) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid URL format"))
|
||||
return
|
||||
}
|
||||
|
||||
upstream := models.UpstreamEndpoint{
|
||||
URL: req.URL,
|
||||
Weight: req.Weight,
|
||||
Status: req.Status,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
if err := h.service.Create(&upstream); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Created(c, upstream)
|
||||
}
|
||||
|
||||
func (h *Handler) ListUpstreams(c *gin.Context) {
|
||||
upstreams, err := h.service.List()
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, upstreams)
|
||||
}
|
||||
|
||||
func (h *Handler) GetUpstream(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
upstream, err := h.service.GetByID(id)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, upstream)
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateUpstream(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
var req UpdateUpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
upstream, err := h.service.GetByID(id)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL != nil {
|
||||
if !isValidURL(*req.URL) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid URL format"))
|
||||
return
|
||||
}
|
||||
upstream.URL = *req.URL
|
||||
}
|
||||
if req.Weight != nil {
|
||||
upstream.Weight = *req.Weight
|
||||
}
|
||||
if req.Status != nil {
|
||||
upstream.Status = *req.Status
|
||||
}
|
||||
if req.Description != nil {
|
||||
upstream.Description = *req.Description
|
||||
}
|
||||
|
||||
if err := h.service.Update(upstream); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, upstream)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteUpstream(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
|
||||
rowsAffected, err := h.service.Delete(id)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
response.Error(c, errors.ErrResourceNotFound)
|
||||
return
|
||||
}
|
||||
response.NoContent(c)
|
||||
}
|
||||
|
||||
// RegisterRoutes
|
||||
|
||||
func (h *Handler) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
upstreamRoutes := rg.Group("/upstreams")
|
||||
{
|
||||
upstreamRoutes.POST("/", h.CreateUpstream)
|
||||
upstreamRoutes.GET("/", h.ListUpstreams)
|
||||
upstreamRoutes.GET("/:id", h.GetUpstream)
|
||||
upstreamRoutes.PUT("/:id", h.UpdateUpstream)
|
||||
upstreamRoutes.DELETE("/:id", h.DeleteUpstream)
|
||||
}
|
||||
}
|
||||
36
internal/domain/upstream/module.go
Normal file
36
internal/domain/upstream/module.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Filename: internal/domain/upstream/module.go
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
service *Service
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
func NewModule(db *gorm.DB) *Module {
|
||||
service := NewService(db)
|
||||
handler := NewHandler(service)
|
||||
|
||||
return &Module{
|
||||
service: service,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// === 领域暴露的公共API ===
|
||||
|
||||
// SelectActiveWeighted
|
||||
|
||||
func (m *Module) SelectActiveWeighted(upstreams []*models.UpstreamEndpoint) (*models.UpstreamEndpoint, error) {
|
||||
return m.service.SelectActiveWeighted(upstreams)
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router *gin.RouterGroup) {
|
||||
m.handler.RegisterRoutes(router)
|
||||
}
|
||||
84
internal/domain/upstream/service.go
Normal file
84
internal/domain/upstream/service.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Filename: internal/domain/upstream/service.go
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB) *Service {
|
||||
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
return &Service{db: db}
|
||||
}
|
||||
|
||||
func (s *Service) SelectActiveWeighted(upstreams []*models.UpstreamEndpoint) (*models.UpstreamEndpoint, error) {
|
||||
activeUpstreams := make([]*models.UpstreamEndpoint, 0)
|
||||
totalWeight := 0
|
||||
for _, u := range upstreams {
|
||||
if u.Status == "active" {
|
||||
activeUpstreams = append(activeUpstreams, u)
|
||||
totalWeight += u.Weight
|
||||
}
|
||||
}
|
||||
if len(activeUpstreams) == 0 {
|
||||
return nil, errors.New("no active upstream endpoints available")
|
||||
}
|
||||
if totalWeight <= 0 || len(activeUpstreams) == 1 {
|
||||
return activeUpstreams[0], nil
|
||||
}
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
for _, u := range activeUpstreams {
|
||||
randomWeight -= u.Weight
|
||||
if randomWeight < 0 {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
return activeUpstreams[len(activeUpstreams)-1], nil
|
||||
}
|
||||
|
||||
// CRUD,供Handler调用
|
||||
|
||||
func (s *Service) Create(upstream *models.UpstreamEndpoint) error {
|
||||
if upstream.Weight == 0 {
|
||||
upstream.Weight = 100 // 默认权重
|
||||
}
|
||||
if upstream.Status == "" {
|
||||
upstream.Status = "active" // 默认状态
|
||||
}
|
||||
return s.db.Create(upstream).Error
|
||||
}
|
||||
|
||||
// List Service层只做数据库查询
|
||||
func (s *Service) List() ([]models.UpstreamEndpoint, error) {
|
||||
var upstreams []models.UpstreamEndpoint
|
||||
err := s.db.Find(&upstreams).Error
|
||||
return upstreams, err
|
||||
}
|
||||
|
||||
// GetByID Service层只做数据库查询
|
||||
func (s *Service) GetByID(id int) (*models.UpstreamEndpoint, error) {
|
||||
var upstream models.UpstreamEndpoint
|
||||
if err := s.db.First(&upstream, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &upstream, nil
|
||||
}
|
||||
|
||||
// Update Service层只做数据库更新
|
||||
func (s *Service) Update(upstream *models.UpstreamEndpoint) error {
|
||||
return s.db.Save(upstream).Error
|
||||
}
|
||||
|
||||
// Delete Service层只做数据库删除
|
||||
func (s *Service) Delete(id int) (int64, error) {
|
||||
result := s.db.Delete(&models.UpstreamEndpoint{}, id)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
101
internal/errors/api_error.go
Normal file
101
internal/errors/api_error.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Filename: internal/error/api_error.go
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// APIError defines a standard error structure for API responses.
|
||||
type APIError struct {
|
||||
HTTPStatus int
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *APIError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Predefined API errors
|
||||
var (
|
||||
ErrBadRequest = &APIError{HTTPStatus: http.StatusBadRequest, Code: "BAD_REQUEST", Message: "Invalid request parameters"}
|
||||
ErrInvalidJSON = &APIError{HTTPStatus: http.StatusBadRequest, Code: "INVALID_JSON", Message: "Invalid JSON format"}
|
||||
ErrValidation = &APIError{HTTPStatus: http.StatusBadRequest, Code: "VALIDATION_FAILED", Message: "Input validation failed"}
|
||||
ErrDuplicateResource = &APIError{HTTPStatus: http.StatusConflict, Code: "DUPLICATE_RESOURCE", Message: "Resource already exists"}
|
||||
ErrResourceNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}
|
||||
ErrInternalServer = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "INTERNAL_SERVER_ERROR", Message: "An unexpected error occurred"}
|
||||
ErrDatabase = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "DATABASE_ERROR", Message: "Database operation failed"}
|
||||
ErrUnauthorized = &APIError{HTTPStatus: http.StatusUnauthorized, Code: "UNAUTHORIZED", Message: "Authentication failed"}
|
||||
ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"}
|
||||
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
|
||||
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
|
||||
ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"}
|
||||
ErrMaxRetriesExceeded = &APIError{HTTPStatus: http.StatusBadGateway, Code: "MAX_RETRIES_EXCEEDED", Message: "Request failed after maximum retries"}
|
||||
ErrNoKeysAvailable = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_KEYS_AVAILABLE", Message: "No API keys available to process the request"}
|
||||
|
||||
ErrStateConflict = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT", Message: "The operation cannot be completed due to the current state of the resource."}
|
||||
ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."}
|
||||
ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."}
|
||||
ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."}
|
||||
|
||||
ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."}
|
||||
ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}
|
||||
ErrNoKeysMatchFilter = &APIError{HTTPStatus: http.StatusBadRequest, Code: "NO_KEYS_MATCH_FILTER", Message: "No keys were found that match the provided filter criteria."}
|
||||
)
|
||||
|
||||
// NewAPIError creates a new APIError with a custom message.
|
||||
func NewAPIError(base *APIError, message string) *APIError {
|
||||
return &APIError{
|
||||
HTTPStatus: base.HTTPStatus,
|
||||
Code: base.Code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAPIErrorWithUpstream creates a new APIError specifically for wrapping raw upstream errors.
|
||||
func NewAPIErrorWithUpstream(statusCode int, code string, upstreamMessage string) *APIError {
|
||||
return &APIError{
|
||||
HTTPStatus: statusCode,
|
||||
Code: code,
|
||||
Message: upstreamMessage,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseDBError intelligently converts a GORM error into a standard APIError.
|
||||
func ParseDBError(err error) *APIError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrResourceNotFound
|
||||
}
|
||||
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
if pgErr.Code == "23505" { // unique_violation
|
||||
return ErrDuplicateResource
|
||||
}
|
||||
}
|
||||
|
||||
var mysqlErr *mysql.MySQLError
|
||||
if errors.As(err, &mysqlErr) {
|
||||
if mysqlErr.Number == 1062 { // Duplicate entry
|
||||
return ErrDuplicateResource
|
||||
}
|
||||
}
|
||||
|
||||
// Generic check for SQLite
|
||||
if strings.Contains(strings.ToLower(err.Error()), "unique constraint failed") {
|
||||
return ErrDuplicateResource
|
||||
}
|
||||
|
||||
return ErrDatabase
|
||||
}
|
||||
19
internal/errors/errors.go
Normal file
19
internal/errors/errors.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Filename: internal/errors/errors.go
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
std_errors "errors" // 为标准库errors包指定别名
|
||||
)
|
||||
|
||||
func Is(err, target error) bool {
|
||||
return std_errors.Is(err, target)
|
||||
}
|
||||
|
||||
func As(err error, target any) bool {
|
||||
return std_errors.As(err, target)
|
||||
}
|
||||
|
||||
func Unwrap(err error) error {
|
||||
return std_errors.Unwrap(err)
|
||||
}
|
||||
111
internal/errors/upstream_errors.go
Normal file
111
internal/errors/upstream_errors.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Filename: internal/errors/upstream_errors.go
|
||||
package errors
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TODO: [Future Evolution] This file establishes the new, granular error classification framework.
|
||||
// The next step is to refactor the handleKeyUsageEvent method in APIKeyService to utilize these new
|
||||
// classifiers and implement the corresponding actions:
|
||||
//
|
||||
// 1. On IsPermanentUpstreamError:
|
||||
// - Set mapping status to models.StatusBanned.
|
||||
// - Set the master APIKey's status to models.MasterStatusRevoked.
|
||||
// - This is a "one-strike, you're out" policy for definitively invalid keys.
|
||||
//
|
||||
// 2. On IsTemporaryUpstreamError:
|
||||
// - Increment mapping.ConsecutiveErrorCount.
|
||||
// - Check against the blacklist threshold to potentially set status to models.StatusDisabled.
|
||||
// - This is for recoverable errors that are the key's fault (e.g., quota limits).
|
||||
//
|
||||
// 3. On ALL other upstream errors (that are not Permanent or Temporary):
|
||||
// - These are treated as "Truly Ignorable" from the key's perspective (e.g., 503 Service Unavailable).
|
||||
// - Do NOT increment the error count. Only update LastUsedAt.
|
||||
// - This prevents good keys from being punished for upstream service instability.
|
||||
|
||||
// --- 1. Permanent Errors ---
|
||||
// Errors that indicate the API Key itself is permanently invalid.
|
||||
// Action: Ban mapping, Revoke Master Key.
|
||||
var permanentErrorSubstrings = []string{
|
||||
"invalid api key",
|
||||
"api key not valid",
|
||||
"api key suspended",
|
||||
"API Key not found",
|
||||
"api key expired",
|
||||
"permission denied", // Often indicates the key lacks permissions for the target model/service.
|
||||
"permission_denied", // Catches the 'status' field in Google's JSON error, e.g., "status": "PERMISSION_DENIED".
|
||||
"service_disabled", // Catches the 'reason' field for disabled APIs, e.g., "reason": "SERVICE_DISABLED".
|
||||
"api has not been used",
|
||||
}
|
||||
|
||||
// --- 2. Temporary Errors ---
|
||||
// Errors that are attributable to the key's state but are recoverable over time.
|
||||
// Action: Increment consecutive error count, potentially disable the key.
|
||||
var temporaryErrorSubstrings = []string{
|
||||
"quota",
|
||||
"limit reached",
|
||||
"insufficient",
|
||||
"billing",
|
||||
"exceeded",
|
||||
"too many requests",
|
||||
}
|
||||
|
||||
// --- 3. Unretryable Request Errors ---
|
||||
// Errors indicating a problem with the user's request, not the key. Retrying with a new key is pointless.
|
||||
// Action: Abort the retry loop immediately in ProxyHandler.
|
||||
var unretryableRequestErrorSubstrings = []string{
|
||||
"invalid content",
|
||||
"invalid argument",
|
||||
"malformed",
|
||||
"unsupported",
|
||||
"invalid model",
|
||||
}
|
||||
|
||||
// --- 4. Ignorable Client/Network Errors ---
|
||||
// Network-level errors, typically caused by the client disconnecting.
|
||||
// Action: Ignore for logging and metrics purposes.
|
||||
var clientNetworkErrorSubstrings = []string{
|
||||
"context canceled",
|
||||
"connection reset by peer",
|
||||
"broken pipe",
|
||||
"use of closed network connection",
|
||||
"request canceled",
|
||||
}
|
||||
|
||||
// IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid.
|
||||
func IsPermanentUpstreamError(msg string) bool {
|
||||
return containsSubstring(msg, permanentErrorSubstrings)
|
||||
}
|
||||
|
||||
// IsTemporaryUpstreamError checks if an upstream error is due to temporary, key-specific limits.
|
||||
func IsTemporaryUpstreamError(msg string) bool {
|
||||
return containsSubstring(msg, temporaryErrorSubstrings)
|
||||
}
|
||||
|
||||
// IsUnretryableRequestError checks if an upstream error is due to a malformed user request.
|
||||
func IsUnretryableRequestError(msg string) bool {
|
||||
return containsSubstring(msg, unretryableRequestErrorSubstrings)
|
||||
}
|
||||
|
||||
// IsClientNetworkError checks if an error is a common, ignorable client-side network issue.
|
||||
func IsClientNetworkError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return containsSubstring(err.Error(), clientNetworkErrorSubstrings)
|
||||
}
|
||||
|
||||
// containsSubstring is a helper function to avoid code repetition.
|
||||
func containsSubstring(s string, substrings []string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
lowerS := strings.ToLower(s)
|
||||
for _, sub := range substrings {
|
||||
if strings.Contains(lowerS, sub) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
79
internal/errors/upstream_parser.go
Normal file
79
internal/errors/upstream_parser.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxErrorBodyLength defines the maximum length of an error message to be stored or returned.
|
||||
maxErrorBodyLength = 2048
|
||||
)
|
||||
|
||||
// standardErrorResponse matches formats like: {"error": {"message": "..."}}
|
||||
type standardErrorResponse struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// vendorErrorResponse matches formats like: {"error_msg": "..."}
|
||||
type vendorErrorResponse struct {
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
// simpleErrorResponse matches formats like: {"error": "..."}
|
||||
type simpleErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// rootMessageErrorResponse matches formats like: {"message": "..."}
|
||||
type rootMessageErrorResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ParseUpstreamError attempts to parse a structured error message from an upstream response body
|
||||
func ParseUpstreamError(body []byte) string {
|
||||
// 1. Attempt to parse the standard OpenAI/Gemini format.
|
||||
var stdErr standardErrorResponse
|
||||
if err := json.Unmarshal(body, &stdErr); err == nil {
|
||||
if msg := strings.TrimSpace(stdErr.Error.Message); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Attempt to parse vendor-specific format (e.g., Baidu).
|
||||
var vendorErr vendorErrorResponse
|
||||
if err := json.Unmarshal(body, &vendorErr); err == nil {
|
||||
if msg := strings.TrimSpace(vendorErr.ErrorMsg); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Attempt to parse simple error format.
|
||||
var simpleErr simpleErrorResponse
|
||||
if err := json.Unmarshal(body, &simpleErr); err == nil {
|
||||
if msg := strings.TrimSpace(simpleErr.Error); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Attempt to parse root-level message format.
|
||||
var rootMsgErr rootMessageErrorResponse
|
||||
if err := json.Unmarshal(body, &rootMsgErr); err == nil {
|
||||
if msg := strings.TrimSpace(rootMsgErr.Message); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Graceful Degradation: If all parsing fails, return the raw (but safe) body.
|
||||
return truncateString(string(body), maxErrorBodyLength)
|
||||
}
|
||||
|
||||
// truncateString ensures a string does not exceed a maximum length.
|
||||
func truncateString(s string, maxLength int) string {
|
||||
if len(s) > maxLength {
|
||||
return s[:maxLength]
|
||||
}
|
||||
return s
|
||||
}
|
||||
50
internal/handlers/api_auth_handler.go
Normal file
50
internal/handlers/api_auth_handler.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Filename: internal/handlers/api_auth_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type APIAuthHandler struct {
|
||||
securityService *service.SecurityService
|
||||
}
|
||||
|
||||
func NewAPIAuthHandler(securityService *service.SecurityService) *APIAuthHandler {
|
||||
return &APIAuthHandler{securityService: securityService}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (h *APIAuthHandler) HandleLogin(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authToken, err := h.securityService.AuthenticateToken(req.Token)
|
||||
// 同时检查token是否有效,以及是否是管理员
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
h.securityService.RecordFailedLoginAttempt(c.Request.Context(), c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效或非管理员Token"})
|
||||
return
|
||||
}
|
||||
|
||||
middleware.SetAdminSessionCookie(c, authToken.Token)
|
||||
|
||||
c.JSON(http.StatusOK, LoginResponse{
|
||||
Token: authToken.Token,
|
||||
Message: "登录成功,欢迎管理员!",
|
||||
})
|
||||
}
|
||||
408
internal/handlers/apikey_handler.go
Normal file
408
internal/handlers/apikey_handler.go
Normal file
@@ -0,0 +1,408 @@
|
||||
// Filename: internal/handlers/apikey_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/task"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.APIKeyService
|
||||
db *gorm.DB
|
||||
keyImportService *service.KeyImportService
|
||||
keyValidationService *service.KeyValidationService
|
||||
}
|
||||
|
||||
func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImportService *service.KeyImportService, keyValidationService *service.KeyValidationService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
db: db,
|
||||
keyImportService: keyImportService,
|
||||
keyValidationService: keyValidationService,
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs for API requests
|
||||
type BulkAddKeysToGroupRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
|
||||
}
|
||||
|
||||
type BulkUnlinkKeysFromGroupRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type BulkHardDeleteKeysRequest struct {
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type BulkRestoreKeysRequest struct {
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Status *string `json:"status" binding:"omitempty,oneof=ACTIVE,PENDING_VALIDATION,COOLDOWN,DISABLED,BANNED"`
|
||||
}
|
||||
|
||||
type UpdateMappingRequest struct {
|
||||
Status models.APIKeyStatus `json:"status" binding:"required,oneof=ACTIVE PENDING_VALIDATION COOLDOWN DISABLED BANNED"`
|
||||
}
|
||||
|
||||
type BulkTestKeysRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type RestoreKeysRequest struct {
|
||||
KeyIDs []uint `json:"key_ids" binding:"required,gt=0"`
|
||||
}
|
||||
type BulkTestKeysForGroupRequest struct {
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type BulkActionFilter struct {
|
||||
Status []string `json:"status"` // Changed to slice to accept multiple statuses
|
||||
}
|
||||
type BulkActionRequest struct {
|
||||
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
|
||||
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
|
||||
Filter BulkActionFilter `json:"filter" binding:"required"`
|
||||
}
|
||||
|
||||
// --- Handler Methods ---
|
||||
|
||||
// AddMultipleKeysToGroup handles adding/linking multiple keys to a specific group.
|
||||
func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
|
||||
var req BulkAddKeysToGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// UnlinkMultipleKeysFromGroup handles unlinking multiple keys from a specific group.
|
||||
func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
|
||||
var req BulkUnlinkKeysFromGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// HardDeleteMultipleKeys handles globally deleting multiple key entities.
|
||||
func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
|
||||
var req BulkHardDeleteKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// RestoreMultipleKeys handles restoring multiple keys to ACTIVE status globally.
|
||||
func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
|
||||
var req BulkRestoreKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
|
||||
var req BulkTestKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
var params models.APIKeyQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
result, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
|
||||
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
// 1. Manually handle the path parameter.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind query parameters using the correctly tagged struct.
|
||||
var params models.APIKeyQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Set server-side defaults and the path parameter.
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
params.KeyGroupID = uint(groupID)
|
||||
// 4. Call the service layer.
|
||||
paginatedResult, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 5. [THE FIX] Return a successful response using the standard `response.Success`
|
||||
// and a gin.H map, as confirmed to exist in your project.
|
||||
response.Success(c, gin.H{
|
||||
"items": paginatedResult.Items,
|
||||
"total": paginatedResult.Total,
|
||||
"page": paginatedResult.Page,
|
||||
"pages": paginatedResult.TotalPages,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
|
||||
// Group ID is now correctly sourced from the URL path.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// The request body is now simpler, only needing the keys.
|
||||
var req BulkTestKeysForGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Call the same underlying service, but with unambiguous context.
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// UpdateAPIKey is DEPRECATED. Status is now contextual to a group.
|
||||
func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
|
||||
err := errors.NewAPIError(errors.ErrBadRequest, "This endpoint is deprecated. Use 'PUT /keygroups/:id/apikeys/:keyId' to update key status within a group context.")
|
||||
response.Error(c, err)
|
||||
}
|
||||
|
||||
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
|
||||
// Route: PUT /keygroups/:id/apikeys/:keyId
|
||||
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
keyID, err := strconv.ParseUint(c.Param("keyId"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Key ID format"))
|
||||
return
|
||||
}
|
||||
var req UpdateMappingRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Directly use the service to handle the logic
|
||||
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
response.Error(c, apiErr)
|
||||
} else {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, updatedMapping)
|
||||
}
|
||||
|
||||
// HardDeleteAPIKey handles globally deleting a single key entity.
|
||||
func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "API key globally deleted successfully"})
|
||||
}
|
||||
|
||||
// RestoreKeysInGroup 恢复指定Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore
|
||||
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
var req RestoreKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
response.Error(c, apiErr)
|
||||
} else {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore-all-banned
|
||||
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
response.Error(c, apiErr)
|
||||
} else {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
|
||||
// Route: POST /keygroups/:id/bulk-actions
|
||||
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
|
||||
// 1. Parse GroupID from URL
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind the JSON payload to our new DTO
|
||||
var req BulkActionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Central logic: based on the action, call the appropriate service method.
|
||||
var task *task.Status
|
||||
var apiErr *errors.APIError
|
||||
switch req.Action {
|
||||
case "revalidate":
|
||||
// Assume keyValidationService has a method that accepts a filter
|
||||
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
case "set_status":
|
||||
if req.NewStatus == "" {
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
|
||||
break
|
||||
}
|
||||
// Assume apiKeyService has a method to update status by filter
|
||||
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
|
||||
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
|
||||
case "delete":
|
||||
// Assume keyImportService has a method to unlink by filter
|
||||
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
default:
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
|
||||
}
|
||||
// 4. Handle errors from the switch block
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
// Attempt to parse it as a known APIError, otherwise, wrap it.
|
||||
var parsedErr *errors.APIError
|
||||
if errors.As(err, &parsedErr) {
|
||||
response.Error(c, parsedErr)
|
||||
} else {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
// 5. Return the task status on success
|
||||
response.Success(c, task)
|
||||
}
|
||||
|
||||
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
|
||||
// Route: GET /keygroups/:id/apikeys/export
|
||||
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
|
||||
statuses := c.QueryArray("status")
|
||||
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, keyStrings)
|
||||
}
|
||||
62
internal/handlers/dashboard_handler.go
Normal file
62
internal/handlers/dashboard_handler.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Filename: internal/handlers/dashboard_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// DashboardHandler 负责处理全局仪表盘相关的API请求
|
||||
type DashboardHandler struct {
|
||||
queryService *service.DashboardQueryService
|
||||
}
|
||||
|
||||
func NewDashboardHandler(qs *service.DashboardQueryService) *DashboardHandler {
|
||||
return &DashboardHandler{queryService: qs}
|
||||
}
|
||||
|
||||
// GetOverview 获取仪表盘的全局统计卡片数据
|
||||
func (h *DashboardHandler) GetOverview(c *gin.Context) {
|
||||
stats, err := h.queryService.GetDashboardOverviewData()
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrInternalServer, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetChart 获取仪表盘的图表数据
|
||||
func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
var groupID *uint
|
||||
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
|
||||
if id, err := strconv.Atoi(groupIDStr); err == nil {
|
||||
uid := uint(id)
|
||||
groupID = &uid
|
||||
}
|
||||
}
|
||||
|
||||
chartData, err := h.queryService.QueryHistoricalChart(groupID)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, chartData)
|
||||
}
|
||||
|
||||
// GetRequestStats 处理对“期间调用概览”的请求
|
||||
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
|
||||
period := c.Param("period") // 从 URL 路径中获取 period
|
||||
stats, err := h.queryService.GetRequestStatsForPeriod(period)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
369
internal/handlers/keygroup_handler.go
Normal file
369
internal/handlers/keygroup_handler.go
Normal file
@@ -0,0 +1,369 @@
|
||||
// Filename: internal/handlers/keygroup_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/store"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
)
|
||||
|
||||
type KeyGroupHandler struct {
|
||||
groupManager *service.GroupManager
|
||||
store store.Store
|
||||
queryService *service.DashboardQueryService
|
||||
}
|
||||
|
||||
func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.DashboardQueryService) *KeyGroupHandler {
|
||||
return &KeyGroupHandler{
|
||||
groupManager: gm,
|
||||
queryService: qs,
|
||||
store: s,
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs & 辅助函数
|
||||
func isValidGroupName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
match, _ := regexp.MatchString("^[a-z0-9_-]{3,30}$", name)
|
||||
return match
|
||||
}
|
||||
|
||||
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
|
||||
type KeyGroupOperationalSettings struct {
|
||||
EnableKeyCheck *bool `json:"enable_key_check"`
|
||||
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
|
||||
KeyBlacklistThreshold *int `json:"key_blacklist_threshold"`
|
||||
KeyCooldownMinutes *int `json:"key_cooldown_minutes"`
|
||||
KeyCheckConcurrency *int `json:"key_check_concurrency"`
|
||||
KeyCheckEndpoint *string `json:"key_check_endpoint"`
|
||||
KeyCheckModel *string `json:"key_check_model"`
|
||||
MaxRetries *int `json:"max_retries"`
|
||||
EnableSmartGateway *bool `json:"enable_smart_gateway"`
|
||||
}
|
||||
|
||||
type CreateKeyGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Description string `json:"description"`
|
||||
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy bool `json:"enable_proxy"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
}
|
||||
|
||||
type UpdateKeyGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
DisplayName *string `json:"display_name"`
|
||||
Description *string `json:"description"`
|
||||
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy *bool `json:"enable_proxy"`
|
||||
ChannelType *string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
|
||||
// M:N associations
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
}
|
||||
|
||||
type KeyGroupResponse struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Description string `json:"description"`
|
||||
PollingStrategy models.PollingStrategy `json:"polling_strategy"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
EnableProxy bool `json:"enable_proxy"`
|
||||
APIKeysCount int64 `json:"api_keys_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Order int `json:"order"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
}
|
||||
|
||||
// [NEW] Define the detailed response structure for a single group.
|
||||
type KeyGroupDetailsResponse struct {
|
||||
KeyGroupResponse
|
||||
Settings *models.GroupSettings `json:"settings,omitempty"`
|
||||
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
|
||||
}
|
||||
|
||||
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
|
||||
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
|
||||
modelNames := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping != nil { // Safety check
|
||||
modelNames = append(modelNames, mapping.ModelName)
|
||||
}
|
||||
}
|
||||
return modelNames
|
||||
}
|
||||
|
||||
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
|
||||
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
|
||||
urls := make([]string, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
if upstream != nil { // Safety check
|
||||
urls = append(urls, upstream.URL)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
|
||||
return KeyGroupResponse{
|
||||
ID: group.ID,
|
||||
Name: group.Name,
|
||||
DisplayName: group.DisplayName,
|
||||
Description: group.Description,
|
||||
PollingStrategy: group.PollingStrategy,
|
||||
ChannelType: group.ChannelType,
|
||||
EnableProxy: group.EnableProxy,
|
||||
APIKeysCount: keyCount,
|
||||
CreatedAt: group.CreatedAt,
|
||||
UpdatedAt: group.UpdatedAt,
|
||||
Order: group.Order,
|
||||
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
|
||||
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
|
||||
}
|
||||
}
|
||||
|
||||
// packGroupSettings is a helper to convert request-level operational settings
|
||||
// into the model-level settings struct.
|
||||
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
|
||||
return &models.KeyGroupSettings{
|
||||
EnableKeyCheck: settings.EnableKeyCheck,
|
||||
KeyCheckIntervalMinutes: settings.KeyCheckIntervalMinutes,
|
||||
KeyBlacklistThreshold: settings.KeyBlacklistThreshold,
|
||||
KeyCooldownMinutes: settings.KeyCooldownMinutes,
|
||||
KeyCheckConcurrency: settings.KeyCheckConcurrency,
|
||||
KeyCheckEndpoint: settings.KeyCheckEndpoint,
|
||||
KeyCheckModel: settings.KeyCheckModel,
|
||||
MaxRetries: settings.MaxRetries,
|
||||
EnableSmartGateway: settings.EnableSmartGateway,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
return nil, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format")
|
||||
}
|
||||
group, ok := h.groupManager.GetGroupByID(uint(id))
|
||||
if !ok {
|
||||
return nil, errors.NewAPIError(errors.ErrResourceNotFound, "Group not found")
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
|
||||
if req.Name != nil {
|
||||
group.Name = *req.Name
|
||||
}
|
||||
p := bluemonday.StripTagsPolicy()
|
||||
if req.DisplayName != nil {
|
||||
group.DisplayName = p.Sanitize(*req.DisplayName)
|
||||
}
|
||||
if req.Description != nil {
|
||||
group.Description = p.Sanitize(*req.Description)
|
||||
}
|
||||
if req.PollingStrategy != nil {
|
||||
group.PollingStrategy = models.PollingStrategy(*req.PollingStrategy)
|
||||
}
|
||||
if req.EnableProxy != nil {
|
||||
group.EnableProxy = *req.EnableProxy
|
||||
}
|
||||
if req.ChannelType != nil {
|
||||
group.ChannelType = *req.ChannelType
|
||||
}
|
||||
}
|
||||
|
||||
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
|
||||
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
|
||||
go func() {
|
||||
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
h.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
// --- Handler 方法 ---
|
||||
|
||||
func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
|
||||
var req CreateKeyGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if !isValidGroupName(req.Name) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name. Must be 3-30 characters, lowercase letters, numbers, hyphens, or underscores."))
|
||||
return
|
||||
}
|
||||
|
||||
// The core logic remains, as it's specific to creation.
|
||||
p := bluemonday.StripTagsPolicy()
|
||||
sanitizedDisplayName := p.Sanitize(req.DisplayName)
|
||||
sanitizedDescription := p.Sanitize(req.Description)
|
||||
keyGroup := &models.KeyGroup{
|
||||
Name: req.Name,
|
||||
DisplayName: sanitizedDisplayName,
|
||||
Description: sanitizedDescription,
|
||||
PollingStrategy: models.PollingStrategy(req.PollingStrategy),
|
||||
EnableProxy: req.EnableProxy,
|
||||
ChannelType: req.ChannelType,
|
||||
}
|
||||
if keyGroup.PollingStrategy == "" {
|
||||
keyGroup.PollingStrategy = models.StrategySequential
|
||||
}
|
||||
if keyGroup.ChannelType == "" {
|
||||
keyGroup.ChannelType = "gemini"
|
||||
}
|
||||
|
||||
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
|
||||
if err := h.groupManager.CreateKeyGroup(keyGroup, groupSettings); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
h.publishGroupChangeEvent(keyGroup.ID, "group_created")
|
||||
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
|
||||
}
|
||||
|
||||
// 统一的处理器可以处理两种情况:
|
||||
// 1. GET /keygroups - 返回所有组的列表
|
||||
// 2. GET /keygroups/:id - 返回指定ID的单个组
|
||||
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
// Case 1: Get a single group
|
||||
if idStr := c.Param("id"); idStr != "" {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
keyCount := h.groupManager.GetKeyCount(group.ID)
|
||||
baseResponse := h.newKeyGroupResponse(group, keyCount)
|
||||
detailedResponse := KeyGroupDetailsResponse{
|
||||
KeyGroupResponse: baseResponse,
|
||||
Settings: group.Settings,
|
||||
RequestConfig: group.RequestConfig,
|
||||
}
|
||||
response.Success(c, detailedResponse)
|
||||
return
|
||||
}
|
||||
// Case 2: Get all groups
|
||||
allGroups := h.groupManager.GetAllGroups()
|
||||
responses := make([]KeyGroupResponse, 0, len(allGroups))
|
||||
for _, group := range allGroups {
|
||||
keyCount := h.groupManager.GetKeyCount(group.ID)
|
||||
responses = append(responses, h.newKeyGroupResponse(group, keyCount))
|
||||
}
|
||||
response.Success(c, responses)
|
||||
}
|
||||
|
||||
// UpdateKeyGroup
|
||||
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
var req UpdateKeyGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if req.Name != nil && !isValidGroupName(*req.Name) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name format."))
|
||||
return
|
||||
}
|
||||
applyUpdateRequestToGroup(&req, group)
|
||||
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
|
||||
err := h.groupManager.UpdateKeyGroup(group, groupSettings, req.AllowedUpstreams, req.AllowedModels)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
h.publishGroupChangeEvent(group.ID, "group_updated")
|
||||
freshGroup, _ := h.groupManager.GetGroupByID(group.ID)
|
||||
keyCount := h.groupManager.GetKeyCount(freshGroup.ID)
|
||||
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
|
||||
}
|
||||
|
||||
// DeleteKeyGroup
|
||||
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
groupName := group.Name
|
||||
if err := h.groupManager.DeleteKeyGroup(group.ID); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
h.publishGroupChangeEvent(group.ID, "group_deleted")
|
||||
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
|
||||
}
|
||||
|
||||
// GetKeyGroupStats
|
||||
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
stats, err := h.queryService.GetGroupStats(group.ID)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
clonedGroup, err := h.groupManager.CloneKeyGroup(group.ID)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
keyCount := int64(len(clonedGroup.Mappings))
|
||||
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
|
||||
}
|
||||
|
||||
// 更新分组排序
|
||||
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
|
||||
var payload []service.UpdateOrderPayload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
response.Success(c, gin.H{"message": "No order data to update."})
|
||||
return
|
||||
}
|
||||
if err := h.groupManager.UpdateOrder(payload); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "Group order updated successfully."})
|
||||
}
|
||||
33
internal/handlers/log_handler.go
Normal file
33
internal/handlers/log_handler.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Filename: internal/handlers/log_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LogHandler 负责处理与日志相关的HTTP请求
|
||||
type LogHandler struct {
|
||||
logService *service.LogService
|
||||
}
|
||||
|
||||
func NewLogHandler(logService *service.LogService) *LogHandler {
|
||||
return &LogHandler{logService: logService}
|
||||
}
|
||||
|
||||
func (h *LogHandler) GetLogs(c *gin.Context) {
|
||||
// 直接将Gin的上下文传递给Service层,让Service自己去解析查询参数
|
||||
logs, err := h.logService.GetLogs(c)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []models.RequestLog{}
|
||||
}
|
||||
response.Success(c, logs)
|
||||
}
|
||||
581
internal/handlers/proxy_handler.go
Normal file
581
internal/handlers/proxy_handler.go
Normal file
@@ -0,0 +1,581 @@
|
||||
// Filename: internal/handlers/proxy_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type proxyErrorKey int
|
||||
|
||||
const proxyErrKey proxyErrorKey = 0
|
||||
|
||||
type ProxyHandler struct {
|
||||
resourceService *service.ResourceService
|
||||
store store.Store
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *service.GroupManager
|
||||
channel channel.ChannelProxy
|
||||
logger *logrus.Entry
|
||||
transparentProxy *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
func NewProxyHandler(
|
||||
resourceService *service.ResourceService,
|
||||
store store.Store,
|
||||
sm *settings.SettingsManager,
|
||||
gm *service.GroupManager,
|
||||
channel channel.ChannelProxy,
|
||||
logger *logrus.Logger,
|
||||
) *ProxyHandler {
|
||||
ph := &ProxyHandler{
|
||||
resourceService: resourceService,
|
||||
store: store,
|
||||
settingsManager: sm,
|
||||
groupManager: gm,
|
||||
channel: channel,
|
||||
logger: logger.WithField("component", "ProxyHandler"),
|
||||
transparentProxy: &httputil.ReverseProxy{},
|
||||
}
|
||||
ph.transparentProxy.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
ph.transparentProxy.ErrorHandler = ph.transparentProxyErrorHandler
|
||||
ph.transparentProxy.BufferPool = &bufferPool{}
|
||||
return ph
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
if c.Request.Method == "GET" && (strings.HasSuffix(c.Request.URL.Path, "/models") || strings.HasSuffix(c.Request.URL.Path, "/models/")) {
|
||||
h.handleListModelsRequest(c)
|
||||
return
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body"))
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
c.Request.ContentLength = int64(len(requestBody))
|
||||
modelName := h.channel.ExtractModel(c, requestBody)
|
||||
groupName := c.Param("group_name")
|
||||
isPreciseRouting := groupName != ""
|
||||
if !isPreciseRouting && modelName == "" {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL"))
|
||||
return
|
||||
}
|
||||
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*errors.APIError); ok {
|
||||
errToJSON(c, uuid.New().String(), apiErr)
|
||||
} else {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to build operational config.")
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration"))
|
||||
return
|
||||
}
|
||||
|
||||
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
|
||||
if isOpenAICompatible {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
return
|
||||
}
|
||||
isStream := h.channel.IsStreamRequest(c, requestBody)
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
|
||||
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry {
|
||||
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
|
||||
} else {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
var finalRecorder *httptest.ResponseRecorder
|
||||
var lastUsedResources *service.RequestResources
|
||||
var finalProxyErr *errors.APIError
|
||||
var isSuccess bool
|
||||
var finalPromptTokens, finalCompletionTokens int
|
||||
var actualRetries int = 0
|
||||
defer func() {
|
||||
if lastUsedResources == nil {
|
||||
return
|
||||
}
|
||||
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
|
||||
finalEvent.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
finalEvent.IsSuccess = isSuccess
|
||||
finalEvent.Retries = actualRetries
|
||||
if isSuccess {
|
||||
finalEvent.PromptTokens = finalPromptTokens
|
||||
finalEvent.CompletionTokens = finalCompletionTokens
|
||||
}
|
||||
if finalRecorder != nil {
|
||||
finalEvent.StatusCode = finalRecorder.Code
|
||||
}
|
||||
if !isSuccess {
|
||||
if finalProxyErr != nil {
|
||||
finalEvent.Error = finalProxyErr
|
||||
finalEvent.ErrorCode = finalProxyErr.Code
|
||||
finalEvent.ErrorMessage = finalProxyErr.Message
|
||||
} else if finalRecorder != nil {
|
||||
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, "PROXY_ERROR", "Request failed after all retries.")
|
||||
finalEvent.Error = apiErr
|
||||
finalEvent.ErrorCode = apiErr.Code
|
||||
finalEvent.ErrorMessage = apiErr.Message
|
||||
}
|
||||
}
|
||||
eventData, _ := json.Marshal(finalEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
}()
|
||||
var maxRetries int
|
||||
if isPreciseRouting {
|
||||
// For precise routing, use the group's setting. If not set, fall back to the global setting.
|
||||
if finalOpConfig.MaxRetries != nil {
|
||||
maxRetries = *finalOpConfig.MaxRetries
|
||||
} else {
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
} else {
|
||||
// For BasePool (intelligent aggregation), *always* use the global setting.
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
totalAttempts := maxRetries + 1
|
||||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||||
if c.Request.Context().Err() != nil {
|
||||
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
|
||||
if finalProxyErr == nil {
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
}
|
||||
break
|
||||
}
|
||||
var currentResources *service.RequestResources
|
||||
var err error
|
||||
if attempt == 1 {
|
||||
currentResources = initialResources
|
||||
} else {
|
||||
actualRetries = attempt - 1
|
||||
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
||||
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig)
|
||||
currentResources.RequestConfig = finalRequestConfig
|
||||
lastUsedResources = currentResources
|
||||
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID)
|
||||
var attemptErr *errors.APIError
|
||||
var attemptIsSuccess bool
|
||||
recorder := httptest.NewRecorder()
|
||||
attemptStartTime := time.Now()
|
||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
||||
defer cancel()
|
||||
attemptReq := c.Request.Clone(ctx)
|
||||
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" {
|
||||
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
|
||||
isSuccess = false
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource")
|
||||
continue
|
||||
}
|
||||
h.transparentProxy.Director = func(req *http.Request) {
|
||||
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL)
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.Host = targetURL.Host
|
||||
var pureClientPath string
|
||||
if isPreciseRouting {
|
||||
proxyPrefix := "/proxy/" + groupName
|
||||
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
|
||||
} else {
|
||||
pureClientPath = req.URL.Path
|
||||
}
|
||||
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||||
req.URL.Path = finalPath
|
||||
h.logger.WithFields(logrus.Fields{
|
||||
"correlation_id": correlationID,
|
||||
"attempt": attempt,
|
||||
"key_id": currentResources.APIKey.ID,
|
||||
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
|
||||
"final_request_url": req.URL.String(),
|
||||
}).Infof("Director constructed final upstream request URL.")
|
||||
req.Header.Del("Authorization")
|
||||
h.channel.ModifyRequest(req, currentResources.APIKey)
|
||||
req.Header.Set("X-Correlation-ID", correlationID)
|
||||
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
|
||||
}
|
||||
transport := h.transparentProxy.Transport.(*http.Transport)
|
||||
if currentResources.ProxyConfig != nil {
|
||||
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address)
|
||||
proxyURL, err := url.Parse(proxyURLStr)
|
||||
if err == nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
defer resp.Body.Close()
|
||||
var reader io.ReadCloser
|
||||
var err error
|
||||
isGzipped := resp.Header.Get("Content-Encoding") == "gzip"
|
||||
if isGzipped {
|
||||
reader, err = gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to create gzip reader")
|
||||
reader = resp.Body
|
||||
} else {
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
defer reader.Close()
|
||||
} else {
|
||||
reader = resp.Body
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error())
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message)))
|
||||
return nil
|
||||
}
|
||||
if resp.StatusCode < 400 {
|
||||
attemptIsSuccess = true
|
||||
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes)
|
||||
} else {
|
||||
parsedMsg := errors.ParseUpstreamError(bodyBytes)
|
||||
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
||||
finalRecorder = recorder
|
||||
finalProxyErr = attemptErr
|
||||
isSuccess = attemptIsSuccess
|
||||
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
|
||||
if isSuccess {
|
||||
break
|
||||
}
|
||||
isUnretryableError := false
|
||||
if finalProxyErr != nil {
|
||||
if errors.IsUnretryableRequestError(finalProxyErr.Message) {
|
||||
isUnretryableError = true
|
||||
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message)
|
||||
}
|
||||
}
|
||||
if attempt >= totalAttempts || isUnretryableError {
|
||||
break
|
||||
}
|
||||
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
|
||||
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds())
|
||||
retryEvent.IsSuccess = false
|
||||
retryEvent.StatusCode = recorder.Code
|
||||
retryEvent.Retries = actualRetries
|
||||
if attemptErr != nil {
|
||||
retryEvent.Error = attemptErr
|
||||
retryEvent.ErrorCode = attemptErr.Code
|
||||
retryEvent.ErrorMessage = attemptErr.Message
|
||||
}
|
||||
eventData, _ := json.Marshal(retryEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
}
|
||||
if finalRecorder != nil {
|
||||
bodyBytes := finalRecorder.Body.Bytes()
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
|
||||
for k, v := range finalRecorder.Header() {
|
||||
if strings.ToLower(k) != "content-length" {
|
||||
c.Writer.Header()[k] = v
|
||||
}
|
||||
}
|
||||
c.Writer.WriteHeader(finalRecorder.Code)
|
||||
c.Writer.Write(finalRecorder.Body.Bytes())
|
||||
} else {
|
||||
errToJSON(c, correlationID, finalProxyErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
log := h.logger.WithField("id", correlationID)
|
||||
log.Info("Smart Gateway activated for streaming request.")
|
||||
var originalRequest models.GeminiRequest
|
||||
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
|
||||
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error()))
|
||||
return
|
||||
}
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
modelName := h.channel.ExtractModel(c, requestBody)
|
||||
requestFinishedEvent := h.createLogEvent(c, startTime, correlationID, modelName, resources, models.LogTypeFinal, isPreciseRouting)
|
||||
defer func() {
|
||||
requestFinishedEvent.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
if c.Writer.Status() > 0 {
|
||||
requestFinishedEvent.StatusCode = c.Writer.Status()
|
||||
}
|
||||
eventData, _ := json.Marshal(requestFinishedEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
}()
|
||||
params := channel.SmartRequestParams{
|
||||
CorrelationID: correlationID,
|
||||
APIKey: resources.APIKey,
|
||||
UpstreamURL: resources.UpstreamEndpoint.URL,
|
||||
RequestBody: requestBody,
|
||||
OriginalRequest: originalRequest,
|
||||
EventLogger: requestFinishedEvent,
|
||||
MaxRetries: systemSettings.MaxStreamingRetries,
|
||||
RetryDelay: time.Duration(systemSettings.StreamingRetryDelayMs) * time.Millisecond,
|
||||
LogTruncationLimit: systemSettings.LogTruncationLimit,
|
||||
StreamingRetryPrompt: systemSettings.StreamingRetryPrompt,
|
||||
}
|
||||
h.channel.ProcessSmartStreamRequest(c, params)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
correlationID := r.Header.Get("X-Correlation-ID")
|
||||
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
|
||||
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
|
||||
if !exists || proxyErrPtr == nil {
|
||||
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
|
||||
return
|
||||
}
|
||||
if errors.IsClientNetworkError(err) {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
} else {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||||
}
|
||||
if _, ok := rw.(*httptest.ResponseRecorder); ok {
|
||||
return
|
||||
}
|
||||
if writer, ok := rw.(interface{ Written() bool }); ok {
|
||||
if writer.Written() {
|
||||
return
|
||||
}
|
||||
}
|
||||
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
|
||||
event := &models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
RequestTime: startTime,
|
||||
ModelName: modelName,
|
||||
RequestPath: c.Request.URL.Path,
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
CorrelationID: corrID,
|
||||
LogType: logType,
|
||||
Metadata: make(datatypes.JSONMap),
|
||||
},
|
||||
CorrelationID: corrID,
|
||||
IsPreciseRouting: isPreciseRouting,
|
||||
}
|
||||
if _, exists := c.Get(middleware.RedactedBodyKey); exists {
|
||||
event.RequestLog.Metadata["request_body_present"] = true
|
||||
}
|
||||
if redactedAuth, exists := c.Get(middleware.RedactedAuthHeaderKey); exists {
|
||||
event.RequestLog.Metadata["authorization_header"] = redactedAuth.(string)
|
||||
}
|
||||
if authTokenValue, exists := c.Get("authToken"); exists {
|
||||
if authToken, ok := authTokenValue.(*models.AuthToken); ok {
|
||||
event.AuthTokenID = &authToken.ID
|
||||
}
|
||||
}
|
||||
if res != nil {
|
||||
event.KeyID = res.APIKey.ID
|
||||
event.GroupID = res.KeyGroup.ID
|
||||
if res.UpstreamEndpoint != nil {
|
||||
event.UpstreamID = &res.UpstreamEndpoint.ID
|
||||
event.UpstreamURL = &res.UpstreamEndpoint.URL
|
||||
}
|
||||
if res.ProxyConfig != nil {
|
||||
event.ProxyID = &res.ProxyConfig.ID
|
||||
}
|
||||
}
|
||||
return event
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, groupName string, isPreciseRouting bool) (*service.RequestResources, error) {
|
||||
authTokenValue, exists := c.Get("authToken")
|
||||
if !exists {
|
||||
return nil, errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context")
|
||||
}
|
||||
authToken, ok := authTokenValue.(*models.AuthToken)
|
||||
if !ok {
|
||||
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
|
||||
}
|
||||
if isPreciseRouting {
|
||||
return h.resourceService.GetResourceFromGroup(authToken, groupName)
|
||||
} else {
|
||||
return h.resourceService.GetResourceFromBasePool(authToken, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{
|
||||
"error": apiErr,
|
||||
"correlation_id": corrID,
|
||||
})
|
||||
}
|
||||
|
||||
type bufferPool struct{}
|
||||
|
||||
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
||||
func (b *bufferPool) Put(bytes []byte) {}
|
||||
|
||||
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
||||
var data struct {
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err == nil {
|
||||
return data.UsageMetadata.PromptTokenCount, data.UsageMetadata.CandidatesTokenCount
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
||||
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders)
|
||||
var customHeadersMap datatypes.JSONMap
|
||||
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
|
||||
finalConfig := &models.RequestConfig{
|
||||
CustomHeaders: customHeadersMap,
|
||||
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
||||
StreamMinDelay: globalSettings.StreamMinDelay,
|
||||
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
||||
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
||||
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
||||
StreamChunkSize: globalSettings.StreamChunkSize,
|
||||
EnableFakeStream: globalSettings.EnableFakeStream,
|
||||
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
||||
}
|
||||
if groupConfig == nil {
|
||||
return finalConfig
|
||||
}
|
||||
groupConfigJSON, err := json.Marshal(groupConfig)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
||||
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
return finalConfig
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
||||
authTokenValue, exists := c.Get("authToken")
|
||||
if !exists {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context"))
|
||||
return
|
||||
}
|
||||
authToken, ok := authTokenValue.(*models.AuthToken)
|
||||
if !ok {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context"))
|
||||
return
|
||||
}
|
||||
modelNames := h.resourceService.GetAllowedModelsForToken(authToken)
|
||||
if strings.Contains(c.Request.URL.Path, "/v1beta/") {
|
||||
h.respondWithGeminiFormat(c, modelNames)
|
||||
} else {
|
||||
h.respondWithOpenAIFormat(c, modelNames)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) respondWithOpenAIFormat(c *gin.Context, modelNames []string) {
|
||||
type ModelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
type ModelListResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelEntry `json:"data"`
|
||||
}
|
||||
data := make([]ModelEntry, len(modelNames))
|
||||
for i, name := range modelNames {
|
||||
data[i] = ModelEntry{
|
||||
ID: name,
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "gemini-balancer",
|
||||
}
|
||||
}
|
||||
response := ModelListResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) respondWithGeminiFormat(c *gin.Context, modelNames []string) {
|
||||
type GeminiModelEntry struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
||||
InputTokenLimit int `json:"inputTokenLimit"`
|
||||
OutputTokenLimit int `json:"outputTokenLimit"`
|
||||
}
|
||||
type GeminiModelListResponse struct {
|
||||
Models []GeminiModelEntry `json:"models"`
|
||||
}
|
||||
models := make([]GeminiModelEntry, len(modelNames))
|
||||
for i, name := range modelNames {
|
||||
models[i] = GeminiModelEntry{
|
||||
Name: fmt.Sprintf("models/%s", name),
|
||||
Version: "1.0.0",
|
||||
DisplayName: name,
|
||||
Description: "Served by Gemini Balancer",
|
||||
SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"},
|
||||
InputTokenLimit: 8192,
|
||||
OutputTokenLimit: 2048,
|
||||
}
|
||||
}
|
||||
response := GeminiModelListResponse{Models: models}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
46
internal/handlers/setting_handler.go
Normal file
46
internal/handlers/setting_handler.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// file: gemini-balancer\internal\handlers\setting_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/settings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SettingHandler struct {
|
||||
settingsManager *settings.SettingsManager
|
||||
}
|
||||
|
||||
func NewSettingHandler(settingsManager *settings.SettingsManager) *SettingHandler {
|
||||
return &SettingHandler{settingsManager: settingsManager}
|
||||
}
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
settings := h.settingsManager.GetSettings()
|
||||
response.Success(c, settings)
|
||||
}
|
||||
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
var newSettingsMap map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&newSettingsMap); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if err := h.settingsManager.UpdateSettings(newSettingsMap); err != nil {
|
||||
// TODO 可以根据错误类型返回更具体的错误
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "Settings update request processed successfully."})
|
||||
|
||||
}
|
||||
|
||||
// ResetSettingsToDefaults resets all settings to their default values
|
||||
func (h *SettingHandler) ResetSettingsToDefaults(c *gin.Context) {
|
||||
defaultSettings, err := h.settingsManager.ResetAndSaveSettings()
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, "Failed to reset settings: "+err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, defaultSettings)
|
||||
}
|
||||
51
internal/handlers/task_handler.go
Normal file
51
internal/handlers/task_handler.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TaskHandler struct {
|
||||
taskService *task.Task
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewTaskHandler(taskService *task.Task, logger *logrus.Logger) *TaskHandler {
|
||||
return &TaskHandler{
|
||||
taskService: taskService,
|
||||
logger: logger.WithField("component", "TaskHandler📦"),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// GetTaskStatus
|
||||
// GET /admin/tasks/:id
|
||||
func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
|
||||
taskID := c.Param("id")
|
||||
if taskID == "" {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "task ID is required"))
|
||||
return
|
||||
}
|
||||
|
||||
taskStatus, err := h.taskService.GetStatus(taskID)
|
||||
if err != nil {
|
||||
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
|
||||
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))
|
||||
return
|
||||
}
|
||||
// [探針] 在返回給前端前,打印從存儲中讀取並解析後的 status 對象
|
||||
loggerWithTaskID := h.logger.WithField("task_id", taskID)
|
||||
loggerWithTaskID.Debugf("Status read from store, ABOUT TO BE SENT to frontend: %+v", taskStatus)
|
||||
// [探針] 手動序列化並打印
|
||||
if h.logger.Logger.IsLevelEnabled(logrus.DebugLevel) {
|
||||
jsonData, _ := json.Marshal(taskStatus)
|
||||
loggerWithTaskID.Debugf("Manually marshalled JSON to be sent to frontend: %s", string(jsonData))
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
51
internal/handlers/tokens_handler.go
Normal file
51
internal/handlers/tokens_handler.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Filename: internal/handlers/tokens_handler.go
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TokensHandler struct {
|
||||
db *gorm.DB
|
||||
tokenManager *service.TokenManager
|
||||
}
|
||||
|
||||
func NewTokensHandler(db *gorm.DB, tm *service.TokenManager) *TokensHandler {
|
||||
return &TokensHandler{
|
||||
db: db,
|
||||
tokenManager: tm,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TokensHandler) GetAllTokens(c *gin.Context) {
|
||||
tokensFromCache := h.tokenManager.GetAllTokens()
|
||||
//TODO 可以像KeyGroupResponse一样,创建一个TokenResponse DTO来整理数据
|
||||
response.Success(c, tokensFromCache)
|
||||
}
|
||||
|
||||
func (h *TokensHandler) UpdateTokens(c *gin.Context) {
|
||||
var incomingTokens []*models.TokenUpdateRequest
|
||||
if err := c.ShouldBindJSON(&incomingTokens); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.tokenManager.BatchUpdateTokens(incomingTokens); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, "Failed to update tokens: "+err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "Tokens updated successfully."})
|
||||
}
|
||||
|
||||
// [TODO]
|
||||
// func (h *TokensHandler) CreateToken(c *gin.Context) {
|
||||
// ... 数据库写操作 ...
|
||||
// h.tokenManager.Invalidate() // 写后,立即让缓存失效
|
||||
// }
|
||||
74
internal/logging/logging.go
Normal file
74
internal/logging/logging.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Filename: internal/logging/logging.go
|
||||
|
||||
package logging
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/config"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewLogger(cfg *config.Config) *logrus.Logger {
|
||||
logger := logrus.New()
|
||||
|
||||
// 1. 设置日志级别
|
||||
level, err := logrus.ParseLevel(cfg.Log.Level)
|
||||
if err != nil {
|
||||
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level specified, defaulting to 'info'.")
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
logger.SetLevel(level)
|
||||
|
||||
// 2. 设置日志格式
|
||||
if cfg.Log.Format == "json" {
|
||||
logger.SetFormatter(&logrus.JSONFormatter{
|
||||
TimestampFormat: "2006-01-02T15:04:05.000Z07:00",
|
||||
FieldMap: logrus.FieldMap{
|
||||
logrus.FieldKeyTime: "timestamp",
|
||||
logrus.FieldKeyLevel: "level",
|
||||
logrus.FieldKeyMsg: "message",
|
||||
},
|
||||
})
|
||||
} else {
|
||||
logger.SetFormatter(&logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
})
|
||||
}
|
||||
|
||||
// 3. 设置日志输出
|
||||
if cfg.Log.EnableFile {
|
||||
if cfg.Log.FilePath == "" {
|
||||
logger.Warn("Log file is enabled but no file path is specified. Logging to console only.")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(cfg.Log.FilePath)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
logger.WithError(err).Warn("Failed to create log directory. Logging to console only.")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
}
|
||||
|
||||
logFile, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Failed to open log file. Logging to console only.")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
}
|
||||
|
||||
// 同时输出到控制台和文件
|
||||
logger.SetOutput(io.MultiWriter(os.Stdout, logFile))
|
||||
logger.WithField("log_file_path", cfg.Log.FilePath).Info("Logging is now configured to output to both console and file.")
|
||||
} else {
|
||||
// 仅输出到控制台
|
||||
logger.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
logger.Info("Root logger initialized.")
|
||||
return logger
|
||||
}
|
||||
82
internal/middleware/auth.go
Normal file
82
internal/middleware/auth.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Filename: internal/middleware/auth.go
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// === API Admin 认证管道 (/admin/* API路由) ===
|
||||
|
||||
func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractBearerToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"})
|
||||
return
|
||||
}
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"})
|
||||
return
|
||||
}
|
||||
c.Set("adminUser", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// === /v1 Proxy 认证 ===
|
||||
|
||||
func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractProxyToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"})
|
||||
return
|
||||
}
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil {
|
||||
// 通用信息,避免泄露过多信息
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"})
|
||||
return
|
||||
}
|
||||
c.Set("authToken", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractProxyToken(c *gin.Context) string {
|
||||
if key := c.Query("key"); key != "" {
|
||||
return key
|
||||
}
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
}
|
||||
if key := c.GetHeader("X-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
|
||||
func extractBearerToken(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
return parts[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
84
internal/middleware/log.go
Normal file
84
internal/middleware/log.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Filename: internal/middleware/log_redaction.go
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const RedactedBodyKey = "redactedBody"
|
||||
const RedactedAuthHeaderKey = "redactedAuthHeader"
|
||||
const RedactedValue = `"[REDACTED]"`
|
||||
|
||||
func RedactionMiddleware() gin.HandlerFunc {
|
||||
// Pre-compile regex for efficiency
|
||||
jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
|
||||
bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
|
||||
return func(c *gin.Context) {
|
||||
// --- 1. Redact Request Body ---
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
|
||||
if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil {
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
bodyString := string(bodyBytes)
|
||||
|
||||
redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
|
||||
|
||||
c.Set(RedactedBodyKey, redactedBody)
|
||||
}
|
||||
}
|
||||
// --- 2. Redact Authorization Header ---
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if bearerTokenPattern.MatchString(authHeader) {
|
||||
redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
|
||||
c.Set(RedactedAuthHeaderKey, redactedHeader)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// LogrusLogger is a Gin middleware that logs requests using a Logrus logger.
|
||||
// It consumes redacted data prepared by the RedactionMiddleware.
|
||||
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// After request, gather data and log
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
entry := logger.WithFields(logrus.Fields{
|
||||
"status_code": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"client_ip": c.ClientIP(),
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
})
|
||||
|
||||
if redactedBody, exists := c.Get(RedactedBodyKey); exists {
|
||||
entry = entry.WithField("body", redactedBody)
|
||||
}
|
||||
|
||||
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
|
||||
entry = entry.WithField("authorization", redactedAuth)
|
||||
}
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
entry.Error(c.Errors.String())
|
||||
} else {
|
||||
entry.Info("request handled")
|
||||
}
|
||||
}
|
||||
}
|
||||
31
internal/middleware/security.go
Normal file
31
internal/middleware/security.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Filename: internal/middleware/security.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !settingsManager.IsIPBanEnabled() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
ip := c.ClientIP()
|
||||
isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if isBanned {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁,请稍后再试"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
54
internal/middleware/web.go
Normal file
54
internal/middleware/web.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Filename: internal/middleware/web.go
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/service"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
AdminSessionCookie = "gemini_admin_session"
|
||||
)
|
||||
|
||||
func SetAdminSessionCookie(c *gin.Context, adminToken string) {
|
||||
c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true)
|
||||
}
|
||||
|
||||
func ClearAdminSessionCookie(c *gin.Context) {
|
||||
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
|
||||
}
|
||||
|
||||
func ExtractTokenFromCookie(c *gin.Context) string {
|
||||
cookie, err := c.Cookie(AdminSessionCookie)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie
|
||||
}
|
||||
|
||||
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path)
|
||||
log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie)
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
if err != nil {
|
||||
log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
|
||||
} else if !authToken.IsAdmin {
|
||||
log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
|
||||
} else {
|
||||
log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
|
||||
}
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
ClearAdminSessionCookie(c)
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set("adminUser", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
78
internal/models/dto.go
Normal file
78
internal/models/dto.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// ========= ViewModel / DTOs for Dashboard API =========
|
||||
type StatCard struct {
|
||||
Value float64 `json:"value"`
|
||||
SubValue any `json:"sub_value,omitempty"`
|
||||
SubValueTip string `json:"sub_value_tip,omitempty"`
|
||||
Trend float64 `json:"trend"`
|
||||
TrendIsGrowth bool `json:"trend_is_growth"`
|
||||
}
|
||||
type DashboardStatsResponse struct {
|
||||
KeyCount StatCard `json:"key_count"`
|
||||
RPM StatCard `json:"rpm"`
|
||||
RequestCount24h StatCard `json:"request_count_24h"`
|
||||
ErrorRate24h StatCard `json:"error_rate_24h"`
|
||||
KeyStatusCount map[APIKeyStatus]int64 `json:"key_status_count"`
|
||||
MasterStatusCount map[MasterAPIKeyStatus]int64 `json:"master_status_count"`
|
||||
TokenCount map[string]any `json:"token_count"`
|
||||
UpstreamHealthStatus map[string]string `json:"upstream_health_status,omitempty"`
|
||||
RequestCounts map[string]int64 `json:"request_counts"`
|
||||
}
|
||||
type ChartDataset struct {
|
||||
Label string `json:"label"`
|
||||
Data []int64 `json:"data"`
|
||||
Color string `json:"color"`
|
||||
}
|
||||
type ChartData struct {
|
||||
Labels []string `json:"labels"`
|
||||
Datasets []ChartDataset `json:"datasets"`
|
||||
}
|
||||
|
||||
// TokenUpdateRequest DTO for binding the PUT /admin/tokens request
|
||||
type TokenUpdateRequest struct {
|
||||
ID uint `json:"ID"`
|
||||
Token string `json:"Token"`
|
||||
Description string `json:"Description"`
|
||||
Tag string `json:"Tag"`
|
||||
Status string `json:"Status"`
|
||||
IsAdmin bool `json:"IsAdmin"`
|
||||
AllowedGroupIDs []uint `json:"AllowedGroupIDs"`
|
||||
}
|
||||
|
||||
// 数据传输对象(DTO),表示单个Key的测试结果。
|
||||
// ===================================================================
|
||||
type KeyTestResult struct {
|
||||
Key string `json:"key"`
|
||||
Status string `json:"status"` // "valid", "invalid", "error"
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// APIKeyQueryParams defines the parameters for listing/searching keys.
|
||||
type APIKeyQueryParams struct {
|
||||
KeyGroupID uint `form:"-"`
|
||||
Page int `form:"page"`
|
||||
PageSize int `form:"limit"`
|
||||
Status string `form:"status"`
|
||||
Keyword string `form:"keyword"`
|
||||
}
|
||||
|
||||
// APIKeyDetails is a DTO that combines APIKey info with its contextual status from the mapping.
|
||||
type APIKeyDetails struct {
|
||||
// Embedded APIKey fields
|
||||
ID uint `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
APIKey string `json:"api_key"`
|
||||
MasterStatus MasterAPIKeyStatus `json:"master_status"`
|
||||
|
||||
// Mapping-specific fields
|
||||
Status APIKeyStatus `json:"status"`
|
||||
LastError string `json:"last_error"`
|
||||
ConsecutiveErrorCount int `json:"consecutive_error_count"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
CooldownUntil *time.Time `json:"cooldown_until"`
|
||||
EncryptedKey string
|
||||
}
|
||||
69
internal/models/events.go
Normal file
69
internal/models/events.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Filename: internal/models/events.go
|
||||
package models
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Topic 定义事件总线主题名称
|
||||
const (
|
||||
TopicRequestFinished = "events:request_finished"
|
||||
TopicKeyStatusChanged = "events:key_status_changed"
|
||||
TopicUpstreamHealthChanged = "events:upstream_health_changed"
|
||||
TopicMasterKeyStatusChanged = "master_key_status_changed"
|
||||
TopicImportGroupCompleted = "events:import_group_completed"
|
||||
)
|
||||
|
||||
type RequestFinishedEvent struct {
|
||||
RequestLog
|
||||
KeyID uint
|
||||
GroupID uint
|
||||
IsSuccess bool
|
||||
StatusCode int
|
||||
Error *errors.APIError
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
UpstreamID *uint `json:"upstream_id"`
|
||||
UpstreamURL *string `json:"upstream_url,omitempty"`
|
||||
IsPreciseRouting bool `json:"is_precise_routing"`
|
||||
}
|
||||
|
||||
type KeyStatusChangedEvent struct {
|
||||
KeyID uint
|
||||
GroupID uint
|
||||
OldStatus APIKeyStatus
|
||||
NewStatus APIKeyStatus
|
||||
ChangeReason string `json:"change_reason"`
|
||||
ChangedAt time.Time `json:"changed_at"`
|
||||
}
|
||||
|
||||
type UpstreamHealthChangedEvent struct {
|
||||
UpstreamID uint `json:"upstream_id"`
|
||||
UpstreamURL string `json:"upstream_url"`
|
||||
OldStatus string `json:"old_status"` // e.g., "healthy", "unhealthy"
|
||||
NewStatus string `json:"new_status"`
|
||||
Latency time.Duration `json:"latency_ms"` // 延迟时间(毫秒)
|
||||
Reason string `json:"reason"` // 状态变更原因,如 "timeout", "status_503"
|
||||
CheckedAt time.Time `json:"checked_at"`
|
||||
}
|
||||
|
||||
const TopicProxyStatusChanged = "proxy:status_changed"
|
||||
|
||||
type ProxyStatusChangedEvent struct {
|
||||
ProxyID uint `json:"proxy_id"`
|
||||
Action string `json:"action"` // "created", "updated", "deleted"
|
||||
}
|
||||
|
||||
type MasterKeyStatusChangedEvent struct {
|
||||
KeyID uint `json:"key_id"`
|
||||
OldMasterStatus MasterAPIKeyStatus `json:"old_master_status"`
|
||||
NewMasterStatus MasterAPIKeyStatus `json:"new_master_status"`
|
||||
ChangeReason string `json:"change_reason"`
|
||||
ChangedAt time.Time `json:"changed_at"`
|
||||
}
|
||||
|
||||
type ImportGroupCompletedEvent struct {
|
||||
GroupID uint `json:"group_id"`
|
||||
KeyIDs []uint `json:"key_ids"`
|
||||
CompletedAt time.Time `json:"completed_at"`
|
||||
}
|
||||
246
internal/models/models.go
Normal file
246
internal/models/models.go
Normal file
@@ -0,0 +1,246 @@
|
||||
// Filename: internal/models/models.go
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ========= 自定义类型和常量 =========
|
||||
type APIKeyStatus string
|
||||
type MasterAPIKeyStatus string
|
||||
type PollingStrategy string
|
||||
type FileProcessingState string
|
||||
type LogType string
|
||||
|
||||
const (
|
||||
// --- 运营状态 (在中间表中使用) ---
|
||||
StatusPendingValidation APIKeyStatus = "PENDING_VALIDATION"
|
||||
StatusActive APIKeyStatus = "ACTIVE"
|
||||
StatusCooldown APIKeyStatus = "COOLDOWN"
|
||||
StatusDisabled APIKeyStatus = "DISABLED"
|
||||
StatusBanned APIKeyStatus = "BANNED"
|
||||
|
||||
// --- 身份状态 (在APIKey实体中使用) ---
|
||||
MasterStatusActive MasterAPIKeyStatus = "ACTIVE" // 有效
|
||||
MasterStatusRevoked MasterAPIKeyStatus = "REVOKED" // 永久吊销
|
||||
MasterStatusManuallyDisabled MasterAPIKeyStatus = "MANUALLY_DISABLED" // 手动全局禁用
|
||||
|
||||
StrategyWeighted PollingStrategy = "weighted"
|
||||
StrategySequential PollingStrategy = "sequential"
|
||||
StrategyRandom PollingStrategy = "random"
|
||||
FileProcessing FileProcessingState = "PROCESSING"
|
||||
FileActive FileProcessingState = "ACTIVE"
|
||||
FileFailed FileProcessingState = "FAILED"
|
||||
|
||||
LogTypeFinal LogType = "FINAL" // Represents the final outcome of a request, including all retries.
|
||||
LogTypeRetry LogType = "RETRY" // Represents a single, failed attempt that triggered a retry.
|
||||
)
|
||||
|
||||
// ========= 核心数据库模型 =========
|
||||
type KeyGroup struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Name string `gorm:"type:varchar(100);unique;not null"`
|
||||
DisplayName string `gorm:"type:varchar(255)"`
|
||||
Description string `gorm:"type:text"`
|
||||
PollingStrategy PollingStrategy `gorm:"type:varchar(20);not null;default:sequential"`
|
||||
EnableProxy bool `gorm:"not null;default:false"`
|
||||
Sort int `gorm:"default:0"` // 用于业务逻辑排序 (保留)
|
||||
Order int `gorm:"default:0"` // 专用于UI拖拽排序
|
||||
LastValidatedAt *time.Time
|
||||
Mappings []*GroupAPIKeyMapping `gorm:"foreignKey:KeyGroupID"`
|
||||
AllowedModels []*GroupModelMapping `gorm:"foreignKey:GroupID"`
|
||||
AllowedUpstreams []*UpstreamEndpoint `gorm:"many2many:group_upstream_access;"`
|
||||
ChannelType string `gorm:"type:varchar(50);not null;default:'gemini'"`
|
||||
Settings *GroupSettings `gorm:"foreignKey:GroupID"`
|
||||
RequestConfigID *uint `json:"request_config_id"`
|
||||
RequestConfig *RequestConfig `gorm:"foreignKey:RequestConfigID"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
EncryptedKey string `gorm:"type:text;not null"`
|
||||
APIKeyHash string `gorm:"type:varchar(64);unique;not null;index"`
|
||||
MaskedKey string `gorm:"type:varchar(50);index"`
|
||||
MasterStatus MasterAPIKeyStatus `gorm:"type:varchar(25);not null;default:ACTIVE;index"`
|
||||
ProxyID *uint `gorm:"index"`
|
||||
Mappings []*GroupAPIKeyMapping `gorm:"foreignKey:APIKeyID"`
|
||||
APIKey string `gorm:"-"`
|
||||
}
|
||||
|
||||
// GroupAPIKeyMapping 承载 “运营状态”
|
||||
type GroupAPIKeyMapping struct {
|
||||
KeyGroupID uint `gorm:"primaryKey"`
|
||||
APIKeyID uint `gorm:"primaryKey"`
|
||||
Status APIKeyStatus `gorm:"type:varchar(25);not null;default:PENDING_VALIDATION;index"`
|
||||
LastError string `gorm:"type:text"`
|
||||
ConsecutiveErrorCount int `gorm:"not null;default:0"`
|
||||
LastUsedAt *time.Time
|
||||
CooldownUntil *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
APIKey *APIKey `gorm:"foreignKey:APIKeyID"`
|
||||
KeyGroup *KeyGroup `gorm:"foreignKey:KeyGroupID"`
|
||||
}
|
||||
|
||||
type RequestLog struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
RequestTime time.Time `gorm:"index"`
|
||||
LatencyMs int
|
||||
IsSuccess bool
|
||||
StatusCode int
|
||||
ModelName string `gorm:"type:varchar(100);index"`
|
||||
GroupID *uint `gorm:"index"`
|
||||
KeyID *uint `gorm:"index"`
|
||||
AuthTokenID *uint
|
||||
UpstreamID *uint
|
||||
ProxyID *uint
|
||||
Retries int `gorm:"not null;default:0"`
|
||||
ErrorCode string `gorm:"type:varchar(50);index"`
|
||||
ErrorMessage string `gorm:"type:text"`
|
||||
RequestPath string `gorm:"type:varchar(500)"`
|
||||
UserAgent string `gorm:"type:varchar(512)"`
|
||||
PromptTokens int `gorm:"not null;default:0"`
|
||||
CompletionTokens int `gorm:"not null;default:0"`
|
||||
CorrelationID string `gorm:"type:varchar(36);index"`
|
||||
LogType LogType `gorm:"type:varchar(20);index"`
|
||||
Metadata datatypes.JSONMap `gorm:"type:json"`
|
||||
}
|
||||
|
||||
// GroupModelMapping 模型关系表
|
||||
type GroupModelMapping struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
GroupID uint `gorm:"index;uniqueIndex:idx_group_model_unique"`
|
||||
ModelName string `gorm:"type:varchar(100);not null;uniqueIndex:idx_group_model_unique"`
|
||||
}
|
||||
|
||||
// AuthToken
|
||||
type AuthToken struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
EncryptedToken string `gorm:"type:text;not null"`
|
||||
TokenHash string `gorm:"type:varchar(64);unique;not null;index"`
|
||||
Token string `gorm:"-"`
|
||||
Description string `gorm:"type:text"`
|
||||
Tag string `gorm:"type:varchar(100);index"`
|
||||
IsAdmin bool `gorm:"not null;default:false"`
|
||||
HasUnrestrictedAccess bool `gorm:"not null;default:false"`
|
||||
Status string `gorm:"type:varchar(20);not null;default:active"`
|
||||
AllowedGroups []*KeyGroup `gorm:"many2many:token_group_access;"`
|
||||
}
|
||||
|
||||
// FileRecord
|
||||
type FileRecord struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
KeyID uint `gorm:"not null"`
|
||||
Name string `gorm:"type:varchar(255);unique;not null"`
|
||||
DisplayName string `gorm:"type:varchar(255)"`
|
||||
MimeType string `gorm:"type:varchar(100);not null"`
|
||||
SizeBytes int64 `gorm:"not null"`
|
||||
Sha256Hash string `gorm:"type:varchar(64)"`
|
||||
State FileProcessingState `gorm:"type:varchar(20);not null;default:PROCESSING"`
|
||||
Uri string `gorm:"type:varchar(500);not null"`
|
||||
ExpirationTime time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
// StatsHourly 长期历史数据仓库,为趋势分析提供高效查询
|
||||
type StatsHourly struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Time time.Time `gorm:"uniqueIndex:idx_stats_hourly_unique"`
|
||||
GroupID uint `gorm:"uniqueIndex:idx_stats_hourly_unique"`
|
||||
ModelName string `gorm:"type:varchar(100);uniqueIndex:idx_stats_hourly_unique"`
|
||||
RequestCount int64 `gorm:"not null;default:0"`
|
||||
SuccessCount int64 `gorm:"not null;default:0"`
|
||||
PromptTokens int64 `gorm:"not null;default:0"`
|
||||
CompletionTokens int64 `gorm:"not null;default:0"`
|
||||
}
|
||||
|
||||
type UpstreamEndpoint struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
URL string `gorm:"type:varchar(500);unique;not null"`
|
||||
Weight int `gorm:"not null;default:100;index"`
|
||||
Status string `gorm:"type:varchar(20);not null;default:'active';index"`
|
||||
Description string `gorm:"type:text"`
|
||||
}
|
||||
type ProxyConfig struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Address string `gorm:"type:varchar(255);unique;not null"`
|
||||
Protocol string `gorm:"type:varchar(10);not null"`
|
||||
Status string `gorm:"type:varchar(20);not null;default:'active';index"`
|
||||
AssignedKeysCount int `gorm:"not null;default:0;index"`
|
||||
Description string `gorm:"type:text"`
|
||||
}
|
||||
type Setting struct {
|
||||
Key string `gorm:"primarykey;type:varchar(100)" json:"key"`
|
||||
Value string `gorm:"type:text" json:"value"`
|
||||
Name string `gorm:"type:varchar(100)" json:"name"`
|
||||
Description string `gorm:"type:varchar(255)" json:"description"`
|
||||
Type string `gorm:"type:varchar(20)" json:"type"`
|
||||
Category string `gorm:"type:varchar(50);index" json:"category"`
|
||||
DefaultValue string `gorm:"type:text" json:"default_value"`
|
||||
}
|
||||
|
||||
// GroupSettings 用于存储特定于Group的配置覆盖
|
||||
type GroupSettings struct {
|
||||
GroupID uint `gorm:"primaryKey"`
|
||||
SettingsJSON datatypes.JSON `gorm:"type:json"` // 将 KeyGroupSettings 序列化后存入此字段
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// KeyGroupSettings 定义可以被分组覆盖的所有**运营**配置项
|
||||
// 不会直接映射到数据库表,作为 JSON 对象存储在 GroupSettings.SettingsJSON
|
||||
type KeyGroupSettings struct {
|
||||
// 健康检查相关配置
|
||||
EnableKeyCheck *bool `json:"enable_key_check,omitempty"`
|
||||
KeyCheckModel *string `json:"key_check_model,omitempty"`
|
||||
KeyCheckEndpoint *string `json:"key_check_endpoint,omitempty"`
|
||||
KeyCheckConcurrency *int `json:"key_check_concurrency,omitempty"`
|
||||
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes,omitempty"`
|
||||
// 惩罚机制相关配置
|
||||
KeyBlacklistThreshold *int `json:"key_blacklist_threshold,omitempty"`
|
||||
KeyCooldownMinutes *int `json:"key_cooldown_minutes,omitempty"`
|
||||
MaxRetries *int `json:"max_retries,omitempty"`
|
||||
// Smart Gateway
|
||||
EnableSmartGateway *bool `json:"enable_smart_gateway,omitempty"`
|
||||
}
|
||||
|
||||
// RequestConfig 封装了所有直接影响对上游API请求的参数,作为一个独立的数据库模型
|
||||
type RequestConfig struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
//Custom Headers
|
||||
CustomHeaders datatypes.JSONMap `gorm:"type:json" json:"custom_headers"`
|
||||
|
||||
// Streaming Optimization
|
||||
EnableStreamOptimizer bool `json:"enable_stream_optimizer"`
|
||||
StreamMinDelay int `json:"stream_min_delay"`
|
||||
StreamMaxDelay int `json:"stream_max_delay"`
|
||||
StreamShortTextThresh int `json:"stream_short_text_thresh"`
|
||||
StreamLongTextThresh int `json:"stream_long_text_thresh"`
|
||||
StreamChunkSize int `json:"stream_chunk_size"`
|
||||
EnableFakeStream bool `json:"enable_fake_stream"`
|
||||
FakeStreamInterval int `json:"fake_stream_interval"`
|
||||
|
||||
// Model and Safety Settings
|
||||
ModelSettings datatypes.JSON `gorm:"type:json" json:"model_settings"`
|
||||
|
||||
// Generic Overrides for parameters not explicitly defined
|
||||
ConfigOverrides datatypes.JSONMap `gorm:"type:json" json:"config_overrides"`
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
64
internal/models/request.go
Normal file
64
internal/models/request.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Filename: internal/models/request.go
|
||||
package models
|
||||
|
||||
// GeminiRequest 对应客户端发来的JSON请求体
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
GenerationConfig GenerationConfig `json:"generationConfig,omitempty"`
|
||||
SafetySettings []SafetySetting `json:"safetySettings,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiContent 包含角色和内容部分
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []Part `json:"parts"`
|
||||
}
|
||||
|
||||
// Part 代表内容的一个组成部分 (文本或内联数据)
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
// InlineData 用于多模态输入,如图像
|
||||
type InlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"` // Base64-encoded data
|
||||
}
|
||||
|
||||
// GenerationConfig 控制模型的生成行为
|
||||
type GenerationConfig struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
// SafetySetting 定义安全过滤的阈值
|
||||
type SafetySetting struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
// Tool 定义模型可以调用的外部工具
|
||||
type Tool struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
// ========= 用于智能网关流式响应解析的模型 =========
|
||||
// GeminiSSEPayload 是用于解析SSE(Server-Sent Events)事件中data字段的结构体
|
||||
// 它代表了从上游接收到的一个数据块。
|
||||
type GeminiSSEPayload struct {
|
||||
Candidates []*Candidate `json:"candidates"`
|
||||
}
|
||||
|
||||
// Candidate 包含了模型生成的内容和会话的结束原因
|
||||
type Candidate struct {
|
||||
// Content 里面包含了本次返回的具体文本内容
|
||||
Content *GeminiContent `json:"content"`
|
||||
// FinishReason 告知我们流结束的原因,例如 "STOP", "MAX_TOKENS" 等。
|
||||
// 这是我们智能重试逻辑判断的核心依据。
|
||||
FinishReason string `json:"finishReason"`
|
||||
}
|
||||
117
internal/models/runtime.go
Normal file
117
internal/models/runtime.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// internal\models\runtime.go
|
||||
package models
|
||||
|
||||
// ========= 运行时配置 (非数据库模型并提供默认值) =========
|
||||
type SystemSettings struct {
|
||||
DefaultUpstreamURL string `json:"default_upstream_url" default:"https://generativelanguage.googleapis.com/v1beta" name:"全局默认上游URL" category:"请求设置" desc:"当密钥组未指定任何专属上游时,将使用此URL作为最终的兜底。"`
|
||||
RequestLogRetentionDays int `json:"request_log_retention_days" default:"7" name:"日志保留天数" category:"基础设置" desc:"请求日志在数据库中的保留天数。"`
|
||||
RequestTimeoutSeconds int `json:"request_timeout_seconds" default:"600" name:"请求超时(秒)" category:"请求设置" desc:"转发请求的完整生命周期超时(秒)。"`
|
||||
ConnectTimeoutSeconds int `json:"connect_timeout_seconds" default:"15" name:"连接超时(秒)" category:"请求设置" desc:"与上游服务建立新连接的超时时间(秒)。"`
|
||||
MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"请求设置" desc:"单个请求使用不同Key的最大重试次数。"`
|
||||
BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"`
|
||||
KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间,单位为分钟。"`
|
||||
LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"`
|
||||
|
||||
PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"`
|
||||
|
||||
// HealthCheckIntervalSeconds is DEPRECATED. Use specific intervals below.
|
||||
|
||||
UpstreamCheckIntervalSeconds int `json:"upstream_check_interval_seconds" default:"300" name:"上游检查周期(秒)" category:"健康检查" desc:"对所有上游服务进行健康检查的周期。"`
|
||||
ProxyCheckIntervalSeconds int `json:"proxy_check_interval_seconds" default:"600" name:"代理检查周期(秒)" category:"健康检查" desc:"对所有代理服务进行健康检查的周期。"`
|
||||
|
||||
EnableBaseKeyCheck bool `json:"enable_base_key_check" default:"true" name:"启用全局基础Key检查" category:"健康检查" desc:"是否启用全局的、长周期的Key身份状态检查。"`
|
||||
KeyCheckTimeoutSeconds int `json:"key_check_timeout_seconds" default:"20" name:"Key检查超时(秒)" category:"健康检查" desc:"对单个API Key进行有效性验证时的网络超时时间(全局与分组检查共用)。"`
|
||||
BaseKeyCheckIntervalMinutes int `json:"base_key_check_interval_minutes" default:"1440" name:"全局Key检查周期(分钟)" category:"健康检查" desc:"对所有ACTIVE状态的Key进行身份检查的周期,建议设置较长时间,例如1天(1440分钟)。"`
|
||||
BaseKeyCheckConcurrency int `json:"base_key_check_concurrency" default:"5" name:"全局Key检查并发数" category:"健康检查" desc:"执行全局Key身份检查时的并发请求数量。"`
|
||||
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
|
||||
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-1.5-flash" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
|
||||
|
||||
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
|
||||
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`
|
||||
|
||||
EnableProxyCheck bool `json:"enable_proxy_check" default:"true" name:"启用代理检查" category:"健康检查" desc:"是否启用对代理(Proxy)的健康检查。"`
|
||||
ProxyCheckTimeoutSeconds int `json:"proxy_check_timeout_seconds" default:"20" name:"代理检查超时(秒)" category:"健康检查" desc:"通过代理进行连通性测试时的网络超时时间。"`
|
||||
ProxyCheckConcurrency int `json:"proxy_check_concurrency" default:"5" name:"代理测试并发数" category:"健康检查" desc:"后台手动批量测试代理时的默认并发请求数量。"`
|
||||
UseProxyHash bool `json:"use_proxy_hash" default:"false" name:"是否开启固定代理策略" category:"API配置" desc:"开启后,对于每一个API_KEY将根据算法从代理列表中选取同一个代理IP,防止一个API_KEY同时被多个IP访问,也同时防止了一个IP访问了过多的API_KEY。"`
|
||||
|
||||
AnalyticsFlushIntervalSeconds int `json:"analytics_flush_interval_seconds" default:"60" name:"分析数据落盘间隔(秒)" category:"高级设置" desc:"内存中的统计数据多久写入数据库一次。"`
|
||||
|
||||
// 安全设置
|
||||
EnableIPBanning bool `json:"enable_ip_banning" default:"false" name:"启用IP封禁功能" category:"安全设置" desc:"当一个IP连续多次登录失败后,是否自动将其封禁一段时间。"`
|
||||
MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前,允许的连续登录失败次数。"`
|
||||
IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长,单位为分钟。"`
|
||||
|
||||
//智能网关
|
||||
LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时,保留的最大字符数。0表示不截断。"`
|
||||
EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"`
|
||||
EnableStreamingRetry bool `json:"enable_streaming_retry" default:"true" name:"启用流式重试" category:"代理设置" desc:"当智能网关开启时,是否对流式请求进行智能中断续传。"`
|
||||
MaxStreamingRetries int `json:"max_streaming_retries" default:"2" name:"最大流式重试次数" category:"代理设置" desc:"对单个流式会话,允许的最大连续重试次数。"`
|
||||
StreamingRetryDelayMs int `json:"streaming_retry_delay_ms" default:"750" name:"流式重试延迟(毫秒)" category:"代理设置" desc:"流式会话重试之间的等待时间,单位为毫秒。"`
|
||||
|
||||
// 智能网关底层HTTP Transport配置
|
||||
TransportMaxIdleConns int `json:"transport_max_idle_conns" default:"200" name:"最大空闲连接数(总)" category:"高级设置" desc:"HTTP客户端Transport的最大总空闲连接数。"`
|
||||
TransportMaxIdleConnsPerHost int `json:"transport_max_idle_conns_per_host" default:"100" name:"最大空-闲连接数(单主机)" category:"高级设置" desc:"HTTP客户端Transport对单个主机的最大空闲连接数。"`
|
||||
TransportIdleConnTimeoutSecs int `json:"transport_idle_conn_timeout_secs" default:"90" name:"空闲连接超时(秒)" category:"高级设置" desc:"HTTP客户端Transport中空闲连接被关闭前的等待时间。"`
|
||||
TransportTLSHandshakeTimeout int `json:"transport_tls_handshake_timeout" default:"10" name:"TLS握手超时(秒)" category:"高级设置" desc:"TLS握手的超时时间。"`
|
||||
|
||||
// 智能续传的自定义Prompt
|
||||
StreamingRetryPrompt string `json:"streaming_retry_prompt" default:"Continue exactly where you left off, providing the final answer without repeating the previous thinking steps." name:"智能续传提示词" category:"代理设置" desc:"在进行智能中断续传时,向模型发送的指令。"`
|
||||
|
||||
// 日志服务相关配置
|
||||
LogLevel string `json:"log_level" default:"INFO" name:"日志级别" category:"日志配置"`
|
||||
AutoDeleteErrorLogsEnabled bool `json:"auto_delete_error_logs_enabled" default:"false" name:"自动删除错误日志" category:"日志配置"`
|
||||
AutoDeleteRequestLogsEnabled bool `json:"auto_delete_request_logs_enabled" default:"false" name:"自动删除请求日志" category:"日志配置"`
|
||||
LogBufferCapacity int `json:"log_buffer_capacity" default:"1000" name:"日志缓冲区容量" category:"日志设置" desc:"内存中日志缓冲区的最大容量,超过则可能丢弃日志。"`
|
||||
LogFlushBatchSize int `json:"log_flush_batch_size" default:"100" name:"日志刷新批次大小" category:"日志设置" desc:"每次向数据库批量写入日志的最大数量。"`
|
||||
|
||||
// --- API配置 ---
|
||||
CustomHeaders map[string]string `json:"custom_headers" name:"自定义Headers" category:"API配置" ` // 默认为nil
|
||||
|
||||
// --- TTS 配置 (模块化预留) ---
|
||||
TTSModel string `json:"tts_model" name:"TTS模型" category:"TTS配置"`
|
||||
TTSVoiceName string `json:"tts_voice_name" name:"TTS语音名称" category:"TTS配置"`
|
||||
TTSSpeed string `json:"tts_speed" name:"TTS语速" category:"TTS配置"`
|
||||
|
||||
// --- 图像生成配置 (模块化预留) ---
|
||||
PaidKey string `json:"paid_key" name:"付费API密钥" category:"图像生成"`
|
||||
CreateImageModel string `json:"create_image_model" name:"图像生成模型" category:"图像生成"`
|
||||
UploadProvider string `json:"upload_provider" name:"上传提供商" category:"图像生成"`
|
||||
SmmsSecretToken string `json:"smms_secret_token" name:"SM.MS密钥" category:"图像生成"`
|
||||
PicgoAPIKey string `json:"picgo_api_key" name:"PicGo API密钥" category:"图像生成"`
|
||||
CloudflareImgbedURL string `json:"cloudflare_imgbed_url" name:"Cloudflare图床URL" category:"图像生成"`
|
||||
CloudflareImgbedAuthCode string `json:"cloudflare_imgbed_auth_code" name:"Cloudflare认证码" category:"图像生成"`
|
||||
CloudflareImgbedUploadFolder string `json:"cloudflare_imgbed_upload_folder" name:"Cloudflare上传文件夹" category:"图像生成"`
|
||||
// --- 流式输出配置 (模块化预留) ---
|
||||
EnableStreamOptimizer bool `json:"enable_stream_optimizer" default:"false" name:"启用流式输出优化" category:"流式输出"`
|
||||
StreamMinDelay int `json:"stream_min_delay" default:"16" name:"最小延迟(秒)" category:"流式输出"`
|
||||
StreamMaxDelay int `json:"stream_max_delay" default:"24" name:"最大延迟(秒)" category:"流式输出"`
|
||||
StreamShortTextThresh int `json:"stream_short_text_thresh" default:"10" name:"短文本阈值" category:"流式输出"`
|
||||
StreamLongTextThresh int `json:"stream_long_text_thresh" default:"50" name:"长文本阈值" category:"流式输出"`
|
||||
StreamChunkSize int `json:"stream_chunk_size" default:"5" name:"分块大小" category:"流式输出"`
|
||||
EnableFakeStream bool `json:"enable_fake_stream" default:"false" name:"启用假流式输出" category:"流式输出"`
|
||||
FakeStreamInterval int `json:"fake_stream_interval" default:"5" name:"假流式空数据发送间隔(秒)" category:"流式输出"`
|
||||
|
||||
// --- 定时任务配置 ---
|
||||
Timezone string `json:"timezone" default:"Asia/Shanghai" name:"时区" category:"定时任务"`
|
||||
|
||||
// --- [短期冻结] 为了UI兼容性而保留的“幻影”字段 ---
|
||||
AllowedTokens []string `json:"-"` // 不参与JSON序列化
|
||||
Proxies []string `json:"-"` // 不参与JSON序列化
|
||||
|
||||
ModelSettings ModelSettings `json:"model_settings" name:"模型配置"`
|
||||
}
|
||||
|
||||
// ModelSettings
|
||||
type ModelSettings struct {
|
||||
ImageModels []string `json:"image_models" name:"图像模型列表"`
|
||||
SearchModels []string `json:"search_models" name:"搜索模型列表"`
|
||||
FilteredModels []string `json:"filtered_models" name:"过滤模型列表"`
|
||||
EnableCodeExecutor bool `json:"enable_code_executor" default:"false" name:"启用代码执行工具"`
|
||||
EnableURLContext bool `json:"enable_url_context" default:"false" name:"启用网址上下文"`
|
||||
URLContextModels []string `json:"url_context_models" name:"网址上下文模型列表"`
|
||||
ShowSearchLink bool `json:"show_search_link" default:"false" name:"显示搜索链接"`
|
||||
ShowThinking bool `json:"show_thinking" default:"false" name:"显示思考过程"`
|
||||
ThinkingModels []string `json:"thinking_models" name:"思考模型列表"`
|
||||
ThinkingBudgetMap map[string]int `json:"thinking_budget_map" name:"思考模型预算映射"`
|
||||
SafetySettings []SafetySetting `json:"safety_settings" name:"安全设置"`
|
||||
}
|
||||
81
internal/pkg/reflectutil/structs.go
Normal file
81
internal/pkg/reflectutil/structs.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Filename: internal/pkg/reflectutil/structs.go
|
||||
package reflectutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// SetFieldFromString sets a struct field's value from a string.
|
||||
func SetFieldFromString(field reflect.Value, value string) error {
|
||||
if !field.CanSet() {
|
||||
return fmt.Errorf("cannot set field")
|
||||
}
|
||||
switch field.Kind() {
|
||||
case reflect.Int, reflect.Int64:
|
||||
intVal, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.SetInt(intVal)
|
||||
case reflect.String:
|
||||
field.SetString(value)
|
||||
case reflect.Bool:
|
||||
boolVal, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.SetBool(boolVal)
|
||||
default:
|
||||
return fmt.Errorf("unsupported field type: %s", field.Type())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FieldTypeToString converts a reflect.Type to a simple string representation.
|
||||
func FieldTypeToString(t reflect.Type) string {
|
||||
switch t.Kind() {
|
||||
case reflect.Int, reflect.Int64:
|
||||
return "int"
|
||||
case reflect.String:
|
||||
return "string"
|
||||
case reflect.Bool:
|
||||
return "bool"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// MergeNilFields uses reflection to merge non-nil pointer fields from 'override' into 'base'.
|
||||
// Both base and override must be pointers to structs of the same type.
|
||||
func MergeNilFields(base, override interface{}) error {
|
||||
baseVal := reflect.ValueOf(base)
|
||||
overrideVal := reflect.ValueOf(override)
|
||||
|
||||
if baseVal.Kind() != reflect.Ptr || overrideVal.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("base and override must be pointers")
|
||||
}
|
||||
|
||||
baseElem := baseVal.Elem()
|
||||
overrideElem := overrideVal.Elem()
|
||||
|
||||
if baseElem.Kind() != reflect.Struct || overrideElem.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("base and override must be pointers to structs")
|
||||
}
|
||||
|
||||
if baseElem.Type() != overrideElem.Type() {
|
||||
return fmt.Errorf("base and override must be of the same struct type")
|
||||
}
|
||||
|
||||
for i := 0; i < overrideElem.NumField(); i++ {
|
||||
overrideField := overrideElem.Field(i)
|
||||
if overrideField.Kind() == reflect.Ptr && !overrideField.IsNil() {
|
||||
baseField := baseElem.Field(i)
|
||||
if baseField.CanSet() {
|
||||
baseField.Set(overrideField)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
internal/pkg/stringutil/string.go
Normal file
57
internal/pkg/stringutil/string.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Filename: internal/pkg/stringutil/string.go
|
||||
package stringutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MaskAPIKey masks an API key for safe logging.
|
||||
func MaskAPIKey(key string) string {
|
||||
length := len(key)
|
||||
if length <= 8 {
|
||||
return key
|
||||
}
|
||||
return fmt.Sprintf("%s****%s", key[:4], key[length-4:])
|
||||
}
|
||||
|
||||
// TruncateString shortens a string to a maximum length.
|
||||
func TruncateString(s string, maxLength int) string {
|
||||
if len(s) > maxLength {
|
||||
return s[:maxLength]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SplitAndTrim splits a string by a separator
|
||||
func SplitAndTrim(s string, sep string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
parts := strings.Split(s, sep)
|
||||
result := make([]string, 0, len(parts))
|
||||
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// StringToSet converts a separator-delimited string into a set
|
||||
func StringToSet(s string, sep string) map[string]struct{} {
|
||||
parts := SplitAndTrim(s, sep)
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
set := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
set[part] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
101
internal/pongo/renderer.go
Normal file
101
internal/pongo/renderer.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Filename: internal/pongo/renderer.go
|
||||
|
||||
package pongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/flosch/pongo2/v6"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/render"
|
||||
)
|
||||
|
||||
type Renderer struct {
|
||||
Context pongo2.Context
|
||||
tplSet *pongo2.TemplateSet
|
||||
}
|
||||
|
||||
func New(directory string, isDebug bool) *Renderer {
|
||||
loader := pongo2.MustNewLocalFileSystemLoader(directory)
|
||||
tplSet := pongo2.NewSet("gin-pongo-templates", loader)
|
||||
tplSet.Debug = isDebug
|
||||
return &Renderer{Context: make(pongo2.Context), tplSet: tplSet}
|
||||
}
|
||||
|
||||
// Instance returns a new render.HTML instance for a single request.
|
||||
func (p *Renderer) Instance(name string, data interface{}) render.Render {
|
||||
var glob pongo2.Context
|
||||
if p.Context != nil {
|
||||
glob = p.Context
|
||||
}
|
||||
|
||||
var context pongo2.Context
|
||||
if data != nil {
|
||||
if ginContext, ok := data.(gin.H); ok {
|
||||
context = pongo2.Context(ginContext)
|
||||
} else if pongoContext, ok := data.(pongo2.Context); ok {
|
||||
context = pongoContext
|
||||
} else if m, ok := data.(map[string]interface{}); ok {
|
||||
context = m
|
||||
} else {
|
||||
context = make(pongo2.Context)
|
||||
}
|
||||
} else {
|
||||
context = make(pongo2.Context)
|
||||
}
|
||||
|
||||
for k, v := range glob {
|
||||
if _, ok := context[k]; !ok {
|
||||
context[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
tpl, err := p.tplSet.FromCache(name)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to load template '%s': %v", name, err))
|
||||
}
|
||||
|
||||
return &HTML{
|
||||
p: p,
|
||||
Template: tpl,
|
||||
Name: name,
|
||||
Data: context,
|
||||
}
|
||||
}
|
||||
|
||||
type HTML struct {
|
||||
p *Renderer
|
||||
Template *pongo2.Template
|
||||
Name string
|
||||
Data pongo2.Context
|
||||
}
|
||||
|
||||
func (h *HTML) Render(w http.ResponseWriter) error {
|
||||
h.WriteContentType(w)
|
||||
bytes, err := h.Template.ExecuteBytes(h.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(bytes)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *HTML) WriteContentType(w http.ResponseWriter) {
|
||||
header := w.Header()
|
||||
if val := header["Content-Type"]; len(val) == 0 {
|
||||
header["Content-Type"] = []string{"text/html; charset=utf-8"}
|
||||
}
|
||||
}
|
||||
|
||||
func C(ctx *gin.Context) pongo2.Context {
|
||||
p, exists := ctx.Get("pongo2")
|
||||
if exists {
|
||||
if pCtx, ok := p.(pongo2.Context); ok {
|
||||
return pCtx
|
||||
}
|
||||
}
|
||||
pCtx := make(pongo2.Context)
|
||||
ctx.Set("pongo2", pCtx)
|
||||
return pCtx
|
||||
}
|
||||
206
internal/repository/auth_token.go
Normal file
206
internal/repository/auth_token.go
Normal file
@@ -0,0 +1,206 @@
|
||||
// Filename: internal/repository/auth_token.go
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuthTokenRepository defines the interface for AuthToken data access.
|
||||
type AuthTokenRepository interface {
|
||||
GetAllTokensWithGroups() ([]*models.AuthToken, error)
|
||||
BatchUpdateTokens(updates []*models.TokenUpdateRequest) error
|
||||
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) // <-- Add this line
|
||||
SeedAdminToken(encryptedToken, tokenHash string) error // <-- And this line for the seeder
|
||||
}
|
||||
|
||||
type gormAuthTokenRepository struct {
|
||||
db *gorm.DB
|
||||
crypto *crypto.Service
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewAuthTokenRepository(db *gorm.DB, crypto *crypto.Service, logger *logrus.Logger) AuthTokenRepository {
|
||||
return &gormAuthTokenRepository{
|
||||
db: db,
|
||||
crypto: crypto,
|
||||
logger: logger.WithField("component", "repository.authToken🔐"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllTokensWithGroups fetches all tokens and decrypts them for use in services.
|
||||
func (r *gormAuthTokenRepository) GetAllTokensWithGroups() ([]*models.AuthToken, error) {
|
||||
var tokens []*models.AuthToken
|
||||
if err := r.db.Preload("AllowedGroups").Find(&tokens).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt all tokens before returning them.
|
||||
if err := r.decryptTokens(tokens); err != nil {
|
||||
// Log the error but return the partially decrypted data, as some might be usable.
|
||||
r.logger.WithError(err).Error("Batch decryption failed for some auth tokens.")
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// BatchUpdateTokens provides a transactional way to update all tokens, handling encryption.
|
||||
func (r *gormAuthTokenRepository) BatchUpdateTokens(updates []*models.TokenUpdateRequest) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Separate admin and user tokens from the request
|
||||
var adminUpdate *models.TokenUpdateRequest
|
||||
var userUpdates []*models.TokenUpdateRequest
|
||||
for _, u := range updates {
|
||||
if u.IsAdmin {
|
||||
adminUpdate = u
|
||||
} else {
|
||||
userUpdates = append(userUpdates, u)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Handle Admin Token Update
|
||||
if adminUpdate != nil && adminUpdate.Token != "" {
|
||||
encryptedToken, err := r.crypto.Encrypt(adminUpdate.Token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt admin token: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(adminUpdate.Token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// Update both encrypted value and the hash
|
||||
updateData := map[string]interface{}{
|
||||
"encrypted_token": encryptedToken,
|
||||
"token_hash": tokenHash,
|
||||
}
|
||||
if err := tx.Model(&models.AuthToken{}).Where("is_admin = ?", true).Updates(updateData).Error; err != nil {
|
||||
return fmt.Errorf("failed to update admin token in db: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle User Tokens Upsert
|
||||
var existingTokens []*models.AuthToken
|
||||
if err := tx.Where("is_admin = ?", false).Find(&existingTokens).Error; err != nil {
|
||||
return fmt.Errorf("failed to fetch existing user tokens: %w", err)
|
||||
}
|
||||
existingTokenMap := make(map[uint]bool)
|
||||
for _, t := range existingTokens {
|
||||
existingTokenMap[t.ID] = true
|
||||
}
|
||||
|
||||
var tokensToUpsert []models.AuthToken
|
||||
for _, req := range userUpdates {
|
||||
if req.Token == "" {
|
||||
continue // Skip tokens with empty values
|
||||
}
|
||||
encryptedToken, err := r.crypto.Encrypt(req.Token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt token for upsert (ID: %d): %w", req.ID, err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(req.Token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
var groups []*models.KeyGroup
|
||||
if len(req.AllowedGroupIDs) > 0 {
|
||||
if err := tx.Find(&groups, req.AllowedGroupIDs).Error; err != nil {
|
||||
return fmt.Errorf("failed to find key groups for token %d: %w", req.ID, err)
|
||||
}
|
||||
}
|
||||
tokensToUpsert = append(tokensToUpsert, models.AuthToken{
|
||||
ID: req.ID,
|
||||
EncryptedToken: encryptedToken,
|
||||
TokenHash: tokenHash,
|
||||
Description: req.Description,
|
||||
Tag: req.Tag,
|
||||
Status: req.Status,
|
||||
IsAdmin: false,
|
||||
AllowedGroups: groups,
|
||||
})
|
||||
}
|
||||
if len(tokensToUpsert) > 0 {
|
||||
if err := tx.Save(&tokensToUpsert).Error; err != nil {
|
||||
return fmt.Errorf("failed to upsert user tokens: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Handle Deletions
|
||||
incomingUserTokenIDs := make(map[uint]bool)
|
||||
for _, u := range userUpdates {
|
||||
if u.ID != 0 {
|
||||
incomingUserTokenIDs[u.ID] = true
|
||||
}
|
||||
}
|
||||
var idsToDelete []uint
|
||||
for id := range existingTokenMap {
|
||||
if !incomingUserTokenIDs[id] {
|
||||
idsToDelete = append(idsToDelete, id)
|
||||
}
|
||||
}
|
||||
if len(idsToDelete) > 0 {
|
||||
if err := tx.Model(&models.AuthToken{}).Where("id IN ?", idsToDelete).Association("AllowedGroups").Clear(); err != nil {
|
||||
return fmt.Errorf("failed to clear associations for tokens to be deleted: %w", err)
|
||||
}
|
||||
if err := tx.Where("id IN ?", idsToDelete).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete user tokens: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// --- Crypto Helper Functions ---
|
||||
|
||||
func (r *gormAuthTokenRepository) decryptToken(token *models.AuthToken) error {
|
||||
if token == nil || token.EncryptedToken == "" || token.Token != "" {
|
||||
return nil // Nothing to decrypt or already done
|
||||
}
|
||||
plaintext, err := r.crypto.Decrypt(token.EncryptedToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt auth token ID %d: %w", token.ID, err)
|
||||
}
|
||||
token.Token = plaintext
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormAuthTokenRepository) decryptTokens(tokens []*models.AuthToken) error {
|
||||
for i := range tokens {
|
||||
if err := r.decryptToken(tokens[i]); err != nil {
|
||||
r.logger.Error(err) // Log error but continue for other tokens
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenByHashedValue finds a token by its SHA256 hash for authentication.
|
||||
func (r *gormAuthTokenRepository) GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) {
|
||||
var authToken models.AuthToken
|
||||
// Find the active token by its hash. This is the core of our secure authentication.
|
||||
err := r.db.Where("token_hash = ? AND status = 'active'", tokenHash).Preload("AllowedGroups").First(&authToken).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt the token before returning it to the service layer.
|
||||
// This ensures that subsequent logic (like in ResourceService) gets the full, usable object.
|
||||
if err := r.decryptToken(&authToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authToken, nil
|
||||
}
|
||||
|
||||
// SeedAdminToken is a special-purpose function for the seeder to insert the initial admin token.
|
||||
func (r *gormAuthTokenRepository) SeedAdminToken(encryptedToken, tokenHash string) error {
|
||||
adminToken := models.AuthToken{
|
||||
EncryptedToken: encryptedToken,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Default Administrator Token",
|
||||
Tag: "SYSTEM_ADMIN",
|
||||
IsAdmin: true,
|
||||
Status: "active", // Ensure the seeded token is active
|
||||
}
|
||||
// Using FirstOrCreate to be idempotent. If an admin token already exists, it does nothing.
|
||||
return r.db.Where(models.AuthToken{IsAdmin: true}).FirstOrCreate(&adminToken).Error
|
||||
}
|
||||
37
internal/repository/group_repository.go
Normal file
37
internal/repository/group_repository.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Filename: internal/repository/group_repository.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (r *gormGroupRepository) GetGroupByName(name string) (*models.KeyGroup, error) {
|
||||
var group models.KeyGroup
|
||||
if err := r.db.Where("name = ?", name).First(&group).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (r *gormGroupRepository) GetAllGroups() ([]*models.KeyGroup, error) {
|
||||
var groups []*models.KeyGroup
|
||||
if err := r.db.Order("\"order\" asc, id desc").Find(&groups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// 更新group排序
|
||||
func (r *gormGroupRepository) UpdateOrderInTransaction(orders map[uint]int) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for id, order := range orders {
|
||||
result := tx.Model(&models.KeyGroup{}).Where("id = ?", id).Update("order", order)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
204
internal/repository/key_cache.go
Normal file
204
internal/repository/key_cache.go
Normal file
@@ -0,0 +1,204 @@
|
||||
// Filename: internal/repository/key_cache.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
KeyGroup = "group:%d:keys:active"
|
||||
KeyDetails = "key:%d:details"
|
||||
KeyMapping = "mapping:%d:%d"
|
||||
KeyGroupSequential = "group:%d:keys:sequential"
|
||||
KeyGroupLRU = "group:%d:keys:lru"
|
||||
KeyGroupRandomMain = "group:%d:keys:random:main"
|
||||
KeyGroupRandomCooldown = "group:%d:keys:random:cooldown"
|
||||
BasePoolSequential = "basepool:%s:keys:sequential"
|
||||
BasePoolLRU = "basepool:%s:keys:lru"
|
||||
BasePoolRandomMain = "basepool:%s:keys:random:main"
|
||||
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
|
||||
}
|
||||
|
||||
keyMap := make(map[uint]*models.APIKey)
|
||||
for _, m := range allMappings {
|
||||
if m.APIKey != nil {
|
||||
keyMap[m.APIKey.ID] = m.APIKey
|
||||
}
|
||||
}
|
||||
keysToDecrypt := make([]models.APIKey, 0, len(keyMap))
|
||||
for _, k := range keyMap {
|
||||
keysToDecrypt = append(keysToDecrypt, *k)
|
||||
}
|
||||
if err := r.decryptKeys(keysToDecrypt); err != nil {
|
||||
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
|
||||
}
|
||||
decryptedKeyMap := make(map[uint]models.APIKey)
|
||||
for _, k := range keysToDecrypt {
|
||||
decryptedKeyMap[k.ID] = k
|
||||
}
|
||||
|
||||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
pipe := r.store.Pipeline()
|
||||
detailsToSet := make(map[string][]byte)
|
||||
var allGroups []*models.KeyGroup
|
||||
if err := r.db.Find(&allGroups).Error; err == nil {
|
||||
for _, group := range allGroups {
|
||||
pipe.Del(
|
||||
fmt.Sprintf(KeyGroup, group.ID),
|
||||
fmt.Sprintf(KeyGroupSequential, group.ID),
|
||||
fmt.Sprintf(KeyGroupLRU, group.ID),
|
||||
fmt.Sprintf(KeyGroupRandomMain, group.ID),
|
||||
fmt.Sprintf(KeyGroupRandomCooldown, group.ID),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
|
||||
}
|
||||
|
||||
for _, mapping := range allMappings {
|
||||
if mapping.APIKey == nil {
|
||||
continue
|
||||
}
|
||||
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keyJSON, _ := json.Marshal(decryptedKey)
|
||||
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
|
||||
mappingJSON, _ := json.Marshal(mapping)
|
||||
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
|
||||
if mapping.Status == models.StatusActive {
|
||||
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
}
|
||||
|
||||
for groupID, activeMappings := range activeKeysByGroup {
|
||||
if len(activeMappings) == 0 {
|
||||
continue
|
||||
}
|
||||
var activeKeyIDs []interface{}
|
||||
lruMembers := make(map[string]float64)
|
||||
for _, mapping := range activeMappings {
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
activeKeyIDs = append(activeKeyIDs, keyIDStr)
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
}
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
|
||||
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
}
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
|
||||
}
|
||||
for key, value := range detailsToSet {
|
||||
if err := r.store.Set(key, value, 0); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Cache rebuild complete, including all polling structures.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
|
||||
}
|
||||
keyJSON, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
|
||||
}
|
||||
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(key.ID)
|
||||
if err != nil {
|
||||
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
|
||||
}
|
||||
|
||||
pipe := r.store.Pipeline()
|
||||
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
|
||||
|
||||
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
}
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
pipe := r.store.Pipeline()
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
|
||||
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
|
||||
if mapping.Status == models.StatusActive {
|
||||
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
|
||||
}
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
|
||||
if len(mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
groupUpdates := make(map[uint]struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
})
|
||||
for _, mapping := range mappings {
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
update, ok := groupUpdates[mapping.KeyGroupID]
|
||||
if !ok {
|
||||
update = struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
}{}
|
||||
}
|
||||
if mapping.Status == models.StatusActive {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
update.ToAdd = append(update.ToAdd, keyIDStr)
|
||||
} else {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
}
|
||||
groupUpdates[mapping.KeyGroupID] = update
|
||||
}
|
||||
pipe := r.store.Pipeline()
|
||||
var pipelineError error
|
||||
for groupID, updates := range groupUpdates {
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
if len(updates.ToRemove) > 0 {
|
||||
for _, keyID := range updates.ToRemove {
|
||||
pipe.LRem(activeKeyListKey, 0, keyID)
|
||||
}
|
||||
}
|
||||
if len(updates.ToAdd) > 0 {
|
||||
pipe.LPush(activeKeyListKey, updates.ToAdd...)
|
||||
}
|
||||
}
|
||||
if err := pipe.Exec(); err != nil {
|
||||
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
|
||||
}
|
||||
return pipelineError
|
||||
}
|
||||
280
internal/repository/key_crud.go
Normal file
280
internal/repository/key_crud.go
Normal file
@@ -0,0 +1,280 @@
|
||||
// Filename: internal/repository/key_crud.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, error) {
|
||||
if len(keys) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
keyHashes := make([]string, len(keys))
|
||||
keyValueToHashMap := make(map[string]string)
|
||||
for i, k := range keys {
|
||||
// All incoming keys must have plaintext APIKey
|
||||
if k.APIKey == "" {
|
||||
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(k.APIKey))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
keyHashes[i] = hashStr
|
||||
keyValueToHashMap[k.APIKey] = hashStr
|
||||
}
|
||||
var finalKeys []models.APIKey
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var existingKeys []models.APIKey
|
||||
// [MODIFIED] Query by hash to find existing keys.
|
||||
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
existingKeyHashMap := make(map[string]models.APIKey)
|
||||
for _, k := range existingKeys {
|
||||
existingKeyHashMap[k.APIKeyHash] = k
|
||||
}
|
||||
var keysToCreate []models.APIKey
|
||||
var keysToRestore []uint
|
||||
for _, keyObj := range keys {
|
||||
keyVal := keyObj.APIKey
|
||||
hash := keyValueToHashMap[keyVal]
|
||||
if ek, found := existingKeyHashMap[hash]; found {
|
||||
if ek.DeletedAt.Valid {
|
||||
keysToRestore = append(keysToRestore, ek.ID)
|
||||
}
|
||||
} else {
|
||||
encryptedKey, err := r.crypto.Encrypt(keyVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key '%s...': %w", keyVal[:min(4, len(keyVal))], err)
|
||||
}
|
||||
keysToCreate = append(keysToCreate, models.APIKey{
|
||||
EncryptedKey: encryptedKey,
|
||||
APIKeyHash: hash,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(keysToRestore) > 0 {
|
||||
if err := tx.Model(&models.APIKey{}).Unscoped().Where("id IN ?", keysToRestore).Update("deleted_at", nil).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(keysToCreate) > 0 {
|
||||
// [MODIFIED] Create now only provides encrypted data and hash.
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
|
||||
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// [CRITICAL] Decrypt all keys before returning them to the service layer.
|
||||
return r.decryptKeys(finalKeys)
|
||||
})
|
||||
return finalKeys, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
|
||||
// This indicates a potential change that needs to be re-encrypted.
|
||||
if key.APIKey != "" {
|
||||
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to re-encrypt key on update for ID %d: %w", key.ID, err)
|
||||
}
|
||||
key.EncryptedKey = encryptedKey
|
||||
// Recalculate hash as a defensive measure.
|
||||
hash := sha256.Sum256([]byte(key.APIKey))
|
||||
key.APIKeyHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
|
||||
return tx.Save(key).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
|
||||
return nil // Continue without cache update if decryption fails.
|
||||
}
|
||||
if err := r.updateStoreCacheForKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
key, err := r.GetKeyByID(id) // This now returns a decrypted key
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Unscoped().Delete(&models.APIKey{}, id).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.removeStoreCacheForKey(key); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error) {
|
||||
if len(keyValues) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
hashes := make([]string, len(keyValues))
|
||||
for i, v := range keyValues {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
// Find the full key objects first to update the cache later.
|
||||
var keysToDelete []models.APIKey
|
||||
// [MODIFIED] Find by hash.
|
||||
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(keysToDelete) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Decrypt them to ensure cache has plaintext if needed.
|
||||
if err := r.decryptKeys(keysToDelete); err != nil {
|
||||
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
|
||||
}
|
||||
var deletedCount int64
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
ids := pluckIDs(keysToDelete)
|
||||
result := tx.Unscoped().Where("id IN ?", ids).Delete(&models.APIKey{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
deletedCount = result.RowsAffected
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
return deletedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeyByID(id uint) (*models.APIKey, error) {
|
||||
var key models.APIKey
|
||||
if err := r.db.First(&key, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKey(&key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
|
||||
if len(ids) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
var keys []models.APIKey
|
||||
err := r.db.Where("id IN ?", ids).Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt before returning.
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeyByValue(keyValue string) (*models.APIKey, error) {
|
||||
hash := sha256.Sum256([]byte(keyValue))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
var key models.APIKey
|
||||
if err := r.db.Where("api_key_hash = ?", hashStr).First(&key).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key.APIKey = keyValue
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeysByValues(keyValues []string) ([]models.APIKey, error) {
|
||||
if len(keyValues) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
hashes := make([]string, len(keyValues))
|
||||
for i, v := range keyValues {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
var keys []models.APIKey
|
||||
err := r.db.Where("api_key_hash IN ?", hashes).Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetKeysByGroup(groupID uint) ([]models.APIKey, error) {
|
||||
var keys []models.APIKey
|
||||
err := r.db.Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
||||
Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) CountByGroup(groupID uint) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&models.APIKey{}).
|
||||
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func (r *gormKeyRepository) executeTransactionWithRetry(operation func(tx *gorm.DB) error) error {
|
||||
const maxRetries = 3
|
||||
const baseDelay = 50 * time.Millisecond
|
||||
const maxJitter = 150 * time.Millisecond
|
||||
var err error
|
||||
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
err = r.db.Transaction(operation)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "database is locked") {
|
||||
jitter := time.Duration(rand.Intn(int(maxJitter)))
|
||||
totalDelay := baseDelay + jitter
|
||||
r.logger.Debugf("Database is locked, retrying in %v... (attempt %d/%d)", totalDelay, i+1, maxRetries)
|
||||
time.Sleep(totalDelay)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func pluckIDs(keys []models.APIKey) []uint {
|
||||
ids := make([]uint, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
ids = append(ids, key.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
62
internal/repository/key_crypto.go
Normal file
62
internal/repository/key_crypto.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Filename: internal/repository/key_crypto.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) decryptKey(key *models.APIKey) error {
|
||||
if key == nil || key.EncryptedKey == "" {
|
||||
return nil // Nothing to decrypt
|
||||
}
|
||||
// Avoid re-decrypting if plaintext already exists
|
||||
if key.APIKey != "" {
|
||||
return nil
|
||||
}
|
||||
plaintext, err := r.crypto.Decrypt(key.EncryptedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt key ID %d: %w", key.ID, err)
|
||||
}
|
||||
key.APIKey = plaintext
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) decryptKeys(keys []models.APIKey) error {
|
||||
for i := range keys {
|
||||
if err := r.decryptKey(&keys[i]); err != nil {
|
||||
// In a batch operation, we log the error but allow the rest to proceed.
|
||||
r.logger.Errorf("Batch decrypt error for key index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt 实现了 KeyRepository 接口
|
||||
func (r *gormKeyRepository) Decrypt(key *models.APIKey) error {
|
||||
if key == nil || len(key.EncryptedKey) == 0 {
|
||||
return nil // Nothing to decrypt
|
||||
}
|
||||
// Avoid re-decrypting if plaintext already exists
|
||||
if key.APIKey != "" {
|
||||
return nil
|
||||
}
|
||||
plaintext, err := r.crypto.Decrypt(key.EncryptedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt key ID %d: %w", key.ID, err)
|
||||
}
|
||||
key.APIKey = plaintext
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptBatch 实现了 KeyRepository 接口
|
||||
func (r *gormKeyRepository) DecryptBatch(keys []models.APIKey) error {
|
||||
for i := range keys {
|
||||
// This delegates to the robust single-key decryption logic.
|
||||
if err := r.Decrypt(&keys[i]); err != nil {
|
||||
// In a batch operation, we log the error but allow the rest to proceed.
|
||||
r.logger.Errorf("Batch decrypt error for key index %d (ID: %d): %v", i, keys[i].ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
169
internal/repository/key_maintenance.go
Normal file
169
internal/repository/key_maintenance.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Filename: internal/repository/key_maintenance.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/models"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error {
|
||||
query := r.db.Model(&models.APIKey{}).
|
||||
Joins("JOIN group_api_key_mappings on group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID)
|
||||
|
||||
if statusFilter != "" && statusFilter != "all" {
|
||||
query = query.Where("group_api_key_mappings.status = ?", statusFilter)
|
||||
}
|
||||
var batchKeys []models.APIKey
|
||||
return query.FindInBatches(&batchKeys, 1000, func(tx *gorm.DB, batch int) error {
|
||||
if err := r.decryptKeys(batchKeys); err != nil {
|
||||
r.logger.Errorf("Failed to decrypt batch %d for streaming: %v", batch, err)
|
||||
}
|
||||
for _, key := range batchKeys {
|
||||
if key.APIKey != "" {
|
||||
if _, err := writer.Write([]byte(key.APIKey + "\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error) {
|
||||
if len(keyValues) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
hashes := make([]string, len(keyValues))
|
||||
for i, v := range keyValues {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
var result *gorm.DB
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result = tx.Model(&models.APIKey{}).
|
||||
Where("api_key_hash IN ?", hashes).
|
||||
Update("master_status", newStatus)
|
||||
return result.Error
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error {
|
||||
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&models.APIKey{}).
|
||||
Where("id = ?", keyID).
|
||||
Update("master_status", newStatus)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
// This ensures that if the key ID doesn't exist, we return a standard "not found" error.
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) DeleteOrphanKeys() (int64, error) {
|
||||
var deletedCount int64
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
count, err := r.deleteOrphanKeysLogic(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deletedCount = count
|
||||
return nil
|
||||
})
|
||||
return deletedCount, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) DeleteOrphanKeysTx(tx *gorm.DB) (int64, error) {
|
||||
return r.deleteOrphanKeysLogic(tx)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
|
||||
var orphanKeyIDs []uint
|
||||
err := db.Raw(`
|
||||
SELECT api_keys.id FROM api_keys
|
||||
LEFT JOIN group_api_key_mappings ON api_keys.id = group_api_key_mappings.api_key_id
|
||||
WHERE group_api_key_mappings.api_key_id IS NULL`).Scan(&orphanKeyIDs).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(orphanKeyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var keysToDelete []models.APIKey
|
||||
if err := db.Where("id IN ?", orphanKeyIDs).Find(&keysToDelete).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteSoftDeletedBefore(date time.Time) (int64, error) {
|
||||
result := r.db.Unscoped().Where("deleted_at < ?", date).Delete(&models.APIKey{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
|
||||
var keys []*models.APIKey
|
||||
err := r.db.Where("master_status = ?", models.MasterStatusActive).Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, key := range keys {
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
r.logger.Warnf("Failed to decrypt key ID %d during GetActiveMasterKeys: %v", key.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&models.APIKey{}).
|
||||
Where("id = ?", keyID).
|
||||
Update("master_status", status)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
|
||||
go func() {
|
||||
if err := r.LoadAllKeysToStore(); err != nil {
|
||||
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
return err
|
||||
}
|
||||
289
internal/repository/key_mapping.go
Normal file
289
internal/repository/key_mapping.go
Normal file
@@ -0,0 +1,289 @@
|
||||
// Filename: internal/repository/key_mapping.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
var mappings []models.GroupAPIKeyMapping
|
||||
for _, keyID := range keyIDs {
|
||||
mappings = append(mappings, models.GroupAPIKeyMapping{
|
||||
KeyGroupID: groupID,
|
||||
APIKeyID: keyID,
|
||||
})
|
||||
}
|
||||
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&mappings).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var unlinkedCount int64
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Table("group_api_key_mappings").
|
||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||
Delete(nil)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
unlinkedCount = result.RowsAffected
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
}
|
||||
|
||||
return unlinkedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
|
||||
strGroupIDs, err := r.store.SMembers(cacheKey)
|
||||
if err != nil || len(strGroupIDs) == 0 {
|
||||
var groupIDs []uint
|
||||
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
if len(groupIDs) > 0 {
|
||||
var interfaceSlice []interface{}
|
||||
for _, id := range groupIDs {
|
||||
interfaceSlice = append(interfaceSlice, id)
|
||||
}
|
||||
r.store.SAdd(cacheKey, interfaceSlice...)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
var groupIDs []uint
|
||||
for _, strID := range strGroupIDs {
|
||||
id, _ := strconv.Atoi(strID)
|
||||
groupIDs = append(groupIDs, uint(id))
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) {
|
||||
var mapping models.GroupAPIKeyMapping
|
||||
err := r.db.Where("key_group_id = ? AND api_key_id = ?", groupID, keyID).First(&mapping).Error
|
||||
return &mapping, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.updateStoreCacheForMapping(mapping)
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] This is the final version for the core refactoring.
|
||||
func (r *gormKeyRepository) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) {
|
||||
items := make([]*models.APIKeyDetails, 0)
|
||||
var total int64
|
||||
|
||||
query := r.db.Table("api_keys").
|
||||
Select(`
|
||||
api_keys.id, api_keys.created_at, api_keys.updated_at,
|
||||
api_keys.encrypted_key, -- Select encrypted key to be scanned into APIKeyDetails.EncryptedKey
|
||||
api_keys.master_status,
|
||||
m.status, m.last_error, m.consecutive_error_count, m.last_used_at, m.cooldown_until
|
||||
`).
|
||||
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id")
|
||||
|
||||
if params.KeyGroupID <= 0 {
|
||||
return nil, 0, errors.New("KeyGroupID is required for this query")
|
||||
}
|
||||
query = query.Where("m.key_group_id = ?", params.KeyGroupID)
|
||||
|
||||
if params.Status != "" {
|
||||
query = query.Where("LOWER(m.status) = LOWER(?)", params.Status)
|
||||
}
|
||||
|
||||
// Keyword search is now handled by the service layer.
|
||||
if params.Keyword != "" {
|
||||
r.logger.Warn("DB query is ignoring keyword; service layer will perform in-memory filtering.")
|
||||
}
|
||||
|
||||
countQuery := query.Model(&models.APIKey{}) // Use model for count to avoid GORM issues
|
||||
err := countQuery.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if total == 0 {
|
||||
return items, 0, nil
|
||||
}
|
||||
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
err = query.Order("api_keys.id DESC").Limit(params.PageSize).Offset(offset).Scan(&items).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Decrypt all results before returning. This loop is now valid.
|
||||
for i := range items {
|
||||
if items[i].EncryptedKey != "" {
|
||||
plaintext, err := r.crypto.Decrypt(items[i].EncryptedKey)
|
||||
if err == nil {
|
||||
items[i].APIKey = plaintext
|
||||
} else {
|
||||
items[i].APIKey = "[DECRYPTION FAILED]"
|
||||
r.logger.Errorf("Failed to decrypt key ID %d for pagination: %v", items[i].ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Uses hashes for lookup.
|
||||
func (r *gormKeyRepository) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) {
|
||||
if len(values) == 0 {
|
||||
return []models.APIKey{}, nil
|
||||
}
|
||||
hashes := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
var keys []models.APIKey
|
||||
err := r.db.Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID).
|
||||
Where("api_keys.api_key_hash IN ?", hashes).
|
||||
Find(&keys).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Fetches full objects, decrypts, then extracts strings.
|
||||
func (r *gormKeyRepository) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) {
|
||||
var keys []models.APIKey
|
||||
query := r.db.Table("api_keys").
|
||||
Select("api_keys.*").
|
||||
Joins("JOIN group_api_key_mappings as m ON m.api_key_id = api_keys.id").
|
||||
Where("m.key_group_id = ?", groupID)
|
||||
|
||||
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(m.status) IN (?)", lowerStatuses)
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKeys(keys); err != nil {
|
||||
return nil, fmt.Errorf("decryption failed during FindKeyValuesByStatus: %w", err)
|
||||
}
|
||||
|
||||
keyValues := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
keyValues[i] = key.APIKey
|
||||
}
|
||||
return keyValues, nil
|
||||
}
|
||||
|
||||
// [MODIFIED & FINAL] Consistent with the new pattern.
|
||||
func (r *gormKeyRepository) GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error) {
|
||||
var keys []models.APIKey
|
||||
query := r.db.Table("api_keys").
|
||||
Select("api_keys.*").
|
||||
Joins("JOIN group_api_key_mappings ON group_api_key_mappings.api_key_id = api_keys.id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", groupID)
|
||||
|
||||
if len(statuses) > 0 {
|
||||
isAll := false
|
||||
for _, s := range statuses {
|
||||
if strings.ToLower(s) == "all" {
|
||||
isAll = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isAll {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(group_api_key_mappings.status) IN ?", lowerStatuses)
|
||||
}
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decryptKeys(keys); err != nil {
|
||||
return nil, fmt.Errorf("decryption failed during GetKeyStringsByGroupAndStatus: %w", err)
|
||||
}
|
||||
|
||||
keyStrings := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
keyStrings[i] = key.APIKey
|
||||
}
|
||||
return keyStrings, nil
|
||||
}
|
||||
|
||||
// FindKeyIDsByStatus remains unchanged as it does not deal with key values.
|
||||
func (r *gormKeyRepository) FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error) {
|
||||
var keyIDs []uint
|
||||
query := r.db.Table("group_api_key_mappings").
|
||||
Select("api_key_id").
|
||||
Where("key_group_id = ?", groupID)
|
||||
if len(statuses) > 0 && !(len(statuses) == 1 && statuses[0] == "all") {
|
||||
lowerStatuses := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
lowerStatuses[i] = strings.ToLower(s)
|
||||
}
|
||||
query = query.Where("LOWER(status) IN (?)", lowerStatuses)
|
||||
}
|
||||
if err := query.Pluck("api_key_id", &keyIDs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keyIDs, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error {
|
||||
return r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
}
|
||||
276
internal/repository/key_selector.go
Normal file
276
internal/repository/key_selector.go
Normal file
@@ -0,0 +1,276 @@
|
||||
// Filename: internal/repository/key_selector.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
CacheTTL = 5 * time.Minute
|
||||
EmptyPoolPlaceholder = "EMPTY_POOL"
|
||||
EmptyCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
|
||||
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch group.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
}
|
||||
err = zerr
|
||||
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
|
||||
default: // 默认或未指定策略时,使用基础的随机策略
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
|
||||
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
|
||||
}
|
||||
|
||||
return apiKey, mapping, nil
|
||||
}
|
||||
|
||||
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
// 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离
|
||||
poolID := generatePoolID(pool.CandidateGroups)
|
||||
log := r.logger.WithField("pool_id", poolID)
|
||||
|
||||
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch pool.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
}
|
||||
err = zerr
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||||
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
|
||||
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
|
||||
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
|
||||
return nil, nil, err
|
||||
}
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
for _, group := range pool.CandidateGroups {
|
||||
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if cacheErr == nil && apiKey != nil && mapping != nil {
|
||||
|
||||
if pool.PollingStrategy == models.StrategyWeighted {
|
||||
|
||||
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
|
||||
}
|
||||
return apiKey, group, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
|
||||
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
|
||||
}
|
||||
|
||||
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
|
||||
// 使用 LIST 键作为存在性检查的标志
|
||||
listKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
exists, err := r.store.Exists(listKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
|
||||
val, err := r.store.LIndex(listKey, 0)
|
||||
if err == nil && val == EmptyPoolPlaceholder {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID)
|
||||
|
||||
var allActiveKeyIDs []string
|
||||
lruMembers := make(map[string]float64)
|
||||
for _, group := range pool.CandidateGroups {
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID)
|
||||
continue
|
||||
}
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
|
||||
|
||||
for _, keyIDStr := range groupKeyIDs {
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if err == nil && mapping != nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(allActiveKeyIDs) == 0 {
|
||||
pipe := r.store.Pipeline()
|
||||
pipe.LPush(listKey, EmptyPoolPlaceholder)
|
||||
pipe.Expire(listKey, EmptyCacheTTL)
|
||||
if err := pipe.Exec(); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
// 使用管道填充所有轮询结构
|
||||
pipe := r.store.Pipeline()
|
||||
// 1. 顺序
|
||||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
// 2. 随机
|
||||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
|
||||
// 设置合理的过期时间,例如5分钟,以防止孤儿数据
|
||||
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(lruMembers) > 0 {
|
||||
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
r.store.ZAdd(lruKey, map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||||
func generatePoolID(groups []*models.KeyGroup) string {
|
||||
ids := make([]int, len(groups))
|
||||
for i, g := range groups {
|
||||
ids[i] = int(g.ID)
|
||||
}
|
||||
sort.Ints(ids)
|
||||
|
||||
h := sha1.New()
|
||||
io.WriteString(h, fmt.Sprintf("%v", ids))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// toInterfaceSlice 类型转换辅助函数
|
||||
func toInterfaceSlice(slice []string) []interface{} {
|
||||
result := make([]interface{}, len(slice))
|
||||
for i, v := range slice {
|
||||
result[i] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略
|
||||
func nowMilli() float64 {
|
||||
return float64(time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
|
||||
}
|
||||
var apiKey models.APIKey
|
||||
if err := json.Unmarshal(apiKeyJSON, &apiKey); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
|
||||
}
|
||||
|
||||
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
|
||||
}
|
||||
var mapping models.GroupAPIKeyMapping
|
||||
if err := json.Unmarshal(mappingJSON, &mapping); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal mapping for key %d in group %d: %w", keyID, groupID, err)
|
||||
}
|
||||
|
||||
return &apiKey, &mapping, nil
|
||||
}
|
||||
77
internal/repository/key_writer.go
Normal file
77
internal/repository/key_writer.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Filename: internal/repository/key_writer.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
timestamp := float64(time.Now().UnixMilli())
|
||||
|
||||
members := map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): timestamp,
|
||||
}
|
||||
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
|
||||
|
||||
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(lruKey, keyIDStr)
|
||||
_ = r.store.SRem(mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
|
||||
|
||||
if newStatus == models.StatusActive {
|
||||
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
|
||||
}
|
||||
members := map[string]float64{keyIDStr: 0}
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
|
||||
}
|
||||
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
if apiErr == nil {
|
||||
r.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided.", key.ID, group.ID)
|
||||
return
|
||||
}
|
||||
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
|
||||
|
||||
// This call is correct. It uses the synchronous, direct method.
|
||||
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
}
|
||||
107
internal/repository/repository.go
Normal file
107
internal/repository/repository.go
Normal file
@@ -0,0 +1,107 @@
|
||||
// Filename: internal/repository/repository.go
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BasePool 虚拟的临时资源池,用于智能聚合模式。
|
||||
type BasePool struct {
|
||||
CandidateGroups []*models.KeyGroup
|
||||
PollingStrategy models.PollingStrategy
|
||||
}
|
||||
|
||||
type KeyRepository interface {
|
||||
// --- 核心选取与调度 --- key_selector
|
||||
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
|
||||
// --- 加密与解密 --- key_crud
|
||||
Decrypt(key *models.APIKey) error
|
||||
DecryptBatch(keys []models.APIKey) error
|
||||
|
||||
// --- 基础增删改查 --- key_crud
|
||||
AddKeys(keys []models.APIKey) ([]models.APIKey, error)
|
||||
Update(key *models.APIKey) error
|
||||
HardDeleteByID(id uint) error
|
||||
HardDeleteByValues(keyValues []string) (int64, error)
|
||||
GetKeyByID(id uint) (*models.APIKey, error)
|
||||
GetKeyByValue(keyValue string) (*models.APIKey, error)
|
||||
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
|
||||
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
|
||||
CountByGroup(groupID uint) (int64, error)
|
||||
|
||||
// --- 多对多关系管理 --- key_mapping
|
||||
LinkKeysToGroup(groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(keyID uint) ([]uint, error)
|
||||
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
|
||||
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
|
||||
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
|
||||
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
|
||||
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
|
||||
FindKeyIDsByStatus(groupID uint, statuses []string) ([]uint, error)
|
||||
GetKeyStringsByGroupAndStatus(groupID uint, statuses []string) ([]string, error)
|
||||
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 缓存管理 --- key_cache
|
||||
LoadAllKeysToStore() error
|
||||
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 维护与后台任务 --- key_maintenance
|
||||
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
|
||||
UpdateMasterStatusByValues(keyValues []string, newStatus models.MasterAPIKeyStatus) (int64, error)
|
||||
UpdateMasterStatusByID(keyID uint, newStatus models.MasterAPIKeyStatus) error
|
||||
DeleteOrphanKeys() (int64, error)
|
||||
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
|
||||
GetActiveMasterKeys() ([]*models.APIKey, error)
|
||||
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
|
||||
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
|
||||
|
||||
// --- 轮询策略的"写"操作 --- key_writer
|
||||
UpdateKeyUsageTimestamp(groupID, keyID uint)
|
||||
// 同步更新缓存,供核心业务使用
|
||||
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
// 异步更新缓存,供事件订阅者使用
|
||||
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
}
|
||||
|
||||
type GroupRepository interface {
|
||||
GetGroupByName(name string) (*models.KeyGroup, error)
|
||||
GetAllGroups() ([]*models.KeyGroup, error)
|
||||
UpdateOrderInTransaction(orders map[uint]int) error
|
||||
}
|
||||
|
||||
type gormKeyRepository struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
crypto *crypto.Service
|
||||
}
|
||||
|
||||
type gormGroupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
|
||||
return &gormKeyRepository{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "repository.key🔗"),
|
||||
crypto: crypto,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGroupRepository(db *gorm.DB) GroupRepository {
|
||||
return &gormGroupRepository{db: db}
|
||||
}
|
||||
47
internal/response/response.go
Normal file
47
internal/response/response.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Filename: internal/response/response.go
|
||||
package response
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SuccessResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Error gin.H `json:"error"`
|
||||
}
|
||||
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, SuccessResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func Created(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusCreated, SuccessResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func NoContent(c *gin.Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func Error(c *gin.Context, err *errors.APIError) {
|
||||
c.JSON(err.HTTPStatus, ErrorResponse{
|
||||
Success: false,
|
||||
Error: gin.H{
|
||||
"code": err.Code,
|
||||
"message": err.Message,
|
||||
},
|
||||
})
|
||||
}
|
||||
210
internal/router/router.go
Normal file
210
internal/router/router.go
Normal file
@@ -0,0 +1,210 @@
|
||||
// Filename: internal/router/router.go
|
||||
package router
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/domain/proxy"
|
||||
"gemini-balancer/internal/domain/upstream"
|
||||
"gemini-balancer/internal/handlers"
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/pongo"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/webhandlers"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func NewRouter(
|
||||
// Core Services
|
||||
cfg *config.Config,
|
||||
securityService *service.SecurityService,
|
||||
settingsManager *settings.SettingsManager,
|
||||
// Core Handlers
|
||||
proxyHandler *handlers.ProxyHandler,
|
||||
apiAuthHandler *handlers.APIAuthHandler,
|
||||
// Admin API Handlers
|
||||
keyGroupHandler *handlers.KeyGroupHandler,
|
||||
apiKeyHandler *handlers.APIKeyHandler,
|
||||
tokensHandler *handlers.TokensHandler,
|
||||
logHandler *handlers.LogHandler,
|
||||
settingHandler *handlers.SettingHandler,
|
||||
dashboardHandler *handlers.DashboardHandler,
|
||||
taskHandler *handlers.TaskHandler,
|
||||
// Web Page Handlers
|
||||
webAuthHandler *webhandlers.WebAuthHandler,
|
||||
pageHandler *webhandlers.PageHandler,
|
||||
// === Domain Modules ===
|
||||
upstreamModule *upstream.Module,
|
||||
proxyModule *proxy.Module,
|
||||
) *gin.Engine {
|
||||
if cfg.Log.Level != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
router := gin.Default()
|
||||
|
||||
router.Static("/static", "./web/static")
|
||||
// CORS 配置
|
||||
config := cors.Config{
|
||||
// 允许前端的来源。在生产环境中,需改为实际域名
|
||||
AllowOrigins: []string{"http://localhost:9000"},
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"},
|
||||
ExposeHeaders: []string{"Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * time.Hour,
|
||||
}
|
||||
router.Use(cors.New(config))
|
||||
isDebug := gin.Mode() != gin.ReleaseMode
|
||||
router.HTMLRender = pongo.New("web/templates", isDebug)
|
||||
|
||||
// --- 基础设施 ---
|
||||
router.GET("/", func(c *gin.Context) { c.Redirect(http.StatusMovedPermanently, "/dashboard") })
|
||||
router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })
|
||||
// --- 统一的认证管道 ---
|
||||
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService)
|
||||
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
|
||||
|
||||
router.Use(gin.RecoveryWithWriter(os.Stdout))
|
||||
// --- 将正确的依赖和中间件管道传递下去 ---
|
||||
registerProxyRoutes(router, proxyHandler, securityService)
|
||||
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
|
||||
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager)
|
||||
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
|
||||
return router
|
||||
}
|
||||
|
||||
func registerProxyRoutes(
|
||||
router *gin.Engine, proxyHandler *handlers.ProxyHandler, securityService *service.SecurityService,
|
||||
) {
|
||||
// 通用的代理认证中间件
|
||||
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService)
|
||||
// --- 模式一: 智能聚合模式 (根路径) ---
|
||||
// /v1 和 /v1beta 路径作为默认入口,服务于 BasePool 聚合逻辑
|
||||
v1 := router.Group("/v1")
|
||||
v1.Use(proxyAuthMiddleware)
|
||||
{
|
||||
v1.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
v1beta := router.Group("/v1beta")
|
||||
v1beta.Use(proxyAuthMiddleware)
|
||||
{
|
||||
v1beta.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
// --- 模式二: 精确路由模式 (/proxy/:group_name) ---
|
||||
// 创建一个新的、物理隔离的路由组,用于按组名精确路由
|
||||
proxyGroup := router.Group("/proxy/:group_name")
|
||||
proxyGroup.Use(proxyAuthMiddleware)
|
||||
{
|
||||
// 捕获所有子路径 (例如 /v1/chat/completions),并全部交给同一个 ProxyHandler。
|
||||
proxyGroup.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAdminRoutes
|
||||
func registerAdminRoutes(
|
||||
router *gin.Engine,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
keyGroupHandler *handlers.KeyGroupHandler,
|
||||
tokensHandler *handlers.TokensHandler,
|
||||
apiKeyHandler *handlers.APIKeyHandler,
|
||||
logHandler *handlers.LogHandler,
|
||||
settingHandler *handlers.SettingHandler,
|
||||
dashboardHandler *handlers.DashboardHandler,
|
||||
taskHandler *handlers.TaskHandler,
|
||||
upstreamModule *upstream.Module,
|
||||
proxyModule *proxy.Module,
|
||||
) {
|
||||
admin := router.Group("/admin", authMiddleware)
|
||||
{
|
||||
// --- KeyGroup Base Routes ---
|
||||
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
|
||||
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
|
||||
// --- KeyGroup Specific Routes (by :id) ---
|
||||
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
|
||||
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
|
||||
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
|
||||
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
|
||||
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
|
||||
// --- APIKey Sub-resource Routes under a KeyGroup ---
|
||||
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
|
||||
{
|
||||
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
|
||||
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
|
||||
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
|
||||
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
|
||||
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
|
||||
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
|
||||
}
|
||||
|
||||
// Global key operations
|
||||
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
|
||||
// admin.PUT("/apikeys/:id", apiKeyHandler.UpdateAPIKey) // DEPRECATED: Status is now contextual
|
||||
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) // Test keys globally
|
||||
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) // Hard delete a single key
|
||||
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) // Hard delete multiple keys
|
||||
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) // Restore multiple keys globally
|
||||
|
||||
// --- Global Routes ---
|
||||
admin.GET("/tokens", tokensHandler.GetAllTokens)
|
||||
admin.PUT("/tokens", tokensHandler.UpdateTokens)
|
||||
admin.GET("/logs", logHandler.GetLogs)
|
||||
admin.GET("/settings", settingHandler.GetSettings)
|
||||
admin.PUT("/settings", settingHandler.UpdateSettings)
|
||||
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
|
||||
|
||||
// 用于查询异步任务的状态
|
||||
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
|
||||
|
||||
// 领域模块
|
||||
upstreamModule.RegisterRoutes(admin)
|
||||
proxyModule.RegisterRoutes(admin)
|
||||
// --- 全局仪表盘路由 ---
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/overview", dashboardHandler.GetOverview)
|
||||
dashboard.GET("/chart", dashboardHandler.GetChart)
|
||||
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) // 点击详情
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebRoutes
|
||||
func registerWebRoutes(
|
||||
router *gin.Engine,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
webAuthHandler *webhandlers.WebAuthHandler,
|
||||
pageHandler *webhandlers.PageHandler,
|
||||
) {
|
||||
router.GET("/login", webAuthHandler.ShowLoginPage)
|
||||
router.POST("/login", webAuthHandler.HandleLogin)
|
||||
router.GET("/logout", webAuthHandler.HandleLogout)
|
||||
// For Test only router.Run("127.0.0.1:9000")
|
||||
// 受保护的Admin Web界面
|
||||
webGroup := router.Group("/", authMiddleware)
|
||||
webGroup.Use(authMiddleware)
|
||||
{
|
||||
webGroup.GET("/keys", pageHandler.ShowKeysPage)
|
||||
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
|
||||
webGroup.GET("/logs", pageHandler.ShowErrorLogsPage)
|
||||
webGroup.GET("/dashboard", pageHandler.ShowDashboardPage)
|
||||
webGroup.GET("/tasks", pageHandler.ShowTasksPage)
|
||||
webGroup.GET("/chat", pageHandler.ShowChatPage)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// registerPublicAPIRoutes 无需后台登录的公共API路由
|
||||
func registerPublicAPIRoutes(router *gin.Engine, apiAuthHandler *handlers.APIAuthHandler, securityService *service.SecurityService, settingsManager *settings.SettingsManager) {
|
||||
ipBanMiddleware := middleware.IPBanMiddleware(securityService, settingsManager)
|
||||
publicAPIGroup := router.Group("/api")
|
||||
{
|
||||
publicAPIGroup.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
|
||||
}
|
||||
}
|
||||
90
internal/scheduler/scheduler.go
Normal file
90
internal/scheduler/scheduler.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Filename: internal/scheduler/scheduler.go
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Scheduler struct {
|
||||
gocronScheduler *gocron.Scheduler
|
||||
logger *logrus.Entry
|
||||
statsService *service.StatsService
|
||||
keyRepo repository.KeyRepository
|
||||
// healthCheckService *service.HealthCheckService // 健康检查任务预留
|
||||
}
|
||||
|
||||
func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler {
|
||||
s := gocron.NewScheduler(time.UTC)
|
||||
s.TagsUnique()
|
||||
return &Scheduler{
|
||||
gocronScheduler: s,
|
||||
logger: logger.WithField("component", "Scheduler📆"),
|
||||
statsService: statsSvc,
|
||||
keyRepo: keyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) Start() {
|
||||
s.logger.Info("Starting scheduler and registering jobs...")
|
||||
|
||||
// --- 任务注册 ---
|
||||
// 使用CRON表达式,精确定义“每小时的第5分钟”执行
|
||||
_, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() {
|
||||
s.logger.Info("Executing hourly request stats aggregation...")
|
||||
if err := s.statsService.AggregateHourlyStats(); err != nil {
|
||||
s.logger.WithError(err).Error("Hourly stats aggregation failed.")
|
||||
} else {
|
||||
s.logger.Info("Hourly stats aggregation completed successfully.")
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err)
|
||||
}
|
||||
|
||||
// 任务二:(预留) 自动健康检查 (例如:每10分钟一次)
|
||||
/*
|
||||
_, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() {
|
||||
s.logger.Info("Executing periodic health check for all groups...")
|
||||
// s.healthCheckService.StartGlobalCheckTask() // 伪代码
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err)
|
||||
}
|
||||
*/
|
||||
// [NEW] --- 任务三: 清理软删除的API Keys ---
|
||||
// Executes once daily at 3:15 AM UTC.
|
||||
_, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() {
|
||||
s.logger.Info("Executing daily cleanup of soft-deleted API keys...")
|
||||
|
||||
// Let's assume a retention period of 7 days for now.
|
||||
// In a real scenario, this should come from settings.
|
||||
const retentionDays = 7
|
||||
|
||||
count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays))
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Daily cleanup of soft-deleted keys failed.")
|
||||
} else if count > 0 {
|
||||
s.logger.Infof("Daily cleanup completed: Permanently deleted %d expired soft-deleted keys.", count)
|
||||
} else {
|
||||
s.logger.Info("Daily cleanup completed: No expired soft-deleted keys found to delete.")
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err)
|
||||
}
|
||||
// --- 任务注册结束 ---
|
||||
|
||||
s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程
|
||||
s.logger.Info("Scheduler started.")
|
||||
}
|
||||
|
||||
func (s *Scheduler) Stop() {
|
||||
s.logger.Info("Stopping scheduler...")
|
||||
s.gocronScheduler.Stop()
|
||||
s.logger.Info("Scheduler stopped.")
|
||||
}
|
||||
197
internal/service/analytics_service.go
Normal file
197
internal/service/analytics_service.go
Normal file
@@ -0,0 +1,197 @@
|
||||
// Filename: internal/service/analytics_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
flushLoopInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
type AnalyticsServiceLogger struct{ *logrus.Entry }
|
||||
|
||||
type AnalyticsService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
dialect dialect.DialectAdapter
|
||||
}
|
||||
|
||||
func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService {
|
||||
return &AnalyticsService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "Analytics📊"),
|
||||
stopChan: make(chan struct{}),
|
||||
dialect: d,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Start() {
|
||||
s.wg.Add(2) // 2 (flushLoop, eventListener)
|
||||
go s.flushLoop()
|
||||
go s.eventListener()
|
||||
s.logger.Info("AnalyticsService (Command Side) started.")
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
|
||||
s.flushToDB() // 停止前刷盘
|
||||
s.logger.Info("AnalyticsService final data flush completed.")
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) eventListener() {
|
||||
defer s.wg.Done()
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
s.logger.Info("AnalyticsService subscribed to request events.")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleAnalyticsEvent(&event)
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("AnalyticsService stopping event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) {
|
||||
if event.GroupID == 0 {
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
|
||||
fieldPrefix := fmt.Sprintf("%d:%s", event.GroupID, event.ModelName)
|
||||
|
||||
pipe := s.store.Pipeline()
|
||||
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
|
||||
if event.IsSuccess {
|
||||
pipe.HIncrBy(key, fieldPrefix+":success", 1)
|
||||
}
|
||||
if event.PromptTokens > 0 {
|
||||
pipe.HIncrBy(key, fieldPrefix+":prompt", int64(event.PromptTokens))
|
||||
}
|
||||
if event.CompletionTokens > 0 {
|
||||
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.CompletionTokens))
|
||||
}
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, event.GroupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) flushLoop() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(flushLoopInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.flushToDB()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) flushToDB() {
|
||||
now := time.Now().UTC()
|
||||
keysToFlush := []string{
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")),
|
||||
}
|
||||
|
||||
for _, key := range keysToFlush {
|
||||
data, err := s.store.HGetAll(key)
|
||||
if err != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data)
|
||||
|
||||
if len(statsToFlush) > 0 {
|
||||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||||
[]string{"time", "group_id", "model_name"}, // conflict columns
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
|
||||
)
|
||||
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
|
||||
_ = s.store.HDel(key, parsedFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
|
||||
tempAggregator := make(map[string]*models.StatsHourly)
|
||||
var parsedFields []string
|
||||
for field, valueStr := range data {
|
||||
parts := strings.Split(field, ":")
|
||||
if len(parts) != 3 {
|
||||
continue
|
||||
}
|
||||
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
|
||||
|
||||
aggKey := groupIDStr + ":" + modelName
|
||||
if _, ok := tempAggregator[aggKey]; !ok {
|
||||
gid, err := strconv.Atoi(groupIDStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
tempAggregator[aggKey] = &models.StatsHourly{
|
||||
Time: t,
|
||||
GroupID: uint(gid),
|
||||
ModelName: modelName,
|
||||
}
|
||||
}
|
||||
val, _ := strconv.ParseInt(valueStr, 10, 64)
|
||||
switch counterType {
|
||||
case "requests":
|
||||
tempAggregator[aggKey].RequestCount = val
|
||||
case "success":
|
||||
tempAggregator[aggKey].SuccessCount = val
|
||||
case "prompt":
|
||||
tempAggregator[aggKey].PromptTokens = val
|
||||
case "completion":
|
||||
tempAggregator[aggKey].CompletionTokens = val
|
||||
}
|
||||
parsedFields = append(parsedFields, field)
|
||||
}
|
||||
var result []models.StatsHourly
|
||||
for _, stats := range tempAggregator {
|
||||
if stats.RequestCount > 0 {
|
||||
result = append(result, *stats)
|
||||
}
|
||||
}
|
||||
return result, parsedFields
|
||||
}
|
||||
857
internal/service/apikey_service.go
Normal file
857
internal/service/apikey_service.go
Normal file
@@ -0,0 +1,857 @@
|
||||
// Filename: internal/service/apikey_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/task"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeRestoreAllBannedInGroup = "restore_all_banned_in_group"
|
||||
TaskTypeRestoreSpecificKeys = "restore_specific_keys_in_group"
|
||||
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
|
||||
)
|
||||
|
||||
// DTOs & Constants
|
||||
const (
|
||||
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
)
|
||||
|
||||
type BatchRestoreResult struct {
|
||||
RestoredCount int `json:"restored_count"`
|
||||
SkippedCount int `json:"skipped_count"`
|
||||
SkippedKeys []SkippedKeyInfo `json:"skipped_keys"`
|
||||
}
|
||||
|
||||
type SkippedKeyInfo struct {
|
||||
KeyID uint `json:"key_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type PaginatedAPIKeys struct {
|
||||
Items []*models.APIKeyDetails `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
type KeyTestResult struct {
|
||||
Key string `json:"key"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type APIKeyService struct {
|
||||
db *gorm.DB
|
||||
keyRepo repository.KeyRepository
|
||||
channel channel.ChannelProxy
|
||||
store store.Store
|
||||
SettingsManager *settings.SettingsManager
|
||||
taskService task.Reporter
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
validationService *KeyValidationService
|
||||
groupManager *GroupManager
|
||||
}
|
||||
|
||||
func NewAPIKeyService(
|
||||
db *gorm.DB,
|
||||
repo repository.KeyRepository,
|
||||
ch channel.ChannelProxy,
|
||||
s store.Store,
|
||||
sm *settings.SettingsManager,
|
||||
ts task.Reporter,
|
||||
vs *KeyValidationService,
|
||||
gm *GroupManager,
|
||||
logger *logrus.Logger,
|
||||
) *APIKeyService {
|
||||
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
|
||||
return &APIKeyService{
|
||||
db: db,
|
||||
keyRepo: repo,
|
||||
channel: ch,
|
||||
store: s,
|
||||
SettingsManager: sm,
|
||||
taskService: ts,
|
||||
logger: logger.WithField("component", "APIKeyService🔑"),
|
||||
stopChan: make(chan struct{}),
|
||||
validationService: vs,
|
||||
groupManager: gm,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) Start() {
|
||||
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
|
||||
return
|
||||
}
|
||||
s.logger.Info("Started and subscribed to request, master key, health check, and import events.")
|
||||
|
||||
go func() {
|
||||
defer requestSub.Close()
|
||||
defer masterKeySub.Close()
|
||||
defer keyStatusSub.Close()
|
||||
defer importSub.Close()
|
||||
for {
|
||||
select {
|
||||
case msg := <-requestSub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event for key status update: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyUsageEvent(&event)
|
||||
|
||||
case msg := <-masterKeySub.Channel():
|
||||
var event models.MasterKeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal MasterKeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleMasterKeyStatusChangeEvent(&event)
|
||||
case msg := <-keyStatusSub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChangeEvent(&event)
|
||||
|
||||
case msg := <-importSub.Channel():
|
||||
var event models.ImportGroupCompletedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent.")
|
||||
continue
|
||||
}
|
||||
s.logger.Infof("Received ImportGroupCompletedEvent for group %d, triggering validation for %d keys.", event.GroupID, len(event.KeyIDs))
|
||||
|
||||
go s.handlePostImportValidation(&event)
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *APIKeyService) Stop() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) {
|
||||
if event.KeyID == 0 || event.GroupID == 0 {
|
||||
return
|
||||
}
|
||||
// Handle success case: key recovery and timestamp update.
|
||||
if event.IsSuccess {
|
||||
mapping, err := s.keyRepo.GetMapping(event.GroupID, event.KeyID)
|
||||
if err != nil {
|
||||
// Log if mapping is not found, but don't proceed.
|
||||
s.logger.Warnf("[%s] Could not find mapping for G:%d K:%d on successful request: %v", event.CorrelationID, event.GroupID, event.KeyID, err)
|
||||
return
|
||||
}
|
||||
|
||||
needsUpdate := false
|
||||
oldStatus := mapping.Status
|
||||
|
||||
// If status was not active, it's a recovery.
|
||||
if mapping.Status != models.StatusActive {
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
needsUpdate = true
|
||||
}
|
||||
// Always update LastUsedAt timestamp.
|
||||
now := time.Now()
|
||||
mapping.LastUsedAt = &now
|
||||
needsUpdate = true
|
||||
|
||||
if needsUpdate {
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, event.GroupID, event.KeyID, err)
|
||||
} else if oldStatus != models.StatusActive {
|
||||
// Only publish event if status actually changed.
|
||||
go s.publishStatusChangeEvent(event.GroupID, event.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// Handle failure case: delegate to the centralized judgment function.
|
||||
if event.Error != nil {
|
||||
s.judgeKeyErrors(
|
||||
event.CorrelationID,
|
||||
event.GroupID,
|
||||
event.KeyID,
|
||||
event.Error,
|
||||
event.IsPreciseRouting,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
|
||||
log := s.logger.WithFields(logrus.Fields{
|
||||
"group_id": event.GroupID,
|
||||
"key_id": event.KeyID,
|
||||
"new_status": event.NewStatus,
|
||||
"reason": event.ChangeReason,
|
||||
})
|
||||
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
|
||||
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
|
||||
log.Info("Polling caches updated based on health check event.")
|
||||
}
|
||||
|
||||
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
changeEvent := models.KeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
GroupID: groupID,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(changeEvent)
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
|
||||
// --- Path 1: High-performance DB pagination (no keyword) ---
|
||||
if params.Keyword == "" {
|
||||
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totalPages := 0
|
||||
if total > 0 && params.PageSize > 0 {
|
||||
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
|
||||
}
|
||||
return &PaginatedAPIKeys{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
// --- Path 2: In-memory search (keyword present) ---
|
||||
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
||||
// To get all keys, we fetch all IDs first, then get their full details.
|
||||
var statusesToFilter []string
|
||||
if params.Status != "" {
|
||||
statusesToFilter = append(statusesToFilter, params.Status)
|
||||
} else {
|
||||
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
|
||||
}
|
||||
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all key IDs for search: %w", err)
|
||||
}
|
||||
if len(allKeyIDs) == 0 {
|
||||
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
|
||||
}
|
||||
|
||||
// This is the heavy operation: getting all keys and decrypting them.
|
||||
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
|
||||
}
|
||||
// We also need mappings to build the final `APIKeyDetails`.
|
||||
var allMappings []models.GroupAPIKeyMapping
|
||||
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
|
||||
}
|
||||
mappingMap := make(map[uint]*models.GroupAPIKeyMapping)
|
||||
for i := range allMappings {
|
||||
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
|
||||
}
|
||||
// Filter the results in memory.
|
||||
var filteredItems []*models.APIKeyDetails
|
||||
for _, key := range allKeys {
|
||||
if strings.Contains(key.APIKey, params.Keyword) {
|
||||
if mapping, ok := mappingMap[key.ID]; ok {
|
||||
filteredItems = append(filteredItems, &models.APIKeyDetails{
|
||||
ID: key.ID,
|
||||
CreatedAt: key.CreatedAt,
|
||||
UpdatedAt: key.UpdatedAt,
|
||||
APIKey: key.APIKey,
|
||||
MasterStatus: key.MasterStatus,
|
||||
Status: mapping.Status,
|
||||
LastError: mapping.LastError,
|
||||
ConsecutiveErrorCount: mapping.ConsecutiveErrorCount,
|
||||
LastUsedAt: mapping.LastUsedAt,
|
||||
CooldownUntil: mapping.CooldownUntil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort the filtered results to ensure consistent pagination (by ID descending).
|
||||
sort.Slice(filteredItems, func(i, j int) bool {
|
||||
return filteredItems[i].ID > filteredItems[j].ID
|
||||
})
|
||||
// Manually paginate the filtered results.
|
||||
total := int64(len(filteredItems))
|
||||
start := (params.Page - 1) * params.PageSize
|
||||
end := start + params.PageSize
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
if start >= len(filteredItems) {
|
||||
return &PaginatedAPIKeys{
|
||||
Items: []*models.APIKeyDetails{},
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
|
||||
}, nil
|
||||
}
|
||||
if end > len(filteredItems) {
|
||||
end = len(filteredItems)
|
||||
}
|
||||
paginatedItems := filteredItems[start:end]
|
||||
return &PaginatedAPIKeys{
|
||||
Items: paginatedItems,
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
|
||||
go func() {
|
||||
var oldKey models.APIKey
|
||||
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
|
||||
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
|
||||
return
|
||||
}
|
||||
if err := s.keyRepo.Update(key); err != nil {
|
||||
s.logger.Errorf("Failed to asynchronously update key ID %d: %v", key.ID, err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
|
||||
// Get all associated groups before deletion to publish correct events
|
||||
groups, err := s.keyRepo.GetGroupsForKey(id)
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
|
||||
}
|
||||
|
||||
err = s.keyRepo.HardDeleteByID(id)
|
||||
if err == nil {
|
||||
// Publish events for each group the key was a part of
|
||||
for _, groupID := range groups {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: id,
|
||||
GroupID: groupID,
|
||||
ChangeReason: "key_hard_deleted",
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
|
||||
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||
return nil, CustomErrors.ErrStateConflictMasterRevoked
|
||||
}
|
||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus == newStatus {
|
||||
return mapping, nil
|
||||
}
|
||||
mapping.Status = newStatus
|
||||
if newStatus == models.StatusActive {
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
}
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
|
||||
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
|
||||
if event.NewMasterStatus != models.MasterStatusRevoked {
|
||||
return
|
||||
}
|
||||
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
|
||||
return
|
||||
}
|
||||
if len(affectedGroupIDs) == 0 {
|
||||
s.logger.Infof("Key ID %d is revoked, but it's not associated with any group. No action needed.", event.KeyID)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
|
||||
for _, groupID := range affectedGroupIDs {
|
||||
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
|
||||
}
|
||||
}()
|
||||
var mappingsToProcess []models.GroupAPIKeyMapping
|
||||
err := s.db.Preload("APIKey").
|
||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||
Find(&mappingsToProcess).Error
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
result := &BatchRestoreResult{
|
||||
SkippedKeys: make([]SkippedKeyInfo, 0),
|
||||
}
|
||||
var successfulMappings []*models.GroupAPIKeyMapping
|
||||
processedCount := 0
|
||||
for _, mapping := range mappingsToProcess {
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
if mapping.APIKey == nil {
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
|
||||
continue
|
||||
}
|
||||
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: fmt.Sprintf("Master status is '%s'.", mapping.APIKey.MasterStatus)})
|
||||
continue
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus != models.StatusActive {
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
// Use the version that doesn't trigger individual cache updates.
|
||||
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
|
||||
} else {
|
||||
result.RestoredCount++
|
||||
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
|
||||
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
|
||||
}
|
||||
} else {
|
||||
result.RestoredCount++ // Already active, count as success.
|
||||
}
|
||||
}
|
||||
// After the loop, perform one single, efficient cache update.
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
|
||||
// This is not a task-fatal error, so we just log it and continue.
|
||||
}
|
||||
// Account for keys that were requested but not found in the initial DB query.
|
||||
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
|
||||
var bannedKeyIDs []uint
|
||||
err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
|
||||
Pluck("api_key_id", &bannedKeyIDs).Error
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(bannedKeyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
|
||||
}
|
||||
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
|
||||
group, ok := s.groupManager.GetGroupByID(event.GroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
|
||||
return
|
||||
}
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to build operational config for group %d, aborting validation: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to build key check endpoint for group %d, aborting validation: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
globalSettings := s.SettingsManager.GetSettings()
|
||||
concurrency := globalSettings.BaseKeyCheckConcurrency
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10 // Safety fallback
|
||||
}
|
||||
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
|
||||
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to get key models for validation in group %d: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Validating %d keys for group %d with concurrency %d against endpoint %s", len(keysToValidate), event.GroupID, concurrency, endpoint)
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan models.APIKey, len(keysToValidate))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
|
||||
if validationErr == nil {
|
||||
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
|
||||
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
|
||||
}
|
||||
} else {
|
||||
var apiErr *CustomErrors.APIError
|
||||
if !CustomErrors.As(validationErr, &apiErr) {
|
||||
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
}
|
||||
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for _, key := range keysToValidate {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
|
||||
}
|
||||
|
||||
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
|
||||
// that match a specific set of source statuses within a group.
|
||||
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
||||
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
|
||||
|
||||
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
|
||||
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(keyIDs) == 0 {
|
||||
now := time.Now()
|
||||
return &task.Status{
|
||||
IsRunning: false, // The "task" is not running.
|
||||
Processed: 0,
|
||||
Total: 0,
|
||||
Result: map[string]string{ // We use the flexible Result field to pass the message.
|
||||
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
|
||||
},
|
||||
Error: "", // There is no error.
|
||||
StartedAt: now,
|
||||
FinishedAt: &now, // It started and finished at the same time.
|
||||
}, nil // Return nil for the error, signaling a 200 OK.
|
||||
}
|
||||
// 2. Start a new task using the TaskService, following existing patterns.
|
||||
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err // Pass up errors like "task already in progress".
|
||||
}
|
||||
|
||||
// 3. Run the core logic in a separate goroutine.
|
||||
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
|
||||
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
|
||||
}
|
||||
}()
|
||||
type BatchUpdateResult struct {
|
||||
UpdatedCount int `json:"updated_count"`
|
||||
SkippedCount int `json:"skipped_count"`
|
||||
}
|
||||
result := &BatchUpdateResult{}
|
||||
var successfulMappings []*models.GroupAPIKeyMapping
|
||||
// 1. Fetch all key master statuses in one go. This is efficient.
|
||||
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
|
||||
for _, key := range keys {
|
||||
masterStatusMap[key.ID] = key.MasterStatus
|
||||
}
|
||||
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
|
||||
// avoiding the need for a new repository method. This pattern is
|
||||
// already used in other parts of this service.
|
||||
var mappings []*models.GroupAPIKeyMapping
|
||||
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
processedCount := 0
|
||||
for _, mapping := range mappings {
|
||||
processedCount++
|
||||
// The progress update should reflect the number of items *being processed*, not the final count.
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
result.SkippedCount++
|
||||
continue
|
||||
}
|
||||
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||
result.SkippedCount++
|
||||
continue
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus != newStatus {
|
||||
mapping.Status = newStatus
|
||||
if newStatus == models.StatusActive {
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
}
|
||||
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
|
||||
result.SkippedCount++
|
||||
} else {
|
||||
result.UpdatedCount++
|
||||
successfulMappings = append(successfulMappings, mapping)
|
||||
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
|
||||
}
|
||||
} else {
|
||||
result.UpdatedCount++ // Already in desired state, count as success.
|
||||
}
|
||||
}
|
||||
result.SkippedCount += (len(keyIDs) - len(mappings))
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
|
||||
}
|
||||
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
if apiErr == nil {
|
||||
s.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided. No action taken.", key.ID, group.ID)
|
||||
return
|
||||
}
|
||||
errMsg := apiErr.Message
|
||||
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
|
||||
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
|
||||
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
} else {
|
||||
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
|
||||
func sanitizeForLog(errMsg string) string {
|
||||
// Find the start of any potential JSON blob or detailed structure.
|
||||
jsonStartIndex := strings.Index(errMsg, "{")
|
||||
var cleanMsg string
|
||||
if jsonStartIndex != -1 {
|
||||
// If a '{' is found, take everything before it as the summary
|
||||
// and append a simple placeholder.
|
||||
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
|
||||
} else {
|
||||
// If no JSON-like structure is found, use the original message.
|
||||
cleanMsg = errMsg
|
||||
}
|
||||
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
|
||||
const maxLen = 250
|
||||
if len(cleanMsg) > maxLen {
|
||||
return cleanMsg[:maxLen] + "..."
|
||||
}
|
||||
return cleanMsg
|
||||
}
|
||||
|
||||
func (s *APIKeyService) judgeKeyErrors(
|
||||
correlationID string,
|
||||
groupID, keyID uint,
|
||||
apiErr *CustomErrors.APIError,
|
||||
isPreciseRouting bool,
|
||||
) {
|
||||
logger := s.logger.WithFields(logrus.Fields{"group_id": groupID, "key_id": keyID, "correlation_id": correlationID})
|
||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Cannot apply consequences for error: mapping not found.")
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
mapping.LastUsedAt = &now
|
||||
errorMessage := apiErr.Message
|
||||
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
|
||||
logger.Errorf("Permanent error detected. Banning mapping and revoking master key. Reason: %s", sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the permanent error.")
|
||||
if mapping.Status != models.StatusBanned {
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = errorMessage
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
|
||||
} else {
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
|
||||
go s.revokeMasterKey(keyID, "permanent_upstream_error")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
|
||||
mapping.LastError = errorMessage
|
||||
mapping.ConsecutiveErrorCount++
|
||||
var threshold int
|
||||
if isPreciseRouting {
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if !ok || err != nil {
|
||||
logger.Warnf("Could not build operational config for group %d in Precise Routing mode. Falling back to global settings.", groupID)
|
||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
||||
} else {
|
||||
threshold = *opConfig.KeyBlacklistThreshold
|
||||
}
|
||||
} else {
|
||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
||||
}
|
||||
logger.Warnf("Temporary error detected. Incrementing error count. New count: %d (Threshold: %d). Reason: %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the temporary error.")
|
||||
oldStatus := mapping.Status
|
||||
newStatus := oldStatus
|
||||
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
|
||||
newStatus = models.StatusCooldown
|
||||
logger.Errorf("Putting mapping into COOLDOWN due to reaching temporary error threshold (%d)", threshold)
|
||||
}
|
||||
if oldStatus != newStatus {
|
||||
mapping.Status = newStatus
|
||||
}
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping after temporary error.")
|
||||
return
|
||||
}
|
||||
if oldStatus != newStatus {
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
|
||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.Warnf("Attempted to revoke non-existent key ID %d.", keyID)
|
||||
} else {
|
||||
s.logger.Errorf("Failed to get key by ID %d for master status revocation: %v", keyID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if key.MasterStatus == models.MasterStatusRevoked {
|
||||
return
|
||||
}
|
||||
oldMasterStatus := key.MasterStatus
|
||||
newMasterStatus := models.MasterStatusRevoked
|
||||
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
|
||||
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
|
||||
return
|
||||
}
|
||||
masterKeyEvent := models.MasterKeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
OldMasterStatus: oldMasterStatus,
|
||||
NewMasterStatus: newMasterStatus,
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(masterKeyEvent)
|
||||
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
|
||||
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
|
||||
}
|
||||
315
internal/service/dashboard_query_service.go
Normal file
315
internal/service/dashboard_query_service.go
Normal file
@@ -0,0 +1,315 @@
|
||||
// Filename: internal/service/dashboard_query_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||||
|
||||
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
|
||||
|
||||
type DashboardQueryService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
|
||||
qs := &DashboardQueryService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "DashboardQueryService"),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
loader := qs.loadOverviewData
|
||||
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
|
||||
}
|
||||
qs.overviewSyncer = overviewSyncer
|
||||
return qs, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Start() {
|
||||
go s.eventListener()
|
||||
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
keyStatsMap, err := s.store.HGetAll(statsKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
|
||||
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
|
||||
}
|
||||
keyStats := make(map[string]int64)
|
||||
for k, v := range keyStatsMap {
|
||||
val, _ := strconv.ParseInt(v, 10, 64)
|
||||
keyStats[k] = val
|
||||
}
|
||||
now := time.Now()
|
||||
oneHourAgo := now.Add(-1 * time.Hour)
|
||||
twentyFourHoursAgo := now.Add(-24 * time.Hour)
|
||||
type requestStatsResult struct {
|
||||
TotalRequests int64
|
||||
SuccessRequests int64
|
||||
}
|
||||
var last1Hour, last24Hours requestStatsResult
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||||
Scan(&last1Hour)
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||||
Scan(&last24Hours)
|
||||
failureRate1h := 0.0
|
||||
if last1Hour.TotalRequests > 0 {
|
||||
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
|
||||
}
|
||||
failureRate24h := 0.0
|
||||
if last24Hours.TotalRequests > 0 {
|
||||
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
|
||||
}
|
||||
last1HourStats := map[string]any{
|
||||
"total_requests": last1Hour.TotalRequests,
|
||||
"success_requests": last1Hour.SuccessRequests,
|
||||
"failure_rate": failureRate1h,
|
||||
}
|
||||
last24HoursStats := map[string]any{
|
||||
"total_requests": last24Hours.TotalRequests,
|
||||
"success_requests": last24Hours.SuccessRequests,
|
||||
"failure_rate": failureRate24h,
|
||||
}
|
||||
result := map[string]any{
|
||||
"key_stats": keyStats,
|
||||
"last_1_hour": last1HourStats,
|
||||
"last_24_hours": last24HoursStats,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) eventListener() {
|
||||
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
|
||||
defer keyStatusSub.Close()
|
||||
defer upstreamStatusSub.Close()
|
||||
for {
|
||||
select {
|
||||
case <-keyStatusSub.Channel():
|
||||
s.logger.Info("Received key status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-upstreamStatusSub.Channel():
|
||||
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping dashboard event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
|
||||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
cachedDataPtr := s.overviewSyncer.Get()
|
||||
if cachedDataPtr == nil {
|
||||
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
|
||||
}
|
||||
return cachedDataPtr, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||||
return s.overviewSyncer.Invalidate()
|
||||
}
|
||||
|
||||
// QueryHistoricalChart 查询历史图表数据。
|
||||
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
|
||||
type ChartPoint struct {
|
||||
TimeLabel string `gorm:"column:time_label"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
}
|
||||
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
|
||||
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
|
||||
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
|
||||
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
|
||||
if groupID != nil && *groupID > 0 {
|
||||
query = query.Where("group_id = ?", *groupID)
|
||||
}
|
||||
var points []ChartPoint
|
||||
if err := query.Find(&points).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
datasets := make(map[string]map[string]int64)
|
||||
for _, p := range points {
|
||||
if _, ok := datasets[p.ModelName]; !ok {
|
||||
datasets[p.ModelName] = make(map[string]int64)
|
||||
}
|
||||
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
|
||||
}
|
||||
var labels []string
|
||||
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
|
||||
labels = append(labels, t.Format(goFormat))
|
||||
}
|
||||
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
|
||||
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
|
||||
colorIndex := 0
|
||||
for modelName, dataPoints := range datasets {
|
||||
dataArray := make([]int64, len(labels))
|
||||
for i, label := range labels {
|
||||
dataArray[i] = dataPoints[label]
|
||||
}
|
||||
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
|
||||
Label: modelName,
|
||||
Data: dataArray,
|
||||
Color: colorPalette[colorIndex%len(colorPalette)],
|
||||
})
|
||||
colorIndex++
|
||||
}
|
||||
return chartData, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
|
||||
startTime := time.Now()
|
||||
resp := &models.DashboardStatsResponse{
|
||||
KeyStatusCount: make(map[models.APIKeyStatus]int64),
|
||||
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
|
||||
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体,而不是nil
|
||||
RequestCount24h: models.StatCard{}, // 同上
|
||||
TokenCount: make(map[string]any),
|
||||
UpstreamHealthStatus: make(map[string]string),
|
||||
RPM: models.StatCard{},
|
||||
RequestCounts: make(map[string]int64),
|
||||
}
|
||||
// --- 1. Aggregate Operational Status from Mappings ---
|
||||
type MappingStatusResult struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var mappingStatusResults []MappingStatusResult
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
|
||||
}
|
||||
for _, res := range mappingStatusResults {
|
||||
resp.KeyStatusCount[res.Status] = res.Count
|
||||
}
|
||||
|
||||
// --- 2. Aggregate Master Status from APIKeys ---
|
||||
type MasterStatusResult struct {
|
||||
Status models.MasterAPIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var masterStatusResults []MasterStatusResult
|
||||
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query master status stats: %w", err)
|
||||
}
|
||||
var totalKeys, invalidKeys int64
|
||||
for _, res := range masterStatusResults {
|
||||
resp.MasterStatusCount[res.Status] = res.Count
|
||||
totalKeys += res.Count
|
||||
if res.Status != models.MasterStatusActive {
|
||||
invalidKeys += res.Count
|
||||
}
|
||||
}
|
||||
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
|
||||
var count1m, count1h, count1d int64
|
||||
// RPM: 从此刻倒推1分钟
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
|
||||
// RPH: 从此刻倒推1小时
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
|
||||
|
||||
// RPD: 从今天零点 (UTC) 到此刻
|
||||
year, month, day := now.UTC().Date()
|
||||
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
|
||||
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
|
||||
var count30d int64
|
||||
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
|
||||
|
||||
resp.RequestCounts["1m"] = count1m
|
||||
resp.RequestCounts["1h"] = count1h
|
||||
resp.RequestCounts["1d"] = count1d
|
||||
resp.RequestCounts["30d"] = count30d
|
||||
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
|
||||
} else {
|
||||
for _, u := range upstreams {
|
||||
resp.UpstreamHealthStatus[u.URL] = u.Status
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
|
||||
var startTime time.Time
|
||||
now := time.Now()
|
||||
switch period {
|
||||
case "1m":
|
||||
startTime = now.Add(-1 * time.Minute)
|
||||
case "1h":
|
||||
startTime = now.Add(-1 * time.Hour)
|
||||
case "1d":
|
||||
year, month, day := now.UTC().Date()
|
||||
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid period specified: %s", period)
|
||||
}
|
||||
var result struct {
|
||||
Total int64
|
||||
Success int64
|
||||
}
|
||||
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
|
||||
Where("request_time >= ?", startTime).
|
||||
Scan(&result).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gin.H{
|
||||
"total": result.Total,
|
||||
"success": result.Success,
|
||||
"failure": result.Total - result.Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
|
||||
dialect := s.db.Dialector.Name()
|
||||
switch dialect {
|
||||
case "mysql":
|
||||
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
|
||||
case "sqlite":
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
default:
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
}
|
||||
}
|
||||
149
internal/service/db_log_writer_service.go
Normal file
149
internal/service/db_log_writer_service.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// Filename: internal/service/db_log_writer_service.go (全新文件)
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DBLogWriterService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
logBuffer chan *models.RequestLog
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
SettingsManager *settings.SettingsManager
|
||||
}
|
||||
|
||||
func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService {
|
||||
cfg := settings.GetSettings()
|
||||
bufferCapacity := cfg.LogBufferCapacity
|
||||
if bufferCapacity <= 0 {
|
||||
bufferCapacity = 1000
|
||||
}
|
||||
return &DBLogWriterService{
|
||||
db: db,
|
||||
store: s,
|
||||
SettingsManager: settings,
|
||||
logger: logger.WithField("component", "DBLogWriter📝"),
|
||||
// 使用配置值来创建缓冲区
|
||||
logBuffer: make(chan *models.RequestLog, bufferCapacity),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Start() {
|
||||
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
|
||||
|
||||
// 启动事件监听器
|
||||
go s.eventListenerLoop()
|
||||
// 启动数据库写入器
|
||||
go s.dbWriterLoop()
|
||||
|
||||
s.logger.Info("DBLogWriterService started.")
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Stop() {
|
||||
s.logger.Info("DBLogWriterService stopping...")
|
||||
close(s.stopChan) // 通知所有goroutine停止
|
||||
s.wg.Wait() // 等待所有goroutine完成
|
||||
s.logger.Info("DBLogWriterService stopped.")
|
||||
}
|
||||
|
||||
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
|
||||
func (s *DBLogWriterService) eventListenerLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
|
||||
s.logger.Info("Subscribed to request events for database logging.")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 将事件中的日志部分放入缓冲区
|
||||
select {
|
||||
case s.logBuffer <- &event.RequestLog:
|
||||
default:
|
||||
s.logger.Warn("Log buffer is full. A log message might be dropped.")
|
||||
}
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Event listener loop stopping.")
|
||||
// 关闭缓冲区,以通知dbWriterLoop处理完剩余日志后退出
|
||||
close(s.logBuffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
|
||||
func (s *DBLogWriterService) dbWriterLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 在启动时获取一次配置
|
||||
cfg := s.SettingsManager.GetSettings()
|
||||
batchSize := cfg.LogFlushBatchSize
|
||||
if batchSize <= 0 {
|
||||
batchSize = 100
|
||||
}
|
||||
|
||||
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if flushTimeout <= 0 {
|
||||
flushTimeout = 5 * time.Second
|
||||
}
|
||||
batch := make([]*models.RequestLog, 0, batchSize)
|
||||
ticker := time.NewTicker(flushTimeout)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case logEntry, ok := <-s.logBuffer:
|
||||
if !ok {
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
}
|
||||
s.logger.Info("DB writer loop finished.")
|
||||
return
|
||||
}
|
||||
batch = append(batch, logEntry)
|
||||
if len(batch) >= batchSize { // 使用配置的批次大小
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
case <-ticker.C:
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushBatch 将一个批次的日志写入数据库
|
||||
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
|
||||
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
|
||||
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d logs to database.", len(batch))
|
||||
}
|
||||
}
|
||||
596
internal/service/group_manager.go
Normal file
596
internal/service/group_manager.go
Normal file
@@ -0,0 +1,596 @@
|
||||
// Filename: internal/service/group_manager.go (Syncer升级版)
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/pkg/reflectutil"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/utils"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const GroupUpdateChannel = "groups:cache_invalidation"
|
||||
|
||||
type GroupManagerCacheData struct {
|
||||
Groups []*models.KeyGroup
|
||||
GroupsByName map[string]*models.KeyGroup
|
||||
GroupsByID map[uint]*models.KeyGroup
|
||||
KeyCounts map[uint]int64 // GroupID -> Total Key Count
|
||||
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
|
||||
}
|
||||
|
||||
type GroupManager struct {
|
||||
db *gorm.DB
|
||||
keyRepo repository.KeyRepository
|
||||
groupRepo repository.GroupRepository
|
||||
settingsManager *settings.SettingsManager
|
||||
syncer *syncer.CacheSyncer[GroupManagerCacheData]
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
type UpdateOrderPayload struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
|
||||
return func() (GroupManagerCacheData, error) {
|
||||
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
|
||||
var groups []*models.KeyGroup
|
||||
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
|
||||
|
||||
if err := db.Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
Preload("Settings").
|
||||
Preload("RequestConfig").
|
||||
Find(&groups).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
|
||||
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := db.Find(&allMappings).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
|
||||
|
||||
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
for i := range allMappings {
|
||||
mapping := allMappings[i] // Avoid pointer issues with range
|
||||
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if mappings, ok := mappingsByGroupID[group.ID]; ok {
|
||||
group.Mappings = mappings
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
|
||||
|
||||
keyCounts := make(map[uint]int64, len(groups))
|
||||
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
|
||||
|
||||
for _, group := range groups {
|
||||
keyCounts[group.ID] = int64(len(group.Mappings))
|
||||
statusCounts := make(map[models.APIKeyStatus]int64)
|
||||
for _, mapping := range group.Mappings {
|
||||
statusCounts[mapping.Status]++
|
||||
}
|
||||
keyStatusCounts[group.ID] = statusCounts
|
||||
}
|
||||
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
||||
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
||||
|
||||
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
|
||||
for i, group := range groups {
|
||||
if group == nil {
|
||||
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
|
||||
} else {
|
||||
groupsByName[group.Name] = group
|
||||
groupsByID[group.ID] = group
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
|
||||
return GroupManagerCacheData{
|
||||
Groups: groups,
|
||||
GroupsByName: groupsByName,
|
||||
GroupsByID: groupsByID,
|
||||
KeyCounts: keyCounts,
|
||||
KeyStatusCounts: keyStatusCounts,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewGroupManager(
|
||||
db *gorm.DB,
|
||||
keyRepo repository.KeyRepository,
|
||||
groupRepo repository.GroupRepository,
|
||||
sm *settings.SettingsManager,
|
||||
syncer *syncer.CacheSyncer[GroupManagerCacheData],
|
||||
logger *logrus.Logger,
|
||||
) *GroupManager {
|
||||
return &GroupManager{
|
||||
db: db,
|
||||
keyRepo: keyRepo,
|
||||
groupRepo: groupRepo,
|
||||
settingsManager: sm,
|
||||
syncer: syncer,
|
||||
logger: logger.WithField("component", "GroupManager"),
|
||||
}
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.Groups) == 0 {
|
||||
return []*models.KeyGroup{}
|
||||
}
|
||||
groupsToOrder := cache.Groups
|
||||
sort.Slice(groupsToOrder, func(i, j int) bool {
|
||||
if groupsToOrder[i].Order != groupsToOrder[j].Order {
|
||||
return groupsToOrder[i].Order < groupsToOrder[j].Order
|
||||
}
|
||||
return groupsToOrder[i].ID < groupsToOrder[j].ID
|
||||
})
|
||||
return groupsToOrder
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyCounts) == 0 {
|
||||
return 0
|
||||
}
|
||||
count := cache.KeyCounts[groupID]
|
||||
return count
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyStatusCounts) == 0 {
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
|
||||
return counts
|
||||
}
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByName) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByName[name]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByID) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByID[id]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Stop() {
|
||||
gm.syncer.Stop()
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Invalidate() error {
|
||||
return gm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
// --- Write Operations ---
|
||||
|
||||
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
|
||||
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
}
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Create the group itself to get an ID
|
||||
if err := tx.Create(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. If settings are provided, create the associated GroupSettings record
|
||||
if settings != nil {
|
||||
// Only marshal non-nil fields to keep the JSON clean
|
||||
settingsToMarshal := make(map[string]interface{})
|
||||
if settings.EnableKeyCheck != nil {
|
||||
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
|
||||
}
|
||||
if settings.KeyCheckIntervalMinutes != nil {
|
||||
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
|
||||
}
|
||||
if settings.KeyBlacklistThreshold != nil {
|
||||
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
|
||||
}
|
||||
if settings.KeyCooldownMinutes != nil {
|
||||
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
|
||||
}
|
||||
if settings.KeyCheckConcurrency != nil {
|
||||
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
|
||||
}
|
||||
if settings.KeyCheckEndpoint != nil {
|
||||
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
|
||||
}
|
||||
if settings.KeyCheckModel != nil {
|
||||
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
|
||||
}
|
||||
if settings.MaxRetries != nil {
|
||||
settingsToMarshal["max_retries"] = settings.MaxRetries
|
||||
}
|
||||
if settings.EnableSmartGateway != nil {
|
||||
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
|
||||
}
|
||||
if len(settingsToMarshal) > 0 {
|
||||
settingsJSON, err := json.Marshal(settingsToMarshal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal group settings: %w", err)
|
||||
}
|
||||
groupSettings := models.GroupSettings{
|
||||
GroupID: group.ID,
|
||||
SettingsJSON: datatypes.JSON(settingsJSON),
|
||||
}
|
||||
if err := tx.Create(&groupSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to save group settings: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
|
||||
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
}
|
||||
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
|
||||
uniqueModelNames := uniqueStrings(modelNames)
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// --- 1. Update AllowedUpstreams (M:N relationship) ---
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if len(uniqueUpstreamURLs) > 0 {
|
||||
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(uniqueModelNames) > 0 {
|
||||
var newMappings []models.GroupModelMapping
|
||||
for _, name := range uniqueModelNames {
|
||||
newMappings = append(newMappings, models.GroupModelMapping{ModelName: name})
|
||||
}
|
||||
if err := tx.Model(group).Association("AllowedModels").Append(newMappings); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Updates(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var existingSettings models.GroupSettings
|
||||
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
var currentSettingsData models.KeyGroupSettings
|
||||
if len(existingSettings.SettingsJSON) > 0 {
|
||||
if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
|
||||
}
|
||||
}
|
||||
if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil {
|
||||
return fmt.Errorf("failed to merge group settings: %w", err)
|
||||
}
|
||||
updatedJSON, err := json.Marshal(currentSettingsData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal updated group settings: %w", err)
|
||||
}
|
||||
existingSettings.GroupID = group.ID
|
||||
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
|
||||
return tx.Save(&existingSettings).Error
|
||||
})
|
||||
if err == nil {
|
||||
go gm.Invalidate()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
|
||||
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
|
||||
// Step 1: First, retrieve the group object we are about to delete.
|
||||
var group models.KeyGroup
|
||||
if err := tx.First(&group, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
|
||||
return nil // Don't treat as an error, the group is already gone.
|
||||
}
|
||||
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
|
||||
return err
|
||||
}
|
||||
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
|
||||
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Also clear settings if they exist to maintain data integrity
|
||||
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Step 3: Delete the KeyGroup itself.
|
||||
if err := tx.Delete(&group).Error; err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
|
||||
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
|
||||
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
|
||||
if err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
|
||||
return err
|
||||
}
|
||||
if deletedCount > 0 {
|
||||
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
|
||||
}
|
||||
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
go gm.Invalidate()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
var originalGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
Preload("RequestConfig").
|
||||
Preload("Mappings").
|
||||
Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
First(&originalGroup, id).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
|
||||
}
|
||||
newGroup := originalGroup
|
||||
timestamp := time.Now().Unix()
|
||||
newGroup.ID = 0
|
||||
newGroup.Name = fmt.Sprintf("%s-clone-%d", originalGroup.Name, timestamp)
|
||||
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
|
||||
newGroup.CreatedAt = time.Time{}
|
||||
newGroup.UpdatedAt = time.Time{}
|
||||
|
||||
newGroup.RequestConfigID = nil
|
||||
newGroup.RequestConfig = nil
|
||||
newGroup.Mappings = nil
|
||||
newGroup.AllowedUpstreams = nil
|
||||
newGroup.AllowedModels = nil
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
if err := tx.Create(&newGroup).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if originalGroup.RequestConfig != nil {
|
||||
newRequestConfig := *originalGroup.RequestConfig
|
||||
newRequestConfig.ID = 0 // Mark as new record
|
||||
|
||||
if err := tx.Create(&newRequestConfig).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone request config: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
|
||||
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var originalSettings models.GroupSettings
|
||||
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
|
||||
if err == nil && len(originalSettings.SettingsJSON) > 0 {
|
||||
newSettings := models.GroupSettings{
|
||||
GroupID: newGroup.ID,
|
||||
SettingsJSON: originalSettings.SettingsJSON,
|
||||
}
|
||||
if err := tx.Create(&newSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone group settings: %w", err)
|
||||
}
|
||||
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("failed to query original group settings: %w", err)
|
||||
}
|
||||
|
||||
if len(originalGroup.Mappings) > 0 {
|
||||
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
|
||||
for i, oldMapping := range originalGroup.Mappings {
|
||||
newMappings[i] = models.GroupAPIKeyMapping{
|
||||
KeyGroupID: newGroup.ID,
|
||||
APIKeyID: oldMapping.APIKeyID,
|
||||
Status: oldMapping.Status,
|
||||
LastError: oldMapping.LastError,
|
||||
ConsecutiveErrorCount: oldMapping.ConsecutiveErrorCount,
|
||||
LastUsedAt: oldMapping.LastUsedAt,
|
||||
CooldownUntil: oldMapping.CooldownUntil,
|
||||
}
|
||||
}
|
||||
if err := tx.Create(&newMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone key group mappings: %w", err)
|
||||
}
|
||||
}
|
||||
if len(originalGroup.AllowedUpstreams) > 0 {
|
||||
if err := tx.Model(&newGroup).Association("AllowedUpstreams").Append(originalGroup.AllowedUpstreams); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(originalGroup.AllowedModels) > 0 {
|
||||
if err := tx.Model(&newGroup).Association("AllowedModels").Append(originalGroup.AllowedModels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
|
||||
var finalClonedGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
Preload("RequestConfig").
|
||||
Preload("Mappings").
|
||||
Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
First(&finalClonedGroup, newGroup.ID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &finalClonedGroup, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
s := "gemini-1.5-flash" // Per user feedback for default model
|
||||
opConfig := &models.KeyGroupSettings{
|
||||
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
||||
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
||||
KeyCheckIntervalMinutes: &globalSettings.BaseKeyCheckIntervalMinutes,
|
||||
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
||||
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
||||
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
||||
KeyCheckModel: &s,
|
||||
MaxRetries: &globalSettings.MaxRetries,
|
||||
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
||||
}
|
||||
|
||||
if group == nil {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSettingsRecord models.GroupSettings
|
||||
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(groupSettingsRecord.SettingsJSON) == 0 {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSpecificSettings models.KeyGroupSettings
|
||||
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
|
||||
return opConfig, err
|
||||
}
|
||||
|
||||
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
group, ok := gm.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("group with id %d not found", groupID)
|
||||
}
|
||||
opConfig, err := gm.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
|
||||
}
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
baseURL := globalSettings.DefaultUpstreamURL
|
||||
if opConfig.KeyCheckEndpoint != nil && *opConfig.KeyCheckEndpoint != "" {
|
||||
baseURL = *opConfig.KeyCheckEndpoint
|
||||
}
|
||||
if baseURL == "" {
|
||||
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
|
||||
}
|
||||
modelName := globalSettings.BaseKeyCheckModel
|
||||
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
|
||||
modelName = *opConfig.KeyCheckModel
|
||||
}
|
||||
parsedURL, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
|
||||
}
|
||||
cleanedPath := parsedURL.Path
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
|
||||
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
|
||||
finalEndpoint := parsedURL.String()
|
||||
return finalEndpoint, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
|
||||
ordersMap := make(map[uint]int, len(payload))
|
||||
for _, item := range payload {
|
||||
ordersMap[item.ID] = item.Order
|
||||
}
|
||||
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to update group order in transaction")
|
||||
return fmt.Errorf("database transaction failed: %w", err)
|
||||
}
|
||||
gm.logger.Info("Group order updated successfully, invalidating cache...")
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func uniqueStrings(slice []string) []string {
|
||||
keys := make(map[string]struct{})
|
||||
list := []string{}
|
||||
for _, entry := range slice {
|
||||
if _, value := keys[entry]; !value {
|
||||
keys[entry] = struct{}{}
|
||||
list = append(list, entry)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
624
internal/service/healthcheck_service.go
Normal file
624
internal/service/healthcheck_service.go
Normal file
@@ -0,0 +1,624 @@
|
||||
// Filename: internal/service/healthcheck_service.go (最终校准版)
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/channel"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProxyCheckTargetURL = "https://www.google.com/generate_204"
|
||||
DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
StatusActive = "active"
|
||||
StatusInactive = "inactive"
|
||||
)
|
||||
|
||||
type HealthCheckServiceLogger struct{ *logrus.Entry }
|
||||
|
||||
type HealthCheckService struct {
|
||||
db *gorm.DB
|
||||
SettingsManager *settings.SettingsManager
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
groupManager *GroupManager
|
||||
channel channel.ChannelProxy
|
||||
keyValidationService *KeyValidationService
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
lastResultsMutex sync.RWMutex
|
||||
lastResults map[string]string
|
||||
groupCheckTimeMutex sync.Mutex
|
||||
groupNextCheckTime map[uint]time.Time
|
||||
}
|
||||
|
||||
func NewHealthCheckService(
|
||||
db *gorm.DB,
|
||||
ss *settings.SettingsManager,
|
||||
s store.Store,
|
||||
repo repository.KeyRepository,
|
||||
gm *GroupManager,
|
||||
ch channel.ChannelProxy,
|
||||
kvs *KeyValidationService,
|
||||
logger *logrus.Logger,
|
||||
) *HealthCheckService {
|
||||
return &HealthCheckService{
|
||||
db: db,
|
||||
SettingsManager: ss,
|
||||
store: s,
|
||||
keyRepo: repo,
|
||||
groupManager: gm,
|
||||
channel: ch,
|
||||
keyValidationService: kvs,
|
||||
logger: logger.WithField("component", "HealthCheck🩺"),
|
||||
stopChan: make(chan struct{}),
|
||||
lastResults: make(map[string]string),
|
||||
groupNextCheckTime: make(map[uint]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) Start() {
|
||||
s.logger.Info("Starting HealthCheckService with independent check loops...")
|
||||
s.wg.Add(4) // Now four loops
|
||||
go s.runKeyCheckLoop()
|
||||
go s.runUpstreamCheckLoop()
|
||||
go s.runProxyCheckLoop()
|
||||
go s.runBaseKeyCheckLoop()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) Stop() {
|
||||
s.logger.Info("Stopping HealthCheckService...")
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
s.logger.Info("HealthCheckService stopped gracefully.")
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
|
||||
s.lastResultsMutex.RLock()
|
||||
defer s.lastResultsMutex.RUnlock()
|
||||
resultsCopy := make(map[string]string, len(s.lastResults))
|
||||
for k, v := range s.lastResults {
|
||||
resultsCopy[k] = v
|
||||
}
|
||||
return resultsCopy
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Key check dynamic scheduler loop started.")
|
||||
|
||||
// 主调度循环,每分钟检查一次任务
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.scheduleKeyChecks()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Key check scheduler loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) scheduleKeyChecks() {
|
||||
groups := s.groupManager.GetAllGroups()
|
||||
now := time.Now()
|
||||
|
||||
s.groupCheckTimeMutex.Lock()
|
||||
defer s.groupCheckTimeMutex.Unlock()
|
||||
|
||||
for _, group := range groups {
|
||||
// 获取特定于组的运营配置
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
|
||||
continue
|
||||
}
|
||||
|
||||
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
|
||||
continue // 跳过禁用了健康检查的组
|
||||
}
|
||||
|
||||
var intervalMinutes int
|
||||
if opConfig.KeyCheckIntervalMinutes != nil {
|
||||
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
|
||||
}
|
||||
interval := time.Duration(intervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
continue // 跳过无效的检查周期
|
||||
}
|
||||
|
||||
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
|
||||
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
|
||||
go s.performKeyChecksForGroup(group, opConfig)
|
||||
s.groupNextCheckTime[group.ID] = now.Add(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runUpstreamCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Upstream check loop started.")
|
||||
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
|
||||
s.performUpstreamChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
|
||||
s.logger.Debug("Upstream check ticker fired.")
|
||||
s.performUpstreamChecks()
|
||||
}
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Upstream check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runProxyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Proxy check loop started.")
|
||||
if s.SettingsManager.GetSettings().EnableProxyCheck {
|
||||
s.performProxyChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if s.SettingsManager.GetSettings().EnableProxyCheck {
|
||||
s.logger.Debug("Proxy check ticker fired.")
|
||||
s.performProxyChecks()
|
||||
}
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Proxy check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(group.ID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build key check endpoint for group, skipping check cycle.")
|
||||
return
|
||||
}
|
||||
|
||||
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
|
||||
|
||||
log.Infof("Starting key health check cycle.")
|
||||
|
||||
var mappingsToCheck []models.GroupAPIKeyMapping
|
||||
err = s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", group.ID).
|
||||
Where("api_keys.master_status = ?", models.MasterStatusActive).
|
||||
Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{models.StatusActive, models.StatusDisabled, models.StatusCooldown}).
|
||||
Preload("APIKey").
|
||||
Find(&mappingsToCheck).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to fetch key mappings for health check.")
|
||||
return
|
||||
}
|
||||
if len(mappingsToCheck) == 0 {
|
||||
log.Info("No key mappings to check for this group.")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
|
||||
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
|
||||
var wg sync.WaitGroup
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 1 // 保证至少有一个 worker
|
||||
}
|
||||
for w := 1; w <= concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
for mapping := range jobs {
|
||||
s.checkAndProcessMapping(&mapping, timeout, endpoint)
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
for _, m := range mappingsToCheck {
|
||||
jobs <- m
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
log.Info("Finished key health check cycle.")
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
|
||||
if mapping.APIKey == nil {
|
||||
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
|
||||
// --- 诊断一:验证成功 (健康) ---
|
||||
if validationErr == nil {
|
||||
if mapping.Status != models.StatusActive {
|
||||
s.activateMapping(mapping)
|
||||
}
|
||||
return
|
||||
}
|
||||
errorString := validationErr.Error()
|
||||
// --- 诊断二:永久性错误 ---
|
||||
if CustomErrors.IsPermanentUpstreamError(errorString) {
|
||||
s.revokeMapping(mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断三:暂时性错误 ---
|
||||
if CustomErrors.IsTemporaryUpstreamError(errorString) {
|
||||
// Log with a higher level (WARN) since this is an actionable, proactive finding.
|
||||
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
|
||||
s.penalizeMapping(mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断四:其他未知或上游服务错误 ---
|
||||
|
||||
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
// Re-fetch group-specific operational config to get the correct thresholds
|
||||
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
|
||||
return
|
||||
}
|
||||
opConfig, buildErr := s.groupManager.BuildOperationalConfig(group)
|
||||
if buildErr != nil {
|
||||
s.logger.WithError(buildErr).Errorf("Failed to build operational config for group %d during penalty.", mapping.KeyGroupID)
|
||||
return
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
mapping.LastError = err.Error()
|
||||
mapping.ConsecutiveErrorCount++
|
||||
// Use the group-specific threshold
|
||||
threshold := *opConfig.KeyBlacklistThreshold
|
||||
if mapping.ConsecutiveErrorCount >= threshold {
|
||||
mapping.Status = models.StatusCooldown
|
||||
cooldownDuration := time.Duration(*opConfig.KeyCooldownMinutes) * time.Minute
|
||||
cooldownTime := time.Now().Add(cooldownDuration)
|
||||
mapping.CooldownUntil = &cooldownTime
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
|
||||
}
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
if oldStatus != mapping.Status {
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus == models.StatusBanned {
|
||||
return // Already banned, do nothing.
|
||||
}
|
||||
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = "Definitive error: " + err.Error()
|
||||
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
|
||||
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
|
||||
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performUpstreamChecks() {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
|
||||
return
|
||||
}
|
||||
if len(upstreams) == 0 {
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Starting validation for %d upstreams.", len(upstreams))
|
||||
var wg sync.WaitGroup
|
||||
for _, u := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(upstream *models.UpstreamEndpoint) {
|
||||
defer wg.Done()
|
||||
oldStatus := upstream.Status
|
||||
isAlive := s.checkEndpoint(upstream.URL, timeout)
|
||||
newStatus := StatusInactive
|
||||
if isAlive {
|
||||
newStatus = StatusActive
|
||||
}
|
||||
s.lastResultsMutex.Lock()
|
||||
s.lastResults[upstream.URL] = newStatus
|
||||
s.lastResultsMutex.Unlock()
|
||||
if oldStatus != newStatus {
|
||||
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
|
||||
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
|
||||
} else {
|
||||
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}(u)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) bool {
|
||||
client := http.Client{Timeout: timeout}
|
||||
resp, err := client.Head(urlStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode < http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performProxyChecks() {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
|
||||
var proxies []*models.ProxyConfig
|
||||
if err := s.db.Find(&proxies).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve proxies.")
|
||||
return
|
||||
}
|
||||
if len(proxies) == 0 {
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Starting validation for %d proxies.", len(proxies))
|
||||
var wg sync.WaitGroup
|
||||
for _, p := range proxies {
|
||||
wg.Add(1)
|
||||
go func(proxyCfg *models.ProxyConfig) {
|
||||
defer wg.Done()
|
||||
isAlive := s.checkProxy(proxyCfg, timeout)
|
||||
newStatus := StatusInactive
|
||||
if isAlive {
|
||||
newStatus = StatusActive
|
||||
}
|
||||
s.lastResultsMutex.Lock()
|
||||
s.lastResults[proxyCfg.Address] = newStatus
|
||||
s.lastResultsMutex.Unlock()
|
||||
if proxyCfg.Status != newStatus {
|
||||
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
|
||||
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
|
||||
}
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
|
||||
transport := &http.Transport{}
|
||||
switch proxyCfg.Protocol {
|
||||
case "http", "https":
|
||||
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format.")
|
||||
return false
|
||||
}
|
||||
transport.Proxy = http.ProxyURL(proxyUrl)
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer.")
|
||||
return false
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol.")
|
||||
return false
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
resp, err := client.Get(ProxyCheckTargetURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
GroupID: groupID,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
ChangeReason: "health_check",
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
|
||||
event := models.UpstreamHealthChangedEvent{
|
||||
UpstreamID: upstream.ID,
|
||||
UpstreamURL: upstream.URL,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
Latency: 0,
|
||||
Reason: "health_check",
|
||||
CheckedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Global Base Key Check (New Logic)
|
||||
// =========================================================================
|
||||
|
||||
func (s *HealthCheckService) runBaseKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Global base key check loop started.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
|
||||
if !settings.EnableBaseKeyCheck {
|
||||
s.logger.Info("Global base key check is disabled.")
|
||||
return
|
||||
}
|
||||
|
||||
// Perform an initial check on startup
|
||||
s.performBaseKeyChecks()
|
||||
|
||||
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.performBaseKeyChecks()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Global base key check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performBaseKeyChecks() {
|
||||
s.logger.Info("Starting global base key check cycle.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint := settings.BaseKeyCheckEndpoint
|
||||
concurrency := settings.BaseKeyCheckConcurrency
|
||||
keys, err := s.keyRepo.GetActiveMasterKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch active master keys for base check.")
|
||||
return
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
s.logger.Info("No active master keys to perform base check on.")
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Performing base check on %d active master keys.", len(keys))
|
||||
jobs := make(chan *models.APIKey, len(keys))
|
||||
var wg sync.WaitGroup
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // Safe default
|
||||
}
|
||||
for w := 0; w < concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint)
|
||||
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
|
||||
oldStatus := key.MasterStatus
|
||||
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
|
||||
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
|
||||
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
|
||||
} else {
|
||||
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for _, key := range keys {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.logger.Info("Global base key check cycle finished.")
|
||||
}
|
||||
|
||||
// 事件发布辅助函数
|
||||
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
|
||||
event := models.MasterKeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
OldMasterStatus: oldStatus,
|
||||
NewMasterStatus: newStatus,
|
||||
ChangeReason: "base_health_check",
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
}
|
||||
}
|
||||
397
internal/service/key_import_service.go
Normal file
397
internal/service/key_import_service.go
Normal file
@@ -0,0 +1,397 @@
|
||||
// Filename: internal/service/key_import_service.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/task"
|
||||
"gemini-balancer/internal/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeAddKeysToGroup = "add_keys_to_group"
|
||||
TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group"
|
||||
TaskTypeHardDeleteKeys = "hard_delete_keys"
|
||||
TaskTypeRestoreKeys = "restore_keys"
|
||||
chunkSize = 500
|
||||
)
|
||||
|
||||
type KeyImportService struct {
|
||||
taskService task.Reporter
|
||||
keyRepo repository.KeyRepository
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
apiKeyService *APIKeyService
|
||||
}
|
||||
|
||||
func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService {
|
||||
return &KeyImportService{
|
||||
taskService: ts,
|
||||
keyRepo: kr,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "KeyImportService🚀"),
|
||||
apiKeyService: as,
|
||||
}
|
||||
}
|
||||
|
||||
// --- 通用的 Panic-Safe 任務執行器 ---
|
||||
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
||||
s.logger.Error(err)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
taskFunc()
|
||||
}
|
||||
|
||||
// --- Public Task Starters ---
|
||||
|
||||
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in input text")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_hard_delete" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_restore_keys" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// --- Private Task Runners ---
|
||||
|
||||
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
// 步骤 1: 对输入的原始 key 列表进行去重。
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeyStrings []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
|
||||
}
|
||||
}
|
||||
if len(uniqueKeyStrings) == 0 {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||||
return
|
||||
}
|
||||
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
|
||||
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
||||
for i, keyStr := range uniqueKeyStrings {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
}
|
||||
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
alreadyLinkedIDSet := make(map[uint]struct{})
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
|
||||
var keysToLink []models.APIKey
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
}
|
||||
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
// 步骤 6: 分块处理【链接Key到组】的操作,并实时更新进度。
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToLink) {
|
||||
end = len(idsToLink)
|
||||
}
|
||||
chunk := idsToLink[i:end]
|
||||
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||||
return
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤 7: 准备最终结果并结束任务。
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": len(alreadyLinkedIDSet),
|
||||
"total_linked_count": len(allKeyModels),
|
||||
}
|
||||
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
if validateOnImport {
|
||||
s.publishImportGroupCompletedEvent(groupID, idsToLink)
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runUnlinkKeysTask
|
||||
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeys []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, kStr)
|
||||
}
|
||||
}
|
||||
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
idsToUnlink := make([]uint, len(keysToUnlink))
|
||||
for i, key := range keysToUnlink {
|
||||
idsToUnlink[i] = key.ID
|
||||
}
|
||||
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
var totalUnlinked int64
|
||||
// 步骤 3: 分块处理【解绑Key】的操作,并上报进度。
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToUnlink) {
|
||||
end = len(idsToUnlink)
|
||||
}
|
||||
chunk := idsToUnlink[i:end]
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||||
return
|
||||
}
|
||||
totalUnlinked += unlinked
|
||||
|
||||
for _, keyID := range chunk {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||||
}
|
||||
result := gin.H{
|
||||
"unlinked_count": totalUnlinked,
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
|
||||
var totalDeleted int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||||
return
|
||||
}
|
||||
totalDeleted += deleted
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
|
||||
var restoredCount int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||||
return
|
||||
}
|
||||
restoredCount += count
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
KeyID: keyID,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithError(err).WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).Error("Failed to publish single key change event.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
ChangeReason: reason,
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
|
||||
if len(keyIDs) == 0 {
|
||||
return
|
||||
}
|
||||
event := models.ImportGroupCompletedEvent{
|
||||
GroupID: groupID,
|
||||
KeyIDs: keyIDs,
|
||||
CompletedAt: time.Now(),
|
||||
}
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
||||
} else {
|
||||
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
||||
}
|
||||
}
|
||||
|
||||
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
// 1. [New] Find the keys to operate on.
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
return s.StartUnlinkKeysTask(groupID, keysAsText)
|
||||
}
|
||||
217
internal/service/key_validation_service.go
Normal file
217
internal/service/key_validation_service.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// Filename: internal/service/key_validation_service.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/task"
|
||||
"gemini-balancer/internal/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeTestKeys = "test_keys"
|
||||
)
|
||||
|
||||
type KeyValidationService struct {
|
||||
taskService task.Reporter
|
||||
channel channel.ChannelProxy
|
||||
db *gorm.DB
|
||||
SettingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
|
||||
return &KeyValidationService{
|
||||
taskService: ts,
|
||||
channel: ch,
|
||||
db: db,
|
||||
SettingsManager: ss,
|
||||
groupManager: gm,
|
||||
store: st,
|
||||
keyRepo: kr,
|
||||
logger: logger.WithField("component", "KeyValidationService🧐"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
|
||||
if err := s.keyRepo.Decrypt(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// This is a network-level error (e.g., timeout, DNS issue)
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil // Success
|
||||
}
|
||||
|
||||
// Read the body for more error details
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
var errorMsg string
|
||||
if readErr != nil {
|
||||
errorMsg = "Failed to read error response body"
|
||||
} else {
|
||||
errorMsg = string(bodyBytes)
|
||||
}
|
||||
|
||||
// This is a validation failure with a specific HTTP status code
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
||||
Code: "VALIDATION_FAILED",
|
||||
}
|
||||
}
|
||||
|
||||
// --- 异步任务方法 (全面适配新task包) ---
|
||||
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
keyStrings := utils.ParseKeysFromText(keysText)
|
||||
if len(keyStrings) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
||||
}
|
||||
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(apiKeyModels) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
|
||||
}
|
||||
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
|
||||
}
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
// [FIX] Correctly use the NewAPIError constructor for a missing group.
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
||||
}
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err // Pass up the error from task service (e.g., "task already running")
|
||||
}
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
|
||||
return nil, err
|
||||
}
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
} else {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
}
|
||||
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
finalResults := make([]models.KeyTestResult, len(keys))
|
||||
processedCount := 0
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
type job struct {
|
||||
Index int
|
||||
Value models.APIKey
|
||||
}
|
||||
jobs := make(chan job, len(keys))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := range jobs {
|
||||
apiKeyModel := j.Value
|
||||
keyToValidate := apiKeyModel
|
||||
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
|
||||
|
||||
var currentResult models.KeyTestResult
|
||||
event := models.RequestFinishedEvent{
|
||||
GroupID: groupID,
|
||||
KeyID: apiKeyModel.ID,
|
||||
}
|
||||
if validationErr == nil {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
|
||||
event.IsSuccess = true
|
||||
} else {
|
||||
var apiErr *CustomErrors.APIError
|
||||
if CustomErrors.As(validationErr, &apiErr) {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
|
||||
event.Error = apiErr
|
||||
} else {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
|
||||
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
}
|
||||
event.IsSuccess = false
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
finalResults[j.Index] = currentResult
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i, k := range keys {
|
||||
jobs <- job{Index: i, Value: k}
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
|
||||
}
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
|
||||
return s.StartTestKeysTask(groupID, keysAsText)
|
||||
}
|
||||
65
internal/service/log_service.go
Normal file
65
internal/service/log_service.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LogService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewLogService(db *gorm.DB) *LogService {
|
||||
return &LogService{db: db}
|
||||
}
|
||||
|
||||
// Record 记录一条日志到数据库 (TODO 暂时保留简单实现,后续再重构为异步)
|
||||
func (s *LogService) Record(log *models.RequestLog) error {
|
||||
return s.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, error) {
|
||||
var logs []models.RequestLog
|
||||
|
||||
query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c)).Order("request_time desc")
|
||||
|
||||
// 简单的分页 ( TODO 后续可以做得更复杂)
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 执行查询
|
||||
err := query.Limit(pageSize).Offset(offset).Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if modelName := c.Query("model_name"); modelName != "" {
|
||||
db = db.Where("model_name = ?", modelName)
|
||||
}
|
||||
if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
|
||||
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
|
||||
db = db.Where("is_success = ?", isSuccess)
|
||||
}
|
||||
}
|
||||
if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
|
||||
db = db.Where("status_code = ?", statusCode)
|
||||
}
|
||||
}
|
||||
if keyIDStr := c.Query("key_id"); keyIDStr != "" {
|
||||
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
|
||||
db = db.Where("key_id = ?", keyID)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
}
|
||||
267
internal/service/resource_service.go
Normal file
267
internal/service/resource_service.go
Normal file
@@ -0,0 +1,267 @@
|
||||
// Filename: internal/service/resource_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
apperrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoResourceAvailable = errors.New("no available resource found for the request")
|
||||
)
|
||||
|
||||
type RequestResources struct {
|
||||
KeyGroup *models.KeyGroup
|
||||
APIKey *models.APIKey
|
||||
UpstreamEndpoint *models.UpstreamEndpoint
|
||||
ProxyConfig *models.ProxyConfig
|
||||
RequestConfig *models.RequestConfig
|
||||
}
|
||||
|
||||
type ResourceService struct {
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
keyRepo repository.KeyRepository
|
||||
apiKeyService *APIKeyService
|
||||
logger *logrus.Entry
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
func NewResourceService(
|
||||
sm *settings.SettingsManager,
|
||||
gm *GroupManager,
|
||||
kr repository.KeyRepository,
|
||||
aks *APIKeyService,
|
||||
logger *logrus.Logger,
|
||||
) *ResourceService {
|
||||
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
|
||||
rs := &ResourceService{
|
||||
settingsManager: sm,
|
||||
groupManager: gm,
|
||||
keyRepo: kr,
|
||||
apiKeyService: aks,
|
||||
logger: logger.WithField("component", "ResourceService📦️"),
|
||||
}
|
||||
|
||||
rs.initOnce.Do(func() {
|
||||
go rs.preWarmCache(logger)
|
||||
})
|
||||
return rs
|
||||
|
||||
}
|
||||
|
||||
// --- [模式一:智能聚合模式] ---
|
||||
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
|
||||
log.Debug("Entering BasePool resource acquisition.")
|
||||
// 1.筛选出所有符合条件的候选组,并按优先级排序
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
|
||||
if len(candidateGroups) == 0 {
|
||||
log.Warn("No candidate groups found for BasePool construction.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 2.从 BasePool中,根据系统全局策略选择一个Key
|
||||
basePool := &repository.BasePool{
|
||||
CandidateGroups: candidateGroups,
|
||||
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
|
||||
}
|
||||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the BasePool.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 3. 组装最终资源
|
||||
// [关键] 在此模式下,RequestConfig 永远是空的,以保证透明性。
|
||||
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = &models.RequestConfig{} // 强制为空
|
||||
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// --- [模式二:精确路由模式] ---
|
||||
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
|
||||
log.Debug("Entering PreciseRoute resource acquisition.")
|
||||
|
||||
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
|
||||
|
||||
if !ok {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
|
||||
}
|
||||
|
||||
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
|
||||
}
|
||||
|
||||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
|
||||
resources, err := s.assembleRequestResources(targetGroup, apiKey)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to assemble resources for precise route.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = targetGroup.RequestConfig
|
||||
|
||||
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
|
||||
allGroups := s.groupManager.GetAllGroups()
|
||||
if len(allGroups) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
if authToken.IsAdmin {
|
||||
for _, group := range allGroups {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
for _, ag := range authToken.AllowedGroups {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
for _, group := range allGroups {
|
||||
if _, ok := allowedGroupIDs[group.ID]; ok {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]string, 0, len(allowedModelsSet))
|
||||
for modelName := range allowedModelsSet {
|
||||
result = append(result, modelName)
|
||||
}
|
||||
sort.Strings(result)
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
|
||||
selectedUpstream := s.selectUpstreamForGroup(group)
|
||||
if selectedUpstream == nil {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
|
||||
}
|
||||
var proxyConfig *models.ProxyConfig
|
||||
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
|
||||
// if group.EnableProxy && s.proxyModule != nil {
|
||||
// var err error
|
||||
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
|
||||
// if err != nil {
|
||||
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
|
||||
// }
|
||||
// }
|
||||
return &RequestResources{
|
||||
KeyGroup: group,
|
||||
APIKey: apiKey,
|
||||
UpstreamEndpoint: selectedUpstream,
|
||||
ProxyConfig: proxyConfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
|
||||
if len(group.AllowedUpstreams) > 0 {
|
||||
return group.AllowedUpstreams[0]
|
||||
}
|
||||
globalSettings := s.settingsManager.GetSettings()
|
||||
if globalSettings.DefaultUpstreamURL != "" {
|
||||
return &models.UpstreamEndpoint{URL: globalSettings.DefaultUpstreamURL, Status: "active"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
|
||||
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
return err
|
||||
}
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
|
||||
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
|
||||
}
|
||||
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
|
||||
allGroupsFromCache := s.groupManager.GetAllGroups()
|
||||
var candidateGroups []*models.KeyGroup
|
||||
// 1. 确定权限范围
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
isTokenRestricted := len(allowedGroupsFromToken) > 0
|
||||
if isTokenRestricted {
|
||||
for _, ag := range allowedGroupsFromToken {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
}
|
||||
// 2. 筛选
|
||||
for _, group := range allGroupsFromCache {
|
||||
// 检查Token权限
|
||||
if isTokenRestricted && !allowedGroupIDs[group.ID] {
|
||||
continue
|
||||
}
|
||||
// 检查模型是否被允许
|
||||
isModelAllowed := false
|
||||
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
|
||||
isModelAllowed = true
|
||||
} else {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
isModelAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if isModelAllowed {
|
||||
candidateGroups = append(candidateGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// 3.按 Order 字段升序排序
|
||||
sort.SliceStable(candidateGroups, func(i, j int) bool {
|
||||
return candidateGroups[i].Order < candidateGroups[j].Order
|
||||
})
|
||||
return candidateGroups
|
||||
}
|
||||
|
||||
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
|
||||
if authToken.IsAdmin {
|
||||
return true
|
||||
}
|
||||
for _, allowedGroup := range authToken.AllowedGroups {
|
||||
if allowedGroup.ID == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
83
internal/service/security_service.go
Normal file
83
internal/service/security_service.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Filename: internal/service/security_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256" // [NEW] Import crypto library for hashing
|
||||
"encoding/hex" // [NEW] Import hex encoding
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository" // [NEW] Import repository
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const loginAttemptsKey = "security:login_attempts"
|
||||
|
||||
type SecurityService struct {
|
||||
repo repository.AuthTokenRepository
|
||||
store store.Store
|
||||
SettingsManager *settings.SettingsManager
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewSecurityService signature updated to accept the repository.
|
||||
func NewSecurityService(repo repository.AuthTokenRepository, store store.Store, settingsManager *settings.SettingsManager, logger *logrus.Logger) *SecurityService {
|
||||
return &SecurityService{
|
||||
repo: repo,
|
||||
store: store,
|
||||
SettingsManager: settingsManager,
|
||||
logger: logger.WithField("component", "SecurityService🛡️"),
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticateToken is now secure and efficient.
|
||||
func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToken, error) {
|
||||
if tokenValue == "" {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
// [REFACTORED]
|
||||
// 1. Hash the incoming plaintext token.
|
||||
hash := sha256.Sum256([]byte(tokenValue))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// 2. Delegate the lookup to the repository using the hash.
|
||||
return s.repo.GetTokenByHashedValue(tokenHash)
|
||||
}
|
||||
|
||||
// IsIPBanned
|
||||
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
return s.store.Exists(banKey)
|
||||
}
|
||||
|
||||
// RecordFailedLoginAttempt
|
||||
func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip string) error {
|
||||
if !s.SettingsManager.IsIPBanEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maxAttempts := s.SettingsManager.GetMaxLoginAttempts()
|
||||
if count >= int64(maxAttempts) {
|
||||
banDuration := s.SettingsManager.GetIPBanDuration()
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
|
||||
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
|
||||
|
||||
s.store.HDel(loginAttemptsKey, ip)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
196
internal/service/stats_service.go
Normal file
196
internal/service/stats_service.go
Normal file
@@ -0,0 +1,196 @@
|
||||
// Filename: internal/service/stats_service.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/store"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type StatsService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService {
|
||||
return &StatsService{
|
||||
db: db,
|
||||
store: s,
|
||||
keyRepo: repo,
|
||||
logger: logger.WithField("component", "StatsService"),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) Start() {
|
||||
s.logger.Info("Starting event listener for stats maintenance.")
|
||||
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer sub.Close()
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChange(&event)
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping stats event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *StatsService) Stop() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
|
||||
if event.GroupID == 0 {
|
||||
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
||||
return
|
||||
}
|
||||
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
|
||||
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
|
||||
|
||||
switch event.ChangeReason {
|
||||
case "key_unlinked", "key_hard_deleted":
|
||||
if event.OldStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
} else {
|
||||
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
}
|
||||
case "key_linked":
|
||||
if event.NewStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
} else {
|
||||
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
}
|
||||
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
default:
|
||||
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
|
||||
var results []struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ?", groupID).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Scan(&results).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
totalKeys := int64(0)
|
||||
for _, res := range results {
|
||||
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
|
||||
totalKeys += res.Count
|
||||
}
|
||||
updates["total_keys"] = totalKeys
|
||||
|
||||
if err := s.store.Del(statsKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
||||
}
|
||||
if err := s.store.HSet(statsKey, updates); err != nil {
|
||||
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
|
||||
}
|
||||
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
|
||||
// TODO 逻辑:
|
||||
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
|
||||
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
|
||||
// 3. 组合成 DashboardStatsResponse
|
||||
// ... 这个方法的具体实现,我们可以在DashboardQueryService中完成,
|
||||
// 这里我们先确保StatsService的核心职责(维护缓存)已经完成。
|
||||
// 为了编译通过,我们先返回一个空对象。
|
||||
|
||||
// 伪代码:
|
||||
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
|
||||
// ...
|
||||
|
||||
return &models.DashboardStatsResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *StatsService) AggregateHourlyStats() error {
|
||||
s.logger.Info("Starting aggregation of the last hour's request data...")
|
||||
now := time.Now()
|
||||
endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00
|
||||
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
|
||||
|
||||
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
|
||||
type aggregationResult struct {
|
||||
GroupID uint
|
||||
ModelName string
|
||||
RequestCount int64
|
||||
SuccessCount int64
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
}
|
||||
var results []aggregationResult
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
|
||||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||||
Group("group_id, model_name").
|
||||
Scan(&results).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query aggregation data from request_logs: %w", err)
|
||||
}
|
||||
if len(results) == 0 {
|
||||
s.logger.Info("No request logs found in the last hour to aggregate. Skipping.")
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results))
|
||||
|
||||
var hourlyStats []models.StatsHourly
|
||||
for _, res := range results {
|
||||
hourlyStats = append(hourlyStats, models.StatsHourly{
|
||||
Time: startTime, // 所有记录的时间戳都是该小时的起点
|
||||
GroupID: res.GroupID,
|
||||
ModelName: res.ModelName,
|
||||
RequestCount: res.RequestCount,
|
||||
SuccessCount: res.SuccessCount,
|
||||
PromptTokens: res.PromptTokens,
|
||||
CompletionTokens: res.CompletionTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
||||
}).Create(&hourlyStats).Error
|
||||
}
|
||||
72
internal/service/token_manager.go
Normal file
72
internal/service/token_manager.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Filename: internal/service/token_manager.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const TopicTokenChanged = "events:token_changed"
|
||||
|
||||
type TokenManager struct {
|
||||
repo repository.AuthTokenRepository
|
||||
syncer *syncer.CacheSyncer[[]*models.AuthToken]
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewTokenManager's signature is updated to accept the new repository.
|
||||
func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, logger *logrus.Logger) (*TokenManager, error) {
|
||||
tm := &TokenManager{
|
||||
repo: repo,
|
||||
logger: logger.WithField("component", "TokenManager🔐"),
|
||||
}
|
||||
|
||||
tokenLoader := func() ([]*models.AuthToken, error) {
|
||||
tm.logger.Info("Loading all auth tokens via repository...")
|
||||
tokens, err := tm.repo.GetAllTokensWithGroups()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load auth tokens from repo: %w", err)
|
||||
}
|
||||
tm.logger.Infof("Successfully loaded and decrypted %d auth tokens into cache.", len(tokens))
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
|
||||
}
|
||||
tm.syncer = s
|
||||
|
||||
return tm, nil
|
||||
}
|
||||
|
||||
func (tm *TokenManager) GetAllTokens() []*models.AuthToken {
|
||||
return tm.syncer.Get()
|
||||
}
|
||||
|
||||
// BatchUpdateTokens is now a thin wrapper around the repository method.
|
||||
func (tm *TokenManager) BatchUpdateTokens(incomingTokens []*models.TokenUpdateRequest) error {
|
||||
tm.logger.Info("Delegating BatchUpdateTokens to repository...")
|
||||
|
||||
if err := tm.repo.BatchUpdateTokens(incomingTokens); err != nil {
|
||||
tm.logger.Errorf("Repository failed to batch update tokens: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
tm.logger.Info("BatchUpdateTokens finished successfully. Invalidating cache.")
|
||||
return tm.Invalidate()
|
||||
}
|
||||
|
||||
func (tm *TokenManager) Invalidate() error {
|
||||
return tm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
func (tm *TokenManager) Stop() {
|
||||
tm.syncer.Stop()
|
||||
}
|
||||
107
internal/settings/manager.go
Normal file
107
internal/settings/manager.go
Normal file
@@ -0,0 +1,107 @@
|
||||
// Filename: internal/settings/manager.go
|
||||
package settings
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/pkg/reflectutil"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ========= 便利的读取器 (Convenience Accessors) =========
|
||||
func (sm *SettingsManager) IsIPBanEnabled() bool {
|
||||
return sm.GetSettings().EnableIPBanning
|
||||
}
|
||||
|
||||
func (sm *SettingsManager) GetMaxLoginAttempts() int {
|
||||
return sm.GetSettings().MaxLoginAttempts
|
||||
}
|
||||
|
||||
func (sm *SettingsManager) GetIPBanDuration() time.Duration {
|
||||
minutes := sm.GetSettings().IPBanDurationMinutes
|
||||
return time.Duration(minutes) * time.Minute
|
||||
}
|
||||
|
||||
// GetSettingsCopy 返回当前系统设置的一个值副本,确保线程安全。
|
||||
func (sm *SettingsManager) GetSettingsCopy() models.SystemSettings {
|
||||
// 从 syncer 获取当前的设置指针
|
||||
currentSettingsPtr := sm.GetSettings()
|
||||
if currentSettingsPtr == nil {
|
||||
// 在 syncer 初始化完成前,返回一个安全的默认值
|
||||
return *defaultSystemSettings()
|
||||
}
|
||||
// 返回指针指向的值,Go会自动创建一个副本
|
||||
return *currentSettingsPtr
|
||||
}
|
||||
|
||||
// ========= 辅助与调试函数 (Helpers & Debugging) =========
|
||||
// [ DisplaySettings 现在接收一个参数,以便在加载后立即打印
|
||||
func (sm *SettingsManager) DisplaySettings(settings *models.SystemSettings) {
|
||||
if settings == nil {
|
||||
sm.logger.Warn("Cannot display settings, current settings object is nil.")
|
||||
return
|
||||
}
|
||||
sm.logger.Info("")
|
||||
sm.logger.Info("========= Runtime System Settings (from SettingsManager) =========")
|
||||
sm.logger.Infof(" - Request Timeout: %d seconds", settings.RequestTimeoutSeconds)
|
||||
sm.logger.Infof(" - Connect Timeout: %d seconds", settings.ConnectTimeoutSeconds)
|
||||
sm.logger.Infof(" - Max Retries: %d", settings.MaxRetries)
|
||||
sm.logger.Infof(" - Blacklist Threshold: %d", settings.BlacklistThreshold)
|
||||
sm.logger.Info("==================================================================")
|
||||
sm.logger.Info("")
|
||||
}
|
||||
|
||||
// defaultSystemSettings
|
||||
func defaultSystemSettings() *models.SystemSettings {
|
||||
settings := &models.SystemSettings{}
|
||||
settings.CustomHeaders = make(map[string]string)
|
||||
|
||||
v := reflect.ValueOf(settings).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
// 我们只对“简单组织”动刀
|
||||
kind := fieldValue.Kind()
|
||||
if kind == reflect.Int || kind == reflect.Int64 || kind == reflect.String || kind == reflect.Bool {
|
||||
defaultValue := field.Tag.Get("default")
|
||||
if defaultValue != "" {
|
||||
if err := reflectutil.SetFieldFromString(fieldValue, defaultValue); err != nil {
|
||||
panic(fmt.Sprintf("FATAL: Invalid default tag for primitive field %s ('%s'): %v", field.Name, defaultValue, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
// parseAndSetField “扫描”数据库字符串,并将其“开箱”到正确的Go类型中
|
||||
func parseAndSetField(fieldValue reflect.Value, dbValue string) error {
|
||||
if !fieldValue.CanSet() {
|
||||
return fmt.Errorf("field is not settable")
|
||||
}
|
||||
// 如果数据库值为空,我们不进行任何操作,以保留defaultSystemSettings设置的默认值
|
||||
if dbValue == "" {
|
||||
return nil
|
||||
}
|
||||
kind := fieldValue.Kind()
|
||||
switch kind {
|
||||
case reflect.Slice, reflect.Map:
|
||||
// 如果是 slice 或 map,我们就用“开箱器 (json.Unmarshal)”
|
||||
// fieldValue.Addr().Interface() 获取到的是指向该字段的指针,正是Unmarshal所需要的
|
||||
return json.Unmarshal([]byte(dbValue), fieldValue.Addr().Interface())
|
||||
default:
|
||||
// 对于所有原始类型,我们信任并复用现有的 reflectutil 工具
|
||||
return reflectutil.SetFieldFromString(fieldValue, dbValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 优雅地停止SettingsManager的后台syncer。
|
||||
func (sm *SettingsManager) Stop() {
|
||||
if sm.syncer != nil {
|
||||
sm.logger.Info("Stopping SettingsManager's syncer...")
|
||||
sm.syncer.Stop()
|
||||
}
|
||||
}
|
||||
240
internal/settings/settings.go
Normal file
240
internal/settings/settings.go
Normal file
@@ -0,0 +1,240 @@
|
||||
// file: gemini-balancer\internal\settings\settings.go
|
||||
package settings
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
const SettingsUpdateChannel = "system_settings:updated"
|
||||
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型
|
||||
type SettingsManager struct {
|
||||
db *gorm.DB
|
||||
syncer *syncer.CacheSyncer[*models.SystemSettings]
|
||||
logger *logrus.Entry
|
||||
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
|
||||
}
|
||||
|
||||
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
|
||||
sm := &SettingsManager{
|
||||
db: db,
|
||||
logger: logger.WithField("component", "SettingsManager⚙️"),
|
||||
jsonToFieldType: make(map[string]reflect.Type),
|
||||
}
|
||||
// settingsLoader 的职责:读取“砖块”,组装并返回“蓝图”
|
||||
settingsType := reflect.TypeOf(models.SystemSettings{})
|
||||
for i := 0; i < settingsType.NumField(); i++ {
|
||||
field := settingsType.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
sm.jsonToFieldType[jsonTag] = field.Type
|
||||
}
|
||||
}
|
||||
// settingsLoader 的职责:读取“砖块”,智能组装成“蓝图”
|
||||
settingsLoader := func() (*models.SystemSettings, error) {
|
||||
sm.logger.Info("Loading system settings from database...")
|
||||
var dbRecords []models.Setting
|
||||
if err := sm.db.Find(&dbRecords).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to load system settings from db: %w", err)
|
||||
}
|
||||
settingsMap := make(map[string]string)
|
||||
for _, record := range dbRecords {
|
||||
settingsMap[record.Key] = record.Value
|
||||
}
|
||||
// 从一个包含了所有“出厂设置”的“蓝图”开始
|
||||
settings := defaultSystemSettings()
|
||||
v := reflect.ValueOf(settings).Elem()
|
||||
t := v.Type()
|
||||
// [智能卸货]
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if dbValue, ok := settingsMap[jsonTag]; ok {
|
||||
|
||||
if err := parseAndSetField(fieldValue, dbValue); err != nil {
|
||||
sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" {
|
||||
if settings.DefaultUpstreamURL != "" {
|
||||
// 如果全局上游URL已设置,则基于它构建新的检查端点。
|
||||
originalEndpoint := settings.BaseKeyCheckEndpoint
|
||||
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
|
||||
settings.BaseKeyCheckEndpoint = derivedEndpoint
|
||||
sm.logger.Infof(
|
||||
"BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'",
|
||||
originalEndpoint, derivedEndpoint,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
|
||||
}
|
||||
|
||||
sm.logger.Info("System settings loaded and cached.")
|
||||
sm.DisplaySettings(settings)
|
||||
return settings, nil
|
||||
}
|
||||
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
|
||||
}
|
||||
sm.syncer = s
|
||||
go sm.ensureSettingsInitialized()
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图”
|
||||
func (sm *SettingsManager) GetSettings() *models.SystemSettings {
|
||||
return sm.syncer.Get()
|
||||
}
|
||||
|
||||
// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库
|
||||
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
|
||||
var settingsToUpdate []models.Setting
|
||||
for key, value := range settingsMap {
|
||||
fieldType, ok := sm.jsonToFieldType[key]
|
||||
if !ok {
|
||||
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
|
||||
continue
|
||||
}
|
||||
var dbValue string
|
||||
// [智能打包]
|
||||
// 如果字段是 slice 或 map,我们就将传入的 interface{} “打包”成 JSON string
|
||||
kind := fieldType.Kind()
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, marshalErr := json.Marshal(value)
|
||||
if marshalErr != nil {
|
||||
// [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。
|
||||
sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr)
|
||||
continue // 跳过,继续处理下一个key
|
||||
}
|
||||
dbValue = string(jsonBytes)
|
||||
} else if kind == reflect.Bool {
|
||||
if b, ok := value.(bool); ok {
|
||||
dbValue = strconv.FormatBool(b)
|
||||
} else {
|
||||
dbValue = "false"
|
||||
}
|
||||
} else {
|
||||
dbValue = fmt.Sprintf("%v", value)
|
||||
}
|
||||
settingsToUpdate = append(settingsToUpdate, models.Setting{
|
||||
Key: key,
|
||||
Value: dbValue,
|
||||
})
|
||||
}
|
||||
if len(settingsToUpdate) > 0 {
|
||||
err := sm.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value"}),
|
||||
}).Create(&settingsToUpdate).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update settings in db: %w", err)
|
||||
}
|
||||
}
|
||||
return sm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义
|
||||
func (sm *SettingsManager) ensureSettingsInitialized() {
|
||||
defaults := defaultSystemSettings()
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
var existing models.Setting
|
||||
if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound {
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
// [智能初始化]
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
// 为复杂类型,生成一个“空的”JSON字符串,例如 "[]" 或 "{}"
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
setting := models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"), // 元数据中的default,永远来自tag
|
||||
}
|
||||
if err := sm.db.Create(&setting).Error; err != nil {
|
||||
sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResetAndSaveSettings [核心新增] 將所有配置重置為其在 'default' 標籤中定義的值。
|
||||
|
||||
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
|
||||
defaults := defaultSystemSettings()
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
var settingsToSave []models.Setting
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
// [智能重置]
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
setting := models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"),
|
||||
}
|
||||
settingsToSave = append(settingsToSave, setting)
|
||||
}
|
||||
if len(settingsToSave) > 0 {
|
||||
err := sm.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "name", "description", "category", "default_value"}),
|
||||
}).Create(&settingsToSave).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to reset settings in db: %w", err)
|
||||
}
|
||||
}
|
||||
if err := sm.syncer.Invalidate(); err != nil {
|
||||
sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err)
|
||||
}
|
||||
return defaults, nil
|
||||
}
|
||||
30
internal/store/factory.go
Normal file
30
internal/store/factory.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewStore creates a new store based on the application configuration.
|
||||
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
|
||||
// 检查是否有Redis配置
|
||||
if cfg.Redis.DSN != "" {
|
||||
opts, err := redis.ParseURL(cfg.Redis.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis DSN: %w", err)
|
||||
}
|
||||
client := redis.NewClient(opts)
|
||||
if err := client.Ping(context.Background()).Err(); err != nil {
|
||||
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
|
||||
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
|
||||
}
|
||||
logger.Info("Successfully connected to Redis. Using Redis as store.")
|
||||
return NewRedisStore(client), nil
|
||||
}
|
||||
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
|
||||
return NewMemoryStore(logger), nil
|
||||
}
|
||||
642
internal/store/memory_store.go
Normal file
642
internal/store/memory_store.go
Normal file
@@ -0,0 +1,642 @@
|
||||
// Filename: internal/store/memory_store.go (统一存储重构版)
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure memoryStore implements Store interface
|
||||
var _ Store = (*memoryStore)(nil)
|
||||
|
||||
// [核心重构] memoryStoreItem 现在是通用容器,可以存储任何类型的值,并自带过期时间
|
||||
type memoryStoreItem struct {
|
||||
value interface{} // 可以是 []byte, []string, map[string]string, map[string]struct{}, []zsetMember
|
||||
expireAt time.Time
|
||||
}
|
||||
|
||||
// isExpired 检查一个条目是否已过期
|
||||
func (item *memoryStoreItem) isExpired() bool {
|
||||
return !item.expireAt.IsZero() && time.Now().After(item.expireAt)
|
||||
}
|
||||
|
||||
// zsetMember 保持不变
|
||||
type zsetMember struct {
|
||||
Value string
|
||||
Score float64
|
||||
}
|
||||
|
||||
// [核心重构] memoryStore 现在使用一个统一的 map 来存储所有数据
|
||||
type memoryStore struct {
|
||||
items map[string]*memoryStoreItem // 指向 item 的指针,以便原地修改
|
||||
pubsub map[string][]chan *Message
|
||||
mu sync.RWMutex
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewMemoryStore [核心重構] 構造函數也被簡化了
|
||||
func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
return &memoryStore{
|
||||
items: make(map[string]*memoryStoreItem),
|
||||
pubsub: make(map[string][]chan *Message),
|
||||
logger: logger.WithField("component", "store.memory 🗱"),
|
||||
}
|
||||
}
|
||||
|
||||
// [核心重构] getItem 是一个新的内部辅助函数,它封装了获取、检查过期和删除的通用逻辑
|
||||
func (s *memoryStore) getItem(key string, lockForWrite bool) *memoryStoreItem {
|
||||
if !lockForWrite {
|
||||
// 如果是读操作,先用读锁检查
|
||||
s.mu.RLock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
s.mu.RUnlock()
|
||||
// 如果不存在或已过期,需要尝试获取写锁来删除它
|
||||
if ok { // 只有在确定 item 存在但已过期时才需要删除
|
||||
s.mu.Lock()
|
||||
// 再次检查,防止在获取写锁期间状态已改变
|
||||
if item, ok := s.items[key]; ok && item.isExpired() {
|
||||
delete(s.items, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return nil // 无论如何都返回 nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
return item
|
||||
}
|
||||
|
||||
// 对于写操作,直接使用写锁
|
||||
item, ok := s.items[key]
|
||||
if ok && item.isExpired() {
|
||||
delete(s.items, key)
|
||||
return nil
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
// --- 所有接口方法现在都基于新的统一结构重写 ---
|
||||
|
||||
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var expireAt time.Time
|
||||
if ttl > 0 {
|
||||
expireAt = time.Now().Add(ttl)
|
||||
}
|
||||
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if value, ok := item.value.([]byte); ok {
|
||||
return value, nil
|
||||
}
|
||||
return nil, ErrNotFound // Type mismatch, treat as not found
|
||||
}
|
||||
|
||||
func (s *memoryStore) Del(keys ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, key := range keys {
|
||||
delete(s.items, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) delNoLock(keys ...string) {
|
||||
for _, key := range keys {
|
||||
delete(s.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memoryStore) Exists(key string) (bool, error) {
|
||||
return s.getItem(key, false) != nil, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if item := s.getItem(key, true); item != nil {
|
||||
return false, nil
|
||||
}
|
||||
var expireAt time.Time
|
||||
if ttl > 0 {
|
||||
expireAt = time.Now().Add(ttl)
|
||||
}
|
||||
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Close() error { return nil }
|
||||
|
||||
func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return nil
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
for _, field := range fields {
|
||||
delete(hash, field)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.hSetNoLock(key, values)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) hSetNoLock(key string, values map[string]any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
s.items[key] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok { // If key exists but is not a hash, create a new hash
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
for field, value := range values {
|
||||
hash[field] = fmt.Sprintf("%v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
result := make(map[string]string, len(hash))
|
||||
for k, v := range hash {
|
||||
result[k] = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.hIncrByNoLock(key, field, incr)
|
||||
}
|
||||
|
||||
func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
s.items[key] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
var currentVal int64
|
||||
if valStr, ok := hash[field]; ok {
|
||||
fmt.Sscanf(valStr, "%d", ¤tVal)
|
||||
}
|
||||
newVal := currentVal + incr
|
||||
hash[field] = fmt.Sprintf("%d", newVal)
|
||||
return newVal, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lPushNoLock(key, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) lPushNoLock(key string, values ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make([]string, 0)}
|
||||
s.items[key] = item
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
list = make([]string, 0)
|
||||
}
|
||||
stringValues := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
stringValues[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
item.value = append(stringValues, list...)
|
||||
}
|
||||
|
||||
func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lRemNoLock(key, count, value)
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) lRemNoLock(key string, count int64, value any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
valToRemove := fmt.Sprintf("%v", value)
|
||||
newList := make([]string, 0, len(list))
|
||||
removedCount := int64(0)
|
||||
for _, v := range list {
|
||||
if v == valToRemove && (count == 0 || removedCount < count) {
|
||||
removedCount++
|
||||
} else {
|
||||
newList = append(newList, v)
|
||||
}
|
||||
}
|
||||
item.value = newList
|
||||
}
|
||||
func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sAddNoLock(key, members...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) sAddNoLock(key string, members ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[key] = item
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
set = make(map[string]struct{})
|
||||
item.value = set
|
||||
}
|
||||
for _, member := range members {
|
||||
set[fmt.Sprintf("%v", member)] = struct{}{}
|
||||
}
|
||||
}
|
||||
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
if int64(len(set)) < count {
|
||||
count = int64(len(set))
|
||||
}
|
||||
popped := make([]string, 0, count)
|
||||
keys := make([]string, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
rand.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
|
||||
for i := int64(0); i < count; i++ {
|
||||
poppedKey := keys[i]
|
||||
popped = append(popped, poppedKey)
|
||||
delete(set, poppedKey)
|
||||
}
|
||||
return popped, nil
|
||||
}
|
||||
func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
return []string{}, nil
|
||||
}
|
||||
s.mu.RLock() // Lock needed for iterating map
|
||||
defer s.mu.RUnlock()
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sRemNoLock(key, members...)
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) sRemNoLock(key string, members ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, member := range members {
|
||||
delete(set, fmt.Sprintf("%v", member))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memoryStore) SRandMember(key string) (string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return members[rand.Intn(len(members))], nil
|
||||
}
|
||||
func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok || len(list) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
val := list[len(list)-1]
|
||||
item.value = append([]string{val}, list[:len(list)-1]...)
|
||||
return val, nil
|
||||
}
|
||||
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
l := int64(len(list))
|
||||
if index < 0 {
|
||||
index += l
|
||||
}
|
||||
if index < 0 || index >= l {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return list[index], nil
|
||||
}
|
||||
|
||||
// Zset methods... (ZAdd, ZRange, ZRem)
|
||||
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make([]zsetMember, 0)}
|
||||
s.items[key] = item
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
zset = make([]zsetMember, 0)
|
||||
}
|
||||
for memberVal, score := range members {
|
||||
found := false
|
||||
for i := range zset {
|
||||
if zset[i].Value == memberVal {
|
||||
zset[i].Score = score
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
zset = append(zset, zsetMember{Value: memberVal, Score: score})
|
||||
}
|
||||
}
|
||||
sort.Slice(zset, func(i, j int) bool {
|
||||
if zset[i].Score == zset[j].Score {
|
||||
return zset[i].Value < zset[j].Value
|
||||
}
|
||||
return zset[i].Score < zset[j].Score
|
||||
})
|
||||
item.value = zset
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
return []string{}, nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
l := int64(len(zset))
|
||||
if start < 0 {
|
||||
start += l
|
||||
}
|
||||
if stop < 0 {
|
||||
stop += l
|
||||
}
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
if start > stop || start >= l {
|
||||
return []string{}, nil
|
||||
}
|
||||
if stop >= l {
|
||||
stop = l - 1
|
||||
}
|
||||
result := make([]string, 0, stop-start+1)
|
||||
for i := start; i <= stop; i++ {
|
||||
result = append(result, zset[i].Value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return nil
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
membersToRemove := make(map[string]struct{}, len(members))
|
||||
for _, m := range members {
|
||||
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
|
||||
}
|
||||
newZSet := make([]zsetMember, 0, len(zset))
|
||||
for _, z := range zset {
|
||||
if _, exists := membersToRemove[z.Value]; !exists {
|
||||
newZSet = append(newZSet, z)
|
||||
}
|
||||
}
|
||||
item.value = newZSet
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
mainItem := s.getItem(mainKey, true)
|
||||
if mainItem == nil || len(mainItem.value.(map[string]struct{})) == 0 {
|
||||
cooldownItem := s.getItem(cooldownKey, true)
|
||||
if cooldownItem == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
// "Rename" by moving value and deleting old key
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainItem = cooldownItem
|
||||
}
|
||||
mainSet, ok := mainItem.value.(map[string]struct{})
|
||||
if !ok || len(mainSet) == 0 {
|
||||
return "", ErrNotFound // Should not happen after cycle logic
|
||||
}
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
// Pipeline implementation
|
||||
type memoryPipeliner struct {
|
||||
store *memoryStore
|
||||
ops []func()
|
||||
}
|
||||
|
||||
func (s *memoryStore) Pipeline() Pipeliner {
|
||||
return &memoryPipeliner{store: s}
|
||||
}
|
||||
func (p *memoryPipeliner) Exec() error {
|
||||
p.store.mu.Lock()
|
||||
defer p.store.mu.Unlock()
|
||||
for _, op := range p.ops {
|
||||
op()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// [核心修正] Expire 现在可以正确地为任何 key 设置过期时间
|
||||
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.ops = append(p.ops, func() {
|
||||
// This must be called within Exec's lock
|
||||
item := p.store.getItem(key, true)
|
||||
if item != nil {
|
||||
item.expireAt = time.Now().Add(expiration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// All other pipeliner methods...
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||
p.ops = append(p.ops, func() { p.store.hSetNoLock(key, values) })
|
||||
}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.ops = append(p.ops, func() { p.store.hIncrByNoLock(key, field, incr) })
|
||||
}
|
||||
func (p *memoryPipeliner) Del(keys ...string) {
|
||||
p.ops = append(p.ops, func() { p.store.delNoLock(keys...) })
|
||||
}
|
||||
func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.sAddNoLock(key, members...) })
|
||||
}
|
||||
func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.sRemNoLock(key, members...) })
|
||||
}
|
||||
func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.lPushNoLock(key, values...) })
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
||||
p.ops = append(p.ops, func() { p.store.lRemNoLock(key, count, value) })
|
||||
}
|
||||
|
||||
// Pub/Sub implementation (remains unchanged as it's a separate system)
|
||||
type memorySubscription struct {
|
||||
store *memoryStore
|
||||
channelName string
|
||||
msgChan chan *Message
|
||||
}
|
||||
|
||||
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
|
||||
func (ms *memorySubscription) Close() error {
|
||||
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
|
||||
}
|
||||
func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
subscribers, ok := s.pubsub[channel]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
msg := &Message{Channel: channel, Payload: message}
|
||||
for _, ch := range subscribers {
|
||||
select {
|
||||
case ch <- msg:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
s.logger.Warnf("Could not publish to a subscriber on channel '%s' within 100ms", channel)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
msgChan := make(chan *Message, 10)
|
||||
sub := &memorySubscription{store: s, channelName: channel, msgChan: msgChan}
|
||||
s.pubsub[channel] = append(s.pubsub[channel], msgChan)
|
||||
return sub, nil
|
||||
}
|
||||
func (s *memoryStore) removeSubscriber(channelName string, msgChan chan *Message) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
subscribers, ok := s.pubsub[channelName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
newSubscribers := make([]chan *Message, 0)
|
||||
for _, ch := range subscribers {
|
||||
if ch != msgChan {
|
||||
newSubscribers = append(newSubscribers, ch)
|
||||
}
|
||||
}
|
||||
if len(newSubscribers) == 0 {
|
||||
delete(s.pubsub, channelName)
|
||||
} else {
|
||||
s.pubsub[channelName] = newSubscribers
|
||||
}
|
||||
close(msgChan)
|
||||
return nil
|
||||
}
|
||||
271
internal/store/redis_store.go
Normal file
271
internal/store/redis_store.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ensure RedisStore implements Store interface
|
||||
var _ Store = (*RedisStore)(nil)
|
||||
|
||||
// RedisStore is a Redis-backed key-value store.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
popAndCycleScript *redis.Script
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new RedisStore instance.
|
||||
func NewRedisStore(client *redis.Client) Store {
|
||||
// Lua script for atomic pop-and-cycle operation.
|
||||
// KEYS[1]: main set key
|
||||
// KEYS[2]: cooldown set key
|
||||
const script = `
|
||||
if redis.call('SCARD', KEYS[1]) == 0 then
|
||||
if redis.call('SCARD', KEYS[2]) == 0 then
|
||||
return nil
|
||||
end
|
||||
redis.call('RENAME', KEYS[2], KEYS[1])
|
||||
end
|
||||
return redis.call('SPOP', KEYS[1])
|
||||
`
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
popAndCycleScript: redis.NewScript(script),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(context.Background(), key, value, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
val, err := s.client.Get(context.Background(), key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) Del(keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.Del(context.Background(), keys...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Exists(key string) (bool, error) {
|
||||
val, err := s.client.Exists(context.Background(), key).Result()
|
||||
return val > 0, err
|
||||
}
|
||||
|
||||
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(context.Background(), key, value, ttl).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HSet(key string, values map[string]any) error {
|
||||
return s.client.HSet(context.Background(), key, values).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(context.Background(), key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
|
||||
}
|
||||
func (s *RedisStore) HDel(key string, fields ...string) error {
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.HDel(context.Background(), key, fields...).Err()
|
||||
}
|
||||
func (s *RedisStore) LPush(key string, values ...any) error {
|
||||
return s.client.LPush(context.Background(), key, values...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) LRem(key string, count int64, value any) error {
|
||||
return s.client.LRem(context.Background(), key, count, value).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) SAdd(key string, members ...any) error {
|
||||
return s.client.SAdd(context.Background(), key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(context.Background(), key, count).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SMembers(key string) ([]string, error) {
|
||||
return s.client.SMembers(context.Background(), key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRem(key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.SRem(context.Background(), key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
member, err := s.client.SRandMember(context.Background(), key).Result()
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return member, nil
|
||||
}
|
||||
|
||||
// === 新增方法实现 ===
|
||||
|
||||
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
redisMembers := make([]redis.Z, 0, len(members))
|
||||
for member, score := range members {
|
||||
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
|
||||
}
|
||||
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
|
||||
}
|
||||
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(context.Background(), key, start, stop).Result()
|
||||
}
|
||||
func (s *RedisStore) ZRem(key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.ZRem(context.Background(), key, members...).Err()
|
||||
}
|
||||
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
// Lua script returns a string, so we need to type assert
|
||||
if str, ok := val.(string); ok {
|
||||
return str, nil
|
||||
}
|
||||
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
|
||||
}
|
||||
|
||||
type redisPipeliner struct{ pipe redis.Pipeliner }
|
||||
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) {
|
||||
p.pipe.HSet(context.Background(), key, values)
|
||||
}
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(context.Background(), key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) {
|
||||
if len(keys) > 0 {
|
||||
p.pipe.Del(context.Background(), keys...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) {
|
||||
p.pipe.SAdd(context.Background(), key, members...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) {
|
||||
if len(members) > 0 {
|
||||
p.pipe.SRem(context.Background(), key, members...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) {
|
||||
p.pipe.LPush(context.Background(), key, values...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(context.Background(), key, count, value)
|
||||
}
|
||||
|
||||
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(context.Background(), key, index).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(context.Background(), key, expiration)
|
||||
}
|
||||
|
||||
func (s *RedisStore) Pipeline() Pipeliner {
|
||||
return &redisPipeliner{pipe: s.client.Pipeline()}
|
||||
}
|
||||
|
||||
type redisSubscription struct {
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Channel() <-chan *Message {
|
||||
rs.once.Do(func() {
|
||||
rs.msgChan = make(chan *Message)
|
||||
go func() {
|
||||
defer close(rs.msgChan)
|
||||
for redisMsg := range rs.pubsub.Channel() {
|
||||
rs.msgChan <- &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
return rs.msgChan
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
|
||||
|
||||
func (s *RedisStore) Publish(channel string, message []byte) error {
|
||||
return s.client.Publish(context.Background(), channel, message).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(context.Background(), channel)
|
||||
_, err := pubsub.Receive(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
return &redisSubscription{pubsub: pubsub}, nil
|
||||
}
|
||||
88
internal/store/store.go
Normal file
88
internal/store/store.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when a key is not found in the store.
|
||||
var ErrNotFound = errors.New("key not found")
|
||||
|
||||
// Message is the struct for received pub/sub messages.
|
||||
type Message struct {
|
||||
Channel string
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
// Subscription represents an active subscription to a pub/sub channel.
|
||||
type Subscription interface {
|
||||
Channel() <-chan *Message
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Pipeliner defines an interface for executing a batch of commands.
|
||||
type Pipeliner interface {
|
||||
// General
|
||||
Del(keys ...string)
|
||||
Expire(key string, expiration time.Duration)
|
||||
|
||||
// HASH
|
||||
HSet(key string, values map[string]any)
|
||||
HIncrBy(key, field string, incr int64)
|
||||
|
||||
// SET
|
||||
SAdd(key string, members ...any)
|
||||
SRem(key string, members ...any)
|
||||
|
||||
// LIST
|
||||
LPush(key string, values ...any)
|
||||
LRem(key string, count int64, value any)
|
||||
|
||||
// Execution
|
||||
Exec() error
|
||||
}
|
||||
|
||||
// Store is the master interface for our cache service.
|
||||
type Store interface {
|
||||
// Basic K/V operations
|
||||
Set(key string, value []byte, ttl time.Duration) error
|
||||
Get(key string) ([]byte, error)
|
||||
Del(keys ...string) error
|
||||
Exists(key string) (bool, error)
|
||||
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
|
||||
|
||||
// HASH operations
|
||||
HSet(key string, values map[string]any) error
|
||||
HGetAll(key string) (map[string]string, error)
|
||||
HIncrBy(key, field string, incr int64) (int64, error)
|
||||
HDel(key string, fields ...string) error // [新增]
|
||||
|
||||
// LIST operations
|
||||
LPush(key string, values ...any) error
|
||||
LRem(key string, count int64, value any) error
|
||||
Rotate(key string) (string, error)
|
||||
LIndex(key string, index int64) (string, error)
|
||||
|
||||
// SET operations
|
||||
SAdd(key string, members ...any) error
|
||||
SPopN(key string, count int64) ([]string, error)
|
||||
SMembers(key string) ([]string, error)
|
||||
SRem(key string, members ...any) error
|
||||
SRandMember(key string) (string, error)
|
||||
|
||||
// Pub/Sub operations
|
||||
Publish(channel string, message []byte) error
|
||||
Subscribe(channel string) (Subscription, error)
|
||||
|
||||
// Pipeline (optional) - 我们在redis实现它,内存版暂时不实现
|
||||
Pipeline() Pipeliner
|
||||
|
||||
// Close closes the store and releases any underlying resources.
|
||||
Close() error
|
||||
|
||||
// === 新增方法,支持轮询策略 ===
|
||||
ZAdd(key string, members map[string]float64) error
|
||||
ZRange(key string, start, stop int64) ([]string, error)
|
||||
ZRem(key string, members ...any) error
|
||||
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
|
||||
}
|
||||
114
internal/syncer/syncer.go
Normal file
114
internal/syncer/syncer.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package syncer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/store"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LoaderFunc
|
||||
type LoaderFunc[T any] func() (T, error)
|
||||
|
||||
// CacheSyncer
|
||||
type CacheSyncer[T any] struct {
|
||||
mu sync.RWMutex
|
||||
cache T
|
||||
loader LoaderFunc[T]
|
||||
store store.Store
|
||||
channelName string
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewCacheSyncer
|
||||
func NewCacheSyncer[T any](
|
||||
loader LoaderFunc[T],
|
||||
store store.Store,
|
||||
channelName string,
|
||||
) (*CacheSyncer[T], error) {
|
||||
s := &CacheSyncer[T]{
|
||||
loader: loader,
|
||||
store: store,
|
||||
channelName: channelName,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
if err := s.reload(); err != nil {
|
||||
return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err)
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go s.listenForUpdates()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Get, Invalidate, Stop, reload 方法 .
|
||||
func (s *CacheSyncer[T]) Get() T {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cache
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) Invalidate() error {
|
||||
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
|
||||
return s.store.Publish(s.channelName, []byte("reload"))
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
log.Printf("INFO: CacheSyncer for channel '%s' stopped.", s.channelName)
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) reload() error {
|
||||
log.Printf("INFO: Reloading cache for channel '%s'...", s.channelName)
|
||||
newData, err := s.loader()
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to reload cache for '%s': %v", s.channelName, err)
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.cache = newData
|
||||
s.mu.Unlock()
|
||||
log.Printf("INFO: Cache for channel '%s' reloaded successfully.", s.channelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenForUpdates ...
|
||||
func (s *CacheSyncer[T]) listenForUpdates() {
|
||||
defer s.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
subscription, err := s.store.Subscribe(s.channelName)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
log.Printf("INFO: Subscribed to channel '%s' for cache invalidation.", s.channelName)
|
||||
|
||||
subscriberLoop:
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-subscription.Channel():
|
||||
if !ok {
|
||||
log.Printf("WARN: Subscription channel '%s' closed, will re-subscribe.", s.channelName)
|
||||
break subscriberLoop
|
||||
}
|
||||
log.Printf("INFO: Received invalidation notification on '%s', reloading cache.", s.channelName)
|
||||
if err := s.reload(); err != nil {
|
||||
log.Printf("ERROR: Failed to reload cache for '%s' after notification: %v", s.channelName, err)
|
||||
}
|
||||
case <-s.stopChan:
|
||||
subscription.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
subscription.Close()
|
||||
}
|
||||
}
|
||||
214
internal/task/task.go
Normal file
214
internal/task/task.go
Normal file
@@ -0,0 +1,214 @@
|
||||
// Filename: internal/task/task.go (最终校准版)
|
||||
package task
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/store"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
ResultTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
// Reporter 接口,定义了领域如何与任务服务交互。
|
||||
type Reporter interface {
|
||||
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||||
EndTaskByID(taskID, resourceID string, result any, taskErr error)
|
||||
UpdateProgressByID(taskID string, processed int) error
|
||||
UpdateTotalByID(taskID string, total int) error
|
||||
}
|
||||
|
||||
// Status 代表一个后台任务的完整状态
|
||||
type Status struct {
|
||||
ID string `json:"id"`
|
||||
TaskType string `json:"task_type"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
ResourceID string `json:"resource_id,omitempty"`
|
||||
Processed int `json:"processed"`
|
||||
Total int `json:"total"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// Task 是任务管理的核心服务
|
||||
type Task struct {
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewTask 是 Task 的构造函数
|
||||
func NewTask(store store.Store, logger *logrus.Logger) *Task {
|
||||
return &Task{
|
||||
store: store,
|
||||
logger: logger.WithField("component", "TaskService📋"),
|
||||
}
|
||||
}
|
||||
|
||||
var _ Reporter = (*Task)(nil)
|
||||
|
||||
func (s *Task) getResourceLockKey(resourceID string) string {
|
||||
return fmt.Sprintf("task:lock:%s", resourceID)
|
||||
}
|
||||
|
||||
func (s *Task) getTaskDataKey(taskID string) string {
|
||||
return fmt.Sprintf("task:data:%s", taskID)
|
||||
}
|
||||
|
||||
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
|
||||
func (s *Task) getIsRunningFlagKey(taskID string) string {
|
||||
return fmt.Sprintf("task:running:%s", taskID)
|
||||
}
|
||||
|
||||
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
|
||||
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
|
||||
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
||||
}
|
||||
|
||||
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
status := &Status{
|
||||
ID: taskID,
|
||||
TaskType: taskType,
|
||||
IsRunning: true,
|
||||
ResourceID: resourceID,
|
||||
Total: total,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
statusBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
|
||||
}
|
||||
|
||||
if timeout == 0 {
|
||||
timeout = ResultTTL * 24
|
||||
}
|
||||
|
||||
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
|
||||
}
|
||||
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
|
||||
}
|
||||
|
||||
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
|
||||
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
_ = s.store.Del(taskKey)
|
||||
return nil, fmt.Errorf("failed to set task running flag: %w", err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
defer func() {
|
||||
if err := s.store.Del(lockKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
|
||||
}
|
||||
}()
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
_ = s.store.Del(runningFlagKey)
|
||||
status, err := s.GetStatus(taskID)
|
||||
if err != nil {
|
||||
|
||||
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
|
||||
return
|
||||
}
|
||||
if !status.IsRunning {
|
||||
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
status.IsRunning = false
|
||||
status.FinishedAt = &now
|
||||
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
|
||||
if taskErr != nil {
|
||||
status.Error = taskErr.Error()
|
||||
} else {
|
||||
status.Result = resultData
|
||||
}
|
||||
updatedTaskBytes, _ := json.Marshal(status)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus 通过ID获取任务状态,供外部(如API Handler)调用
|
||||
func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
statusBytes, err := s.store.Get(taskKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get task status from store: %w", err)
|
||||
}
|
||||
|
||||
var status Status
|
||||
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
||||
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
|
||||
}
|
||||
|
||||
if !status.IsRunning && status.FinishedAt != nil {
|
||||
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// UpdateProgressByID 通过ID更新任务进度
|
||||
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if _, err := s.store.Get(runningFlagKey); err != nil {
|
||||
// 任务已结束,静默返回是预期行为。
|
||||
return nil
|
||||
}
|
||||
status, err := s.GetStatus(taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
|
||||
return nil
|
||||
}
|
||||
if !status.IsRunning {
|
||||
return nil
|
||||
}
|
||||
// 调用传入的 updater 函数来修改 status
|
||||
updater(status)
|
||||
statusBytes, marshalErr := json.Marshal(status)
|
||||
if marshalErr != nil {
|
||||
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
|
||||
return nil
|
||||
}
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
// 使用更长的TTL,确保运行中的任务不会过早过期
|
||||
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
status.Processed = processed
|
||||
})
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateTotalByID(taskID string, total int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
status.Total = total
|
||||
})
|
||||
}
|
||||
34
internal/utils/parser.go
Normal file
34
internal/utils/parser.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Filename: internal/utils/parser.go
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseKeysFromText 是一个共享的工具函数,用于从各种格式的文本中解析出密钥列表。
|
||||
// 因为它是共享的,所以函数名首字母大写以导出。
|
||||
func ParseKeysFromText(text string) []string {
|
||||
var keys []string
|
||||
if json.Unmarshal([]byte(text), &keys) == nil && len(keys) > 0 {
|
||||
var cleanedKeys []string
|
||||
for _, key := range keys {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed != "" {
|
||||
cleanedKeys = append(cleanedKeys, trimmed)
|
||||
}
|
||||
}
|
||||
return cleanedKeys
|
||||
}
|
||||
keys = []string{}
|
||||
delimiters := regexp.MustCompile(`[\s,;|\n\r\t]+`)
|
||||
splitKeys := delimiters.Split(strings.TrimSpace(text), -1)
|
||||
for _, key := range splitKeys {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed != "" {
|
||||
keys = append(keys, trimmed)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
20
internal/utils/validators.go
Normal file
20
internal/utils/validators.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Filename: internal/utils/validators.go
|
||||
package utils
|
||||
|
||||
import "regexp"
|
||||
|
||||
const MaxGroupNameLength = 32
|
||||
|
||||
// groupNameRegex validates that a group name consists only of lowercase letters, numbers, and hyphens.
|
||||
var groupNameRegex = regexp.MustCompile(`^[a-z0-9-]+$`)
|
||||
|
||||
// IsValidGroupName checks if the provided string is a valid group name.
|
||||
func IsValidGroupName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if len(name) > MaxGroupNameLength {
|
||||
return false
|
||||
}
|
||||
return groupNameRegex.MatchString(name)
|
||||
}
|
||||
43
internal/webhandlers/auth_handler.go
Normal file
43
internal/webhandlers/auth_handler.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Filename: internal/webhandlers/auth_handler.go (最终现代化改造版)
|
||||
package webhandlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/service" // [核心改造] 依赖service层
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// WebAuthHandler [核心改造] 依赖关系净化,注入SecurityService
|
||||
type WebAuthHandler struct {
|
||||
securityService *service.SecurityService
|
||||
}
|
||||
|
||||
// NewWebAuthHandler [核心改造] 构造函数更新
|
||||
func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler {
|
||||
return &WebAuthHandler{
|
||||
securityService: securityService,
|
||||
}
|
||||
}
|
||||
|
||||
// ShowLoginPage 保持不变
|
||||
func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) {
|
||||
errMsg := c.Query("error")
|
||||
from := c.Query("from") // 可以从登录失败的页面返回
|
||||
c.HTML(http.StatusOK, "auth.html", gin.H{
|
||||
"error": errMsg,
|
||||
"from": from,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleLogin [核心改造] 认证逻辑完全委托给SecurityService
|
||||
func (h *WebAuthHandler) HandleLogin(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD")
|
||||
}
|
||||
|
||||
// HandleLogout 保持不变
|
||||
func (h *WebAuthHandler) HandleLogout(c *gin.Context) {
|
||||
middleware.ClearAdminSessionCookie(c)
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
}
|
||||
119
internal/webhandlers/page_handler.go
Normal file
119
internal/webhandlers/page_handler.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Filename: internal/webhandlers/page_handler.go (测试前版本)
|
||||
|
||||
package webhandlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PageHandler
|
||||
type PageHandler struct {
|
||||
queryService *service.DashboardQueryService
|
||||
}
|
||||
|
||||
// 构造函数注入新依赖
|
||||
func NewPageHandler(queryService *service.DashboardQueryService) *PageHandler {
|
||||
return &PageHandler{queryService: queryService}
|
||||
}
|
||||
|
||||
// ShowDashboardPage 渲染现代化的监控面板页面
|
||||
func (h *PageHandler) ShowDashboardPage(c *gin.Context) {
|
||||
// [核心改造] 调用高速缓存读取方法,为模板准备数据
|
||||
overviewData, err := h.queryService.GetDashboardOverviewData()
|
||||
if err != nil {
|
||||
// 在SSR模式下,如果缓存未就绪,我们可以传递一个空对象或默认值
|
||||
// 这样页面至少可以渲染出骨架,而不是直接报错
|
||||
c.HTML(http.StatusOK, "dashboard.html", gin.H{
|
||||
"PageID": "dashboard",
|
||||
"pageTitle": "监控面板",
|
||||
"overview": nil, // 或者一个默认的空结构体实例
|
||||
"error": "Dashboard data is currently being generated. Please refresh in a moment.",
|
||||
"ssr_error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.HTML(http.StatusOK, "dashboard.html", gin.H{
|
||||
"PageID": "dashboard",
|
||||
"pageTitle": "监控面板",
|
||||
"overview": overviewData,
|
||||
"ssr_error": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ShowConfigEditorPage 渲染配置编辑页面。
|
||||
func (h *PageHandler) ShowConfigEditorPage(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "settings.html", gin.H{
|
||||
"PageID": "settings",
|
||||
"pageTitle": "配置编辑",
|
||||
"currentPage": "settings",
|
||||
})
|
||||
}
|
||||
|
||||
// ShowErrorLogsPage 渲染错误日志页面。
|
||||
func (h *PageHandler) ShowErrorLogsPage(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "logs.html", gin.H{
|
||||
"PageID": "logs",
|
||||
"pageTitle": "错误日志",
|
||||
})
|
||||
}
|
||||
|
||||
// ShowKeysPage 渲染密钥管理页面。
|
||||
func (h *PageHandler) ShowKeysPage(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "keys.html", gin.H{
|
||||
"PageID": "keys",
|
||||
"pageTitle": "密钥管理",
|
||||
})
|
||||
}
|
||||
|
||||
// ShowTasksPage renders the M:N testing page.
|
||||
func (h *PageHandler) ShowTasksPage(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "tasks.html", gin.H{
|
||||
"PageID": "tasks",
|
||||
"pageTitle": "计划任务",
|
||||
})
|
||||
}
|
||||
|
||||
// ShowTasksPage renders the M:N testing page.
|
||||
func (h *PageHandler) ShowChatPage(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "chat.html", gin.H{
|
||||
"PageID": "chat",
|
||||
"pageTitle": "Webchat",
|
||||
})
|
||||
}
|
||||
|
||||
// 它只负责从 DashboardQueryService 获取数据,并将其作为JSON字符串渲染出来。
|
||||
func (h *PageHandler) ShowDashboardDebugPage(c *gin.Context) {
|
||||
// 1. 调用与正式Dashboard完全相同的数据获取方法
|
||||
overviewData, err := h.queryService.GetDashboardOverviewData()
|
||||
if err != nil {
|
||||
// 如果出错,直接将错误信息打印出来
|
||||
c.HTML(http.StatusInternalServerError, "dashboard_debug.html", gin.H{
|
||||
"PageID": "dashboard_debug",
|
||||
"pageTitle": "Dashboard 诊断 - 错误",
|
||||
"overview_json": "获取数据时发生错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 2. 将获取到的Go数据结构,格式化(美化)成JSON字符串
|
||||
// 使用 MarshalIndent 是为了让输出的JSON带缩进,便于阅读
|
||||
jsonData, err := json.MarshalIndent(overviewData, "", " ")
|
||||
if err != nil {
|
||||
// 如果JSON序列化出错,也直接打印错误
|
||||
c.HTML(http.StatusInternalServerError, "dashboard_debug.html", gin.H{
|
||||
"PageID": "dashboard_debug",
|
||||
"pageTitle": "Dashboard 诊断 - 错误",
|
||||
"overview_json": "序列化JSON时发生错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 3. 将这个JSON字符串,直接传递给一个极简的HTML模板
|
||||
c.HTML(http.StatusOK, "dashboard_debug.html", gin.H{
|
||||
"PageID": "dashboard-debug",
|
||||
"pageTitle": "Dashboard 诊断 - 成功",
|
||||
"overview_json": string(jsonData), // 将 []byte 转换为 string
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user