Fix basepool & 优化 repo

This commit is contained in:
XOF
2025-11-23 22:42:58 +08:00
parent 2b0b9b67dc
commit 6c7283d51b
16 changed files with 1312 additions and 723 deletions

View File

@@ -28,12 +28,6 @@ type GeminiChannel struct {
httpClient *http.Client httpClient *http.Client
} }
// 用于安全提取信息的本地结构体
type requestMetadata struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel { func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
@@ -47,38 +41,50 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini
logger: logger, logger: logger,
httpClient: &http.Client{ httpClient: &http.Client{
Transport: transport, Transport: transport,
Timeout: 0, Timeout: 0, // Timeout is handled by the request context
}, },
} }
} }
// TransformRequest // TransformRequest
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) { func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
modelName = ch.ExtractModel(c, requestBody)
return requestBody, modelName, nil
}
// ExtractModel
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
return ch.extractModelFromRequest(c, bodyBytes)
}
// 统一的模型提取逻辑优先从请求体解析失败则回退到从URL路径解析。
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
var p struct { var p struct {
Model string `json:"model"` Model string `json:"model"`
} }
_ = json.Unmarshal(requestBody, &p) if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
modelName = strings.TrimPrefix(p.Model, "models/") return strings.TrimPrefix(p.Model, "models/")
if modelName == "" {
modelName = ch.extractModelFromPath(c.Request.URL.Path)
} }
return requestBody, modelName, nil return ch.extractModelFromPath(c.Request.URL.Path)
} }
func (ch *GeminiChannel) extractModelFromPath(path string) string { func (ch *GeminiChannel) extractModelFromPath(path string) string {
parts := strings.Split(path, "/") parts := strings.Split(path, "/")
for _, part := range parts { for _, part := range parts {
// 覆盖更多模型名称格式
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") { if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
modelPart := strings.Split(part, ":")[0] return strings.Split(part, ":")[0]
return modelPart
} }
} }
return "" return ""
} }
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool { func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
path := c.Request.URL.Path return ch.isOpenAIPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings") return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
} }
@@ -88,25 +94,28 @@ func (ch *GeminiChannel) ValidateKey(
targetURL string, targetURL string,
timeout time.Duration, timeout time.Duration,
) *CustomErrors.APIError { ) *CustomErrors.APIError {
client := &http.Client{ client := &http.Client{Timeout: timeout}
Timeout: timeout,
}
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
if err != nil { if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error()) return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
} }
ch.ModifyRequest(req, apiKey) ch.ModifyRequest(req, apiKey)
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error()) return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 { if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil return nil
} }
errorBody, _ := io.ReadAll(resp.Body) errorBody, _ := io.ReadAll(resp.Body)
parsedMessage := CustomErrors.ParseUpstreamError(errorBody) parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
return &CustomErrors.APIError{ return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode, HTTPStatus: resp.StatusCode,
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
@@ -115,10 +124,6 @@ func (ch *GeminiChannel) ValidateKey(
} }
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) { func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
// TODO: [Future Refactoring] Decouple auth logic from URL path.
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
// This would make the channel more generic and adaptable to new upstream provider types.
if strings.Contains(req.URL.Path, "/v1beta/openai/") { if strings.Contains(req.URL.Path, "/v1beta/openai/") {
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey) req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
} else { } else {
@@ -133,24 +138,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") { if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
return true return true
} }
var meta requestMetadata var meta struct {
if err := json.Unmarshal(bodyBytes, &meta); err == nil { Stream bool `json:"stream"`
}
if json.Unmarshal(bodyBytes, &meta) == nil {
return meta.Stream return meta.Stream
} }
return false return false
} }
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string { // RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
return modelName
}
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string { func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
var rewrittenSegment string var rewrittenSegment string
if ch.IsOpenAICompatibleRequest(tempCtx) {
var apiEndpoint string if ch.isOpenAIPath(originalPath) {
v1Index := strings.LastIndex(originalPath, "/v1/") v1Index := strings.LastIndex(originalPath, "/v1/")
var apiEndpoint string
if v1Index != -1 { if v1Index != -1 {
apiEndpoint = originalPath[v1Index+len("/v1/"):] apiEndpoint = originalPath[v1Index+len("/v1/"):]
} else { } else {
@@ -158,69 +161,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
} }
rewrittenSegment = "v1beta/openai/" + apiEndpoint rewrittenSegment = "v1beta/openai/" + apiEndpoint
} else { } else {
tempPath := originalPath if strings.HasPrefix(originalPath, "/v1/") {
if strings.HasPrefix(tempPath, "/v1/") { rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/") } else {
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
} }
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
} }
trimmedBasePath := strings.TrimSuffix(basePath, "/")
pathToJoin := rewrittenSegment
trimmedBasePath := strings.TrimSuffix(basePath, "/")
// 防止版本号重复拼接,例如 basePath 是 /v1beta而重写段也是 v1beta/..
versionPrefixes := []string{"v1beta", "v1"} versionPrefixes := []string{"v1beta", "v1"}
for _, prefix := range versionPrefixes { for _, prefix := range versionPrefixes {
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") { if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/") rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
break break
} }
} }
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
if err != nil { if err != nil {
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/") // 回退到简单的字符串拼接
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
} }
return finalPath return finalPath
} }
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error { func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
// 这是一个桩实现,暂时不需要任何逻辑。 return nil // 桩实现
return nil
} }
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) { func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
// 这是一个桩实现,暂时不需要任何逻辑。 // 桩实现
} }
// ==========================================================
// ================== “智能路由”的核心引擎 ===================
// ==========================================================
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) { func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
log := ch.logger.WithField("correlation_id", params.CorrelationID) log := ch.logger.WithField("correlation_id", params.CorrelationID)
targetURL, err := url.Parse(params.UpstreamURL) targetURL, err := url.Parse(params.UpstreamURL)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to parse upstream URL") log.WithError(err).Error("Invalid upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format")) errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
return return
} }
targetURL.Path = c.Request.URL.Path targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery targetURL.RawQuery = c.Request.URL.RawQuery
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody)) initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
if err != nil { if err != nil {
log.WithError(err).Error("Failed to create initial smart request") log.WithError(err).Error("Failed to create initial smart stream request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error())) errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
return return
} }
ch.ModifyRequest(initialReq, params.APIKey) ch.ModifyRequest(initialReq, params.APIKey)
initialReq.Header.Del("Authorization") initialReq.Header.Del("Authorization")
resp, err := ch.httpClient.Do(initialReq) resp, err := ch.httpClient.Do(initialReq)
if err != nil { if err != nil {
log.WithError(err).Error("Initial smart request failed") log.WithError(err).Error("Initial smart stream request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error())) errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
return return
} }
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode) log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log) standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
defer standardizedResp.Body.Close()
c.Writer.WriteHeader(standardizedResp.StatusCode) c.Writer.WriteHeader(standardizedResp.StatusCode)
for key, values := range standardizedResp.Header { for key, values := range standardizedResp.Header {
for _, value := range values { for _, value := range values {
@@ -228,45 +238,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR
} }
} }
io.Copy(c.Writer, standardizedResp.Body) io.Copy(c.Writer, standardizedResp.Body)
params.EventLogger.IsSuccess = false params.EventLogger.IsSuccess = false
params.EventLogger.StatusCode = resp.StatusCode params.EventLogger.StatusCode = resp.StatusCode
return return
} }
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log) ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
} }
func (ch *GeminiChannel) processStreamAndRetry( func (ch *GeminiChannel) processStreamAndRetry(
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry, c *gin.Context,
initialRequestHeaders http.Header,
initialReader io.ReadCloser,
params SmartRequestParams,
log *logrus.Entry,
) { ) {
defer initialReader.Close() defer initialReader.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Connection", "keep-alive")
flusher, _ := c.Writer.(http.Flusher) flusher, _ := c.Writer.(http.Flusher)
var accumulatedText strings.Builder var accumulatedText strings.Builder
consecutiveRetryCount := 0 consecutiveRetryCount := 0
currentReader := initialReader currentReader := initialReader
maxRetries := params.MaxRetries maxRetries := params.MaxRetries
retryDelay := params.RetryDelay retryDelay := params.RetryDelay
log.Infof("Starting smart stream session. Max retries: %d", maxRetries) log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
for { for {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected, stopping stream processing.")
return
}
var interruptionReason string var interruptionReason string
scanner := bufio.NewScanner(currentReader) scanner := bufio.NewScanner(currentReader)
for scanner.Scan() { for scanner.Scan() {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected during scan.")
return
}
line := scanner.Text() line := scanner.Text()
if line == "" { if line == "" {
continue continue
} }
fmt.Fprintf(c.Writer, "%s\n\n", line) fmt.Fprintf(c.Writer, "%s\n\n", line)
flusher.Flush() flusher.Flush()
if !strings.HasPrefix(line, "data: ") { if !strings.HasPrefix(line, "data: ") {
continue continue
} }
data := strings.TrimPrefix(line, "data: ") data := strings.TrimPrefix(line, "data: ")
var payload models.GeminiSSEPayload var payload models.GeminiSSEPayload
if err := json.Unmarshal([]byte(data), &payload); err != nil { if err := json.Unmarshal([]byte(data), &payload); err != nil {
continue continue
} }
if len(payload.Candidates) > 0 { if len(payload.Candidates) > 0 {
candidate := payload.Candidates[0] candidate := payload.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 { if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
@@ -285,52 +321,71 @@ func (ch *GeminiChannel) processStreamAndRetry(
} }
} }
currentReader.Close() currentReader.Close()
if interruptionReason == "" { if interruptionReason == "" {
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
log.WithError(err).Warn("Stream scanner encountered an error.") log.WithError(err).Warn("Stream scanner encountered an error.")
interruptionReason = "SCANNER_ERROR" interruptionReason = "SCANNER_ERROR"
} else { } else {
log.Warn("Stream dropped unexpectedly without a finish reason.") log.Warn("Stream connection dropped without a finish reason.")
interruptionReason = "CONNECTION_DROP" interruptionReason = "CONNECTION_DROP"
} }
} }
if consecutiveRetryCount >= maxRetries { if consecutiveRetryCount >= maxRetries {
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason) log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}}) errData, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": http.StatusGatewayTimeout,
"status": "DEADLINE_EXCEEDED",
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
},
})
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData)) fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
flusher.Flush() flusher.Flush()
return return
} }
consecutiveRetryCount++ consecutiveRetryCount++
params.EventLogger.Retries = consecutiveRetryCount params.EventLogger.Retries = consecutiveRetryCount
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay) log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
time.Sleep(retryDelay) time.Sleep(retryDelay)
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBodyBytes, _ := json.Marshal(retryBody) retryBodyBytes, _ := json.Marshal(retryBody)
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes)) retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
if err != nil {
log.WithError(err).Error("Failed to create retry request")
continue
}
retryReq.Header = initialRequestHeaders retryReq.Header = initialRequestHeaders
ch.ModifyRequest(retryReq, params.APIKey) ch.ModifyRequest(retryReq, params.APIKey)
retryReq.Header.Del("Authorization") retryReq.Header.Del("Authorization")
retryResp, err := ch.httpClient.Do(retryReq) retryResp, err := ch.httpClient.Do(retryReq)
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
if err != nil { if err != nil {
log.WithError(err).Errorf("Retry request failed.") log.WithError(err).Error("Retry request failed")
} else {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
if retryResp.Body != nil {
retryResp.Body.Close()
}
}
continue continue
} }
if retryResp.StatusCode != http.StatusOK {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
retryResp.Body.Close()
continue
}
currentReader = retryResp.Body currentReader = retryResp.Body
} }
} }
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) { // buildRetryRequestBody 正确处理多轮对话的上下文插入。
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
retryBody := originalBody retryBody := originalBody
// 找到最后一个 'user' 角色的消息索引
lastUserIndex := -1 lastUserIndex := -1
for i := len(retryBody.Contents) - 1; i >= 0; i-- { for i := len(retryBody.Contents) - 1; i >= 0; i-- {
if retryBody.Contents[i].Role == "user" { if retryBody.Contents[i].Role == "user" {
@@ -338,25 +393,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st
break break
} }
} }
history := []models.GeminiContent{ history := []models.GeminiContent{
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}}, {Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}}, {Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
} }
if lastUserIndex != -1 { if lastUserIndex != -1 {
// 如果找到了 'user' 消息,将历史记录插入到其后
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2) newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...) newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
newContents = append(newContents, history...) newContents = append(newContents, history...)
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...) newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
retryBody.Contents = newContents retryBody.Contents = newContents
} else { } else {
// 如果没有 'user' 消息(理论上不应发生),则直接追加
retryBody.Contents = append(retryBody.Contents, history...) retryBody.Contents = append(retryBody.Contents, history...)
} }
return retryBody, nil
}
// =============================================== return retryBody
// ========= 辅助函数区 (继承并强化) ========= }
// ===============================================
type googleAPIError struct { type googleAPIError struct {
Error struct { Error struct {
@@ -397,25 +453,28 @@ func truncate(s string, n int) string {
return s return s
} }
// standardizeError
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response { func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
bodyBytes, err := io.ReadAll(resp.Body) bodyBytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to read upstream error body") log.WithError(err).Error("Failed to read upstream error body")
bodyBytes = []byte("Failed to read upstream error body: " + err.Error()) bodyBytes = []byte("Failed to read upstream error body")
} }
resp.Body.Close() resp.Body.Close()
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
var standardizedPayload googleAPIError var standardizedPayload googleAPIError
// 即使解析失败,也要构建一个标准的错误结构体
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 { if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
standardizedPayload.Error.Code = resp.StatusCode standardizedPayload.Error.Code = resp.StatusCode
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode) standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode) standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
standardizedPayload.Error.Details = []interface{}{map[string]string{ standardizedPayload.Error.Details = []interface{}{map[string]string{
"@type": "proxy.upstream.error", "@type": "proxy.upstream.unparsed.error",
"body": truncate(string(bodyBytes), truncateLimit), "body": truncate(string(bodyBytes), truncateLimit),
}} }}
} }
newBodyBytes, _ := json.Marshal(standardizedPayload) newBodyBytes, _ := json.Marshal(standardizedPayload)
newResp := &http.Response{ newResp := &http.Response{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
@@ -425,10 +484,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int
} }
newResp.Header.Set("Content-Type", "application/json; charset=utf-8") newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
newResp.Header.Set("Access-Control-Allow-Origin", "*") newResp.Header.Set("Access-Control-Allow-Origin", "*")
return newResp return newResp
} }
// errToJSON
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) { func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr}) c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
} }

