// Filename: internal/channel/gemini_channel.go package channel import ( "bufio" "bytes" "context" "encoding/json" "fmt" CustomErrors "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "io" "net/http" "net/url" "strings" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) const SmartRetryPrompt = "Continue exactly where you left off..." var _ ChannelProxy = (*GeminiChannel)(nil) type GeminiChannel struct { logger *logrus.Logger httpClient *http.Client } // 用于安全提取信息的本地结构体 type requestMetadata struct { Model string `json:"model"` Stream bool `json:"stream"` } func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, MaxIdleConns: cfg.TransportMaxIdleConns, MaxIdleConnsPerHost: cfg.TransportMaxIdleConnsPerHost, IdleConnTimeout: time.Duration(cfg.TransportIdleConnTimeoutSecs) * time.Second, TLSHandshakeTimeout: time.Duration(cfg.TransportTLSHandshakeTimeout) * time.Second, ExpectContinueTimeout: 1 * time.Second, } return &GeminiChannel{ logger: logger, httpClient: &http.Client{ Transport: transport, Timeout: 0, }, } } // TransformRequest func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) { var p struct { Model string `json:"model"` } _ = json.Unmarshal(requestBody, &p) modelName = strings.TrimPrefix(p.Model, "models/") if modelName == "" { modelName = ch.extractModelFromPath(c.Request.URL.Path) } return requestBody, modelName, nil } func (ch *GeminiChannel) extractModelFromPath(path string) string { parts := strings.Split(path, "/") for _, part := range parts { if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") { modelPart := strings.Split(part, ":")[0] return modelPart } } return "" } func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool { path := c.Request.URL.Path return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings") } func (ch *GeminiChannel) ValidateKey( ctx context.Context, apiKey *models.APIKey, targetURL string, timeout time.Duration, ) *CustomErrors.APIError { client := &http.Client{ Timeout: timeout, } req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil) if err != nil { return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error()) } ch.ModifyRequest(req, apiKey) resp, err := client.Do(req) if err != nil { return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error()) } defer resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } errorBody, _ := io.ReadAll(resp.Body) parsedMessage := CustomErrors.ParseUpstreamError(errorBody) return &CustomErrors.APIError{ HTTPStatus: resp.StatusCode, Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), Message: parsedMessage, } } func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) { // TODO: [Future Refactoring] Decouple auth logic from URL path. // The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property // of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns. // This would make the channel more generic and adaptable to new upstream provider types. if strings.Contains(req.URL.Path, "/v1beta/openai/") { req.Header.Set("Authorization", "Bearer "+apiKey.APIKey) } else { req.Header.Del("Authorization") q := req.URL.Query() q.Set("key", apiKey.APIKey) req.URL.RawQuery = q.Encode() } } func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") { return true } var meta requestMetadata if err := json.Unmarshal(bodyBytes, &meta); err == nil { return meta.Stream } return false } func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string { _, modelName, _ := ch.TransformRequest(c, bodyBytes) return modelName } func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string { tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}} var rewrittenSegment string if ch.IsOpenAICompatibleRequest(tempCtx) { var apiEndpoint string v1Index := strings.LastIndex(originalPath, "/v1/") if v1Index != -1 { apiEndpoint = originalPath[v1Index+len("/v1/"):] } else { apiEndpoint = strings.TrimPrefix(originalPath, "/") } rewrittenSegment = "v1beta/openai/" + apiEndpoint } else { tempPath := originalPath if strings.HasPrefix(tempPath, "/v1/") { tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/") } rewrittenSegment = strings.TrimPrefix(tempPath, "/") } trimmedBasePath := strings.TrimSuffix(basePath, "/") pathToJoin := rewrittenSegment versionPrefixes := []string{"v1beta", "v1"} for _, prefix := range versionPrefixes { if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") { pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/") break } } finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin) if err != nil { return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/") } return finalPath } func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error { // 这是一个桩实现,暂时不需要任何逻辑。 return nil } func (ch *GeminiChannel) HandleError(c *gin.Context, err error) { // 这是一个桩实现,暂时不需要任何逻辑。 } // ========================================================== // ================== “智能路由”的核心引擎 =================== // ========================================================== func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) { log := ch.logger.WithField("correlation_id", params.CorrelationID) targetURL, err := url.Parse(params.UpstreamURL) if err != nil { log.WithError(err).Error("Failed to parse upstream URL") errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format")) return } targetURL.Path = c.Request.URL.Path targetURL.RawQuery = c.Request.URL.RawQuery initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody)) if err != nil { log.WithError(err).Error("Failed to create initial smart request") errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error())) return } ch.ModifyRequest(initialReq, params.APIKey) initialReq.Header.Del("Authorization") resp, err := ch.httpClient.Do(initialReq) if err != nil { log.WithError(err).Error("Initial smart request failed") errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error())) return } if resp.StatusCode != http.StatusOK { log.Warnf("Initial request received non-200 status: %d", resp.StatusCode) standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log) c.Writer.WriteHeader(standardizedResp.StatusCode) for key, values := range standardizedResp.Header { for _, value := range values { c.Writer.Header().Add(key, value) } } io.Copy(c.Writer, standardizedResp.Body) params.EventLogger.IsSuccess = false params.EventLogger.StatusCode = resp.StatusCode return } ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log) } func (ch *GeminiChannel) processStreamAndRetry( c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry, ) { defer initialReader.Close() c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") flusher, _ := c.Writer.(http.Flusher) var accumulatedText strings.Builder consecutiveRetryCount := 0 currentReader := initialReader maxRetries := params.MaxRetries retryDelay := params.RetryDelay log.Infof("Starting smart stream session. Max retries: %d", maxRetries) for { var interruptionReason string scanner := bufio.NewScanner(currentReader) for scanner.Scan() { line := scanner.Text() if line == "" { continue } fmt.Fprintf(c.Writer, "%s\n\n", line) flusher.Flush() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") var payload models.GeminiSSEPayload if err := json.Unmarshal([]byte(data), &payload); err != nil { continue } if len(payload.Candidates) > 0 { candidate := payload.Candidates[0] if candidate.Content != nil && len(candidate.Content.Parts) > 0 { accumulatedText.WriteString(candidate.Content.Parts[0].Text) } if candidate.FinishReason == "STOP" { log.Info("Stream finished successfully with STOP reason.") params.EventLogger.IsSuccess = true return } if candidate.FinishReason != "" { log.Warnf("Stream interrupted with abnormal finish reason: %s", candidate.FinishReason) interruptionReason = candidate.FinishReason break } } } currentReader.Close() if interruptionReason == "" { if err := scanner.Err(); err != nil { log.WithError(err).Warn("Stream scanner encountered an error.") interruptionReason = "SCANNER_ERROR" } else { log.Warn("Stream dropped unexpectedly without a finish reason.") interruptionReason = "CONNECTION_DROP" } } if consecutiveRetryCount >= maxRetries { log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason) errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}}) fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData)) flusher.Flush() return } consecutiveRetryCount++ params.EventLogger.Retries = consecutiveRetryCount log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay) time.Sleep(retryDelay) retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String()) retryBodyBytes, _ := json.Marshal(retryBody) retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes)) retryReq.Header = initialRequestHeaders ch.ModifyRequest(retryReq, params.APIKey) retryReq.Header.Del("Authorization") retryResp, err := ch.httpClient.Do(retryReq) if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil { if err != nil { log.WithError(err).Errorf("Retry request failed.") } else { log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode) if retryResp.Body != nil { retryResp.Body.Close() } } continue } currentReader = retryResp.Body } } func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) { retryBody := originalBody lastUserIndex := -1 for i := len(retryBody.Contents) - 1; i >= 0; i-- { if retryBody.Contents[i].Role == "user" { lastUserIndex = i break } } history := []models.GeminiContent{ {Role: "model", Parts: []models.Part{{Text: accumulatedText}}}, {Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}}, } if lastUserIndex != -1 { newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2) newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...) newContents = append(newContents, history...) newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...) retryBody.Contents = newContents } else { retryBody.Contents = append(retryBody.Contents, history...) } return retryBody, nil } // =============================================== // ========= 辅助函数区 (继承并强化) ========= // =============================================== type googleAPIError struct { Error struct { Code int `json:"code"` Message string `json:"message"` Status string `json:"status"` Details []interface{} `json:"details,omitempty"` } `json:"error"` } func statusToGoogleStatus(code int) string { switch code { case 400: return "INVALID_ARGUMENT" case 401: return "UNAUTHENTICATED" case 403: return "PERMISSION_DENIED" case 404: return "NOT_FOUND" case 429: return "RESOURCE_EXHAUSTED" case 500: return "INTERNAL" case 503: return "UNAVAILABLE" case 504: return "DEADLINE_EXCEEDED" default: return "UNKNOWN" } } func truncate(s string, n int) string { if n > 0 && len(s) > n { return fmt.Sprintf("%s... [truncated %d chars]", s[:n], len(s)-n) } return s } // standardizeError func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { log.WithError(err).Error("Failed to read upstream error body") bodyBytes = []byte("Failed to read upstream error body: " + err.Error()) } resp.Body.Close() log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit)) var standardizedPayload googleAPIError if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 { standardizedPayload.Error.Code = resp.StatusCode standardizedPayload.Error.Message = http.StatusText(resp.StatusCode) standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode) standardizedPayload.Error.Details = []interface{}{map[string]string{ "@type": "proxy.upstream.error", "body": truncate(string(bodyBytes), truncateLimit), }} } newBodyBytes, _ := json.Marshal(standardizedPayload) newResp := &http.Response{ StatusCode: resp.StatusCode, Status: resp.Status, Header: http.Header{}, Body: io.NopCloser(bytes.NewReader(newBodyBytes)), } newResp.Header.Set("Content-Type", "application/json; charset=utf-8") newResp.Header.Set("Access-Control-Allow-Origin", "*") return newResp } // errToJSON func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) { c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr}) }