Fix basepool & 优化 repo

This commit is contained in:
XOF
2025-11-23 22:42:58 +08:00
parent 2b0b9b67dc
commit 6c7283d51b
16 changed files with 1312 additions and 723 deletions

View File

@@ -28,12 +28,6 @@ type GeminiChannel struct {
httpClient *http.Client
}
// 用于安全提取信息的本地结构体
type requestMetadata struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -47,38 +41,50 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini
logger: logger,
httpClient: &http.Client{
Transport: transport,
Timeout: 0,
Timeout: 0, // Timeout is handled by the request context
},
}
}
// TransformRequest
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
modelName = ch.ExtractModel(c, requestBody)
return requestBody, modelName, nil
}
// ExtractModel
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
return ch.extractModelFromRequest(c, bodyBytes)
}
// 统一的模型提取逻辑优先从请求体解析失败则回退到从URL路径解析。
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
var p struct {
Model string `json:"model"`
}
_ = json.Unmarshal(requestBody, &p)
modelName = strings.TrimPrefix(p.Model, "models/")
if modelName == "" {
modelName = ch.extractModelFromPath(c.Request.URL.Path)
if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
return strings.TrimPrefix(p.Model, "models/")
}
return requestBody, modelName, nil
return ch.extractModelFromPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) extractModelFromPath(path string) string {
parts := strings.Split(path, "/")
for _, part := range parts {
// 覆盖更多模型名称格式
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
modelPart := strings.Split(part, ":")[0]
return modelPart
return strings.Split(part, ":")[0]
}
}
return ""
}
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
path := c.Request.URL.Path
return ch.isOpenAIPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
}
@@ -88,25 +94,28 @@ func (ch *GeminiChannel) ValidateKey(
targetURL string,
timeout time.Duration,
) *CustomErrors.APIError {
client := &http.Client{
Timeout: timeout,
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
}
ch.ModifyRequest(req, apiKey)
resp, err := client.Do(req)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
errorBody, _ := io.ReadAll(resp.Body)
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
@@ -115,10 +124,6 @@ func (ch *GeminiChannel) ValidateKey(
}
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
// TODO: [Future Refactoring] Decouple auth logic from URL path.
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
// This would make the channel more generic and adaptable to new upstream provider types.
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
} else {
@@ -133,24 +138,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
return true
}
var meta requestMetadata
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
var meta struct {
Stream bool `json:"stream"`
}
if json.Unmarshal(bodyBytes, &meta) == nil {
return meta.Stream
}
return false
}
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
return modelName
}
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
var rewrittenSegment string
if ch.IsOpenAICompatibleRequest(tempCtx) {
var apiEndpoint string
if ch.isOpenAIPath(originalPath) {
v1Index := strings.LastIndex(originalPath, "/v1/")
var apiEndpoint string
if v1Index != -1 {
apiEndpoint = originalPath[v1Index+len("/v1/"):]
} else {
@@ -158,69 +161,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
}
rewrittenSegment = "v1beta/openai/" + apiEndpoint
} else {
tempPath := originalPath
if strings.HasPrefix(tempPath, "/v1/") {
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
if strings.HasPrefix(originalPath, "/v1/") {
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
} else {
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
}
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
}
trimmedBasePath := strings.TrimSuffix(basePath, "/")
pathToJoin := rewrittenSegment
trimmedBasePath := strings.TrimSuffix(basePath, "/")
// 防止版本号重复拼接,例如 basePath 是 /v1beta而重写段也是 v1beta/..
versionPrefixes := []string{"v1beta", "v1"}
for _, prefix := range versionPrefixes {
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
break
}
}
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
if err != nil {
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
// 回退到简单的字符串拼接
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
}
return finalPath
}
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
// 这是一个桩实现,暂时不需要任何逻辑。
return nil
return nil // 桩实现
}
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
// 这是一个桩实现,暂时不需要任何逻辑。
// 桩实现
}
// ==========================================================
// ================== “智能路由”的核心引擎 ===================
// ==========================================================
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
log := ch.logger.WithField("correlation_id", params.CorrelationID)
targetURL, err := url.Parse(params.UpstreamURL)
if err != nil {
log.WithError(err).Error("Failed to parse upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
log.WithError(err).Error("Invalid upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
return
}
targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
if err != nil {
log.WithError(err).Error("Failed to create initial smart request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
log.WithError(err).Error("Failed to create initial smart stream request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
return
}
ch.ModifyRequest(initialReq, params.APIKey)
initialReq.Header.Del("Authorization")
resp, err := ch.httpClient.Do(initialReq)
if err != nil {
log.WithError(err).Error("Initial smart request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
log.WithError(err).Error("Initial smart stream request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
defer standardizedResp.Body.Close()
c.Writer.WriteHeader(standardizedResp.StatusCode)
for key, values := range standardizedResp.Header {
for _, value := range values {
@@ -228,45 +238,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR
}
}
io.Copy(c.Writer, standardizedResp.Body)
params.EventLogger.IsSuccess = false
params.EventLogger.StatusCode = resp.StatusCode
return
}
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
}
func (ch *GeminiChannel) processStreamAndRetry(
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
c *gin.Context,
initialRequestHeaders http.Header,
initialReader io.ReadCloser,
params SmartRequestParams,
log *logrus.Entry,
) {
defer initialReader.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
flusher, _ := c.Writer.(http.Flusher)
var accumulatedText strings.Builder
consecutiveRetryCount := 0
currentReader := initialReader
maxRetries := params.MaxRetries
retryDelay := params.RetryDelay
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
for {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected, stopping stream processing.")
return
}
var interruptionReason string
scanner := bufio.NewScanner(currentReader)
for scanner.Scan() {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected during scan.")
return
}
line := scanner.Text()
if line == "" {
continue
}
fmt.Fprintf(c.Writer, "%s\n\n", line)
flusher.Flush()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var payload models.GeminiSSEPayload
if err := json.Unmarshal([]byte(data), &payload); err != nil {
continue
}
if len(payload.Candidates) > 0 {
candidate := payload.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
@@ -285,52 +321,71 @@ func (ch *GeminiChannel) processStreamAndRetry(
}
}
currentReader.Close()
if interruptionReason == "" {
if err := scanner.Err(); err != nil {
log.WithError(err).Warn("Stream scanner encountered an error.")
interruptionReason = "SCANNER_ERROR"
} else {
log.Warn("Stream dropped unexpectedly without a finish reason.")
log.Warn("Stream connection dropped without a finish reason.")
interruptionReason = "CONNECTION_DROP"
}
}
if consecutiveRetryCount >= maxRetries {
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": http.StatusGatewayTimeout,
"status": "DEADLINE_EXCEEDED",
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
},
})
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
flusher.Flush()
return
}
consecutiveRetryCount++
params.EventLogger.Retries = consecutiveRetryCount
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
time.Sleep(retryDelay)
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBodyBytes, _ := json.Marshal(retryBody)
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
if err != nil {
log.WithError(err).Error("Failed to create retry request")
continue
}
retryReq.Header = initialRequestHeaders
ch.ModifyRequest(retryReq, params.APIKey)
retryReq.Header.Del("Authorization")
retryResp, err := ch.httpClient.Do(retryReq)
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
if err != nil {
log.WithError(err).Errorf("Retry request failed.")
} else {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
if retryResp.Body != nil {
retryResp.Body.Close()
}
}
log.WithError(err).Error("Retry request failed")
continue
}
if retryResp.StatusCode != http.StatusOK {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
retryResp.Body.Close()
continue
}
currentReader = retryResp.Body
}
}
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
retryBody := originalBody
// 找到最后一个 'user' 角色的消息索引
lastUserIndex := -1
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
if retryBody.Contents[i].Role == "user" {
@@ -338,25 +393,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st
break
}
}
history := []models.GeminiContent{
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
}
if lastUserIndex != -1 {
// 如果找到了 'user' 消息,将历史记录插入到其后
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
newContents = append(newContents, history...)
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
retryBody.Contents = newContents
} else {
// 如果没有 'user' 消息(理论上不应发生),则直接追加
retryBody.Contents = append(retryBody.Contents, history...)
}
return retryBody, nil
}
// ===============================================
// ========= 辅助函数区 (继承并强化) =========
// ===============================================
return retryBody
}
type googleAPIError struct {
Error struct {
@@ -397,25 +453,28 @@ func truncate(s string, n int) string {
return s
}
// standardizeError
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("Failed to read upstream error body")
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
bodyBytes = []byte("Failed to read upstream error body")
}
resp.Body.Close()
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
var standardizedPayload googleAPIError
// 即使解析失败,也要构建一个标准的错误结构体
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
standardizedPayload.Error.Code = resp.StatusCode
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
standardizedPayload.Error.Details = []interface{}{map[string]string{
"@type": "proxy.upstream.error",
"@type": "proxy.upstream.unparsed.error",
"body": truncate(string(bodyBytes), truncateLimit),
}}
}
newBodyBytes, _ := json.Marshal(standardizedPayload)
newResp := &http.Response{
StatusCode: resp.StatusCode,
@@ -425,10 +484,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int
}
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
newResp.Header.Set("Access-Control-Allow-Origin", "*")
return newResp
}
// errToJSON
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
}

