Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -1,4 +1,4 @@
// Filename: internal/handlers/apikey_handler.go
// Filename: internal/handlers/apikey_handler.go (最终决战版)
package handlers
import (
@@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo
}
}
// DTOs for API requests
type BulkAddKeysToGroupRequest struct {
KeyGroupID uint `json:"key_group_id" binding:"required"`
Keys string `json:"keys" binding:"required"`
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
ValidateOnImport bool `json:"validate_on_import"`
}
type BulkUnlinkKeysFromGroupRequest struct {
@@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct {
}
type BulkActionFilter struct {
Status []string `json:"status"` // Changed to slice to accept multiple statuses
Status []string `json:"status"`
}
type BulkActionRequest struct {
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
Filter BulkActionFilter `json:"filter" binding:"required"`
}
@@ -89,7 +88,8 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -104,7 +104,8 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -119,7 +120,8 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -134,7 +136,8 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -148,7 +151,8 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -172,7 +176,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
}
}
if len(ids) > 0 {
keys, err := h.apiKeyService.GetKeysByIds(ids)
keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
if err != nil {
response.Error(c, &errors.APIError{
HTTPStatus: http.StatusInternalServerError,
@@ -191,7 +195,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
if params.PageSize <= 0 {
params.PageSize = 20
}
result, err := h.apiKeyService.ListAPIKeys(&params)
result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
@@ -201,19 +205,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
// 1. Manually handle the path parameter.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// 2. Bind query parameters using the correctly tagged struct.
var params models.APIKeyQueryParams
if err := c.ShouldBindQuery(&params); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
return
}
// 3. Set server-side defaults and the path parameter.
if params.Page <= 0 {
params.Page = 1
}
@@ -221,15 +222,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
params.PageSize = 20
}
params.KeyGroupID = uint(groupID)
// 4. Call the service layer.
paginatedResult, err := h.apiKeyService.ListAPIKeys(&params)
paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
// 5. [THE FIX] Return a successful response using the standard `response.Success`
// and a gin.H map, as confirmed to exist in your project.
response.Success(c, gin.H{
"items": paginatedResult.Items,
"total": paginatedResult.Total,
@@ -239,20 +236,18 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
}
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
// Group ID is now correctly sourced from the URL path.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// The request body is now simpler, only needing the keys.
var req BulkTestKeysForGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Call the same underlying service, but with unambiguous context.
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -267,7 +262,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
}
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
// Route: PUT /keygroups/:id/apikeys/:keyId
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -284,8 +278,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Directly use the service to handle the logic
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -305,7 +298,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
return
}
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
@@ -313,7 +306,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
}
// RestoreKeysInGroup 恢复指定Key的接口
// POST /keygroups/:id/apikeys/restore
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -325,7 +317,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -339,14 +331,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
}
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
// POST /keygroups/:id/apikeys/restore-all-banned
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -360,48 +351,41 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
}
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
// Route: POST /keygroups/:id/bulk-actions
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
// 1. Parse GroupID from URL
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// 2. Bind the JSON payload to our new DTO
var req BulkActionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// 3. Central logic: based on the action, call the appropriate service method.
var task *task.Status
var apiErr *errors.APIError
switch req.Action {
case "revalidate":
// Assume keyValidationService has a method that accepts a filter
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
// [修正] 将请求的 context 传递给 service 层
task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
case "set_status":
if req.NewStatus == "" {
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
break
}
// Assume apiKeyService has a method to update status by filter
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
targetStatus := models.APIKeyStatus(req.NewStatus)
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
case "delete":
// Assume keyImportService has a method to unlink by filter
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
// [修正] 将请求的 context 传递给 service 层
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
default:
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
}
// 4. Handle errors from the switch block
if apiErr != nil {
response.Error(c, apiErr)
return
}
if err != nil {
// Attempt to parse it as a known APIError, otherwise, wrap it.
var parsedErr *errors.APIError
if errors.As(err, &parsedErr) {
response.Error(c, parsedErr)
@@ -410,21 +394,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
}
return
}
// 5. Return the task status on success
response.Success(c, task)
}
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
// Route: GET /keygroups/:id/apikeys/export
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
statuses := c.QueryArray("status")
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return

View File

@@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}
// GetChart 获取仪表盘的图表数据
// GetChart
func (h *DashboardHandler) GetChart(c *gin.Context) {
var groupID *uint
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
@@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
}
}
chartData, err := h.queryService.QueryHistoricalChart(groupID)
chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)
@@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
c.JSON(http.StatusOK, chartData)
}
// GetRequestStats 处理对“期间调用概览”的请求
// GetRequestStats
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
period := c.Param("period") // 从 URL 路径中获取 period
stats, err := h.queryService.GetRequestStatsForPeriod(period)
period := c.Param("period")
stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)

View File

