// Filename: internal/handlers/proxy_handler.go package handlers import ( "bytes" "compress/gzip" "context" "encoding/json" "fmt" "gemini-balancer/internal/channel" "gemini-balancer/internal/errors" "gemini-balancer/internal/middleware" "gemini-balancer/internal/models" "gemini-balancer/internal/service" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" "io" "net" "net/http" "net/http/httptest" "net/http/httputil" "net/url" "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/sirupsen/logrus" "gorm.io/datatypes" ) type proxyErrorContextKey struct{} type ProxyHandler struct { resourceService *service.ResourceService store store.Store settingsManager *settings.SettingsManager groupManager *service.GroupManager channel channel.ChannelProxy logger *logrus.Entry transparentProxy *httputil.ReverseProxy } func NewProxyHandler( resourceService *service.ResourceService, store store.Store, sm *settings.SettingsManager, gm *service.GroupManager, channel channel.ChannelProxy, logger *logrus.Logger, ) *ProxyHandler { ph := &ProxyHandler{ resourceService: resourceService, store: store, settingsManager: sm, groupManager: gm, channel: channel, logger: logger.WithField("component", "ProxyHandler"), transparentProxy: &httputil.ReverseProxy{}, } ph.transparentProxy.Transport = &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 60 * time.Second, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } ph.transparentProxy.ErrorHandler = ph.transparentProxyErrorHandler ph.transparentProxy.BufferPool = &bufferPool{} return ph } func (h *ProxyHandler) HandleProxy(c *gin.Context) { if c.Request.Method == "GET" && (strings.HasSuffix(c.Request.URL.Path, "/models") || strings.HasSuffix(c.Request.URL.Path, "/models/")) { h.handleListModelsRequest(c) return } maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024) requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize)) if err != nil { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read")) return } modelName := h.channel.ExtractModel(c, requestBody) groupName := c.Param("group_name") isPreciseRouting := groupName != "" if !isPreciseRouting && modelName == "" { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request")) return } initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) if err != nil { if apiErr, ok := err.(*errors.APIError); ok { errToJSON(c, uuid.New().String(), apiErr) } else { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources")) } return } finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup) if err != nil { h.logger.WithError(err).Error("Failed to build operational config.") errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config")) return } initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig) isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c) if isOpenAICompatible { h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) return } isStream := h.channel.IsStreamRequest(c, requestBody) useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry { h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting) } else { h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) } } func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) { startTime := time.Now() correlationID := uuid.New().String() // ✅ 检查是否是流式请求 isStreamRequest := h.channel.IsStreamRequest(c, requestBody) // ✅ 流式请求也支持重试 if isStreamRequest { h.serveStreamWithRetry(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting, correlationID, startTime) return } var finalRecorder *httptest.ResponseRecorder var lastUsedResources *service.RequestResources var finalProxyErr *errors.APIError var isSuccess bool var finalPromptTokens, finalCompletionTokens, actualRetries int defer func() { h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources, finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens, actualRetries, isPreciseRouting) }() maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig) totalAttempts := maxRetries + 1 for attempt := 1; attempt <= totalAttempts; attempt++ { if c.Request.Context().Err() != nil { h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.") if finalProxyErr == nil { finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected") } break } resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID) if err != nil { h.logger.WithField("id", correlationID).Errorf("❌ getResourcesForAttempt failed: %v", err) if apiErr, ok := err.(*errors.APIError); ok { finalProxyErr = apiErr } else { finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry") } break } h.logger.WithField("id", correlationID).Infof("✅ Got resources: KeyID=%d", resources.APIKey.ID) // lastUsedResources = resources if attempt > 1 { actualRetries = attempt - 1 } h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID) recorder, attemptErr, attemptSuccess := h.executeProxyAttempt( c, correlationID, requestBody, resources, isPreciseRouting, groupName, &finalPromptTokens, &finalCompletionTokens, ) h.logger.WithField("id", correlationID).Infof("✅ Before assignment: lastUsedResources=%v", lastUsedResources) finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess // ✅ 修正 isSuccess if finalProxyErr != nil || (finalRecorder != nil && finalRecorder.Code >= 400) { isSuccess = false } lastUsedResources = resources h.logger.WithField("id", correlationID).Infof("✅ After assignment: lastUsedResources=%v", lastUsedResources) h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr) if isSuccess { break } if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) { break } h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting) } h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr) } // ✅ 修改 serveStreamWithRetry,添加 nil 检查 func (h *ProxyHandler) serveStreamWithRetry(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool, correlationID string, startTime time.Time) { initialResources.RequestConfig = h.buildFinalRequestConfig( h.settingsManager.GetSettings(), initialResources.RequestConfig, ) h.logger.WithField("id", correlationID).Info("🌊 Serving stream request with retry support") var lastUsedResources *service.RequestResources var finalProxyErr *errors.APIError var isSuccess bool var actualRetries int defer func() { h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources, nil, finalProxyErr, isSuccess, 0, 0, actualRetries, isPreciseRouting) }() maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig) totalAttempts := maxRetries + 1 for attempt := 1; attempt <= totalAttempts; attempt++ { // ✅ 检查客户端是否断开连接 if c.Request.Context().Err() != nil { h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.") if finalProxyErr == nil { finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected") } break } // ✅ 获取资源(第一次使用 initialResources,后续重试获取新资源) resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID) if err != nil { h.logger.WithField("id", correlationID).Errorf("❌ Failed to get resources: %v", err) if apiErr, ok := err.(*errors.APIError); ok { finalProxyErr = apiErr } else { finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry") } break } if attempt > 1 { actualRetries = attempt - 1 } h.logger.WithField("id", correlationID).Infof("🔄 Stream attempt %d/%d (KeyID=%d, GroupID=%d)", attempt, totalAttempts, resources.APIKey.ID, resources.KeyGroup.ID) // ✅ 执行流式代理请求 attemptErr, attemptSuccess := h.executeStreamAttempt( c, correlationID, requestBody, resources, groupName, isPreciseRouting, ) finalProxyErr, isSuccess = attemptErr, attemptSuccess lastUsedResources = resources // ✅ 报告结果 h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr) // ✅ 成功则退出 if isSuccess { h.logger.WithField("id", correlationID).Info("✅ Stream request succeeded") break } // ✅ 判断是否应该停止重试 if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) { // ✅ 安全地记录错误信息(添加 nil 检查) if finalProxyErr != nil { h.logger.WithField("id", correlationID).Warnf("⛔ Stopping retry: %s", finalProxyErr.Message) } else { h.logger.WithField("id", correlationID).Warn("⛔ Stopping retry: unknown error") } break } // ✅ 发布重试日志事件 h.publishStreamRetryLogEvent(c, startTime, correlationID, modelName, resources, attemptErr, actualRetries, isPreciseRouting) // ✅ 简化重试日志 if attempt < totalAttempts { h.logger.WithField("id", correlationID).Infof("🔁 Retrying... (%d/%d)", attempt, totalAttempts-1) } } // ✅ 如果所有尝试都失败,写入错误响应 if !isSuccess && finalProxyErr != nil { h.logger.WithField("id", correlationID).Warnf("❌ All stream attempts failed: %s (code=%s)", finalProxyErr.Message, finalProxyErr.Code) // ✅ 检查是否已经写入响应头 if !c.Writer.Written() { errToJSON(c, correlationID, finalProxyErr) } else { h.logger.WithField("id", correlationID).Warn("⚠️ Cannot write error, response already started") } } } // 执行单次流式代理请求 func (h *ProxyHandler) executeStreamAttempt( c *gin.Context, correlationID string, requestBody []byte, resources *service.RequestResources, groupName string, isPreciseRouting bool, ) (finalErr *errors.APIError, finalSuccess bool) { // ✅ 使用命名返回值 // ✅ 捕获 ReverseProxy 的 ErrAbortHandler panic defer func() { if r := recover(); r != nil { // ✅ 如果是 http.ErrAbortHandler,说明流式响应已成功完成 if r == http.ErrAbortHandler { h.logger.WithField("id", correlationID).Debug("✅ Stream completed (ErrAbortHandler caught)") // ✅ 修改命名返回值,确保返回成功状态 finalErr = nil finalSuccess = true return } // ✅ 其他 panic 继续抛出 h.logger.WithField("id", correlationID).Errorf("❌ Unexpected panic in stream: %v", r) panic(r) } }() var attemptErr *errors.APIError var isSuccess bool requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout) defer cancel() attemptReq := c.Request.Clone(ctx) attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody)) attemptReq.ContentLength = int64(len(requestBody)) // ✅ 创建独立的 ReverseProxy streamProxy := &httputil.ReverseProxy{ Transport: h.transparentProxy.Transport, BufferPool: h.transparentProxy.BufferPool, } streamProxy.Director = func(r *http.Request) { targetURL, _ := url.Parse(resources.UpstreamEndpoint.URL) r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host var pureClientPath string if isPreciseRouting { pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName) } else { pureClientPath = r.URL.Path } r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath) r.Header.Del("Authorization") h.channel.ModifyRequest(r, resources.APIKey) r.Header.Set("X-Correlation-ID", correlationID) // ✅ 添加:应用自定义请求头 if resources.RequestConfig != nil { for k, v := range resources.RequestConfig.CustomHeaders { if strVal, ok := v.(string); ok { r.Header.Set(k, strVal) } } } } // ✅ 配置 Transport transport := streamProxy.Transport.(*http.Transport) if resources.ProxyConfig != nil { proxyURLStr := fmt.Sprintf("%s://%s", resources.ProxyConfig.Protocol, resources.ProxyConfig.Address) if proxyURL, err := url.Parse(proxyURLStr); err == nil { transportCopy := transport.Clone() transportCopy.Proxy = http.ProxyURL(proxyURL) streamProxy.Transport = transportCopy h.logger.WithField("id", correlationID).Infof("🔀 Using proxy: %s", proxyURLStr) } } // ✅ 配置 ModifyResponse streamProxy.ModifyResponse = func(resp *http.Response) error { h.logger.WithField("id", correlationID).Infof("📨 Stream response: status=%d, contentType=%s", resp.StatusCode, resp.Header.Get("Content-Type")) // ✅ 处理 gzip 解压 if resp.Header.Get("Content-Encoding") == "gzip" { gzReader, err := gzip.NewReader(resp.Body) if err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to create gzip reader") attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response") isSuccess = false return fmt.Errorf("gzip decompression failed: %w", err) } resp.Body = gzReader resp.Header.Del("Content-Encoding") } // ✅ 成功响应:直接透传 if resp.StatusCode < 400 { isSuccess = true h.logger.WithField("id", correlationID).Info("✅ Stream response marked as success") return nil } // ✅ 错误响应:读取错误信息(用于重试判断) isSuccess = false // ✅ 读取错误响应体 bodyBytes, err := io.ReadAll(resp.Body) if err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to read error response") attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error") } else { // ✅ 根据状态码决定是否输出详细错误信息 shouldLogErrorBody := h.shouldLogErrorBody(resp.StatusCode) if shouldLogErrorBody { h.logger.WithField("id", correlationID).Errorf("❌ Stream error: status=%d, body=%s", resp.StatusCode, string(bodyBytes)) } else { // ✅ 对于常见错误(429、403等),只记录简要信息 errorSummary := h.extractErrorSummary(bodyBytes) h.logger.WithField("id", correlationID).Warnf("⚠️ Stream error: status=%d, summary=%s", resp.StatusCode, errorSummary) } attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes) } // ✅ 返回错误,触发 ErrorHandler(但不写入响应,因为可能需要重试) return fmt.Errorf("upstream error: status %d", resp.StatusCode) } // ✅ 配置 ErrorHandler streamProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { h.logger.WithField("id", correlationID).Debugf("Stream proxy error handler triggered: %v", err) // ✅ 如果 attemptErr 未设置,根据错误类型创建 if attemptErr == nil { isSuccess = false if err == context.DeadlineExceeded { attemptErr = errors.NewAPIError(errors.ErrGatewayTimeout, "Request timeout") } else if err == context.Canceled { attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Request canceled") } else if errors.IsClientNetworkError(err) { attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") } else { attemptErr = errors.NewAPIError(errors.ErrBadGateway, err.Error()) } } // ✅ 不在这里写入响应,让外层重试逻辑决定 } // ✅ 执行代理请求(可能抛出 ErrAbortHandler) streamProxy.ServeHTTP(c.Writer, attemptReq) // ✅ 正常返回(如果没有 panic) return attemptErr, isSuccess } // ✅ 新增:判断是否应该记录详细错误体 func (h *ProxyHandler) shouldLogErrorBody(statusCode int) bool { // ✅ 对于常见的客户端错误和限流错误,不记录详细错误体 commonErrors := map[int]bool{ 400: true, // Bad Request 401: true, // Unauthorized 403: true, // Forbidden 404: true, // Not Found 429: true, // Too Many Requests } return !commonErrors[statusCode] } // ✅ 新增:从错误响应中提取简要信息 func (h *ProxyHandler) extractErrorSummary(bodyBytes []byte) string { // ✅ 尝试解析 JSON 错误响应 var errorResp struct { Error struct { Message string `json:"message"` Code int `json:"code"` Status string `json:"status"` } `json:"error"` } if err := json.Unmarshal(bodyBytes, &errorResp); err == nil && errorResp.Error.Message != "" { // ✅ 截取错误消息的前100个字符 message := errorResp.Error.Message if len(message) > 100 { message = message[:100] + "..." } if errorResp.Error.Status != "" { return fmt.Sprintf("%s: %s", errorResp.Error.Status, message) } return message } // ✅ 如果无法解析 JSON,返回前100个字符 if len(bodyBytes) > 100 { return string(bodyBytes[:100]) + "..." } return string(bodyBytes) } // ✅ 新增:发布流式重试日志事件 func (h *ProxyHandler) publishStreamRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, attemptErr *errors.APIError, retries int, isPrecise bool) { retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise) retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) retryEvent.RequestLog.IsSuccess = false retryEvent.RequestLog.Retries = retries if attemptErr != nil { retryEvent.Error = attemptErr retryEvent.RequestLog.ErrorCode = attemptErr.Code retryEvent.RequestLog.ErrorMessage = attemptErr.Message retryEvent.RequestLog.Status = attemptErr.Status retryEvent.RequestLog.StatusCode = attemptErr.HTTPStatus } eventData, err := json.Marshal(retryEvent) if err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal stream retry log event") return } if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish stream retry log event") } } func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) { recorder := httptest.NewRecorder() var attemptErr *errors.APIError isSuccess := false requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout) defer cancel() attemptReq := c.Request.Clone(ctx) attemptReq.Body = io.NopCloser(bytes.NewReader(body)) attemptReq.ContentLength = int64(len(body)) h.logger.WithField("id", corrID).Infof("🚀 Starting proxy attempt with KeyID=%d", res.APIKey.ID) h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens) *attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr)) h.transparentProxy.ServeHTTP(recorder, attemptReq) h.logger.WithField("id", corrID).Infof("📥 Proxy returned: status=%d, bodyLen=%d, err=%v, success=%v", recorder.Code, recorder.Body.Len(), attemptErr, isSuccess) // 调试检查 ✅ if recorder.Code == 0 && attemptErr != nil { h.logger.WithField("id", corrID).Warnf("⚠️ Fixing zero status code to %d", attemptErr.HTTPStatus) recorder.Code = attemptErr.HTTPStatus if recorder.Body.Len() == 0 { errJSON, _ := json.Marshal(gin.H{"error": attemptErr}) recorder.Body.Write(errJSON) } } return recorder, attemptErr, isSuccess } func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) { h.transparentProxy.Director = func(r *http.Request) { targetURL, _ := url.Parse(res.UpstreamEndpoint.URL) r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host var pureClientPath string if isPrecise { pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName) } else { pureClientPath = r.URL.Path } r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath) r.Header.Del("Authorization") h.channel.ModifyRequest(r, res.APIKey) r.Header.Set("X-Correlation-ID", corrID) } transport := h.transparentProxy.Transport.(*http.Transport) if res.ProxyConfig != nil { proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address) if proxyURL, err := url.Parse(proxyURLStr); err == nil { transport.Proxy = http.ProxyURL(proxyURL) } else { transport.Proxy = http.ProxyFromEnvironment } } else { transport.Proxy = http.ProxyFromEnvironment } h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens) } func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error { return func(resp *http.Response) error { corrID := resp.Request.Header.Get("X-Correlation-ID") h.logger.WithField("id", corrID).Infof("📨 Upstream response: status=%d, contentType=%s", resp.StatusCode, resp.Header.Get("Content-Type")) // 检查是否是流式响应 isStream := strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") // 处理 gzip 压缩 var reader io.ReadCloser = resp.Body if resp.Header.Get("Content-Encoding") == "gzip" { gzReader, err := gzip.NewReader(resp.Body) if err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to create gzip reader") *attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response") *isSuccess = false resp.Body = io.NopCloser(bytes.NewReader([]byte{})) return nil } reader = gzReader resp.Header.Del("Content-Encoding") // ✅ 对于流式响应,需要替换 resp.Body 为解压后的 reader if isStream { resp.Body = reader } } if isStream { h.logger.WithField("id", corrID).Info("📡 Processing stream response") if resp.StatusCode < 400 { *isSuccess = true h.logger.WithField("id", corrID).Info("✅ Stream response marked as success, passing through") // 不关闭 reader,让它继续流式传输 return nil } else { // 错误响应才读取完整内容 bodyBytes, err := io.ReadAll(reader) reader.Close() if err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to read error response") *attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error") } else { h.logger.WithField("id", corrID).Errorf("❌ Stream error: status=%d, body=%s", resp.StatusCode, string(bodyBytes)) *attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes) } *isSuccess = false resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) return nil } } // 非流式响应:读取完整内容 h.logger.WithField("id", corrID).Info("📄 Processing non-stream response") bodyBytes, err := io.ReadAll(reader) reader.Close() if err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to read response body") *attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response") *isSuccess = false resp.Body = io.NopCloser(bytes.NewReader([]byte{})) return nil } if resp.StatusCode < 400 { *isSuccess = true *pTokens, *cTokens = extractUsage(bodyBytes) h.logger.WithField("id", corrID).Infof("✅ Success: bytes=%d, pTokens=%d, cTokens=%d", len(bodyBytes), *pTokens, *cTokens) } else { h.logger.WithField("id", corrID).Errorf("❌ Error: status=%d, body=%s", resp.StatusCode, string(bodyBytes)) *attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes) *isSuccess = false } resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) return nil } } func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) { corrID := r.Header.Get("X-Correlation-ID") log := h.logger.WithField("id", corrID) log.Errorf("Transparent proxy encountered an error: %v", err) errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError) if !ok || errPtr == nil { log.Error("FATAL: proxyErrorContextKey not found in context for error handler.") defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred") writeErrorToResponse(rw, defaultErr) return } if *errPtr == nil { if errors.IsClientNetworkError(err) { *errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") } else { *errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error()) } } writeErrorToResponse(rw, *errPtr) } func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) { if attempt == 1 { return initialResources, nil } h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt) resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) if err != nil { return nil, err } finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig) resources.RequestConfig = finalRequestConfig return resources, nil } func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool { if attempt >= totalAttempts { return true } if err == nil { return false } // ✅ 不可重试的请求错误:立即停止 if errors.IsUnretryableRequestError(err.Message) { h.logger.WithField("id", correlationID).Warnf("Unretryable request error, aborting: %s", err.Message) return true } // ✅ 永久性上游错误:立即停止(Key 已失效) if errors.IsPermanentUpstreamError(err.Message) { h.logger.WithField("id", correlationID).Warnf("Permanent upstream error, aborting: %s", err.Message) return false } // ✅ 可重试的网络错误:继续重试 if errors.IsRetryableNetworkError(err.Message) { return false } // ✅ 临时性错误(配额等):继续重试 if errors.IsTemporaryUpstreamError(err.Message) { return false } // ✅ 其他未分类错误:继续重试 return false } func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) { if rec != nil { for k, v := range rec.Header() { c.Writer.Header()[k] = v } c.Writer.WriteHeader(rec.Code) c.Writer.Write(rec.Body.Bytes()) } else if apiErr != nil { errToJSON(c, corrID, apiErr) } else { errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred")) } } func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) { if res == nil { h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.") return } event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise) event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) event.RequestLog.IsSuccess = isSuccess event.RequestLog.Retries = retries if isSuccess { event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens } if rec != nil { event.RequestLog.StatusCode = rec.Code } if !isSuccess { errToLog := finalErr if errToLog == nil && rec != nil { errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), rec.Body.Bytes()) } if errToLog != nil { if errToLog.Code == "" && errToLog.HTTPStatus >= 400 { errToLog.Code = fmt.Sprintf("UPSTREAM_%d", errToLog.HTTPStatus) } event.Error = errToLog event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message event.RequestLog.Status = errToLog.Status } } eventData, err := json.Marshal(event) if err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event") return } if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event") } } func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) { retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise) retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) retryEvent.RequestLog.IsSuccess = false retryEvent.RequestLog.StatusCode = rec.Code retryEvent.RequestLog.Retries = retries if attemptErr != nil { retryEvent.Error = attemptErr retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message retryEvent.RequestLog.Status = attemptErr.Status } eventData, err := json.Marshal(retryEvent) if err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event") return } if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event") } } func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig { finalConfig := &models.RequestConfig{ CustomHeaders: make(datatypes.JSONMap), EnableStreamOptimizer: globalSettings.EnableStreamOptimizer, StreamMinDelay: globalSettings.StreamMinDelay, StreamMaxDelay: globalSettings.StreamMaxDelay, StreamShortTextThresh: globalSettings.StreamShortTextThresh, StreamLongTextThresh: globalSettings.StreamLongTextThresh, StreamChunkSize: globalSettings.StreamChunkSize, EnableFakeStream: globalSettings.EnableFakeStream, FakeStreamInterval: globalSettings.FakeStreamInterval, } for k, v := range globalSettings.CustomHeaders { finalConfig.CustomHeaders[k] = v } if groupConfig == nil { return finalConfig } groupConfigJSON, err := json.Marshal(groupConfig) if err != nil { h.logger.WithError(err).Error("Failed to marshal group request config for merging.") return finalConfig } if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil { h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.") } return finalConfig } func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) { if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() { return } rw.Header().Set("Content-Type", "application/json; charset=utf-8") rw.WriteHeader(apiErr.HTTPStatus) json.NewEncoder(rw).Encode(gin.H{"error": apiErr}) } func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) { startTime := time.Now() correlationID := uuid.New().String() log := h.logger.WithField("id", correlationID) log.Info("Smart Gateway activated for streaming request.") var originalRequest models.GeminiRequest if err := json.Unmarshal(requestBody, &originalRequest); err != nil { errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway")) return } systemSettings := h.settingsManager.GetSettings() modelName := h.channel.ExtractModel(c, requestBody) requestFinishedEvent := h.createLogEvent(c, startTime, correlationID, modelName, resources, models.LogTypeFinal, isPreciseRouting) defer func() { requestFinishedEvent.LatencyMs = int(time.Since(startTime).Milliseconds()) if c.Writer.Status() > 0 { requestFinishedEvent.StatusCode = c.Writer.Status() } eventData, err := json.Marshal(requestFinishedEvent) if err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream") return } if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream") } }() params := channel.SmartRequestParams{ CorrelationID: correlationID, APIKey: resources.APIKey, UpstreamURL: resources.UpstreamEndpoint.URL, RequestBody: requestBody, OriginalRequest: originalRequest, EventLogger: requestFinishedEvent, MaxRetries: systemSettings.MaxStreamingRetries, RetryDelay: time.Duration(systemSettings.StreamingRetryDelayMs) * time.Millisecond, LogTruncationLimit: systemSettings.LogTruncationLimit, StreamingRetryPrompt: systemSettings.StreamingRetryPrompt, } h.channel.ProcessSmartStreamRequest(c, params) } func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent { event := &models.RequestFinishedEvent{ RequestLog: models.RequestLog{ RequestTime: startTime, ModelName: modelName, RequestPath: c.Request.URL.Path, UserAgent: c.Request.UserAgent(), CorrelationID: corrID, LogType: logType, Metadata: make(datatypes.JSONMap), }, CorrelationID: corrID, IsPreciseRouting: isPreciseRouting, } if _, exists := c.Get(middleware.RedactedBodyKey); exists { event.RequestLog.Metadata["request_body_present"] = true } if redactedAuth, exists := c.Get(middleware.RedactedAuthHeaderKey); exists { event.RequestLog.Metadata["authorization_header"] = redactedAuth.(string) } if authTokenValue, exists := c.Get("authToken"); exists { if authToken, ok := authTokenValue.(*models.AuthToken); ok { event.RequestLog.AuthTokenID = &authToken.ID } } if res != nil { if res.APIKey != nil { event.RequestLog.KeyID = &res.APIKey.ID } if res.KeyGroup != nil { event.RequestLog.GroupID = &res.KeyGroup.ID } if res.UpstreamEndpoint != nil { event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID event.UpstreamURL = &res.UpstreamEndpoint.URL } if res.ProxyConfig != nil { event.RequestLog.ProxyID = &res.ProxyConfig.ID } } return event } func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, groupName string, isPreciseRouting bool) (*service.RequestResources, error) { authTokenValue, exists := c.Get("authToken") if !exists { return nil, errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context") } authToken, ok := authTokenValue.(*models.AuthToken) if !ok { return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context") } if isPreciseRouting { return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName) } return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName) } func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) { if c.IsAborted() { return } c.JSON(apiErr.HTTPStatus, gin.H{ "error": apiErr, "correlation_id": corrID, }) } type bufferPool struct{} func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) } func (b *bufferPool) Put(_ []byte) {} func extractUsage(body []byte) (promptTokens int, completionTokens int) { var data struct { UsageMetadata struct { PromptTokenCount int `json:"promptTokenCount"` CandidatesTokenCount int `json:"candidatesTokenCount"` } `json:"usageMetadata"` } if err := json.Unmarshal(body, &data); err == nil { return data.UsageMetadata.PromptTokenCount, data.UsageMetadata.CandidatesTokenCount } return 0, 0 } func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int { if isPreciseRouting && finalOpConfig.MaxRetries != nil { return *finalOpConfig.MaxRetries } return h.settingsManager.GetSettings().MaxRetries } func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) { authTokenValue, exists := c.Get("authToken") if !exists { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context")) return } authToken, ok := authTokenValue.(*models.AuthToken) if !ok { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")) return } modelNames := h.resourceService.GetAllowedModelsForToken(authToken) if strings.Contains(c.Request.URL.Path, "/v1beta/") { h.respondWithGeminiFormat(c, modelNames) } else { h.respondWithOpenAIFormat(c, modelNames) } } func (h *ProxyHandler) respondWithOpenAIFormat(c *gin.Context, modelNames []string) { type ModelEntry struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } type ModelListResponse struct { Object string `json:"object"` Data []ModelEntry `json:"data"` } data := make([]ModelEntry, len(modelNames)) for i, name := range modelNames { data[i] = ModelEntry{ ID: name, Object: "model", Created: time.Now().Unix(), OwnedBy: "gemini-balancer", } } response := ModelListResponse{ Object: "list", Data: data, } c.JSON(http.StatusOK, response) } func (h *ProxyHandler) respondWithGeminiFormat(c *gin.Context, modelNames []string) { type GeminiModelEntry struct { Name string `json:"name"` Version string `json:"version"` DisplayName string `json:"displayName"` Description string `json:"description"` SupportedGenerationMethods []string `json:"supportedGenerationMethods"` InputTokenLimit int `json:"inputTokenLimit"` OutputTokenLimit int `json:"outputTokenLimit"` } type GeminiModelListResponse struct { Models []GeminiModelEntry `json:"models"` } models := make([]GeminiModelEntry, len(modelNames)) for i, name := range modelNames { models[i] = GeminiModelEntry{ Name: fmt.Sprintf("models/%s", name), Version: "1.0.0", DisplayName: name, Description: "Served by Gemini Balancer", SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"}, InputTokenLimit: 8192, OutputTokenLimit: 2048, } } response := GeminiModelListResponse{Models: models} c.JSON(http.StatusOK, response) }