View File

@@ -16,6 +16,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
SessionSecret string `mapstructure:"session_secret"`
EncryptionKey string `mapstructure:"encryption_key"`
Repository RepositoryConfig `mapstructure:"repository"`
}
// DatabaseConfig 存储数据库连接信息
@@ -43,19 +44,24 @@ type RedisConfig struct {
DSN string `mapstructure:"dsn"`
}
type RepositoryConfig struct {
BasePoolTTLMinutes int `mapstructure:"base_pool_ttl_minutes"`
BasePoolTTIMinutes int `mapstructure:"base_pool_tti_minutes"`
}
// LoadConfig 从文件和环境变量加载配置
func LoadConfig() (*Config, error) {
// 设置配置文件名和路径
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/gemini-balancer/") // for production
// 允许从环境变量读取
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
// 设置默认值
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.port", "9000")
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "text")
viper.SetDefault("log.enable_file", false)
@@ -67,6 +73,9 @@ func LoadConfig() (*Config, error) {
viper.SetDefault("database.conn_max_lifetime", "1h")
viper.SetDefault("encryption_key", "")
viper.SetDefault("repository.base_pool_ttl_minutes", 60)
viper.SetDefault("repository.base_pool_tti_minutes", 10)
// 读取配置文件
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {

View File

@@ -311,3 +311,8 @@ func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout t
defer resp.Body.Close()
return true
}
type Manager interface {
AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error)
// ... 其他需要暴露给外部服务的方法
}

View File

@@ -44,6 +44,7 @@ var (
ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."}
ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."}
ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."}
ErrProxyNotAvailable = &APIError{HTTPStatus: http.StatusNotFound, Code: "PROXY_ERROR", Message: "Required proxy is not available for this request."}
ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."}
ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}

View File