View File

@@ -16,6 +16,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"` Redis RedisConfig `mapstructure:"redis"`
SessionSecret string `mapstructure:"session_secret"` SessionSecret string `mapstructure:"session_secret"`
EncryptionKey string `mapstructure:"encryption_key"` EncryptionKey string `mapstructure:"encryption_key"`
Repository RepositoryConfig `mapstructure:"repository"`
} }
// DatabaseConfig 存储数据库连接信息 // DatabaseConfig 存储数据库连接信息
@@ -43,19 +44,24 @@ type RedisConfig struct {
DSN string `mapstructure:"dsn"` DSN string `mapstructure:"dsn"`
} }
type RepositoryConfig struct {
BasePoolTTLMinutes int `mapstructure:"base_pool_ttl_minutes"`
BasePoolTTIMinutes int `mapstructure:"base_pool_tti_minutes"`
}
// LoadConfig 从文件和环境变量加载配置 // LoadConfig 从文件和环境变量加载配置
func LoadConfig() (*Config, error) { func LoadConfig() (*Config, error) {
// 设置配置文件名和路径 // 设置配置文件名和路径
viper.SetConfigName("config") viper.SetConfigName("config")
viper.SetConfigType("yaml") viper.SetConfigType("yaml")
viper.AddConfigPath(".") viper.AddConfigPath(".")
viper.AddConfigPath("/etc/gemini-balancer/") // for production
// 允许从环境变量读取 // 允许从环境变量读取
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv() viper.AutomaticEnv()
// 设置默认值 // 设置默认值
viper.SetDefault("server.port", "8080") viper.SetDefault("server.port", "9000")
viper.SetDefault("log.level", "info") viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "text") viper.SetDefault("log.format", "text")
viper.SetDefault("log.enable_file", false) viper.SetDefault("log.enable_file", false)
@@ -67,6 +73,9 @@ func LoadConfig() (*Config, error) {
viper.SetDefault("database.conn_max_lifetime", "1h") viper.SetDefault("database.conn_max_lifetime", "1h")
viper.SetDefault("encryption_key", "") viper.SetDefault("encryption_key", "")
viper.SetDefault("repository.base_pool_ttl_minutes", 60)
viper.SetDefault("repository.base_pool_tti_minutes", 10)
// 读取配置文件 // 读取配置文件
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok { if _, ok := err.(viper.ConfigFileNotFoundError); !ok {

View File

@@ -311,3 +311,8 @@ func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout t
defer resp.Body.Close() defer resp.Body.Close()
return true return true
} }
type Manager interface {
AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error)
// ... 其他需要暴露给外部服务的方法
}

View File

@@ -44,6 +44,7 @@ var (
ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."} ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."}
ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."} ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."}
ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."} ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."}
ErrProxyNotAvailable = &APIError{HTTPStatus: http.StatusNotFound, Code: "PROXY_ERROR", Message: "Required proxy is not available for this request."}
ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."} ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."}
ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"} ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}

View File

