Files
gemini-banlancer/internal/handlers/proxy_handler.go
2025-11-25 16:58:15 +08:00

1062 lines
40 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/handlers/proxy_handler.go
package handlers
import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/channel"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/middleware"
"gemini-balancer/internal/models"
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gorm.io/datatypes"
)
type proxyErrorContextKey struct{}
type ProxyHandler struct {
resourceService *service.ResourceService
store store.Store
settingsManager *settings.SettingsManager
groupManager *service.GroupManager
channel channel.ChannelProxy
logger *logrus.Entry
transparentProxy *httputil.ReverseProxy
}
func NewProxyHandler(
resourceService *service.ResourceService,
store store.Store,
sm *settings.SettingsManager,
gm *service.GroupManager,
channel channel.ChannelProxy,
logger *logrus.Logger,
) *ProxyHandler {
ph := &ProxyHandler{
resourceService: resourceService,
store: store,
settingsManager: sm,
groupManager: gm,
channel: channel,
logger: logger.WithField("component", "ProxyHandler"),
transparentProxy: &httputil.ReverseProxy{},
}
ph.transparentProxy.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 60 * time.Second,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
ph.transparentProxy.ErrorHandler = ph.transparentProxyErrorHandler
ph.transparentProxy.BufferPool = &bufferPool{}
return ph
}
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
if c.Request.Method == "GET" && (strings.HasSuffix(c.Request.URL.Path, "/models") || strings.HasSuffix(c.Request.URL.Path, "/models/")) {
h.handleListModelsRequest(c)
return
}
maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024)
requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
if err != nil {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read"))
return
}
modelName := h.channel.ExtractModel(c, requestBody)
groupName := c.Param("group_name")
isPreciseRouting := groupName != ""
if !isPreciseRouting && modelName == "" {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request"))
return
}
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
if apiErr, ok := err.(*errors.APIError); ok {
errToJSON(c, uuid.New().String(), apiErr)
} else {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources"))
}
return
}
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
if err != nil {
h.logger.WithError(err).Error("Failed to build operational config.")
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config"))
return
}
initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig)
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
if isOpenAICompatible {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
return
}
isStream := h.channel.IsStreamRequest(c, requestBody)
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry {
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
} else {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
}
}
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
startTime := time.Now()
correlationID := uuid.New().String()
// ✅ 检查是否是流式请求
isStreamRequest := h.channel.IsStreamRequest(c, requestBody)
// ✅ 流式请求也支持重试
if isStreamRequest {
h.serveStreamWithRetry(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting, correlationID, startTime)
return
}
var finalRecorder *httptest.ResponseRecorder
var lastUsedResources *service.RequestResources
var finalProxyErr *errors.APIError
var isSuccess bool
var finalPromptTokens, finalCompletionTokens, actualRetries int
defer func() {
h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
actualRetries, isPreciseRouting)
}()
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
totalAttempts := maxRetries + 1
for attempt := 1; attempt <= totalAttempts; attempt++ {
if c.Request.Context().Err() != nil {
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
if finalProxyErr == nil {
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
}
break
}
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
if err != nil {
h.logger.WithField("id", correlationID).Errorf("❌ getResourcesForAttempt failed: %v", err)
if apiErr, ok := err.(*errors.APIError); ok {
finalProxyErr = apiErr
} else {
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
}
break
}
h.logger.WithField("id", correlationID).Infof("✅ Got resources: KeyID=%d", resources.APIKey.ID)
// lastUsedResources = resources
if attempt > 1 {
actualRetries = attempt - 1
}
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID)
recorder, attemptErr, attemptSuccess := h.executeProxyAttempt(
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
&finalPromptTokens, &finalCompletionTokens,
)
h.logger.WithField("id", correlationID).Infof("✅ Before assignment: lastUsedResources=%v", lastUsedResources)
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
// ✅ 修正 isSuccess
if finalProxyErr != nil || (finalRecorder != nil && finalRecorder.Code >= 400) {
isSuccess = false
}
lastUsedResources = resources
h.logger.WithField("id", correlationID).Infof("✅ After assignment: lastUsedResources=%v", lastUsedResources)
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
if isSuccess {
break
}
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
break
}
h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting)
}
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
}
// ✅ 修改 serveStreamWithRetry添加 nil 检查
func (h *ProxyHandler) serveStreamWithRetry(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool, correlationID string, startTime time.Time) {
initialResources.RequestConfig = h.buildFinalRequestConfig(
h.settingsManager.GetSettings(),
initialResources.RequestConfig,
)
h.logger.WithField("id", correlationID).Info("🌊 Serving stream request with retry support")
var lastUsedResources *service.RequestResources
var finalProxyErr *errors.APIError
var isSuccess bool
var actualRetries int
defer func() {
h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
nil, finalProxyErr, isSuccess, 0, 0, actualRetries, isPreciseRouting)
}()
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
totalAttempts := maxRetries + 1
for attempt := 1; attempt <= totalAttempts; attempt++ {
// ✅ 检查客户端是否断开连接
if c.Request.Context().Err() != nil {
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
if finalProxyErr == nil {
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
}
break
}
// ✅ 获取资源(第一次使用 initialResources后续重试获取新资源
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
if err != nil {
h.logger.WithField("id", correlationID).Errorf("❌ Failed to get resources: %v", err)
if apiErr, ok := err.(*errors.APIError); ok {
finalProxyErr = apiErr
} else {
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
}
break
}
if attempt > 1 {
actualRetries = attempt - 1
}
h.logger.WithField("id", correlationID).Infof("🔄 Stream attempt %d/%d (KeyID=%d, GroupID=%d)",
attempt, totalAttempts, resources.APIKey.ID, resources.KeyGroup.ID)
// ✅ 执行流式代理请求
attemptErr, attemptSuccess := h.executeStreamAttempt(
c, correlationID, requestBody, resources, groupName, isPreciseRouting,
)
finalProxyErr, isSuccess = attemptErr, attemptSuccess
lastUsedResources = resources
// ✅ 报告结果
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
// ✅ 成功则退出
if isSuccess {
h.logger.WithField("id", correlationID).Info("✅ Stream request succeeded")
break
}
// ✅ 判断是否应该停止重试
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
// ✅ 安全地记录错误信息(添加 nil 检查)
if finalProxyErr != nil {
h.logger.WithField("id", correlationID).Warnf("⛔ Stopping retry: %s", finalProxyErr.Message)
} else {
h.logger.WithField("id", correlationID).Warn("⛔ Stopping retry: unknown error")
}
break
}
// ✅ 发布重试日志事件
h.publishStreamRetryLogEvent(c, startTime, correlationID, modelName, resources, attemptErr, actualRetries, isPreciseRouting)
// ✅ 简化重试日志
if attempt < totalAttempts {
h.logger.WithField("id", correlationID).Infof("🔁 Retrying... (%d/%d)", attempt, totalAttempts-1)
}
}
// ✅ 如果所有尝试都失败,写入错误响应
if !isSuccess && finalProxyErr != nil {
h.logger.WithField("id", correlationID).Warnf("❌ All stream attempts failed: %s (code=%s)",
finalProxyErr.Message, finalProxyErr.Code)
// ✅ 检查是否已经写入响应头
if !c.Writer.Written() {
errToJSON(c, correlationID, finalProxyErr)
} else {
h.logger.WithField("id", correlationID).Warn("⚠️ Cannot write error, response already started")
}
}
}
// 执行单次流式代理请求
func (h *ProxyHandler) executeStreamAttempt(
c *gin.Context,
correlationID string,
requestBody []byte,
resources *service.RequestResources,
groupName string,
isPreciseRouting bool,
) (finalErr *errors.APIError, finalSuccess bool) { // ✅ 使用命名返回值
// ✅ 捕获 ReverseProxy 的 ErrAbortHandler panic
defer func() {
if r := recover(); r != nil {
// ✅ 如果是 http.ErrAbortHandler说明流式响应已成功完成
if r == http.ErrAbortHandler {
h.logger.WithField("id", correlationID).Debug("✅ Stream completed (ErrAbortHandler caught)")
// ✅ 修改命名返回值,确保返回成功状态
finalErr = nil
finalSuccess = true
return
}
// ✅ 其他 panic 继续抛出
h.logger.WithField("id", correlationID).Errorf("❌ Unexpected panic in stream: %v", r)
panic(r)
}
}()
var attemptErr *errors.APIError
var isSuccess bool
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
defer cancel()
attemptReq := c.Request.Clone(ctx)
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
attemptReq.ContentLength = int64(len(requestBody))
// ✅ 创建独立的 ReverseProxy
streamProxy := &httputil.ReverseProxy{
Transport: h.transparentProxy.Transport,
BufferPool: h.transparentProxy.BufferPool,
}
streamProxy.Director = func(r *http.Request) {
targetURL, _ := url.Parse(resources.UpstreamEndpoint.URL)
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
var pureClientPath string
if isPreciseRouting {
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
} else {
pureClientPath = r.URL.Path
}
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
r.Header.Del("Authorization")
h.channel.ModifyRequest(r, resources.APIKey)
r.Header.Set("X-Correlation-ID", correlationID)
// ✅ 添加:应用自定义请求头
if resources.RequestConfig != nil {
for k, v := range resources.RequestConfig.CustomHeaders {
if strVal, ok := v.(string); ok {
r.Header.Set(k, strVal)
}
}
}
}
// ✅ 配置 Transport
transport := streamProxy.Transport.(*http.Transport)
if resources.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", resources.ProxyConfig.Protocol, resources.ProxyConfig.Address)
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
transportCopy := transport.Clone()
transportCopy.Proxy = http.ProxyURL(proxyURL)
streamProxy.Transport = transportCopy
h.logger.WithField("id", correlationID).Infof("🔀 Using proxy: %s", proxyURLStr)
}
}
// ✅ 配置 ModifyResponse
streamProxy.ModifyResponse = func(resp *http.Response) error {
h.logger.WithField("id", correlationID).Infof("📨 Stream response: status=%d, contentType=%s",
resp.StatusCode, resp.Header.Get("Content-Type"))
// ✅ 处理 gzip 解压
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to create gzip reader")
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response")
isSuccess = false
return fmt.Errorf("gzip decompression failed: %w", err)
}
resp.Body = gzReader
resp.Header.Del("Content-Encoding")
}
// ✅ 成功响应:直接透传
if resp.StatusCode < 400 {
isSuccess = true
h.logger.WithField("id", correlationID).Info("✅ Stream response marked as success")
return nil
}
// ✅ 错误响应:读取错误信息(用于重试判断)
isSuccess = false
// ✅ 读取错误响应体
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to read error response")
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error")
} else {
// ✅ 根据状态码决定是否输出详细错误信息
shouldLogErrorBody := h.shouldLogErrorBody(resp.StatusCode)
if shouldLogErrorBody {
h.logger.WithField("id", correlationID).Errorf("❌ Stream error: status=%d, body=%s",
resp.StatusCode, string(bodyBytes))
} else {
// ✅ 对于常见错误429、403等只记录简要信息
errorSummary := h.extractErrorSummary(bodyBytes)
h.logger.WithField("id", correlationID).Warnf("⚠️ Stream error: status=%d, summary=%s",
resp.StatusCode, errorSummary)
}
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
}
// ✅ 返回错误,触发 ErrorHandler但不写入响应因为可能需要重试
return fmt.Errorf("upstream error: status %d", resp.StatusCode)
}
// ✅ 配置 ErrorHandler
streamProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
h.logger.WithField("id", correlationID).Debugf("Stream proxy error handler triggered: %v", err)
// ✅ 如果 attemptErr 未设置,根据错误类型创建
if attemptErr == nil {
isSuccess = false
if err == context.DeadlineExceeded {
attemptErr = errors.NewAPIError(errors.ErrGatewayTimeout, "Request timeout")
} else if err == context.Canceled {
attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Request canceled")
} else if errors.IsClientNetworkError(err) {
attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
attemptErr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
}
// ✅ 不在这里写入响应,让外层重试逻辑决定
}
// ✅ 执行代理请求(可能抛出 ErrAbortHandler
streamProxy.ServeHTTP(c.Writer, attemptReq)
// ✅ 正常返回(如果没有 panic
return attemptErr, isSuccess
}
// ✅ 新增:判断是否应该记录详细错误体
func (h *ProxyHandler) shouldLogErrorBody(statusCode int) bool {
// ✅ 对于常见的客户端错误和限流错误,不记录详细错误体
commonErrors := map[int]bool{
400: true, // Bad Request
401: true, // Unauthorized
403: true, // Forbidden
404: true, // Not Found
429: true, // Too Many Requests
}
return !commonErrors[statusCode]
}
// ✅ 新增:从错误响应中提取简要信息
func (h *ProxyHandler) extractErrorSummary(bodyBytes []byte) string {
// ✅ 尝试解析 JSON 错误响应
var errorResp struct {
Error struct {
Message string `json:"message"`
Code int `json:"code"`
Status string `json:"status"`
} `json:"error"`
}
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil && errorResp.Error.Message != "" {
// ✅ 截取错误消息的前100个字符
message := errorResp.Error.Message
if len(message) > 100 {
message = message[:100] + "..."
}
if errorResp.Error.Status != "" {
return fmt.Sprintf("%s: %s", errorResp.Error.Status, message)
}
return message
}
// ✅ 如果无法解析 JSON返回前100个字符
if len(bodyBytes) > 100 {
return string(bodyBytes[:100]) + "..."
}
return string(bodyBytes)
}
// ✅ 新增:发布流式重试日志事件
func (h *ProxyHandler) publishStreamRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, attemptErr *errors.APIError, retries int, isPrecise bool) {
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
retryEvent.RequestLog.IsSuccess = false
retryEvent.RequestLog.Retries = retries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.RequestLog.ErrorCode = attemptErr.Code
retryEvent.RequestLog.ErrorMessage = attemptErr.Message
retryEvent.RequestLog.Status = attemptErr.Status
retryEvent.RequestLog.StatusCode = attemptErr.HTTPStatus
}
eventData, err := json.Marshal(retryEvent)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal stream retry log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish stream retry log event")
}
}
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
recorder := httptest.NewRecorder()
var attemptErr *errors.APIError
isSuccess := false
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
defer cancel()
attemptReq := c.Request.Clone(ctx)
attemptReq.Body = io.NopCloser(bytes.NewReader(body))
attemptReq.ContentLength = int64(len(body))
h.logger.WithField("id", corrID).Infof("🚀 Starting proxy attempt with KeyID=%d", res.APIKey.ID)
h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
*attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
h.transparentProxy.ServeHTTP(recorder, attemptReq)
h.logger.WithField("id", corrID).Infof("📥 Proxy returned: status=%d, bodyLen=%d, err=%v, success=%v",
recorder.Code, recorder.Body.Len(), attemptErr, isSuccess)
// 调试检查 ✅
if recorder.Code == 0 && attemptErr != nil {
h.logger.WithField("id", corrID).Warnf("⚠️ Fixing zero status code to %d", attemptErr.HTTPStatus)
recorder.Code = attemptErr.HTTPStatus
if recorder.Body.Len() == 0 {
errJSON, _ := json.Marshal(gin.H{"error": attemptErr})
recorder.Body.Write(errJSON)
}
}
return recorder, attemptErr, isSuccess
}
func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) {
h.transparentProxy.Director = func(r *http.Request) {
targetURL, _ := url.Parse(res.UpstreamEndpoint.URL)
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
var pureClientPath string
if isPrecise {
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
} else {
pureClientPath = r.URL.Path
}
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
r.Header.Del("Authorization")
h.channel.ModifyRequest(r, res.APIKey)
r.Header.Set("X-Correlation-ID", corrID)
}
transport := h.transparentProxy.Transport.(*http.Transport)
if res.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address)
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
transport.Proxy = http.ProxyURL(proxyURL)
} else {
transport.Proxy = http.ProxyFromEnvironment
}
} else {
transport.Proxy = http.ProxyFromEnvironment
}
h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens)
}
func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
return func(resp *http.Response) error {
corrID := resp.Request.Header.Get("X-Correlation-ID")
h.logger.WithField("id", corrID).Infof("📨 Upstream response: status=%d, contentType=%s",
resp.StatusCode, resp.Header.Get("Content-Type"))
// 检查是否是流式响应
isStream := strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
// 处理 gzip 压缩
var reader io.ReadCloser = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to create gzip reader")
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response")
*isSuccess = false
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
return nil
}
reader = gzReader
resp.Header.Del("Content-Encoding")
// ✅ 对于流式响应,需要替换 resp.Body 为解压后的 reader
if isStream {
resp.Body = reader
}
}
if isStream {
h.logger.WithField("id", corrID).Info("📡 Processing stream response")
if resp.StatusCode < 400 {
*isSuccess = true
h.logger.WithField("id", corrID).Info("✅ Stream response marked as success, passing through")
// 不关闭 reader让它继续流式传输
return nil
} else {
// 错误响应才读取完整内容
bodyBytes, err := io.ReadAll(reader)
reader.Close()
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to read error response")
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error")
} else {
h.logger.WithField("id", corrID).Errorf("❌ Stream error: status=%d, body=%s",
resp.StatusCode, string(bodyBytes))
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
}
*isSuccess = false
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
}
// 非流式响应:读取完整内容
h.logger.WithField("id", corrID).Info("📄 Processing non-stream response")
bodyBytes, err := io.ReadAll(reader)
reader.Close()
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to read response body")
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
*isSuccess = false
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
return nil
}
if resp.StatusCode < 400 {
*isSuccess = true
*pTokens, *cTokens = extractUsage(bodyBytes)
h.logger.WithField("id", corrID).Infof("✅ Success: bytes=%d, pTokens=%d, cTokens=%d",
len(bodyBytes), *pTokens, *cTokens)
} else {
h.logger.WithField("id", corrID).Errorf("❌ Error: status=%d, body=%s",
resp.StatusCode, string(bodyBytes))
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
*isSuccess = false
}
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
}
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
corrID := r.Header.Get("X-Correlation-ID")
log := h.logger.WithField("id", corrID)
log.Errorf("Transparent proxy encountered an error: %v", err)
errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError)
if !ok || errPtr == nil {
log.Error("FATAL: proxyErrorContextKey not found in context for error handler.")
defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred")
writeErrorToResponse(rw, defaultErr)
return
}
if *errPtr == nil {
if errors.IsClientNetworkError(err) {
*errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
*errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
}
writeErrorToResponse(rw, *errPtr)
}
func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) {
if attempt == 1 {
return initialResources, nil
}
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
return nil, err
}
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig)
resources.RequestConfig = finalRequestConfig
return resources, nil
}
func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool {
if attempt >= totalAttempts {
return true
}
if err == nil {
return false
}
// ✅ 不可重试的请求错误:立即停止
if errors.IsUnretryableRequestError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Unretryable request error, aborting: %s", err.Message)
return true
}
// ✅ 永久性上游错误立即停止Key 已失效)
if errors.IsPermanentUpstreamError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Permanent upstream error, aborting: %s", err.Message)
return false
}
// ✅ 可重试的网络错误:继续重试
if errors.IsRetryableNetworkError(err.Message) {
return false
}
// ✅ 临时性错误(配额等):继续重试
if errors.IsTemporaryUpstreamError(err.Message) {
return false
}
// ✅ 其他未分类错误:继续重试
return false
}
func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) {
if rec != nil {
for k, v := range rec.Header() {
c.Writer.Header()[k] = v
}
c.Writer.WriteHeader(rec.Code)
c.Writer.Write(rec.Body.Bytes())
} else if apiErr != nil {
errToJSON(c, corrID, apiErr)
} else {
errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred"))
}
}
func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) {
if res == nil {
h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.")
return
}
event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise)
event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
event.RequestLog.IsSuccess = isSuccess
event.RequestLog.Retries = retries
if isSuccess {
event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens
}
if rec != nil {
event.RequestLog.StatusCode = rec.Code
}
if !isSuccess {
errToLog := finalErr
if errToLog == nil && rec != nil {
errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), rec.Body.Bytes())
}
if errToLog != nil {
if errToLog.Code == "" && errToLog.HTTPStatus >= 400 {
errToLog.Code = fmt.Sprintf("UPSTREAM_%d", errToLog.HTTPStatus)
}
event.Error = errToLog
event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message
event.RequestLog.Status = errToLog.Status
}
}
eventData, err := json.Marshal(event)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event")
}
}
func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) {
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
retryEvent.RequestLog.IsSuccess = false
retryEvent.RequestLog.StatusCode = rec.Code
retryEvent.RequestLog.Retries = retries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message
retryEvent.RequestLog.Status = attemptErr.Status
}
eventData, err := json.Marshal(retryEvent)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event")
}
}
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
finalConfig := &models.RequestConfig{
CustomHeaders: make(datatypes.JSONMap),
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
StreamMinDelay: globalSettings.StreamMinDelay,
StreamMaxDelay: globalSettings.StreamMaxDelay,
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
StreamChunkSize: globalSettings.StreamChunkSize,
EnableFakeStream: globalSettings.EnableFakeStream,
FakeStreamInterval: globalSettings.FakeStreamInterval,
}
for k, v := range globalSettings.CustomHeaders {
finalConfig.CustomHeaders[k] = v
}
if groupConfig == nil {
return finalConfig
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
return finalConfig
}
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
}
return finalConfig
}
func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) {
if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() {
return
}
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(apiErr.HTTPStatus)
json.NewEncoder(rw).Encode(gin.H{"error": apiErr})
}
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
startTime := time.Now()
correlationID := uuid.New().String()
log := h.logger.WithField("id", correlationID)
log.Info("Smart Gateway activated for streaming request.")
var originalRequest models.GeminiRequest
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway"))
return
}
systemSettings := h.settingsManager.GetSettings()
modelName := h.channel.ExtractModel(c, requestBody)
requestFinishedEvent := h.createLogEvent(c, startTime, correlationID, modelName, resources, models.LogTypeFinal, isPreciseRouting)
defer func() {
requestFinishedEvent.LatencyMs = int(time.Since(startTime).Milliseconds())
if c.Writer.Status() > 0 {
requestFinishedEvent.StatusCode = c.Writer.Status()
}
eventData, err := json.Marshal(requestFinishedEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream")
}
}()
params := channel.SmartRequestParams{
CorrelationID: correlationID,
APIKey: resources.APIKey,
UpstreamURL: resources.UpstreamEndpoint.URL,
RequestBody: requestBody,
OriginalRequest: originalRequest,
EventLogger: requestFinishedEvent,
MaxRetries: systemSettings.MaxStreamingRetries,
RetryDelay: time.Duration(systemSettings.StreamingRetryDelayMs) * time.Millisecond,
LogTruncationLimit: systemSettings.LogTruncationLimit,
StreamingRetryPrompt: systemSettings.StreamingRetryPrompt,
}
h.channel.ProcessSmartStreamRequest(c, params)
}
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
event := &models.RequestFinishedEvent{
RequestLog: models.RequestLog{
RequestTime: startTime,
ModelName: modelName,
RequestPath: c.Request.URL.Path,
UserAgent: c.Request.UserAgent(),
CorrelationID: corrID,
LogType: logType,
Metadata: make(datatypes.JSONMap),
},
CorrelationID: corrID,
IsPreciseRouting: isPreciseRouting,
}
if _, exists := c.Get(middleware.RedactedBodyKey); exists {
event.RequestLog.Metadata["request_body_present"] = true
}
if redactedAuth, exists := c.Get(middleware.RedactedAuthHeaderKey); exists {
event.RequestLog.Metadata["authorization_header"] = redactedAuth.(string)
}
if authTokenValue, exists := c.Get("authToken"); exists {
if authToken, ok := authTokenValue.(*models.AuthToken); ok {
event.RequestLog.AuthTokenID = &authToken.ID
}
}
if res != nil {
if res.APIKey != nil {
event.RequestLog.KeyID = &res.APIKey.ID
}
if res.KeyGroup != nil {
event.RequestLog.GroupID = &res.KeyGroup.ID
}
if res.UpstreamEndpoint != nil {
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
event.UpstreamURL = &res.UpstreamEndpoint.URL
}
if res.ProxyConfig != nil {
event.RequestLog.ProxyID = &res.ProxyConfig.ID
}
}
return event
}
func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, groupName string, isPreciseRouting bool) (*service.RequestResources, error) {
authTokenValue, exists := c.Get("authToken")
if !exists {
return nil, errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context")
}
authToken, ok := authTokenValue.(*models.AuthToken)
if !ok {
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
}
if isPreciseRouting {
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
}
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
}
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{
"error": apiErr,
"correlation_id": corrID,
})
}
type bufferPool struct{}
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
func (b *bufferPool) Put(_ []byte) {}
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
var data struct {
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal(body, &data); err == nil {
return data.UsageMetadata.PromptTokenCount, data.UsageMetadata.CandidatesTokenCount
}
return 0, 0
}
func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int {
if isPreciseRouting && finalOpConfig.MaxRetries != nil {
return *finalOpConfig.MaxRetries
}
return h.settingsManager.GetSettings().MaxRetries
}
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
authTokenValue, exists := c.Get("authToken")
if !exists {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context"))
return
}
authToken, ok := authTokenValue.(*models.AuthToken)
if !ok {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context"))
return
}
modelNames := h.resourceService.GetAllowedModelsForToken(authToken)
if strings.Contains(c.Request.URL.Path, "/v1beta/") {
h.respondWithGeminiFormat(c, modelNames)
} else {
h.respondWithOpenAIFormat(c, modelNames)
}
}
func (h *ProxyHandler) respondWithOpenAIFormat(c *gin.Context, modelNames []string) {
type ModelEntry struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelEntry `json:"data"`
}
data := make([]ModelEntry, len(modelNames))
for i, name := range modelNames {
data[i] = ModelEntry{
ID: name,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "gemini-balancer",
}
}
response := ModelListResponse{
Object: "list",
Data: data,
}
c.JSON(http.StatusOK, response)
}
func (h *ProxyHandler) respondWithGeminiFormat(c *gin.Context, modelNames []string) {
type GeminiModelEntry struct {
Name string `json:"name"`
Version string `json:"version"`
DisplayName string `json:"displayName"`
Description string `json:"description"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
InputTokenLimit int `json:"inputTokenLimit"`
OutputTokenLimit int `json:"outputTokenLimit"`
}
type GeminiModelListResponse struct {
Models []GeminiModelEntry `json:"models"`
}
models := make([]GeminiModelEntry, len(modelNames))
for i, name := range modelNames {
models[i] = GeminiModelEntry{
Name: fmt.Sprintf("models/%s", name),
Version: "1.0.0",
DisplayName: name,
Description: "Served by Gemini Balancer",
SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"},
InputTokenLimit: 8192,
OutputTokenLimit: 2048,
}
}
response := GeminiModelListResponse{Models: models}
c.JSON(http.StatusOK, response)
}