diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 912c272..9819a81 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -28,12 +28,6 @@ type GeminiChannel struct { httpClient *http.Client } -// 用于安全提取信息的本地结构体 -type requestMetadata struct { - Model string `json:"model"` - Stream bool `json:"stream"` -} - func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -47,38 +41,50 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini logger: logger, httpClient: &http.Client{ Transport: transport, - Timeout: 0, + Timeout: 0, // Timeout is handled by the request context }, } } // TransformRequest func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) { + modelName = ch.ExtractModel(c, requestBody) + return requestBody, modelName, nil +} + +// ExtractModel +func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string { + return ch.extractModelFromRequest(c, bodyBytes) +} + +// 统一的模型提取逻辑:优先从请求体解析,失败则回退到从URL路径解析。 +func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string { var p struct { Model string `json:"model"` } - _ = json.Unmarshal(requestBody, &p) - modelName = strings.TrimPrefix(p.Model, "models/") - - if modelName == "" { - modelName = ch.extractModelFromPath(c.Request.URL.Path) + if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" { + return strings.TrimPrefix(p.Model, "models/") } - return requestBody, modelName, nil + return ch.extractModelFromPath(c.Request.URL.Path) } func (ch *GeminiChannel) extractModelFromPath(path string) string { parts := strings.Split(path, "/") for _, part := range parts { + // 覆盖更多模型名称格式 if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") { - modelPart := strings.Split(part, ":")[0] - return modelPart + return strings.Split(part, ":")[0] } } return "" } +// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。 func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool { - path := c.Request.URL.Path + return ch.isOpenAIPath(c.Request.URL.Path) +} + +func (ch *GeminiChannel) isOpenAIPath(path string) bool { return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings") } @@ -88,25 +94,28 @@ func (ch *GeminiChannel) ValidateKey( targetURL string, timeout time.Duration, ) *CustomErrors.APIError { - client := &http.Client{ - Timeout: timeout, - } + client := &http.Client{Timeout: timeout} + req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil) if err != nil { - return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error()) + return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request") } ch.ModifyRequest(req, apiKey) + resp, err := client.Do(req) if err != nil { - return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error()) + return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error()) } defer resp.Body.Close() + if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } + errorBody, _ := io.ReadAll(resp.Body) parsedMessage := CustomErrors.ParseUpstreamError(errorBody) + return &CustomErrors.APIError{ HTTPStatus: resp.StatusCode, Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), @@ -115,10 +124,6 @@ func (ch *GeminiChannel) ValidateKey( } func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) { - // TODO: [Future Refactoring] Decouple auth logic from URL path. - // The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property - // of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns. - // This would make the channel more generic and adaptable to new upstream provider types. if strings.Contains(req.URL.Path, "/v1beta/openai/") { req.Header.Set("Authorization", "Bearer "+apiKey.APIKey) } else { @@ -133,24 +138,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") { return true } - var meta requestMetadata - if err := json.Unmarshal(bodyBytes, &meta); err == nil { + var meta struct { + Stream bool `json:"stream"` + } + if json.Unmarshal(bodyBytes, &meta) == nil { return meta.Stream } return false } -func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string { - _, modelName, _ := ch.TransformRequest(c, bodyBytes) - return modelName -} - +// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。 func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string { - tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}} var rewrittenSegment string - if ch.IsOpenAICompatibleRequest(tempCtx) { - var apiEndpoint string + + if ch.isOpenAIPath(originalPath) { v1Index := strings.LastIndex(originalPath, "/v1/") + var apiEndpoint string if v1Index != -1 { apiEndpoint = originalPath[v1Index+len("/v1/"):] } else { @@ -158,69 +161,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string { } rewrittenSegment = "v1beta/openai/" + apiEndpoint } else { - tempPath := originalPath - if strings.HasPrefix(tempPath, "/v1/") { - tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/") + if strings.HasPrefix(originalPath, "/v1/") { + rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/") + } else { + rewrittenSegment = strings.TrimPrefix(originalPath, "/") } - rewrittenSegment = strings.TrimPrefix(tempPath, "/") } - trimmedBasePath := strings.TrimSuffix(basePath, "/") - pathToJoin := rewrittenSegment + trimmedBasePath := strings.TrimSuffix(basePath, "/") + + // 防止版本号重复拼接,例如 basePath 是 /v1beta,而重写段也是 v1beta/.. versionPrefixes := []string{"v1beta", "v1"} for _, prefix := range versionPrefixes { - if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") { - pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/") + if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") { + rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/") break } } - finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin) + + finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment) if err != nil { - return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/") + // 回退到简单的字符串拼接 + return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/") } return finalPath } func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error { - // 这是一个桩实现,暂时不需要任何逻辑。 - return nil + return nil // 桩实现 } func (ch *GeminiChannel) HandleError(c *gin.Context, err error) { - // 这是一个桩实现,暂时不需要任何逻辑。 + // 桩实现 } -// ========================================================== -// ================== “智能路由”的核心引擎 =================== -// ========================================================== func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) { log := ch.logger.WithField("correlation_id", params.CorrelationID) targetURL, err := url.Parse(params.UpstreamURL) if err != nil { - log.WithError(err).Error("Failed to parse upstream URL") - errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format")) + log.WithError(err).Error("Invalid upstream URL") + errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL")) return } + targetURL.Path = c.Request.URL.Path targetURL.RawQuery = c.Request.URL.RawQuery + initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody)) if err != nil { - log.WithError(err).Error("Failed to create initial smart request") - errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error())) + log.WithError(err).Error("Failed to create initial smart stream request") + errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request")) return } + ch.ModifyRequest(initialReq, params.APIKey) initialReq.Header.Del("Authorization") + resp, err := ch.httpClient.Do(initialReq) if err != nil { - log.WithError(err).Error("Initial smart request failed") - errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error())) + log.WithError(err).Error("Initial smart stream request failed") + errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed")) return } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { log.Warnf("Initial request received non-200 status: %d", resp.StatusCode) standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log) + defer standardizedResp.Body.Close() + c.Writer.WriteHeader(standardizedResp.StatusCode) for key, values := range standardizedResp.Header { for _, value := range values { @@ -228,45 +238,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR } } io.Copy(c.Writer, standardizedResp.Body) + params.EventLogger.IsSuccess = false params.EventLogger.StatusCode = resp.StatusCode return } + ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log) } func (ch *GeminiChannel) processStreamAndRetry( - c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry, + c *gin.Context, + initialRequestHeaders http.Header, + initialReader io.ReadCloser, + params SmartRequestParams, + log *logrus.Entry, ) { defer initialReader.Close() + c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") + flusher, _ := c.Writer.(http.Flusher) + var accumulatedText strings.Builder consecutiveRetryCount := 0 currentReader := initialReader maxRetries := params.MaxRetries retryDelay := params.RetryDelay + log.Infof("Starting smart stream session. Max retries: %d", maxRetries) + for { + if c.Request.Context().Err() != nil { + log.Info("Client disconnected, stopping stream processing.") + return + } + var interruptionReason string scanner := bufio.NewScanner(currentReader) + for scanner.Scan() { + if c.Request.Context().Err() != nil { + log.Info("Client disconnected during scan.") + return + } + line := scanner.Text() if line == "" { continue } + fmt.Fprintf(c.Writer, "%s\n\n", line) flusher.Flush() + if !strings.HasPrefix(line, "data: ") { continue } + data := strings.TrimPrefix(line, "data: ") var payload models.GeminiSSEPayload if err := json.Unmarshal([]byte(data), &payload); err != nil { continue } + if len(payload.Candidates) > 0 { candidate := payload.Candidates[0] if candidate.Content != nil && len(candidate.Content.Parts) > 0 { @@ -285,52 +321,71 @@ func (ch *GeminiChannel) processStreamAndRetry( } } currentReader.Close() + if interruptionReason == "" { if err := scanner.Err(); err != nil { log.WithError(err).Warn("Stream scanner encountered an error.") interruptionReason = "SCANNER_ERROR" } else { - log.Warn("Stream dropped unexpectedly without a finish reason.") + log.Warn("Stream connection dropped without a finish reason.") interruptionReason = "CONNECTION_DROP" } } + if consecutiveRetryCount >= maxRetries { - log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason) - errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}}) + log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason) + errData, _ := json.Marshal(map[string]interface{}{ + "error": map[string]interface{}{ + "code": http.StatusGatewayTimeout, + "status": "DEADLINE_EXCEEDED", + "message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason), + }, + }) fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData)) flusher.Flush() return } + consecutiveRetryCount++ params.EventLogger.Retries = consecutiveRetryCount log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay) + time.Sleep(retryDelay) - retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String()) + + retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String()) retryBodyBytes, _ := json.Marshal(retryBody) - retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes)) + retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes)) + if err != nil { + log.WithError(err).Error("Failed to create retry request") + continue + } + retryReq.Header = initialRequestHeaders ch.ModifyRequest(retryReq, params.APIKey) retryReq.Header.Del("Authorization") retryResp, err := ch.httpClient.Do(retryReq) - if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil { - if err != nil { - log.WithError(err).Errorf("Retry request failed.") - } else { - log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode) - if retryResp.Body != nil { - retryResp.Body.Close() - } - } + if err != nil { + log.WithError(err).Error("Retry request failed") continue } + + if retryResp.StatusCode != http.StatusOK { + log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode) + retryResp.Body.Close() + continue + } + currentReader = retryResp.Body } } -func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) { +// buildRetryRequestBody 正确处理多轮对话的上下文插入。 +func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest { retryBody := originalBody + + // 找到最后一个 'user' 角色的消息索引 lastUserIndex := -1 for i := len(retryBody.Contents) - 1; i >= 0; i-- { if retryBody.Contents[i].Role == "user" { @@ -338,25 +393,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st break } } + history := []models.GeminiContent{ {Role: "model", Parts: []models.Part{{Text: accumulatedText}}}, {Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}}, } + if lastUserIndex != -1 { + // 如果找到了 'user' 消息,将历史记录插入到其后 newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2) newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...) newContents = append(newContents, history...) newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...) retryBody.Contents = newContents } else { + // 如果没有 'user' 消息(理论上不应发生),则直接追加 retryBody.Contents = append(retryBody.Contents, history...) } - return retryBody, nil -} -// =============================================== -// ========= 辅助函数区 (继承并强化) ========= -// =============================================== + return retryBody +} type googleAPIError struct { Error struct { @@ -397,25 +453,28 @@ func truncate(s string, n int) string { return s } -// standardizeError func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { log.WithError(err).Error("Failed to read upstream error body") - bodyBytes = []byte("Failed to read upstream error body: " + err.Error()) + bodyBytes = []byte("Failed to read upstream error body") } resp.Body.Close() - log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit)) + + log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit)) + var standardizedPayload googleAPIError + // 即使解析失败,也要构建一个标准的错误结构体 if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 { standardizedPayload.Error.Code = resp.StatusCode standardizedPayload.Error.Message = http.StatusText(resp.StatusCode) standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode) standardizedPayload.Error.Details = []interface{}{map[string]string{ - "@type": "proxy.upstream.error", + "@type": "proxy.upstream.unparsed.error", "body": truncate(string(bodyBytes), truncateLimit), }} } + newBodyBytes, _ := json.Marshal(standardizedPayload) newResp := &http.Response{ StatusCode: resp.StatusCode, @@ -425,10 +484,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int } newResp.Header.Set("Content-Type", "application/json; charset=utf-8") newResp.Header.Set("Access-Control-Allow-Origin", "*") + return newResp } -// errToJSON func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) { + if c.IsAborted() { + return + } c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr}) } diff --git a/internal/config/config.go b/internal/config/config.go index fb34491..e9c877f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,9 +13,10 @@ type Config struct { Database DatabaseConfig Server ServerConfig Log LogConfig - Redis RedisConfig `mapstructure:"redis"` - SessionSecret string `mapstructure:"session_secret"` - EncryptionKey string `mapstructure:"encryption_key"` + Redis RedisConfig `mapstructure:"redis"` + SessionSecret string `mapstructure:"session_secret"` + EncryptionKey string `mapstructure:"encryption_key"` + Repository RepositoryConfig `mapstructure:"repository"` } // DatabaseConfig 存储数据库连接信息 @@ -43,19 +44,24 @@ type RedisConfig struct { DSN string `mapstructure:"dsn"` } +type RepositoryConfig struct { + BasePoolTTLMinutes int `mapstructure:"base_pool_ttl_minutes"` + BasePoolTTIMinutes int `mapstructure:"base_pool_tti_minutes"` +} + // LoadConfig 从文件和环境变量加载配置 func LoadConfig() (*Config, error) { // 设置配置文件名和路径 viper.SetConfigName("config") viper.SetConfigType("yaml") viper.AddConfigPath(".") - + viper.AddConfigPath("/etc/gemini-balancer/") // for production // 允许从环境变量读取 viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() // 设置默认值 - viper.SetDefault("server.port", "8080") + viper.SetDefault("server.port", "9000") viper.SetDefault("log.level", "info") viper.SetDefault("log.format", "text") viper.SetDefault("log.enable_file", false) @@ -67,6 +73,9 @@ func LoadConfig() (*Config, error) { viper.SetDefault("database.conn_max_lifetime", "1h") viper.SetDefault("encryption_key", "") + viper.SetDefault("repository.base_pool_ttl_minutes", 60) + viper.SetDefault("repository.base_pool_tti_minutes", 10) + // 读取配置文件 if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { diff --git a/internal/domain/proxy/manager.go b/internal/domain/proxy/manager.go index 653f164..0ec0105 100644 --- a/internal/domain/proxy/manager.go +++ b/internal/domain/proxy/manager.go @@ -311,3 +311,8 @@ func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout t defer resp.Body.Close() return true } + +type Manager interface { + AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) + // ... 其他需要暴露给外部服务的方法 +} diff --git a/internal/errors/api_error.go b/internal/errors/api_error.go index c4b31bb..fb90f78 100644 --- a/internal/errors/api_error.go +++ b/internal/errors/api_error.go @@ -44,6 +44,7 @@ var ( ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."} ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."} ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."} + ErrProxyNotAvailable = &APIError{HTTPStatus: http.StatusNotFound, Code: "PROXY_ERROR", Message: "Required proxy is not available for this request."} ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."} ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"} diff --git a/internal/handlers/proxy_handler.go b/internal/handlers/proxy_handler.go index 853ae8c..1271e51 100644 --- a/internal/handlers/proxy_handler.go +++ b/internal/handlers/proxy_handler.go @@ -29,9 +29,7 @@ import ( "gorm.io/datatypes" ) -type proxyErrorKey int - -const proxyErrKey proxyErrorKey = 0 +type proxyErrorContextKey struct{} type ProxyHandler struct { resourceService *service.ResourceService @@ -81,45 +79,51 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) { h.handleListModelsRequest(c) return } - requestBody, err := io.ReadAll(c.Request.Body) + + maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024) + requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize)) if err != nil { - errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body")) + errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read")) return } - c.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) - c.Request.ContentLength = int64(len(requestBody)) + modelName := h.channel.ExtractModel(c, requestBody) groupName := c.Param("group_name") isPreciseRouting := groupName != "" + if !isPreciseRouting && modelName == "" { - errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL")) + errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request")) return } + initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) if err != nil { if apiErr, ok := err.(*errors.APIError); ok { errToJSON(c, uuid.New().String(), apiErr) } else { - errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error())) + errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources")) } return } + finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup) if err != nil { h.logger.WithError(err).Error("Failed to build operational config.") - errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration")) + errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config")) return } + initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig) + isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c) if isOpenAICompatible { h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) return } + isStream := h.channel.IsStreamRequest(c, requestBody) - systemSettings := h.settingsManager.GetSettings() useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway - if useSmartGateway && isStream && systemSettings.EnableStreamingRetry { + if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry { h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting) } else { h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) @@ -129,219 +133,307 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) { func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) { startTime := time.Now() correlationID := uuid.New().String() + var finalRecorder *httptest.ResponseRecorder var lastUsedResources *service.RequestResources var finalProxyErr *errors.APIError var isSuccess bool - var finalPromptTokens, finalCompletionTokens int - var actualRetries int = 0 - defer func() { - if lastUsedResources == nil { - h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.") - return - } - finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting) + var finalPromptTokens, finalCompletionTokens, actualRetries int - finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) - finalEvent.RequestLog.IsSuccess = isSuccess - finalEvent.RequestLog.Retries = actualRetries - if isSuccess { - finalEvent.RequestLog.PromptTokens = finalPromptTokens - finalEvent.RequestLog.CompletionTokens = finalCompletionTokens - } + defer h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources, + finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens, + actualRetries, isPreciseRouting) - if finalRecorder != nil { - finalEvent.RequestLog.StatusCode = finalRecorder.Code - } - if !isSuccess { - if finalProxyErr != nil { - finalEvent.Error = finalProxyErr - finalEvent.RequestLog.ErrorCode = finalProxyErr.Code - finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message - } else if finalRecorder != nil { - apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.") - finalEvent.Error = apiErr - finalEvent.RequestLog.ErrorCode = apiErr.Code - finalEvent.RequestLog.ErrorMessage = apiErr.Message - } - } - eventData, err := json.Marshal(finalEvent) - if err != nil { - h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.") - return - } - if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { - h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.") - } - }() - var maxRetries int - if isPreciseRouting { - if finalOpConfig.MaxRetries != nil { - maxRetries = *finalOpConfig.MaxRetries - } else { - maxRetries = h.settingsManager.GetSettings().MaxRetries - } - } else { - maxRetries = h.settingsManager.GetSettings().MaxRetries - } + maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig) totalAttempts := maxRetries + 1 + for attempt := 1; attempt <= totalAttempts; attempt++ { if c.Request.Context().Err() != nil { h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.") if finalProxyErr == nil { - finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") + finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected") } break } - var currentResources *service.RequestResources - var err error - if attempt == 1 { - currentResources = initialResources - } else { - actualRetries = attempt - 1 - h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt) - currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) - if err != nil { - h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err) - finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry") - break + + resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID) + if err != nil { + if apiErr, ok := err.(*errors.APIError); ok { + finalProxyErr = apiErr + } else { + finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry") } + break + } + lastUsedResources = resources + if attempt > 1 { + actualRetries = attempt - 1 } - finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig) - currentResources.RequestConfig = finalRequestConfig - lastUsedResources = currentResources - h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID) - var attemptErr *errors.APIError - var attemptIsSuccess bool - recorder := httptest.NewRecorder() - attemptStartTime := time.Now() - connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second - ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout) - defer cancel() - attemptReq := c.Request.Clone(ctx) - attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody)) - if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" { - h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt) - isSuccess = false - finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource") - continue - } - h.transparentProxy.Director = func(req *http.Request) { - targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL) - req.URL.Scheme = targetURL.Scheme - req.URL.Host = targetURL.Host - req.Host = targetURL.Host - var pureClientPath string - if isPreciseRouting { - proxyPrefix := "/proxy/" + groupName - pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix) - } else { - pureClientPath = req.URL.Path - } - finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath) - req.URL.Path = finalPath - h.logger.WithFields(logrus.Fields{ - "correlation_id": correlationID, - "attempt": attempt, - "key_id": currentResources.APIKey.ID, - "base_upstream_url": currentResources.UpstreamEndpoint.URL, - "final_request_url": req.URL.String(), - }).Infof("Director constructed final upstream request URL.") - req.Header.Del("Authorization") - h.channel.ModifyRequest(req, currentResources.APIKey) - req.Header.Set("X-Correlation-ID", correlationID) - *req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr)) - } - transport := h.transparentProxy.Transport.(*http.Transport) - if currentResources.ProxyConfig != nil { - proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address) - proxyURL, err := url.Parse(proxyURLStr) - if err == nil { - transport.Proxy = http.ProxyURL(proxyURL) - } - } else { - transport.Proxy = http.ProxyFromEnvironment - } - h.transparentProxy.ModifyResponse = func(resp *http.Response) error { - defer resp.Body.Close() - var reader io.ReadCloser - var err error - isGzipped := resp.Header.Get("Content-Encoding") == "gzip" - if isGzipped { - reader, err = gzip.NewReader(resp.Body) - if err != nil { - h.logger.WithError(err).Error("Failed to create gzip reader") - reader = resp.Body - } else { - resp.Header.Del("Content-Encoding") - } - defer reader.Close() - } else { - reader = resp.Body - } - bodyBytes, err := io.ReadAll(reader) - if err != nil { - attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error()) - resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message))) - return nil - } - if resp.StatusCode < 400 { - attemptIsSuccess = true - finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes) - } else { - parsedMsg := errors.ParseUpstreamError(bodyBytes) - attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg) - } - resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - return nil - } - h.transparentProxy.ServeHTTP(recorder, attemptReq) - finalRecorder = recorder - finalProxyErr = attemptErr - isSuccess = attemptIsSuccess - h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr) + h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID) + + recorder, attemptErr, attemptSuccess := h.executeProxyAttempt( + c, correlationID, requestBody, resources, isPreciseRouting, groupName, + &finalPromptTokens, &finalCompletionTokens, + ) + + finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess + h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr) + if isSuccess { break } - isUnretryableError := false - if finalProxyErr != nil { - if errors.IsUnretryableRequestError(finalProxyErr.Message) { - isUnretryableError = true - h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message) - } - } - if attempt >= totalAttempts || isUnretryableError { + if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) { break } - retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting) - retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds()) - retryEvent.IsSuccess = false - retryEvent.StatusCode = recorder.Code - retryEvent.Retries = actualRetries - if attemptErr != nil { - retryEvent.Error = attemptErr - retryEvent.ErrorCode = attemptErr.Code - retryEvent.ErrorMessage = attemptErr.Message - } - eventData, _ := json.Marshal(retryEvent) - _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) + h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting) } - if finalRecorder != nil { - bodyBytes := finalRecorder.Body.Bytes() - c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes))) - for k, v := range finalRecorder.Header() { - if strings.ToLower(k) != "content-length" { - c.Writer.Header()[k] = v + + h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr) +} + +func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) { + recorder := httptest.NewRecorder() + var attemptErr *errors.APIError + var isSuccess bool + + connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second + ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout) + defer cancel() + + attemptReq := c.Request.Clone(ctx) + attemptReq.Body = io.NopCloser(bytes.NewReader(body)) + attemptReq.ContentLength = int64(len(body)) + + h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens) + *attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr)) + + h.transparentProxy.ServeHTTP(recorder, attemptReq) + + return recorder, attemptErr, isSuccess +} + +func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) { + h.transparentProxy.Director = func(r *http.Request) { + targetURL, _ := url.Parse(res.UpstreamEndpoint.URL) + r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host + + var pureClientPath string + if isPrecise { + pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName) + } else { + pureClientPath = r.URL.Path + } + r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath) + + r.Header.Del("Authorization") + h.channel.ModifyRequest(r, res.APIKey) + r.Header.Set("X-Correlation-ID", corrID) + } + + transport := h.transparentProxy.Transport.(*http.Transport) + if res.ProxyConfig != nil { + proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address) + if proxyURL, err := url.Parse(proxyURLStr); err == nil { + transport.Proxy = http.ProxyURL(proxyURL) + } else { + transport.Proxy = http.ProxyFromEnvironment + } + } else { + transport.Proxy = http.ProxyFromEnvironment + } + + h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens) +} + +func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error { + return func(resp *http.Response) error { + var reader io.ReadCloser = resp.Body + if resp.Header.Get("Content-Encoding") == "gzip" { + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + h.logger.WithError(err).Error("Failed to create gzip reader") + } else { + reader = gzReader + resp.Header.Del("Content-Encoding") } } - c.Writer.WriteHeader(finalRecorder.Code) - c.Writer.Write(finalRecorder.Body.Bytes()) - } else { - errToJSON(c, correlationID, finalProxyErr) + defer reader.Close() + + bodyBytes, err := io.ReadAll(reader) + if err != nil { + *attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response") + resp.Body = io.NopCloser(bytes.NewReader([]byte{})) + return nil + } + + if resp.StatusCode < 400 { + *isSuccess = true + *pTokens, *cTokens = extractUsage(bodyBytes) + } else { + parsedMsg := errors.ParseUpstreamError(bodyBytes) + *attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg) + } + resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + return nil } } +func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) { + corrID := r.Header.Get("X-Correlation-ID") + log := h.logger.WithField("id", corrID) + log.Errorf("Transparent proxy encountered an error: %v", err) + + errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError) + if !ok || errPtr == nil { + log.Error("FATAL: proxyErrorContextKey not found in context for error handler.") + defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred") + writeErrorToResponse(rw, defaultErr) + return + } + + if *errPtr == nil { + if errors.IsClientNetworkError(err) { + *errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") + } else { + *errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error()) + } + } + writeErrorToResponse(rw, *errPtr) +} + +func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) { + if attempt == 1 { + return initialResources, nil + } + h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt) + resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) + if err != nil { + return nil, err + } + finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig) + resources.RequestConfig = finalRequestConfig + return resources, nil +} + +func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool { + if attempt >= totalAttempts { + return true + } + if err != nil && errors.IsUnretryableRequestError(err.Message) { + h.logger.WithField("id", correlationID).Warnf("Attempt failed with unretryable request error. Aborting retries. Message: %s", err.Message) + return true + } + return false +} + +func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) { + if rec != nil { + for k, v := range rec.Header() { + c.Writer.Header()[k] = v + } + c.Writer.WriteHeader(rec.Code) + c.Writer.Write(rec.Body.Bytes()) + } else if apiErr != nil { + errToJSON(c, corrID, apiErr) + } else { + errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred")) + } +} + +func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) { + if res == nil { + h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.") + return + } + event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise) + event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) + event.RequestLog.IsSuccess = isSuccess + event.RequestLog.Retries = retries + if isSuccess { + event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens + } + if rec != nil { + event.RequestLog.StatusCode = rec.Code + } + if !isSuccess { + errToLog := finalErr + if errToLog == nil && rec != nil { + errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), "Request failed after all retries.") + } + if errToLog != nil { + event.Error = errToLog + event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message + } + } + eventData, err := json.Marshal(event) + if err != nil { + h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event") + return + } + if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { + h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event") + } +} + +func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) { + retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise) + retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) + retryEvent.RequestLog.IsSuccess = false + retryEvent.RequestLog.StatusCode = rec.Code + retryEvent.RequestLog.Retries = retries + if attemptErr != nil { + retryEvent.Error = attemptErr + retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message + } + eventData, err := json.Marshal(retryEvent) + if err != nil { + h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event") + return + } + if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { + h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event") + } +} + +func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig { + finalConfig := &models.RequestConfig{ + CustomHeaders: make(datatypes.JSONMap), + EnableStreamOptimizer: globalSettings.EnableStreamOptimizer, + StreamMinDelay: globalSettings.StreamMinDelay, + StreamMaxDelay: globalSettings.StreamMaxDelay, + StreamShortTextThresh: globalSettings.StreamShortTextThresh, + StreamLongTextThresh: globalSettings.StreamLongTextThresh, + StreamChunkSize: globalSettings.StreamChunkSize, + EnableFakeStream: globalSettings.EnableFakeStream, + FakeStreamInterval: globalSettings.FakeStreamInterval, + } + for k, v := range globalSettings.CustomHeaders { + finalConfig.CustomHeaders[k] = v + } + if groupConfig == nil { + return finalConfig + } + groupConfigJSON, err := json.Marshal(groupConfig) + if err != nil { + h.logger.WithError(err).Error("Failed to marshal group request config for merging.") + return finalConfig + } + if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil { + h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.") + } + return finalConfig +} + +func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) { + if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() { + return + } + rw.Header().Set("Content-Type", "application/json; charset=utf-8") + rw.WriteHeader(apiErr.HTTPStatus) + json.NewEncoder(rw).Encode(gin.H{"error": apiErr}) +} + func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) { startTime := time.Now() correlationID := uuid.New().String() @@ -349,7 +441,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso log.Info("Smart Gateway activated for streaming request.") var originalRequest models.GeminiRequest if err := json.Unmarshal(requestBody, &originalRequest); err != nil { - errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error())) + errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway")) return } systemSettings := h.settingsManager.GetSettings() @@ -360,8 +452,14 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso if c.Writer.Status() > 0 { requestFinishedEvent.StatusCode = c.Writer.Status() } - eventData, _ := json.Marshal(requestFinishedEvent) - _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) + eventData, err := json.Marshal(requestFinishedEvent) + if err != nil { + h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream") + return + } + if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { + h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream") + } }() params := channel.SmartRequestParams{ CorrelationID: correlationID, @@ -378,30 +476,6 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso h.channel.ProcessSmartStreamRequest(c, params) } -func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) { - correlationID := r.Header.Get("X-Correlation-ID") - h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err) - proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError) - if !exists || proxyErrPtr == nil { - h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.") - return - } - if errors.IsClientNetworkError(err) { - *proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") - } else { - *proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error()) - } - if _, ok := rw.(*httptest.ResponseRecorder); ok { - return - } - if writer, ok := rw.(interface{ Written() bool }); ok { - if writer.Written() { - return - } - } - rw.WriteHeader((*proxyErrPtr).HTTPStatus) -} - func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent { event := &models.RequestFinishedEvent{ RequestLog: models.RequestLog{ @@ -456,12 +530,14 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, } if isPreciseRouting { return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName) - } else { - return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName) } + return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName) } func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) { + if c.IsAborted() { + return + } c.JSON(apiErr.HTTPStatus, gin.H{ "error": apiErr, "correlation_id": corrID, @@ -470,8 +546,8 @@ func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) { type bufferPool struct{} -func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) } -func (b *bufferPool) Put(bytes []byte) {} +func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) } +func (b *bufferPool) Put(_ []byte) {} func extractUsage(body []byte) (promptTokens int, completionTokens int) { var data struct { @@ -486,34 +562,11 @@ func extractUsage(body []byte) (promptTokens int, completionTokens int) { return 0, 0 } -func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig { - customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders) - var customHeadersMap datatypes.JSONMap - _ = json.Unmarshal(customHeadersJSON, &customHeadersMap) - finalConfig := &models.RequestConfig{ - CustomHeaders: customHeadersMap, - EnableStreamOptimizer: globalSettings.EnableStreamOptimizer, - StreamMinDelay: globalSettings.StreamMinDelay, - StreamMaxDelay: globalSettings.StreamMaxDelay, - StreamShortTextThresh: globalSettings.StreamShortTextThresh, - StreamLongTextThresh: globalSettings.StreamLongTextThresh, - StreamChunkSize: globalSettings.StreamChunkSize, - EnableFakeStream: globalSettings.EnableFakeStream, - FakeStreamInterval: globalSettings.FakeStreamInterval, +func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int { + if isPreciseRouting && finalOpConfig.MaxRetries != nil { + return *finalOpConfig.MaxRetries } - if groupConfig == nil { - return finalConfig - } - groupConfigJSON, err := json.Marshal(groupConfig) - if err != nil { - h.logger.WithError(err).Error("Failed to marshal group request config for merging.") - return finalConfig - } - if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil { - h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.") - return finalConfig - } - return finalConfig + return h.settingsManager.GetSettings().MaxRetries } func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) { diff --git a/internal/models/dto.go b/internal/models/dto.go index 464ee47..f534689 100644 --- a/internal/models/dto.go +++ b/internal/models/dto.go @@ -77,3 +77,9 @@ type APIKeyDetails struct { CooldownUntil *time.Time `json:"cooldown_until"` EncryptedKey string } + +// SettingsManager 定义了系统设置管理器的抽象接口。 + +type SettingsManager interface { + GetSettings() *SystemSettings +} diff --git a/internal/models/runtime.go b/internal/models/runtime.go index 7a16668..f09beeb 100644 --- a/internal/models/runtime.go +++ b/internal/models/runtime.go @@ -11,6 +11,7 @@ type SystemSettings struct { BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"` KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间,单位为分钟。"` LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"` + MaxRequestBodySizeMB int `json:"max_request_body_size_mb" default:"10" name:"最大请求体大小 (MB)" category:"请求设置" desc:"允许代理接收的最大请求体大小,单位为MB。超过此大小的请求将被拒绝。"` PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"` @@ -41,6 +42,10 @@ type SystemSettings struct { MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前,允许的连续登录失败次数。"` IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长,单位为分钟。"` + // BasePool 相关配置 + // BasePoolTTLMinutes int `json:"base_pool_ttl_minutes" default:"30" name:"基础资源池最大生存时间(分钟)" category:"基础资源池" desc:"一个动态构建的基础资源池(BasePool)在Redis中的最大生存时间。到期后即使仍在活跃使用也会被强制重建。"` + // BasePoolTTIMinutes int `json:"base_pool_tti_minutes" default:"10" name:"基础资源池空闲超时(分钟)" category:"基础资源池" desc:"一个基础资源池(BasePool)在连续无请求后,自动销毁的空闲等待时间。"` + //智能网关 LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时,保留的最大字符数。0表示不截断。"` EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"` diff --git a/internal/repository/key_cache.go b/internal/repository/key_cache.go index 132e789..f1f7a8e 100644 --- a/internal/repository/key_cache.go +++ b/internal/repository/key_cache.go @@ -1,4 +1,4 @@ -// Filename: internal/repository/key_cache.go +// Filename: internal/repository/key_cache.go (最终定稿) package repository import ( @@ -9,6 +9,7 @@ import ( "strconv" ) +// --- Redis Key 常量定义 --- const ( KeyGroup = "group:%d:keys:active" KeyDetails = "key:%d:details" @@ -23,13 +24,16 @@ const ( BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown" ) +// LoadAllKeysToStore 从数据库加载所有密钥和映射关系,并完整重建Redis缓存。 func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error { - r.logger.Info("Starting to load all keys and associations into cache, including polling structures...") + r.logger.Info("Starting full cache rebuild for all keys and polling structures.") + var allMappings []*models.GroupAPIKeyMapping if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil { - return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err) + return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err) } + // 1. 批量解密所有涉及的密钥 keyMap := make(map[uint]*models.APIKey) for _, m := range allMappings { if m.APIKey != nil { @@ -41,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error { keysToDecrypt = append(keysToDecrypt, *k) } if err := r.decryptKeys(keysToDecrypt); err != nil { - r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.") + r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.") + // 即使解密失败,也继续尝试加载未加密或已解密的部分 } decryptedKeyMap := make(map[uint]models.APIKey) for _, k := range keysToDecrypt { decryptedKeyMap[k.ID] = k } - activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping) - pipe := r.store.Pipeline(context.Background()) - detailsToSet := make(map[string][]byte) + // 2. 清理所有分组的旧轮询结构 + pipe := r.store.Pipeline(ctx) var allGroups []*models.KeyGroup if err := r.db.Find(&allGroups).Error; err == nil { for _, group := range allGroups { @@ -63,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error { ) } } else { - r.logger.WithError(err).Error("Failed to get all groups for cache cleanup") + r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.") } + // 3. 准备批量更新数据 + activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping) + detailsToSet := make(map[string]any) + for _, mapping := range allMappings { if mapping.APIKey == nil { continue } decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID] if !ok { - continue + continue // 跳过解密失败的密钥 } + + // 准备 KeyDetails 和 KeyMapping 的 MSet 数据 keyJSON, _ := json.Marshal(decryptedKey) detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON mappingJSON, _ := json.Marshal(mapping) detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON + if mapping.Status == models.StatusActive { activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping) } } + // 4. 使用 MSet 批量写入详情和映射缓存 + if len(detailsToSet) > 0 { + if err := r.store.MSet(ctx, detailsToSet); err != nil { + r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.") + } + } + + // 5. 在Pipeline中重建所有分组的轮询结构 for groupID, activeMappings := range activeKeysByGroup { if len(activeMappings) == 0 { continue @@ -101,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error { pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...) - go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) + pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) } + // 6. 执行Pipeline if err := pipe.Exec(); err != nil { - return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err) - } - for key, value := range detailsToSet { - if err := r.store.Set(context.Background(), key, value, 0); err != nil { - r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key) - } + return fmt.Errorf("pipeline execution for polling structures failed: %w", err) } - r.logger.Info("Cache rebuild complete, including all polling structures.") + r.logger.Info("Full cache rebuild completed successfully.") return nil } +// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。 func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error { if err := r.decryptKey(key); err != nil { return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err) @@ -128,78 +144,101 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error { return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0) } +// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。 func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error { groupIDs, err := r.GetGroupsForKey(ctx, key.ID) if err != nil { - r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err) + r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID) } pipe := r.store.Pipeline(ctx) pipe.Del(fmt.Sprintf(KeyDetails, key.ID)) + keyIDStr := strconv.FormatUint(uint64(key.ID), 10) for _, groupID := range groupIDs { pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID)) - - keyIDStr := strconv.FormatUint(uint64(key.ID), 10) pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) - go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) + pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) } + return pipe.Exec() } +// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。 func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error { - pipe := r.store.Pipeline(context.Background()) - activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID) - pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID) + keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) + groupID := mapping.KeyGroupID + ctx := context.Background() + + pipe := r.store.Pipeline(ctx) + + // 统一、无条件地从所有轮询结构中移除,确保状态清洁 + pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) + pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) + pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) + pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) + pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) + + // 如果新状态是 Active,则重新添加到所有轮询结构中 if mapping.Status == models.StatusActive { - pipe.LPush(activeKeyListKey, mapping.APIKeyID) + pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr) + pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr) + pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) + + var score float64 + if mapping.LastUsedAt != nil { + score = float64(mapping.LastUsedAt.UnixMilli()) + } + pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score}) } + + // 无论状态如何,都更新映射详情的 K-V 缓存 + mappingJSON, err := json.Marshal(mapping) + if err != nil { + return fmt.Errorf("failed to marshal mapping: %w", err) + } + pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0) + return pipe.Exec() } +// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。 func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error { if len(mappings) == 0 { return nil } - groupUpdates := make(map[uint]struct { - ToAdd []interface{} - ToRemove []interface{} - }) + + pipe := r.store.Pipeline(ctx) + for _, mapping := range mappings { keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) - update, ok := groupUpdates[mapping.KeyGroupID] - if !ok { - update = struct { - ToAdd []interface{} - ToRemove []interface{} - }{} - } + groupID := mapping.KeyGroupID + + // 对于批处理中的每一个mapping,都执行完整的、正确的“先删后增”逻辑 + pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) + pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) + pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) + pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) + pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) + if mapping.Status == models.StatusActive { - update.ToRemove = append(update.ToRemove, keyIDStr) - update.ToAdd = append(update.ToAdd, keyIDStr) - } else { - update.ToRemove = append(update.ToRemove, keyIDStr) - } - groupUpdates[mapping.KeyGroupID] = update - } - pipe := r.store.Pipeline(context.Background()) - var pipelineError error - for groupID, updates := range groupUpdates { - activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID) - if len(updates.ToRemove) > 0 { - for _, keyID := range updates.ToRemove { - pipe.LRem(activeKeyListKey, 0, keyID) + pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr) + pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr) + pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) + + var score float64 + if mapping.LastUsedAt != nil { + score = float64(mapping.LastUsedAt.UnixMilli()) } + pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score}) } - if len(updates.ToAdd) > 0 { - pipe.LPush(activeKeyListKey, updates.ToAdd...) - } + + mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误,以保证大部分更新成功 + pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0) } - if err := pipe.Exec(); err != nil { - pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err) - } - return pipelineError + + return pipe.Exec() } diff --git a/internal/repository/key_crud.go b/internal/repository/key_crud.go index 1a9b6d5..4668e9d 100644 --- a/internal/repository/key_crud.go +++ b/internal/repository/key_crud.go @@ -23,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro keyHashes := make([]string, len(keys)) keyValueToHashMap := make(map[string]string) for i, k := range keys { - // All incoming keys must have plaintext APIKey if k.APIKey == "" { return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i) } @@ -35,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro var finalKeys []models.APIKey err := r.db.Transaction(func(tx *gorm.DB) error { var existingKeys []models.APIKey - // [MODIFIED] Query by hash to find existing keys. if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil { return err } @@ -69,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro } } if len(keysToCreate) > 0 { - // [MODIFIED] Create now only provides encrypted data and hash. if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil { return err } } - // [MODIFIED] Final select uses hashes to retrieve all relevant keys. if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil { return err } - // [CRITICAL] Decrypt all keys before returning them to the service layer. + return r.decryptKeys(finalKeys) }) return finalKeys, err } func (r *gormKeyRepository) Update(key *models.APIKey) error { - // [CRITICAL] Before saving, check if the plaintext APIKey field was populated. - // This indicates a potential change that needs to be re-encrypted. if key.APIKey != "" { encryptedKey, err := r.crypto.Encrypt(key.APIKey) if err != nil { @@ -98,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error { key.APIKeyHash = hex.EncodeToString(hash[:]) } err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { - // GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag. + return tx.Save(key).Error }) if err != nil { return err } - // For the cache update, we need the plaintext. Decrypt if it's not already populated. + if err := r.decryptKey(key); err != nil { r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err) - return nil // Continue without cache update if decryption fails. + return nil } if err := r.updateStoreCacheForKey(key); err != nil { r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err) @@ -192,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) { if err != nil { return nil, err } - // [CRITICAL] Decrypt before returning. return keys, r.decryptKeys(keys) } diff --git a/internal/repository/key_selector.go b/internal/repository/key_selector.go index b2c78f0..446f40a 100644 --- a/internal/repository/key_selector.go +++ b/internal/repository/key_selector.go @@ -1,4 +1,4 @@ -// Filename: internal/repository/key_selector.go (经审查后最终修复版) +// Filename: internal/repository/key_selector.go package repository import ( @@ -18,39 +18,40 @@ import ( ) const ( - CacheTTL = 5 * time.Minute - EmptyPoolPlaceholder = "EMPTY_POOL" - EmptyCacheTTL = 1 * time.Minute + CacheTTL = 5 * time.Minute + EmptyCacheTTL = 1 * time.Minute ) -// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。 +// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。 func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) { + if group == nil { + return nil, nil, fmt.Errorf("group cannot be nil") + } var keyIDStr string var err error - switch group.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID) keyIDStr, err = r.store.Rotate(ctx, sequentialKey) - case models.StrategyWeighted: lruKey := fmt.Sprintf(KeyGroupLRU, group.ID) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) - if zerr == nil && len(results) > 0 { - keyIDStr = results[0] + if zerr == nil { + if len(results) > 0 { + keyIDStr = results[0] + } else { + zerr = gorm.ErrRecordNotFound + } } err = zerr - case models.StrategyRandom: mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) - - default: // 默认或未指定策略时,使用基础的随机策略 + default: activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey) } - if err != nil { if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, gorm.ErrRecordNotFound @@ -58,39 +59,44 @@ func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *model r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy) return nil, nil, err } - if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } - - keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) - + keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64) + if parseErr != nil { + r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr) + return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr) + } apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if err != nil { - r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID) + r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID) return nil, nil, err } - if group.PollingStrategy == models.StrategyWeighted { - go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID)) + go func() { + updateCtx, cancel := r.withTimeout(5 * time.Second) + defer cancel() + r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID)) + }() } - return apiKey, mapping, nil } -// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。 +// SelectOneActiveKeyFromBasePool 从智能聚合池中选取一个可用Key。 func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { - poolID := generatePoolID(pool.CandidateGroups) + if pool == nil || len(pool.CandidateGroups) == 0 { + return nil, nil, fmt.Errorf("invalid or empty base pool configuration") + } + poolID := r.generatePoolID(pool.CandidateGroups) log := r.logger.WithField("pool_id", poolID) - if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil { - log.WithError(err).Error("Failed to ensure BasePool cache exists.") + if !errors.Is(err, gorm.ErrRecordNotFound) { + log.WithError(err).Error("Failed to ensure BasePool cache exists") + } return nil, nil, err } - var keyIDStr string var err error - switch pool.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) @@ -98,8 +104,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, case models.StrategyWeighted: lruKey := fmt.Sprintf(BasePoolLRU, poolID) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) - if zerr == nil && len(results) > 0 { - keyIDStr = results[0] + if zerr == nil { + if len(results) > 0 { + keyIDStr = results[0] + } else { + zerr = gorm.ErrRecordNotFound + } } err = zerr case models.StrategyRandom: @@ -107,13 +117,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) default: - log.Warnf("Default polling strategy triggered inside selection. This should be rare.") + log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy) sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) - keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0) + keyIDStr, err = r.store.Rotate(ctx, sequentialKey) } - if err != nil { - if errors.Is(err, store.ErrNotFound) { + if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, gorm.ErrRecordNotFound } log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy) @@ -122,118 +131,246 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, if keyIDStr == "" { return nil, nil, gorm.ErrRecordNotFound } - keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) - - for _, group := range pool.CandidateGroups { - apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) - if cacheErr == nil && apiKey != nil && mapping != nil { - if pool.PollingStrategy == models.StrategyWeighted { - go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID)) - } - return apiKey, group, nil + go func() { + bgCtx, cancel := r.withTimeout(5 * time.Second) + defer cancel() + r.refreshBasePoolHeartbeat(bgCtx, poolID) + }() + keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64) + if parseErr != nil { + log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr) + return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr) + } + keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID) + groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr) + if err != nil { + log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID) + return nil, nil, errors.New("cache inconsistency: key has no origin group mapping") + } + groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64) + if parseErr != nil { + log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr) + return nil, nil, errors.New("cache inconsistency: invalid group id in mapping") + } + apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID)) + if err != nil { + log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID) + return nil, nil, err + } + var originGroup *models.KeyGroup + for _, g := range pool.CandidateGroups { + if g.ID == uint(groupID) { + originGroup = g + break } } - - log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID) - return nil, nil, errors.New("cache inconsistency: selected key has no origin group") + if originGroup == nil { + log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID) + return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list") + } + if pool.PollingStrategy == models.StrategyWeighted { + go func() { + bgCtx, cancel := r.withTimeout(5 * time.Second) + defer cancel() + r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID)) + }() + } + return apiKey, originGroup, nil } -// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构 +// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构。 func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error { - listKey := fmt.Sprintf(BasePoolSequential, poolID) - - exists, err := r.store.Exists(ctx, listKey) - if err != nil { - r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID) - return err + heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID) + emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID) + // 预检查,快速失败 + if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists { + return gorm.ErrRecordNotFound } - if exists { - val, err := r.store.LIndex(ctx, listKey, 0) - if err != nil { - r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID) - } else { - if val == EmptyPoolPlaceholder { - return gorm.ErrRecordNotFound - } - return nil - } - } - - lockKey := fmt.Sprintf("lock:basepool:%s", poolID) - acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second) - if err != nil { - r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.") - return err - } - if !acquired { - time.Sleep(100 * time.Millisecond) - return r.ensureBasePoolCacheExists(ctx, pool, poolID) - } - defer r.store.Del(context.Background(), lockKey) - - if exists, _ := r.store.Exists(ctx, listKey); exists { + if exists, _ := r.store.Exists(ctx, heartbeatKey); exists { return nil } - r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID) - var allActiveKeyIDs []string - lruMembers := make(map[string]float64) + // 获取分布式锁 + lockKey := fmt.Sprintf("lock:basepool:%s", poolID) + if err := r.acquireLock(ctx, lockKey); err != nil { + return err // acquireLock 内部已记录日志并返回明确错误 + } + defer r.releaseLock(context.Background(), lockKey) + // 双重检查锁定 + if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists { + return gorm.ErrRecordNotFound + } + if exists, _ := r.store.Exists(ctx, heartbeatKey); exists { + return nil + } + // 在执行重度操作前,最后检查一次上下文是否已取消 + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID) + // 手动聚合所有 Keys 并同时构建 key-to-group 映射 + keyToGroupMap := make(map[string]any) + allKeyIDsSet := make(map[string]struct{}) for _, group := range pool.CandidateGroups { - activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) - groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey) + groupKeySet := fmt.Sprintf(KeyGroup, group.ID) + groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet) if err != nil { - r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID) - return err + r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID) + continue } - allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...) - for _, keyIDStr := range groupKeyIDs { - keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) - _, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) - if err == nil && mapping != nil { - var score float64 - if mapping.LastUsedAt != nil { - score = float64(mapping.LastUsedAt.UnixMilli()) - } - lruMembers[keyIDStr] = score + groupIDStr := strconv.FormatUint(uint64(group.ID), 10) + for _, keyID := range groupKeyIDs { + if _, exists := allKeyIDsSet[keyID]; !exists { + allKeyIDsSet[keyID] = struct{}{} + keyToGroupMap[keyID] = groupIDStr } } } - - if len(allActiveKeyIDs) == 0 { - r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID) - pipe := r.store.Pipeline(ctx) - pipe.LPush(listKey, EmptyPoolPlaceholder) - pipe.Expire(listKey, EmptyCacheTTL) - if err := pipe.Exec(); err != nil { - r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID) + // 处理空池情况 + if len(allKeyIDsSet) == 0 { + emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2 + if emptyCacheTTL < time.Minute { + emptyCacheTTL = time.Minute + } + r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID) + if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil { + r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID) } return gorm.ErrRecordNotFound } - + allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet)) + for keyID := range allKeyIDsSet { + allActiveKeyIDs = append(allActiveKeyIDs, keyID) + } + // 使用 Pipeline 原子化构建所有缓存结构 + basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute + basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute + mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID) + sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) + cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) + lruKey := fmt.Sprintf(BasePoolLRU, poolID) + keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID) pipe := r.store.Pipeline(ctx) - pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...) - pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...) - pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL) - pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL) - pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL) - pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL) - + pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey) + pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...) + pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...) + if len(keyToGroupMap) > 0 { + pipe.HSet(keyToGroupMapKey, keyToGroupMap) + pipe.Expire(keyToGroupMapKey, basePoolTTL) + } + pipe.Expire(mainPoolKey, basePoolTTL) + pipe.Expire(sequentialKey, basePoolTTL) + pipe.Expire(cooldownKey, basePoolTTL) + pipe.Expire(lruKey, basePoolTTL) + pipe.Set(heartbeatKey, []byte("1"), basePoolTTI) if err := pipe.Exec(); err != nil { + r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID) + cleanupCtx, cancel := r.withTimeout(5 * time.Second) + defer cancel() + r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey) return err } + // 异步填充 LRU 缓存,并传入已构建好的映射 + go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap) + return nil +} - if len(lruMembers) > 0 { - if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil { - r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID) +// --- 辅助方法 --- +// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。 +func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error { + const ( + lockTTL = 30 * time.Second + lockMaxRetries = 5 + lockBaseBackoff = 50 * time.Millisecond + ) + for i := 0; i < lockMaxRetries; i++ { + acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL) + if err != nil { + r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock") + return err + } + if acquired { + return nil + } + time.Sleep(lockBaseBackoff * (1 << i)) + } + return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries) +} + +// releaseLock 封装了分布式锁的释放逻辑。 +func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) { + if err := r.store.Del(ctx, lockKey); err != nil { + r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey) + } +} + +// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。 +func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), duration) +} + +// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI +func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) { + basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute + heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID) + // 使用 EXPIRE 命令来刷新,如果Key不存在,它什么也不做,是安全的 + if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil { + if ctx.Err() == nil { // 避免在context取消后打印不必要的错误 + r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID) + } + } +} + +// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构 +func (r *gormKeyRepository) populateBasePoolLRUCache( + parentCtx context.Context, + currentPoolID string, + keys []string, + keyToGroupMap map[string]any, +) { + lruMembers := make(map[string]float64, len(keys)) + for _, keyIDStr := range keys { + select { + case <-parentCtx.Done(): + return + default: + } + groupIDStr, ok := keyToGroupMap[keyIDStr].(string) + if !ok { + continue + } + keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) + groupID, _ := strconv.ParseUint(groupIDStr, 10, 64) + mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID) + data, err := r.store.Get(parentCtx, mappingKey) + if err != nil { + continue + } + var mapping models.GroupAPIKeyMapping + if json.Unmarshal(data, &mapping) == nil { + var score float64 + if mapping.LastUsedAt != nil { + score = float64(mapping.LastUsedAt.UnixMilli()) + } + lruMembers[keyIDStr] = score + } + } + if len(lruMembers) > 0 { + lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID) + if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil { + if parentCtx.Err() == nil { + r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID) + } } } - return nil } // updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) { lruKey := fmt.Sprintf(BasePoolLRU, poolID) err := r.store.ZAdd(ctx, lruKey, map[string]float64{ - strconv.FormatUint(uint64(keyID), 10): nowMilli(), + strconv.FormatUint(uint64(keyID), 10): r.nowMilli(), }) if err != nil { r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID) @@ -241,20 +378,19 @@ func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, } // generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID -func generatePoolID(groups []*models.KeyGroup) string { +func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string { ids := make([]int, len(groups)) for i, g := range groups { ids[i] = int(g.ID) } sort.Ints(ids) - h := sha1.New() io.WriteString(h, fmt.Sprintf("%v", ids)) return fmt.Sprintf("%x", h.Sum(nil)) } // toInterfaceSlice 类型转换辅助函数 -func toInterfaceSlice(slice []string) []interface{} { +func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} { result := make([]interface{}, len(slice)) for i, v := range slice { result[i] = v @@ -263,7 +399,7 @@ func toInterfaceSlice(slice []string) []interface{} { } // nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略 -func nowMilli() float64 { +func (r *gormKeyRepository) nowMilli() float64 { return float64(time.Now().UnixMilli()) } diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 239d86f..3fb3c20 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -1,8 +1,9 @@ -// Filename: internal/repository/repository.go (经审查后最终修复版) +// Filename: internal/repository/repository.go package repository import ( "context" + "gemini-balancer/internal/config" "gemini-balancer/internal/crypto" "gemini-balancer/internal/errors" "gemini-balancer/internal/models" @@ -87,18 +88,20 @@ type gormKeyRepository struct { store store.Store logger *logrus.Entry crypto *crypto.Service + config *config.Config } type gormGroupRepository struct { db *gorm.DB } -func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository { +func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository { return &gormKeyRepository{ db: db, store: s, logger: logger.WithField("component", "repository.key🔗"), crypto: crypto, + config: cfg, } } diff --git a/internal/service/resource_service.go b/internal/service/resource_service.go index 87ecf36..21c1c42 100644 --- a/internal/service/resource_service.go +++ b/internal/service/resource_service.go @@ -1,10 +1,9 @@ // Filename: internal/service/resource_service.go - package service import ( "context" - "errors" + "gemini-balancer/internal/domain/proxy" apperrors "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "gemini-balancer/internal/repository" @@ -16,10 +15,7 @@ import ( "github.com/sirupsen/logrus" ) -var ( - ErrNoResourceAvailable = errors.New("no available resource found for the request") -) - +// RequestResources 封装了一次成功请求所需的所有资源。 type RequestResources struct { KeyGroup *models.KeyGroup APIKey *models.APIKey @@ -28,41 +24,51 @@ type RequestResources struct { RequestConfig *models.RequestConfig } +// ResourceService 负责根据请求参数和业务规则,动态地选择和分配API密钥及相关资源。 type ResourceService struct { settingsManager *settings.SettingsManager groupManager *GroupManager keyRepo repository.KeyRepository + authTokenRepo repository.AuthTokenRepository apiKeyService *APIKeyService + proxyManager *proxy.Module logger *logrus.Entry initOnce sync.Once } +// NewResourceService 创建并初始化一个新的 ResourceService 实例。 func NewResourceService( sm *settings.SettingsManager, gm *GroupManager, kr repository.KeyRepository, + atr repository.AuthTokenRepository, aks *APIKeyService, + pm *proxy.Module, logger *logrus.Logger, ) *ResourceService { rs := &ResourceService{ settingsManager: sm, groupManager: gm, keyRepo: kr, + authTokenRepo: atr, apiKeyService: aks, + proxyManager: pm, logger: logger.WithField("component", "ResourceService📦️"), } + // 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次 rs.initOnce.Do(func() { - go rs.preWarmCache(logger) + go rs.preWarmCache() }) return rs } +// GetResourceFromBasePool 使用智能聚合池模式获取资源。 func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) { log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"}) log.Debug("Entering BasePool resource acquisition.") - candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups) + candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken) if len(candidateGroups) == 0 { log.Warn("No candidate groups found for BasePool construction.") return nil, apperrors.ErrNoKeysAvailable @@ -84,17 +90,18 @@ func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.") return nil, err } - resources.RequestConfig = &models.RequestConfig{} + resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置 + log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID) return resources, nil } +// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。 func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) { log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"}) log.Debug("Entering PreciseRoute resource acquisition.") targetGroup, ok := s.groupManager.GetGroupByName(groupName) - if !ok { return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.") } @@ -113,37 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *m log.WithError(err).Error("Failed to assemble resources for precise route.") return nil, err } - resources.RequestConfig = targetGroup.RequestConfig + resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置 log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID) return resources, nil } +// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。 func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string { allGroups := s.groupManager.GetAllGroups() if len(allGroups) == 0 { return []string{} } - allowedModelsSet := make(map[string]struct{}) + + allowedGroupIDs := make(map[uint]bool) if authToken.IsAdmin { for _, group := range allGroups { + allowedGroupIDs[group.ID] = true + } + } else { + for _, ag := range authToken.AllowedGroups { + allowedGroupIDs[ag.ID] = true + } + } + + allowedModelsSet := make(map[string]struct{}) + for _, group := range allGroups { + if allowedGroupIDs[group.ID] { for _, modelMapping := range group.AllowedModels { allowedModelsSet[modelMapping.ModelName] = struct{}{} } } - } else { - allowedGroupIDs := make(map[uint]bool) - for _, ag := range authToken.AllowedGroups { - allowedGroupIDs[ag.ID] = true - } - for _, group := range allGroups { - if _, ok := allowedGroupIDs[group.ID]; ok { - for _, modelMapping := range group.AllowedModels { - allowedModelsSet[modelMapping.ModelName] = struct{}{} - } - } - } } + result := make([]string, 0, len(allowedModelsSet)) for modelName := range allowedModelsSet { result = append(result, modelName) @@ -152,12 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) return result } +// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。 +func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) { + if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil { + return + } + s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr) +} + +// --- 私有辅助方法 --- + +// preWarmCache 在后台执行一次性的缓存预热任务。 +func (s *ResourceService) preWarmCache() { + time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化 + s.logger.Info("Performing initial key cache pre-warming...") + + // 强制加载 GroupManager 缓存 + s.logger.Info("Pre-warming GroupManager cache...") + _ = s.groupManager.GetAllGroups() + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时 + defer cancel() + + if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil { + s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.") + } else { + s.logger.Info("Initial key cache pre-warming completed successfully.") + } +} + +// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。 func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) { selectedUpstream := s.selectUpstreamForGroup(group) if selectedUpstream == nil { return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.") } var proxyConfig *models.ProxyConfig + var err error + // 只有在组明确启用代理时,才为其分配代理 + if group.EnableProxy { + proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey) + if err != nil { + s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID) + // 根据业务需求,这里必须返回错误,因为代理是该组的强制要求 + return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.") + } + } return &RequestResources{ KeyGroup: group, APIKey: apiKey, @@ -166,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe }, nil } +// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。 func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint { if len(group.AllowedUpstreams) > 0 { + // (未来可扩展负载均衡逻辑) return group.AllowedUpstreams[0] } globalSettings := s.settingsManager.GetSettings() @@ -177,56 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models return nil } -func (s *ResourceService) preWarmCache(logger *logrus.Logger) error { - time.Sleep(2 * time.Second) - s.logger.Info("Performing initial key cache pre-warming...") - if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil { - logger.WithError(err).Error("Failed to perform initial key cache pre-warming.") - return err - } - s.logger.Info("Initial key cache pre-warming completed successfully.") - return nil -} - -func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) { - return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup") -} - -func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup { +// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。 +func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup { allGroupsFromCache := s.groupManager.GetAllGroups() var candidateGroups []*models.KeyGroup - allowedGroupIDs := make(map[uint]bool) - isTokenRestricted := len(allowedGroupsFromToken) > 0 - if isTokenRestricted { - for _, ag := range allowedGroupsFromToken { - allowedGroupIDs[ag.ID] = true - } - } + for _, group := range allGroupsFromCache { - if isTokenRestricted && !allowedGroupIDs[group.ID] { + // 检查令牌权限 + if !s.isTokenAllowedForGroup(authToken, group.ID) { continue } - isModelAllowed := false - if len(group.AllowedModels) == 0 { - isModelAllowed = true - } else { - for _, m := range group.AllowedModels { - if m.ModelName == modelName { - isModelAllowed = true - break - } - } - } - if isModelAllowed { + // 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型) + if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) { candidateGroups = append(candidateGroups, group) } } + sort.SliceStable(candidateGroups, func(i, j int) bool { return candidateGroups[i].Order < candidateGroups[j].Order }) return candidateGroups } +// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。 +func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool { + for _, m := range group.AllowedModels { + if m.ModelName == modelName { + return true + } + } + return false +} + +// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。 func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool { if authToken.IsAdmin { return true @@ -238,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr } return false } - -func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) { - if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil { - return - } - s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr) -} diff --git a/internal/settings/settings.go b/internal/settings/settings.go index 525e651..d8217f1 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -1,4 +1,4 @@ -// file: gemini-balancer\internal\settings\settings.go +// Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版) package settings import ( @@ -19,7 +19,9 @@ import ( const SettingsUpdateChannel = "system_settings:updated" const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" -// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型 +var _ models.SettingsManager = (*SettingsManager)(nil) + +// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。 type SettingsManager struct { db *gorm.DB syncer *syncer.CacheSyncer[*models.SystemSettings] @@ -27,13 +29,14 @@ type SettingsManager struct { jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型 } +// NewSettingsManager 创建一个新的 SettingsManager 实例。 func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) { sm := &SettingsManager{ db: db, logger: logger.WithField("component", "SettingsManager⚙️"), jsonToFieldType: make(map[string]reflect.Type), } - // settingsLoader 的职责:读取“砖块”,组装并返回“蓝图” + settingsType := reflect.TypeOf(models.SystemSettings{}) for i := 0; i < settingsType.NumField(); i++ { field := settingsType.Field(i) @@ -42,102 +45,89 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) ( sm.jsonToFieldType[jsonTag] = field.Type } } - // settingsLoader 的职责:读取“砖块”,智能组装成“蓝图” + settingsLoader := func() (*models.SystemSettings, error) { sm.logger.Info("Loading system settings from database...") var dbRecords []models.Setting if err := sm.db.Find(&dbRecords).Error; err != nil { return nil, fmt.Errorf("failed to load system settings from db: %w", err) } + settingsMap := make(map[string]string) for _, record := range dbRecords { settingsMap[record.Key] = record.Value } - // 从一个包含了所有“出厂设置”的“蓝图”开始 + settings := defaultSystemSettings() v := reflect.ValueOf(settings).Elem() - t := v.Type() - // [智能卸货] - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) + + for i := 0; i < v.NumField(); i++ { + field := v.Type().Field(i) fieldValue := v.Field(i) jsonTag := field.Tag.Get("json") - if dbValue, ok := settingsMap[jsonTag]; ok { + if dbValue, ok := settingsMap[jsonTag]; ok { if err := parseAndSetField(fieldValue, dbValue); err != nil { - sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err) + sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err) } } } - if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" { - if settings.DefaultUpstreamURL != "" { - // 如果全局上游URL已设置,则基于它构建新的检查端点。 - originalEndpoint := settings.BaseKeyCheckEndpoint - derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models" - settings.BaseKeyCheckEndpoint = derivedEndpoint - sm.logger.Infof( - "BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'", - originalEndpoint, derivedEndpoint, - ) - } - } else { + // [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。 + if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" { + derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models" + sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint) + settings.BaseKeyCheckEndpoint = derivedEndpoint + } else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" { + // 恢复 else 日志,以明确告知用户正在使用自定义覆盖。 sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint) } sm.logger.Info("System settings loaded and cached.") - sm.DisplaySettings(settings) return settings, nil } + s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel) if err != nil { return nil, fmt.Errorf("failed to create system settings syncer: %w", err) } sm.syncer = s - go sm.ensureSettingsInitialized() + + if err := sm.ensureSettingsInitialized(); err != nil { + return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err) + } + return sm, nil } -// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图” +// GetSettings 返回当前缓存的系统设置。 func (sm *SettingsManager) GetSettings() *models.SystemSettings { return sm.syncer.Get() } -// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库 +// UpdateSettings 更新一个或多个系统设置。 func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error { var settingsToUpdate []models.Setting + for key, value := range settingsMap { fieldType, ok := sm.jsonToFieldType[key] if !ok { sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key) continue } - var dbValue string - // [智能打包] - // 如果字段是 slice 或 map,我们就将传入的 interface{} “打包”成 JSON string - kind := fieldType.Kind() - if kind == reflect.Slice || kind == reflect.Map { - jsonBytes, marshalErr := json.Marshal(value) - if marshalErr != nil { - // [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。 - sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr) - continue // 跳过,继续处理下一个key - } - dbValue = string(jsonBytes) - } else if kind == reflect.Bool { - if b, ok := value.(bool); ok { - dbValue = strconv.FormatBool(b) - } else { - dbValue = "false" - } - } else { - dbValue = fmt.Sprintf("%v", value) + + dbValue, err := sm.convertToDBValue(key, value, fieldType) + if err != nil { + sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err) + continue } + settingsToUpdate = append(settingsToUpdate, models.Setting{ Key: key, Value: dbValue, }) } + if len(settingsToUpdate) > 0 { err := sm.db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "key"}}, @@ -147,83 +137,20 @@ func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) er return fmt.Errorf("failed to update settings in db: %w", err) } } - return sm.syncer.Invalidate() -} -// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义 -func (sm *SettingsManager) ensureSettingsInitialized() { - defaults := defaultSystemSettings() - v := reflect.ValueOf(defaults).Elem() - t := v.Type() - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - fieldValue := v.Field(i) - key := field.Tag.Get("json") - if key == "" || key == "-" { - continue - } - var existing models.Setting - if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound { - - var defaultValue string - kind := fieldValue.Kind() - // [智能初始化] - if kind == reflect.Slice || kind == reflect.Map { - // 为复杂类型,生成一个“空的”JSON字符串,例如 "[]" 或 "{}" - jsonBytes, _ := json.Marshal(fieldValue.Interface()) - defaultValue = string(jsonBytes) - } else { - defaultValue = field.Tag.Get("default") - } - setting := models.Setting{ - Key: key, - Value: defaultValue, - Name: field.Tag.Get("name"), - Description: field.Tag.Get("desc"), - Category: field.Tag.Get("category"), - DefaultValue: field.Tag.Get("default"), // 元数据中的default,永远来自tag - } - if err := sm.db.Create(&setting).Error; err != nil { - sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err) - } - } + if err := sm.syncer.Invalidate(); err != nil { + sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err) + return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err) } + + return nil } -// ResetAndSaveSettings [核心新增] 將所有配置重置為其在 'default' 標籤中定義的值。 - +// ResetAndSaveSettings 将所有设置重置为其默认值。 func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) { defaults := defaultSystemSettings() - v := reflect.ValueOf(defaults).Elem() - t := v.Type() - var settingsToSave []models.Setting - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - fieldValue := v.Field(i) - key := field.Tag.Get("json") - if key == "" || key == "-" { - continue - } + settingsToSave := sm.buildSettingsFromDefaults(defaults) - var defaultValue string - kind := fieldValue.Kind() - // [智能重置] - if kind == reflect.Slice || kind == reflect.Map { - jsonBytes, _ := json.Marshal(fieldValue.Interface()) - defaultValue = string(jsonBytes) - } else { - defaultValue = field.Tag.Get("default") - } - setting := models.Setting{ - Key: key, - Value: defaultValue, - Name: field.Tag.Get("name"), - Description: field.Tag.Get("desc"), - Category: field.Tag.Get("category"), - DefaultValue: field.Tag.Get("default"), - } - settingsToSave = append(settingsToSave, setting) - } if len(settingsToSave) > 0 { err := sm.db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "key"}}, @@ -233,8 +160,93 @@ func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error return nil, fmt.Errorf("failed to reset settings in db: %w", err) } } + if err := sm.syncer.Invalidate(); err != nil { - sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err) + sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err) + return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err) } + return defaults, nil } + +// --- 私有辅助函数 --- + +func (sm *SettingsManager) ensureSettingsInitialized() error { + defaults := defaultSystemSettings() + settingsToCreate := sm.buildSettingsFromDefaults(defaults) + + for _, setting := range settingsToCreate { + var existing models.Setting + err := sm.db.Where("key = ?", setting.Key).First(&existing).Error + + if err == gorm.ErrRecordNotFound { + sm.logger.Infof("Initializing new setting '%s'", setting.Key) + if createErr := sm.db.Create(&setting).Error; createErr != nil { + return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr) + } + } else if err != nil { + return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err) + } + } + return nil +} + +func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting { + v := reflect.ValueOf(defaults).Elem() + t := v.Type() + var settings []models.Setting + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldValue := v.Field(i) + key := field.Tag.Get("json") + + if key == "" || key == "-" { + continue + } + + var defaultValue string + kind := fieldValue.Kind() + + if kind == reflect.Slice || kind == reflect.Map { + jsonBytes, _ := json.Marshal(fieldValue.Interface()) + defaultValue = string(jsonBytes) + } else { + defaultValue = field.Tag.Get("default") + } + + settings = append(settings, models.Setting{ + Key: key, + Value: defaultValue, + Name: field.Tag.Get("name"), + Description: field.Tag.Get("desc"), + Category: field.Tag.Get("category"), + DefaultValue: field.Tag.Get("default"), + }) + } + return settings +} + +// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。 +func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) { + kind := fieldType.Kind() + + switch kind { + case reflect.Slice, reflect.Map: + jsonBytes, err := json.Marshal(value) + if err != nil { + return "", fmt.Errorf("failed to marshal to JSON: %w", err) + } + return string(jsonBytes), nil + + case reflect.Bool: + b, ok := value.(bool) + if !ok { + return "", fmt.Errorf("expected bool, but got %T", value) + } + return strconv.FormatBool(b), nil + + default: + return fmt.Sprintf("%v", value), nil + } +} diff --git a/internal/store/memory_store.go b/internal/store/memory_store.go index aaeed21..3d9b268 100644 --- a/internal/store/memory_store.go +++ b/internal/store/memory_store.go @@ -7,6 +7,7 @@ import ( "fmt" "math/rand" "sort" + "strconv" "sync" "time" @@ -65,7 +66,7 @@ func (s *memoryStore) startGCollector() { } } -// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 --- +// --- 所有方法签名都增加了 context.Context 参数以匹配接口 --- // --- 内存实现可以忽略该参数,用 _ 接收 --- func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error { @@ -108,6 +109,17 @@ func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) { return ok && !item.isExpired(), nil } +func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + item, ok := s.items[key] + if !ok { + return ErrNotFound + } + item.expireAt = time.Now().Add(expiration) + return nil +} + func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) { s.mu.Lock() defer s.mu.Unlock() @@ -159,6 +171,21 @@ func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) return nil } +func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + item, ok := s.items[key] + if !ok || item.isExpired() { + return "", ErrNotFound + } + if hash, ok := item.value.(map[string]string); ok { + if value, exists := hash[field]; exists { + return value, nil + } + } + return "", ErrNotFound +} + func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -351,6 +378,26 @@ func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) return members[n], nil } +func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + unionSet := make(map[string]struct{}) + for _, key := range keys { + item, ok := s.items[key] + if !ok || item.isExpired() { + continue + } + if set, ok := item.value.(map[string]struct{}); ok { + for member := range set { + unionSet[member] = struct{}{} + } + } + } + destItem := &memoryStoreItem{value: unionSet} + s.items[destination] = destItem + return int64(len(unionSet)), nil +} + func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) { s.mu.Lock() defer s.mu.Unlock() @@ -388,6 +435,16 @@ func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string return list[index], nil } +func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error { + s.mu.Lock() + defer s.mu.Unlock() + for key, value := range values { + // 内存存储不支持独立的 TTL,因此我们假设永不过期 + s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}} + } + return nil +} + func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error { s.mu.Lock() defer s.mu.Unlock() @@ -556,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) { } }) } + +func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) { + capturedKey := key + capturedValue := value + p.ops = append(p.ops, func() { + var expireAt time.Time + if expiration > 0 { + expireAt = time.Now().Add(expiration) + } + p.store.items[capturedKey] = &memoryStoreItem{ + value: capturedValue, + expireAt: expireAt, + } + }) +} + func (p *memoryPipeliner) SAdd(key string, members ...any) { capturedKey := key capturedMembers := make([]any, len(members)) @@ -576,6 +649,7 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) { } }) } + func (p *memoryPipeliner) SRem(key string, members ...any) { capturedKey := key capturedMembers := make([]any, len(members)) @@ -615,11 +689,125 @@ func (p *memoryPipeliner) LPush(key string, values ...any) { item.value = append(stringValues, list...) }) } -func (p *memoryPipeliner) LRem(key string, count int64, value any) {} -func (p *memoryPipeliner) HSet(key string, values map[string]any) {} -func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {} -func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {} -func (p *memoryPipeliner) ZRem(key string, members ...any) {} +func (p *memoryPipeliner) LRem(key string, count int64, value any) { + capturedKey := key + capturedValue := fmt.Sprintf("%v", value) + p.ops = append(p.ops, func() { + item, ok := p.store.items[capturedKey] + if !ok || item.isExpired() { + return + } + list, ok := item.value.([]string) + if !ok { + return + } + newList := make([]string, 0, len(list)) + removed := int64(0) + for _, v := range list { + if count != 0 && v == capturedValue && (count < 0 || removed < count) { + removed++ + continue + } + newList = append(newList, v) + } + item.value = newList + }) +} +func (p *memoryPipeliner) HSet(key string, values map[string]any) { + capturedKey := key + capturedValues := make(map[string]any, len(values)) + for k, v := range values { + capturedValues[k] = v + } + p.ops = append(p.ops, func() { + item, ok := p.store.items[capturedKey] + if !ok || item.isExpired() { + item = &memoryStoreItem{value: make(map[string]string)} + p.store.items[capturedKey] = item + } + hash, ok := item.value.(map[string]string) + if !ok { + hash = make(map[string]string) + item.value = hash + } + for field, value := range capturedValues { + hash[field] = fmt.Sprintf("%v", value) + } + }) +} +func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) { + capturedKey := key + capturedField := field + p.ops = append(p.ops, func() { + item, ok := p.store.items[capturedKey] + if !ok || item.isExpired() { + item = &memoryStoreItem{value: make(map[string]string)} + p.store.items[capturedKey] = item + } + hash, ok := item.value.(map[string]string) + if !ok { + hash = make(map[string]string) + item.value = hash + } + current, _ := strconv.ParseInt(hash[capturedField], 10, 64) + hash[capturedField] = strconv.FormatInt(current+incr, 10) + }) +} +func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) { + capturedKey := key + capturedMembers := make(map[string]float64, len(members)) + for k, v := range members { + capturedMembers[k] = v + } + p.ops = append(p.ops, func() { + item, ok := p.store.items[capturedKey] + if !ok || item.isExpired() { + item = &memoryStoreItem{value: make(map[string]float64)} + p.store.items[capturedKey] = item + } + zset, ok := item.value.(map[string]float64) + if !ok { + zset = make(map[string]float64) + item.value = zset + } + for member, score := range capturedMembers { + zset[member] = score + } + }) +} +func (p *memoryPipeliner) ZRem(key string, members ...any) { + capturedKey := key + capturedMembers := make([]any, len(members)) + copy(capturedMembers, members) + p.ops = append(p.ops, func() { + item, ok := p.store.items[capturedKey] + if !ok || item.isExpired() { + return + } + zset, ok := item.value.(map[string]float64) + if !ok { + return + } + for _, member := range capturedMembers { + delete(zset, fmt.Sprintf("%v", member)) + } + }) +} + +func (p *memoryPipeliner) MSet(values map[string]any) { + capturedValues := make(map[string]any, len(values)) + for k, v := range values { + capturedValues[k] = v + } + p.ops = append(p.ops, func() { + for key, value := range capturedValues { + p.store.items[key] = &memoryStoreItem{ + value: value, + expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期 + } + } + }) +} type memorySubscription struct { store *memoryStore diff --git a/internal/store/redis_store.go b/internal/store/redis_store.go index e16f849..e8b451d 100644 --- a/internal/store/redis_store.go +++ b/internal/store/redis_store.go @@ -75,10 +75,24 @@ func (s *RedisStore) Close() error { return s.client.Close() } +func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error { + return s.client.Expire(ctx, key, expiration).Err() +} + func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error { return s.client.HSet(ctx, key, values).Err() } +func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) { + val, err := s.client.HGet(ctx, key, field).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return "", ErrNotFound + } + return "", err + } + return val, nil +} func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) { return s.client.HGetAll(ctx, key).Result() } @@ -111,6 +125,18 @@ func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) { return val, nil } +func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error { + if len(values) == 0 { + return nil + } + // Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片 + pairs := make([]interface{}, 0, len(values)*2) + for k, v := range values { + pairs = append(pairs, k, v) + } + return s.client.MSet(ctx, pairs...).Err() +} + func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error { return s.client.SAdd(ctx, key, members...).Err() } @@ -141,6 +167,13 @@ func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error return member, nil } +func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) { + if len(keys) == 0 { + return 0, nil + } + return s.client.SUnionStore(ctx, destination, keys...).Result() +} + func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error { if len(members) == 0 { return nil @@ -216,6 +249,17 @@ func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, func (p *redisPipeliner) LRem(key string, count int64, value any) { p.pipe.LRem(p.ctx, key, count, value) } + +func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) { + p.pipe.Set(p.ctx, key, value, expiration) +} + +func (p *redisPipeliner) MSet(values map[string]any) { + if len(values) == 0 { + return + } + p.pipe.MSet(p.ctx, values) +} func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) } func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) } func (p *redisPipeliner) ZAdd(key string, members map[string]float64) { diff --git a/internal/store/store.go b/internal/store/store.go index fc2edcf..41b3f6f 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -35,6 +35,8 @@ type Pipeliner interface { HIncrBy(key, field string, incr int64) // SET + MSet(values map[string]any) + Set(key string, value []byte, expiration time.Duration) SAdd(key string, members ...any) SRem(key string, members ...any) @@ -58,9 +60,11 @@ type Store interface { Del(ctx context.Context, keys ...string) error Exists(ctx context.Context, key string) (bool, error) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) + MSet(ctx context.Context, values map[string]any) error // HASH operations HSet(ctx context.Context, key string, values map[string]any) error + HGet(ctx context.Context, key, field string) (string, error) HGetAll(ctx context.Context, key string) (map[string]string, error) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) HDel(ctx context.Context, key string, fields ...string) error @@ -70,6 +74,7 @@ type Store interface { LRem(ctx context.Context, key string, count int64, value any) error Rotate(ctx context.Context, key string) (string, error) LIndex(ctx context.Context, key string, index int64) (string, error) + Expire(ctx context.Context, key string, expiration time.Duration) error // SET operations SAdd(ctx context.Context, key string, members ...any) error @@ -77,6 +82,7 @@ type Store interface { SMembers(ctx context.Context, key string) ([]string, error) SRem(ctx context.Context, key string, members ...any) error SRandMember(ctx context.Context, key string) (string, error) + SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) // Pub/Sub operations Publish(ctx context.Context, channel string, message []byte) error