@@ -29,9 +29,7 @@ import (
"gorm.io/datatypes" "gorm.io/datatypes"
) )
type proxyErrorKey int type proxyErrorContextKey struct{}
const proxyErrKey proxyErrorKey = 0
type ProxyHandler struct { type ProxyHandler struct {
resourceService *service.ResourceService resourceService *service.ResourceService
@@ -81,45 +79,51 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
h.handleListModelsRequest(c) h.handleListModelsRequest(c)
return return
} }
requestBody, err := io.ReadAll(c.Request.Body)
maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024)
requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
if err != nil { if err != nil {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body")) errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read"))
return return
} }
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
c.Request.ContentLength = int64(len(requestBody))
modelName := h.channel.ExtractModel(c, requestBody) modelName := h.channel.ExtractModel(c, requestBody)
groupName := c.Param("group_name") groupName := c.Param("group_name")
isPreciseRouting := groupName != "" isPreciseRouting := groupName != ""
if !isPreciseRouting && modelName == "" { if !isPreciseRouting && modelName == "" {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL")) errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request"))
return return
} }
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil { if err != nil {
if apiErr, ok := err.(*errors.APIError); ok { if apiErr, ok := err.(*errors.APIError); ok {
errToJSON(c, uuid.New().String(), apiErr) errToJSON(c, uuid.New().String(), apiErr)
} else { } else {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error())) errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources"))
} }
return return
} }
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup) finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
if err != nil { if err != nil {
h.logger.WithError(err).Error("Failed to build operational config.") h.logger.WithError(err).Error("Failed to build operational config.")
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration")) errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config"))
return return
} }
initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig)
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c) isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
if isOpenAICompatible { if isOpenAICompatible {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
return return
} }
isStream := h.channel.IsStreamRequest(c, requestBody) isStream := h.channel.IsStreamRequest(c, requestBody)
systemSettings := h.settingsManager.GetSettings()
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry { if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry {
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting) h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
} else { } else {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
@@ -129,219 +133,307 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) { func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
startTime := time.Now() startTime := time.Now()
correlationID := uuid.New().String() correlationID := uuid.New().String()
var finalRecorder *httptest.ResponseRecorder var finalRecorder *httptest.ResponseRecorder
var lastUsedResources *service.RequestResources var lastUsedResources *service.RequestResources
var finalProxyErr *errors.APIError var finalProxyErr *errors.APIError
var isSuccess bool var isSuccess bool
var finalPromptTokens, finalCompletionTokens int var finalPromptTokens, finalCompletionTokens, actualRetries int
var actualRetries int = 0
defer func() {
if lastUsedResources == nil {
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
return
}
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) defer h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
finalEvent.RequestLog.IsSuccess = isSuccess finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
finalEvent.RequestLog.Retries = actualRetries actualRetries, isPreciseRouting)
if isSuccess {
finalEvent.RequestLog.PromptTokens = finalPromptTokens
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
}
if finalRecorder != nil { maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
finalEvent.RequestLog.StatusCode = finalRecorder.Code
}
if !isSuccess {
if finalProxyErr != nil {
finalEvent.Error = finalProxyErr
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
} else if finalRecorder != nil {
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
finalEvent.Error = apiErr
finalEvent.RequestLog.ErrorCode = apiErr.Code
finalEvent.RequestLog.ErrorMessage = apiErr.Message
}
}
eventData, err := json.Marshal(finalEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
}
}()
var maxRetries int
if isPreciseRouting {
if finalOpConfig.MaxRetries != nil {
maxRetries = *finalOpConfig.MaxRetries
} else {
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
} else {
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
totalAttempts := maxRetries + 1 totalAttempts := maxRetries + 1
for attempt := 1; attempt <= totalAttempts; attempt++ { for attempt := 1; attempt <= totalAttempts; attempt++ {
if c.Request.Context().Err() != nil { if c.Request.Context().Err() != nil {
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.") h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
if finalProxyErr == nil { if finalProxyErr == nil {
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
} }
break break
} }
var currentResources *service.RequestResources
var err error
if attempt == 1 {
currentResources = initialResources
} else {
actualRetries = attempt - 1
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
break
}
}
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig) resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
currentResources.RequestConfig = finalRequestConfig if err != nil {
lastUsedResources = currentResources if apiErr, ok := err.(*errors.APIError); ok {
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID) finalProxyErr = apiErr
var attemptErr *errors.APIError } else {
var attemptIsSuccess bool finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
}
break
}
lastUsedResources = resources
if attempt > 1 {
actualRetries = attempt - 1
}
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID)
recorder, attemptErr, attemptSuccess := h.executeProxyAttempt(
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
&finalPromptTokens, &finalCompletionTokens,
)
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
if isSuccess {
break
}
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
break
}
h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting)
}
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
}
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
attemptStartTime := time.Now() var attemptErr *errors.APIError
var isSuccess bool
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout) ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
defer cancel() defer cancel()
attemptReq := c.Request.Clone(ctx) attemptReq := c.Request.Clone(ctx)
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody)) attemptReq.Body = io.NopCloser(bytes.NewReader(body))
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" { attemptReq.ContentLength = int64(len(body))
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
isSuccess = false h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource") *attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
continue
h.transparentProxy.ServeHTTP(recorder, attemptReq)
return recorder, attemptErr, isSuccess
} }
h.transparentProxy.Director = func(req *http.Request) {
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL) func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) {
req.URL.Scheme = targetURL.Scheme h.transparentProxy.Director = func(r *http.Request) {
req.URL.Host = targetURL.Host targetURL, _ := url.Parse(res.UpstreamEndpoint.URL)
req.Host = targetURL.Host r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
var pureClientPath string var pureClientPath string
if isPreciseRouting { if isPrecise {
proxyPrefix := "/proxy/" + groupName pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
} else { } else {
pureClientPath = req.URL.Path pureClientPath = r.URL.Path
} }
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath) r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
req.URL.Path = finalPath
h.logger.WithFields(logrus.Fields{ r.Header.Del("Authorization")
"correlation_id": correlationID, h.channel.ModifyRequest(r, res.APIKey)
"attempt": attempt, r.Header.Set("X-Correlation-ID", corrID)
"key_id": currentResources.APIKey.ID,
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
"final_request_url": req.URL.String(),
}).Infof("Director constructed final upstream request URL.")
req.Header.Del("Authorization")
h.channel.ModifyRequest(req, currentResources.APIKey)
req.Header.Set("X-Correlation-ID", correlationID)
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
} }
transport := h.transparentProxy.Transport.(*http.Transport) transport := h.transparentProxy.Transport.(*http.Transport)
if currentResources.ProxyConfig != nil { if res.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address) proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address)
proxyURL, err := url.Parse(proxyURLStr) if proxyURL, err := url.Parse(proxyURLStr); err == nil {
if err == nil {
transport.Proxy = http.ProxyURL(proxyURL) transport.Proxy = http.ProxyURL(proxyURL)
} else {
transport.Proxy = http.ProxyFromEnvironment
} }
} else { } else {
transport.Proxy = http.ProxyFromEnvironment transport.Proxy = http.ProxyFromEnvironment
} }
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
defer resp.Body.Close() h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens)
var reader io.ReadCloser }
var err error
isGzipped := resp.Header.Get("Content-Encoding") == "gzip" func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
if isGzipped { return func(resp *http.Response) error {
reader, err = gzip.NewReader(resp.Body) var reader io.ReadCloser = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil { if err != nil {
h.logger.WithError(err).Error("Failed to create gzip reader") h.logger.WithError(err).Error("Failed to create gzip reader")
reader = resp.Body
} else { } else {
reader = gzReader
resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Encoding")
} }
defer reader.Close()
} else {
reader = resp.Body
} }
defer reader.Close()
bodyBytes, err := io.ReadAll(reader) bodyBytes, err := io.ReadAll(reader)
if err != nil { if err != nil {
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error()) *attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message))) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
return nil return nil
} }
if resp.StatusCode < 400 { if resp.StatusCode < 400 {
attemptIsSuccess = true *isSuccess = true
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes) *pTokens, *cTokens = extractUsage(bodyBytes)
} else { } else {
parsedMsg := errors.ParseUpstreamError(bodyBytes) parsedMsg := errors.ParseUpstreamError(bodyBytes)
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg) *attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
} }
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil return nil
} }
h.transparentProxy.ServeHTTP(recorder, attemptReq)
finalRecorder = recorder
finalProxyErr = attemptErr
isSuccess = attemptIsSuccess
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
if isSuccess {
break
} }
isUnretryableError := false
if finalProxyErr != nil { func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
if errors.IsUnretryableRequestError(finalProxyErr.Message) { corrID := r.Header.Get("X-Correlation-ID")
isUnretryableError = true log := h.logger.WithField("id", corrID)
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message) log.Errorf("Transparent proxy encountered an error: %v", err)
errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError)
if !ok || errPtr == nil {
log.Error("FATAL: proxyErrorContextKey not found in context for error handler.")
defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred")
writeErrorToResponse(rw, defaultErr)
return
}
if *errPtr == nil {
if errors.IsClientNetworkError(err) {
*errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
*errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
} }
} }
if attempt >= totalAttempts || isUnretryableError { writeErrorToResponse(rw, *errPtr)
break
} }
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds()) func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) {
retryEvent.IsSuccess = false if attempt == 1 {
retryEvent.StatusCode = recorder.Code return initialResources, nil
retryEvent.Retries = actualRetries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.ErrorCode = attemptErr.Code
retryEvent.ErrorMessage = attemptErr.Message
} }
eventData, _ := json.Marshal(retryEvent) h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
return nil, err
} }
if finalRecorder != nil { finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig)
bodyBytes := finalRecorder.Body.Bytes() resources.RequestConfig = finalRequestConfig
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes))) return resources, nil
for k, v := range finalRecorder.Header() { }
if strings.ToLower(k) != "content-length" {
func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool {
if attempt >= totalAttempts {
return true
}
if err != nil && errors.IsUnretryableRequestError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Attempt failed with unretryable request error. Aborting retries. Message: %s", err.Message)
return true
}
return false
}
func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) {
if rec != nil {
for k, v := range rec.Header() {
c.Writer.Header()[k] = v c.Writer.Header()[k] = v
} }
} c.Writer.WriteHeader(rec.Code)
c.Writer.WriteHeader(finalRecorder.Code) c.Writer.Write(rec.Body.Bytes())
c.Writer.Write(finalRecorder.Body.Bytes()) } else if apiErr != nil {
errToJSON(c, corrID, apiErr)
} else { } else {
errToJSON(c, correlationID, finalProxyErr) errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred"))
} }
} }
func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) {
if res == nil {
h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.")
return
}
event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise)
event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
event.RequestLog.IsSuccess = isSuccess
event.RequestLog.Retries = retries
if isSuccess {
event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens
}
if rec != nil {
event.RequestLog.StatusCode = rec.Code
}
if !isSuccess {
errToLog := finalErr
if errToLog == nil && rec != nil {
errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), "Request failed after all retries.")
}
if errToLog != nil {
event.Error = errToLog
event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message
}
}
eventData, err := json.Marshal(event)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event")
}
}
func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) {
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
retryEvent.RequestLog.IsSuccess = false
retryEvent.RequestLog.StatusCode = rec.Code
retryEvent.RequestLog.Retries = retries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message
}
eventData, err := json.Marshal(retryEvent)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event")
}
}
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
finalConfig := &models.RequestConfig{
CustomHeaders: make(datatypes.JSONMap),
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
StreamMinDelay: globalSettings.StreamMinDelay,
StreamMaxDelay: globalSettings.StreamMaxDelay,
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
StreamChunkSize: globalSettings.StreamChunkSize,
EnableFakeStream: globalSettings.EnableFakeStream,
FakeStreamInterval: globalSettings.FakeStreamInterval,
}
for k, v := range globalSettings.CustomHeaders {
finalConfig.CustomHeaders[k] = v
}
if groupConfig == nil {
return finalConfig
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
return finalConfig
}
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
}
return finalConfig
}
func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) {
if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() {
return
}
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(apiErr.HTTPStatus)
json.NewEncoder(rw).Encode(gin.H{"error": apiErr})
}
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) { func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
startTime := time.Now() startTime := time.Now()
correlationID := uuid.New().String() correlationID := uuid.New().String()
@@ -349,7 +441,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
log.Info("Smart Gateway activated for streaming request.") log.Info("Smart Gateway activated for streaming request.")
var originalRequest models.GeminiRequest var originalRequest models.GeminiRequest
if err := json.Unmarshal(requestBody, &originalRequest); err != nil { if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error())) errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway"))
return return
} }
systemSettings := h.settingsManager.GetSettings() systemSettings := h.settingsManager.GetSettings()
@@ -360,8 +452,14 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
if c.Writer.Status() > 0 { if c.Writer.Status() > 0 {
requestFinishedEvent.StatusCode = c.Writer.Status() requestFinishedEvent.StatusCode = c.Writer.Status()
} }
eventData, _ := json.Marshal(requestFinishedEvent) eventData, err := json.Marshal(requestFinishedEvent)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream")
}
}() }()
params := channel.SmartRequestParams{ params := channel.SmartRequestParams{
CorrelationID: correlationID, CorrelationID: correlationID,
@@ -378,30 +476,6 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
h.channel.ProcessSmartStreamRequest(c, params) h.channel.ProcessSmartStreamRequest(c, params)
} }
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
correlationID := r.Header.Get("X-Correlation-ID")
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
if !exists || proxyErrPtr == nil {
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
return
}
if errors.IsClientNetworkError(err) {
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
if _, ok := rw.(*httptest.ResponseRecorder); ok {
return
}
if writer, ok := rw.(interface{ Written() bool }); ok {
if writer.Written() {
return
}
}
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
}
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent { func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
event := &models.RequestFinishedEvent{ event := &models.RequestFinishedEvent{
RequestLog: models.RequestLog{ RequestLog: models.RequestLog{
@@ -456,12 +530,14 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
} }
if isPreciseRouting { if isPreciseRouting {
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName) return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
} else {
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
} }
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
} }
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) { func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{ c.JSON(apiErr.HTTPStatus, gin.H{
"error": apiErr, "error": apiErr,
"correlation_id": corrID, "correlation_id": corrID,
@@ -471,7 +547,7 @@ func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
type bufferPool struct{} type bufferPool struct{}
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) } func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
func (b *bufferPool) Put(bytes []byte) {} func (b *bufferPool) Put(_ []byte) {}
func extractUsage(body []byte) (promptTokens int, completionTokens int) { func extractUsage(body []byte) (promptTokens int, completionTokens int) {
var data struct { var data struct {
@@ -486,34 +562,11 @@ func extractUsage(body []byte) (promptTokens int, completionTokens int) {
return 0, 0 return 0, 0
} }
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig { func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int {
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders) if isPreciseRouting && finalOpConfig.MaxRetries != nil {
var customHeadersMap datatypes.JSONMap return *finalOpConfig.MaxRetries
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
finalConfig := &models.RequestConfig{
CustomHeaders: customHeadersMap,
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
StreamMinDelay: globalSettings.StreamMinDelay,
StreamMaxDelay: globalSettings.StreamMaxDelay,
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
StreamChunkSize: globalSettings.StreamChunkSize,
EnableFakeStream: globalSettings.EnableFakeStream,
FakeStreamInterval: globalSettings.FakeStreamInterval,
} }
if groupConfig == nil { return h.settingsManager.GetSettings().MaxRetries
return finalConfig
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
return finalConfig
}
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
return finalConfig
}
return finalConfig
} }
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) { func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {

View File

@@ -77,3 +77,9 @@ type APIKeyDetails struct {
CooldownUntil *time.Time `json:"cooldown_until"` CooldownUntil *time.Time `json:"cooldown_until"`
EncryptedKey string EncryptedKey string
} }
// SettingsManager 定义了系统设置管理器的抽象接口。
type SettingsManager interface {
GetSettings() *SystemSettings
}

View File

@@ -11,6 +11,7 @@ type SystemSettings struct {
BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"` BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"`
KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间单位为分钟。"` KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间单位为分钟。"`
LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"` LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"`
MaxRequestBodySizeMB int `json:"max_request_body_size_mb" default:"10" name:"最大请求体大小 (MB)" category:"请求设置" desc:"允许代理接收的最大请求体大小单位为MB。超过此大小的请求将被拒绝。"`
PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"` PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"`
@@ -41,6 +42,10 @@ type SystemSettings struct {
MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前允许的连续登录失败次数。"` MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前允许的连续登录失败次数。"`
IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长单位为分钟。"` IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长单位为分钟。"`
// BasePool 相关配置
// BasePoolTTLMinutes int `json:"base_pool_ttl_minutes" default:"30" name:"基础资源池最大生存时间(分钟)" category:"基础资源池" desc:"一个动态构建的基础资源池BasePool在Redis中的最大生存时间。到期后即使仍在活跃使用也会被强制重建。"`
// BasePoolTTIMinutes int `json:"base_pool_tti_minutes" default:"10" name:"基础资源池空闲超时(分钟)" category:"基础资源池" desc:"一个基础资源池BasePool在连续无请求后自动销毁的空闲等待时间。"`
//智能网关 //智能网关
LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时保留的最大字符数。0表示不截断。"` LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时保留的最大字符数。0表示不截断。"`
EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"` EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"`

View File

@@ -1,4 +1,4 @@
// Filename: internal/repository/key_cache.go // Filename: internal/repository/key_cache.go (最终定稿)
package repository package repository
import ( import (
@@ -9,6 +9,7 @@ import (
"strconv" "strconv"
) )
// --- Redis Key 常量定义 ---
const ( const (
KeyGroup = "group:%d:keys:active" KeyGroup = "group:%d:keys:active"
KeyDetails = "key:%d:details" KeyDetails = "key:%d:details"
@@ -23,13 +24,16 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown" BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
) )
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系并完整重建Redis缓存。
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error { func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...") r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
var allMappings []*models.GroupAPIKeyMapping var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil { if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err) return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
} }
// 1. 批量解密所有涉及的密钥
keyMap := make(map[uint]*models.APIKey) keyMap := make(map[uint]*models.APIKey)
for _, m := range allMappings { for _, m := range allMappings {
if m.APIKey != nil { if m.APIKey != nil {
@@ -41,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
keysToDecrypt = append(keysToDecrypt, *k) keysToDecrypt = append(keysToDecrypt, *k)
} }
if err := r.decryptKeys(keysToDecrypt); err != nil { if err := r.decryptKeys(keysToDecrypt); err != nil {
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.") r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
// 即使解密失败,也继续尝试加载未加密或已解密的部分
} }
decryptedKeyMap := make(map[uint]models.APIKey) decryptedKeyMap := make(map[uint]models.APIKey)
for _, k := range keysToDecrypt { for _, k := range keysToDecrypt {
decryptedKeyMap[k.ID] = k decryptedKeyMap[k.ID] = k
} }
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping) // 2. 清理所有分组的旧轮询结构
pipe := r.store.Pipeline(context.Background()) pipe := r.store.Pipeline(ctx)
detailsToSet := make(map[string][]byte)
var allGroups []*models.KeyGroup var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil { if err := r.db.Find(&allGroups).Error; err == nil {
for _, group := range allGroups { for _, group := range allGroups {
@@ -63,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
) )
} }
} else { } else {
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup") r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
} }
// 3. 准备批量更新数据
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
detailsToSet := make(map[string]any)
for _, mapping := range allMappings { for _, mapping := range allMappings {
if mapping.APIKey == nil { if mapping.APIKey == nil {
continue continue
} }
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID] decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
if !ok { if !ok {
continue continue // 跳过解密失败的密钥
} }
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
keyJSON, _ := json.Marshal(decryptedKey) keyJSON, _ := json.Marshal(decryptedKey)
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
mappingJSON, _ := json.Marshal(mapping) mappingJSON, _ := json.Marshal(mapping)
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
if mapping.Status == models.StatusActive { if mapping.Status == models.StatusActive {
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping) activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
} }
} }
// 4. 使用 MSet 批量写入详情和映射缓存
if len(detailsToSet) > 0 {
if err := r.store.MSet(ctx, detailsToSet); err != nil {
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
}
}
// 5. 在Pipeline中重建所有分组的轮询结构
for groupID, activeMappings := range activeKeysByGroup { for groupID, activeMappings := range activeKeysByGroup {
if len(activeMappings) == 0 { if len(activeMappings) == 0 {
continue continue
@@ -101,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
} }
// 6. 执行Pipeline
if err := pipe.Exec(); err != nil { if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err) return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(context.Background(), key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
} }
r.logger.Info("Cache rebuild complete, including all polling structures.") r.logger.Info("Full cache rebuild completed successfully.")
return nil return nil
} }
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error { func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err := r.decryptKey(key); err != nil { if err := r.decryptKey(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err) return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
@@ -128,78 +144,101 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0) return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
} }
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error { func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(ctx, key.ID) groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil { if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err) r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
} }
pipe := r.store.Pipeline(ctx) pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID)) pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID)) pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
} }
return pipe.Exec() return pipe.Exec()
} }
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error { func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline(context.Background()) keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID) groupID := mapping.KeyGroupID
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID) ctx := context.Background()
pipe := r.store.Pipeline(ctx)
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
// 如果新状态是 Active则重新添加到所有轮询结构中
if mapping.Status == models.StatusActive { if mapping.Status == models.StatusActive {
pipe.LPush(activeKeyListKey, mapping.APIKeyID) pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
} }
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
// 无论状态如何,都更新映射详情的 K-V 缓存
mappingJSON, err := json.Marshal(mapping)
if err != nil {
return fmt.Errorf("failed to marshal mapping: %w", err)
}
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
return pipe.Exec() return pipe.Exec()
} }
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error { func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 { if len(mappings) == 0 {
return nil return nil
} }
groupUpdates := make(map[uint]struct {
ToAdd []interface{} pipe := r.store.Pipeline(ctx)
ToRemove []interface{}
})
for _, mapping := range mappings { for _, mapping := range mappings {
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10) keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
update, ok := groupUpdates[mapping.KeyGroupID] groupID := mapping.KeyGroupID
if !ok {
update = struct { // 对于批处理中的每一个mapping都执行完整的、正确的“先删后增”逻辑
ToAdd []interface{} pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
ToRemove []interface{} pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
}{} pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
} pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
if mapping.Status == models.StatusActive { if mapping.Status == models.StatusActive {
update.ToRemove = append(update.ToRemove, keyIDStr) pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
update.ToAdd = append(update.ToAdd, keyIDStr) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
} else { pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
update.ToRemove = append(update.ToRemove, keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
} }
groupUpdates[mapping.KeyGroupID] = update pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
} }
pipe := r.store.Pipeline(context.Background())
var pipelineError error mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误以保证大部分更新成功
for groupID, updates := range groupUpdates { pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
if len(updates.ToRemove) > 0 {
for _, keyID := range updates.ToRemove {
pipe.LRem(activeKeyListKey, 0, keyID)
} }
}
if len(updates.ToAdd) > 0 { return pipe.Exec()
pipe.LPush(activeKeyListKey, updates.ToAdd...)
}
}
if err := pipe.Exec(); err != nil {
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
}
return pipelineError
} }