@@ -2,6 +2,7 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/errors"
@@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das
}
}
// DTOs & 辅助函数
func isValidGroupName(name string) bool {
if name == "" {
return false
@@ -40,7 +40,6 @@ func isValidGroupName(name string) bool {
return match
}
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
type KeyGroupOperationalSettings struct {
EnableKeyCheck *bool `json:"enable_key_check"`
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
@@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct {
MaxRetries *int `json:"max_retries"`
EnableSmartGateway *bool `json:"enable_smart_gateway"`
}
type CreateKeyGroupRequest struct {
Name string `json:"name" binding:"required"`
DisplayName string `json:"display_name"`
@@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct {
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy bool `json:"enable_proxy"`
ChannelType string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
}
type UpdateKeyGroupRequest struct {
Name *string `json:"name"`
DisplayName *string `json:"display_name"`
@@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct {
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy *bool `json:"enable_proxy"`
ChannelType *string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
// M:N associations
AllowedUpstreams []string `json:"allowed_upstreams"`
AllowedModels []string `json:"allowed_models"`
}
type KeyGroupResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
@@ -96,36 +86,30 @@ type KeyGroupResponse struct {
AllowedModels []string `json:"allowed_models"`
AllowedUpstreams []string `json:"allowed_upstreams"`
}
// [NEW] Define the detailed response structure for a single group.
type KeyGroupDetailsResponse struct {
KeyGroupResponse
Settings *models.GroupSettings `json:"settings,omitempty"`
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
}
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
modelNames := make([]string, 0, len(mappings))
for _, mapping := range mappings {
if mapping != nil { // Safety check
if mapping != nil {
modelNames = append(modelNames, mapping.ModelName)
}
}
return modelNames
}
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
urls := make([]string, 0, len(upstreams))
for _, upstream := range upstreams {
if upstream != nil { // Safety check
if upstream != nil {
urls = append(urls, upstream.URL)
}
}
return urls
}
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
return KeyGroupResponse{
ID: group.ID,
@@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i
CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt,
Order: group.Order,
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
AllowedModels: transformModelsToStrings(group.AllowedModels),
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams),
}
}
// packGroupSettings is a helper to convert request-level operational settings
// into the model-level settings struct.
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
return &models.KeyGroupSettings{
EnableKeyCheck: settings.EnableKeyCheck,
@@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet
EnableSmartGateway: settings.EnableSmartGateway,
}
}
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup,
}
return group, nil
}
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
if req.Name != nil {
group.Name = *req.Name
@@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
go func() {
ctx := context.Background()
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
eventData, _ := json.Marshal(event)
h.store.Publish(models.TopicKeyStatusChanged, eventData)
_ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}()
}
@@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
return
}
// The core logic remains, as it's specific to creation.
p := bluemonday.StripTagsPolicy()
sanitizedDisplayName := p.Sanitize(req.DisplayName)
sanitizedDescription := p.Sanitize(req.Description)
@@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
}
// 统一的处理器可以处理两种情况:
// 1. GET /keygroups - 返回所有组的列表
// 2. GET /keygroups/:id - 返回指定ID的单个组
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
// Case 1: Get a single group
if idStr := c.Param("id"); idStr != "" {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, detailedResponse)
return
}
// Case 2: Get all groups
allGroups := h.groupManager.GetAllGroups()
responses := make([]KeyGroupResponse, 0, len(allGroups))
for _, group := range allGroups {
@@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, responses)
}
// UpdateKeyGroup
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
}
// DeleteKeyGroup
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
}
// GetKeyGroupStats
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
stats, err := h.queryService.GetGroupStats(group.ID)
stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
return
@@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
}
// 更新分组排序
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
var payload []service.UpdateOrderPayload
if err := c.ShouldBindJSON(&payload); err != nil {

View File

@@ -136,7 +136,6 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
var finalPromptTokens, finalCompletionTokens int
var actualRetries int = 0
defer func() {
// 如果一次尝试都未成功(例如,在第一次获取资源时就失败),则不记录日志
if lastUsedResources == nil {
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
return
@@ -151,44 +150,38 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
}
// 确保即使在成功的情况下如果recorder存在也记录最终的状态码
if finalRecorder != nil {
finalEvent.RequestLog.StatusCode = finalRecorder.Code
}
if !isSuccess {
// 将 finalProxyErr 的信息填充到 RequestLog 中
if finalProxyErr != nil {
finalEvent.Error = finalProxyErr // Error 字段用于事件传递,不会被序列化到数据库
finalEvent.Error = finalProxyErr
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
} else if finalRecorder != nil {
// 降级处理:如果 finalProxyErr 为空但 recorder 存在且失败
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
finalEvent.Error = apiErr
finalEvent.RequestLog.ErrorCode = apiErr.Code
finalEvent.RequestLog.ErrorMessage = apiErr.Message
}
}
// 将完整的事件发布
eventData, err := json.Marshal(finalEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
return
}
if err := h.store.Publish(models.TopicRequestFinished, eventData); err != nil {
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
}
}()
var maxRetries int
if isPreciseRouting {
// For precise routing, use the group's setting. If not set, fall back to the global setting.
if finalOpConfig.MaxRetries != nil {
maxRetries = *finalOpConfig.MaxRetries
} else {
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
} else {
// For BasePool (intelligent aggregation), *always* use the global setting.
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
totalAttempts := maxRetries + 1
@@ -332,7 +325,7 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
retryEvent.ErrorMessage = attemptErr.Message
}
eventData, _ := json.Marshal(retryEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
}
if finalRecorder != nil {
bodyBytes := finalRecorder.Body.Bytes()
@@ -368,7 +361,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
requestFinishedEvent.StatusCode = c.Writer.Status()
}
eventData, _ := json.Marshal(requestFinishedEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
}()
params := channel.SmartRequestParams{
CorrelationID: correlationID,
@@ -435,7 +428,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
}
}
if res != nil {
// [核心修正] 填充到内嵌的 RequestLog 结构体中
if res.APIKey != nil {
event.RequestLog.KeyID = &res.APIKey.ID
}
@@ -444,7 +436,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
}
if res.UpstreamEndpoint != nil {
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
// UpstreamURL 是事件传递字段,不是数据库字段,所以在这里赋值是正确的
event.UpstreamURL = &res.UpstreamEndpoint.URL
}
if res.ProxyConfig != nil {
@@ -464,9 +455,9 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
}
if isPreciseRouting {
return h.resourceService.GetResourceFromGroup(authToken, groupName)
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
} else {
return h.resourceService.GetResourceFromBasePool(authToken, modelName)
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
}
}

View File

@@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
return
}
taskStatus, err := h.taskService.GetStatus(taskID)
taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID)
if err != nil {
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))