New
This commit is contained in:
47
internal/channel/channel.go
Normal file
47
internal/channel/channel.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Filename: internal/channel/channel.go
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ChannelProxy 是所有协议通道必须实现的统一接口
|
||||
type ChannelProxy interface {
|
||||
// RewritePath 路径重写的核心接口,负责将客户端路径转换为上游期望的路径
|
||||
RewritePath(basePath, originalPath string) string
|
||||
|
||||
TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error)
|
||||
// IsStreamRequest 检查请求是否为流式
|
||||
IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
|
||||
// IsOpenAICompatibleRequest 检查请求是否使用了OpenAI兼容路径
|
||||
IsOpenAICompatibleRequest(c *gin.Context) bool
|
||||
// ExtractModel 从请求中提取模型名称
|
||||
ExtractModel(c *gin.Context, bodyBytes []byte) string
|
||||
// ValidateKey 验证API Key的有效性
|
||||
ValidateKey(ctx context.Context, apiKey *models.APIKey, targetURL string, timeout time.Duration) *errors.APIError
|
||||
// ModifyRequest 在将请求发往上游前对其进行修改(如添加认证)
|
||||
ModifyRequest(req *http.Request, apiKey *models.APIKey)
|
||||
// ProcessSmartStreamRequest 处理核心的流式代理请求
|
||||
ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams)
|
||||
// 其他非流式处理方法等...
|
||||
}
|
||||
|
||||
// SmartRequestParams 是一个参数容器,用于将所有高层依赖,一次性、干净地传递到底层。
|
||||
type SmartRequestParams struct {
|
||||
CorrelationID string
|
||||
APIKey *models.APIKey
|
||||
UpstreamURL string
|
||||
RequestBody []byte
|
||||
OriginalRequest models.GeminiRequest
|
||||
EventLogger *models.RequestFinishedEvent
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
LogTruncationLimit int
|
||||
StreamingRetryPrompt string // <--- 传递续传指令
|
||||
}
|
||||
434
internal/channel/gemini_channel.go
Normal file
434
internal/channel/gemini_channel.go
Normal file
@@ -0,0 +1,434 @@
|
||||
// Filename: internal/channel/gemini_channel.go
|
||||
package channel
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const SmartRetryPrompt = "Continue exactly where you left off..."
|
||||
|
||||
var _ ChannelProxy = (*GeminiChannel)(nil)
|
||||
|
||||
type GeminiChannel struct {
|
||||
logger *logrus.Logger
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// 用于安全提取信息的本地结构体
|
||||
type requestMetadata struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
MaxIdleConns: cfg.TransportMaxIdleConns,
|
||||
MaxIdleConnsPerHost: cfg.TransportMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(cfg.TransportIdleConnTimeoutSecs) * time.Second,
|
||||
TLSHandshakeTimeout: time.Duration(cfg.TransportTLSHandshakeTimeout) * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
return &GeminiChannel{
|
||||
logger: logger,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TransformRequest
|
||||
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
|
||||
var p struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
_ = json.Unmarshal(requestBody, &p)
|
||||
modelName = strings.TrimPrefix(p.Model, "models/")
|
||||
|
||||
if modelName == "" {
|
||||
modelName = ch.extractModelFromPath(c.Request.URL.Path)
|
||||
}
|
||||
return requestBody, modelName, nil
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) extractModelFromPath(path string) string {
|
||||
parts := strings.Split(path, "/")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
|
||||
modelPart := strings.Split(part, ":")[0]
|
||||
return modelPart
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
|
||||
path := c.Request.URL.Path
|
||||
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ValidateKey(
|
||||
ctx context.Context,
|
||||
apiKey *models.APIKey,
|
||||
targetURL string,
|
||||
timeout time.Duration,
|
||||
) *CustomErrors.APIError {
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
|
||||
}
|
||||
|
||||
ch.ModifyRequest(req, apiKey)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
errorBody, _ := io.ReadAll(resp.Body)
|
||||
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
|
||||
Message: parsedMessage,
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
|
||||
// TODO: [Future Refactoring] Decouple auth logic from URL path.
|
||||
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
|
||||
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
|
||||
// This would make the channel more generic and adaptable to new upstream provider types.
|
||||
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
|
||||
} else {
|
||||
req.Header.Del("Authorization")
|
||||
q := req.URL.Query()
|
||||
q.Set("key", apiKey.APIKey)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
|
||||
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
|
||||
return true
|
||||
}
|
||||
var meta requestMetadata
|
||||
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
|
||||
return meta.Stream
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
||||
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
|
||||
return modelName
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
||||
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
|
||||
var rewrittenSegment string
|
||||
if ch.IsOpenAICompatibleRequest(tempCtx) {
|
||||
var apiEndpoint string
|
||||
v1Index := strings.LastIndex(originalPath, "/v1/")
|
||||
if v1Index != -1 {
|
||||
apiEndpoint = originalPath[v1Index+len("/v1/"):]
|
||||
} else {
|
||||
apiEndpoint = strings.TrimPrefix(originalPath, "/")
|
||||
}
|
||||
rewrittenSegment = "v1beta/openai/" + apiEndpoint
|
||||
} else {
|
||||
tempPath := originalPath
|
||||
if strings.HasPrefix(tempPath, "/v1/") {
|
||||
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
|
||||
}
|
||||
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
|
||||
}
|
||||
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
||||
pathToJoin := rewrittenSegment
|
||||
|
||||
versionPrefixes := []string{"v1beta", "v1"}
|
||||
for _, prefix := range versionPrefixes {
|
||||
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
|
||||
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
|
||||
break
|
||||
}
|
||||
}
|
||||
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
|
||||
if err != nil {
|
||||
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
|
||||
}
|
||||
return finalPath
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
}
|
||||
|
||||
// ==========================================================
|
||||
// ================== “智能路由”的核心引擎 ===================
|
||||
// ==========================================================
|
||||
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
|
||||
log := ch.logger.WithField("correlation_id", params.CorrelationID)
|
||||
|
||||
targetURL, err := url.Parse(params.UpstreamURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse upstream URL")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
|
||||
return
|
||||
}
|
||||
targetURL.Path = c.Request.URL.Path
|
||||
targetURL.RawQuery = c.Request.URL.RawQuery
|
||||
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create initial smart request")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
|
||||
return
|
||||
}
|
||||
ch.ModifyRequest(initialReq, params.APIKey)
|
||||
initialReq.Header.Del("Authorization")
|
||||
resp, err := ch.httpClient.Do(initialReq)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Initial smart request failed")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
|
||||
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
|
||||
c.Writer.WriteHeader(standardizedResp.StatusCode)
|
||||
for key, values := range standardizedResp.Header {
|
||||
for _, value := range values {
|
||||
c.Writer.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
io.Copy(c.Writer, standardizedResp.Body)
|
||||
params.EventLogger.IsSuccess = false
|
||||
params.EventLogger.StatusCode = resp.StatusCode
|
||||
return
|
||||
}
|
||||
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) processStreamAndRetry(
|
||||
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
|
||||
) {
|
||||
defer initialReader.Close()
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
var accumulatedText strings.Builder
|
||||
consecutiveRetryCount := 0
|
||||
currentReader := initialReader
|
||||
maxRetries := params.MaxRetries
|
||||
retryDelay := params.RetryDelay
|
||||
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
|
||||
for {
|
||||
var interruptionReason string
|
||||
scanner := bufio.NewScanner(currentReader)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(c.Writer, "%s\n\n", line)
|
||||
flusher.Flush()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
var payload models.GeminiSSEPayload
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
if len(payload.Candidates) > 0 {
|
||||
candidate := payload.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
accumulatedText.WriteString(candidate.Content.Parts[0].Text)
|
||||
}
|
||||
if candidate.FinishReason == "STOP" {
|
||||
log.Info("Stream finished successfully with STOP reason.")
|
||||
params.EventLogger.IsSuccess = true
|
||||
return
|
||||
}
|
||||
if candidate.FinishReason != "" {
|
||||
log.Warnf("Stream interrupted with abnormal finish reason: %s", candidate.FinishReason)
|
||||
interruptionReason = candidate.FinishReason
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
currentReader.Close()
|
||||
if interruptionReason == "" {
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.WithError(err).Warn("Stream scanner encountered an error.")
|
||||
interruptionReason = "SCANNER_ERROR"
|
||||
} else {
|
||||
log.Warn("Stream dropped unexpectedly without a finish reason.")
|
||||
interruptionReason = "CONNECTION_DROP"
|
||||
}
|
||||
}
|
||||
if consecutiveRetryCount >= maxRetries {
|
||||
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
|
||||
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
consecutiveRetryCount++
|
||||
params.EventLogger.Retries = consecutiveRetryCount
|
||||
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
|
||||
time.Sleep(retryDelay)
|
||||
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
||||
retryBodyBytes, _ := json.Marshal(retryBody)
|
||||
|
||||
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
||||
retryReq.Header = initialRequestHeaders
|
||||
ch.ModifyRequest(retryReq, params.APIKey)
|
||||
retryReq.Header.Del("Authorization")
|
||||
|
||||
retryResp, err := ch.httpClient.Do(retryReq)
|
||||
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Retry request failed.")
|
||||
} else {
|
||||
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
||||
if retryResp.Body != nil {
|
||||
retryResp.Body.Close()
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
currentReader = retryResp.Body
|
||||
}
|
||||
}
|
||||
|
||||
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
|
||||
retryBody := originalBody
|
||||
lastUserIndex := -1
|
||||
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
||||
if retryBody.Contents[i].Role == "user" {
|
||||
lastUserIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
history := []models.GeminiContent{
|
||||
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
|
||||
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
|
||||
}
|
||||
if lastUserIndex != -1 {
|
||||
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
|
||||
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
||||
newContents = append(newContents, history...)
|
||||
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
||||
retryBody.Contents = newContents
|
||||
} else {
|
||||
retryBody.Contents = append(retryBody.Contents, history...)
|
||||
}
|
||||
return retryBody, nil
|
||||
}
|
||||
|
||||
// ===============================================
|
||||
// ========= 辅助函数区 (继承并强化) =========
|
||||
// ===============================================
|
||||
|
||||
type googleAPIError struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
Details []interface{} `json:"details,omitempty"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func statusToGoogleStatus(code int) string {
|
||||
switch code {
|
||||
case 400:
|
||||
return "INVALID_ARGUMENT"
|
||||
case 401:
|
||||
return "UNAUTHENTICATED"
|
||||
case 403:
|
||||
return "PERMISSION_DENIED"
|
||||
case 404:
|
||||
return "NOT_FOUND"
|
||||
case 429:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
case 500:
|
||||
return "INTERNAL"
|
||||
case 503:
|
||||
return "UNAVAILABLE"
|
||||
case 504:
|
||||
return "DEADLINE_EXCEEDED"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if n > 0 && len(s) > n {
|
||||
return fmt.Sprintf("%s... [truncated %d chars]", s[:n], len(s)-n)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// standardizeError
|
||||
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to read upstream error body")
|
||||
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
|
||||
}
|
||||
resp.Body.Close()
|
||||
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
|
||||
var standardizedPayload googleAPIError
|
||||
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
|
||||
standardizedPayload.Error.Code = resp.StatusCode
|
||||
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
|
||||
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
|
||||
standardizedPayload.Error.Details = []interface{}{map[string]string{
|
||||
"@type": "proxy.upstream.error",
|
||||
"body": truncate(string(bodyBytes), truncateLimit),
|
||||
}}
|
||||
}
|
||||
newBodyBytes, _ := json.Marshal(standardizedPayload)
|
||||
newResp := &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(newBodyBytes)),
|
||||
}
|
||||
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
newResp.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
return newResp
|
||||
}
|
||||
|
||||
// errToJSON
|
||||
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
|
||||
}
|
||||
Reference in New Issue
Block a user