View File

@@ -23,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
keyHashes := make([]string, len(keys)) keyHashes := make([]string, len(keys))
keyValueToHashMap := make(map[string]string) keyValueToHashMap := make(map[string]string)
for i, k := range keys { for i, k := range keys {
// All incoming keys must have plaintext APIKey
if k.APIKey == "" { if k.APIKey == "" {
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i) return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
} }
@@ -35,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
var finalKeys []models.APIKey var finalKeys []models.APIKey
err := r.db.Transaction(func(tx *gorm.DB) error { err := r.db.Transaction(func(tx *gorm.DB) error {
var existingKeys []models.APIKey var existingKeys []models.APIKey
// [MODIFIED] Query by hash to find existing keys.
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil { if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
return err return err
} }
@@ -69,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
} }
} }
if len(keysToCreate) > 0 { if len(keysToCreate) > 0 {
// [MODIFIED] Create now only provides encrypted data and hash.
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil { if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
return err return err
} }
} }
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil { if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
return err return err
} }
// [CRITICAL] Decrypt all keys before returning them to the service layer.
return r.decryptKeys(finalKeys) return r.decryptKeys(finalKeys)
}) })
return finalKeys, err return finalKeys, err
} }
func (r *gormKeyRepository) Update(key *models.APIKey) error { func (r *gormKeyRepository) Update(key *models.APIKey) error {
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
// This indicates a potential change that needs to be re-encrypted.
if key.APIKey != "" { if key.APIKey != "" {
encryptedKey, err := r.crypto.Encrypt(key.APIKey) encryptedKey, err := r.crypto.Encrypt(key.APIKey)
if err != nil { if err != nil {
@@ -98,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
key.APIKeyHash = hex.EncodeToString(hash[:]) key.APIKeyHash = hex.EncodeToString(hash[:])
} }
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
return tx.Save(key).Error return tx.Save(key).Error
}) })
if err != nil { if err != nil {
return err return err
} }
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
if err := r.decryptKey(key); err != nil { if err := r.decryptKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err) r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
return nil // Continue without cache update if decryption fails. return nil
} }
if err := r.updateStoreCacheForKey(key); err != nil { if err := r.updateStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err) r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
@@ -192,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// [CRITICAL] Decrypt before returning.
return keys, r.decryptKeys(keys) return keys, r.decryptKeys(keys)
} }

