Fix basepool & 优化 repo

This commit is contained in:
XOF
2025-11-23 22:42:58 +08:00
parent 2b0b9b67dc
commit 6c7283d51b
16 changed files with 1312 additions and 723 deletions

View File

@@ -28,12 +28,6 @@ type GeminiChannel struct {
httpClient *http.Client
}
// 用于安全提取信息的本地结构体
type requestMetadata struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -47,38 +41,50 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini
logger: logger,
httpClient: &http.Client{
Transport: transport,
Timeout: 0,
Timeout: 0, // Timeout is handled by the request context
},
}
}
// TransformRequest
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
modelName = ch.ExtractModel(c, requestBody)
return requestBody, modelName, nil
}
// ExtractModel
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
return ch.extractModelFromRequest(c, bodyBytes)
}
// 统一的模型提取逻辑优先从请求体解析失败则回退到从URL路径解析。
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
var p struct {
Model string `json:"model"`
}
_ = json.Unmarshal(requestBody, &p)
modelName = strings.TrimPrefix(p.Model, "models/")
if modelName == "" {
modelName = ch.extractModelFromPath(c.Request.URL.Path)
if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
return strings.TrimPrefix(p.Model, "models/")
}
return requestBody, modelName, nil
return ch.extractModelFromPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) extractModelFromPath(path string) string {
parts := strings.Split(path, "/")
for _, part := range parts {
// 覆盖更多模型名称格式
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
modelPart := strings.Split(part, ":")[0]
return modelPart
return strings.Split(part, ":")[0]
}
}
return ""
}
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
path := c.Request.URL.Path
return ch.isOpenAIPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
}
@@ -88,25 +94,28 @@ func (ch *GeminiChannel) ValidateKey(
targetURL string,
timeout time.Duration,
) *CustomErrors.APIError {
client := &http.Client{
Timeout: timeout,
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
}
ch.ModifyRequest(req, apiKey)
resp, err := client.Do(req)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
errorBody, _ := io.ReadAll(resp.Body)
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
@@ -115,10 +124,6 @@ func (ch *GeminiChannel) ValidateKey(
}
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
// TODO: [Future Refactoring] Decouple auth logic from URL path.
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
// This would make the channel more generic and adaptable to new upstream provider types.
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
} else {
@@ -133,24 +138,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
return true
}
var meta requestMetadata
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
var meta struct {
Stream bool `json:"stream"`
}
if json.Unmarshal(bodyBytes, &meta) == nil {
return meta.Stream
}
return false
}
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
return modelName
}
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
var rewrittenSegment string
if ch.IsOpenAICompatibleRequest(tempCtx) {
var apiEndpoint string
if ch.isOpenAIPath(originalPath) {
v1Index := strings.LastIndex(originalPath, "/v1/")
var apiEndpoint string
if v1Index != -1 {
apiEndpoint = originalPath[v1Index+len("/v1/"):]
} else {
@@ -158,69 +161,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
}
rewrittenSegment = "v1beta/openai/" + apiEndpoint
} else {
tempPath := originalPath
if strings.HasPrefix(tempPath, "/v1/") {
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
if strings.HasPrefix(originalPath, "/v1/") {
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
} else {
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
}
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
}
trimmedBasePath := strings.TrimSuffix(basePath, "/")
pathToJoin := rewrittenSegment
trimmedBasePath := strings.TrimSuffix(basePath, "/")
// 防止版本号重复拼接,例如 basePath 是 /v1beta而重写段也是 v1beta/..
versionPrefixes := []string{"v1beta", "v1"}
for _, prefix := range versionPrefixes {
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
break
}
}
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
if err != nil {
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
// 回退到简单的字符串拼接
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
}
return finalPath
}
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
// 这是一个桩实现,暂时不需要任何逻辑。
return nil
return nil // 桩实现
}
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
// 这是一个桩实现,暂时不需要任何逻辑。
// 桩实现
}
// ==========================================================
// ================== “智能路由”的核心引擎 ===================
// ==========================================================
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
log := ch.logger.WithField("correlation_id", params.CorrelationID)
targetURL, err := url.Parse(params.UpstreamURL)
if err != nil {
log.WithError(err).Error("Failed to parse upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
log.WithError(err).Error("Invalid upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
return
}
targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
if err != nil {
log.WithError(err).Error("Failed to create initial smart request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
log.WithError(err).Error("Failed to create initial smart stream request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
return
}
ch.ModifyRequest(initialReq, params.APIKey)
initialReq.Header.Del("Authorization")
resp, err := ch.httpClient.Do(initialReq)
if err != nil {
log.WithError(err).Error("Initial smart request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
log.WithError(err).Error("Initial smart stream request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
defer standardizedResp.Body.Close()
c.Writer.WriteHeader(standardizedResp.StatusCode)
for key, values := range standardizedResp.Header {
for _, value := range values {
@@ -228,45 +238,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR
}
}
io.Copy(c.Writer, standardizedResp.Body)
params.EventLogger.IsSuccess = false
params.EventLogger.StatusCode = resp.StatusCode
return
}
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
}
func (ch *GeminiChannel) processStreamAndRetry(
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
c *gin.Context,
initialRequestHeaders http.Header,
initialReader io.ReadCloser,
params SmartRequestParams,
log *logrus.Entry,
) {
defer initialReader.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
flusher, _ := c.Writer.(http.Flusher)
var accumulatedText strings.Builder
consecutiveRetryCount := 0
currentReader := initialReader
maxRetries := params.MaxRetries
retryDelay := params.RetryDelay
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
for {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected, stopping stream processing.")
return
}
var interruptionReason string
scanner := bufio.NewScanner(currentReader)
for scanner.Scan() {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected during scan.")
return
}
line := scanner.Text()
if line == "" {
continue
}
fmt.Fprintf(c.Writer, "%s\n\n", line)
flusher.Flush()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var payload models.GeminiSSEPayload
if err := json.Unmarshal([]byte(data), &payload); err != nil {
continue
}
if len(payload.Candidates) > 0 {
candidate := payload.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
@@ -285,52 +321,71 @@ func (ch *GeminiChannel) processStreamAndRetry(
}
}
currentReader.Close()
if interruptionReason == "" {
if err := scanner.Err(); err != nil {
log.WithError(err).Warn("Stream scanner encountered an error.")
interruptionReason = "SCANNER_ERROR"
} else {
log.Warn("Stream dropped unexpectedly without a finish reason.")
log.Warn("Stream connection dropped without a finish reason.")
interruptionReason = "CONNECTION_DROP"
}
}
if consecutiveRetryCount >= maxRetries {
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": http.StatusGatewayTimeout,
"status": "DEADLINE_EXCEEDED",
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
},
})
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
flusher.Flush()
return
}
consecutiveRetryCount++
params.EventLogger.Retries = consecutiveRetryCount
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
time.Sleep(retryDelay)
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBodyBytes, _ := json.Marshal(retryBody)
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
if err != nil {
log.WithError(err).Error("Failed to create retry request")
continue
}
retryReq.Header = initialRequestHeaders
ch.ModifyRequest(retryReq, params.APIKey)
retryReq.Header.Del("Authorization")
retryResp, err := ch.httpClient.Do(retryReq)
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
if err != nil {
log.WithError(err).Errorf("Retry request failed.")
} else {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
if retryResp.Body != nil {
retryResp.Body.Close()
}
}
if err != nil {
log.WithError(err).Error("Retry request failed")
continue
}
if retryResp.StatusCode != http.StatusOK {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
retryResp.Body.Close()
continue
}
currentReader = retryResp.Body
}
}
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
retryBody := originalBody
// 找到最后一个 'user' 角色的消息索引
lastUserIndex := -1
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
if retryBody.Contents[i].Role == "user" {
@@ -338,25 +393,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st
break
}
}
history := []models.GeminiContent{
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
}
if lastUserIndex != -1 {
// 如果找到了 'user' 消息,将历史记录插入到其后
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
newContents = append(newContents, history...)
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
retryBody.Contents = newContents
} else {
// 如果没有 'user' 消息(理论上不应发生),则直接追加
retryBody.Contents = append(retryBody.Contents, history...)
}
return retryBody, nil
}
// ===============================================
// ========= 辅助函数区 (继承并强化) =========
// ===============================================
return retryBody
}
type googleAPIError struct {
Error struct {
@@ -397,25 +453,28 @@ func truncate(s string, n int) string {
return s
}
// standardizeError
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("Failed to read upstream error body")
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
bodyBytes = []byte("Failed to read upstream error body")
}
resp.Body.Close()
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
var standardizedPayload googleAPIError
// 即使解析失败,也要构建一个标准的错误结构体
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
standardizedPayload.Error.Code = resp.StatusCode
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
standardizedPayload.Error.Details = []interface{}{map[string]string{
"@type": "proxy.upstream.error",
"@type": "proxy.upstream.unparsed.error",
"body": truncate(string(bodyBytes), truncateLimit),
}}
}
newBodyBytes, _ := json.Marshal(standardizedPayload)
newResp := &http.Response{
StatusCode: resp.StatusCode,
@@ -425,10 +484,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int
}
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
newResp.Header.Set("Access-Control-Allow-Origin", "*")
return newResp
}
// errToJSON
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
}