Initial commit
This commit is contained in:
50
internal/handlers/api_auth_handler.go
Normal file
50
internal/handlers/api_auth_handler.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Filename: internal/handlers/api_auth_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type APIAuthHandler struct {
|
||||
securityService *service.SecurityService
|
||||
}
|
||||
|
||||
func NewAPIAuthHandler(securityService *service.SecurityService) *APIAuthHandler {
|
||||
return &APIAuthHandler{securityService: securityService}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (h *APIAuthHandler) HandleLogin(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authToken, err := h.securityService.AuthenticateToken(req.Token)
|
||||
// 同时检查token是否有效,以及是否是管理员
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
h.securityService.RecordFailedLoginAttempt(c.Request.Context(), c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效或非管理员Token"})
|
||||
return
|
||||
}
|
||||
|
||||
middleware.SetAdminSessionCookie(c, authToken.Token)
|
||||
|
||||
c.JSON(http.StatusOK, LoginResponse{
|
||||
Token: authToken.Token,
|
||||
Message: "登录成功,欢迎管理员!",
|
||||
})
|
||||
}
|
||||
408
internal/handlers/apikey_handler.go
Normal file
408
internal/handlers/apikey_handler.go
Normal file
@@ -0,0 +1,408 @@
|
||||
// Filename: internal/handlers/apikey_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/task"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.APIKeyService
|
||||
db *gorm.DB
|
||||
keyImportService *service.KeyImportService
|
||||
keyValidationService *service.KeyValidationService
|
||||
}
|
||||
|
||||
func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImportService *service.KeyImportService, keyValidationService *service.KeyValidationService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
db: db,
|
||||
keyImportService: keyImportService,
|
||||
keyValidationService: keyValidationService,
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs for API requests
|
||||
type BulkAddKeysToGroupRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
|
||||
}
|
||||
|
||||
type BulkUnlinkKeysFromGroupRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type BulkHardDeleteKeysRequest struct {
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type BulkRestoreKeysRequest struct {
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Status *string `json:"status" binding:"omitempty,oneof=ACTIVE,PENDING_VALIDATION,COOLDOWN,DISABLED,BANNED"`
|
||||
}
|
||||
|
||||
type UpdateMappingRequest struct {
|
||||
Status models.APIKeyStatus `json:"status" binding:"required,oneof=ACTIVE PENDING_VALIDATION COOLDOWN DISABLED BANNED"`
|
||||
}
|
||||
|
||||
type BulkTestKeysRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type RestoreKeysRequest struct {
|
||||
KeyIDs []uint `json:"key_ids" binding:"required,gt=0"`
|
||||
}
|
||||
type BulkTestKeysForGroupRequest struct {
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
}
|
||||
|
||||
type BulkActionFilter struct {
|
||||
Status []string `json:"status"` // Changed to slice to accept multiple statuses
|
||||
}
|
||||
type BulkActionRequest struct {
|
||||
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
|
||||
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
|
||||
Filter BulkActionFilter `json:"filter" binding:"required"`
|
||||
}
|
||||
|
||||
// --- Handler Methods ---
|
||||
|
||||
// AddMultipleKeysToGroup handles adding/linking multiple keys to a specific group.
|
||||
func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
|
||||
var req BulkAddKeysToGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// UnlinkMultipleKeysFromGroup handles unlinking multiple keys from a specific group.
|
||||
func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
|
||||
var req BulkUnlinkKeysFromGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// HardDeleteMultipleKeys handles globally deleting multiple key entities.
|
||||
func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
|
||||
var req BulkHardDeleteKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// RestoreMultipleKeys handles restoring multiple keys to ACTIVE status globally.
|
||||
func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
|
||||
var req BulkRestoreKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
|
||||
var req BulkTestKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
var params models.APIKeyQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
result, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
|
||||
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
// 1. Manually handle the path parameter.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind query parameters using the correctly tagged struct.
|
||||
var params models.APIKeyQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Set server-side defaults and the path parameter.
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
params.KeyGroupID = uint(groupID)
|
||||
// 4. Call the service layer.
|
||||
paginatedResult, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 5. [THE FIX] Return a successful response using the standard `response.Success`
|
||||
// and a gin.H map, as confirmed to exist in your project.
|
||||
response.Success(c, gin.H{
|
||||
"items": paginatedResult.Items,
|
||||
"total": paginatedResult.Total,
|
||||
"page": paginatedResult.Page,
|
||||
"pages": paginatedResult.TotalPages,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
|
||||
// Group ID is now correctly sourced from the URL path.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// The request body is now simpler, only needing the keys.
|
||||
var req BulkTestKeysForGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Call the same underlying service, but with unambiguous context.
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// UpdateAPIKey is DEPRECATED. Status is now contextual to a group.
|
||||
func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
|
||||
err := errors.NewAPIError(errors.ErrBadRequest, "This endpoint is deprecated. Use 'PUT /keygroups/:id/apikeys/:keyId' to update key status within a group context.")
|
||||
response.Error(c, err)
|
||||
}
|
||||
|
||||
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
|
||||
// Route: PUT /keygroups/:id/apikeys/:keyId
|
||||
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
keyID, err := strconv.ParseUint(c.Param("keyId"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Key ID format"))
|
||||
return
|
||||
}
|
||||
var req UpdateMappingRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Directly use the service to handle the logic
|
||||
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
response.Error(c, apiErr)
|
||||
} else {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, updatedMapping)
|
||||
}
|
||||
|
||||
// HardDeleteAPIKey handles globally deleting a single key entity.
|
||||
func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "API key globally deleted successfully"})
|
||||
}
|
||||
|
||||
// RestoreKeysInGroup 恢复指定Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore
|
||||
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
var req RestoreKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
response.Error(c, apiErr)
|
||||
} else {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore-all-banned
|
||||
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
response.Error(c, apiErr)
|
||||
} else {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
|
||||
// Route: POST /keygroups/:id/bulk-actions
|
||||
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
|
||||
// 1. Parse GroupID from URL
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind the JSON payload to our new DTO
|
||||
var req BulkActionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Central logic: based on the action, call the appropriate service method.
|
||||
var task *task.Status
|
||||
var apiErr *errors.APIError
|
||||
switch req.Action {
|
||||
case "revalidate":
|
||||
// Assume keyValidationService has a method that accepts a filter
|
||||
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
case "set_status":
|
||||
if req.NewStatus == "" {
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
|
||||
break
|
||||
}
|
||||
// Assume apiKeyService has a method to update status by filter
|
||||
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
|
||||
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
|
||||
case "delete":
|
||||
// Assume keyImportService has a method to unlink by filter
|
||||
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
default:
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
|
||||
}
|
||||
// 4. Handle errors from the switch block
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
// Attempt to parse it as a known APIError, otherwise, wrap it.
|
||||
var parsedErr *errors.APIError
|
||||
if errors.As(err, &parsedErr) {
|
||||
response.Error(c, parsedErr)
|
||||
} else {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
// 5. Return the task status on success
|
||||
response.Success(c, task)
|
||||
}
|
||||
|
||||
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
|
||||
// Route: GET /keygroups/:id/apikeys/export
|
||||
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
|
||||
statuses := c.QueryArray("status")
|
||||
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, keyStrings)
|
||||
}
|
||||
62
internal/handlers/dashboard_handler.go
Normal file
62
internal/handlers/dashboard_handler.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Filename: internal/handlers/dashboard_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// DashboardHandler 负责处理全局仪表盘相关的API请求
|
||||
type DashboardHandler struct {
|
||||
queryService *service.DashboardQueryService
|
||||
}
|
||||
|
||||
func NewDashboardHandler(qs *service.DashboardQueryService) *DashboardHandler {
|
||||
return &DashboardHandler{queryService: qs}
|
||||
}
|
||||
|
||||
// GetOverview 获取仪表盘的全局统计卡片数据
|
||||
func (h *DashboardHandler) GetOverview(c *gin.Context) {
|
||||
stats, err := h.queryService.GetDashboardOverviewData()
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrInternalServer, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetChart 获取仪表盘的图表数据
|
||||
func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
var groupID *uint
|
||||
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
|
||||
if id, err := strconv.Atoi(groupIDStr); err == nil {
|
||||
uid := uint(id)
|
||||
groupID = &uid
|
||||
}
|
||||
}
|
||||
|
||||
chartData, err := h.queryService.QueryHistoricalChart(groupID)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, chartData)
|
||||
}
|
||||
|
||||
// GetRequestStats 处理对“期间调用概览”的请求
|
||||
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
|
||||
period := c.Param("period") // 从 URL 路径中获取 period
|
||||
stats, err := h.queryService.GetRequestStatsForPeriod(period)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
369
internal/handlers/keygroup_handler.go
Normal file
369
internal/handlers/keygroup_handler.go
Normal file
@@ -0,0 +1,369 @@
|
||||
// Filename: internal/handlers/keygroup_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/store"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
)
|
||||
|
||||
type KeyGroupHandler struct {
|
||||
groupManager *service.GroupManager
|
||||
store store.Store
|
||||
queryService *service.DashboardQueryService
|
||||
}
|
||||
|
||||
func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.DashboardQueryService) *KeyGroupHandler {
|
||||
return &KeyGroupHandler{
|
||||
groupManager: gm,
|
||||
queryService: qs,
|
||||
store: s,
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs & 辅助函数
|
||||
func isValidGroupName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
match, _ := regexp.MatchString("^[a-z0-9_-]{3,30}$", name)
|
||||
return match
|
||||
}
|
||||
|
||||
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
|
||||
type KeyGroupOperationalSettings struct {
|
||||
EnableKeyCheck *bool `json:"enable_key_check"`
|
||||
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
|
||||
KeyBlacklistThreshold *int `json:"key_blacklist_threshold"`
|
||||
KeyCooldownMinutes *int `json:"key_cooldown_minutes"`
|
||||
KeyCheckConcurrency *int `json:"key_check_concurrency"`
|
||||
KeyCheckEndpoint *string `json:"key_check_endpoint"`
|
||||
KeyCheckModel *string `json:"key_check_model"`
|
||||
MaxRetries *int `json:"max_retries"`
|
||||
EnableSmartGateway *bool `json:"enable_smart_gateway"`
|
||||
}
|
||||
|
||||
type CreateKeyGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Description string `json:"description"`
|
||||
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy bool `json:"enable_proxy"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
}
|
||||
|
||||
type UpdateKeyGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
DisplayName *string `json:"display_name"`
|
||||
Description *string `json:"description"`
|
||||
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy *bool `json:"enable_proxy"`
|
||||
ChannelType *string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
|
||||
// M:N associations
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
}
|
||||
|
||||
type KeyGroupResponse struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Description string `json:"description"`
|
||||
PollingStrategy models.PollingStrategy `json:"polling_strategy"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
EnableProxy bool `json:"enable_proxy"`
|
||||
APIKeysCount int64 `json:"api_keys_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Order int `json:"order"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
}
|
||||
|
||||
// [NEW] Define the detailed response structure for a single group.
|
||||
type KeyGroupDetailsResponse struct {
|
||||
KeyGroupResponse
|
||||
Settings *models.GroupSettings `json:"settings,omitempty"`
|
||||
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
|
||||
}
|
||||
|
||||
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
|
||||
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
|
||||
modelNames := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping != nil { // Safety check
|
||||
modelNames = append(modelNames, mapping.ModelName)
|
||||
}
|
||||
}
|
||||
return modelNames
|
||||
}
|
||||
|
||||
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
|
||||
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
|
||||
urls := make([]string, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
if upstream != nil { // Safety check
|
||||
urls = append(urls, upstream.URL)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
|
||||
return KeyGroupResponse{
|
||||
ID: group.ID,
|
||||
Name: group.Name,
|
||||
DisplayName: group.DisplayName,
|
||||
Description: group.Description,
|
||||
PollingStrategy: group.PollingStrategy,
|
||||
ChannelType: group.ChannelType,
|
||||
EnableProxy: group.EnableProxy,
|
||||
APIKeysCount: keyCount,
|
||||
CreatedAt: group.CreatedAt,
|
||||
UpdatedAt: group.UpdatedAt,
|
||||
Order: group.Order,
|
||||
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
|
||||
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
|
||||
}
|
||||
}
|
||||
|
||||
// packGroupSettings is a helper to convert request-level operational settings
|
||||
// into the model-level settings struct.
|
||||
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
|
||||
return &models.KeyGroupSettings{
|
||||
EnableKeyCheck: settings.EnableKeyCheck,
|
||||
KeyCheckIntervalMinutes: settings.KeyCheckIntervalMinutes,
|
||||
KeyBlacklistThreshold: settings.KeyBlacklistThreshold,
|
||||
KeyCooldownMinutes: settings.KeyCooldownMinutes,
|
||||
KeyCheckConcurrency: settings.KeyCheckConcurrency,
|
||||
KeyCheckEndpoint: settings.KeyCheckEndpoint,
|
||||
KeyCheckModel: settings.KeyCheckModel,
|
||||
MaxRetries: settings.MaxRetries,
|
||||
EnableSmartGateway: settings.EnableSmartGateway,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
return nil, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format")
|
||||
}
|
||||
group, ok := h.groupManager.GetGroupByID(uint(id))
|
||||
if !ok {
|
||||
return nil, errors.NewAPIError(errors.ErrResourceNotFound, "Group not found")
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
|
||||
if req.Name != nil {
|
||||
group.Name = *req.Name
|
||||
}
|
||||
p := bluemonday.StripTagsPolicy()
|
||||
if req.DisplayName != nil {
|
||||
group.DisplayName = p.Sanitize(*req.DisplayName)
|
||||
}
|
||||
if req.Description != nil {
|
||||
group.Description = p.Sanitize(*req.Description)
|
||||
}
|
||||
if req.PollingStrategy != nil {
|
||||
group.PollingStrategy = models.PollingStrategy(*req.PollingStrategy)
|
||||
}
|
||||
if req.EnableProxy != nil {
|
||||
group.EnableProxy = *req.EnableProxy
|
||||
}
|
||||
if req.ChannelType != nil {
|
||||
group.ChannelType = *req.ChannelType
|
||||
}
|
||||
}
|
||||
|
||||
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
|
||||
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
|
||||
go func() {
|
||||
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
h.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
// --- Handler 方法 ---
|
||||
|
||||
func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
|
||||
var req CreateKeyGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if !isValidGroupName(req.Name) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name. Must be 3-30 characters, lowercase letters, numbers, hyphens, or underscores."))
|
||||
return
|
||||
}
|
||||
|
||||
// The core logic remains, as it's specific to creation.
|
||||
p := bluemonday.StripTagsPolicy()
|
||||
sanitizedDisplayName := p.Sanitize(req.DisplayName)
|
||||
sanitizedDescription := p.Sanitize(req.Description)
|
||||
keyGroup := &models.KeyGroup{
|
||||
Name: req.Name,
|
||||
DisplayName: sanitizedDisplayName,
|
||||
Description: sanitizedDescription,
|
||||
PollingStrategy: models.PollingStrategy(req.PollingStrategy),
|
||||
EnableProxy: req.EnableProxy,
|
||||
ChannelType: req.ChannelType,
|
||||
}
|
||||
if keyGroup.PollingStrategy == "" {
|
||||
keyGroup.PollingStrategy = models.StrategySequential
|
||||
}
|
||||
if keyGroup.ChannelType == "" {
|
||||
keyGroup.ChannelType = "gemini"
|
||||
}
|
||||
|
||||
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
|
||||
if err := h.groupManager.CreateKeyGroup(keyGroup, groupSettings); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
h.publishGroupChangeEvent(keyGroup.ID, "group_created")
|
||||
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
|
||||
}
|
||||
|
||||
// 统一的处理器可以处理两种情况:
|
||||
// 1. GET /keygroups - 返回所有组的列表
|
||||
// 2. GET /keygroups/:id - 返回指定ID的单个组
|
||||
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
// Case 1: Get a single group
|
||||
if idStr := c.Param("id"); idStr != "" {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
keyCount := h.groupManager.GetKeyCount(group.ID)
|
||||
baseResponse := h.newKeyGroupResponse(group, keyCount)
|
||||
detailedResponse := KeyGroupDetailsResponse{
|
||||
KeyGroupResponse: baseResponse,
|
||||
Settings: group.Settings,
|
||||
RequestConfig: group.RequestConfig,
|
||||
}
|
||||
response.Success(c, detailedResponse)
|
||||
return
|
||||
}
|
||||
// Case 2: Get all groups
|
||||
allGroups := h.groupManager.GetAllGroups()
|
||||
responses := make([]KeyGroupResponse, 0, len(allGroups))
|
||||
for _, group := range allGroups {
|
||||
keyCount := h.groupManager.GetKeyCount(group.ID)
|
||||
responses = append(responses, h.newKeyGroupResponse(group, keyCount))
|
||||
}
|
||||
response.Success(c, responses)
|
||||
}
|
||||
|
||||
// UpdateKeyGroup
|
||||
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
var req UpdateKeyGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if req.Name != nil && !isValidGroupName(*req.Name) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name format."))
|
||||
return
|
||||
}
|
||||
applyUpdateRequestToGroup(&req, group)
|
||||
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
|
||||
err := h.groupManager.UpdateKeyGroup(group, groupSettings, req.AllowedUpstreams, req.AllowedModels)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
h.publishGroupChangeEvent(group.ID, "group_updated")
|
||||
freshGroup, _ := h.groupManager.GetGroupByID(group.ID)
|
||||
keyCount := h.groupManager.GetKeyCount(freshGroup.ID)
|
||||
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
|
||||
}
|
||||
|
||||
// DeleteKeyGroup
|
||||
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
groupName := group.Name
|
||||
if err := h.groupManager.DeleteKeyGroup(group.ID); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
h.publishGroupChangeEvent(group.ID, "group_deleted")
|
||||
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
|
||||
}
|
||||
|
||||
// GetKeyGroupStats
|
||||
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
stats, err := h.queryService.GetGroupStats(group.ID)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
clonedGroup, err := h.groupManager.CloneKeyGroup(group.ID)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
keyCount := int64(len(clonedGroup.Mappings))
|
||||
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
|
||||
}
|
||||
|
||||
// 更新分组排序
|
||||
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
|
||||
var payload []service.UpdateOrderPayload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
response.Success(c, gin.H{"message": "No order data to update."})
|
||||
return
|
||||
}
|
||||
if err := h.groupManager.UpdateOrder(payload); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "Group order updated successfully."})
|
||||
}
|
||||
33
internal/handlers/log_handler.go
Normal file
33
internal/handlers/log_handler.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Filename: internal/handlers/log_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LogHandler 负责处理与日志相关的HTTP请求
|
||||
type LogHandler struct {
|
||||
logService *service.LogService
|
||||
}
|
||||
|
||||
func NewLogHandler(logService *service.LogService) *LogHandler {
|
||||
return &LogHandler{logService: logService}
|
||||
}
|
||||
|
||||
func (h *LogHandler) GetLogs(c *gin.Context) {
|
||||
// 直接将Gin的上下文传递给Service层,让Service自己去解析查询参数
|
||||
logs, err := h.logService.GetLogs(c)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []models.RequestLog{}
|
||||
}
|
||||
response.Success(c, logs)
|
||||
}
|
||||
581
internal/handlers/proxy_handler.go
Normal file
581
internal/handlers/proxy_handler.go
Normal file
@@ -0,0 +1,581 @@
|
||||
// Filename: internal/handlers/proxy_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type proxyErrorKey int
|
||||
|
||||
const proxyErrKey proxyErrorKey = 0
|
||||
|
||||
type ProxyHandler struct {
|
||||
resourceService *service.ResourceService
|
||||
store store.Store
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *service.GroupManager
|
||||
channel channel.ChannelProxy
|
||||
logger *logrus.Entry
|
||||
transparentProxy *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
func NewProxyHandler(
|
||||
resourceService *service.ResourceService,
|
||||
store store.Store,
|
||||
sm *settings.SettingsManager,
|
||||
gm *service.GroupManager,
|
||||
channel channel.ChannelProxy,
|
||||
logger *logrus.Logger,
|
||||
) *ProxyHandler {
|
||||
ph := &ProxyHandler{
|
||||
resourceService: resourceService,
|
||||
store: store,
|
||||
settingsManager: sm,
|
||||
groupManager: gm,
|
||||
channel: channel,
|
||||
logger: logger.WithField("component", "ProxyHandler"),
|
||||
transparentProxy: &httputil.ReverseProxy{},
|
||||
}
|
||||
ph.transparentProxy.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
ph.transparentProxy.ErrorHandler = ph.transparentProxyErrorHandler
|
||||
ph.transparentProxy.BufferPool = &bufferPool{}
|
||||
return ph
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
if c.Request.Method == "GET" && (strings.HasSuffix(c.Request.URL.Path, "/models") || strings.HasSuffix(c.Request.URL.Path, "/models/")) {
|
||||
h.handleListModelsRequest(c)
|
||||
return
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body"))
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
c.Request.ContentLength = int64(len(requestBody))
|
||||
modelName := h.channel.ExtractModel(c, requestBody)
|
||||
groupName := c.Param("group_name")
|
||||
isPreciseRouting := groupName != ""
|
||||
if !isPreciseRouting && modelName == "" {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL"))
|
||||
return
|
||||
}
|
||||
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*errors.APIError); ok {
|
||||
errToJSON(c, uuid.New().String(), apiErr)
|
||||
} else {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to build operational config.")
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration"))
|
||||
return
|
||||
}
|
||||
|
||||
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
|
||||
if isOpenAICompatible {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
return
|
||||
}
|
||||
isStream := h.channel.IsStreamRequest(c, requestBody)
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
|
||||
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry {
|
||||
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
|
||||
} else {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
var finalRecorder *httptest.ResponseRecorder
|
||||
var lastUsedResources *service.RequestResources
|
||||
var finalProxyErr *errors.APIError
|
||||
var isSuccess bool
|
||||
var finalPromptTokens, finalCompletionTokens int
|
||||
var actualRetries int = 0
|
||||
defer func() {
|
||||
if lastUsedResources == nil {
|
||||
return
|
||||
}
|
||||
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
|
||||
finalEvent.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
finalEvent.IsSuccess = isSuccess
|
||||
finalEvent.Retries = actualRetries
|
||||
if isSuccess {
|
||||
finalEvent.PromptTokens = finalPromptTokens
|
||||
finalEvent.CompletionTokens = finalCompletionTokens
|
||||
}
|
||||
if finalRecorder != nil {
|
||||
finalEvent.StatusCode = finalRecorder.Code
|
||||
}
|
||||
if !isSuccess {
|
||||
if finalProxyErr != nil {
|
||||
finalEvent.Error = finalProxyErr
|
||||
finalEvent.ErrorCode = finalProxyErr.Code
|
||||
finalEvent.ErrorMessage = finalProxyErr.Message
|
||||
} else if finalRecorder != nil {
|
||||
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, "PROXY_ERROR", "Request failed after all retries.")
|
||||
finalEvent.Error = apiErr
|
||||
finalEvent.ErrorCode = apiErr.Code
|
||||
finalEvent.ErrorMessage = apiErr.Message
|
||||
}
|
||||
}
|
||||
eventData, _ := json.Marshal(finalEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
}()
|
||||
var maxRetries int
|
||||
if isPreciseRouting {
|
||||
// For precise routing, use the group's setting. If not set, fall back to the global setting.
|
||||
if finalOpConfig.MaxRetries != nil {
|
||||
maxRetries = *finalOpConfig.MaxRetries
|
||||
} else {
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
} else {
|
||||
// For BasePool (intelligent aggregation), *always* use the global setting.
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
totalAttempts := maxRetries + 1
|
||||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||||
if c.Request.Context().Err() != nil {
|
||||
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
|
||||
if finalProxyErr == nil {
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
}
|
||||
break
|
||||
}
|
||||
var currentResources *service.RequestResources
|
||||
var err error
|
||||
if attempt == 1 {
|
||||
currentResources = initialResources
|
||||
} else {
|
||||
actualRetries = attempt - 1
|
||||
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
||||
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig)
|
||||
currentResources.RequestConfig = finalRequestConfig
|
||||
lastUsedResources = currentResources
|
||||
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID)
|
||||
var attemptErr *errors.APIError
|
||||
var attemptIsSuccess bool
|
||||
recorder := httptest.NewRecorder()
|
||||
attemptStartTime := time.Now()
|
||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
||||
defer cancel()
|
||||
attemptReq := c.Request.Clone(ctx)
|
||||
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" {
|
||||
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
|
||||
isSuccess = false
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource")
|
||||
continue
|
||||
}
|
||||
h.transparentProxy.Director = func(req *http.Request) {
|
||||
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL)
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.Host = targetURL.Host
|
||||
var pureClientPath string
|
||||
if isPreciseRouting {
|
||||
proxyPrefix := "/proxy/" + groupName
|
||||
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
|
||||
} else {
|
||||
pureClientPath = req.URL.Path
|
||||
}
|
||||
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||||
req.URL.Path = finalPath
|
||||
h.logger.WithFields(logrus.Fields{
|
||||
"correlation_id": correlationID,
|
||||
"attempt": attempt,
|
||||
"key_id": currentResources.APIKey.ID,
|
||||
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
|
||||
"final_request_url": req.URL.String(),
|
||||
}).Infof("Director constructed final upstream request URL.")
|
||||
req.Header.Del("Authorization")
|
||||
h.channel.ModifyRequest(req, currentResources.APIKey)
|
||||
req.Header.Set("X-Correlation-ID", correlationID)
|
||||
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
|
||||
}
|
||||
transport := h.transparentProxy.Transport.(*http.Transport)
|
||||
if currentResources.ProxyConfig != nil {
|
||||
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address)
|
||||
proxyURL, err := url.Parse(proxyURLStr)
|
||||
if err == nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
defer resp.Body.Close()
|
||||
var reader io.ReadCloser
|
||||
var err error
|
||||
isGzipped := resp.Header.Get("Content-Encoding") == "gzip"
|
||||
if isGzipped {
|
||||
reader, err = gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to create gzip reader")
|
||||
reader = resp.Body
|
||||
} else {
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
defer reader.Close()
|
||||
} else {
|
||||
reader = resp.Body
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error())
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message)))
|
||||
return nil
|
||||
}
|
||||
if resp.StatusCode < 400 {
|
||||
attemptIsSuccess = true
|
||||
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes)
|
||||
} else {
|
||||
parsedMsg := errors.ParseUpstreamError(bodyBytes)
|
||||
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
||||
finalRecorder = recorder
|
||||
finalProxyErr = attemptErr
|
||||
isSuccess = attemptIsSuccess
|
||||
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
|
||||
if isSuccess {
|
||||
break
|
||||
}
|
||||
isUnretryableError := false
|
||||
if finalProxyErr != nil {
|
||||
if errors.IsUnretryableRequestError(finalProxyErr.Message) {
|
||||
isUnretryableError = true
|
||||
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message)
|
||||
}
|
||||
}
|
||||
if attempt >= totalAttempts || isUnretryableError {
|
||||
break
|
||||
}
|
||||
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
|
||||
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds())
|
||||
retryEvent.IsSuccess = false
|
||||
retryEvent.StatusCode = recorder.Code
|
||||
retryEvent.Retries = actualRetries
|
||||
if attemptErr != nil {
|
||||
retryEvent.Error = attemptErr
|
||||
retryEvent.ErrorCode = attemptErr.Code
|
||||
retryEvent.ErrorMessage = attemptErr.Message
|
||||
}
|
||||
eventData, _ := json.Marshal(retryEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
}
|
||||
if finalRecorder != nil {
|
||||
bodyBytes := finalRecorder.Body.Bytes()
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
|
||||
for k, v := range finalRecorder.Header() {
|
||||
if strings.ToLower(k) != "content-length" {
|
||||
c.Writer.Header()[k] = v
|
||||
}
|
||||
}
|
||||
c.Writer.WriteHeader(finalRecorder.Code)
|
||||
c.Writer.Write(finalRecorder.Body.Bytes())
|
||||
} else {
|
||||
errToJSON(c, correlationID, finalProxyErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
log := h.logger.WithField("id", correlationID)
|
||||
log.Info("Smart Gateway activated for streaming request.")
|
||||
var originalRequest models.GeminiRequest
|
||||
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
|
||||
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error()))
|
||||
return
|
||||
}
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
modelName := h.channel.ExtractModel(c, requestBody)
|
||||
requestFinishedEvent := h.createLogEvent(c, startTime, correlationID, modelName, resources, models.LogTypeFinal, isPreciseRouting)
|
||||
defer func() {
|
||||
requestFinishedEvent.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
if c.Writer.Status() > 0 {
|
||||
requestFinishedEvent.StatusCode = c.Writer.Status()
|
||||
}
|
||||
eventData, _ := json.Marshal(requestFinishedEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
}()
|
||||
params := channel.SmartRequestParams{
|
||||
CorrelationID: correlationID,
|
||||
APIKey: resources.APIKey,
|
||||
UpstreamURL: resources.UpstreamEndpoint.URL,
|
||||
RequestBody: requestBody,
|
||||
OriginalRequest: originalRequest,
|
||||
EventLogger: requestFinishedEvent,
|
||||
MaxRetries: systemSettings.MaxStreamingRetries,
|
||||
RetryDelay: time.Duration(systemSettings.StreamingRetryDelayMs) * time.Millisecond,
|
||||
LogTruncationLimit: systemSettings.LogTruncationLimit,
|
||||
StreamingRetryPrompt: systemSettings.StreamingRetryPrompt,
|
||||
}
|
||||
h.channel.ProcessSmartStreamRequest(c, params)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
correlationID := r.Header.Get("X-Correlation-ID")
|
||||
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
|
||||
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
|
||||
if !exists || proxyErrPtr == nil {
|
||||
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
|
||||
return
|
||||
}
|
||||
if errors.IsClientNetworkError(err) {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
} else {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||||
}
|
||||
if _, ok := rw.(*httptest.ResponseRecorder); ok {
|
||||
return
|
||||
}
|
||||
if writer, ok := rw.(interface{ Written() bool }); ok {
|
||||
if writer.Written() {
|
||||
return
|
||||
}
|
||||
}
|
||||
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
|
||||
event := &models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
RequestTime: startTime,
|
||||
ModelName: modelName,
|
||||
RequestPath: c.Request.URL.Path,
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
CorrelationID: corrID,
|
||||
LogType: logType,
|
||||
Metadata: make(datatypes.JSONMap),
|
||||
},
|
||||
CorrelationID: corrID,
|
||||
IsPreciseRouting: isPreciseRouting,
|
||||
}
|
||||
if _, exists := c.Get(middleware.RedactedBodyKey); exists {
|
||||
event.RequestLog.Metadata["request_body_present"] = true
|
||||
}
|
||||
if redactedAuth, exists := c.Get(middleware.RedactedAuthHeaderKey); exists {
|
||||
event.RequestLog.Metadata["authorization_header"] = redactedAuth.(string)
|
||||
}
|
||||
if authTokenValue, exists := c.Get("authToken"); exists {
|
||||
if authToken, ok := authTokenValue.(*models.AuthToken); ok {
|
||||
event.AuthTokenID = &authToken.ID
|
||||
}
|
||||
}
|
||||
if res != nil {
|
||||
event.KeyID = res.APIKey.ID
|
||||
event.GroupID = res.KeyGroup.ID
|
||||
if res.UpstreamEndpoint != nil {
|
||||
event.UpstreamID = &res.UpstreamEndpoint.ID
|
||||
event.UpstreamURL = &res.UpstreamEndpoint.URL
|
||||
}
|
||||
if res.ProxyConfig != nil {
|
||||
event.ProxyID = &res.ProxyConfig.ID
|
||||
}
|
||||
}
|
||||
return event
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, groupName string, isPreciseRouting bool) (*service.RequestResources, error) {
|
||||
authTokenValue, exists := c.Get("authToken")
|
||||
if !exists {
|
||||
return nil, errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context")
|
||||
}
|
||||
authToken, ok := authTokenValue.(*models.AuthToken)
|
||||
if !ok {
|
||||
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
|
||||
}
|
||||
if isPreciseRouting {
|
||||
return h.resourceService.GetResourceFromGroup(authToken, groupName)
|
||||
} else {
|
||||
return h.resourceService.GetResourceFromBasePool(authToken, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{
|
||||
"error": apiErr,
|
||||
"correlation_id": corrID,
|
||||
})
|
||||
}
|
||||
|
||||
type bufferPool struct{}
|
||||
|
||||
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
||||
func (b *bufferPool) Put(bytes []byte) {}
|
||||
|
||||
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
||||
var data struct {
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err == nil {
|
||||
return data.UsageMetadata.PromptTokenCount, data.UsageMetadata.CandidatesTokenCount
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
||||
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders)
|
||||
var customHeadersMap datatypes.JSONMap
|
||||
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
|
||||
finalConfig := &models.RequestConfig{
|
||||
CustomHeaders: customHeadersMap,
|
||||
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
||||
StreamMinDelay: globalSettings.StreamMinDelay,
|
||||
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
||||
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
||||
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
||||
StreamChunkSize: globalSettings.StreamChunkSize,
|
||||
EnableFakeStream: globalSettings.EnableFakeStream,
|
||||
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
||||
}
|
||||
if groupConfig == nil {
|
||||
return finalConfig
|
||||
}
|
||||
groupConfigJSON, err := json.Marshal(groupConfig)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
||||
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
return finalConfig
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
||||
authTokenValue, exists := c.Get("authToken")
|
||||
if !exists {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrUnauthorized, "Auth token not found in context"))
|
||||
return
|
||||
}
|
||||
authToken, ok := authTokenValue.(*models.AuthToken)
|
||||
if !ok {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context"))
|
||||
return
|
||||
}
|
||||
modelNames := h.resourceService.GetAllowedModelsForToken(authToken)
|
||||
if strings.Contains(c.Request.URL.Path, "/v1beta/") {
|
||||
h.respondWithGeminiFormat(c, modelNames)
|
||||
} else {
|
||||
h.respondWithOpenAIFormat(c, modelNames)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) respondWithOpenAIFormat(c *gin.Context, modelNames []string) {
|
||||
type ModelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
type ModelListResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelEntry `json:"data"`
|
||||
}
|
||||
data := make([]ModelEntry, len(modelNames))
|
||||
for i, name := range modelNames {
|
||||
data[i] = ModelEntry{
|
||||
ID: name,
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "gemini-balancer",
|
||||
}
|
||||
}
|
||||
response := ModelListResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) respondWithGeminiFormat(c *gin.Context, modelNames []string) {
|
||||
type GeminiModelEntry struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
||||
InputTokenLimit int `json:"inputTokenLimit"`
|
||||
OutputTokenLimit int `json:"outputTokenLimit"`
|
||||
}
|
||||
type GeminiModelListResponse struct {
|
||||
Models []GeminiModelEntry `json:"models"`
|
||||
}
|
||||
models := make([]GeminiModelEntry, len(modelNames))
|
||||
for i, name := range modelNames {
|
||||
models[i] = GeminiModelEntry{
|
||||
Name: fmt.Sprintf("models/%s", name),
|
||||
Version: "1.0.0",
|
||||
DisplayName: name,
|
||||
Description: "Served by Gemini Balancer",
|
||||
SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"},
|
||||
InputTokenLimit: 8192,
|
||||
OutputTokenLimit: 2048,
|
||||
}
|
||||
}
|
||||
response := GeminiModelListResponse{Models: models}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
46
internal/handlers/setting_handler.go
Normal file
46
internal/handlers/setting_handler.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// file: gemini-balancer\internal\handlers\setting_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/settings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SettingHandler struct {
|
||||
settingsManager *settings.SettingsManager
|
||||
}
|
||||
|
||||
func NewSettingHandler(settingsManager *settings.SettingsManager) *SettingHandler {
|
||||
return &SettingHandler{settingsManager: settingsManager}
|
||||
}
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
settings := h.settingsManager.GetSettings()
|
||||
response.Success(c, settings)
|
||||
}
|
||||
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
var newSettingsMap map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&newSettingsMap); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
if err := h.settingsManager.UpdateSettings(newSettingsMap); err != nil {
|
||||
// TODO 可以根据错误类型返回更具体的错误
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "Settings update request processed successfully."})
|
||||
|
||||
}
|
||||
|
||||
// ResetSettingsToDefaults resets all settings to their default values
|
||||
func (h *SettingHandler) ResetSettingsToDefaults(c *gin.Context) {
|
||||
defaultSettings, err := h.settingsManager.ResetAndSaveSettings()
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, "Failed to reset settings: "+err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, defaultSettings)
|
||||
}
|
||||
51
internal/handlers/task_handler.go
Normal file
51
internal/handlers/task_handler.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TaskHandler struct {
|
||||
taskService *task.Task
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewTaskHandler(taskService *task.Task, logger *logrus.Logger) *TaskHandler {
|
||||
return &TaskHandler{
|
||||
taskService: taskService,
|
||||
logger: logger.WithField("component", "TaskHandler📦"),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// GetTaskStatus
|
||||
// GET /admin/tasks/:id
|
||||
func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
|
||||
taskID := c.Param("id")
|
||||
if taskID == "" {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "task ID is required"))
|
||||
return
|
||||
}
|
||||
|
||||
taskStatus, err := h.taskService.GetStatus(taskID)
|
||||
if err != nil {
|
||||
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
|
||||
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))
|
||||
return
|
||||
}
|
||||
// [探針] 在返回給前端前,打印從存儲中讀取並解析後的 status 對象
|
||||
loggerWithTaskID := h.logger.WithField("task_id", taskID)
|
||||
loggerWithTaskID.Debugf("Status read from store, ABOUT TO BE SENT to frontend: %+v", taskStatus)
|
||||
// [探針] 手動序列化並打印
|
||||
if h.logger.Logger.IsLevelEnabled(logrus.DebugLevel) {
|
||||
jsonData, _ := json.Marshal(taskStatus)
|
||||
loggerWithTaskID.Debugf("Manually marshalled JSON to be sent to frontend: %s", string(jsonData))
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
51
internal/handlers/tokens_handler.go
Normal file
51
internal/handlers/tokens_handler.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Filename: internal/handlers/tokens_handler.go
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TokensHandler struct {
|
||||
db *gorm.DB
|
||||
tokenManager *service.TokenManager
|
||||
}
|
||||
|
||||
func NewTokensHandler(db *gorm.DB, tm *service.TokenManager) *TokensHandler {
|
||||
return &TokensHandler{
|
||||
db: db,
|
||||
tokenManager: tm,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TokensHandler) GetAllTokens(c *gin.Context) {
|
||||
tokensFromCache := h.tokenManager.GetAllTokens()
|
||||
//TODO 可以像KeyGroupResponse一样,创建一个TokenResponse DTO来整理数据
|
||||
response.Success(c, tokensFromCache)
|
||||
}
|
||||
|
||||
func (h *TokensHandler) UpdateTokens(c *gin.Context) {
|
||||
var incomingTokens []*models.TokenUpdateRequest
|
||||
if err := c.ShouldBindJSON(&incomingTokens); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.tokenManager.BatchUpdateTokens(incomingTokens); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, "Failed to update tokens: "+err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "Tokens updated successfully."})
|
||||
}
|
||||
|
||||
// [TODO]
|
||||
// func (h *TokensHandler) CreateToken(c *gin.Context) {
|
||||
// ... 数据库写操作 ...
|
||||
// h.tokenManager.Invalidate() // 写后,立即让缓存失效
|
||||
// }
|
||||
Reference in New Issue
Block a user