// Filename: internal/handlers/proxy_handler.go package handlers import ( "bytes" "compress/gzip" "context" "encoding/json" "fmt" "gemini-balancer/internal/channel" "gemini-balancer/internal/errors" "gemini-balancer/internal/middleware" "gemini-balancer/internal/models" "gemini-balancer/internal/service" "gemini-balancer/internal/settings" "gemini-balancer/internal/store" "io" "net" "net/http" "net/http/httptest" "net/http/httputil" "net/url" "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/sirupsen/logrus" "gorm.io/datatypes" ) type proxyErrorKey int const proxyErrKey proxyErrorKey = 0 type ProxyHandler struct { resourceService *service.ResourceService store store.Store settingsManager *settings.SettingsManager groupManager *service.GroupManager channel channel.ChannelProxy logger *logrus.Entry transparentProxy *httputil.ReverseProxy } func NewProxyHandler( resourceService *service.ResourceService, store store.Store, sm *settings.SettingsManager, gm *service.GroupManager, channel channel.ChannelProxy, logger *logrus.Logger, ) *ProxyHandler { ph := &ProxyHandler{ resourceService: resourceService, store: store, settingsManager: sm, groupManager: gm, channel: channel, logger: logger.WithField("component", "ProxyHandler"), transparentProxy: &httputil.ReverseProxy{}, } ph.transparentProxy.Transport = &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 60 * time.Second, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } ph.transparentProxy.ErrorHandler = ph.transparentProxyErrorHandler ph.transparentProxy.BufferPool = &bufferPool{} return ph } func (h *ProxyHandler) HandleProxy(c *gin.Context) { if c.Request.Method == "GET" && (strings.HasSuffix(c.Request.URL.Path, "/models") || strings.HasSuffix(c.Request.URL.Path, "/models/")) { h.handleListModelsRequest(c) return } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body")) return } c.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) c.Request.ContentLength = int64(len(requestBody)) modelName := h.channel.ExtractModel(c, requestBody) groupName := c.Param("group_name") isPreciseRouting := groupName != "" if !isPreciseRouting && modelName == "" { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL")) return } initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) if err != nil { if apiErr, ok := err.(*errors.APIError); ok { errToJSON(c, uuid.New().String(), apiErr) } else { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error())) } return } finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup) if err != nil { h.logger.WithError(err).Error("Failed to build operational config.") errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration")) return } isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c) if isOpenAICompatible { h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) return } isStream := h.channel.IsStreamRequest(c, requestBody) systemSettings := h.settingsManager.GetSettings() useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway if useSmartGateway && isStream && systemSettings.EnableStreamingRetry { h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting) } else { h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting) } } func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) { startTime := time.Now() correlationID := uuid.New().String() var finalRecorder *httptest.ResponseRecorder var lastUsedResources *service.RequestResources var finalProxyErr *errors.APIError var isSuccess bool var finalPromptTokens, finalCompletionTokens int var actualRetries int = 0 defer func() { if lastUsedResources == nil { h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.") return } finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting) finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds()) finalEvent.RequestLog.IsSuccess = isSuccess finalEvent.RequestLog.Retries = actualRetries if isSuccess { finalEvent.RequestLog.PromptTokens = finalPromptTokens finalEvent.RequestLog.CompletionTokens = finalCompletionTokens } if finalRecorder != nil { finalEvent.RequestLog.StatusCode = finalRecorder.Code } if !isSuccess { if finalProxyErr != nil { finalEvent.Error = finalProxyErr finalEvent.RequestLog.ErrorCode = finalProxyErr.Code finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message } else if finalRecorder != nil { apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.") finalEvent.Error = apiErr finalEvent.RequestLog.ErrorCode = apiErr.Code finalEvent.RequestLog.ErrorMessage = apiErr.Message } } eventData, err := json.Marshal(finalEvent) if err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.") return } if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.") } }() var maxRetries int if isPreciseRouting { if finalOpConfig.MaxRetries != nil { maxRetries = *finalOpConfig.MaxRetries } else { maxRetries = h.settingsManager.GetSettings().MaxRetries } } else { maxRetries = h.settingsManager.GetSettings().MaxRetries } totalAttempts := maxRetries + 1 for attempt := 1; attempt <= totalAttempts; attempt++ { if c.Request.Context().Err() != nil { h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.") if finalProxyErr == nil { finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") } break } var currentResources *service.RequestResources var err error if attempt == 1 { currentResources = initialResources } else { actualRetries = attempt - 1 h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt) currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting) if err != nil { h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err) finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry") break } } finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig) currentResources.RequestConfig = finalRequestConfig lastUsedResources = currentResources h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID) var attemptErr *errors.APIError var attemptIsSuccess bool recorder := httptest.NewRecorder() attemptStartTime := time.Now() connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout) defer cancel() attemptReq := c.Request.Clone(ctx) attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody)) if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" { h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt) isSuccess = false finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource") continue } h.transparentProxy.Director = func(req *http.Request) { targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL) req.URL.Scheme = targetURL.Scheme req.URL.Host = targetURL.Host req.Host = targetURL.Host var pureClientPath string if isPreciseRouting { proxyPrefix := "/proxy/" + groupName pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix) } else { pureClientPath = req.URL.Path } finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath) req.URL.Path = finalPath h.logger.WithFields(logrus.Fields{ "correlation_id": correlationID, "attempt": attempt, "key_id": currentResources.APIKey.ID, "base_upstream_url": currentResources.UpstreamEndpoint.URL, "final_request_url": req.URL.String(), }).Infof("Director constructed final upstream request URL.") req.Header.Del("Authorization") h.channel.ModifyRequest(req, currentResources.APIKey) req.Header.Set("X-Correlation-ID", correlationID) *req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr)) } transport := h.transparentProxy.Transport.(*http.Transport) if currentResources.ProxyConfig != nil { proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address) proxyURL, err := url.Parse(proxyURLStr) if err == nil { transport.Proxy = http.ProxyURL(proxyURL) } } else { transport.Proxy = http.ProxyFromEnvironment } h.transparentProxy.ModifyResponse = func(resp *http.Response) error { defer resp.Body.Close() var reader io.ReadCloser var err error isGzipped := resp.Header.Get("Content-Encoding") == "gzip" if isGzipped { reader, err = gzip.NewReader(resp.Body) if err != nil { h.logger.WithError(err).Error("Failed to create gzip reader") reader = resp.Body } else { resp.Header.Del("Content-Encoding") } defer reader.Close() } else { reader = resp.Body } bodyBytes, err := io.ReadAll(reader) if err != nil { attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error()) resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message))) return nil } if resp.StatusCode < 400 { attemptIsSuccess = true finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes) } else { parsedMsg := errors.ParseUpstreamError(bodyBytes) attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg) } resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) return nil } h.transparentProxy.ServeHTTP(recorder, attemptReq) finalRecorder = recorder finalProxyErr = attemptErr isSuccess = attemptIsSuccess h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr) if isSuccess { break } isUnretryableError := false if finalProxyErr != nil { if errors.IsUnretryableRequestError(finalProxyErr.Message) { isUnretryableError = true h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message) } } if attempt >= totalAttempts || isUnretryableError { break } retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting) retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds()) retryEvent.IsSuccess = false retryEvent.StatusCode = recorder.Code retryEvent.Retries = actualRetries if attemptErr != nil { retryEvent.Error = attemptErr retryEvent.ErrorCode = attemptErr.Code retryEvent.ErrorMessage = attemptErr.Message } eventData, _ := json.Marshal(retryEvent) _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) } if finalRecorder != nil { bodyBytes := finalRecorder.Body.Bytes() c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes))) for k, v := range finalRecorder.Header() { if strings.ToLower(k) != "content-length" { c.Writer.Header()[k] = v } } c.Writer.WriteHeader(finalRecorder.Code) c.Writer.Write(finalRecorder.Body.Bytes()) } else { errToJSON(c, correlationID, finalProxyErr) } } func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) { startTime := time.Now() correlationID := uuid.New().String() log := h.logger.WithField("id", correlationID) log.Info("Smart Gateway activated for streaming request.") var originalRequest models.GeminiRequest if err := json.Unmarshal(requestBody, &originalRequest); err != nil { errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error())) return } systemSettings := h.settingsManager.GetSettings() modelName := h.channel.ExtractModel(c, requestBody) requestFinishedEvent := h.createLogEvent(c, startTime, correlationID, modelName, resources, models.LogTypeFinal, isPreciseRouting) defer func() { requestFinishedEvent.LatencyMs = int(time.Since(startTime).Milliseconds()) if c.Writer.Status() > 0 { requestFinishedEvent.StatusCode = c.Writer.Status() } eventData, _ := json.Marshal(requestFinishedEvent) _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) }() params := channel.SmartRequestParams{ CorrelationID: correlationID, APIKey: resources.APIKey, UpstreamURL: resources.UpstreamEndpoint.URL, RequestBody: requestBody, OriginalRequest: originalRequest, EventLogger: requestFinishedEvent, MaxRetries: systemSettings.MaxStreamingRetries, RetryDelay: time.Duration(systemSettings.StreamingRetryDelayMs) * time.Millisecond, LogTruncationLimit: systemSettings.LogTruncationLimit, StreamingRetryPrompt: systemSettings.StreamingRetryPrompt, } h.channel.ProcessSmartStreamRequest(c, params) } func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) { correlationID := r.Header.Get("X-Correlation-ID") h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err) proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError) if !exists || proxyErrPtr == nil { h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.") return } if errors.IsClientNetworkError(err) { *proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed") } else { *proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error()) } if _, ok := rw.(*httptest.ResponseRecorder); ok { return } if writer, ok := rw.(interface{ Written() bool }); ok { if writer.Written() { return } } rw.WriteHeader((*proxyErrPtr).HTTPStatus) } func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent { event := &models.RequestFinishedEvent{ RequestLog: models.RequestLog{ RequestTime: startTime, ModelName: modelName, RequestPath: c.Request.URL.Path, UserAgent: c.Request.UserAgent(), CorrelationID: corrID, LogType: logType, Metadata: make(datatypes.JSONMap), }, CorrelationID: corrID, IsPreciseRouting: isPreciseRouting, } if _, exists := c.Get(middleware.RedactedBodyKey); exists { event.RequestLog.Metadata["request_body_present"] = true } if redactedAuth, exists := c.Get(middleware.RedactedAuthHeaderKey); exists { event.RequestLog.Metadata["authorization_header"] = redactedAuth.(string) } if authTokenValue, exists := c.Get("authToken"); exists { if authToken, ok := authTokenValue.(*models.AuthToken); ok { event.RequestLog.AuthTokenID = &authToken.ID } } if res != nil { if res.APIKey != nil { event.RequestLog.KeyID = &res.APIKey.ID } if res.KeyGroup != nil { event.RequestLog.GroupID = &res.KeyGroup.ID } if res.UpstreamEndpoint != nil { event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID event.UpstreamURL = &res.UpstreamEndpoint.URL } if res.ProxyConfig != nil { event.RequestLog.ProxyID = &res.ProxyConfig.ID } } return event } func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, groupName string, isPreciseRouting bool) (*service.RequestResources, error) { authTokenValue, exists := c.Get("authToken") if !exists { return nil, errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context") } authToken, ok := authTokenValue.(*models.AuthToken) if !ok { return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context") } if isPreciseRouting { return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName) } else { return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName) } } func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) { c.JSON(apiErr.HTTPStatus, gin.H{ "error": apiErr, "correlation_id": corrID, }) } type bufferPool struct{} func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) } func (b *bufferPool) Put(bytes []byte) {} func extractUsage(body []byte) (promptTokens int, completionTokens int) { var data struct { UsageMetadata struct { PromptTokenCount int `json:"promptTokenCount"` CandidatesTokenCount int `json:"candidatesTokenCount"` } `json:"usageMetadata"` } if err := json.Unmarshal(body, &data); err == nil { return data.UsageMetadata.PromptTokenCount, data.UsageMetadata.CandidatesTokenCount } return 0, 0 } func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig { customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders) var customHeadersMap datatypes.JSONMap _ = json.Unmarshal(customHeadersJSON, &customHeadersMap) finalConfig := &models.RequestConfig{ CustomHeaders: customHeadersMap, EnableStreamOptimizer: globalSettings.EnableStreamOptimizer, StreamMinDelay: globalSettings.StreamMinDelay, StreamMaxDelay: globalSettings.StreamMaxDelay, StreamShortTextThresh: globalSettings.StreamShortTextThresh, StreamLongTextThresh: globalSettings.StreamLongTextThresh, StreamChunkSize: globalSettings.StreamChunkSize, EnableFakeStream: globalSettings.EnableFakeStream, FakeStreamInterval: globalSettings.FakeStreamInterval, } if groupConfig == nil { return finalConfig } groupConfigJSON, err := json.Marshal(groupConfig) if err != nil { h.logger.WithError(err).Error("Failed to marshal group request config for merging.") return finalConfig } if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil { h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.") return finalConfig } return finalConfig } func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) { authTokenValue, exists := c.Get("authToken") if !exists { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context")) return } authToken, ok := authTokenValue.(*models.AuthToken) if !ok { errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")) return } modelNames := h.resourceService.GetAllowedModelsForToken(authToken) if strings.Contains(c.Request.URL.Path, "/v1beta/") { h.respondWithGeminiFormat(c, modelNames) } else { h.respondWithOpenAIFormat(c, modelNames) } } func (h *ProxyHandler) respondWithOpenAIFormat(c *gin.Context, modelNames []string) { type ModelEntry struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } type ModelListResponse struct { Object string `json:"object"` Data []ModelEntry `json:"data"` } data := make([]ModelEntry, len(modelNames)) for i, name := range modelNames { data[i] = ModelEntry{ ID: name, Object: "model", Created: time.Now().Unix(), OwnedBy: "gemini-balancer", } } response := ModelListResponse{ Object: "list", Data: data, } c.JSON(http.StatusOK, response) } func (h *ProxyHandler) respondWithGeminiFormat(c *gin.Context, modelNames []string) { type GeminiModelEntry struct { Name string `json:"name"` Version string `json:"version"` DisplayName string `json:"displayName"` Description string `json:"description"` SupportedGenerationMethods []string `json:"supportedGenerationMethods"` InputTokenLimit int `json:"inputTokenLimit"` OutputTokenLimit int `json:"outputTokenLimit"` } type GeminiModelListResponse struct { Models []GeminiModelEntry `json:"models"` } models := make([]GeminiModelEntry, len(modelNames)) for i, name := range modelNames { models[i] = GeminiModelEntry{ Name: fmt.Sprintf("models/%s", name), Version: "1.0.0", DisplayName: name, Description: "Served by Gemini Balancer", SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"}, InputTokenLimit: 8192, OutputTokenLimit: 2048, } } response := GeminiModelListResponse{Models: models} c.JSON(http.StatusOK, response) }