497 lines
15 KiB
Go
497 lines
15 KiB
Go
// Filename: internal/channel/gemini_channel.go
|
||
package channel
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
CustomErrors "gemini-balancer/internal/errors"
|
||
"gemini-balancer/internal/models"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
const SmartRetryPrompt = "Continue exactly where you left off..."
|
||
|
||
var _ ChannelProxy = (*GeminiChannel)(nil)
|
||
|
||
type GeminiChannel struct {
|
||
logger *logrus.Logger
|
||
httpClient *http.Client
|
||
}
|
||
|
||
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
|
||
transport := &http.Transport{
|
||
Proxy: http.ProxyFromEnvironment,
|
||
MaxIdleConns: cfg.TransportMaxIdleConns,
|
||
MaxIdleConnsPerHost: cfg.TransportMaxIdleConnsPerHost,
|
||
IdleConnTimeout: time.Duration(cfg.TransportIdleConnTimeoutSecs) * time.Second,
|
||
TLSHandshakeTimeout: time.Duration(cfg.TransportTLSHandshakeTimeout) * time.Second,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
}
|
||
return &GeminiChannel{
|
||
logger: logger,
|
||
httpClient: &http.Client{
|
||
Transport: transport,
|
||
Timeout: 0, // Timeout is handled by the request context
|
||
},
|
||
}
|
||
}
|
||
|
||
// TransformRequest
|
||
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
|
||
modelName = ch.ExtractModel(c, requestBody)
|
||
return requestBody, modelName, nil
|
||
}
|
||
|
||
// ExtractModel
|
||
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
||
return ch.extractModelFromRequest(c, bodyBytes)
|
||
}
|
||
|
||
// 统一的模型提取逻辑:优先从请求体解析,失败则回退到从URL路径解析。
|
||
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
|
||
var p struct {
|
||
Model string `json:"model"`
|
||
}
|
||
if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
|
||
return strings.TrimPrefix(p.Model, "models/")
|
||
}
|
||
return ch.extractModelFromPath(c.Request.URL.Path)
|
||
}
|
||
|
||
func (ch *GeminiChannel) extractModelFromPath(path string) string {
|
||
parts := strings.Split(path, "/")
|
||
for _, part := range parts {
|
||
// 覆盖更多模型名称格式
|
||
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
|
||
return strings.Split(part, ":")[0]
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
|
||
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
|
||
return ch.isOpenAIPath(c.Request.URL.Path)
|
||
}
|
||
|
||
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
|
||
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
|
||
}
|
||
|
||
func (ch *GeminiChannel) ValidateKey(
|
||
ctx context.Context,
|
||
apiKey *models.APIKey,
|
||
targetURL string,
|
||
timeout time.Duration,
|
||
) *CustomErrors.APIError {
|
||
client := &http.Client{Timeout: timeout}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
|
||
if err != nil {
|
||
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
|
||
}
|
||
|
||
ch.ModifyRequest(req, apiKey)
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||
return nil
|
||
}
|
||
|
||
errorBody, _ := io.ReadAll(resp.Body)
|
||
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
|
||
|
||
return &CustomErrors.APIError{
|
||
HTTPStatus: resp.StatusCode,
|
||
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
|
||
Message: parsedMessage,
|
||
}
|
||
}
|
||
|
||
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
|
||
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
|
||
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
|
||
} else {
|
||
req.Header.Del("Authorization")
|
||
q := req.URL.Query()
|
||
q.Set("key", apiKey.APIKey)
|
||
req.URL.RawQuery = q.Encode()
|
||
}
|
||
}
|
||
|
||
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
|
||
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
|
||
return true
|
||
}
|
||
var meta struct {
|
||
Stream bool `json:"stream"`
|
||
}
|
||
if json.Unmarshal(bodyBytes, &meta) == nil {
|
||
return meta.Stream
|
||
}
|
||
return false
|
||
}
|
||
|
||
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
|
||
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
||
var rewrittenSegment string
|
||
|
||
if ch.isOpenAIPath(originalPath) {
|
||
v1Index := strings.LastIndex(originalPath, "/v1/")
|
||
var apiEndpoint string
|
||
if v1Index != -1 {
|
||
apiEndpoint = originalPath[v1Index+len("/v1/"):]
|
||
} else {
|
||
apiEndpoint = strings.TrimPrefix(originalPath, "/")
|
||
}
|
||
rewrittenSegment = "v1beta/openai/" + apiEndpoint
|
||
} else {
|
||
if strings.HasPrefix(originalPath, "/v1/") {
|
||
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
|
||
} else {
|
||
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
|
||
}
|
||
}
|
||
|
||
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
||
|
||
// 防止版本号重复拼接,例如 basePath 是 /v1beta,而重写段也是 v1beta/..
|
||
versionPrefixes := []string{"v1beta", "v1"}
|
||
for _, prefix := range versionPrefixes {
|
||
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
|
||
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
|
||
break
|
||
}
|
||
}
|
||
|
||
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
|
||
if err != nil {
|
||
// 回退到简单的字符串拼接
|
||
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
|
||
}
|
||
return finalPath
|
||
}
|
||
|
||
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
|
||
return nil // 桩实现
|
||
}
|
||
|
||
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
|
||
// 桩实现
|
||
}
|
||
|
||
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
|
||
log := ch.logger.WithField("correlation_id", params.CorrelationID)
|
||
|
||
targetURL, err := url.Parse(params.UpstreamURL)
|
||
if err != nil {
|
||
log.WithError(err).Error("Invalid upstream URL")
|
||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
|
||
return
|
||
}
|
||
|
||
targetURL.Path = c.Request.URL.Path
|
||
targetURL.RawQuery = c.Request.URL.RawQuery
|
||
|
||
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
|
||
if err != nil {
|
||
log.WithError(err).Error("Failed to create initial smart stream request")
|
||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
|
||
return
|
||
}
|
||
|
||
ch.ModifyRequest(initialReq, params.APIKey)
|
||
initialReq.Header.Del("Authorization")
|
||
|
||
resp, err := ch.httpClient.Do(initialReq)
|
||
if err != nil {
|
||
log.WithError(err).Error("Initial smart stream request failed")
|
||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
|
||
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
|
||
defer standardizedResp.Body.Close()
|
||
|
||
c.Writer.WriteHeader(standardizedResp.StatusCode)
|
||
for key, values := range standardizedResp.Header {
|
||
for _, value := range values {
|
||
c.Writer.Header().Add(key, value)
|
||
}
|
||
}
|
||
io.Copy(c.Writer, standardizedResp.Body)
|
||
|
||
params.EventLogger.IsSuccess = false
|
||
params.EventLogger.StatusCode = resp.StatusCode
|
||
return
|
||
}
|
||
|
||
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
|
||
}
|
||
|
||
func (ch *GeminiChannel) processStreamAndRetry(
|
||
c *gin.Context,
|
||
initialRequestHeaders http.Header,
|
||
initialReader io.ReadCloser,
|
||
params SmartRequestParams,
|
||
log *logrus.Entry,
|
||
) {
|
||
defer initialReader.Close()
|
||
|
||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||
c.Writer.Header().Set("Connection", "keep-alive")
|
||
|
||
flusher, _ := c.Writer.(http.Flusher)
|
||
|
||
var accumulatedText strings.Builder
|
||
consecutiveRetryCount := 0
|
||
currentReader := initialReader
|
||
maxRetries := params.MaxRetries
|
||
retryDelay := params.RetryDelay
|
||
|
||
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
|
||
|
||
for {
|
||
if c.Request.Context().Err() != nil {
|
||
log.Info("Client disconnected, stopping stream processing.")
|
||
return
|
||
}
|
||
|
||
var interruptionReason string
|
||
scanner := bufio.NewScanner(currentReader)
|
||
|
||
for scanner.Scan() {
|
||
if c.Request.Context().Err() != nil {
|
||
log.Info("Client disconnected during scan.")
|
||
return
|
||
}
|
||
|
||
line := scanner.Text()
|
||
if line == "" {
|
||
continue
|
||
}
|
||
|
||
fmt.Fprintf(c.Writer, "%s\n\n", line)
|
||
flusher.Flush()
|
||
|
||
if !strings.HasPrefix(line, "data: ") {
|
||
continue
|
||
}
|
||
|
||
data := strings.TrimPrefix(line, "data: ")
|
||
var payload models.GeminiSSEPayload
|
||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||
continue
|
||
}
|
||
|
||
if len(payload.Candidates) > 0 {
|
||
candidate := payload.Candidates[0]
|
||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||
accumulatedText.WriteString(candidate.Content.Parts[0].Text)
|
||
}
|
||
if candidate.FinishReason == "STOP" {
|
||
log.Info("Stream finished successfully with STOP reason.")
|
||
params.EventLogger.IsSuccess = true
|
||
return
|
||
}
|
||
if candidate.FinishReason != "" {
|
||
log.Warnf("Stream interrupted with abnormal finish reason: %s", candidate.FinishReason)
|
||
interruptionReason = candidate.FinishReason
|
||
break
|
||
}
|
||
}
|
||
}
|
||
currentReader.Close()
|
||
|
||
if interruptionReason == "" {
|
||
if err := scanner.Err(); err != nil {
|
||
log.WithError(err).Warn("Stream scanner encountered an error.")
|
||
interruptionReason = "SCANNER_ERROR"
|
||
} else {
|
||
log.Warn("Stream connection dropped without a finish reason.")
|
||
interruptionReason = "CONNECTION_DROP"
|
||
}
|
||
}
|
||
|
||
if consecutiveRetryCount >= maxRetries {
|
||
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
|
||
errData, _ := json.Marshal(map[string]interface{}{
|
||
"error": map[string]interface{}{
|
||
"code": http.StatusGatewayTimeout,
|
||
"status": "DEADLINE_EXCEEDED",
|
||
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
|
||
},
|
||
})
|
||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
|
||
flusher.Flush()
|
||
return
|
||
}
|
||
|
||
consecutiveRetryCount++
|
||
params.EventLogger.Retries = consecutiveRetryCount
|
||
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
|
||
|
||
time.Sleep(retryDelay)
|
||
|
||
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
||
retryBodyBytes, _ := json.Marshal(retryBody)
|
||
|
||
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
||
if err != nil {
|
||
log.WithError(err).Error("Failed to create retry request")
|
||
continue
|
||
}
|
||
|
||
retryReq.Header = initialRequestHeaders
|
||
ch.ModifyRequest(retryReq, params.APIKey)
|
||
retryReq.Header.Del("Authorization")
|
||
|
||
retryResp, err := ch.httpClient.Do(retryReq)
|
||
if err != nil {
|
||
log.WithError(err).Error("Retry request failed")
|
||
continue
|
||
}
|
||
|
||
if retryResp.StatusCode != http.StatusOK {
|
||
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
||
retryResp.Body.Close()
|
||
continue
|
||
}
|
||
|
||
currentReader = retryResp.Body
|
||
}
|
||
}
|
||
|
||
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
|
||
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
|
||
retryBody := originalBody
|
||
|
||
// 找到最后一个 'user' 角色的消息索引
|
||
lastUserIndex := -1
|
||
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
||
if retryBody.Contents[i].Role == "user" {
|
||
lastUserIndex = i
|
||
break
|
||
}
|
||
}
|
||
|
||
history := []models.GeminiContent{
|
||
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
|
||
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
|
||
}
|
||
|
||
if lastUserIndex != -1 {
|
||
// 如果找到了 'user' 消息,将历史记录插入到其后
|
||
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
|
||
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
||
newContents = append(newContents, history...)
|
||
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
||
retryBody.Contents = newContents
|
||
} else {
|
||
// 如果没有 'user' 消息(理论上不应发生),则直接追加
|
||
retryBody.Contents = append(retryBody.Contents, history...)
|
||
}
|
||
|
||
return retryBody
|
||
}
|
||
|
||
type googleAPIError struct {
|
||
Error struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Status string `json:"status"`
|
||
Details []interface{} `json:"details,omitempty"`
|
||
} `json:"error"`
|
||
}
|
||
|
||
func statusToGoogleStatus(code int) string {
|
||
switch code {
|
||
case 400:
|
||
return "INVALID_ARGUMENT"
|
||
case 401:
|
||
return "UNAUTHENTICATED"
|
||
case 403:
|
||
return "PERMISSION_DENIED"
|
||
case 404:
|
||
return "NOT_FOUND"
|
||
case 429:
|
||
return "RESOURCE_EXHAUSTED"
|
||
case 500:
|
||
return "INTERNAL"
|
||
case 503:
|
||
return "UNAVAILABLE"
|
||
case 504:
|
||
return "DEADLINE_EXCEEDED"
|
||
default:
|
||
return "UNKNOWN"
|
||
}
|
||
}
|
||
|
||
func truncate(s string, n int) string {
|
||
if n > 0 && len(s) > n {
|
||
return fmt.Sprintf("%s... [truncated %d chars]", s[:n], len(s)-n)
|
||
}
|
||
return s
|
||
}
|
||
|
||
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
|
||
bodyBytes, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
log.WithError(err).Error("Failed to read upstream error body")
|
||
bodyBytes = []byte("Failed to read upstream error body")
|
||
}
|
||
resp.Body.Close()
|
||
|
||
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
|
||
|
||
var standardizedPayload googleAPIError
|
||
// 即使解析失败,也要构建一个标准的错误结构体
|
||
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
|
||
standardizedPayload.Error.Code = resp.StatusCode
|
||
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
|
||
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
|
||
standardizedPayload.Error.Details = []interface{}{map[string]string{
|
||
"@type": "proxy.upstream.unparsed.error",
|
||
"body": truncate(string(bodyBytes), truncateLimit),
|
||
}}
|
||
}
|
||
|
||
newBodyBytes, _ := json.Marshal(standardizedPayload)
|
||
newResp := &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Status: resp.Status,
|
||
Header: http.Header{},
|
||
Body: io.NopCloser(bytes.NewReader(newBodyBytes)),
|
||
}
|
||
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||
newResp.Header.Set("Access-Control-Allow-Origin", "*")
|
||
|
||
return newResp
|
||
}
|
||
|
||
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
|
||
if c.IsAborted() {
|
||
return
|
||
}
|
||
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
|
||
}
|