Fix basepool & 优化 repo
This commit is contained in:
@@ -29,9 +29,7 @@ import (
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type proxyErrorKey int
|
||||
|
||||
const proxyErrKey proxyErrorKey = 0
|
||||
type proxyErrorContextKey struct{}
|
||||
|
||||
type ProxyHandler struct {
|
||||
resourceService *service.ResourceService
|
||||
@@ -81,45 +79,51 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
h.handleListModelsRequest(c)
|
||||
return
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
|
||||
maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024)
|
||||
requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
|
||||
if err != nil {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body"))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read"))
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
c.Request.ContentLength = int64(len(requestBody))
|
||||
|
||||
modelName := h.channel.ExtractModel(c, requestBody)
|
||||
groupName := c.Param("group_name")
|
||||
isPreciseRouting := groupName != ""
|
||||
|
||||
if !isPreciseRouting && modelName == "" {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL"))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request"))
|
||||
return
|
||||
}
|
||||
|
||||
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*errors.APIError); ok {
|
||||
errToJSON(c, uuid.New().String(), apiErr)
|
||||
} else {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error()))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources"))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to build operational config.")
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration"))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config"))
|
||||
return
|
||||
}
|
||||
|
||||
initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig)
|
||||
|
||||
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
|
||||
if isOpenAICompatible {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
return
|
||||
}
|
||||
|
||||
isStream := h.channel.IsStreamRequest(c, requestBody)
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
|
||||
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry {
|
||||
if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry {
|
||||
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
|
||||
} else {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
@@ -129,219 +133,307 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
|
||||
var finalRecorder *httptest.ResponseRecorder
|
||||
var lastUsedResources *service.RequestResources
|
||||
var finalProxyErr *errors.APIError
|
||||
var isSuccess bool
|
||||
var finalPromptTokens, finalCompletionTokens int
|
||||
var actualRetries int = 0
|
||||
defer func() {
|
||||
if lastUsedResources == nil {
|
||||
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
|
||||
return
|
||||
}
|
||||
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
|
||||
var finalPromptTokens, finalCompletionTokens, actualRetries int
|
||||
|
||||
finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
finalEvent.RequestLog.IsSuccess = isSuccess
|
||||
finalEvent.RequestLog.Retries = actualRetries
|
||||
if isSuccess {
|
||||
finalEvent.RequestLog.PromptTokens = finalPromptTokens
|
||||
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
|
||||
}
|
||||
defer h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
|
||||
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
|
||||
actualRetries, isPreciseRouting)
|
||||
|
||||
if finalRecorder != nil {
|
||||
finalEvent.RequestLog.StatusCode = finalRecorder.Code
|
||||
}
|
||||
if !isSuccess {
|
||||
if finalProxyErr != nil {
|
||||
finalEvent.Error = finalProxyErr
|
||||
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
|
||||
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
|
||||
} else if finalRecorder != nil {
|
||||
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
|
||||
finalEvent.Error = apiErr
|
||||
finalEvent.RequestLog.ErrorCode = apiErr.Code
|
||||
finalEvent.RequestLog.ErrorMessage = apiErr.Message
|
||||
}
|
||||
}
|
||||
eventData, err := json.Marshal(finalEvent)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
|
||||
}
|
||||
}()
|
||||
var maxRetries int
|
||||
if isPreciseRouting {
|
||||
if finalOpConfig.MaxRetries != nil {
|
||||
maxRetries = *finalOpConfig.MaxRetries
|
||||
} else {
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
} else {
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
|
||||
totalAttempts := maxRetries + 1
|
||||
|
||||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||||
if c.Request.Context().Err() != nil {
|
||||
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
|
||||
if finalProxyErr == nil {
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
|
||||
}
|
||||
break
|
||||
}
|
||||
var currentResources *service.RequestResources
|
||||
var err error
|
||||
if attempt == 1 {
|
||||
currentResources = initialResources
|
||||
} else {
|
||||
actualRetries = attempt - 1
|
||||
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
||||
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
|
||||
break
|
||||
|
||||
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*errors.APIError); ok {
|
||||
finalProxyErr = apiErr
|
||||
} else {
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
|
||||
}
|
||||
break
|
||||
}
|
||||
lastUsedResources = resources
|
||||
if attempt > 1 {
|
||||
actualRetries = attempt - 1
|
||||
}
|
||||
|
||||
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig)
|
||||
currentResources.RequestConfig = finalRequestConfig
|
||||
lastUsedResources = currentResources
|
||||
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID)
|
||||
var attemptErr *errors.APIError
|
||||
var attemptIsSuccess bool
|
||||
recorder := httptest.NewRecorder()
|
||||
attemptStartTime := time.Now()
|
||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
||||
defer cancel()
|
||||
attemptReq := c.Request.Clone(ctx)
|
||||
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" {
|
||||
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
|
||||
isSuccess = false
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource")
|
||||
continue
|
||||
}
|
||||
h.transparentProxy.Director = func(req *http.Request) {
|
||||
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL)
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.Host = targetURL.Host
|
||||
var pureClientPath string
|
||||
if isPreciseRouting {
|
||||
proxyPrefix := "/proxy/" + groupName
|
||||
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
|
||||
} else {
|
||||
pureClientPath = req.URL.Path
|
||||
}
|
||||
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||||
req.URL.Path = finalPath
|
||||
h.logger.WithFields(logrus.Fields{
|
||||
"correlation_id": correlationID,
|
||||
"attempt": attempt,
|
||||
"key_id": currentResources.APIKey.ID,
|
||||
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
|
||||
"final_request_url": req.URL.String(),
|
||||
}).Infof("Director constructed final upstream request URL.")
|
||||
req.Header.Del("Authorization")
|
||||
h.channel.ModifyRequest(req, currentResources.APIKey)
|
||||
req.Header.Set("X-Correlation-ID", correlationID)
|
||||
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
|
||||
}
|
||||
transport := h.transparentProxy.Transport.(*http.Transport)
|
||||
if currentResources.ProxyConfig != nil {
|
||||
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address)
|
||||
proxyURL, err := url.Parse(proxyURLStr)
|
||||
if err == nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
defer resp.Body.Close()
|
||||
var reader io.ReadCloser
|
||||
var err error
|
||||
isGzipped := resp.Header.Get("Content-Encoding") == "gzip"
|
||||
if isGzipped {
|
||||
reader, err = gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to create gzip reader")
|
||||
reader = resp.Body
|
||||
} else {
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
defer reader.Close()
|
||||
} else {
|
||||
reader = resp.Body
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error())
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message)))
|
||||
return nil
|
||||
}
|
||||
if resp.StatusCode < 400 {
|
||||
attemptIsSuccess = true
|
||||
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes)
|
||||
} else {
|
||||
parsedMsg := errors.ParseUpstreamError(bodyBytes)
|
||||
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
||||
finalRecorder = recorder
|
||||
finalProxyErr = attemptErr
|
||||
isSuccess = attemptIsSuccess
|
||||
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
|
||||
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID)
|
||||
|
||||
recorder, attemptErr, attemptSuccess := h.executeProxyAttempt(
|
||||
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
|
||||
&finalPromptTokens, &finalCompletionTokens,
|
||||
)
|
||||
|
||||
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
|
||||
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
|
||||
|
||||
if isSuccess {
|
||||
break
|
||||
}
|
||||
isUnretryableError := false
|
||||
if finalProxyErr != nil {
|
||||
if errors.IsUnretryableRequestError(finalProxyErr.Message) {
|
||||
isUnretryableError = true
|
||||
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message)
|
||||
}
|
||||
}
|
||||
if attempt >= totalAttempts || isUnretryableError {
|
||||
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
|
||||
break
|
||||
}
|
||||
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
|
||||
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds())
|
||||
retryEvent.IsSuccess = false
|
||||
retryEvent.StatusCode = recorder.Code
|
||||
retryEvent.Retries = actualRetries
|
||||
if attemptErr != nil {
|
||||
retryEvent.Error = attemptErr
|
||||
retryEvent.ErrorCode = attemptErr.Code
|
||||
retryEvent.ErrorMessage = attemptErr.Message
|
||||
}
|
||||
eventData, _ := json.Marshal(retryEvent)
|
||||
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
|
||||
h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting)
|
||||
}
|
||||
if finalRecorder != nil {
|
||||
bodyBytes := finalRecorder.Body.Bytes()
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
|
||||
for k, v := range finalRecorder.Header() {
|
||||
if strings.ToLower(k) != "content-length" {
|
||||
c.Writer.Header()[k] = v
|
||||
|
||||
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
|
||||
recorder := httptest.NewRecorder()
|
||||
var attemptErr *errors.APIError
|
||||
var isSuccess bool
|
||||
|
||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
attemptReq := c.Request.Clone(ctx)
|
||||
attemptReq.Body = io.NopCloser(bytes.NewReader(body))
|
||||
attemptReq.ContentLength = int64(len(body))
|
||||
|
||||
h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
|
||||
*attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
|
||||
|
||||
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
||||
|
||||
return recorder, attemptErr, isSuccess
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) {
|
||||
h.transparentProxy.Director = func(r *http.Request) {
|
||||
targetURL, _ := url.Parse(res.UpstreamEndpoint.URL)
|
||||
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
|
||||
|
||||
var pureClientPath string
|
||||
if isPrecise {
|
||||
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
|
||||
} else {
|
||||
pureClientPath = r.URL.Path
|
||||
}
|
||||
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||||
|
||||
r.Header.Del("Authorization")
|
||||
h.channel.ModifyRequest(r, res.APIKey)
|
||||
r.Header.Set("X-Correlation-ID", corrID)
|
||||
}
|
||||
|
||||
transport := h.transparentProxy.Transport.(*http.Transport)
|
||||
if res.ProxyConfig != nil {
|
||||
proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address)
|
||||
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
|
||||
return func(resp *http.Response) error {
|
||||
var reader io.ReadCloser = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to create gzip reader")
|
||||
} else {
|
||||
reader = gzReader
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
}
|
||||
c.Writer.WriteHeader(finalRecorder.Code)
|
||||
c.Writer.Write(finalRecorder.Body.Bytes())
|
||||
} else {
|
||||
errToJSON(c, correlationID, finalProxyErr)
|
||||
defer reader.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode < 400 {
|
||||
*isSuccess = true
|
||||
*pTokens, *cTokens = extractUsage(bodyBytes)
|
||||
} else {
|
||||
parsedMsg := errors.ParseUpstreamError(bodyBytes)
|
||||
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
corrID := r.Header.Get("X-Correlation-ID")
|
||||
log := h.logger.WithField("id", corrID)
|
||||
log.Errorf("Transparent proxy encountered an error: %v", err)
|
||||
|
||||
errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError)
|
||||
if !ok || errPtr == nil {
|
||||
log.Error("FATAL: proxyErrorContextKey not found in context for error handler.")
|
||||
defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred")
|
||||
writeErrorToResponse(rw, defaultErr)
|
||||
return
|
||||
}
|
||||
|
||||
if *errPtr == nil {
|
||||
if errors.IsClientNetworkError(err) {
|
||||
*errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
} else {
|
||||
*errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||||
}
|
||||
}
|
||||
writeErrorToResponse(rw, *errPtr)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) {
|
||||
if attempt == 1 {
|
||||
return initialResources, nil
|
||||
}
|
||||
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
||||
resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig)
|
||||
resources.RequestConfig = finalRequestConfig
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool {
|
||||
if attempt >= totalAttempts {
|
||||
return true
|
||||
}
|
||||
if err != nil && errors.IsUnretryableRequestError(err.Message) {
|
||||
h.logger.WithField("id", correlationID).Warnf("Attempt failed with unretryable request error. Aborting retries. Message: %s", err.Message)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) {
|
||||
if rec != nil {
|
||||
for k, v := range rec.Header() {
|
||||
c.Writer.Header()[k] = v
|
||||
}
|
||||
c.Writer.WriteHeader(rec.Code)
|
||||
c.Writer.Write(rec.Body.Bytes())
|
||||
} else if apiErr != nil {
|
||||
errToJSON(c, corrID, apiErr)
|
||||
} else {
|
||||
errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred"))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) {
|
||||
if res == nil {
|
||||
h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.")
|
||||
return
|
||||
}
|
||||
event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise)
|
||||
event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
event.RequestLog.IsSuccess = isSuccess
|
||||
event.RequestLog.Retries = retries
|
||||
if isSuccess {
|
||||
event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens
|
||||
}
|
||||
if rec != nil {
|
||||
event.RequestLog.StatusCode = rec.Code
|
||||
}
|
||||
if !isSuccess {
|
||||
errToLog := finalErr
|
||||
if errToLog == nil && rec != nil {
|
||||
errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), "Request failed after all retries.")
|
||||
}
|
||||
if errToLog != nil {
|
||||
event.Error = errToLog
|
||||
event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message
|
||||
}
|
||||
}
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) {
|
||||
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
|
||||
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
retryEvent.RequestLog.IsSuccess = false
|
||||
retryEvent.RequestLog.StatusCode = rec.Code
|
||||
retryEvent.RequestLog.Retries = retries
|
||||
if attemptErr != nil {
|
||||
retryEvent.Error = attemptErr
|
||||
retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message
|
||||
}
|
||||
eventData, err := json.Marshal(retryEvent)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
||||
finalConfig := &models.RequestConfig{
|
||||
CustomHeaders: make(datatypes.JSONMap),
|
||||
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
||||
StreamMinDelay: globalSettings.StreamMinDelay,
|
||||
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
||||
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
||||
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
||||
StreamChunkSize: globalSettings.StreamChunkSize,
|
||||
EnableFakeStream: globalSettings.EnableFakeStream,
|
||||
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
||||
}
|
||||
for k, v := range globalSettings.CustomHeaders {
|
||||
finalConfig.CustomHeaders[k] = v
|
||||
}
|
||||
if groupConfig == nil {
|
||||
return finalConfig
|
||||
}
|
||||
groupConfigJSON, err := json.Marshal(groupConfig)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
||||
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
||||
}
|
||||
return finalConfig
|
||||
}
|
||||
|
||||
func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) {
|
||||
if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() {
|
||||
return
|
||||
}
|
||||
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
rw.WriteHeader(apiErr.HTTPStatus)
|
||||
json.NewEncoder(rw).Encode(gin.H{"error": apiErr})
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
@@ -349,7 +441,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
|
||||
log.Info("Smart Gateway activated for streaming request.")
|
||||
var originalRequest models.GeminiRequest
|
||||
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
|
||||
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error()))
|
||||
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway"))
|
||||
return
|
||||
}
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
@@ -360,8 +452,14 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
|
||||
if c.Writer.Status() > 0 {
|
||||
requestFinishedEvent.StatusCode = c.Writer.Status()
|
||||
}
|
||||
eventData, _ := json.Marshal(requestFinishedEvent)
|
||||
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
|
||||
eventData, err := json.Marshal(requestFinishedEvent)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream")
|
||||
}
|
||||
}()
|
||||
params := channel.SmartRequestParams{
|
||||
CorrelationID: correlationID,
|
||||
@@ -378,30 +476,6 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
|
||||
h.channel.ProcessSmartStreamRequest(c, params)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
correlationID := r.Header.Get("X-Correlation-ID")
|
||||
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
|
||||
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
|
||||
if !exists || proxyErrPtr == nil {
|
||||
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
|
||||
return
|
||||
}
|
||||
if errors.IsClientNetworkError(err) {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
} else {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||||
}
|
||||
if _, ok := rw.(*httptest.ResponseRecorder); ok {
|
||||
return
|
||||
}
|
||||
if writer, ok := rw.(interface{ Written() bool }); ok {
|
||||
if writer.Written() {
|
||||
return
|
||||
}
|
||||
}
|
||||
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
|
||||
event := &models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
@@ -456,12 +530,14 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
|
||||
}
|
||||
if isPreciseRouting {
|
||||
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
|
||||
} else {
|
||||
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
|
||||
}
|
||||
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
|
||||
}
|
||||
|
||||
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{
|
||||
"error": apiErr,
|
||||
"correlation_id": corrID,
|
||||
@@ -470,8 +546,8 @@ func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
||||
|
||||
type bufferPool struct{}
|
||||
|
||||
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
||||
func (b *bufferPool) Put(bytes []byte) {}
|
||||
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
||||
func (b *bufferPool) Put(_ []byte) {}
|
||||
|
||||
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
||||
var data struct {
|
||||
@@ -486,34 +562,11 @@ func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
||||
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders)
|
||||
var customHeadersMap datatypes.JSONMap
|
||||
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
|
||||
finalConfig := &models.RequestConfig{
|
||||
CustomHeaders: customHeadersMap,
|
||||
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
||||
StreamMinDelay: globalSettings.StreamMinDelay,
|
||||
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
||||
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
||||
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
||||
StreamChunkSize: globalSettings.StreamChunkSize,
|
||||
EnableFakeStream: globalSettings.EnableFakeStream,
|
||||
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
||||
func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int {
|
||||
if isPreciseRouting && finalOpConfig.MaxRetries != nil {
|
||||
return *finalOpConfig.MaxRetries
|
||||
}
|
||||
if groupConfig == nil {
|
||||
return finalConfig
|
||||
}
|
||||
groupConfigJSON, err := json.Marshal(groupConfig)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
||||
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
return finalConfig
|
||||
return h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
||||
|
||||
Reference in New Issue
Block a user