1062 lines
40 KiB
Go
1062 lines
40 KiB
Go
// Filename: internal/handlers/proxy_handler.go
|
||
package handlers
|
||
|
||
import (
|
||
"bytes"
|
||
"compress/gzip"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"gemini-balancer/internal/channel"
|
||
"gemini-balancer/internal/errors"
|
||
"gemini-balancer/internal/middleware"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/service"
|
||
"gemini-balancer/internal/settings"
|
||
"gemini-balancer/internal/store"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"net/http/httputil"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/google/uuid"
|
||
"github.com/sirupsen/logrus"
|
||
"gorm.io/datatypes"
|
||
)
|
||
|
||
type proxyErrorContextKey struct{}
|
||
|
||
type ProxyHandler struct {
|
||
resourceService *service.ResourceService
|
||
store store.Store
|
||
settingsManager *settings.SettingsManager
|
||
groupManager *service.GroupManager
|
||
channel channel.ChannelProxy
|
||
logger *logrus.Entry
|
||
transparentProxy *httputil.ReverseProxy
|
||
}
|
||
|
||
func NewProxyHandler(
|
||
resourceService *service.ResourceService,
|
||
store store.Store,
|
||
sm *settings.SettingsManager,
|
||
gm *service.GroupManager,
|
||
channel channel.ChannelProxy,
|
||
logger *logrus.Logger,
|
||
) *ProxyHandler {
|
||
ph := &ProxyHandler{
|
||
resourceService: resourceService,
|
||
store: store,
|
||
settingsManager: sm,
|
||
groupManager: gm,
|
||
channel: channel,
|
||
logger: logger.WithField("component", "ProxyHandler"),
|
||
transparentProxy: &httputil.ReverseProxy{},
|
||
}
|
||
ph.transparentProxy.Transport = &http.Transport{
|
||
Proxy: http.ProxyFromEnvironment,
|
||
DialContext: (&net.Dialer{
|
||
Timeout: 30 * time.Second,
|
||
KeepAlive: 60 * time.Second,
|
||
}).DialContext,
|
||
MaxIdleConns: 100,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
TLSHandshakeTimeout: 10 * time.Second,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
}
|
||
ph.transparentProxy.ErrorHandler = ph.transparentProxyErrorHandler
|
||
ph.transparentProxy.BufferPool = &bufferPool{}
|
||
return ph
|
||
}
|
||
|
||
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||
if c.Request.Method == "GET" && (strings.HasSuffix(c.Request.URL.Path, "/models") || strings.HasSuffix(c.Request.URL.Path, "/models/")) {
|
||
h.handleListModelsRequest(c)
|
||
return
|
||
}
|
||
|
||
maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024)
|
||
requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
|
||
if err != nil {
|
||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read"))
|
||
return
|
||
}
|
||
|
||
modelName := h.channel.ExtractModel(c, requestBody)
|
||
groupName := c.Param("group_name")
|
||
isPreciseRouting := groupName != ""
|
||
|
||
if !isPreciseRouting && modelName == "" {
|
||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request"))
|
||
return
|
||
}
|
||
|
||
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||
if err != nil {
|
||
if apiErr, ok := err.(*errors.APIError); ok {
|
||
errToJSON(c, uuid.New().String(), apiErr)
|
||
} else {
|
||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources"))
|
||
}
|
||
return
|
||
}
|
||
|
||
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
|
||
if err != nil {
|
||
h.logger.WithError(err).Error("Failed to build operational config.")
|
||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config"))
|
||
return
|
||
}
|
||
|
||
initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig)
|
||
|
||
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
|
||
if isOpenAICompatible {
|
||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||
return
|
||
}
|
||
|
||
isStream := h.channel.IsStreamRequest(c, requestBody)
|
||
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
|
||
if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry {
|
||
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
|
||
} else {
|
||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
|
||
startTime := time.Now()
|
||
correlationID := uuid.New().String()
|
||
// ✅ 检查是否是流式请求
|
||
isStreamRequest := h.channel.IsStreamRequest(c, requestBody)
|
||
|
||
// ✅ 流式请求也支持重试
|
||
if isStreamRequest {
|
||
h.serveStreamWithRetry(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting, correlationID, startTime)
|
||
return
|
||
}
|
||
var finalRecorder *httptest.ResponseRecorder
|
||
var lastUsedResources *service.RequestResources
|
||
var finalProxyErr *errors.APIError
|
||
var isSuccess bool
|
||
var finalPromptTokens, finalCompletionTokens, actualRetries int
|
||
|
||
defer func() {
|
||
h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
|
||
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
|
||
actualRetries, isPreciseRouting)
|
||
}()
|
||
|
||
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
|
||
totalAttempts := maxRetries + 1
|
||
|
||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||
if c.Request.Context().Err() != nil {
|
||
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
|
||
if finalProxyErr == nil {
|
||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
|
||
}
|
||
break
|
||
}
|
||
|
||
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
|
||
if err != nil {
|
||
h.logger.WithField("id", correlationID).Errorf("❌ getResourcesForAttempt failed: %v", err)
|
||
if apiErr, ok := err.(*errors.APIError); ok {
|
||
finalProxyErr = apiErr
|
||
} else {
|
||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
|
||
}
|
||
break
|
||
}
|
||
|
||
h.logger.WithField("id", correlationID).Infof("✅ Got resources: KeyID=%d", resources.APIKey.ID)
|
||
// lastUsedResources = resources
|
||
if attempt > 1 {
|
||
actualRetries = attempt - 1
|
||
}
|
||
|
||
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID)
|
||
|
||
recorder, attemptErr, attemptSuccess := h.executeProxyAttempt(
|
||
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
|
||
&finalPromptTokens, &finalCompletionTokens,
|
||
)
|
||
h.logger.WithField("id", correlationID).Infof("✅ Before assignment: lastUsedResources=%v", lastUsedResources)
|
||
|
||
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
|
||
// ✅ 修正 isSuccess
|
||
if finalProxyErr != nil || (finalRecorder != nil && finalRecorder.Code >= 400) {
|
||
isSuccess = false
|
||
}
|
||
lastUsedResources = resources
|
||
h.logger.WithField("id", correlationID).Infof("✅ After assignment: lastUsedResources=%v", lastUsedResources)
|
||
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
|
||
|
||
if isSuccess {
|
||
break
|
||
}
|
||
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
|
||
break
|
||
}
|
||
h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting)
|
||
}
|
||
|
||
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
|
||
}
|
||
|
||
// ✅ 修改 serveStreamWithRetry,添加 nil 检查
|
||
func (h *ProxyHandler) serveStreamWithRetry(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool, correlationID string, startTime time.Time) {
|
||
initialResources.RequestConfig = h.buildFinalRequestConfig(
|
||
h.settingsManager.GetSettings(),
|
||
initialResources.RequestConfig,
|
||
)
|
||
h.logger.WithField("id", correlationID).Info("🌊 Serving stream request with retry support")
|
||
var lastUsedResources *service.RequestResources
|
||
var finalProxyErr *errors.APIError
|
||
var isSuccess bool
|
||
var actualRetries int
|
||
defer func() {
|
||
h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
|
||
nil, finalProxyErr, isSuccess, 0, 0, actualRetries, isPreciseRouting)
|
||
}()
|
||
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
|
||
totalAttempts := maxRetries + 1
|
||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||
// ✅ 检查客户端是否断开连接
|
||
if c.Request.Context().Err() != nil {
|
||
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
|
||
if finalProxyErr == nil {
|
||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
|
||
}
|
||
break
|
||
}
|
||
// ✅ 获取资源(第一次使用 initialResources,后续重试获取新资源)
|
||
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
|
||
if err != nil {
|
||
h.logger.WithField("id", correlationID).Errorf("❌ Failed to get resources: %v", err)
|
||
if apiErr, ok := err.(*errors.APIError); ok {
|
||
finalProxyErr = apiErr
|
||
} else {
|
||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
|
||
}
|
||
break
|
||
}
|
||
if attempt > 1 {
|
||
actualRetries = attempt - 1
|
||
}
|
||
h.logger.WithField("id", correlationID).Infof("🔄 Stream attempt %d/%d (KeyID=%d, GroupID=%d)",
|
||
attempt, totalAttempts, resources.APIKey.ID, resources.KeyGroup.ID)
|
||
// ✅ 执行流式代理请求
|
||
attemptErr, attemptSuccess := h.executeStreamAttempt(
|
||
c, correlationID, requestBody, resources, groupName, isPreciseRouting,
|
||
)
|
||
finalProxyErr, isSuccess = attemptErr, attemptSuccess
|
||
lastUsedResources = resources
|
||
// ✅ 报告结果
|
||
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
|
||
// ✅ 成功则退出
|
||
if isSuccess {
|
||
h.logger.WithField("id", correlationID).Info("✅ Stream request succeeded")
|
||
break
|
||
}
|
||
// ✅ 判断是否应该停止重试
|
||
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
|
||
// ✅ 安全地记录错误信息(添加 nil 检查)
|
||
if finalProxyErr != nil {
|
||
h.logger.WithField("id", correlationID).Warnf("⛔ Stopping retry: %s", finalProxyErr.Message)
|
||
} else {
|
||
h.logger.WithField("id", correlationID).Warn("⛔ Stopping retry: unknown error")
|
||
}
|
||
break
|
||
}
|
||
// ✅ 发布重试日志事件
|
||
h.publishStreamRetryLogEvent(c, startTime, correlationID, modelName, resources, attemptErr, actualRetries, isPreciseRouting)
|
||
// ✅ 简化重试日志
|
||
if attempt < totalAttempts {
|
||
h.logger.WithField("id", correlationID).Infof("🔁 Retrying... (%d/%d)", attempt, totalAttempts-1)
|
||
}
|
||
}
|
||
// ✅ 如果所有尝试都失败,写入错误响应
|
||
if !isSuccess && finalProxyErr != nil {
|
||
h.logger.WithField("id", correlationID).Warnf("❌ All stream attempts failed: %s (code=%s)",
|
||
finalProxyErr.Message, finalProxyErr.Code)
|
||
// ✅ 检查是否已经写入响应头
|
||
if !c.Writer.Written() {
|
||
errToJSON(c, correlationID, finalProxyErr)
|
||
} else {
|
||
h.logger.WithField("id", correlationID).Warn("⚠️ Cannot write error, response already started")
|
||
}
|
||
}
|
||
}
|
||
|
||
// 执行单次流式代理请求
|
||
func (h *ProxyHandler) executeStreamAttempt(
|
||
c *gin.Context,
|
||
correlationID string,
|
||
requestBody []byte,
|
||
resources *service.RequestResources,
|
||
groupName string,
|
||
isPreciseRouting bool,
|
||
) (finalErr *errors.APIError, finalSuccess bool) { // ✅ 使用命名返回值
|
||
|
||
// ✅ 捕获 ReverseProxy 的 ErrAbortHandler panic
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
// ✅ 如果是 http.ErrAbortHandler,说明流式响应已成功完成
|
||
if r == http.ErrAbortHandler {
|
||
h.logger.WithField("id", correlationID).Debug("✅ Stream completed (ErrAbortHandler caught)")
|
||
// ✅ 修改命名返回值,确保返回成功状态
|
||
finalErr = nil
|
||
finalSuccess = true
|
||
return
|
||
}
|
||
// ✅ 其他 panic 继续抛出
|
||
h.logger.WithField("id", correlationID).Errorf("❌ Unexpected panic in stream: %v", r)
|
||
panic(r)
|
||
}
|
||
}()
|
||
var attemptErr *errors.APIError
|
||
var isSuccess bool
|
||
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
|
||
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||
defer cancel()
|
||
attemptReq := c.Request.Clone(ctx)
|
||
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||
attemptReq.ContentLength = int64(len(requestBody))
|
||
// ✅ 创建独立的 ReverseProxy
|
||
streamProxy := &httputil.ReverseProxy{
|
||
Transport: h.transparentProxy.Transport,
|
||
BufferPool: h.transparentProxy.BufferPool,
|
||
}
|
||
streamProxy.Director = func(r *http.Request) {
|
||
targetURL, _ := url.Parse(resources.UpstreamEndpoint.URL)
|
||
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
|
||
var pureClientPath string
|
||
if isPreciseRouting {
|
||
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
|
||
} else {
|
||
pureClientPath = r.URL.Path
|
||
}
|
||
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||
r.Header.Del("Authorization")
|
||
h.channel.ModifyRequest(r, resources.APIKey)
|
||
r.Header.Set("X-Correlation-ID", correlationID)
|
||
|
||
// ✅ 添加:应用自定义请求头
|
||
if resources.RequestConfig != nil {
|
||
for k, v := range resources.RequestConfig.CustomHeaders {
|
||
if strVal, ok := v.(string); ok {
|
||
r.Header.Set(k, strVal)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// ✅ 配置 Transport
|
||
transport := streamProxy.Transport.(*http.Transport)
|
||
if resources.ProxyConfig != nil {
|
||
proxyURLStr := fmt.Sprintf("%s://%s", resources.ProxyConfig.Protocol, resources.ProxyConfig.Address)
|
||
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
|
||
transportCopy := transport.Clone()
|
||
transportCopy.Proxy = http.ProxyURL(proxyURL)
|
||
streamProxy.Transport = transportCopy
|
||
h.logger.WithField("id", correlationID).Infof("🔀 Using proxy: %s", proxyURLStr)
|
||
}
|
||
}
|
||
// ✅ 配置 ModifyResponse
|
||
streamProxy.ModifyResponse = func(resp *http.Response) error {
|
||
h.logger.WithField("id", correlationID).Infof("📨 Stream response: status=%d, contentType=%s",
|
||
resp.StatusCode, resp.Header.Get("Content-Type"))
|
||
// ✅ 处理 gzip 解压
|
||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||
gzReader, err := gzip.NewReader(resp.Body)
|
||
if err != nil {
|
||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to create gzip reader")
|
||
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response")
|
||
isSuccess = false
|
||
return fmt.Errorf("gzip decompression failed: %w", err)
|
||
}
|
||
resp.Body = gzReader
|
||
resp.Header.Del("Content-Encoding")
|
||
}
|
||
// ✅ 成功响应:直接透传
|
||
if resp.StatusCode < 400 {
|
||
isSuccess = true
|
||
h.logger.WithField("id", correlationID).Info("✅ Stream response marked as success")
|
||
return nil
|
||
}
|
||
// ✅ 错误响应:读取错误信息(用于重试判断)
|
||
isSuccess = false
|
||
// ✅ 读取错误响应体
|
||
bodyBytes, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to read error response")
|
||
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error")
|
||
} else {
|
||
// ✅ 根据状态码决定是否输出详细错误信息
|
||
shouldLogErrorBody := h.shouldLogErrorBody(resp.StatusCode)
|
||
if shouldLogErrorBody {
|
||
h.logger.WithField("id", correlationID).Errorf("❌ Stream error: status=%d, body=%s",
|
||
resp.StatusCode, string(bodyBytes))
|
||
} else {
|
||
// ✅ 对于常见错误(429、403等),只记录简要信息
|
||
errorSummary := h.extractErrorSummary(bodyBytes)
|
||
h.logger.WithField("id", correlationID).Warnf("⚠️ Stream error: status=%d, summary=%s",
|
||
resp.StatusCode, errorSummary)
|
||
}
|
||
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
|
||
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
|
||
}
|
||
// ✅ 返回错误,触发 ErrorHandler(但不写入响应,因为可能需要重试)
|
||
return fmt.Errorf("upstream error: status %d", resp.StatusCode)
|
||
}
|
||
// ✅ 配置 ErrorHandler
|
||
streamProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||
h.logger.WithField("id", correlationID).Debugf("Stream proxy error handler triggered: %v", err)
|
||
// ✅ 如果 attemptErr 未设置,根据错误类型创建
|
||
if attemptErr == nil {
|
||
isSuccess = false
|
||
if err == context.DeadlineExceeded {
|
||
attemptErr = errors.NewAPIError(errors.ErrGatewayTimeout, "Request timeout")
|
||
} else if err == context.Canceled {
|
||
attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Request canceled")
|
||
} else if errors.IsClientNetworkError(err) {
|
||
attemptErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||
} else {
|
||
attemptErr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||
}
|
||
}
|
||
// ✅ 不在这里写入响应,让外层重试逻辑决定
|
||
}
|
||
// ✅ 执行代理请求(可能抛出 ErrAbortHandler)
|
||
streamProxy.ServeHTTP(c.Writer, attemptReq)
|
||
// ✅ 正常返回(如果没有 panic)
|
||
return attemptErr, isSuccess
|
||
}
|
||
|
||
// ✅ 新增:判断是否应该记录详细错误体
|
||
func (h *ProxyHandler) shouldLogErrorBody(statusCode int) bool {
|
||
// ✅ 对于常见的客户端错误和限流错误,不记录详细错误体
|
||
commonErrors := map[int]bool{
|
||
400: true, // Bad Request
|
||
401: true, // Unauthorized
|
||
403: true, // Forbidden
|
||
404: true, // Not Found
|
||
429: true, // Too Many Requests
|
||
}
|
||
|
||
return !commonErrors[statusCode]
|
||
}
|
||
|
||
// ✅ 新增:从错误响应中提取简要信息
|
||
func (h *ProxyHandler) extractErrorSummary(bodyBytes []byte) string {
|
||
// ✅ 尝试解析 JSON 错误响应
|
||
var errorResp struct {
|
||
Error struct {
|
||
Message string `json:"message"`
|
||
Code int `json:"code"`
|
||
Status string `json:"status"`
|
||
} `json:"error"`
|
||
}
|
||
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil && errorResp.Error.Message != "" {
|
||
// ✅ 截取错误消息的前100个字符
|
||
message := errorResp.Error.Message
|
||
if len(message) > 100 {
|
||
message = message[:100] + "..."
|
||
}
|
||
if errorResp.Error.Status != "" {
|
||
return fmt.Sprintf("%s: %s", errorResp.Error.Status, message)
|
||
}
|
||
return message
|
||
}
|
||
// ✅ 如果无法解析 JSON,返回前100个字符
|
||
if len(bodyBytes) > 100 {
|
||
return string(bodyBytes[:100]) + "..."
|
||
}
|
||
return string(bodyBytes)
|
||
}
|
||
|
||
// ✅ 新增:发布流式重试日志事件
|
||
func (h *ProxyHandler) publishStreamRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, attemptErr *errors.APIError, retries int, isPrecise bool) {
|
||
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
|
||
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||
retryEvent.RequestLog.IsSuccess = false
|
||
retryEvent.RequestLog.Retries = retries
|
||
|
||
if attemptErr != nil {
|
||
retryEvent.Error = attemptErr
|
||
retryEvent.RequestLog.ErrorCode = attemptErr.Code
|
||
retryEvent.RequestLog.ErrorMessage = attemptErr.Message
|
||
retryEvent.RequestLog.Status = attemptErr.Status
|
||
retryEvent.RequestLog.StatusCode = attemptErr.HTTPStatus
|
||
}
|
||
|
||
eventData, err := json.Marshal(retryEvent)
|
||
if err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal stream retry log event")
|
||
return
|
||
}
|
||
|
||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish stream retry log event")
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
|
||
recorder := httptest.NewRecorder()
|
||
var attemptErr *errors.APIError
|
||
isSuccess := false
|
||
|
||
requestTimeout := time.Duration(h.settingsManager.GetSettings().RequestTimeoutSeconds) * time.Second
|
||
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||
defer cancel()
|
||
|
||
attemptReq := c.Request.Clone(ctx)
|
||
attemptReq.Body = io.NopCloser(bytes.NewReader(body))
|
||
attemptReq.ContentLength = int64(len(body))
|
||
|
||
h.logger.WithField("id", corrID).Infof("🚀 Starting proxy attempt with KeyID=%d", res.APIKey.ID)
|
||
|
||
h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
|
||
*attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
|
||
|
||
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
||
|
||
h.logger.WithField("id", corrID).Infof("📥 Proxy returned: status=%d, bodyLen=%d, err=%v, success=%v",
|
||
recorder.Code, recorder.Body.Len(), attemptErr, isSuccess)
|
||
|
||
// 调试检查 ✅
|
||
if recorder.Code == 0 && attemptErr != nil {
|
||
h.logger.WithField("id", corrID).Warnf("⚠️ Fixing zero status code to %d", attemptErr.HTTPStatus)
|
||
recorder.Code = attemptErr.HTTPStatus
|
||
if recorder.Body.Len() == 0 {
|
||
errJSON, _ := json.Marshal(gin.H{"error": attemptErr})
|
||
recorder.Body.Write(errJSON)
|
||
}
|
||
}
|
||
|
||
return recorder, attemptErr, isSuccess
|
||
}
|
||
|
||
func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) {
|
||
h.transparentProxy.Director = func(r *http.Request) {
|
||
targetURL, _ := url.Parse(res.UpstreamEndpoint.URL)
|
||
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
|
||
|
||
var pureClientPath string
|
||
if isPrecise {
|
||
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
|
||
} else {
|
||
pureClientPath = r.URL.Path
|
||
}
|
||
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||
|
||
r.Header.Del("Authorization")
|
||
h.channel.ModifyRequest(r, res.APIKey)
|
||
r.Header.Set("X-Correlation-ID", corrID)
|
||
}
|
||
|
||
transport := h.transparentProxy.Transport.(*http.Transport)
|
||
if res.ProxyConfig != nil {
|
||
proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address)
|
||
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
|
||
transport.Proxy = http.ProxyURL(proxyURL)
|
||
} else {
|
||
transport.Proxy = http.ProxyFromEnvironment
|
||
}
|
||
} else {
|
||
transport.Proxy = http.ProxyFromEnvironment
|
||
}
|
||
|
||
h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens)
|
||
}
|
||
|
||
func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
|
||
return func(resp *http.Response) error {
|
||
corrID := resp.Request.Header.Get("X-Correlation-ID")
|
||
|
||
h.logger.WithField("id", corrID).Infof("📨 Upstream response: status=%d, contentType=%s",
|
||
resp.StatusCode, resp.Header.Get("Content-Type"))
|
||
|
||
// 检查是否是流式响应
|
||
isStream := strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
|
||
|
||
// 处理 gzip 压缩
|
||
var reader io.ReadCloser = resp.Body
|
||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||
gzReader, err := gzip.NewReader(resp.Body)
|
||
if err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to create gzip reader")
|
||
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to decompress response")
|
||
*isSuccess = false
|
||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||
return nil
|
||
}
|
||
reader = gzReader
|
||
resp.Header.Del("Content-Encoding")
|
||
|
||
// ✅ 对于流式响应,需要替换 resp.Body 为解压后的 reader
|
||
if isStream {
|
||
resp.Body = reader
|
||
}
|
||
}
|
||
|
||
if isStream {
|
||
h.logger.WithField("id", corrID).Info("📡 Processing stream response")
|
||
|
||
if resp.StatusCode < 400 {
|
||
*isSuccess = true
|
||
h.logger.WithField("id", corrID).Info("✅ Stream response marked as success, passing through")
|
||
// 不关闭 reader,让它继续流式传输
|
||
return nil
|
||
} else {
|
||
// 错误响应才读取完整内容
|
||
bodyBytes, err := io.ReadAll(reader)
|
||
reader.Close()
|
||
if err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to read error response")
|
||
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream error")
|
||
} else {
|
||
h.logger.WithField("id", corrID).Errorf("❌ Stream error: status=%d, body=%s",
|
||
resp.StatusCode, string(bodyBytes))
|
||
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
|
||
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
|
||
}
|
||
*isSuccess = false
|
||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 非流式响应:读取完整内容
|
||
h.logger.WithField("id", corrID).Info("📄 Processing non-stream response")
|
||
|
||
bodyBytes, err := io.ReadAll(reader)
|
||
reader.Close()
|
||
|
||
if err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to read response body")
|
||
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
|
||
*isSuccess = false
|
||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||
return nil
|
||
}
|
||
|
||
if resp.StatusCode < 400 {
|
||
*isSuccess = true
|
||
*pTokens, *cTokens = extractUsage(bodyBytes)
|
||
h.logger.WithField("id", corrID).Infof("✅ Success: bytes=%d, pTokens=%d, cTokens=%d",
|
||
len(bodyBytes), *pTokens, *cTokens)
|
||
} else {
|
||
h.logger.WithField("id", corrID).Errorf("❌ Error: status=%d, body=%s",
|
||
resp.StatusCode, string(bodyBytes))
|
||
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode,
|
||
fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), bodyBytes)
|
||
*isSuccess = false
|
||
}
|
||
|
||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
||
corrID := r.Header.Get("X-Correlation-ID")
|
||
log := h.logger.WithField("id", corrID)
|
||
log.Errorf("Transparent proxy encountered an error: %v", err)
|
||
|
||
errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError)
|
||
if !ok || errPtr == nil {
|
||
log.Error("FATAL: proxyErrorContextKey not found in context for error handler.")
|
||
defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred")
|
||
writeErrorToResponse(rw, defaultErr)
|
||
return
|
||
}
|
||
|
||
if *errPtr == nil {
|
||
if errors.IsClientNetworkError(err) {
|
||
*errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||
} else {
|
||
*errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||
}
|
||
}
|
||
writeErrorToResponse(rw, *errPtr)
|
||
}
|
||
|
||
func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) {
|
||
if attempt == 1 {
|
||
return initialResources, nil
|
||
}
|
||
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
||
resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig)
|
||
resources.RequestConfig = finalRequestConfig
|
||
return resources, nil
|
||
}
|
||
|
||
func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool {
|
||
if attempt >= totalAttempts {
|
||
return true
|
||
}
|
||
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
// ✅ 不可重试的请求错误:立即停止
|
||
if errors.IsUnretryableRequestError(err.Message) {
|
||
h.logger.WithField("id", correlationID).Warnf("Unretryable request error, aborting: %s", err.Message)
|
||
return true
|
||
}
|
||
|
||
// ✅ 永久性上游错误:立即停止(Key 已失效)
|
||
if errors.IsPermanentUpstreamError(err.Message) {
|
||
h.logger.WithField("id", correlationID).Warnf("Permanent upstream error, aborting: %s", err.Message)
|
||
return false
|
||
}
|
||
|
||
// ✅ 可重试的网络错误:继续重试
|
||
if errors.IsRetryableNetworkError(err.Message) {
|
||
return false
|
||
}
|
||
|
||
// ✅ 临时性错误(配额等):继续重试
|
||
if errors.IsTemporaryUpstreamError(err.Message) {
|
||
return false
|
||
}
|
||
|
||
// ✅ 其他未分类错误:继续重试
|
||
return false
|
||
}
|
||
|
||
func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) {
|
||
if rec != nil {
|
||
for k, v := range rec.Header() {
|
||
c.Writer.Header()[k] = v
|
||
}
|
||
c.Writer.WriteHeader(rec.Code)
|
||
c.Writer.Write(rec.Body.Bytes())
|
||
} else if apiErr != nil {
|
||
errToJSON(c, corrID, apiErr)
|
||
} else {
|
||
errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred"))
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) {
|
||
if res == nil {
|
||
h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.")
|
||
return
|
||
}
|
||
event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise)
|
||
event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||
event.RequestLog.IsSuccess = isSuccess
|
||
event.RequestLog.Retries = retries
|
||
if isSuccess {
|
||
event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens
|
||
}
|
||
if rec != nil {
|
||
event.RequestLog.StatusCode = rec.Code
|
||
}
|
||
if !isSuccess {
|
||
errToLog := finalErr
|
||
if errToLog == nil && rec != nil {
|
||
errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), rec.Body.Bytes())
|
||
}
|
||
if errToLog != nil {
|
||
if errToLog.Code == "" && errToLog.HTTPStatus >= 400 {
|
||
errToLog.Code = fmt.Sprintf("UPSTREAM_%d", errToLog.HTTPStatus)
|
||
}
|
||
event.Error = errToLog
|
||
event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message
|
||
event.RequestLog.Status = errToLog.Status
|
||
}
|
||
}
|
||
eventData, err := json.Marshal(event)
|
||
if err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event")
|
||
return
|
||
}
|
||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event")
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) {
|
||
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
|
||
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||
retryEvent.RequestLog.IsSuccess = false
|
||
retryEvent.RequestLog.StatusCode = rec.Code
|
||
retryEvent.RequestLog.Retries = retries
|
||
if attemptErr != nil {
|
||
retryEvent.Error = attemptErr
|
||
retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message
|
||
retryEvent.RequestLog.Status = attemptErr.Status
|
||
}
|
||
eventData, err := json.Marshal(retryEvent)
|
||
if err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event")
|
||
return
|
||
}
|
||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event")
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
||
finalConfig := &models.RequestConfig{
|
||
CustomHeaders: make(datatypes.JSONMap),
|
||
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
||
StreamMinDelay: globalSettings.StreamMinDelay,
|
||
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
||
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
||
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
||
StreamChunkSize: globalSettings.StreamChunkSize,
|
||
EnableFakeStream: globalSettings.EnableFakeStream,
|
||
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
||
}
|
||
for k, v := range globalSettings.CustomHeaders {
|
||
finalConfig.CustomHeaders[k] = v
|
||
}
|
||
if groupConfig == nil {
|
||
return finalConfig
|
||
}
|
||
groupConfigJSON, err := json.Marshal(groupConfig)
|
||
if err != nil {
|
||
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
||
return finalConfig
|
||
}
|
||
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
||
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
||
}
|
||
return finalConfig
|
||
}
|
||
|
||
func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) {
|
||
if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() {
|
||
return
|
||
}
|
||
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||
rw.WriteHeader(apiErr.HTTPStatus)
|
||
json.NewEncoder(rw).Encode(gin.H{"error": apiErr})
|
||
}
|
||
|
||
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
|
||
startTime := time.Now()
|
||
correlationID := uuid.New().String()
|
||
log := h.logger.WithField("id", correlationID)
|
||
log.Info("Smart Gateway activated for streaming request.")
|
||
var originalRequest models.GeminiRequest
|
||
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
|
||
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway"))
|
||
return
|
||
}
|
||
systemSettings := h.settingsManager.GetSettings()
|
||
modelName := h.channel.ExtractModel(c, requestBody)
|
||
requestFinishedEvent := h.createLogEvent(c, startTime, correlationID, modelName, resources, models.LogTypeFinal, isPreciseRouting)
|
||
defer func() {
|
||
requestFinishedEvent.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||
if c.Writer.Status() > 0 {
|
||
requestFinishedEvent.StatusCode = c.Writer.Status()
|
||
}
|
||
eventData, err := json.Marshal(requestFinishedEvent)
|
||
if err != nil {
|
||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream")
|
||
return
|
||
}
|
||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream")
|
||
}
|
||
}()
|
||
params := channel.SmartRequestParams{
|
||
CorrelationID: correlationID,
|
||
APIKey: resources.APIKey,
|
||
UpstreamURL: resources.UpstreamEndpoint.URL,
|
||
RequestBody: requestBody,
|
||
OriginalRequest: originalRequest,
|
||
EventLogger: requestFinishedEvent,
|
||
MaxRetries: systemSettings.MaxStreamingRetries,
|
||
RetryDelay: time.Duration(systemSettings.StreamingRetryDelayMs) * time.Millisecond,
|
||
LogTruncationLimit: systemSettings.LogTruncationLimit,
|
||
StreamingRetryPrompt: systemSettings.StreamingRetryPrompt,
|
||
}
|
||
h.channel.ProcessSmartStreamRequest(c, params)
|
||
}
|
||
|
||
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
|
||
event := &models.RequestFinishedEvent{
|
||
RequestLog: models.RequestLog{
|
||
RequestTime: startTime,
|
||
ModelName: modelName,
|
||
RequestPath: c.Request.URL.Path,
|
||
UserAgent: c.Request.UserAgent(),
|
||
CorrelationID: corrID,
|
||
LogType: logType,
|
||
Metadata: make(datatypes.JSONMap),
|
||
},
|
||
CorrelationID: corrID,
|
||
IsPreciseRouting: isPreciseRouting,
|
||
}
|
||
if _, exists := c.Get(middleware.RedactedBodyKey); exists {
|
||
event.RequestLog.Metadata["request_body_present"] = true
|
||
}
|
||
if redactedAuth, exists := c.Get(middleware.RedactedAuthHeaderKey); exists {
|
||
event.RequestLog.Metadata["authorization_header"] = redactedAuth.(string)
|
||
}
|
||
if authTokenValue, exists := c.Get("authToken"); exists {
|
||
if authToken, ok := authTokenValue.(*models.AuthToken); ok {
|
||
event.RequestLog.AuthTokenID = &authToken.ID
|
||
}
|
||
}
|
||
if res != nil {
|
||
if res.APIKey != nil {
|
||
event.RequestLog.KeyID = &res.APIKey.ID
|
||
}
|
||
if res.KeyGroup != nil {
|
||
event.RequestLog.GroupID = &res.KeyGroup.ID
|
||
}
|
||
if res.UpstreamEndpoint != nil {
|
||
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
|
||
event.UpstreamURL = &res.UpstreamEndpoint.URL
|
||
}
|
||
if res.ProxyConfig != nil {
|
||
event.RequestLog.ProxyID = &res.ProxyConfig.ID
|
||
}
|
||
}
|
||
return event
|
||
}
|
||
|
||
func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, groupName string, isPreciseRouting bool) (*service.RequestResources, error) {
|
||
authTokenValue, exists := c.Get("authToken")
|
||
if !exists {
|
||
return nil, errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context")
|
||
}
|
||
authToken, ok := authTokenValue.(*models.AuthToken)
|
||
if !ok {
|
||
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
|
||
}
|
||
if isPreciseRouting {
|
||
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
|
||
}
|
||
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
|
||
}
|
||
|
||
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
||
if c.IsAborted() {
|
||
return
|
||
}
|
||
c.JSON(apiErr.HTTPStatus, gin.H{
|
||
"error": apiErr,
|
||
"correlation_id": corrID,
|
||
})
|
||
}
|
||
|
||
type bufferPool struct{}
|
||
|
||
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
||
func (b *bufferPool) Put(_ []byte) {}
|
||
|
||
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
||
var data struct {
|
||
UsageMetadata struct {
|
||
PromptTokenCount int `json:"promptTokenCount"`
|
||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||
} `json:"usageMetadata"`
|
||
}
|
||
if err := json.Unmarshal(body, &data); err == nil {
|
||
return data.UsageMetadata.PromptTokenCount, data.UsageMetadata.CandidatesTokenCount
|
||
}
|
||
return 0, 0
|
||
}
|
||
|
||
func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int {
|
||
if isPreciseRouting && finalOpConfig.MaxRetries != nil {
|
||
return *finalOpConfig.MaxRetries
|
||
}
|
||
return h.settingsManager.GetSettings().MaxRetries
|
||
}
|
||
|
||
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
||
authTokenValue, exists := c.Get("authToken")
|
||
if !exists {
|
||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context"))
|
||
return
|
||
}
|
||
authToken, ok := authTokenValue.(*models.AuthToken)
|
||
if !ok {
|
||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context"))
|
||
return
|
||
}
|
||
modelNames := h.resourceService.GetAllowedModelsForToken(authToken)
|
||
if strings.Contains(c.Request.URL.Path, "/v1beta/") {
|
||
h.respondWithGeminiFormat(c, modelNames)
|
||
} else {
|
||
h.respondWithOpenAIFormat(c, modelNames)
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) respondWithOpenAIFormat(c *gin.Context, modelNames []string) {
|
||
type ModelEntry struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
OwnedBy string `json:"owned_by"`
|
||
}
|
||
type ModelListResponse struct {
|
||
Object string `json:"object"`
|
||
Data []ModelEntry `json:"data"`
|
||
}
|
||
data := make([]ModelEntry, len(modelNames))
|
||
for i, name := range modelNames {
|
||
data[i] = ModelEntry{
|
||
ID: name,
|
||
Object: "model",
|
||
Created: time.Now().Unix(),
|
||
OwnedBy: "gemini-balancer",
|
||
}
|
||
}
|
||
response := ModelListResponse{
|
||
Object: "list",
|
||
Data: data,
|
||
}
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
func (h *ProxyHandler) respondWithGeminiFormat(c *gin.Context, modelNames []string) {
|
||
type GeminiModelEntry struct {
|
||
Name string `json:"name"`
|
||
Version string `json:"version"`
|
||
DisplayName string `json:"displayName"`
|
||
Description string `json:"description"`
|
||
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
||
InputTokenLimit int `json:"inputTokenLimit"`
|
||
OutputTokenLimit int `json:"outputTokenLimit"`
|
||
}
|
||
type GeminiModelListResponse struct {
|
||
Models []GeminiModelEntry `json:"models"`
|
||
}
|
||
models := make([]GeminiModelEntry, len(modelNames))
|
||
for i, name := range modelNames {
|
||
models[i] = GeminiModelEntry{
|
||
Name: fmt.Sprintf("models/%s", name),
|
||
Version: "1.0.0",
|
||
DisplayName: name,
|
||
Description: "Served by Gemini Balancer",
|
||
SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"},
|
||
InputTokenLimit: 8192,
|
||
OutputTokenLimit: 2048,
|
||
}
|
||
}
|
||
response := GeminiModelListResponse{Models: models}
|
||
c.JSON(http.StatusOK, response)
|
||
}
|