fix:path rewriting & model list
This commit is contained in:
@@ -62,10 +62,13 @@ func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byt
|
||||
var p struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
|
||||
return strings.TrimPrefix(p.Model, "models/")
|
||||
_ = json.Unmarshal(bodyBytes, &p)
|
||||
modelName := strings.TrimPrefix(p.Model, "models/")
|
||||
|
||||
if modelName == "" {
|
||||
modelName = ch.extractModelFromPath(c.Request.URL.Path)
|
||||
}
|
||||
return ch.extractModelFromPath(c.Request.URL.Path)
|
||||
return modelName
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) extractModelFromPath(path string) string {
|
||||
@@ -85,7 +88,11 @@ func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
|
||||
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
|
||||
return strings.Contains(path, "/v1/chat/completions") ||
|
||||
strings.Contains(path, "/v1/completions") ||
|
||||
strings.Contains(path, "/v1/embeddings") ||
|
||||
strings.Contains(path, "/v1/models") ||
|
||||
strings.Contains(path, "/v1/audio/")
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ValidateKey(
|
||||
|
||||
@@ -986,6 +986,7 @@ func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *model
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
||||
|
||||
authTokenValue, exists := c.Get("authToken")
|
||||
if !exists {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context"))
|
||||
@@ -996,7 +997,61 @@ func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context"))
|
||||
return
|
||||
}
|
||||
modelNames := h.resourceService.GetAllowedModelsForToken(authToken)
|
||||
|
||||
groupName := c.Param("group_name")
|
||||
h.logger.Infof("List models request: path=%s, groupName=%s", c.Request.URL.Path, groupName)
|
||||
isPreciseRouting := groupName != ""
|
||||
|
||||
var modelNames []string
|
||||
|
||||
if isPreciseRouting {
|
||||
group, ok := h.groupManager.GetGroupByName(groupName)
|
||||
if !ok {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNotFound, "Group not found"))
|
||||
return
|
||||
}
|
||||
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
modelNames = append(modelNames, modelMapping.ModelName)
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
h.logger.Infof("Triggering passthrough for model list")
|
||||
initialResources, err := h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
|
||||
if err != nil {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources"))
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, _ := url.Parse(initialResources.UpstreamEndpoint.URL)
|
||||
apiPath := strings.TrimPrefix(c.Request.URL.Path, "/proxy/"+groupName)
|
||||
targetURL.Path = h.channel.RewritePath(targetURL.Path, apiPath)
|
||||
h.logger.Infof("Final upstream path: %s", targetURL.String())
|
||||
targetURL.RawQuery = c.Request.URL.RawQuery
|
||||
|
||||
req, _ := http.NewRequestWithContext(c.Request.Context(), "GET", targetURL.String(), nil)
|
||||
h.channel.ModifyRequest(req, initialResources.APIKey)
|
||||
|
||||
client := &http.Client{Transport: h.transparentProxy.Transport}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadGateway, "Failed to fetch models"))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header()[k] = v
|
||||
}
|
||||
io.Copy(c.Writer, resp.Body)
|
||||
h.logger.Infof("Passthrough response sent")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
modelNames = h.resourceService.GetAllowedModelsForToken(authToken)
|
||||
}
|
||||
|
||||
if strings.Contains(c.Request.URL.Path, "/v1beta/") {
|
||||
h.respondWithGeminiFormat(c, modelNames)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user