@@ -29,9 +29,7 @@ import (
"gorm.io/datatypes"
)
type proxyErrorKey int
const proxyErrKey proxyErrorKey = 0
type proxyErrorContextKey struct{}
type ProxyHandler struct {
resourceService *service.ResourceService
@@ -81,45 +79,51 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
h.handleListModelsRequest(c)
return
}
requestBody, err := io.ReadAll(c.Request.Body)
maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024)
requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
if err != nil {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body"))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read"))
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
c.Request.ContentLength = int64(len(requestBody))
modelName := h.channel.ExtractModel(c, requestBody)
groupName := c.Param("group_name")
isPreciseRouting := groupName != ""
if !isPreciseRouting && modelName == "" {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL"))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request"))
return
}
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
if apiErr, ok := err.(*errors.APIError); ok {
errToJSON(c, uuid.New().String(), apiErr)
} else {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error()))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources"))
}
return
}
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
if err != nil {
h.logger.WithError(err).Error("Failed to build operational config.")
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration"))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config"))
return
}
initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig)
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
if isOpenAICompatible {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
return
}
isStream := h.channel.IsStreamRequest(c, requestBody)
systemSettings := h.settingsManager.GetSettings()
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry {
if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry {
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
} else {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
@@ -129,219 +133,307 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
startTime := time.Now()
correlationID := uuid.New().String()
var finalRecorder *httptest.ResponseRecorder
var lastUsedResources *service.RequestResources
var finalProxyErr *errors.APIError
var isSuccess bool
var finalPromptTokens, finalCompletionTokens int
var actualRetries int = 0
defer func() {
if lastUsedResources == nil {
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
return
}
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
var finalPromptTokens, finalCompletionTokens, actualRetries int
finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
finalEvent.RequestLog.IsSuccess = isSuccess
finalEvent.RequestLog.Retries = actualRetries
if isSuccess {
finalEvent.RequestLog.PromptTokens = finalPromptTokens
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
}
defer h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
actualRetries, isPreciseRouting)
if finalRecorder != nil {
finalEvent.RequestLog.StatusCode = finalRecorder.Code
}
if !isSuccess {
if finalProxyErr != nil {
finalEvent.Error = finalProxyErr
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
} else if finalRecorder != nil {
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
finalEvent.Error = apiErr
finalEvent.RequestLog.ErrorCode = apiErr.Code
finalEvent.RequestLog.ErrorMessage = apiErr.Message
}
}
eventData, err := json.Marshal(finalEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
}
}()
var maxRetries int
if isPreciseRouting {
if finalOpConfig.MaxRetries != nil {
maxRetries = *finalOpConfig.MaxRetries
} else {
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
} else {
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
totalAttempts := maxRetries + 1
for attempt := 1; attempt <= totalAttempts; attempt++ {
if c.Request.Context().Err() != nil {
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
if finalProxyErr == nil {
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
}
break
}
var currentResources *service.RequestResources
var err error
if attempt == 1 {
currentResources = initialResources
} else {
actualRetries = attempt - 1
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
break
}
}
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig)
currentResources.RequestConfig = finalRequestConfig
lastUsedResources = currentResources
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID)
var attemptErr *errors.APIError
var attemptIsSuccess bool
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
if err != nil {
if apiErr, ok := err.(*errors.APIError); ok {
finalProxyErr = apiErr
} else {
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
}
break
}
lastUsedResources = resources
if attempt > 1 {
actualRetries = attempt - 1
}
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID)
recorder, attemptErr, attemptSuccess := h.executeProxyAttempt(
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
&finalPromptTokens, &finalCompletionTokens,
)
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
if isSuccess {
break
}
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
break
}
h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting)
}
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
}
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
recorder := httptest.NewRecorder()
attemptStartTime := time.Now()
var attemptErr *errors.APIError
var isSuccess bool
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
defer cancel()
attemptReq := c.Request.Clone(ctx)
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" {
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
isSuccess = false
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource")
continue
}
h.transparentProxy.Director = func(req *http.Request) {
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL)
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.Host = targetURL.Host
attemptReq.Body = io.NopCloser(bytes.NewReader(body))
attemptReq.ContentLength = int64(len(body))
h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
*attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
h.transparentProxy.ServeHTTP(recorder, attemptReq)
return recorder, attemptErr, isSuccess
}
func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) {
h.transparentProxy.Director = func(r *http.Request) {
targetURL, _ := url.Parse(res.UpstreamEndpoint.URL)
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
var pureClientPath string
if isPreciseRouting {
proxyPrefix := "/proxy/" + groupName
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
if isPrecise {
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
} else {
pureClientPath = req.URL.Path
pureClientPath = r.URL.Path
}
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath)
req.URL.Path = finalPath
h.logger.WithFields(logrus.Fields{
"correlation_id": correlationID,
"attempt": attempt,
"key_id": currentResources.APIKey.ID,
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
"final_request_url": req.URL.String(),
}).Infof("Director constructed final upstream request URL.")
req.Header.Del("Authorization")
h.channel.ModifyRequest(req, currentResources.APIKey)
req.Header.Set("X-Correlation-ID", correlationID)
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
r.Header.Del("Authorization")
h.channel.ModifyRequest(r, res.APIKey)
r.Header.Set("X-Correlation-ID", corrID)
}
transport := h.transparentProxy.Transport.(*http.Transport)
if currentResources.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address)
proxyURL, err := url.Parse(proxyURLStr)
if err == nil {
if res.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address)
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
transport.Proxy = http.ProxyURL(proxyURL)
} else {
transport.Proxy = http.ProxyFromEnvironment
}
} else {
transport.Proxy = http.ProxyFromEnvironment
}
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
defer resp.Body.Close()
var reader io.ReadCloser
var err error
isGzipped := resp.Header.Get("Content-Encoding") == "gzip"
if isGzipped {
reader, err = gzip.NewReader(resp.Body)
h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens)
}
func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
return func(resp *http.Response) error {
var reader io.ReadCloser = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
h.logger.WithError(err).Error("Failed to create gzip reader")
reader = resp.Body
} else {
reader = gzReader
resp.Header.Del("Content-Encoding")
}
defer reader.Close()
} else {
reader = resp.Body
}
defer reader.Close()
bodyBytes, err := io.ReadAll(reader)
if err != nil {
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error())
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message)))
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
return nil
}
if resp.StatusCode < 400 {
attemptIsSuccess = true
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes)
*isSuccess = true
*pTokens, *cTokens = extractUsage(bodyBytes)
} else {
parsedMsg := errors.ParseUpstreamError(bodyBytes)
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
}
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
h.transparentProxy.ServeHTTP(recorder, attemptReq)
finalRecorder = recorder
finalProxyErr = attemptErr
isSuccess = attemptIsSuccess
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
if isSuccess {
break
}
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
corrID := r.Header.Get("X-Correlation-ID")
log := h.logger.WithField("id", corrID)
log.Errorf("Transparent proxy encountered an error: %v", err)
errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError)
if !ok || errPtr == nil {
log.Error("FATAL: proxyErrorContextKey not found in context for error handler.")
defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred")
writeErrorToResponse(rw, defaultErr)
return
}
isUnretryableError := false
if finalProxyErr != nil {
if errors.IsUnretryableRequestError(finalProxyErr.Message) {
isUnretryableError = true
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message)
if *errPtr == nil {
if errors.IsClientNetworkError(err) {
*errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
*errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
}
if attempt >= totalAttempts || isUnretryableError {
break
writeErrorToResponse(rw, *errPtr)
}
func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) {
if attempt == 1 {
return initialResources, nil
}
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds())
retryEvent.IsSuccess = false
retryEvent.StatusCode = recorder.Code
retryEvent.Retries = actualRetries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.ErrorCode = attemptErr.Code
retryEvent.ErrorMessage = attemptErr.Message
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
return nil, err
}
eventData, _ := json.Marshal(retryEvent)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig)
resources.RequestConfig = finalRequestConfig
return resources, nil
}
func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool {
if attempt >= totalAttempts {
return true
}
if finalRecorder != nil {
bodyBytes := finalRecorder.Body.Bytes()
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
for k, v := range finalRecorder.Header() {
if strings.ToLower(k) != "content-length" {
if err != nil && errors.IsUnretryableRequestError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Attempt failed with unretryable request error. Aborting retries. Message: %s", err.Message)
return true
}
return false
}
func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) {
if rec != nil {
for k, v := range rec.Header() {
c.Writer.Header()[k] = v
}
}
c.Writer.WriteHeader(finalRecorder.Code)
c.Writer.Write(finalRecorder.Body.Bytes())
c.Writer.WriteHeader(rec.Code)
c.Writer.Write(rec.Body.Bytes())
} else if apiErr != nil {
errToJSON(c, corrID, apiErr)
} else {
errToJSON(c, correlationID, finalProxyErr)
errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred"))
}
}
func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) {
if res == nil {
h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.")
return
}
event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise)
event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
event.RequestLog.IsSuccess = isSuccess
event.RequestLog.Retries = retries
if isSuccess {
event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens
}
if rec != nil {
event.RequestLog.StatusCode = rec.Code
}
if !isSuccess {
errToLog := finalErr
if errToLog == nil && rec != nil {
errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), "Request failed after all retries.")
}
if errToLog != nil {
event.Error = errToLog
event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message
}
}
eventData, err := json.Marshal(event)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event")
}
}
func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) {
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
retryEvent.RequestLog.IsSuccess = false
retryEvent.RequestLog.StatusCode = rec.Code
retryEvent.RequestLog.Retries = retries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message
}
eventData, err := json.Marshal(retryEvent)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event")
}
}
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
finalConfig := &models.RequestConfig{
CustomHeaders: make(datatypes.JSONMap),
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
StreamMinDelay: globalSettings.StreamMinDelay,
StreamMaxDelay: globalSettings.StreamMaxDelay,
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
StreamChunkSize: globalSettings.StreamChunkSize,
EnableFakeStream: globalSettings.EnableFakeStream,
FakeStreamInterval: globalSettings.FakeStreamInterval,
}
for k, v := range globalSettings.CustomHeaders {
finalConfig.CustomHeaders[k] = v
}
if groupConfig == nil {
return finalConfig
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
return finalConfig
}
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
}
return finalConfig
}
func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) {
if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() {
return
}
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(apiErr.HTTPStatus)
json.NewEncoder(rw).Encode(gin.H{"error": apiErr})
}
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
startTime := time.Now()
correlationID := uuid.New().String()
@@ -349,7 +441,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
log.Info("Smart Gateway activated for streaming request.")
var originalRequest models.GeminiRequest
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error()))
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway"))
return
}
systemSettings := h.settingsManager.GetSettings()
@@ -360,8 +452,14 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
if c.Writer.Status() > 0 {
requestFinishedEvent.StatusCode = c.Writer.Status()
}
eventData, _ := json.Marshal(requestFinishedEvent)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
eventData, err := json.Marshal(requestFinishedEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream")
}
}()
params := channel.SmartRequestParams{
CorrelationID: correlationID,
@@ -378,30 +476,6 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
h.channel.ProcessSmartStreamRequest(c, params)
}
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
correlationID := r.Header.Get("X-Correlation-ID")
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
if !exists || proxyErrPtr == nil {
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
return
}
if errors.IsClientNetworkError(err) {
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
if _, ok := rw.(*httptest.ResponseRecorder); ok {
return
}
if writer, ok := rw.(interface{ Written() bool }); ok {
if writer.Written() {
return
}
}
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
}
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
event := &models.RequestFinishedEvent{
RequestLog: models.RequestLog{
@@ -456,12 +530,14 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
}
if isPreciseRouting {
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
} else {
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
}
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
}
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{
"error": apiErr,
"correlation_id": corrID,
@@ -471,7 +547,7 @@ func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
type bufferPool struct{}
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
func (b *bufferPool) Put(bytes []byte) {}
func (b *bufferPool) Put(_ []byte) {}
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
var data struct {
@@ -486,34 +562,11 @@ func extractUsage(body []byte) (promptTokens int, completionTokens int) {
return 0, 0
}
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders)
var customHeadersMap datatypes.JSONMap
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
finalConfig := &models.RequestConfig{
CustomHeaders: customHeadersMap,
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
StreamMinDelay: globalSettings.StreamMinDelay,
StreamMaxDelay: globalSettings.StreamMaxDelay,
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
StreamChunkSize: globalSettings.StreamChunkSize,
EnableFakeStream: globalSettings.EnableFakeStream,
FakeStreamInterval: globalSettings.FakeStreamInterval,
func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int {
if isPreciseRouting && finalOpConfig.MaxRetries != nil {
return *finalOpConfig.MaxRetries
}
if groupConfig == nil {
return finalConfig
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
return finalConfig
}
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
return finalConfig
}
return finalConfig
return h.settingsManager.GetSettings().MaxRetries
}
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {

View File

@@ -77,3 +77,9 @@ type APIKeyDetails struct {
CooldownUntil *time.Time `json:"cooldown_until"`
EncryptedKey string
}
// SettingsManager 定义了系统设置管理器的抽象接口。
type SettingsManager interface {
GetSettings() *SystemSettings
}

View File

@@ -11,6 +11,7 @@ type SystemSettings struct {
BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"`
KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间单位为分钟。"`
LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"`
MaxRequestBodySizeMB int `json:"max_request_body_size_mb" default:"10" name:"最大请求体大小 (MB)" category:"请求设置" desc:"允许代理接收的最大请求体大小单位为MB。超过此大小的请求将被拒绝。"`
PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"`
@@ -41,6 +42,10 @@ type SystemSettings struct {
MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前允许的连续登录失败次数。"`
IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长单位为分钟。"`
// BasePool 相关配置
// BasePoolTTLMinutes int `json:"base_pool_ttl_minutes" default:"30" name:"基础资源池最大生存时间(分钟)" category:"基础资源池" desc:"一个动态构建的基础资源池BasePool在Redis中的最大生存时间。到期后即使仍在活跃使用也会被强制重建。"`
// BasePoolTTIMinutes int `json:"base_pool_tti_minutes" default:"10" name:"基础资源池空闲超时(分钟)" category:"基础资源池" desc:"一个基础资源池BasePool在连续无请求后自动销毁的空闲等待时间。"`
//智能网关
LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时保留的最大字符数。0表示不截断。"`
EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"`

View File

@@ -1,4 +1,4 @@
// Filename: internal/repository/key_cache.go
// Filename: internal/repository/key_cache.go (最终定稿)
package repository
import (
@@ -9,6 +9,7 @@ import (
"strconv"
)
// --- Redis Key 常量定义 ---
const (
KeyGroup = "group:%d:keys:active"
KeyDetails = "key:%d:details"
@@ -23,13 +24,16 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
)
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系并完整重建Redis缓存。
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
}
// 1. 批量解密所有涉及的密钥
keyMap := make(map[uint]*models.APIKey)
for _, m := range allMappings {
if m.APIKey != nil {
@@ -41,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
keysToDecrypt = append(keysToDecrypt, *k)
}
if err := r.decryptKeys(keysToDecrypt); err != nil {
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
// 即使解密失败,也继续尝试加载未加密或已解密的部分
}
decryptedKeyMap := make(map[uint]models.APIKey)
for _, k := range keysToDecrypt {
decryptedKeyMap[k.ID] = k
}
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline(context.Background())
detailsToSet := make(map[string][]byte)
// 2. 清理所有分组的旧轮询结构
pipe := r.store.Pipeline(ctx)
var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil {
for _, group := range allGroups {
@@ -63,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
)
}
} else {
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
}
// 3. 准备批量更新数据
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
detailsToSet := make(map[string]any)
for _, mapping := range allMappings {
if mapping.APIKey == nil {
continue
}
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
if !ok {
continue
continue // 跳过解密失败的密钥
}
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
keyJSON, _ := json.Marshal(decryptedKey)
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
mappingJSON, _ := json.Marshal(mapping)
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
if mapping.Status == models.StatusActive {
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
}
}
// 4. 使用 MSet 批量写入详情和映射缓存
if len(detailsToSet) > 0 {
if err := r.store.MSet(ctx, detailsToSet); err != nil {
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
}
}
// 5. 在Pipeline中重建所有分组的轮询结构
for groupID, activeMappings := range activeKeysByGroup {
if len(activeMappings) == 0 {
continue
@@ -101,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
}
// 6. 执行Pipeline
if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(context.Background(), key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
}
r.logger.Info("Cache rebuild complete, including all polling structures.")
r.logger.Info("Full cache rebuild completed successfully.")
return nil
}
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err := r.decryptKey(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
@@ -128,78 +144,101 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
}
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
}
pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
for _, groupID := range groupIDs {
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
}
return pipe.Exec()
}
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline(context.Background())
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
groupID := mapping.KeyGroupID
ctx := context.Background()
pipe := r.store.Pipeline(ctx)
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
// 如果新状态是 Active则重新添加到所有轮询结构中
if mapping.Status == models.StatusActive {
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
// 无论状态如何,都更新映射详情的 K-V 缓存
mappingJSON, err := json.Marshal(mapping)
if err != nil {
return fmt.Errorf("failed to marshal mapping: %w", err)
}
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
return pipe.Exec()
}
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 {
return nil
}
groupUpdates := make(map[uint]struct {
ToAdd []interface{}
ToRemove []interface{}
})
pipe := r.store.Pipeline(ctx)
for _, mapping := range mappings {
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
update, ok := groupUpdates[mapping.KeyGroupID]
if !ok {
update = struct {
ToAdd []interface{}
ToRemove []interface{}
}{}
}
groupID := mapping.KeyGroupID
// 对于批处理中的每一个mapping都执行完整的、正确的“先删后增”逻辑
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
if mapping.Status == models.StatusActive {
update.ToRemove = append(update.ToRemove, keyIDStr)
update.ToAdd = append(update.ToAdd, keyIDStr)
} else {
update.ToRemove = append(update.ToRemove, keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
groupUpdates[mapping.KeyGroupID] = update
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
pipe := r.store.Pipeline(context.Background())
var pipelineError error
for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
if len(updates.ToRemove) > 0 {
for _, keyID := range updates.ToRemove {
pipe.LRem(activeKeyListKey, 0, keyID)
mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误以保证大部分更新成功
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
}
}
if len(updates.ToAdd) > 0 {
pipe.LPush(activeKeyListKey, updates.ToAdd...)
}
}
if err := pipe.Exec(); err != nil {
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
}
return pipelineError
return pipe.Exec()
}

View File

@@ -23,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
keyHashes := make([]string, len(keys))
keyValueToHashMap := make(map[string]string)
for i, k := range keys {
// All incoming keys must have plaintext APIKey
if k.APIKey == "" {
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
}
@@ -35,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
var finalKeys []models.APIKey
err := r.db.Transaction(func(tx *gorm.DB) error {
var existingKeys []models.APIKey
// [MODIFIED] Query by hash to find existing keys.
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
return err
}
@@ -69,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
}
}
if len(keysToCreate) > 0 {
// [MODIFIED] Create now only provides encrypted data and hash.
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
return err
}
}
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
return err
}
// [CRITICAL] Decrypt all keys before returning them to the service layer.
return r.decryptKeys(finalKeys)
})
return finalKeys, err
}
func (r *gormKeyRepository) Update(key *models.APIKey) error {
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
// This indicates a potential change that needs to be re-encrypted.
if key.APIKey != "" {
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
if err != nil {
@@ -98,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
key.APIKeyHash = hex.EncodeToString(hash[:])
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
return tx.Save(key).Error
})
if err != nil {
return err
}
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
return nil // Continue without cache update if decryption fails.
return nil
}
if err := r.updateStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
@@ -192,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt before returning.
return keys, r.decryptKeys(keys)
}

