Files
gemini-banlancer/internal/channel/gemini_channel.go
2025-11-23 22:42:58 +08:00

497 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/channel/gemini_channel.go
package channel
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
CustomErrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
const SmartRetryPrompt = "Continue exactly where you left off..."
var _ ChannelProxy = (*GeminiChannel)(nil)
type GeminiChannel struct {
logger *logrus.Logger
httpClient *http.Client
}
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: cfg.TransportMaxIdleConns,
MaxIdleConnsPerHost: cfg.TransportMaxIdleConnsPerHost,
IdleConnTimeout: time.Duration(cfg.TransportIdleConnTimeoutSecs) * time.Second,
TLSHandshakeTimeout: time.Duration(cfg.TransportTLSHandshakeTimeout) * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &GeminiChannel{
logger: logger,
httpClient: &http.Client{
Transport: transport,
Timeout: 0, // Timeout is handled by the request context
},
}
}
// TransformRequest
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
modelName = ch.ExtractModel(c, requestBody)
return requestBody, modelName, nil
}
// ExtractModel
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
return ch.extractModelFromRequest(c, bodyBytes)
}
// 统一的模型提取逻辑优先从请求体解析失败则回退到从URL路径解析。
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
var p struct {
Model string `json:"model"`
}
if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
return strings.TrimPrefix(p.Model, "models/")
}
return ch.extractModelFromPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) extractModelFromPath(path string) string {
parts := strings.Split(path, "/")
for _, part := range parts {
// 覆盖更多模型名称格式
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
return strings.Split(part, ":")[0]
}
}
return ""
}
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
return ch.isOpenAIPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
}
func (ch *GeminiChannel) ValidateKey(
ctx context.Context,
apiKey *models.APIKey,
targetURL string,
timeout time.Duration,
) *CustomErrors.APIError {
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
}
ch.ModifyRequest(req, apiKey)
resp, err := client.Do(req)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
errorBody, _ := io.ReadAll(resp.Body)
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
Message: parsedMessage,
}
}
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
} else {
req.Header.Del("Authorization")
q := req.URL.Query()
q.Set("key", apiKey.APIKey)
req.URL.RawQuery = q.Encode()
}
}
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
return true
}
var meta struct {
Stream bool `json:"stream"`
}
if json.Unmarshal(bodyBytes, &meta) == nil {
return meta.Stream
}
return false
}
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
var rewrittenSegment string
if ch.isOpenAIPath(originalPath) {
v1Index := strings.LastIndex(originalPath, "/v1/")
var apiEndpoint string
if v1Index != -1 {
apiEndpoint = originalPath[v1Index+len("/v1/"):]
} else {
apiEndpoint = strings.TrimPrefix(originalPath, "/")
}
rewrittenSegment = "v1beta/openai/" + apiEndpoint
} else {
if strings.HasPrefix(originalPath, "/v1/") {
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
} else {
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
}
}
trimmedBasePath := strings.TrimSuffix(basePath, "/")
// 防止版本号重复拼接,例如 basePath 是 /v1beta而重写段也是 v1beta/..
versionPrefixes := []string{"v1beta", "v1"}
for _, prefix := range versionPrefixes {
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
break
}
}
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
if err != nil {
// 回退到简单的字符串拼接
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
}
return finalPath
}
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
return nil // 桩实现
}
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
// 桩实现
}
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
log := ch.logger.WithField("correlation_id", params.CorrelationID)
targetURL, err := url.Parse(params.UpstreamURL)
if err != nil {
log.WithError(err).Error("Invalid upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
return
}
targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
if err != nil {
log.WithError(err).Error("Failed to create initial smart stream request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
return
}
ch.ModifyRequest(initialReq, params.APIKey)
initialReq.Header.Del("Authorization")
resp, err := ch.httpClient.Do(initialReq)
if err != nil {
log.WithError(err).Error("Initial smart stream request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
defer standardizedResp.Body.Close()
c.Writer.WriteHeader(standardizedResp.StatusCode)
for key, values := range standardizedResp.Header {
for _, value := range values {
c.Writer.Header().Add(key, value)
}
}
io.Copy(c.Writer, standardizedResp.Body)
params.EventLogger.IsSuccess = false
params.EventLogger.StatusCode = resp.StatusCode
return
}
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
}
func (ch *GeminiChannel) processStreamAndRetry(
c *gin.Context,
initialRequestHeaders http.Header,
initialReader io.ReadCloser,
params SmartRequestParams,
log *logrus.Entry,
) {
defer initialReader.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
flusher, _ := c.Writer.(http.Flusher)
var accumulatedText strings.Builder
consecutiveRetryCount := 0
currentReader := initialReader
maxRetries := params.MaxRetries
retryDelay := params.RetryDelay
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
for {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected, stopping stream processing.")
return
}
var interruptionReason string
scanner := bufio.NewScanner(currentReader)
for scanner.Scan() {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected during scan.")
return
}
line := scanner.Text()
if line == "" {
continue
}
fmt.Fprintf(c.Writer, "%s\n\n", line)
flusher.Flush()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var payload models.GeminiSSEPayload
if err := json.Unmarshal([]byte(data), &payload); err != nil {
continue
}
if len(payload.Candidates) > 0 {
candidate := payload.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
accumulatedText.WriteString(candidate.Content.Parts[0].Text)
}
if candidate.FinishReason == "STOP" {
log.Info("Stream finished successfully with STOP reason.")
params.EventLogger.IsSuccess = true
return
}
if candidate.FinishReason != "" {
log.Warnf("Stream interrupted with abnormal finish reason: %s", candidate.FinishReason)
interruptionReason = candidate.FinishReason
break
}
}
}
currentReader.Close()
if interruptionReason == "" {
if err := scanner.Err(); err != nil {
log.WithError(err).Warn("Stream scanner encountered an error.")
interruptionReason = "SCANNER_ERROR"
} else {
log.Warn("Stream connection dropped without a finish reason.")
interruptionReason = "CONNECTION_DROP"
}
}
if consecutiveRetryCount >= maxRetries {
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": http.StatusGatewayTimeout,
"status": "DEADLINE_EXCEEDED",
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
},
})
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
flusher.Flush()
return
}
consecutiveRetryCount++
params.EventLogger.Retries = consecutiveRetryCount
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
time.Sleep(retryDelay)
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBodyBytes, _ := json.Marshal(retryBody)
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
if err != nil {
log.WithError(err).Error("Failed to create retry request")
continue
}
retryReq.Header = initialRequestHeaders
ch.ModifyRequest(retryReq, params.APIKey)
retryReq.Header.Del("Authorization")
retryResp, err := ch.httpClient.Do(retryReq)
if err != nil {
log.WithError(err).Error("Retry request failed")
continue
}
if retryResp.StatusCode != http.StatusOK {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
retryResp.Body.Close()
continue
}
currentReader = retryResp.Body
}
}
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
retryBody := originalBody
// 找到最后一个 'user' 角色的消息索引
lastUserIndex := -1
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
if retryBody.Contents[i].Role == "user" {
lastUserIndex = i
break
}
}
history := []models.GeminiContent{
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
}
if lastUserIndex != -1 {
// 如果找到了 'user' 消息,将历史记录插入到其后
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
newContents = append(newContents, history...)
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
retryBody.Contents = newContents
} else {
// 如果没有 'user' 消息(理论上不应发生),则直接追加
retryBody.Contents = append(retryBody.Contents, history...)
}
return retryBody
}
type googleAPIError struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []interface{} `json:"details,omitempty"`
} `json:"error"`
}
func statusToGoogleStatus(code int) string {
switch code {
case 400:
return "INVALID_ARGUMENT"
case 401:
return "UNAUTHENTICATED"
case 403:
return "PERMISSION_DENIED"
case 404:
return "NOT_FOUND"
case 429:
return "RESOURCE_EXHAUSTED"
case 500:
return "INTERNAL"
case 503:
return "UNAVAILABLE"
case 504:
return "DEADLINE_EXCEEDED"
default:
return "UNKNOWN"
}
}
func truncate(s string, n int) string {
if n > 0 && len(s) > n {
return fmt.Sprintf("%s... [truncated %d chars]", s[:n], len(s)-n)
}
return s
}
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("Failed to read upstream error body")
bodyBytes = []byte("Failed to read upstream error body")
}
resp.Body.Close()
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
var standardizedPayload googleAPIError
// 即使解析失败,也要构建一个标准的错误结构体
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
standardizedPayload.Error.Code = resp.StatusCode
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
standardizedPayload.Error.Details = []interface{}{map[string]string{
"@type": "proxy.upstream.unparsed.error",
"body": truncate(string(bodyBytes), truncateLimit),
}}
}
newBodyBytes, _ := json.Marshal(standardizedPayload)
newResp := &http.Response{
StatusCode: resp.StatusCode,
Status: resp.Status,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(newBodyBytes)),
}
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
newResp.Header.Set("Access-Control-Allow-Origin", "*")
return newResp
}
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
}