View File

@@ -1,4 +1,4 @@
// Filename: internal/repository/key_selector.go (经审查后最终修复版) // Filename: internal/repository/key_selector.go
package repository package repository
import ( import (
@@ -19,38 +19,39 @@ import (
const ( const (
CacheTTL = 5 * time.Minute CacheTTL = 5 * time.Minute
EmptyPoolPlaceholder = "EMPTY_POOL"
EmptyCacheTTL = 1 * time.Minute EmptyCacheTTL = 1 * time.Minute
) )
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。 // SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) { func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
if group == nil {
return nil, nil, fmt.Errorf("group cannot be nil")
}
var keyIDStr string var keyIDStr string
var err error var err error
switch group.PollingStrategy { switch group.PollingStrategy {
case models.StrategySequential: case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID) sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey) keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted: case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID) lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 { if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0] keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
} }
err = zerr err = zerr
case models.StrategyRandom: case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID) mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
default: // 默认或未指定策略时,使用基础的随机策略
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey) keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
} }
if err != nil { if err != nil {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound return nil, nil, gorm.ErrRecordNotFound
@@ -58,39 +59,44 @@ func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *model
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy) r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
return nil, nil, err return nil, nil, err
} }
if keyIDStr == "" { if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound return nil, nil, gorm.ErrRecordNotFound
} }
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) if parseErr != nil {
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil { if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID) r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
return nil, nil, err return nil, nil, err
} }
if group.PollingStrategy == models.StrategyWeighted { if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID)) go func() {
updateCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
}()
} }
return apiKey, mapping, nil return apiKey, mapping, nil
} }
// SelectOneActiveKeyFromBasePool 智能聚合模式设计的全新轮询器 // SelectOneActiveKeyFromBasePool 智能聚合池中选取一个可用Key
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
poolID := generatePoolID(pool.CandidateGroups) if pool == nil || len(pool.CandidateGroups) == 0 {
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
}
poolID := r.generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID) log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil { if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.") if !errors.Is(err, gorm.ErrRecordNotFound) {
log.WithError(err).Error("Failed to ensure BasePool cache exists")
}
return nil, nil, err return nil, nil, err
} }
var keyIDStr string var keyIDStr string
var err error var err error
switch pool.PollingStrategy { switch pool.PollingStrategy {
case models.StrategySequential: case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
@@ -98,8 +104,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
case models.StrategyWeighted: case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID) lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 { if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0] keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
} }
err = zerr err = zerr
case models.StrategyRandom: case models.StrategyRandom:
@@ -107,13 +117,12 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: default:
log.Warnf("Default polling strategy triggered inside selection. This should be rare.") log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0) keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
} }
if err != nil { if err != nil {
if errors.Is(err, store.ErrNotFound) { if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound return nil, nil, gorm.ErrRecordNotFound
} }
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy) log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
@@ -122,73 +131,224 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context,
if keyIDStr == "" { if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound return nil, nil, gorm.ErrRecordNotFound
} }
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
for _, group := range pool.CandidateGroups { defer cancel()
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) r.refreshBasePoolHeartbeat(bgCtx, poolID)
if cacheErr == nil && apiKey != nil && mapping != nil { }()
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
if err != nil {
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
}
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
}
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
if err != nil {
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
return nil, nil, err
}
var originGroup *models.KeyGroup
for _, g := range pool.CandidateGroups {
if g.ID == uint(groupID) {
originGroup = g
break
}
}
if originGroup == nil {
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
}
if pool.PollingStrategy == models.StrategyWeighted { if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID)) go func() {
} bgCtx, cancel := r.withTimeout(5 * time.Second)
return apiKey, group, nil defer cancel()
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
}()
} }
return apiKey, originGroup, nil
} }
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID) // ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构。
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error { func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID) heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
exists, err := r.store.Exists(ctx, listKey) // 预检查,快速失败
if err != nil { if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err
}
if exists {
val, err := r.store.LIndex(ctx, listKey, 0)
if err != nil {
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound return gorm.ErrRecordNotFound
} }
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil return nil
} }
} // 获取分布式锁
lockKey := fmt.Sprintf("lock:basepool:%s", poolID) lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second) if err := r.acquireLock(ctx, lockKey); err != nil {
if err != nil { return err // acquireLock 内部已记录日志并返回明确错误
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
} }
if !acquired { defer r.releaseLock(context.Background(), lockKey)
time.Sleep(100 * time.Millisecond) // 双重检查锁定
return r.ensureBasePoolCacheExists(ctx, pool, poolID) if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
} }
defer r.store.Del(context.Background(), lockKey) if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
if exists, _ := r.store.Exists(ctx, listKey); exists {
return nil return nil
} }
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID) // 在执行重度操作前,最后检查一次上下文是否已取消
var allActiveKeyIDs []string select {
lruMembers := make(map[string]float64) case <-ctx.Done():
return ctx.Err()
default:
}
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
keyToGroupMap := make(map[string]any)
allKeyIDsSet := make(map[string]struct{})
for _, group := range pool.CandidateGroups { for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey) groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
if err != nil { if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID) r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
continue
}
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
for _, keyID := range groupKeyIDs {
if _, exists := allKeyIDsSet[keyID]; !exists {
allKeyIDsSet[keyID] = struct{}{}
keyToGroupMap[keyID] = groupIDStr
}
}
}
// 处理空池情况
if len(allKeyIDsSet) == 0 {
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
if emptyCacheTTL < time.Minute {
emptyCacheTTL = time.Minute
}
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
for keyID := range allKeyIDsSet {
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
}
// 使用 Pipeline 原子化构建所有缓存结构
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
pipe := r.store.Pipeline(ctx)
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
if len(keyToGroupMap) > 0 {
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
pipe.Expire(keyToGroupMapKey, basePoolTTL)
}
pipe.Expire(mainPoolKey, basePoolTTL)
pipe.Expire(sequentialKey, basePoolTTL)
pipe.Expire(cooldownKey, basePoolTTL)
pipe.Expire(lruKey, basePoolTTL)
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
return err return err
} }
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...) // 异步填充 LRU 缓存,并传入已构建好的映射
for _, keyIDStr := range groupKeyIDs { go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
return nil
}
// --- 辅助方法 ---
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
const (
lockTTL = 30 * time.Second
lockMaxRetries = 5
lockBaseBackoff = 50 * time.Millisecond
)
for i := 0; i < lockMaxRetries; i++ {
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
return err
}
if acquired {
return nil
}
time.Sleep(lockBaseBackoff * (1 << i))
}
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
}
// releaseLock 封装了分布式锁的释放逻辑。
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
if err := r.store.Del(ctx, lockKey); err != nil {
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
}
}
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), duration)
}
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
// 使用 EXPIRE 命令来刷新如果Key不存在它什么也不做是安全的
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
}
}
}
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
func (r *gormKeyRepository) populateBasePoolLRUCache(
parentCtx context.Context,
currentPoolID string,
keys []string,
keyToGroupMap map[string]any,
) {
lruMembers := make(map[string]float64, len(keys))
for _, keyIDStr := range keys {
select {
case <-parentCtx.Done():
return
default:
}
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
if !ok {
continue
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
if err == nil && mapping != nil { mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
data, err := r.store.Get(parentCtx, mappingKey)
if err != nil {
continue
}
var mapping models.GroupAPIKeyMapping
if json.Unmarshal(data, &mapping) == nil {
var score float64 var score float64
if mapping.LastUsedAt != nil { if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli()) score = float64(mapping.LastUsedAt.UnixMilli())
@@ -196,44 +356,21 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool
lruMembers[keyIDStr] = score lruMembers[keyIDStr] = score
} }
} }
}
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline(ctx)
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
pipe := r.store.Pipeline(ctx)
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
if err := pipe.Exec(); err != nil {
return err
}
if len(lruMembers) > 0 { if len(lruMembers) > 0 {
if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil { lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID) if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
if parentCtx.Err() == nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
}
} }
} }
return nil
} }
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET // updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) { func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID) lruKey := fmt.Sprintf(BasePoolLRU, poolID)
err := r.store.ZAdd(ctx, lruKey, map[string]float64{ err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(), strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
}) })
if err != nil { if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID) r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
@@ -241,20 +378,19 @@ func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context,
} }
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID // generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup) string { func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
ids := make([]int, len(groups)) ids := make([]int, len(groups))
for i, g := range groups { for i, g := range groups {
ids[i] = int(g.ID) ids[i] = int(g.ID)
} }
sort.Ints(ids) sort.Ints(ids)
h := sha1.New() h := sha1.New()
io.WriteString(h, fmt.Sprintf("%v", ids)) io.WriteString(h, fmt.Sprintf("%v", ids))
return fmt.Sprintf("%x", h.Sum(nil)) return fmt.Sprintf("%x", h.Sum(nil))
} }
// toInterfaceSlice 类型转换辅助函数 // toInterfaceSlice 类型转换辅助函数
func toInterfaceSlice(slice []string) []interface{} { func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
result := make([]interface{}, len(slice)) result := make([]interface{}, len(slice))
for i, v := range slice { for i, v := range slice {
result[i] = v result[i] = v
@@ -263,7 +399,7 @@ func toInterfaceSlice(slice []string) []interface{} {
} }
// nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略 // nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略
func nowMilli() float64 { func (r *gormKeyRepository) nowMilli() float64 {
return float64(time.Now().UnixMilli()) return float64(time.Now().UnixMilli())
} }

View File

@@ -1,8 +1,9 @@
// Filename: internal/repository/repository.go (经审查后最终修复版) // Filename: internal/repository/repository.go
package repository package repository
import ( import (
"context" "context"
"gemini-balancer/internal/config"
"gemini-balancer/internal/crypto" "gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors" "gemini-balancer/internal/errors"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -87,18 +88,20 @@ type gormKeyRepository struct {
store store.Store store store.Store
logger *logrus.Entry logger *logrus.Entry
crypto *crypto.Service crypto *crypto.Service
config *config.Config
} }
type gormGroupRepository struct { type gormGroupRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository { func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository {
return &gormKeyRepository{ return &gormKeyRepository{
db: db, db: db,
store: s, store: s,
logger: logger.WithField("component", "repository.key🔗"), logger: logger.WithField("component", "repository.key🔗"),
crypto: crypto, crypto: crypto,
config: cfg,
} }
} }

View File

@@ -1,10 +1,9 @@
// Filename: internal/service/resource_service.go // Filename: internal/service/resource_service.go
package service package service
import ( import (
"context" "context"
"errors" "gemini-balancer/internal/domain/proxy"
apperrors "gemini-balancer/internal/errors" apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"gemini-balancer/internal/repository" "gemini-balancer/internal/repository"
@@ -16,10 +15,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
var ( // RequestResources 封装了一次成功请求所需的所有资源。
ErrNoResourceAvailable = errors.New("no available resource found for the request")
)
type RequestResources struct { type RequestResources struct {
KeyGroup *models.KeyGroup KeyGroup *models.KeyGroup
APIKey *models.APIKey APIKey *models.APIKey
@@ -28,41 +24,51 @@ type RequestResources struct {
RequestConfig *models.RequestConfig RequestConfig *models.RequestConfig
} }
// ResourceService 负责根据请求参数和业务规则动态地选择和分配API密钥及相关资源。
type ResourceService struct { type ResourceService struct {
settingsManager *settings.SettingsManager settingsManager *settings.SettingsManager
groupManager *GroupManager groupManager *GroupManager
keyRepo repository.KeyRepository keyRepo repository.KeyRepository
authTokenRepo repository.AuthTokenRepository
apiKeyService *APIKeyService apiKeyService *APIKeyService
proxyManager *proxy.Module
logger *logrus.Entry logger *logrus.Entry
initOnce sync.Once initOnce sync.Once
} }
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
func NewResourceService( func NewResourceService(
sm *settings.SettingsManager, sm *settings.SettingsManager,
gm *GroupManager, gm *GroupManager,
kr repository.KeyRepository, kr repository.KeyRepository,
atr repository.AuthTokenRepository,
aks *APIKeyService, aks *APIKeyService,
pm *proxy.Module,
logger *logrus.Logger, logger *logrus.Logger,
) *ResourceService { ) *ResourceService {
rs := &ResourceService{ rs := &ResourceService{
settingsManager: sm, settingsManager: sm,
groupManager: gm, groupManager: gm,
keyRepo: kr, keyRepo: kr,
authTokenRepo: atr,
apiKeyService: aks, apiKeyService: aks,
proxyManager: pm,
logger: logger.WithField("component", "ResourceService📦"), logger: logger.WithField("component", "ResourceService📦"),
} }
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
rs.initOnce.Do(func() { rs.initOnce.Do(func() {
go rs.preWarmCache(logger) go rs.preWarmCache()
}) })
return rs return rs
} }
// GetResourceFromBasePool 使用智能聚合池模式获取资源。
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) { func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"}) log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.") log.Debug("Entering BasePool resource acquisition.")
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups) candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
if len(candidateGroups) == 0 { if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.") log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable return nil, apperrors.ErrNoKeysAvailable
@@ -84,17 +90,18 @@ func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.") log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err return nil, err
} }
resources.RequestConfig = &models.RequestConfig{} resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID) log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil return resources, nil
} }
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) { func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"}) log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.") log.Debug("Entering PreciseRoute resource acquisition.")
targetGroup, ok := s.groupManager.GetGroupByName(groupName) targetGroup, ok := s.groupManager.GetGroupByName(groupName)
if !ok { if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.") return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
} }
@@ -113,37 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *m
log.WithError(err).Error("Failed to assemble resources for precise route.") log.WithError(err).Error("Failed to assemble resources for precise route.")
return nil, err return nil, err
} }
resources.RequestConfig = targetGroup.RequestConfig resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID) log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
return resources, nil return resources, nil
} }
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string { func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
allGroups := s.groupManager.GetAllGroups() allGroups := s.groupManager.GetAllGroups()
if len(allGroups) == 0 { if len(allGroups) == 0 {
return []string{} return []string{}
} }
allowedModelsSet := make(map[string]struct{})
allowedGroupIDs := make(map[uint]bool)
if authToken.IsAdmin { if authToken.IsAdmin {
for _, group := range allGroups { for _, group := range allGroups {
for _, modelMapping := range group.AllowedModels { allowedGroupIDs[group.ID] = true
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
} }
} else { } else {
allowedGroupIDs := make(map[uint]bool)
for _, ag := range authToken.AllowedGroups { for _, ag := range authToken.AllowedGroups {
allowedGroupIDs[ag.ID] = true allowedGroupIDs[ag.ID] = true
} }
}
allowedModelsSet := make(map[string]struct{})
for _, group := range allGroups { for _, group := range allGroups {
if _, ok := allowedGroupIDs[group.ID]; ok { if allowedGroupIDs[group.ID] {
for _, modelMapping := range group.AllowedModels { for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{} allowedModelsSet[modelMapping.ModelName] = struct{}{}
} }
} }
} }
}
result := make([]string, 0, len(allowedModelsSet)) result := make([]string, 0, len(allowedModelsSet))
for modelName := range allowedModelsSet { for modelName := range allowedModelsSet {
result = append(result, modelName) result = append(result, modelName)
@@ -152,12 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
return result return result
} }
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}
// --- 私有辅助方法 ---
// preWarmCache 在后台执行一次性的缓存预热任务。
func (s *ResourceService) preWarmCache() {
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
s.logger.Info("Performing initial key cache pre-warming...")
// 强制加载 GroupManager 缓存
s.logger.Info("Pre-warming GroupManager cache...")
_ = s.groupManager.GetAllGroups()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
defer cancel()
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
} else {
s.logger.Info("Initial key cache pre-warming completed successfully.")
}
}
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) { func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
selectedUpstream := s.selectUpstreamForGroup(group) selectedUpstream := s.selectUpstreamForGroup(group)
if selectedUpstream == nil { if selectedUpstream == nil {
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.") return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
} }
var proxyConfig *models.ProxyConfig var proxyConfig *models.ProxyConfig
var err error
// 只有在组明确启用代理时,才为其分配代理
if group.EnableProxy {
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
if err != nil {
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
}
}
return &RequestResources{ return &RequestResources{
KeyGroup: group, KeyGroup: group,
APIKey: apiKey, APIKey: apiKey,
@@ -166,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
}, nil }, nil
} }
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint { func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
if len(group.AllowedUpstreams) > 0 { if len(group.AllowedUpstreams) > 0 {
// (未来可扩展负载均衡逻辑)
return group.AllowedUpstreams[0] return group.AllowedUpstreams[0]
} }
globalSettings := s.settingsManager.GetSettings() globalSettings := s.settingsManager.GetSettings()
@@ -177,56 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
return nil return nil
} }
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error { // filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
time.Sleep(2 * time.Second) func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
s.logger.Info("Performing initial key cache pre-warming...")
if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil {
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
return err
}
s.logger.Info("Initial key cache pre-warming completed successfully.")
return nil
}
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
}
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups() allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup var candidateGroups []*models.KeyGroup
allowedGroupIDs := make(map[uint]bool)
isTokenRestricted := len(allowedGroupsFromToken) > 0
if isTokenRestricted {
for _, ag := range allowedGroupsFromToken {
allowedGroupIDs[ag.ID] = true
}
}
for _, group := range allGroupsFromCache { for _, group := range allGroupsFromCache {
if isTokenRestricted && !allowedGroupIDs[group.ID] { // 检查令牌权限
if !s.isTokenAllowedForGroup(authToken, group.ID) {
continue continue
} }
isModelAllowed := false // 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
if len(group.AllowedModels) == 0 { if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
isModelAllowed = true
} else {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
isModelAllowed = true
break
}
}
}
if isModelAllowed {
candidateGroups = append(candidateGroups, group) candidateGroups = append(candidateGroups, group)
} }
} }
sort.SliceStable(candidateGroups, func(i, j int) bool { sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order return candidateGroups[i].Order < candidateGroups[j].Order
}) })
return candidateGroups return candidateGroups
} }
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
return true
}
}
return false
}
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool { func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
if authToken.IsAdmin { if authToken.IsAdmin {
return true return true
@@ -238,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr
} }
return false return false
} }
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}