View File

@@ -1,4 +1,4 @@
// Filename: internal/repository/key_selector.go (经审查后最终修复版)
// Filename: internal/repository/key_selector.go
package repository
import (
@@ -19,38 +19,39 @@ import (
const (
CacheTTL = 5 * time.Minute
EmptyPoolPlaceholder = "EMPTY_POOL"
EmptyCacheTTL = 1 * time.Minute
)
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
if group == nil {
return nil, nil, fmt.Errorf("group cannot be nil")
}
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
default:
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
@@ -58,39 +59,44 @@ func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *model
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID))
go func() {
updateCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
}()
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 智能聚合模式设计的全新轮询器
// SelectOneActiveKeyFromBasePool 智能聚合池中选取一个可用Key
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
poolID := generatePoolID(pool.CandidateGroups)
if pool == nil || len(pool.CandidateGroups) == 0 {
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
}
poolID := r.generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
if !errors.Is(err, gorm.ErrRecordNotFound) {
log.WithError(err).Error("Failed to ensure BasePool cache exists")
}
return nil, nil, err
}
var keyIDStr string
var err error
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
@@ -98,8 +104,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
@@ -107,13 +117,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
@@ -122,73 +131,224 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.refreshBasePoolHeartbeat(bgCtx, poolID)
}()
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
if err != nil {
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
}
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
}
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
if err != nil {
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
return nil, nil, err
}
var originGroup *models.KeyGroup
for _, g := range pool.CandidateGroups {
if g.ID == uint(groupID) {
originGroup = g
break
}
}
if originGroup == nil {
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
}
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID))
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
}()
}
return apiKey, group, nil
}
}
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
return apiKey, originGroup, nil
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
exists, err := r.store.Exists(ctx, listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err
}
if exists {
val, err := r.store.LIndex(ctx, listKey, 0)
if err != nil {
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
// 预检查,快速失败
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
}
// 获取分布式锁
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
if err := r.acquireLock(ctx, lockKey); err != nil {
return err // acquireLock 内部已记录日志并返回明确错误
}
if !acquired {
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(ctx, pool, poolID)
defer r.releaseLock(context.Background(), lockKey)
// 双重检查锁定
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
defer r.store.Del(context.Background(), lockKey)
if exists, _ := r.store.Exists(ctx, listKey); exists {
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
var allActiveKeyIDs []string
lruMembers := make(map[string]float64)
// 在执行重度操作前,最后检查一次上下文是否已取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
keyToGroupMap := make(map[string]any)
allKeyIDsSet := make(map[string]struct{})
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey)
groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
continue
}
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
for _, keyID := range groupKeyIDs {
if _, exists := allKeyIDsSet[keyID]; !exists {
allKeyIDsSet[keyID] = struct{}{}
keyToGroupMap[keyID] = groupIDStr
}
}
}
// 处理空池情况
if len(allKeyIDsSet) == 0 {
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
if emptyCacheTTL < time.Minute {
emptyCacheTTL = time.Minute
}
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
for keyID := range allKeyIDsSet {
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
}
// 使用 Pipeline 原子化构建所有缓存结构
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
pipe := r.store.Pipeline(ctx)
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
if len(keyToGroupMap) > 0 {
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
pipe.Expire(keyToGroupMapKey, basePoolTTL)
}
pipe.Expire(mainPoolKey, basePoolTTL)
pipe.Expire(sequentialKey, basePoolTTL)
pipe.Expire(cooldownKey, basePoolTTL)
pipe.Expire(lruKey, basePoolTTL)
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
return err
}
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
// 异步填充 LRU 缓存,并传入已构建好的映射
go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
return nil
}
// --- 辅助方法 ---
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
const (
lockTTL = 30 * time.Second
lockMaxRetries = 5
lockBaseBackoff = 50 * time.Millisecond
)
for i := 0; i < lockMaxRetries; i++ {
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
return err
}
if acquired {
return nil
}
time.Sleep(lockBaseBackoff * (1 << i))
}
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
}
// releaseLock 封装了分布式锁的释放逻辑。
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
if err := r.store.Del(ctx, lockKey); err != nil {
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
}
}
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), duration)
}
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
// 使用 EXPIRE 命令来刷新如果Key不存在它什么也不做是安全的
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
}
}
}
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
func (r *gormKeyRepository) populateBasePoolLRUCache(
parentCtx context.Context,
currentPoolID string,
keys []string,
keyToGroupMap map[string]any,
) {
lruMembers := make(map[string]float64, len(keys))
for _, keyIDStr := range keys {
select {
case <-parentCtx.Done():
return
default:
}
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
if !ok {
continue
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err == nil && mapping != nil {
groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
data, err := r.store.Get(parentCtx, mappingKey)
if err != nil {
continue
}
var mapping models.GroupAPIKeyMapping
if json.Unmarshal(data, &mapping) == nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
@@ -196,44 +356,21 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool
lruMembers[keyIDStr] = score
}
}
}
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline(ctx)
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
pipe := r.store.Pipeline(ctx)
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
if err := pipe.Exec(); err != nil {
return err
}
if len(lruMembers) > 0 {
if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID)
lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
if parentCtx.Err() == nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
}
}
}
return nil
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
})
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
@@ -241,20 +378,19 @@ func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context,
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup) string {
func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
ids := make([]int, len(groups))
for i, g := range groups {
ids[i] = int(g.ID)
}
sort.Ints(ids)
h := sha1.New()
io.WriteString(h, fmt.Sprintf("%v", ids))
return fmt.Sprintf("%x", h.Sum(nil))
}
// toInterfaceSlice 类型转换辅助函数
func toInterfaceSlice(slice []string) []interface{} {
func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
result := make([]interface{}, len(slice))
for i, v := range slice {
result[i] = v
@@ -263,7 +399,7 @@ func toInterfaceSlice(slice []string) []interface{} {
}
// nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略
func nowMilli() float64 {
func (r *gormKeyRepository) nowMilli() float64 {
return float64(time.Now().UnixMilli())
}

View File

@@ -1,8 +1,9 @@
// Filename: internal/repository/repository.go (经审查后最终修复版)
// Filename: internal/repository/repository.go
package repository
import (
"context"
"gemini-balancer/internal/config"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -87,18 +88,20 @@ type gormKeyRepository struct {
store store.Store
logger *logrus.Entry
crypto *crypto.Service
config *config.Config
}
type gormGroupRepository struct {
db *gorm.DB
}
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository {
return &gormKeyRepository{
db: db,
store: s,
logger: logger.WithField("component", "repository.key🔗"),
crypto: crypto,
config: cfg,
}
}

View File

@@ -1,10 +1,9 @@
// Filename: internal/service/resource_service.go
package service
import (
"context"
"errors"
"gemini-balancer/internal/domain/proxy"
apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
@@ -16,10 +15,7 @@ import (
"github.com/sirupsen/logrus"
)
var (
ErrNoResourceAvailable = errors.New("no available resource found for the request")
)
// RequestResources 封装了一次成功请求所需的所有资源。
type RequestResources struct {
KeyGroup *models.KeyGroup
APIKey *models.APIKey
@@ -28,41 +24,51 @@ type RequestResources struct {
RequestConfig *models.RequestConfig
}
// ResourceService 负责根据请求参数和业务规则动态地选择和分配API密钥及相关资源。
type ResourceService struct {
settingsManager *settings.SettingsManager
groupManager *GroupManager
keyRepo repository.KeyRepository
authTokenRepo repository.AuthTokenRepository
apiKeyService *APIKeyService
proxyManager *proxy.Module
logger *logrus.Entry
initOnce sync.Once
}
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
func NewResourceService(
sm *settings.SettingsManager,
gm *GroupManager,
kr repository.KeyRepository,
atr repository.AuthTokenRepository,
aks *APIKeyService,
pm *proxy.Module,
logger *logrus.Logger,
) *ResourceService {
rs := &ResourceService{
settingsManager: sm,
groupManager: gm,
keyRepo: kr,
authTokenRepo: atr,
apiKeyService: aks,
proxyManager: pm,
logger: logger.WithField("component", "ResourceService📦"),
}
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
rs.initOnce.Do(func() {
go rs.preWarmCache(logger)
go rs.preWarmCache()
})
return rs
}
// GetResourceFromBasePool 使用智能聚合池模式获取资源。
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.")
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable
@@ -84,17 +90,18 @@ func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err
}
resources.RequestConfig = &models.RequestConfig{}
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil
}
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.")
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
}
@@ -113,37 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *m
log.WithError(err).Error("Failed to assemble resources for precise route.")
return nil, err
}
resources.RequestConfig = targetGroup.RequestConfig
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
return resources, nil
}
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
allGroups := s.groupManager.GetAllGroups()
if len(allGroups) == 0 {
return []string{}
}
allowedModelsSet := make(map[string]struct{})
allowedGroupIDs := make(map[uint]bool)
if authToken.IsAdmin {
for _, group := range allGroups {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
allowedGroupIDs[group.ID] = true
}
} else {
allowedGroupIDs := make(map[uint]bool)
for _, ag := range authToken.AllowedGroups {
allowedGroupIDs[ag.ID] = true
}
}
allowedModelsSet := make(map[string]struct{})
for _, group := range allGroups {
if _, ok := allowedGroupIDs[group.ID]; ok {
if allowedGroupIDs[group.ID] {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
}
}
result := make([]string, 0, len(allowedModelsSet))
for modelName := range allowedModelsSet {
result = append(result, modelName)
@@ -152,12 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
return result
}
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}
// --- 私有辅助方法 ---
// preWarmCache 在后台执行一次性的缓存预热任务。
func (s *ResourceService) preWarmCache() {
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
s.logger.Info("Performing initial key cache pre-warming...")
// 强制加载 GroupManager 缓存
s.logger.Info("Pre-warming GroupManager cache...")
_ = s.groupManager.GetAllGroups()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
defer cancel()
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
} else {
s.logger.Info("Initial key cache pre-warming completed successfully.")
}
}
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
selectedUpstream := s.selectUpstreamForGroup(group)
if selectedUpstream == nil {
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
}
var proxyConfig *models.ProxyConfig
var err error
// 只有在组明确启用代理时,才为其分配代理
if group.EnableProxy {
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
if err != nil {
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
}
}
return &RequestResources{
KeyGroup: group,
APIKey: apiKey,
@@ -166,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
}, nil
}
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
if len(group.AllowedUpstreams) > 0 {
// (未来可扩展负载均衡逻辑)
return group.AllowedUpstreams[0]
}
globalSettings := s.settingsManager.GetSettings()
@@ -177,56 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
return nil
}
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
time.Sleep(2 * time.Second)
s.logger.Info("Performing initial key cache pre-warming...")
if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil {
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
return err
}
s.logger.Info("Initial key cache pre-warming completed successfully.")
return nil
}
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
}
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup
allowedGroupIDs := make(map[uint]bool)
isTokenRestricted := len(allowedGroupsFromToken) > 0
if isTokenRestricted {
for _, ag := range allowedGroupsFromToken {
allowedGroupIDs[ag.ID] = true
}
}
for _, group := range allGroupsFromCache {
if isTokenRestricted && !allowedGroupIDs[group.ID] {
// 检查令牌权限
if !s.isTokenAllowedForGroup(authToken, group.ID) {
continue
}
isModelAllowed := false
if len(group.AllowedModels) == 0 {
isModelAllowed = true
} else {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
isModelAllowed = true
break
}
}
}
if isModelAllowed {
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
candidateGroups = append(candidateGroups, group)
}
}
sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order
})
return candidateGroups
}
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
return true
}
}
return false
}
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
if authToken.IsAdmin {
return true
@@ -238,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr
}
return false
}
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}

