593 lines
22 KiB
Go
593 lines
22 KiB
Go
// Filename: internal/handlers/proxy_handler.go
|
|
package handlers
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/channel"
|
|
"gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/middleware"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/service"
|
|
"gemini-balancer/internal/settings"
|
|
"gemini-balancer/internal/store"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/datatypes"
|
|
)
|
|
|
|
type proxyErrorKey int
|
|
|
|
const proxyErrKey proxyErrorKey = 0
|
|
|
|
type ProxyHandler struct {
|
|
resourceService *service.ResourceService
|
|
store store.Store
|
|
settingsManager *settings.SettingsManager
|
|
groupManager *service.GroupManager
|
|
channel channel.ChannelProxy
|
|
logger *logrus.Entry
|
|
transparentProxy *httputil.ReverseProxy
|
|
}
|
|
|
|
func NewProxyHandler(
|
|
resourceService *service.ResourceService,
|
|
store store.Store,
|
|
sm *settings.SettingsManager,
|
|
gm *service.GroupManager,
|
|
channel channel.ChannelProxy,
|
|
logger *logrus.Logger,
|
|
) *ProxyHandler {
|
|
ph := &ProxyHandler{
|
|
resourceService: resourceService,
|
|
store: store,
|
|
settingsManager: sm,
|
|
groupManager: gm,
|
|
channel: channel,
|
|
logger: logger.WithField("component", "ProxyHandler"),
|
|
transparentProxy: &httputil.ReverseProxy{},
|
|
}
|
|
ph.transparentProxy.Transport = &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 60 * time.Second,
|
|
}).DialContext,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
ph.transparentProxy.ErrorHandler = ph.transparentProxyErrorHandler
|
|
ph.transparentProxy.BufferPool = &bufferPool{}
|
|
return ph
|
|
}
|
|
|
|
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|
if c.Request.Method == "GET" && (strings.HasSuffix(c.Request.URL.Path, "/models") || strings.HasSuffix(c.Request.URL.Path, "/models/")) {
|
|
h.handleListModelsRequest(c)
|
|
return
|
|
}
|
|
requestBody, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body"))
|
|
return
|
|
}
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
|
|
c.Request.ContentLength = int64(len(requestBody))
|
|
modelName := h.channel.ExtractModel(c, requestBody)
|
|
groupName := c.Param("group_name")
|
|
isPreciseRouting := groupName != ""
|
|
if !isPreciseRouting && modelName == "" {
|
|
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL"))
|
|
return
|
|
}
|
|
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
|
if err != nil {
|
|
if apiErr, ok := err.(*errors.APIError); ok {
|
|
errToJSON(c, uuid.New().String(), apiErr)
|
|
} else {
|
|
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error()))
|
|
}
|
|
return
|
|
}
|
|
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
|
|
if err != nil {
|
|
h.logger.WithError(err).Error("Failed to build operational config.")
|
|
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration"))
|
|
return
|
|
}
|
|
|
|
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
|
|
if isOpenAICompatible {
|
|
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
|
return
|
|
}
|
|
isStream := h.channel.IsStreamRequest(c, requestBody)
|
|
systemSettings := h.settingsManager.GetSettings()
|
|
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
|
|
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry {
|
|
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
|
|
} else {
|
|
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
|
}
|
|
}
|
|
|
|
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
|
|
startTime := time.Now()
|
|
correlationID := uuid.New().String()
|
|
var finalRecorder *httptest.ResponseRecorder
|
|
var lastUsedResources *service.RequestResources
|
|
var finalProxyErr *errors.APIError
|
|
var isSuccess bool
|
|
var finalPromptTokens, finalCompletionTokens int
|
|
var actualRetries int = 0
|
|
defer func() {
|
|
if lastUsedResources == nil {
|
|
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
|
|
return
|
|
}
|
|
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
|
|
|
|
finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
|
finalEvent.RequestLog.IsSuccess = isSuccess
|
|
finalEvent.RequestLog.Retries = actualRetries
|
|
if isSuccess {
|
|
finalEvent.RequestLog.PromptTokens = finalPromptTokens
|
|
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
|
|
}
|
|
|
|
if finalRecorder != nil {
|
|
finalEvent.RequestLog.StatusCode = finalRecorder.Code
|
|
}
|
|
if !isSuccess {
|
|
if finalProxyErr != nil {
|
|
finalEvent.Error = finalProxyErr
|
|
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
|
|
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
|
|
} else if finalRecorder != nil {
|
|
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
|
|
finalEvent.Error = apiErr
|
|
finalEvent.RequestLog.ErrorCode = apiErr.Code
|
|
finalEvent.RequestLog.ErrorMessage = apiErr.Message
|
|
}
|
|
}
|
|
eventData, err := json.Marshal(finalEvent)
|
|
if err != nil {
|
|
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
|
|
return
|
|
}
|
|
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
|
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
|
|
}
|
|
}()
|
|
var maxRetries int
|
|
if isPreciseRouting {
|
|
if finalOpConfig.MaxRetries != nil {
|
|
maxRetries = *finalOpConfig.MaxRetries
|
|
} else {
|
|
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
|
}
|
|
} else {
|
|
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
|
}
|
|
totalAttempts := maxRetries + 1
|
|
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
|
if c.Request.Context().Err() != nil {
|
|
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
|
|
if finalProxyErr == nil {
|
|
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
|
}
|
|
break
|
|
}
|
|
var currentResources *service.RequestResources
|
|
var err error
|
|
if attempt == 1 {
|
|
currentResources = initialResources
|
|
} else {
|
|
actualRetries = attempt - 1
|
|
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
|
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
|
if err != nil {
|
|
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
|
|
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
|
|
break
|
|
}
|
|
}
|
|
|
|
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig)
|
|
currentResources.RequestConfig = finalRequestConfig
|
|
lastUsedResources = currentResources
|
|
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID)
|
|
var attemptErr *errors.APIError
|
|
var attemptIsSuccess bool
|
|
recorder := httptest.NewRecorder()
|
|
attemptStartTime := time.Now()
|
|
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
|
defer cancel()
|
|
attemptReq := c.Request.Clone(ctx)
|
|
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
|
|
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" {
|
|
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
|
|
isSuccess = false
|
|
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource")
|
|
continue
|
|
}
|
|
h.transparentProxy.Director = func(req *http.Request) {
|
|
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL)
|
|
req.URL.Scheme = targetURL.Scheme
|
|
req.URL.Host = targetURL.Host
|
|
req.Host = targetURL.Host
|
|
var pureClientPath string
|
|
if isPreciseRouting {
|
|
proxyPrefix := "/proxy/" + groupName
|
|
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
|
|
} else {
|
|
pureClientPath = req.URL.Path
|
|
}
|
|
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath)
|
|
req.URL.Path = finalPath
|
|
h.logger.WithFields(logrus.Fields{
|
|
"correlation_id": correlationID,
|
|
"attempt": attempt,
|
|
"key_id": currentResources.APIKey.ID,
|
|
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
|
|
"final_request_url": req.URL.String(),
|
|
}).Infof("Director constructed final upstream request URL.")
|
|
req.Header.Del("Authorization")
|
|
h.channel.ModifyRequest(req, currentResources.APIKey)
|
|
req.Header.Set("X-Correlation-ID", correlationID)
|
|
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
|
|
}
|
|
transport := h.transparentProxy.Transport.(*http.Transport)
|
|
if currentResources.ProxyConfig != nil {
|
|
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address)
|
|
proxyURL, err := url.Parse(proxyURLStr)
|
|
if err == nil {
|
|
transport.Proxy = http.ProxyURL(proxyURL)
|
|
}
|
|
} else {
|
|
transport.Proxy = http.ProxyFromEnvironment
|
|
}
|
|
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
|
|
defer resp.Body.Close()
|
|
var reader io.ReadCloser
|
|
var err error
|
|
isGzipped := resp.Header.Get("Content-Encoding") == "gzip"
|
|
if isGzipped {
|
|
reader, err = gzip.NewReader(resp.Body)
|
|
if err != nil {
|
|
h.logger.WithError(err).Error("Failed to create gzip reader")
|
|
reader = resp.Body
|
|
} else {
|
|
resp.Header.Del("Content-Encoding")
|
|
}
|
|
defer reader.Close()
|
|
} else {
|
|
reader = resp.Body
|
|
}
|
|
bodyBytes, err := io.ReadAll(reader)
|
|
if err != nil {
|
|
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error())
|
|
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message)))
|
|
return nil
|
|
}
|
|
if resp.StatusCode < 400 {
|
|
attemptIsSuccess = true
|
|
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes)
|
|
} else {
|
|
parsedMsg := errors.ParseUpstreamError(bodyBytes)
|
|
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
|
|
}
|
|
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
return nil
|
|
}
|
|
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
|
finalRecorder = recorder
|
|
finalProxyErr = attemptErr
|
|
isSuccess = attemptIsSuccess
|
|
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
|
|
if isSuccess {
|
|
break
|
|
}
|
|
isUnretryableError := false
|
|
if finalProxyErr != nil {
|
|
if errors.IsUnretryableRequestError(finalProxyErr.Message) {
|
|
isUnretryableError = true
|
|
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message)
|
|
}
|
|
}
|
|
if attempt >= totalAttempts || isUnretryableError {
|
|
break
|
|
}
|
|
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
|
|
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds())
|
|
retryEvent.IsSuccess = false
|
|
retryEvent.StatusCode = recorder.Code
|
|
retryEvent.Retries = actualRetries
|
|
if attemptErr != nil {
|
|
retryEvent.Error = attemptErr
|
|
retryEvent.ErrorCode = attemptErr.Code
|
|
retryEvent.ErrorMessage = attemptErr.Message
|
|
}
|
|
eventData, _ := json.Marshal(retryEvent)
|
|
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
|
|
}
|
|
if finalRecorder != nil {
|
|
bodyBytes := finalRecorder.Body.Bytes()
|
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
|
|
for k, v := range finalRecorder.Header() {
|
|
if strings.ToLower(k) != "content-length" {
|
|
c.Writer.Header()[k] = v
|
|
}
|
|
}
|
|
c.Writer.WriteHeader(finalRecorder.Code)
|
|
c.Writer.Write(finalRecorder.Body.Bytes())
|
|
} else {
|
|
errToJSON(c, correlationID, finalProxyErr)
|
|
}
|
|
}
|
|
|
|
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
|
|
startTime := time.Now()
|
|
correlationID := uuid.New().String()
|
|
log := h.logger.WithField("id", correlationID)
|
|
log.Info("Smart Gateway activated for streaming request.")
|
|
var originalRequest models.GeminiRequest
|
|
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
|
|
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error()))
|
|
return
|
|
}
|
|
systemSettings := h.settingsManager.GetSettings()
|
|
modelName := h.channel.ExtractModel(c, requestBody)
|
|
requestFinishedEvent := h.createLogEvent(c, startTime, correlationID, modelName, resources, models.LogTypeFinal, isPreciseRouting)
|
|
defer func() {
|
|
requestFinishedEvent.LatencyMs = int(time.Since(startTime).Milliseconds())
|
|
if c.Writer.Status() > 0 {
|
|
requestFinishedEvent.StatusCode = c.Writer.Status()
|
|
}
|
|
eventData, _ := json.Marshal(requestFinishedEvent)
|
|
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
|
|
}()
|
|
params := channel.SmartRequestParams{
|
|
CorrelationID: correlationID,
|
|
APIKey: resources.APIKey,
|
|
UpstreamURL: resources.UpstreamEndpoint.URL,
|
|
RequestBody: requestBody,
|
|
OriginalRequest: originalRequest,
|
|
EventLogger: requestFinishedEvent,
|
|
MaxRetries: systemSettings.MaxStreamingRetries,
|
|
RetryDelay: time.Duration(systemSettings.StreamingRetryDelayMs) * time.Millisecond,
|
|
LogTruncationLimit: systemSettings.LogTruncationLimit,
|
|
StreamingRetryPrompt: systemSettings.StreamingRetryPrompt,
|
|
}
|
|
h.channel.ProcessSmartStreamRequest(c, params)
|
|
}
|
|
|
|
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
|
correlationID := r.Header.Get("X-Correlation-ID")
|
|
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
|
|
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
|
|
if !exists || proxyErrPtr == nil {
|
|
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
|
|
return
|
|
}
|
|
if errors.IsClientNetworkError(err) {
|
|
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
|
} else {
|
|
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
|
}
|
|
if _, ok := rw.(*httptest.ResponseRecorder); ok {
|
|
return
|
|
}
|
|
if writer, ok := rw.(interface{ Written() bool }); ok {
|
|
if writer.Written() {
|
|
return
|
|
}
|
|
}
|
|
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
|
|
}
|
|
|
|
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
|
|
event := &models.RequestFinishedEvent{
|
|
RequestLog: models.RequestLog{
|
|
RequestTime: startTime,
|
|
ModelName: modelName,
|
|
RequestPath: c.Request.URL.Path,
|
|
UserAgent: c.Request.UserAgent(),
|
|
CorrelationID: corrID,
|
|
LogType: logType,
|
|
Metadata: make(datatypes.JSONMap),
|
|
},
|
|
CorrelationID: corrID,
|
|
IsPreciseRouting: isPreciseRouting,
|
|
}
|
|
if _, exists := c.Get(middleware.RedactedBodyKey); exists {
|
|
event.RequestLog.Metadata["request_body_present"] = true
|
|
}
|
|
if redactedAuth, exists := c.Get(middleware.RedactedAuthHeaderKey); exists {
|
|
event.RequestLog.Metadata["authorization_header"] = redactedAuth.(string)
|
|
}
|
|
if authTokenValue, exists := c.Get("authToken"); exists {
|
|
if authToken, ok := authTokenValue.(*models.AuthToken); ok {
|
|
event.RequestLog.AuthTokenID = &authToken.ID
|
|
}
|
|
}
|
|
if res != nil {
|
|
if res.APIKey != nil {
|
|
event.RequestLog.KeyID = &res.APIKey.ID
|
|
}
|
|
if res.KeyGroup != nil {
|
|
event.RequestLog.GroupID = &res.KeyGroup.ID
|
|
}
|
|
if res.UpstreamEndpoint != nil {
|
|
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
|
|
event.UpstreamURL = &res.UpstreamEndpoint.URL
|
|
}
|
|
if res.ProxyConfig != nil {
|
|
event.RequestLog.ProxyID = &res.ProxyConfig.ID
|
|
}
|
|
}
|
|
return event
|
|
}
|
|
|
|
func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, groupName string, isPreciseRouting bool) (*service.RequestResources, error) {
|
|
authTokenValue, exists := c.Get("authToken")
|
|
if !exists {
|
|
return nil, errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context")
|
|
}
|
|
authToken, ok := authTokenValue.(*models.AuthToken)
|
|
if !ok {
|
|
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
|
|
}
|
|
if isPreciseRouting {
|
|
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
|
|
} else {
|
|
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
|
|
}
|
|
}
|
|
|
|
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
|
c.JSON(apiErr.HTTPStatus, gin.H{
|
|
"error": apiErr,
|
|
"correlation_id": corrID,
|
|
})
|
|
}
|
|
|
|
type bufferPool struct{}
|
|
|
|
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
|
func (b *bufferPool) Put(bytes []byte) {}
|
|
|
|
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
|
var data struct {
|
|
UsageMetadata struct {
|
|
PromptTokenCount int `json:"promptTokenCount"`
|
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
} `json:"usageMetadata"`
|
|
}
|
|
if err := json.Unmarshal(body, &data); err == nil {
|
|
return data.UsageMetadata.PromptTokenCount, data.UsageMetadata.CandidatesTokenCount
|
|
}
|
|
return 0, 0
|
|
}
|
|
|
|
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
|
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders)
|
|
var customHeadersMap datatypes.JSONMap
|
|
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
|
|
finalConfig := &models.RequestConfig{
|
|
CustomHeaders: customHeadersMap,
|
|
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
|
StreamMinDelay: globalSettings.StreamMinDelay,
|
|
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
|
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
|
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
|
StreamChunkSize: globalSettings.StreamChunkSize,
|
|
EnableFakeStream: globalSettings.EnableFakeStream,
|
|
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
|
}
|
|
if groupConfig == nil {
|
|
return finalConfig
|
|
}
|
|
groupConfigJSON, err := json.Marshal(groupConfig)
|
|
if err != nil {
|
|
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
|
return finalConfig
|
|
}
|
|
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
|
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
|
return finalConfig
|
|
}
|
|
return finalConfig
|
|
}
|
|
|
|
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
|
authTokenValue, exists := c.Get("authToken")
|
|
if !exists {
|
|
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context"))
|
|
return
|
|
}
|
|
authToken, ok := authTokenValue.(*models.AuthToken)
|
|
if !ok {
|
|
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context"))
|
|
return
|
|
}
|
|
modelNames := h.resourceService.GetAllowedModelsForToken(authToken)
|
|
if strings.Contains(c.Request.URL.Path, "/v1beta/") {
|
|
h.respondWithGeminiFormat(c, modelNames)
|
|
} else {
|
|
h.respondWithOpenAIFormat(c, modelNames)
|
|
}
|
|
}
|
|
|
|
func (h *ProxyHandler) respondWithOpenAIFormat(c *gin.Context, modelNames []string) {
|
|
type ModelEntry struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
type ModelListResponse struct {
|
|
Object string `json:"object"`
|
|
Data []ModelEntry `json:"data"`
|
|
}
|
|
data := make([]ModelEntry, len(modelNames))
|
|
for i, name := range modelNames {
|
|
data[i] = ModelEntry{
|
|
ID: name,
|
|
Object: "model",
|
|
Created: time.Now().Unix(),
|
|
OwnedBy: "gemini-balancer",
|
|
}
|
|
}
|
|
response := ModelListResponse{
|
|
Object: "list",
|
|
Data: data,
|
|
}
|
|
c.JSON(http.StatusOK, response)
|
|
}
|
|
|
|
func (h *ProxyHandler) respondWithGeminiFormat(c *gin.Context, modelNames []string) {
|
|
type GeminiModelEntry struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
DisplayName string `json:"displayName"`
|
|
Description string `json:"description"`
|
|
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
|
InputTokenLimit int `json:"inputTokenLimit"`
|
|
OutputTokenLimit int `json:"outputTokenLimit"`
|
|
}
|
|
type GeminiModelListResponse struct {
|
|
Models []GeminiModelEntry `json:"models"`
|
|
}
|
|
models := make([]GeminiModelEntry, len(modelNames))
|
|
for i, name := range modelNames {
|
|
models[i] = GeminiModelEntry{
|
|
Name: fmt.Sprintf("models/%s", name),
|
|
Version: "1.0.0",
|
|
DisplayName: name,
|
|
Description: "Served by Gemini Balancer",
|
|
SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"},
|
|
InputTokenLimit: 8192,
|
|
OutputTokenLimit: 2048,
|
|
}
|
|
}
|
|
response := GeminiModelListResponse{Models: models}
|
|
c.JSON(http.StatusOK, response)
|
|
}
|