View File

@@ -1,4 +1,4 @@
// file: gemini-balancer\internal\settings\settings.go // Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版)
package settings package settings
import ( import (
@@ -19,7 +19,9 @@ import (
const SettingsUpdateChannel = "system_settings:updated" const SettingsUpdateChannel = "system_settings:updated"
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型 var _ models.SettingsManager = (*SettingsManager)(nil)
// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。
type SettingsManager struct { type SettingsManager struct {
db *gorm.DB db *gorm.DB
syncer *syncer.CacheSyncer[*models.SystemSettings] syncer *syncer.CacheSyncer[*models.SystemSettings]
@@ -27,13 +29,14 @@ type SettingsManager struct {
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型 jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
} }
// NewSettingsManager 创建一个新的 SettingsManager 实例。
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) { func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
sm := &SettingsManager{ sm := &SettingsManager{
db: db, db: db,
logger: logger.WithField("component", "SettingsManager⚙"), logger: logger.WithField("component", "SettingsManager⚙"),
jsonToFieldType: make(map[string]reflect.Type), jsonToFieldType: make(map[string]reflect.Type),
} }
// settingsLoader 的职责:读取“砖块”,组装并返回“蓝图”
settingsType := reflect.TypeOf(models.SystemSettings{}) settingsType := reflect.TypeOf(models.SystemSettings{})
for i := 0; i < settingsType.NumField(); i++ { for i := 0; i < settingsType.NumField(); i++ {
field := settingsType.Field(i) field := settingsType.Field(i)
@@ -42,102 +45,89 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
sm.jsonToFieldType[jsonTag] = field.Type sm.jsonToFieldType[jsonTag] = field.Type
} }
} }
// settingsLoader 的职责:读取“砖块”,智能组装成“蓝图”
settingsLoader := func() (*models.SystemSettings, error) { settingsLoader := func() (*models.SystemSettings, error) {
sm.logger.Info("Loading system settings from database...") sm.logger.Info("Loading system settings from database...")
var dbRecords []models.Setting var dbRecords []models.Setting
if err := sm.db.Find(&dbRecords).Error; err != nil { if err := sm.db.Find(&dbRecords).Error; err != nil {
return nil, fmt.Errorf("failed to load system settings from db: %w", err) return nil, fmt.Errorf("failed to load system settings from db: %w", err)
} }
settingsMap := make(map[string]string) settingsMap := make(map[string]string)
for _, record := range dbRecords { for _, record := range dbRecords {
settingsMap[record.Key] = record.Value settingsMap[record.Key] = record.Value
} }
// 从一个包含了所有“出厂设置”的“蓝图”开始
settings := defaultSystemSettings() settings := defaultSystemSettings()
v := reflect.ValueOf(settings).Elem() v := reflect.ValueOf(settings).Elem()
t := v.Type()
// [智能卸货] for i := 0; i < v.NumField(); i++ {
for i := 0; i < t.NumField(); i++ { field := v.Type().Field(i)
field := t.Field(i)
fieldValue := v.Field(i) fieldValue := v.Field(i)
jsonTag := field.Tag.Get("json") jsonTag := field.Tag.Get("json")
if dbValue, ok := settingsMap[jsonTag]; ok { if dbValue, ok := settingsMap[jsonTag]; ok {
if err := parseAndSetField(fieldValue, dbValue); err != nil { if err := parseAndSetField(fieldValue, dbValue); err != nil {
sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err) sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
} }
} }
} }
if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" { // [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。
if settings.DefaultUpstreamURL != "" { if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" {
// 如果全局上游URL已设置则基于它构建新的检查端点。
originalEndpoint := settings.BaseKeyCheckEndpoint
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models" derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint)
settings.BaseKeyCheckEndpoint = derivedEndpoint settings.BaseKeyCheckEndpoint = derivedEndpoint
sm.logger.Infof( } else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" {
"BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'", // 恢复 else 日志,以明确告知用户正在使用自定义覆盖。
originalEndpoint, derivedEndpoint,
)
}
} else {
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint) sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
} }
sm.logger.Info("System settings loaded and cached.") sm.logger.Info("System settings loaded and cached.")
sm.DisplaySettings(settings)
return settings, nil return settings, nil
} }
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel) s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create system settings syncer: %w", err) return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
} }
sm.syncer = s sm.syncer = s
go sm.ensureSettingsInitialized()
if err := sm.ensureSettingsInitialized(); err != nil {
return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err)
}
return sm, nil return sm, nil
} }
// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图” // GetSettings 返回当前缓存的系统设置。
func (sm *SettingsManager) GetSettings() *models.SystemSettings { func (sm *SettingsManager) GetSettings() *models.SystemSettings {
return sm.syncer.Get() return sm.syncer.Get()
} }
// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库 // UpdateSettings 更新一个或多个系统设置。
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error { func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
var settingsToUpdate []models.Setting var settingsToUpdate []models.Setting
for key, value := range settingsMap { for key, value := range settingsMap {
fieldType, ok := sm.jsonToFieldType[key] fieldType, ok := sm.jsonToFieldType[key]
if !ok { if !ok {
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key) sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
continue continue
} }
var dbValue string
// [智能打包] dbValue, err := sm.convertToDBValue(key, value, fieldType)
// 如果字段是 slice 或 map我们就将传入的 interface{} “打包”成 JSON string if err != nil {
kind := fieldType.Kind() sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err)
if kind == reflect.Slice || kind == reflect.Map { continue
jsonBytes, marshalErr := json.Marshal(value)
if marshalErr != nil {
// [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。
sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr)
continue // 跳过继续处理下一个key
}
dbValue = string(jsonBytes)
} else if kind == reflect.Bool {
if b, ok := value.(bool); ok {
dbValue = strconv.FormatBool(b)
} else {
dbValue = "false"
}
} else {
dbValue = fmt.Sprintf("%v", value)
} }
settingsToUpdate = append(settingsToUpdate, models.Setting{ settingsToUpdate = append(settingsToUpdate, models.Setting{
Key: key, Key: key,
Value: dbValue, Value: dbValue,
}) })
} }
if len(settingsToUpdate) > 0 { if len(settingsToUpdate) > 0 {
err := sm.db.Clauses(clause.OnConflict{ err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
@@ -147,83 +137,20 @@ func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) er
return fmt.Errorf("failed to update settings in db: %w", err) return fmt.Errorf("failed to update settings in db: %w", err)
} }
} }
return sm.syncer.Invalidate()
if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err)
return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err)
} }
// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义 return nil
func (sm *SettingsManager) ensureSettingsInitialized() {
defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var existing models.Setting
if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound {
var defaultValue string
kind := fieldValue.Kind()
// [智能初始化]
if kind == reflect.Slice || kind == reflect.Map {
// 为复杂类型生成一个“空的”JSON字符串例如 "[]" 或 "{}"
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"), // 元数据中的default永远来自tag
}
if err := sm.db.Create(&setting).Error; err != nil {
sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err)
}
}
}
} }
// ResetAndSaveSettings [核心新增] 將所有置重置為其在 'default' 標籤中定義的值。 // ResetAndSaveSettings 所有置重置为其默认值。
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) { func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
defaults := defaultSystemSettings() defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem() settingsToSave := sm.buildSettingsFromDefaults(defaults)
t := v.Type()
var settingsToSave []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var defaultValue string
kind := fieldValue.Kind()
// [智能重置]
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
}
settingsToSave = append(settingsToSave, setting)
}
if len(settingsToSave) > 0 { if len(settingsToSave) > 0 {
err := sm.db.Clauses(clause.OnConflict{ err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
@@ -233,8 +160,93 @@ func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error
return nil, fmt.Errorf("failed to reset settings in db: %w", err) return nil, fmt.Errorf("failed to reset settings in db: %w", err)
} }
} }
if err := sm.syncer.Invalidate(); err != nil { if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err) sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err)
return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err)
} }
return defaults, nil return defaults, nil
} }
// --- 私有辅助函数 ---
func (sm *SettingsManager) ensureSettingsInitialized() error {
defaults := defaultSystemSettings()
settingsToCreate := sm.buildSettingsFromDefaults(defaults)
for _, setting := range settingsToCreate {
var existing models.Setting
err := sm.db.Where("key = ?", setting.Key).First(&existing).Error
if err == gorm.ErrRecordNotFound {
sm.logger.Infof("Initializing new setting '%s'", setting.Key)
if createErr := sm.db.Create(&setting).Error; createErr != nil {
return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr)
}
} else if err != nil {
return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err)
}
}
return nil
}
func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting {
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
var settings []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var defaultValue string
kind := fieldValue.Kind()
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
settings = append(settings, models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
})
}
return settings
}
// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。
func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) {
kind := fieldType.Kind()
switch kind {
case reflect.Slice, reflect.Map:
jsonBytes, err := json.Marshal(value)
if err != nil {
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
}
return string(jsonBytes), nil
case reflect.Bool:
b, ok := value.(bool)
if !ok {
return "", fmt.Errorf("expected bool, but got %T", value)
}
return strconv.FormatBool(b), nil
default:
return fmt.Sprintf("%v", value), nil
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"sort" "sort"
"strconv"
"sync" "sync"
"time" "time"
@@ -65,7 +66,7 @@ func (s *memoryStore) startGCollector() {
} }
} }
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 --- // --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 --- // --- 内存实现可以忽略该参数,用 _ 接收 ---
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error { func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
@@ -108,6 +109,17 @@ func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
return ok && !item.isExpired(), nil return ok && !item.isExpired(), nil
} }
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok {
return ErrNotFound
}
item.expireAt = time.Now().Add(expiration)
return nil
}
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) { func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -159,6 +171,21 @@ func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any)
return nil return nil
} }
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
if hash, ok := item.value.(map[string]string); ok {
if value, exists := hash[field]; exists {
return value, nil
}
}
return "", ErrNotFound
}
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) { func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
@@ -351,6 +378,26 @@ func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error)
return members[n], nil return members[n], nil
} }
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
unionSet := make(map[string]struct{})
for _, key := range keys {
item, ok := s.items[key]
if !ok || item.isExpired() {
continue
}
if set, ok := item.value.(map[string]struct{}); ok {
for member := range set {
unionSet[member] = struct{}{}
}
}
}
destItem := &memoryStoreItem{value: unionSet}
s.items[destination] = destItem
return int64(len(unionSet)), nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) { func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -388,6 +435,16 @@ func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string
return list[index], nil return list[index], nil
} }
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
for key, value := range values {
// 内存存储不支持独立的 TTL因此我们假设永不过期
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
}
return nil
}
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error { func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -556,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
} }
}) })
} }
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
capturedKey := key
capturedValue := value
p.ops = append(p.ops, func() {
var expireAt time.Time
if expiration > 0 {
expireAt = time.Now().Add(expiration)
}
p.store.items[capturedKey] = &memoryStoreItem{
value: capturedValue,
expireAt: expireAt,
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) { func (p *memoryPipeliner) SAdd(key string, members ...any) {
capturedKey := key capturedKey := key
capturedMembers := make([]any, len(members)) capturedMembers := make([]any, len(members))
@@ -576,6 +649,7 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
} }
}) })
} }
func (p *memoryPipeliner) SRem(key string, members ...any) { func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key capturedKey := key
capturedMembers := make([]any, len(members)) capturedMembers := make([]any, len(members))
@@ -615,11 +689,125 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...) item.value = append(stringValues, list...)
}) })
} }
func (p *memoryPipeliner) LRem(key string, count int64, value any) {} func (p *memoryPipeliner) LRem(key string, count int64, value any) {
func (p *memoryPipeliner) HSet(key string, values map[string]any) {} capturedKey := key
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {} capturedValue := fmt.Sprintf("%v", value)
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {} p.ops = append(p.ops, func() {
func (p *memoryPipeliner) ZRem(key string, members ...any) {} item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
list, ok := item.value.([]string)
if !ok {
return
}
newList := make([]string, 0, len(list))
removed := int64(0)
for _, v := range list {
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
removed++
continue
}
newList = append(newList, v)
}
item.value = newList
})
}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
capturedKey := key
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range capturedValues {
hash[field] = fmt.Sprintf("%v", value)
}
})
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
capturedKey := key
capturedField := field
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
hash[capturedField] = strconv.FormatInt(current+incr, 10)
})
}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
capturedKey := key
capturedMembers := make(map[string]float64, len(members))
for k, v := range members {
capturedMembers[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]float64)}
p.store.items[capturedKey] = item
}
zset, ok := item.value.(map[string]float64)
if !ok {
zset = make(map[string]float64)
item.value = zset
}
for member, score := range capturedMembers {
zset[member] = score
}
})
}
func (p *memoryPipeliner) ZRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
zset, ok := item.value.(map[string]float64)
if !ok {
return
}
for _, member := range capturedMembers {
delete(zset, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) MSet(values map[string]any) {
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
for key, value := range capturedValues {
p.store.items[key] = &memoryStoreItem{
value: value,
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
}
}
})
}
type memorySubscription struct { type memorySubscription struct {
store *memoryStore store *memoryStore

View File

@@ -75,10 +75,24 @@ func (s *RedisStore) Close() error {
return s.client.Close() return s.client.Close()
} }
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
return s.client.Expire(ctx, key, expiration).Err()
}
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error { func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(ctx, key, values).Err() return s.client.HSet(ctx, key, values).Err()
} }
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
val, err := s.client.HGet(ctx, key, field).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
return val, nil
}
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) { func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(ctx, key).Result() return s.client.HGetAll(ctx, key).Result()
} }
@@ -111,6 +125,18 @@ func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
return val, nil return val, nil
} }
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
if len(values) == 0 {
return nil
}
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
pairs := make([]interface{}, 0, len(values)*2)
for k, v := range values {
pairs = append(pairs, k, v)
}
return s.client.MSet(ctx, pairs...).Err()
}
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error { func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(ctx, key, members...).Err() return s.client.SAdd(ctx, key, members...).Err()
} }
@@ -141,6 +167,13 @@ func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error
return member, nil return member, nil
} }
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
if len(keys) == 0 {
return 0, nil
}
return s.client.SUnionStore(ctx, destination, keys...).Result()
}
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error { func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
if len(members) == 0 { if len(members) == 0 {
return nil return nil
@@ -216,6 +249,17 @@ func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx,
func (p *redisPipeliner) LRem(key string, count int64, value any) { func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value) p.pipe.LRem(p.ctx, key, count, value)
} }
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
p.pipe.Set(p.ctx, key, value, expiration)
}
func (p *redisPipeliner) MSet(values map[string]any) {
if len(values) == 0 {
return
}
p.pipe.MSet(p.ctx, values)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) } func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) } func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) { func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {

View File

@@ -35,6 +35,8 @@ type Pipeliner interface {
HIncrBy(key, field string, incr int64) HIncrBy(key, field string, incr int64)
// SET // SET
MSet(values map[string]any)
Set(key string, value []byte, expiration time.Duration)
SAdd(key string, members ...any) SAdd(key string, members ...any)
SRem(key string, members ...any) SRem(key string, members ...any)
@@ -58,9 +60,11 @@ type Store interface {
Del(ctx context.Context, keys ...string) error Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error) Exists(ctx context.Context, key string) (bool, error)
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
MSet(ctx context.Context, values map[string]any) error
// HASH operations // HASH operations
HSet(ctx context.Context, key string, values map[string]any) error HSet(ctx context.Context, key string, values map[string]any) error
HGet(ctx context.Context, key, field string) (string, error)
HGetAll(ctx context.Context, key string) (map[string]string, error) HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(ctx context.Context, key string, fields ...string) error HDel(ctx context.Context, key string, fields ...string) error
@@ -70,6 +74,7 @@ type Store interface {
LRem(ctx context.Context, key string, count int64, value any) error LRem(ctx context.Context, key string, count int64, value any) error
Rotate(ctx context.Context, key string) (string, error) Rotate(ctx context.Context, key string) (string, error)
LIndex(ctx context.Context, key string, index int64) (string, error) LIndex(ctx context.Context, key string, index int64) (string, error)
Expire(ctx context.Context, key string, expiration time.Duration) error
// SET operations // SET operations
SAdd(ctx context.Context, key string, members ...any) error SAdd(ctx context.Context, key string, members ...any) error
@@ -77,6 +82,7 @@ type Store interface {
SMembers(ctx context.Context, key string) ([]string, error) SMembers(ctx context.Context, key string) ([]string, error)
SRem(ctx context.Context, key string, members ...any) error SRem(ctx context.Context, key string, members ...any) error
SRandMember(ctx context.Context, key string) (string, error) SRandMember(ctx context.Context, key string) (string, error)
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
// Pub/Sub operations // Pub/Sub operations
Publish(ctx context.Context, channel string, message []byte) error Publish(ctx context.Context, channel string, message []byte) error