优化流式传输&fix bugs

This commit is contained in:
XOF
2025-11-25 16:58:15 +08:00
parent e026d8f324
commit ad1e6180cf
18 changed files with 1135 additions and 156 deletions

View File

@@ -133,16 +133,25 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
startTime := time.Now()
correlationID := uuid.New().String()
// ✅ 检查是否是流式请求
isStreamRequest := h.channel.IsStreamRequest(c, requestBody)
// ✅ 流式请求也支持重试
if isStreamRequest {
h.serveStreamWithRetry(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting, correlationID, startTime)
return
}
var finalRecorder *httptest.ResponseRecorder
var lastUsedResources *service.RequestResources
var finalProxyErr *errors.APIError
var isSuccess bool
var finalPromptTokens, finalCompletionTokens, actualRetries int
defer h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
actualRetries, isPreciseRouting)
defer func() {
h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
actualRetries, isPreciseRouting)
}()
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
totalAttempts := maxRetries + 1
@@ -158,6 +167,7 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
if err != nil {
h.logger.WithField("id", correlationID).Errorf("❌ getResourcesForAttempt failed: %v", err)
if apiErr, ok := err.(*errors.APIError); ok {
finalProxyErr = apiErr
} else {
@@ -165,7 +175,9 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
}
break
}
lastUsedResources = resources
h.logger.WithField("id", correlationID).Infof("✅ Got resources: KeyID=%d", resources.APIKey.ID)
// lastUsedResources = resources
if attempt > 1 {
actualRetries = attempt - 1
}
@@ -176,8 +188,15 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
&finalPromptTokens, &finalCompletionTokens,
)
h.logger.WithField("id", correlationID).Infof("✅ Before assignment: lastUsedResources=%v", lastUsedResources)
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
// ✅ 修正 isSuccess
if finalProxyErr != nil || (finalRecorder != nil && finalRecorder.Code >= 400) {
isSuccess = false
}
lastUsedResources = resources
h.logger.WithField("id", correlationID).Infof("✅ After assignment: lastUsedResources=%v", lastUsedResources)
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
if isSuccess {
@@ -192,10 +211,307 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
}
// ✅ 修改 serveStreamWithRetry添加 nil 检查
func (h *ProxyHandler) serveStreamWithRetry(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool, correlationID string, startTime time.Time) {
initialResources.RequestConfig = h.buildFinalRequestConfig(
h.settingsManager.GetSettings(),
initialResources.RequestConfig,
)
h.logger.WithField("id", correlationID).Info("🌊 Serving stream request with retry support")
var lastUsedResources *service.RequestResources
var finalProxyErr *errors.APIError
var isSuccess bool
var actualRetries int
defer func() {
h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
nil, finalProxyErr, isSuccess, 0, 0, actualRetries, isPreciseRouting)
}()
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
totalAttempts := maxRetries + 1
for attempt := 1; attempt <= totalAttempts; attempt++ {
// ✅ 检查客户端是否断开连接
if c.Request.Context().Err() != nil {
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
if finalProxyErr == nil {
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
}
break
}
// ✅ 获取资源(第一次使用 initialResources后续重试获取新资源
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
if err != nil {
h.logger.WithField("id", correlationID).Errorf("❌ Failed to get resources: %v", err)
if apiErr, ok := err.(*errors.APIError); ok {
finalProxyErr = apiErr
} else {
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
}
break
}
if attempt > 1 {
actualRetries = attempt - 1
}
h.logger.WithField("id", correlationID).Infof("🔄 Stream attempt %d/%d (KeyID=%d, GroupID=%d)",
attempt, totalAttempts, resources.APIKey.ID, resources.KeyGroup.ID)
// ✅ 执行流式代理请求
attemptErr, attemptSuccess := h.executeStreamAttempt(
c, correlationID, requestBody, resources, groupName, isPreciseRouting,
)
finalProxyErr, isSuccess = attemptErr, attemptSuccess
lastUsedResources = resources
// ✅ 报告结果
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
// ✅ 成功则退出
if isSuccess {
h.logger.WithField("id", correlationID).Info("✅ Stream request succeeded")
break
}
// ✅ 判断是否应该停止重试
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
// ✅ 安全地记录错误信息(添加 nil 检查)
if finalProxyErr != nil {
h.logger.WithField("id", correlationID).Warnf("⛔ Stopping retry: %s", finalProxyErr.Message)
} else {
h.logger.WithField("id", correlationID).Warn("⛔ Stopping retry: unknown error")
}
break
}
// ✅ 发布重试日志事件
h.publishStreamRetryLogEvent(c, startTime, correlationID, modelName, resources, attemptErr, actualRetries, isPreciseRouting)
// ✅ 简化重试日志
if attempt < totalAttempts {
h.logger.WithField("id", correlationID).Infof("🔁 Retrying... (%d/%d)", attempt, totalAttempts-1)
}
}
// ✅ 如果所有尝试都失败,写入错误响应
if !isSuccess && finalProxyErr != nil {
h.logger.WithField("id", correlationID).Warnf("❌ All stream attempts failed: %s (code=%s)",
finalProxyErr.Message, finalProxyErr.Code)
// ✅ 检查是否已经写入响应头
if !c.Writer.Written() {
errToJSON(c, correlationID, finalProxyErr)
} else {
h.logger.WithField("id", correlationID).Warn("⚠️ Cannot write error, response already started")
}
}
}
// 执行单次流式代理请求
func (h *ProxyHandler) executeStreamAttempt(
c *gin.Context,
correlationID string,
requestBody []byte,
resources *service.RequestResources,
groupName string,
isPreciseRouting bool,
) (finalErr *errors.APIError, finalSuccess bool) { // ✅ 使用命名返回值
// ✅ 捕获 ReverseProxy 的 ErrAbortHandler panic
defer func() {
if r := recover(); r != nil {
// ✅ 如果是 http.ErrAbortHandler说明流式响应已成功完成
if r == http.ErrAbortHandler {
h.logger.WithField("id", correlationID).Debug("✅ Stream completed (ErrAbortHandler caught)")
// ✅ 修改命名返回值,确保返回成功状态
finalErr = nil
finalSuccess = true
return
}
// ✅ 其他 panic 继续抛出
h.logger.WithField("id", correlationID).Errorf("❌ Unexpected panic in stream: %v", r)
panic(r)
}
}()
var attemptErr *errors.APIError
var isSuccess bool
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
defer cancel()
attemptReq := c.Request.Clone(ctx)
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
attemptReq.ContentLength = int64(len(requestBody))
// ✅ 创建独立的 ReverseProxy
streamProxy := &httputil.ReverseProxy{
Transport: h.transparentProxy.Transport,
BufferPool: h.transparentProxy.BufferPool,
}
streamProxy.Director = func(r *http.Request) {
targetURL, _ := url.Parse(resources.UpstreamEndpoint.URL)
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
var pureClientPath string
if isPreciseRouting {
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
} else {
pureClientPath = r.URL.Path
}
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
r.Header.Del("Authorization")
h.channel.ModifyRequest(r, resources.APIKey)
r.Header.Set("X-Correlation-ID", correlationID)
// ✅ 添加:应用自定义请求头
if resources.RequestConfig != nil {
for k, v := range resources.RequestConfig.CustomHeaders {
if strVal, ok := v.(string); ok {
r.Header.Set(k, strVal)
}
}
}
}
// ✅ 配置 Transport
transport := streamProxy.Transport.(*http.Transport)
if resources.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", resources.ProxyConfig.Protocol, resources.ProxyConfig.Address)
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
transportCopy := transport.Clone()
transportCopy.Proxy = http.ProxyURL(proxyURL)
streamProxy.Transport = transportCopy
h.logger.WithField("id", correlationID).Infof("🔀 Using proxy: %s", proxyURLStr)
}
}
// ✅ 配置 ModifyResponse
streamProxy.ModifyResponse = func(resp *http.Response) error {
h.logger.WithField("id", correlationID).Infof("📨 Stream response: status=%d, contentType=%s",
resp.StatusCode, resp.Header.Get("Content-Type"))
// ✅ 处理 gzip 解压
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to create gzip reader")
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response")
isSuccess = false
return fmt.Errorf("gzip decompression failed: %w", err)
}
resp.Body = gzReader
resp.Header.Del("Content-Encoding")
}
// ✅ 成功响应:直接透传
if resp.StatusCode < 400 {
isSuccess = true
h.logger.WithField("id", correlationID).Info("✅ Stream response marked as success")
return nil
}
// ✅ 错误响应:读取错误信息(用于重试判断)
isSuccess = false
// ✅ 读取错误响应体
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to read error response")
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error")
} else {
// ✅ 根据状态码决定是否输出详细错误信息
shouldLogErrorBody := h.shouldLogErrorBody(resp.StatusCode)
if shouldLogErrorBody {
h.logger.WithField("id", correlationID).Errorf("❌ Stream error: status=%d, body=%s",
resp.StatusCode, string(bodyBytes))
} else {
// ✅ 对于常见错误429、403等只记录简要信息
errorSummary := h.extractErrorSummary(bodyBytes)
h.logger.WithField("id", correlationID).Warnf("⚠️ Stream error: status=%d, summary=%s",
resp.StatusCode, errorSummary)
}
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
}
// ✅ 返回错误,触发 ErrorHandler但不写入响应因为可能需要重试
return fmt.Errorf("upstream error: status %d", resp.StatusCode)
}
// ✅ 配置 ErrorHandler
streamProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
h.logger.WithField("id", correlationID).Debugf("Stream proxy error handler triggered: %v", err)
// ✅ 如果 attemptErr 未设置,根据错误类型创建
if attemptErr == nil {
isSuccess = false
if err == context.DeadlineExceeded {
attemptErr = errors.NewAPIError(errors.ErrGatewayTimeout, "Request timeout")
} else if err == context.Canceled {
attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Request canceled")
} else if errors.IsClientNetworkError(err) {
attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
attemptErr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
}
// ✅ 不在这里写入响应,让外层重试逻辑决定
}
// ✅ 执行代理请求(可能抛出 ErrAbortHandler
streamProxy.ServeHTTP(c.Writer, attemptReq)
// ✅ 正常返回(如果没有 panic
return attemptErr, isSuccess
}
// ✅ 新增:判断是否应该记录详细错误体
func (h *ProxyHandler) shouldLogErrorBody(statusCode int) bool {
// ✅ 对于常见的客户端错误和限流错误,不记录详细错误体
commonErrors := map[int]bool{
400: true, // Bad Request
401: true, // Unauthorized
403: true, // Forbidden
404: true, // Not Found
429: true, // Too Many Requests
}
return !commonErrors[statusCode]
}
// ✅ 新增:从错误响应中提取简要信息
func (h *ProxyHandler) extractErrorSummary(bodyBytes []byte) string {
// ✅ 尝试解析 JSON 错误响应
var errorResp struct {
Error struct {
Message string `json:"message"`
Code int `json:"code"`
Status string `json:"status"`
} `json:"error"`
}
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil && errorResp.Error.Message != "" {
// ✅ 截取错误消息的前100个字符
message := errorResp.Error.Message
if len(message) > 100 {
message = message[:100] + "..."
}
if errorResp.Error.Status != "" {
return fmt.Sprintf("%s: %s", errorResp.Error.Status, message)
}
return message
}
// ✅ 如果无法解析 JSON返回前100个字符
if len(bodyBytes) > 100 {
return string(bodyBytes[:100]) + "..."
}
return string(bodyBytes)
}
// ✅ 新增:发布流式重试日志事件
func (h *ProxyHandler) publishStreamRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, attemptErr *errors.APIError, retries int, isPrecise bool) {
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
retryEvent.RequestLog.IsSuccess = false
retryEvent.RequestLog.Retries = retries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.RequestLog.ErrorCode = attemptErr.Code
retryEvent.RequestLog.ErrorMessage = attemptErr.Message
retryEvent.RequestLog.Status = attemptErr.Status
retryEvent.RequestLog.StatusCode = attemptErr.HTTPStatus
}
eventData, err := json.Marshal(retryEvent)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal stream retry log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish stream retry log event")
}
}
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
recorder := httptest.NewRecorder()
var attemptErr *errors.APIError
var isSuccess bool
isSuccess := false
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
@@ -205,11 +521,26 @@ func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body [
attemptReq.Body = io.NopCloser(bytes.NewReader(body))
attemptReq.ContentLength = int64(len(body))
h.logger.WithField("id", corrID).Infof("🚀 Starting proxy attempt with KeyID=%d", res.APIKey.ID)
h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
*attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
h.transparentProxy.ServeHTTP(recorder, attemptReq)
h.logger.WithField("id", corrID).Infof("📥 Proxy returned: status=%d, bodyLen=%d, err=%v, success=%v",
recorder.Code, recorder.Body.Len(), attemptErr, isSuccess)
// 调试检查 ✅
if recorder.Code == 0 && attemptErr != nil {
h.logger.WithField("id", corrID).Warnf("⚠️ Fixing zero status code to %d", attemptErr.HTTPStatus)
recorder.Code = attemptErr.HTTPStatus
if recorder.Body.Len() == 0 {
errJSON, _ := json.Marshal(gin.H{"error": attemptErr})
recorder.Body.Write(errJSON)
}
}
return recorder, attemptErr, isSuccess
}
@@ -248,21 +579,71 @@ func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResourc
func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
return func(resp *http.Response) error {
corrID := resp.Request.Header.Get("X-Correlation-ID")
h.logger.WithField("id", corrID).Infof("📨 Upstream response: status=%d, contentType=%s",
resp.StatusCode, resp.Header.Get("Content-Type"))
// 检查是否是流式响应
isStream := strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
// 处理 gzip 压缩
var reader io.ReadCloser = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
h.logger.WithError(err).Error("Failed to create gzip reader")
} else {
reader = gzReader
resp.Header.Del("Content-Encoding")
h.logger.WithField("id", corrID).WithError(err).Error("Failed to create gzip reader")
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response")
*isSuccess = false
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
return nil
}
reader = gzReader
resp.Header.Del("Content-Encoding")
// ✅ 对于流式响应,需要替换 resp.Body 为解压后的 reader
if isStream {
resp.Body = reader
}
}
defer reader.Close()
if isStream {
h.logger.WithField("id", corrID).Info("📡 Processing stream response")
if resp.StatusCode < 400 {
*isSuccess = true
h.logger.WithField("id", corrID).Info("✅ Stream response marked as success, passing through")
// 不关闭 reader让它继续流式传输
return nil
} else {
// 错误响应才读取完整内容
bodyBytes, err := io.ReadAll(reader)
reader.Close()
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to read error response")
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error")
} else {
h.logger.WithField("id", corrID).Errorf("❌ Stream error: status=%d, body=%s",
resp.StatusCode, string(bodyBytes))
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
}
*isSuccess = false
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
}
// 非流式响应:读取完整内容
h.logger.WithField("id", corrID).Info("📄 Processing non-stream response")
bodyBytes, err := io.ReadAll(reader)
reader.Close()
if err != nil {
*attemptErr = errors.NewAPIErrorWithUpstream(http.StatusBadGateway, "UPSTREAM_GATEWAY_ERROR", nil)
h.logger.WithField("id", corrID).WithError(err).Error("Failed to read response body")
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
*isSuccess = false
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
return nil
}
@@ -270,9 +651,16 @@ func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, is
if resp.StatusCode < 400 {
*isSuccess = true
*pTokens, *cTokens = extractUsage(bodyBytes)
h.logger.WithField("id", corrID).Infof("✅ Success: bytes=%d, pTokens=%d, cTokens=%d",
len(bodyBytes), *pTokens, *cTokens)
} else {
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
h.logger.WithField("id", corrID).Errorf("❌ Error: status=%d, body=%s",
resp.StatusCode, string(bodyBytes))
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
*isSuccess = false
}
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
@@ -319,10 +707,34 @@ func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *error
if attempt >= totalAttempts {
return true
}
if err != nil && errors.IsUnretryableRequestError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Attempt failed with unretryable request error. Aborting retries. Message: %s", err.Message)
if err == nil {
return false
}
// ✅ 不可重试的请求错误:立即停止
if errors.IsUnretryableRequestError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Unretryable request error, aborting: %s", err.Message)
return true
}
// ✅ 永久性上游错误立即停止Key 已失效)
if errors.IsPermanentUpstreamError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Permanent upstream error, aborting: %s", err.Message)
return false
}
// ✅ 可重试的网络错误:继续重试
if errors.IsRetryableNetworkError(err.Message) {
return false
}
// ✅ 临时性错误(配额等):继续重试
if errors.IsTemporaryUpstreamError(err.Message) {
return false
}
// ✅ 其他未分类错误:继续重试
return false
}