View File

@@ -1,4 +1,4 @@
// file: gemini-balancer\internal\settings\settings.go
// Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版)
package settings
import (
@@ -19,7 +19,9 @@ import (
const SettingsUpdateChannel = "system_settings:updated"
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型
var _ models.SettingsManager = (*SettingsManager)(nil)
// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。
type SettingsManager struct {
db *gorm.DB
syncer *syncer.CacheSyncer[*models.SystemSettings]
@@ -27,13 +29,14 @@ type SettingsManager struct {
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
}
// NewSettingsManager 创建一个新的 SettingsManager 实例。
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
sm := &SettingsManager{
db: db,
logger: logger.WithField("component", "SettingsManager⚙"),
jsonToFieldType: make(map[string]reflect.Type),
}
// settingsLoader 的职责:读取“砖块”,组装并返回“蓝图”
settingsType := reflect.TypeOf(models.SystemSettings{})
for i := 0; i < settingsType.NumField(); i++ {
field := settingsType.Field(i)
@@ -42,102 +45,89 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
sm.jsonToFieldType[jsonTag] = field.Type
}
}
// settingsLoader 的职责:读取“砖块”,智能组装成“蓝图”
settingsLoader := func() (*models.SystemSettings, error) {
sm.logger.Info("Loading system settings from database...")
var dbRecords []models.Setting
if err := sm.db.Find(&dbRecords).Error; err != nil {
return nil, fmt.Errorf("failed to load system settings from db: %w", err)
}
settingsMap := make(map[string]string)
for _, record := range dbRecords {
settingsMap[record.Key] = record.Value
}
// 从一个包含了所有“出厂设置”的“蓝图”开始
settings := defaultSystemSettings()
v := reflect.ValueOf(settings).Elem()
t := v.Type()
// [智能卸货]
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
fieldValue := v.Field(i)
jsonTag := field.Tag.Get("json")
if dbValue, ok := settingsMap[jsonTag]; ok {
if err := parseAndSetField(fieldValue, dbValue); err != nil {
sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
}
}
}
if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" {
if settings.DefaultUpstreamURL != "" {
// 如果全局上游URL已设置则基于它构建新的检查端点。
originalEndpoint := settings.BaseKeyCheckEndpoint
// [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。
if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" {
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint)
settings.BaseKeyCheckEndpoint = derivedEndpoint
sm.logger.Infof(
"BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'",
originalEndpoint, derivedEndpoint,
)
}
} else {
} else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" {
// 恢复 else 日志,以明确告知用户正在使用自定义覆盖。
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
}
sm.logger.Info("System settings loaded and cached.")
sm.DisplaySettings(settings)
return settings, nil
}
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
if err != nil {
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
}
sm.syncer = s
go sm.ensureSettingsInitialized()
if err := sm.ensureSettingsInitialized(); err != nil {
return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err)
}
return sm, nil
}
// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图”
// GetSettings 返回当前缓存的系统设置。
func (sm *SettingsManager) GetSettings() *models.SystemSettings {
return sm.syncer.Get()
}
// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库
// UpdateSettings 更新一个或多个系统设置。
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
var settingsToUpdate []models.Setting
for key, value := range settingsMap {
fieldType, ok := sm.jsonToFieldType[key]
if !ok {
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
continue
}
var dbValue string
// [智能打包]
// 如果字段是 slice 或 map我们就将传入的 interface{} “打包”成 JSON string
kind := fieldType.Kind()
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, marshalErr := json.Marshal(value)
if marshalErr != nil {
// [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。
sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr)
continue // 跳过继续处理下一个key
}
dbValue = string(jsonBytes)
} else if kind == reflect.Bool {
if b, ok := value.(bool); ok {
dbValue = strconv.FormatBool(b)
} else {
dbValue = "false"
}
} else {
dbValue = fmt.Sprintf("%v", value)
dbValue, err := sm.convertToDBValue(key, value, fieldType)
if err != nil {
sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err)
continue
}
settingsToUpdate = append(settingsToUpdate, models.Setting{
Key: key,
Value: dbValue,
})
}
if len(settingsToUpdate) > 0 {
err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
@@ -147,83 +137,20 @@ func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) er
return fmt.Errorf("failed to update settings in db: %w", err)
}
}
return sm.syncer.Invalidate()
if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err)
return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err)
}
return nil
}
// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义
func (sm *SettingsManager) ensureSettingsInitialized() {
defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var existing models.Setting
if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound {
var defaultValue string
kind := fieldValue.Kind()
// [智能初始化]
if kind == reflect.Slice || kind == reflect.Map {
// 为复杂类型生成一个“空的”JSON字符串例如 "[]" 或 "{}"
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"), // 元数据中的default永远来自tag
}
if err := sm.db.Create(&setting).Error; err != nil {
sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err)
}
}
}
}
// ResetAndSaveSettings [核心新增] 將所有配置重置為其在 'default' 標籤中定義的值。
// ResetAndSaveSettings 将所有设置重置为其默认值。
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
var settingsToSave []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
settingsToSave := sm.buildSettingsFromDefaults(defaults)
var defaultValue string
kind := fieldValue.Kind()
// [智能重置]
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
}
settingsToSave = append(settingsToSave, setting)
}
if len(settingsToSave) > 0 {
err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
@@ -233,8 +160,93 @@ func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error
return nil, fmt.Errorf("failed to reset settings in db: %w", err)
}
}
if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err)
sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err)
return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err)
}
return defaults, nil
}
// --- 私有辅助函数 ---
func (sm *SettingsManager) ensureSettingsInitialized() error {
defaults := defaultSystemSettings()
settingsToCreate := sm.buildSettingsFromDefaults(defaults)
for _, setting := range settingsToCreate {
var existing models.Setting
err := sm.db.Where("key = ?", setting.Key).First(&existing).Error
if err == gorm.ErrRecordNotFound {
sm.logger.Infof("Initializing new setting '%s'", setting.Key)
if createErr := sm.db.Create(&setting).Error; createErr != nil {
return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr)
}
} else if err != nil {
return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err)
}
}
return nil
}
func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting {
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
var settings []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var defaultValue string
kind := fieldValue.Kind()
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
settings = append(settings, models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
})
}
return settings
}
// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。
func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) {
kind := fieldType.Kind()
switch kind {
case reflect.Slice, reflect.Map:
jsonBytes, err := json.Marshal(value)
if err != nil {
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
}
return string(jsonBytes), nil
case reflect.Bool:
b, ok := value.(bool)
if !ok {
return "", fmt.Errorf("expected bool, but got %T", value)
}
return strconv.FormatBool(b), nil
default:
return fmt.Sprintf("%v", value), nil
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"math/rand"
"sort"
"strconv"
"sync"
"time"
@@ -65,7 +66,7 @@ func (s *memoryStore) startGCollector() {
}
}
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
@@ -108,6 +109,17 @@ func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
return ok && !item.isExpired(), nil
}
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok {
return ErrNotFound
}
item.expireAt = time.Now().Add(expiration)
return nil
}
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -159,6 +171,21 @@ func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any)
return nil
}
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
if hash, ok := item.value.(map[string]string); ok {
if value, exists := hash[field]; exists {
return value, nil
}
}
return "", ErrNotFound
}
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -351,6 +378,26 @@ func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error)
return members[n], nil
}
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
unionSet := make(map[string]struct{})
for _, key := range keys {
item, ok := s.items[key]
if !ok || item.isExpired() {
continue
}
if set, ok := item.value.(map[string]struct{}); ok {
for member := range set {
unionSet[member] = struct{}{}
}
}
}
destItem := &memoryStoreItem{value: unionSet}
s.items[destination] = destItem
return int64(len(unionSet)), nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -388,6 +435,16 @@ func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string
return list[index], nil
}
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
for key, value := range values {
// 内存存储不支持独立的 TTL因此我们假设永不过期
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
}
return nil
}
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -556,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
}
})
}
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
capturedKey := key
capturedValue := value
p.ops = append(p.ops, func() {
var expireAt time.Time
if expiration > 0 {
expireAt = time.Now().Add(expiration)
}
p.store.items[capturedKey] = &memoryStoreItem{
value: capturedValue,
expireAt: expireAt,
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -576,6 +649,7 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -615,11 +689,125 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
capturedKey := key
capturedValue := fmt.Sprintf("%v", value)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
list, ok := item.value.([]string)
if !ok {
return
}
newList := make([]string, 0, len(list))
removed := int64(0)
for _, v := range list {
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
removed++
continue
}
newList = append(newList, v)
}
item.value = newList
})
}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
capturedKey := key
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range capturedValues {
hash[field] = fmt.Sprintf("%v", value)
}
})
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
capturedKey := key
capturedField := field
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
hash[capturedField] = strconv.FormatInt(current+incr, 10)
})
}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
capturedKey := key
capturedMembers := make(map[string]float64, len(members))
for k, v := range members {
capturedMembers[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]float64)}
p.store.items[capturedKey] = item
}
zset, ok := item.value.(map[string]float64)
if !ok {
zset = make(map[string]float64)
item.value = zset
}
for member, score := range capturedMembers {
zset[member] = score
}
})
}
func (p *memoryPipeliner) ZRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
zset, ok := item.value.(map[string]float64)
if !ok {
return
}
for _, member := range capturedMembers {
delete(zset, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) MSet(values map[string]any) {
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
for key, value := range capturedValues {
p.store.items[key] = &memoryStoreItem{
value: value,
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
}
}
})
}
type memorySubscription struct {
store *memoryStore

View File

@@ -75,10 +75,24 @@ func (s *RedisStore) Close() error {
return s.client.Close()
}
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
return s.client.Expire(ctx, key, expiration).Err()
}
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(ctx, key, values).Err()
}
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
val, err := s.client.HGet(ctx, key, field).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
return val, nil
}
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(ctx, key).Result()
}
@@ -111,6 +125,18 @@ func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
return val, nil
}
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
if len(values) == 0 {
return nil
}
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
pairs := make([]interface{}, 0, len(values)*2)
for k, v := range values {
pairs = append(pairs, k, v)
}
return s.client.MSet(ctx, pairs...).Err()
}
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(ctx, key, members...).Err()
}
@@ -141,6 +167,13 @@ func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error
return member, nil
}
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
if len(keys) == 0 {
return 0, nil
}
return s.client.SUnionStore(ctx, destination, keys...).Result()
}
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
if len(members) == 0 {
return nil
@@ -216,6 +249,17 @@ func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx,
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value)
}
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
p.pipe.Set(p.ctx, key, value, expiration)
}
func (p *redisPipeliner) MSet(values map[string]any) {
if len(values) == 0 {
return
}
p.pipe.MSet(p.ctx, values)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {

View File

@@ -35,6 +35,8 @@ type Pipeliner interface {
HIncrBy(key, field string, incr int64)
// SET
MSet(values map[string]any)
Set(key string, value []byte, expiration time.Duration)
SAdd(key string, members ...any)
SRem(key string, members ...any)
@@ -58,9 +60,11 @@ type Store interface {
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
MSet(ctx context.Context, values map[string]any) error
// HASH operations
HSet(ctx context.Context, key string, values map[string]any) error
HGet(ctx context.Context, key, field string) (string, error)
HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(ctx context.Context, key string, fields ...string) error
@@ -70,6 +74,7 @@ type Store interface {
LRem(ctx context.Context, key string, count int64, value any) error
Rotate(ctx context.Context, key string) (string, error)
LIndex(ctx context.Context, key string, index int64) (string, error)
Expire(ctx context.Context, key string, expiration time.Duration) error
// SET operations
SAdd(ctx context.Context, key string, members ...any) error
@@ -77,6 +82,7 @@ type Store interface {
SMembers(ctx context.Context, key string) ([]string, error)
SRem(ctx context.Context, key string, members ...any) error
SRandMember(ctx context.Context, key string) (string, error)
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
// Pub/Sub operations
Publish(ctx context.Context, channel string, message []byte) error