// Filename: internal/channel/gemini_channel.go package channel import ( "bufio" "bytes" "context" "encoding/json" "fmt" CustomErrors "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "io" "net/http" "net/url" "strings" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) const SmartRetryPrompt = "Continue exactly where you left off..." var _ ChannelProxy = (*GeminiChannel)(nil) type GeminiChannel struct { logger *logrus.Logger httpClient *http.Client } func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, MaxIdleConns: cfg.TransportMaxIdleConns, MaxIdleConnsPerHost: cfg.TransportMaxIdleConnsPerHost, IdleConnTimeout: time.Duration(cfg.TransportIdleConnTimeoutSecs) * time.Second, TLSHandshakeTimeout: time.Duration(cfg.TransportTLSHandshakeTimeout) * time.Second, ExpectContinueTimeout: 1 * time.Second, } return &GeminiChannel{ logger: logger, httpClient: &http.Client{ Transport: transport, Timeout: 0, // Timeout is handled by the request context }, } } // TransformRequest func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) { modelName = ch.ExtractModel(c, requestBody) return requestBody, modelName, nil } // ExtractModel func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string { return ch.extractModelFromRequest(c, bodyBytes) } // 统一的模型提取逻辑:优先从请求体解析,失败则回退到从URL路径解析。 func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string { var p struct { Model string `json:"model"` } if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" { return strings.TrimPrefix(p.Model, "models/") } return ch.extractModelFromPath(c.Request.URL.Path) } func (ch *GeminiChannel) extractModelFromPath(path string) string { parts := strings.Split(path, "/") for _, part := range parts { // 覆盖更多模型名称格式 if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") { return strings.Split(part, ":")[0] } } return "" } // IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。 func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool { return ch.isOpenAIPath(c.Request.URL.Path) } func (ch *GeminiChannel) isOpenAIPath(path string) bool { return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings") } func (ch *GeminiChannel) ValidateKey( ctx context.Context, apiKey *models.APIKey, targetURL string, timeout time.Duration, ) *CustomErrors.APIError { client := &http.Client{Timeout: timeout} req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil) if err != nil { return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request") } ch.ModifyRequest(req, apiKey) resp, err := client.Do(req) if err != nil { return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error()) } defer resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } errorBody, _ := io.ReadAll(resp.Body) parsedMessage := CustomErrors.ParseUpstreamError(errorBody) return &CustomErrors.APIError{ HTTPStatus: resp.StatusCode, Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), Message: parsedMessage, } } func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) { if strings.Contains(req.URL.Path, "/v1beta/openai/") { req.Header.Set("Authorization", "Bearer "+apiKey.APIKey) } else { req.Header.Del("Authorization") q := req.URL.Query() q.Set("key", apiKey.APIKey) req.URL.RawQuery = q.Encode() } } func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") { return true } var meta struct { Stream bool `json:"stream"` } if json.Unmarshal(bodyBytes, &meta) == nil { return meta.Stream } return false } // RewritePath 使用 url.JoinPath 保证路径拼接的正确性。 func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string { var rewrittenSegment string if ch.isOpenAIPath(originalPath) { v1Index := strings.LastIndex(originalPath, "/v1/") var apiEndpoint string if v1Index != -1 { apiEndpoint = originalPath[v1Index+len("/v1/"):] } else { apiEndpoint = strings.TrimPrefix(originalPath, "/") } rewrittenSegment = "v1beta/openai/" + apiEndpoint } else { if strings.HasPrefix(originalPath, "/v1/") { rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/") } else { rewrittenSegment = strings.TrimPrefix(originalPath, "/") } } trimmedBasePath := strings.TrimSuffix(basePath, "/") // 防止版本号重复拼接,例如 basePath 是 /v1beta,而重写段也是 v1beta/.. versionPrefixes := []string{"v1beta", "v1"} for _, prefix := range versionPrefixes { if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") { rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/") break } } finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment) if err != nil { // 回退到简单的字符串拼接 return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/") } return finalPath } func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error { return nil // 桩实现 } func (ch *GeminiChannel) HandleError(c *gin.Context, err error) { // 桩实现 } func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) { log := ch.logger.WithField("correlation_id", params.CorrelationID) targetURL, err := url.Parse(params.UpstreamURL) if err != nil { log.WithError(err).Error("Invalid upstream URL") errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL")) return } targetURL.Path = c.Request.URL.Path targetURL.RawQuery = c.Request.URL.RawQuery initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody)) if err != nil { log.WithError(err).Error("Failed to create initial smart stream request") errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request")) return } ch.ModifyRequest(initialReq, params.APIKey) initialReq.Header.Del("Authorization") resp, err := ch.httpClient.Do(initialReq) if err != nil { log.WithError(err).Error("Initial smart stream request failed") errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed")) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Warnf("Initial request received non-200 status: %d", resp.StatusCode) standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log) defer standardizedResp.Body.Close() c.Writer.WriteHeader(standardizedResp.StatusCode) for key, values := range standardizedResp.Header { for _, value := range values { c.Writer.Header().Add(key, value) } } io.Copy(c.Writer, standardizedResp.Body) params.EventLogger.IsSuccess = false params.EventLogger.StatusCode = resp.StatusCode return } ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log) } func (ch *GeminiChannel) processStreamAndRetry( c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry, ) { defer initialReader.Close() c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") flusher, _ := c.Writer.(http.Flusher) var accumulatedText strings.Builder consecutiveRetryCount := 0 currentReader := initialReader maxRetries := params.MaxRetries retryDelay := params.RetryDelay log.Infof("Starting smart stream session. Max retries: %d", maxRetries) for { if c.Request.Context().Err() != nil { log.Info("Client disconnected, stopping stream processing.") return } var interruptionReason string scanner := bufio.NewScanner(currentReader) for scanner.Scan() { if c.Request.Context().Err() != nil { log.Info("Client disconnected during scan.") return } line := scanner.Text() if line == "" { continue } fmt.Fprintf(c.Writer, "%s\n\n", line) flusher.Flush() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") var payload models.GeminiSSEPayload if err := json.Unmarshal([]byte(data), &payload); err != nil { continue } if len(payload.Candidates) > 0 { candidate := payload.Candidates[0] if candidate.Content != nil && len(candidate.Content.Parts) > 0 { accumulatedText.WriteString(candidate.Content.Parts[0].Text) } if candidate.FinishReason == "STOP" { log.Info("Stream finished successfully with STOP reason.") params.EventLogger.IsSuccess = true return } if candidate.FinishReason != "" { log.Warnf("Stream interrupted with abnormal finish reason: %s", candidate.FinishReason) interruptionReason = candidate.FinishReason break } } } currentReader.Close() if interruptionReason == "" { if err := scanner.Err(); err != nil { log.WithError(err).Warn("Stream scanner encountered an error.") interruptionReason = "SCANNER_ERROR" } else { log.Warn("Stream connection dropped without a finish reason.") interruptionReason = "CONNECTION_DROP" } } if consecutiveRetryCount >= maxRetries { log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason) errData, _ := json.Marshal(map[string]interface{}{ "error": map[string]interface{}{ "code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason), }, }) fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData)) flusher.Flush() return } consecutiveRetryCount++ params.EventLogger.Retries = consecutiveRetryCount log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay) time.Sleep(retryDelay) retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String()) retryBodyBytes, _ := json.Marshal(retryBody) retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes)) if err != nil { log.WithError(err).Error("Failed to create retry request") continue } retryReq.Header = initialRequestHeaders ch.ModifyRequest(retryReq, params.APIKey) retryReq.Header.Del("Authorization") retryResp, err := ch.httpClient.Do(retryReq) if err != nil { log.WithError(err).Error("Retry request failed") continue } if retryResp.StatusCode != http.StatusOK { log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode) retryResp.Body.Close() continue } currentReader = retryResp.Body } } // buildRetryRequestBody 正确处理多轮对话的上下文插入。 func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest { retryBody := originalBody // 找到最后一个 'user' 角色的消息索引 lastUserIndex := -1 for i := len(retryBody.Contents) - 1; i >= 0; i-- { if retryBody.Contents[i].Role == "user" { lastUserIndex = i break } } history := []models.GeminiContent{ {Role: "model", Parts: []models.Part{{Text: accumulatedText}}}, {Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}}, } if lastUserIndex != -1 { // 如果找到了 'user' 消息,将历史记录插入到其后 newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2) newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...) newContents = append(newContents, history...) newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...) retryBody.Contents = newContents } else { // 如果没有 'user' 消息(理论上不应发生),则直接追加 retryBody.Contents = append(retryBody.Contents, history...) } return retryBody } type googleAPIError struct { Error struct { Code int `json:"code"` Message string `json:"message"` Status string `json:"status"` Details []interface{} `json:"details,omitempty"` } `json:"error"` } func statusToGoogleStatus(code int) string { switch code { case 400: return "INVALID_ARGUMENT" case 401: return "UNAUTHENTICATED" case 403: return "PERMISSION_DENIED" case 404: return "NOT_FOUND" case 429: return "RESOURCE_EXHAUSTED" case 500: return "INTERNAL" case 503: return "UNAVAILABLE" case 504: return "DEADLINE_EXCEEDED" default: return "UNKNOWN" } } func truncate(s string, n int) string { if n > 0 && len(s) > n { return fmt.Sprintf("%s... [truncated %d chars]", s[:n], len(s)-n) } return s } func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { log.WithError(err).Error("Failed to read upstream error body") bodyBytes = []byte("Failed to read upstream error body") } resp.Body.Close() log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit)) var standardizedPayload googleAPIError // 即使解析失败,也要构建一个标准的错误结构体 if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 { standardizedPayload.Error.Code = resp.StatusCode standardizedPayload.Error.Message = http.StatusText(resp.StatusCode) standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode) standardizedPayload.Error.Details = []interface{}{map[string]string{ "@type": "proxy.upstream.unparsed.error", "body": truncate(string(bodyBytes), truncateLimit), }} } newBodyBytes, _ := json.Marshal(standardizedPayload) newResp := &http.Response{ StatusCode: resp.StatusCode, Status: resp.Status, Header: http.Header{}, Body: io.NopCloser(bytes.NewReader(newBodyBytes)), } newResp.Header.Set("Content-Type", "application/json; charset=utf-8") newResp.Header.Set("Access-Control-Allow-Origin", "*") return newResp } func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) { if c.IsAborted() { return } c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr}) }