435 lines
14 KiB
Go
435 lines
14 KiB
Go
// Filename: internal/channel/gemini_channel.go
|
|
package channel
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
CustomErrors "gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/models"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const SmartRetryPrompt = "Continue exactly where you left off..."
|
|
|
|
var _ ChannelProxy = (*GeminiChannel)(nil)
|
|
|
|
type GeminiChannel struct {
|
|
logger *logrus.Logger
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// 用于安全提取信息的本地结构体
|
|
type requestMetadata struct {
|
|
Model string `json:"model"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
|
|
transport := &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
MaxIdleConns: cfg.TransportMaxIdleConns,
|
|
MaxIdleConnsPerHost: cfg.TransportMaxIdleConnsPerHost,
|
|
IdleConnTimeout: time.Duration(cfg.TransportIdleConnTimeoutSecs) * time.Second,
|
|
TLSHandshakeTimeout: time.Duration(cfg.TransportTLSHandshakeTimeout) * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
return &GeminiChannel{
|
|
logger: logger,
|
|
httpClient: &http.Client{
|
|
Transport: transport,
|
|
Timeout: 0,
|
|
},
|
|
}
|
|
}
|
|
|
|
// TransformRequest
|
|
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
|
|
var p struct {
|
|
Model string `json:"model"`
|
|
}
|
|
_ = json.Unmarshal(requestBody, &p)
|
|
modelName = strings.TrimPrefix(p.Model, "models/")
|
|
|
|
if modelName == "" {
|
|
modelName = ch.extractModelFromPath(c.Request.URL.Path)
|
|
}
|
|
return requestBody, modelName, nil
|
|
}
|
|
|
|
func (ch *GeminiChannel) extractModelFromPath(path string) string {
|
|
parts := strings.Split(path, "/")
|
|
for _, part := range parts {
|
|
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
|
|
modelPart := strings.Split(part, ":")[0]
|
|
return modelPart
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
|
|
path := c.Request.URL.Path
|
|
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
|
|
}
|
|
|
|
func (ch *GeminiChannel) ValidateKey(
|
|
ctx context.Context,
|
|
apiKey *models.APIKey,
|
|
targetURL string,
|
|
timeout time.Duration,
|
|
) *CustomErrors.APIError {
|
|
client := &http.Client{
|
|
Timeout: timeout,
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
|
|
if err != nil {
|
|
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
|
|
}
|
|
|
|
ch.ModifyRequest(req, apiKey)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
return nil
|
|
}
|
|
errorBody, _ := io.ReadAll(resp.Body)
|
|
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
|
|
return &CustomErrors.APIError{
|
|
HTTPStatus: resp.StatusCode,
|
|
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
|
|
Message: parsedMessage,
|
|
}
|
|
}
|
|
|
|
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
|
|
// TODO: [Future Refactoring] Decouple auth logic from URL path.
|
|
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
|
|
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
|
|
// This would make the channel more generic and adaptable to new upstream provider types.
|
|
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
|
|
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
|
|
} else {
|
|
req.Header.Del("Authorization")
|
|
q := req.URL.Query()
|
|
q.Set("key", apiKey.APIKey)
|
|
req.URL.RawQuery = q.Encode()
|
|
}
|
|
}
|
|
|
|
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
|
|
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
|
|
return true
|
|
}
|
|
var meta requestMetadata
|
|
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
|
|
return meta.Stream
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
|
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
|
|
return modelName
|
|
}
|
|
|
|
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
|
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
|
|
var rewrittenSegment string
|
|
if ch.IsOpenAICompatibleRequest(tempCtx) {
|
|
var apiEndpoint string
|
|
v1Index := strings.LastIndex(originalPath, "/v1/")
|
|
if v1Index != -1 {
|
|
apiEndpoint = originalPath[v1Index+len("/v1/"):]
|
|
} else {
|
|
apiEndpoint = strings.TrimPrefix(originalPath, "/")
|
|
}
|
|
rewrittenSegment = "v1beta/openai/" + apiEndpoint
|
|
} else {
|
|
tempPath := originalPath
|
|
if strings.HasPrefix(tempPath, "/v1/") {
|
|
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
|
|
}
|
|
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
|
|
}
|
|
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
|
pathToJoin := rewrittenSegment
|
|
|
|
versionPrefixes := []string{"v1beta", "v1"}
|
|
for _, prefix := range versionPrefixes {
|
|
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
|
|
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
|
|
break
|
|
}
|
|
}
|
|
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
|
|
if err != nil {
|
|
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
|
|
}
|
|
return finalPath
|
|
}
|
|
|
|
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
|
|
// 这是一个桩实现,暂时不需要任何逻辑。
|
|
return nil
|
|
}
|
|
|
|
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
|
|
// 这是一个桩实现,暂时不需要任何逻辑。
|
|
}
|
|
|
|
// ==========================================================
|
|
// ================== “智能路由”的核心引擎 ===================
|
|
// ==========================================================
|
|
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
|
|
log := ch.logger.WithField("correlation_id", params.CorrelationID)
|
|
|
|
targetURL, err := url.Parse(params.UpstreamURL)
|
|
if err != nil {
|
|
log.WithError(err).Error("Failed to parse upstream URL")
|
|
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
|
|
return
|
|
}
|
|
targetURL.Path = c.Request.URL.Path
|
|
targetURL.RawQuery = c.Request.URL.RawQuery
|
|
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
|
|
if err != nil {
|
|
log.WithError(err).Error("Failed to create initial smart request")
|
|
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
|
|
return
|
|
}
|
|
ch.ModifyRequest(initialReq, params.APIKey)
|
|
initialReq.Header.Del("Authorization")
|
|
resp, err := ch.httpClient.Do(initialReq)
|
|
if err != nil {
|
|
log.WithError(err).Error("Initial smart request failed")
|
|
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
|
|
return
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
|
|
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
|
|
c.Writer.WriteHeader(standardizedResp.StatusCode)
|
|
for key, values := range standardizedResp.Header {
|
|
for _, value := range values {
|
|
c.Writer.Header().Add(key, value)
|
|
}
|
|
}
|
|
io.Copy(c.Writer, standardizedResp.Body)
|
|
params.EventLogger.IsSuccess = false
|
|
params.EventLogger.StatusCode = resp.StatusCode
|
|
return
|
|
}
|
|
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
|
|
}
|
|
|
|
func (ch *GeminiChannel) processStreamAndRetry(
|
|
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
|
|
) {
|
|
defer initialReader.Close()
|
|
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
c.Writer.Header().Set("Connection", "keep-alive")
|
|
flusher, _ := c.Writer.(http.Flusher)
|
|
var accumulatedText strings.Builder
|
|
consecutiveRetryCount := 0
|
|
currentReader := initialReader
|
|
maxRetries := params.MaxRetries
|
|
retryDelay := params.RetryDelay
|
|
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
|
|
for {
|
|
var interruptionReason string
|
|
scanner := bufio.NewScanner(currentReader)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if line == "" {
|
|
continue
|
|
}
|
|
fmt.Fprintf(c.Writer, "%s\n\n", line)
|
|
flusher.Flush()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
var payload models.GeminiSSEPayload
|
|
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
|
continue
|
|
}
|
|
if len(payload.Candidates) > 0 {
|
|
candidate := payload.Candidates[0]
|
|
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
|
accumulatedText.WriteString(candidate.Content.Parts[0].Text)
|
|
}
|
|
if candidate.FinishReason == "STOP" {
|
|
log.Info("Stream finished successfully with STOP reason.")
|
|
params.EventLogger.IsSuccess = true
|
|
return
|
|
}
|
|
if candidate.FinishReason != "" {
|
|
log.Warnf("Stream interrupted with abnormal finish reason: %s", candidate.FinishReason)
|
|
interruptionReason = candidate.FinishReason
|
|
break
|
|
}
|
|
}
|
|
}
|
|
currentReader.Close()
|
|
if interruptionReason == "" {
|
|
if err := scanner.Err(); err != nil {
|
|
log.WithError(err).Warn("Stream scanner encountered an error.")
|
|
interruptionReason = "SCANNER_ERROR"
|
|
} else {
|
|
log.Warn("Stream dropped unexpectedly without a finish reason.")
|
|
interruptionReason = "CONNECTION_DROP"
|
|
}
|
|
}
|
|
if consecutiveRetryCount >= maxRetries {
|
|
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
|
|
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
|
|
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
|
|
flusher.Flush()
|
|
return
|
|
}
|
|
consecutiveRetryCount++
|
|
params.EventLogger.Retries = consecutiveRetryCount
|
|
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
|
|
time.Sleep(retryDelay)
|
|
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
|
retryBodyBytes, _ := json.Marshal(retryBody)
|
|
|
|
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
|
retryReq.Header = initialRequestHeaders
|
|
ch.ModifyRequest(retryReq, params.APIKey)
|
|
retryReq.Header.Del("Authorization")
|
|
|
|
retryResp, err := ch.httpClient.Do(retryReq)
|
|
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
|
|
if err != nil {
|
|
log.WithError(err).Errorf("Retry request failed.")
|
|
} else {
|
|
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
|
if retryResp.Body != nil {
|
|
retryResp.Body.Close()
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
currentReader = retryResp.Body
|
|
}
|
|
}
|
|
|
|
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
|
|
retryBody := originalBody
|
|
lastUserIndex := -1
|
|
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
|
if retryBody.Contents[i].Role == "user" {
|
|
lastUserIndex = i
|
|
break
|
|
}
|
|
}
|
|
history := []models.GeminiContent{
|
|
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
|
|
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
|
|
}
|
|
if lastUserIndex != -1 {
|
|
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
|
|
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
|
newContents = append(newContents, history...)
|
|
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
|
retryBody.Contents = newContents
|
|
} else {
|
|
retryBody.Contents = append(retryBody.Contents, history...)
|
|
}
|
|
return retryBody, nil
|
|
}
|
|
|
|
// ===============================================
|
|
// ========= 辅助函数区 (继承并强化) =========
|
|
// ===============================================
|
|
|
|
type googleAPIError struct {
|
|
Error struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Status string `json:"status"`
|
|
Details []interface{} `json:"details,omitempty"`
|
|
} `json:"error"`
|
|
}
|
|
|
|
func statusToGoogleStatus(code int) string {
|
|
switch code {
|
|
case 400:
|
|
return "INVALID_ARGUMENT"
|
|
case 401:
|
|
return "UNAUTHENTICATED"
|
|
case 403:
|
|
return "PERMISSION_DENIED"
|
|
case 404:
|
|
return "NOT_FOUND"
|
|
case 429:
|
|
return "RESOURCE_EXHAUSTED"
|
|
case 500:
|
|
return "INTERNAL"
|
|
case 503:
|
|
return "UNAVAILABLE"
|
|
case 504:
|
|
return "DEADLINE_EXCEEDED"
|
|
default:
|
|
return "UNKNOWN"
|
|
}
|
|
}
|
|
|
|
func truncate(s string, n int) string {
|
|
if n > 0 && len(s) > n {
|
|
return fmt.Sprintf("%s... [truncated %d chars]", s[:n], len(s)-n)
|
|
}
|
|
return s
|
|
}
|
|
|
|
// standardizeError
|
|
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
|
|
bodyBytes, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
log.WithError(err).Error("Failed to read upstream error body")
|
|
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
|
|
}
|
|
resp.Body.Close()
|
|
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
|
|
var standardizedPayload googleAPIError
|
|
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
|
|
standardizedPayload.Error.Code = resp.StatusCode
|
|
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
|
|
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
|
|
standardizedPayload.Error.Details = []interface{}{map[string]string{
|
|
"@type": "proxy.upstream.error",
|
|
"body": truncate(string(bodyBytes), truncateLimit),
|
|
}}
|
|
}
|
|
newBodyBytes, _ := json.Marshal(standardizedPayload)
|
|
newResp := &http.Response{
|
|
StatusCode: resp.StatusCode,
|
|
Status: resp.Status,
|
|
Header: http.Header{},
|
|
Body: io.NopCloser(bytes.NewReader(newBodyBytes)),
|
|
}
|
|
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
|
|
newResp.Header.Set("Access-Control-Allow-Origin", "*")
|
|
return newResp
|
|
}
|
|
|
|
// errToJSON
|
|
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
|
|
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
|
|
}
|