diff --git a/internal/domain/proxy/handler.go b/internal/domain/proxy/handler.go index 0967ad4..8b98b2f 100644 --- a/internal/domain/proxy/handler.go +++ b/internal/domain/proxy/handler.go @@ -2,6 +2,7 @@ package proxy import ( + "context" "encoding/json" "gemini-balancer/internal/errors" "gemini-balancer/internal/models" @@ -49,7 +50,6 @@ func (h *handler) registerRoutes(rg *gin.RouterGroup) { } } -// --- 请求 DTO --- type CreateProxyConfigRequest struct { Address string `json:"address" binding:"required"` Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` @@ -64,12 +64,10 @@ type UpdateProxyConfigRequest struct { Description *string `json:"description"` } -// 单个检测的请求体 (与前端JS对齐) type CheckSingleProxyRequest struct { Proxy string `json:"proxy" binding:"required"` } -// 批量检测的请求体 type CheckAllProxiesRequest struct { Proxies []string `json:"proxies" binding:"required"` } @@ -84,7 +82,7 @@ func (h *handler) CreateProxyConfig(c *gin.Context) { } if req.Status == "" { - req.Status = "active" // 默认状态 + req.Status = "active" } proxyConfig := models.ProxyConfig{ @@ -98,7 +96,6 @@ func (h *handler) CreateProxyConfig(c *gin.Context) { response.Error(c, errors.ParseDBError(err)) return } - // 写操作后,发布事件并使缓存失效 h.publishAndInvalidate(proxyConfig.ID, "created") response.Created(c, proxyConfig) } @@ -199,17 +196,16 @@ func (h *handler) DeleteProxyConfig(c *gin.Context) { response.NoContent(c) } -// publishAndInvalidate 统一事件发布和缓存失效逻辑 func (h *handler) publishAndInvalidate(proxyID uint, action string) { go h.manager.invalidate() go func() { + ctx := context.Background() event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action} eventData, _ := json.Marshal(event) - _ = h.store.Publish(models.TopicProxyStatusChanged, eventData) + _ = h.store.Publish(ctx, models.TopicProxyStatusChanged, eventData) }() } -// 新的 Handler 方法和 DTO type SyncProxiesRequest struct { Proxies []string `json:"proxies"` } @@ -220,14 +216,12 @@ func (h *handler) SyncProxies(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies) + + taskStatus, err := h.manager.SyncProxiesInBackground(c.Request.Context(), req.Proxies) if err != nil { - if errors.Is(err, ErrTaskConflict) { - response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) } else { - response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error())) } return @@ -262,7 +256,7 @@ func (h *handler) CheckAllProxies(c *gin.Context) { concurrency := cfg.ProxyCheckConcurrency if concurrency <= 0 { - concurrency = 5 // 如果配置不合法,提供一个安全的默认值 + concurrency = 5 } results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency) response.Success(c, results) diff --git a/internal/domain/proxy/manager.go b/internal/domain/proxy/manager.go index 418c284..653f164 100644 --- a/internal/domain/proxy/manager.go +++ b/internal/domain/proxy/manager.go @@ -2,14 +2,13 @@ package proxy import ( + "context" "encoding/json" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/store" "gemini-balancer/internal/syncer" "gemini-balancer/internal/task" - - "context" "net" "net/http" "net/url" @@ -25,7 +24,7 @@ import ( const ( TaskTypeProxySync = "proxy_sync" - proxyChunkSize = 200 // 代理同步的批量大小 + proxyChunkSize = 200 ) type ProxyCheckResult struct { @@ -35,13 +34,11 @@ type ProxyCheckResult struct { ErrorMessage string `json:"error_message"` } -// managerCacheData type managerCacheData struct { ActiveProxies []*models.ProxyConfig ProxiesByID map[uint]*models.ProxyConfig } -// manager结构体 type manager struct { db *gorm.DB syncer *syncer.CacheSyncer[managerCacheData] @@ -80,21 +77,21 @@ func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskR } } -func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) { +func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) { resourceID := "global_proxy_sync" - taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0) + taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0) if err != nil { return nil, ErrTaskConflict } - go m.runProxySyncTask(taskStatus.ID, proxyStrings) + go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings) return taskStatus, nil } -func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) { +func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) { resourceID := "global_proxy_sync" var allProxies []models.ProxyConfig if err := m.db.Find(&allProxies).Error; err != nil { - m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err)) + m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err)) return } currentProxyMap := make(map[string]uint) @@ -125,19 +122,19 @@ func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) { } if len(idsToDelete) > 0 { if err := m.bulkDeleteByIDs(idsToDelete); err != nil { - m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err)) + m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err)) return } } if len(proxiesToAdd) > 0 { if err := m.bulkAdd(proxiesToAdd); err != nil { - m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err)) + m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err)) return } } result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)} - m.task.EndTaskByID(taskID, resourceID, result, nil) - m.publishChangeEvent("proxies_synced") + m.task.EndTaskByID(ctx, taskID, resourceID, result, nil) + m.publishChangeEvent(ctx, "proxies_synced") go m.invalidate() } @@ -184,14 +181,15 @@ func (m *manager) bulkDeleteByIDs(ids []uint) error { } return nil } + func (m *manager) bulkAdd(proxies []models.ProxyConfig) error { return m.db.CreateInBatches(proxies, proxyChunkSize).Error } -func (m *manager) publishChangeEvent(reason string) { +func (m *manager) publishChangeEvent(ctx context.Context, reason string) { event := models.ProxyStatusChangedEvent{Action: reason} eventData, _ := json.Marshal(event) - _ = m.store.Publish(models.TopicProxyStatusChanged, eventData) + _ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData) } func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) { diff --git a/internal/handlers/apikey_handler.go b/internal/handlers/apikey_handler.go index 878243c..60b8bf3 100644 --- a/internal/handlers/apikey_handler.go +++ b/internal/handlers/apikey_handler.go @@ -1,4 +1,4 @@ -// Filename: internal/handlers/apikey_handler.go +// Filename: internal/handlers/apikey_handler.go (最终决战版) package handlers import ( @@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo } } -// DTOs for API requests type BulkAddKeysToGroupRequest struct { KeyGroupID uint `json:"key_group_id" binding:"required"` Keys string `json:"keys" binding:"required"` - ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false + ValidateOnImport bool `json:"validate_on_import"` } type BulkUnlinkKeysFromGroupRequest struct { @@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct { } type BulkActionFilter struct { - Status []string `json:"status"` // Changed to slice to accept multiple statuses + Status []string `json:"status"` } type BulkActionRequest struct { Action string `json:"action" binding:"required,oneof=revalidate set_status delete"` - NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action + NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` Filter BulkActionFilter `json:"filter" binding:"required"` } @@ -89,7 +88,8 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport) + // [修正] 将请求的 context 传递给 service 层 + taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return @@ -104,7 +104,8 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys) + // [修正] 将请求的 context 传递给 service 层 + taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return @@ -119,7 +120,8 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys) + // [修正] 将请求的 context 传递给 service 层 + taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return @@ -134,7 +136,8 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys) + // [修正] 将请求的 context 传递给 service 层 + taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return @@ -148,7 +151,8 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys) + // [修正] 将请求的 context 传递给 service 层 + taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return @@ -172,7 +176,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) { } } if len(ids) > 0 { - keys, err := h.apiKeyService.GetKeysByIds(ids) + keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids) if err != nil { response.Error(c, &errors.APIError{ HTTPStatus: http.StatusInternalServerError, @@ -191,7 +195,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) { if params.PageSize <= 0 { params.PageSize = 20 } - result, err := h.apiKeyService.ListAPIKeys(¶ms) + result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms) if err != nil { response.Error(c, errors.ParseDBError(err)) return @@ -201,19 +205,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) { // ListKeysForGroup handles the GET /keygroups/:id/keys request. func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) { - // 1. Manually handle the path parameter. groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format")) return } - // 2. Bind query parameters using the correctly tagged struct. var params models.APIKeyQueryParams if err := c.ShouldBindQuery(¶ms); err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error())) return } - // 3. Set server-side defaults and the path parameter. if params.Page <= 0 { params.Page = 1 } @@ -221,15 +222,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) { params.PageSize = 20 } params.KeyGroupID = uint(groupID) - // 4. Call the service layer. - paginatedResult, err := h.apiKeyService.ListAPIKeys(¶ms) + paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms) if err != nil { response.Error(c, errors.ParseDBError(err)) return } - - // 5. [THE FIX] Return a successful response using the standard `response.Success` - // and a gin.H map, as confirmed to exist in your project. response.Success(c, gin.H{ "items": paginatedResult.Items, "total": paginatedResult.Total, @@ -239,20 +236,18 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) { } func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) { - // Group ID is now correctly sourced from the URL path. groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format")) return } - // The request body is now simpler, only needing the keys. var req BulkTestKeysForGroupRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - // Call the same underlying service, but with unambiguous context. - taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys) + // [修正] 将请求的 context 传递给 service 层 + taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return @@ -267,7 +262,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) { } // UpdateGroupAPIKeyMapping handles updating a key's status within a specific group. -// Route: PUT /keygroups/:id/apikeys/:keyId func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { @@ -284,8 +278,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - // Directly use the service to handle the logic - updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status) + updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status) if err != nil { var apiErr *errors.APIError if errors.As(err, &apiErr) { @@ -305,7 +298,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format")) return } - if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil { + if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil { response.Error(c, errors.ParseDBError(err)) return } @@ -313,7 +306,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) { } // RestoreKeysInGroup 恢复指定Key的接口 -// POST /keygroups/:id/apikeys/restore func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { @@ -325,7 +317,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs) + taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs) if err != nil { var apiErr *errors.APIError if errors.As(err, &apiErr) { @@ -339,14 +331,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) { } // RestoreAllBannedInGroup 一键恢复所有Banned Key的接口 -// POST /keygroups/:id/apikeys/restore-all-banned func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } - taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID)) + taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID)) if err != nil { var apiErr *errors.APIError if errors.As(err, &apiErr) { @@ -360,48 +351,41 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) { } // HandleBulkAction handles generic bulk actions on a key group based on server-side filters. -// Route: POST /keygroups/:id/bulk-actions func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) { - // 1. Parse GroupID from URL groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } - // 2. Bind the JSON payload to our new DTO var req BulkActionRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } - // 3. Central logic: based on the action, call the appropriate service method. var task *task.Status var apiErr *errors.APIError switch req.Action { case "revalidate": - // Assume keyValidationService has a method that accepts a filter - task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status) + // [修正] 将请求的 context 传递给 service 层 + task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status) case "set_status": if req.NewStatus == "" { apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action") break } - // Assume apiKeyService has a method to update status by filter - targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type - task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus) + targetStatus := models.APIKeyStatus(req.NewStatus) + task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus) case "delete": - // Assume keyImportService has a method to unlink by filter - task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status) + // [修正] 将请求的 context 传递给 service 层 + task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status) default: apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action) } - // 4. Handle errors from the switch block if apiErr != nil { response.Error(c, apiErr) return } if err != nil { - // Attempt to parse it as a known APIError, otherwise, wrap it. var parsedErr *errors.APIError if errors.As(err, &parsedErr) { response.Error(c, parsedErr) @@ -410,21 +394,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) { } return } - // 5. Return the task status on success response.Success(c, task) } // ExportKeysForGroup handles requests to export all keys for a group based on status filters. -// Route: GET /keygroups/:id/apikeys/export func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } - // Use QueryArray to correctly parse `status[]=active&status[]=cooldown` statuses := c.QueryArray("status") - keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses) + keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses) if err != nil { response.Error(c, errors.ParseDBError(err)) return diff --git a/internal/handlers/dashboard_handler.go b/internal/handlers/dashboard_handler.go index f989290..768c47e 100644 --- a/internal/handlers/dashboard_handler.go +++ b/internal/handlers/dashboard_handler.go @@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) { c.JSON(http.StatusOK, stats) } -// GetChart 获取仪表盘的图表数据 +// GetChart func (h *DashboardHandler) GetChart(c *gin.Context) { var groupID *uint if groupIDStr := c.Query("groupId"); groupIDStr != "" { @@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) { } } - chartData, err := h.queryService.QueryHistoricalChart(groupID) + chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID) if err != nil { apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error()) c.JSON(apiErr.HTTPStatus, apiErr) @@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) { c.JSON(http.StatusOK, chartData) } -// GetRequestStats 处理对“期间调用概览”的请求 +// GetRequestStats func (h *DashboardHandler) GetRequestStats(c *gin.Context) { - period := c.Param("period") // 从 URL 路径中获取 period - stats, err := h.queryService.GetRequestStatsForPeriod(period) + period := c.Param("period") + stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period) if err != nil { apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error()) c.JSON(apiErr.HTTPStatus, apiErr) diff --git a/internal/handlers/keygroup_handler.go b/internal/handlers/keygroup_handler.go index 20a3107..d53d67f 100644 --- a/internal/handlers/keygroup_handler.go +++ b/internal/handlers/keygroup_handler.go @@ -2,6 +2,7 @@ package handlers import ( + "context" "encoding/json" "fmt" "gemini-balancer/internal/errors" @@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das } } -// DTOs & 辅助函数 func isValidGroupName(name string) bool { if name == "" { return false @@ -40,7 +40,6 @@ func isValidGroupName(name string) bool { return match } -// KeyGroupOperationalSettings defines the shared operational settings for a key group. type KeyGroupOperationalSettings struct { EnableKeyCheck *bool `json:"enable_key_check"` KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"` @@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct { MaxRetries *int `json:"max_retries"` EnableSmartGateway *bool `json:"enable_smart_gateway"` } - type CreateKeyGroupRequest struct { Name string `json:"name" binding:"required"` DisplayName string `json:"display_name"` @@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct { PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"` EnableProxy bool `json:"enable_proxy"` ChannelType string `json:"channel_type"` - - // Embed shared operational settings KeyGroupOperationalSettings } - type UpdateKeyGroupRequest struct { Name *string `json:"name"` DisplayName *string `json:"display_name"` @@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct { PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"` EnableProxy *bool `json:"enable_proxy"` ChannelType *string `json:"channel_type"` - - // Embed shared operational settings KeyGroupOperationalSettings - - // M:N associations AllowedUpstreams []string `json:"allowed_upstreams"` AllowedModels []string `json:"allowed_models"` } - type KeyGroupResponse struct { ID uint `json:"id"` Name string `json:"name"` @@ -96,36 +86,30 @@ type KeyGroupResponse struct { AllowedModels []string `json:"allowed_models"` AllowedUpstreams []string `json:"allowed_upstreams"` } - -// [NEW] Define the detailed response structure for a single group. type KeyGroupDetailsResponse struct { KeyGroupResponse Settings *models.GroupSettings `json:"settings,omitempty"` RequestConfig *models.RequestConfig `json:"request_config,omitempty"` } -// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names. func transformModelsToStrings(mappings []*models.GroupModelMapping) []string { modelNames := make([]string, 0, len(mappings)) for _, mapping := range mappings { - if mapping != nil { // Safety check + if mapping != nil { modelNames = append(modelNames, mapping.ModelName) } } return modelNames } - -// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs. func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string { urls := make([]string, 0, len(upstreams)) for _, upstream := range upstreams { - if upstream != nil { // Safety check + if upstream != nil { urls = append(urls, upstream.URL) } } return urls } - func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse { return KeyGroupResponse{ ID: group.ID, @@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i CreatedAt: group.CreatedAt, UpdatedAt: group.UpdatedAt, Order: group.Order, - AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper - AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper + AllowedModels: transformModelsToStrings(group.AllowedModels), + AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), } } - -// packGroupSettings is a helper to convert request-level operational settings -// into the model-level settings struct. func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings { return &models.KeyGroupSettings{ EnableKeyCheck: settings.EnableKeyCheck, @@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet EnableSmartGateway: settings.EnableSmartGateway, } } - func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) { id, err := strconv.Atoi(c.Param("id")) if err != nil { @@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, } return group, nil } - func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) { if req.Name != nil { group.Name = *req.Name @@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou // publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event. func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) { go func() { + ctx := context.Background() event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason} eventData, _ := json.Marshal(event) - h.store.Publish(models.TopicKeyStatusChanged, eventData) + _ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData) }() } @@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) { return } - // The core logic remains, as it's specific to creation. p := bluemonday.StripTagsPolicy() sanitizedDisplayName := p.Sanitize(req.DisplayName) sanitizedDescription := p.Sanitize(req.Description) @@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) { response.Created(c, h.newKeyGroupResponse(keyGroup, 0)) } -// 统一的处理器可以处理两种情况: // 1. GET /keygroups - 返回所有组的列表 // 2. GET /keygroups/:id - 返回指定ID的单个组 func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) { - // Case 1: Get a single group if idStr := c.Param("id"); idStr != "" { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { @@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) { response.Success(c, detailedResponse) return } - // Case 2: Get all groups allGroups := h.groupManager.GetAllGroups() responses := make([]KeyGroupResponse, 0, len(allGroups)) for _, group := range allGroups { @@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) { response.Success(c, responses) } -// UpdateKeyGroup func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { @@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) { response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount)) } -// DeleteKeyGroup func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { @@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) { response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)}) } -// GetKeyGroupStats func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { response.Error(c, apiErr) return } - stats, err := h.queryService.GetGroupStats(group.ID) + + stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error())) return @@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) { response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount)) } -// 更新分组排序 func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) { var payload []service.UpdateOrderPayload if err := c.ShouldBindJSON(&payload); err != nil { diff --git a/internal/handlers/proxy_handler.go b/internal/handlers/proxy_handler.go index 0c4a2fc..853ae8c 100644 --- a/internal/handlers/proxy_handler.go +++ b/internal/handlers/proxy_handler.go @@ -136,7 +136,6 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, var finalPromptTokens, finalCompletionTokens int var actualRetries int = 0 defer func() { - // 如果一次尝试都未成功(例如,在第一次获取资源时就失败),则不记录日志 if lastUsedResources == nil { h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.") return @@ -151,44 +150,38 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, finalEvent.RequestLog.CompletionTokens = finalCompletionTokens } - // 确保即使在成功的情况下,如果recorder存在,也记录最终的状态码 if finalRecorder != nil { finalEvent.RequestLog.StatusCode = finalRecorder.Code } if !isSuccess { - // 将 finalProxyErr 的信息填充到 RequestLog 中 if finalProxyErr != nil { - finalEvent.Error = finalProxyErr // Error 字段用于事件传递,不会被序列化到数据库 + finalEvent.Error = finalProxyErr finalEvent.RequestLog.ErrorCode = finalProxyErr.Code finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message } else if finalRecorder != nil { - // 降级处理:如果 finalProxyErr 为空但 recorder 存在且失败 apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.") finalEvent.Error = apiErr finalEvent.RequestLog.ErrorCode = apiErr.Code finalEvent.RequestLog.ErrorMessage = apiErr.Message } } - // 将完整的事件发布 eventData, err := json.Marshal(finalEvent) if err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.") return } - if err := h.store.Publish(models.TopicRequestFinished, eventData); err != nil { + if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil { h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.") } }() var maxRetries int if isPreciseRouting { - // For precise routing, use the group's setting. If not set, fall back to the global setting. if finalOpConfig.MaxRetries != nil { maxRetries = *finalOpConfig.MaxRetries } else { maxRetries = h.settingsManager.GetSettings().MaxRetries } } else { - // For BasePool (intelligent aggregation), *always* use the global setting. maxRetries = h.settingsManager.GetSettings().MaxRetries } totalAttempts := maxRetries + 1 @@ -332,7 +325,7 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, retryEvent.ErrorMessage = attemptErr.Message } eventData, _ := json.Marshal(retryEvent) - _ = h.store.Publish(models.TopicRequestFinished, eventData) + _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) } if finalRecorder != nil { bodyBytes := finalRecorder.Body.Bytes() @@ -368,7 +361,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso requestFinishedEvent.StatusCode = c.Writer.Status() } eventData, _ := json.Marshal(requestFinishedEvent) - _ = h.store.Publish(models.TopicRequestFinished, eventData) + _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData) }() params := channel.SmartRequestParams{ CorrelationID: correlationID, @@ -435,7 +428,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI } } if res != nil { - // [核心修正] 填充到内嵌的 RequestLog 结构体中 if res.APIKey != nil { event.RequestLog.KeyID = &res.APIKey.ID } @@ -444,7 +436,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI } if res.UpstreamEndpoint != nil { event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID - // UpstreamURL 是事件传递字段,不是数据库字段,所以在这里赋值是正确的 event.UpstreamURL = &res.UpstreamEndpoint.URL } if res.ProxyConfig != nil { @@ -464,9 +455,9 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string, return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context") } if isPreciseRouting { - return h.resourceService.GetResourceFromGroup(authToken, groupName) + return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName) } else { - return h.resourceService.GetResourceFromBasePool(authToken, modelName) + return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName) } } diff --git a/internal/handlers/task_handler.go b/internal/handlers/task_handler.go index 71c36d1..0f2d665 100644 --- a/internal/handlers/task_handler.go +++ b/internal/handlers/task_handler.go @@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) { return } - taskStatus, err := h.taskService.GetStatus(taskID) + taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID) if err != nil { // TODO 可以根据 service 层返回的具体错误类型进行更精细的处理 response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error())) diff --git a/internal/repository/key_cache.go b/internal/repository/key_cache.go index a05f10c..132e789 100644 --- a/internal/repository/key_cache.go +++ b/internal/repository/key_cache.go @@ -2,6 +2,7 @@ package repository import ( + "context" "encoding/json" "fmt" "gemini-balancer/internal/models" @@ -22,7 +23,7 @@ const ( BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown" ) -func (r *gormKeyRepository) LoadAllKeysToStore() error { +func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error { r.logger.Info("Starting to load all keys and associations into cache, including polling structures...") var allMappings []*models.GroupAPIKeyMapping if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil { @@ -48,7 +49,7 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error { } activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping) - pipe := r.store.Pipeline() + pipe := r.store.Pipeline(context.Background()) detailsToSet := make(map[string][]byte) var allGroups []*models.KeyGroup if err := r.db.Find(&allGroups).Error; err == nil { @@ -100,14 +101,14 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error { pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...) - go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) + go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) } if err := pipe.Exec(); err != nil { return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err) } for key, value := range detailsToSet { - if err := r.store.Set(key, value, 0); err != nil { + if err := r.store.Set(context.Background(), key, value, 0); err != nil { r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key) } } @@ -124,16 +125,16 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error { if err != nil { return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err) } - return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0) + return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0) } -func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error { - groupIDs, err := r.GetGroupsForKey(key.ID) +func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error { + groupIDs, err := r.GetGroupsForKey(ctx, key.ID) if err != nil { r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err) } - pipe := r.store.Pipeline() + pipe := r.store.Pipeline(ctx) pipe.Del(fmt.Sprintf(KeyDetails, key.ID)) for _, groupID := range groupIDs { @@ -144,13 +145,13 @@ func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error { pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) - go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) + go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) } return pipe.Exec() } func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error { - pipe := r.store.Pipeline() + pipe := r.store.Pipeline(context.Background()) activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID) pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID) if mapping.Status == models.StatusActive { @@ -159,7 +160,7 @@ func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIK return pipe.Exec() } -func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error { +func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error { if len(mappings) == 0 { return nil } @@ -184,7 +185,7 @@ func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.Group } groupUpdates[mapping.KeyGroupID] = update } - pipe := r.store.Pipeline() + pipe := r.store.Pipeline(context.Background()) var pipelineError error for groupID, updates := range groupUpdates { activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID) diff --git a/internal/repository/key_crud.go b/internal/repository/key_crud.go index 00df938..1a9b6d5 100644 --- a/internal/repository/key_crud.go +++ b/internal/repository/key_crud.go @@ -7,6 +7,7 @@ import ( "fmt" "gemini-balancer/internal/models" + "context" "math/rand" "strings" "time" @@ -115,7 +116,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error { } func (r *gormKeyRepository) HardDeleteByID(id uint) error { - key, err := r.GetKeyByID(id) // This now returns a decrypted key + key, err := r.GetKeyByID(id) if err != nil { return err } @@ -125,7 +126,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error { if err != nil { return err } - if err := r.removeStoreCacheForKey(key); err != nil { + if err := r.removeStoreCacheForKey(context.Background(), key); err != nil { r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err) } return nil @@ -140,16 +141,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error hash := sha256.Sum256([]byte(v)) hashes[i] = hex.EncodeToString(hash[:]) } - // Find the full key objects first to update the cache later. var keysToDelete []models.APIKey - // [MODIFIED] Find by hash. if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil { return 0, err } if len(keysToDelete) == 0 { return 0, nil } - // Decrypt them to ensure cache has plaintext if needed. if err := r.decryptKeys(keysToDelete); err != nil { r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err) } @@ -167,7 +165,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error return 0, err } for i := range keysToDelete { - if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil { + if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil { r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err) } } diff --git a/internal/repository/key_maintenance.go b/internal/repository/key_maintenance.go index e2b840c..0effe6c 100644 --- a/internal/repository/key_maintenance.go +++ b/internal/repository/key_maintenance.go @@ -2,6 +2,7 @@ package repository import ( + "context" "crypto/sha256" "encoding/hex" "gemini-balancer/internal/models" @@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) { } result := db.Delete(&models.APIKey{}, orphanKeyIDs) - //result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs) if result.Error != nil { return 0, result.Error } for i := range keysToDelete { - if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil { + // [修正] 使用 context.Background() 调用已更新的缓存清理函数 + if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil { r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err) } } @@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) { return keys, nil } -func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error { +func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error { err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { result := tx.Model(&models.APIKey{}). Where("id = ?", keyID). @@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA if err == nil { r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID) go func() { - if err := r.LoadAllKeysToStore(); err != nil { + if err := r.LoadAllKeysToStore(context.Background()); err != nil { r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err) } }() diff --git a/internal/repository/key_mapping.go b/internal/repository/key_mapping.go index 79d6723..edda449 100644 --- a/internal/repository/key_mapping.go +++ b/internal/repository/key_mapping.go @@ -2,6 +2,7 @@ package repository import ( + "context" "crypto/sha256" "encoding/hex" "errors" @@ -14,7 +15,7 @@ import ( "gorm.io/gorm/clause" ) -func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error { +func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error { if len(keyIDs) == 0 { return nil } @@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error { } for _, keyID := range keyIDs { - r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID) + r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID) } return nil } -func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) { +func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) { if len(keyIDs) == 0 { return 0, nil } @@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID) for _, keyID := range keyIDs { - r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID) - r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID))) + r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID) + r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID))) } return unlinkedCount, nil } -func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) { +func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) { cacheKey := fmt.Sprintf("key:%d:groups", keyID) - strGroupIDs, err := r.store.SMembers(cacheKey) + strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey) if err != nil || len(strGroupIDs) == 0 { var groupIDs []uint dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error @@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) { for _, id := range groupIDs { interfaceSlice = append(interfaceSlice, id) } - r.store.SAdd(cacheKey, interfaceSlice...) + r.store.SAdd(context.Background(), cacheKey, interfaceSlice...) } return groupIDs, nil } @@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey return &mapping, err } -func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error { +func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error { err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { return tx.Save(mapping).Error }) diff --git a/internal/repository/key_selector.go b/internal/repository/key_selector.go index a49ce21..b2c78f0 100644 --- a/internal/repository/key_selector.go +++ b/internal/repository/key_selector.go @@ -1,7 +1,8 @@ -// Filename: internal/repository/key_selector.go +// Filename: internal/repository/key_selector.go (经审查后最终修复版) package repository import ( + "context" "crypto/sha1" "encoding/json" "errors" @@ -23,19 +24,18 @@ const ( ) // SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。 - -func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) { +func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) { var keyIDStr string var err error switch group.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID) - keyIDStr, err = r.store.Rotate(sequentialKey) + keyIDStr, err = r.store.Rotate(ctx, sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(KeyGroupLRU, group.ID) - results, zerr := r.store.ZRange(lruKey, 0, 0) + results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) if zerr == nil && len(results) > 0 { keyIDStr = results[0] } @@ -44,11 +44,11 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models. case models.StrategyRandom: mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID) - keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey) + keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) default: // 默认或未指定策略时,使用基础的随机策略 activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) - keyIDStr, err = r.store.SRandMember(activeKeySetKey) + keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey) } if err != nil { @@ -65,27 +65,25 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models. keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) - apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID) + apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if err != nil { r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID) - // TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group) return nil, nil, err } if group.PollingStrategy == models.StrategyWeighted { - go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID)) + go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID)) } return apiKey, mapping, nil } // SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。 -func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { - // 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离 +func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { poolID := generatePoolID(pool.CandidateGroups) log := r.logger.WithField("pool_id", poolID) - if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil { + if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil { log.WithError(err).Error("Failed to ensure BasePool cache exists.") return nil, nil, err } @@ -96,10 +94,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod switch pool.PollingStrategy { case models.StrategySequential: sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) - keyIDStr, err = r.store.Rotate(sequentialKey) + keyIDStr, err = r.store.Rotate(ctx, sequentialKey) case models.StrategyWeighted: lruKey := fmt.Sprintf(BasePoolLRU, poolID) - results, zerr := r.store.ZRange(lruKey, 0, 0) + results, zerr := r.store.ZRange(ctx, lruKey, 0, 0) if zerr == nil && len(results) > 0 { keyIDStr = results[0] } @@ -107,12 +105,11 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod case models.StrategyRandom: mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID) cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) - keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey) - default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案 + keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey) + default: log.Warnf("Default polling strategy triggered inside selection. This should be rare.") - sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) - keyIDStr, err = r.store.LIndex(sequentialKey, 0) + keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0) } if err != nil { @@ -128,12 +125,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) for _, group := range pool.CandidateGroups { - apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID) + apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if cacheErr == nil && apiKey != nil && mapping != nil { - if pool.PollingStrategy == models.StrategyWeighted { - - go r.updateKeyUsageTimestampForPool(poolID, uint(keyID)) + go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID)) } return apiKey, group, nil } @@ -144,42 +139,39 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod } // ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构 -func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error { +func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error { listKey := fmt.Sprintf(BasePoolSequential, poolID) - // --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 --- - exists, err := r.store.Exists(listKey) + exists, err := r.store.Exists(ctx, listKey) if err != nil { r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID) - return err // 直接返回读取错误 + return err } if exists { - val, err := r.store.LIndex(listKey, 0) + val, err := r.store.LIndex(ctx, listKey, 0) if err != nil { - // 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建 r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID) } else { if val == EmptyPoolPlaceholder { - return gorm.ErrRecordNotFound // 已知为空,直接返回 + return gorm.ErrRecordNotFound } - return nil // 缓存有效,直接返回 + return nil } } - // --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 --- + lockKey := fmt.Sprintf("lock:basepool:%s", poolID) - acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时 + acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second) if err != nil { r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.") return err } if !acquired { - // 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建 time.Sleep(100 * time.Millisecond) - return r.ensureBasePoolCacheExists(pool, poolID) + return r.ensureBasePoolCacheExists(ctx, pool, poolID) } - defer r.store.Del(lockKey) // 确保在函数退出时释放锁 - // 双重检查,防止在获取锁的间隙,已有其他协程完成了构建 - if exists, _ := r.store.Exists(listKey); exists { + defer r.store.Del(context.Background(), lockKey) + + if exists, _ := r.store.Exists(ctx, listKey); exists { return nil } r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID) @@ -187,22 +179,15 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str lruMembers := make(map[string]float64) for _, group := range pool.CandidateGroups { activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) - groupKeyIDs, err := r.store.SMembers(activeKeySetKey) - - // --- [核心修正] --- - // 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。 - // 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。 + groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey) if err != nil { r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID) - // 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”, - // 从而给了下一次请求一个全新的、成功的机会。 return err } - // 只有在 SMembers 成功时,才继续处理 allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...) for _, keyIDStr := range groupKeyIDs { keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) - _, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID) + _, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID) if err == nil && mapping != nil { var score float64 if mapping.LastUsedAt != nil { @@ -213,12 +198,9 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str } } - // --- [逻辑修正] --- - // 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下, - // 才允许写入“毒丸”。 if len(allActiveKeyIDs) == 0 { r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID) - pipe := r.store.Pipeline() + pipe := r.store.Pipeline(ctx) pipe.LPush(listKey, EmptyPoolPlaceholder) pipe.Expire(listKey, EmptyCacheTTL) if err := pipe.Exec(); err != nil { @@ -226,14 +208,10 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str } return gorm.ErrRecordNotFound } - // 使用管道填充所有轮询结构 - pipe := r.store.Pipeline() - // 1. 顺序 - pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...) - // 2. 随机 - pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...) - // 设置合理的过期时间,例如5分钟,以防止孤儿数据 + pipe := r.store.Pipeline(ctx) + pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...) + pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...) pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL) @@ -244,17 +222,22 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str } if len(lruMembers) > 0 { - r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers) + if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil { + r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID) + } } return nil } // updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET -func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) { +func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) { lruKey := fmt.Sprintf(BasePoolLRU, poolID) - r.store.ZAdd(lruKey, map[string]float64{ + err := r.store.ZAdd(ctx, lruKey, map[string]float64{ strconv.FormatUint(uint64(keyID), 10): nowMilli(), }) + if err != nil { + r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID) + } } // generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID @@ -285,8 +268,8 @@ func nowMilli() float64 { } // getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。 -func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) { - apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID)) +func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) { + apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err) } @@ -295,7 +278,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err) } - mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID)) + mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID)) if err != nil { return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err) } diff --git a/internal/repository/key_writer.go b/internal/repository/key_writer.go index 85c7728..d2f1d6e 100644 --- a/internal/repository/key_writer.go +++ b/internal/repository/key_writer.go @@ -1,7 +1,9 @@ // Filename: internal/repository/key_writer.go + package repository import ( + "context" "fmt" "gemini-balancer/internal/errors" "gemini-balancer/internal/models" @@ -9,7 +11,7 @@ import ( "time" ) -func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) { +func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) { lruKey := fmt.Sprintf(KeyGroupLRU, groupID) timestamp := float64(time.Now().UnixMilli()) @@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) { strconv.FormatUint(uint64(keyID), 10): timestamp, } - if err := r.store.ZAdd(lruKey, members); err != nil { + if err := r.store.ZAdd(ctx, lruKey, members); err != nil { r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID) } } -func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) { +func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) { r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus) - r.updatePollingCachesLogic(groupID, keyID, newStatus) + r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus) } -func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) { +func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) { r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus) - r.updatePollingCachesLogic(groupID, keyID, newStatus) + r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus) } -func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) { +func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) { keyIDStr := strconv.FormatUint(uint64(keyID), 10) sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID) lruKey := fmt.Sprintf(KeyGroupLRU, groupID) mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID) - _ = r.store.LRem(sequentialKey, 0, keyIDStr) - _ = r.store.ZRem(lruKey, keyIDStr) - _ = r.store.SRem(mainPoolKey, keyIDStr) - _ = r.store.SRem(cooldownPoolKey, keyIDStr) + _ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr) + _ = r.store.ZRem(ctx, lruKey, keyIDStr) + _ = r.store.SRem(ctx, mainPoolKey, keyIDStr) + _ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr) if newStatus == models.StatusActive { - if err := r.store.LPush(sequentialKey, keyIDStr); err != nil { + if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil { r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID) } members := map[string]float64{keyIDStr: 0} - if err := r.store.ZAdd(lruKey, members); err != nil { + if err := r.store.ZAdd(ctx, lruKey, members); err != nil { r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID) } - if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil { + if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil { r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID) } } } -// UpdateKeyStatusAfterRequest is the new central hub for handling feedback. -func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) { +func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) { if success { if group.PollingStrategy == models.StrategyWeighted { - go r.UpdateKeyUsageTimestamp(group.ID, key.ID) + go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID) } return } @@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, } r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message) - // This call is correct. It uses the synchronous, direct method. - r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown) + r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown) } diff --git a/internal/repository/repository.go b/internal/repository/repository.go index c62daf1..239d86f 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -1,7 +1,8 @@ -// Filename: internal/repository/repository.go +// Filename: internal/repository/repository.go (经审查后最终修复版) package repository import ( + "context" "gemini-balancer/internal/crypto" "gemini-balancer/internal/errors" "gemini-balancer/internal/models" @@ -22,8 +23,8 @@ type BasePool struct { type KeyRepository interface { // --- 核心选取与调度 --- key_selector - SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) - SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) + SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) + SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) // --- 加密与解密 --- key_crud Decrypt(key *models.APIKey) error @@ -37,16 +38,16 @@ type KeyRepository interface { GetKeyByID(id uint) (*models.APIKey, error) GetKeyByValue(keyValue string) (*models.APIKey, error) GetKeysByValues(keyValues []string) ([]models.APIKey, error) - GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key + GetKeysByIDs(ids []uint) ([]models.APIKey, error) GetKeysByGroup(groupID uint) ([]models.APIKey, error) CountByGroup(groupID uint) (int64, error) // --- 多对多关系管理 --- key_mapping - LinkKeysToGroup(groupID uint, keyIDs []uint) error - UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error) - GetGroupsForKey(keyID uint) ([]uint, error) + LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error + UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error) + GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) - UpdateMapping(mapping *models.GroupAPIKeyMapping) error + UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) @@ -55,8 +56,8 @@ type KeyRepository interface { UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error // --- 缓存管理 --- key_cache - LoadAllKeysToStore() error - HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error + LoadAllKeysToStore(ctx context.Context) error + HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error // --- 维护与后台任务 --- key_maintenance StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error @@ -65,16 +66,14 @@ type KeyRepository interface { DeleteOrphanKeys() (int64, error) DeleteOrphanKeysTx(tx *gorm.DB) (int64, error) GetActiveMasterKeys() ([]*models.APIKey, error) - UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error + UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error HardDeleteSoftDeletedBefore(date time.Time) (int64, error) // --- 轮询策略的"写"操作 --- key_writer - UpdateKeyUsageTimestamp(groupID, keyID uint) - // 同步更新缓存,供核心业务使用 - SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) - // 异步更新缓存,供事件订阅者使用 - HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) - UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) + UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) + SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) + HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) + UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) } type GroupRepository interface { diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 6beff2f..fd058f9 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -2,6 +2,7 @@ package scheduler import ( + "context" "gemini-balancer/internal/repository" "gemini-balancer/internal/service" "time" @@ -15,7 +16,6 @@ type Scheduler struct { logger *logrus.Entry statsService *service.StatsService keyRepo repository.KeyRepository - // healthCheckService *service.HealthCheckService // 健康检查任务预留 } func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler { @@ -32,11 +32,13 @@ func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyReposito func (s *Scheduler) Start() { s.logger.Info("Starting scheduler and registering jobs...") - // --- 任务注册 --- + // 任务一:每小时执行一次的统计聚合 // 使用CRON表达式,精确定义“每小时的第5分钟”执行 _, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() { s.logger.Info("Executing hourly request stats aggregation...") - if err := s.statsService.AggregateHourlyStats(); err != nil { + // 为后台定时任务创建一个新的、空的 context + ctx := context.Background() + if err := s.statsService.AggregateHourlyStats(ctx); err != nil { s.logger.WithError(err).Error("Hourly stats aggregation failed.") } else { s.logger.Info("Hourly stats aggregation completed successfully.") @@ -46,23 +48,14 @@ func (s *Scheduler) Start() { s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err) } - // 任务二:(预留) 自动健康检查 (例如:每10分钟一次) - /* - _, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() { - s.logger.Info("Executing periodic health check for all groups...") - // s.healthCheckService.StartGlobalCheckTask() // 伪代码 - }) - if err != nil { - s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err) - } - */ - // [NEW] --- 任务三: 清理软删除的API Keys --- + // 任务二:(预留) 自动健康检查 + + // 任务三:每日执行一次的软删除Key清理 // Executes once daily at 3:15 AM UTC. _, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() { s.logger.Info("Executing daily cleanup of soft-deleted API keys...") - // Let's assume a retention period of 7 days for now. - // In a real scenario, this should come from settings. + // [假设保留7天,实际应来自配置 const retentionDays = 7 count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays)) @@ -77,9 +70,8 @@ func (s *Scheduler) Start() { if err != nil { s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err) } - // --- 任务注册结束 --- - s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程 + s.gocronScheduler.StartAsync() s.logger.Info("Scheduler started.") } diff --git a/internal/service/analytics_service.go b/internal/service/analytics_service.go index a3375c9..3c0cad9 100644 --- a/internal/service/analytics_service.go +++ b/internal/service/analytics_service.go @@ -1,8 +1,8 @@ // Filename: internal/service/analytics_service.go - package service import ( + "context" "encoding/json" "fmt" "gemini-balancer/internal/db/dialect" @@ -43,7 +43,7 @@ func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d di } func (s *AnalyticsService) Start() { - s.wg.Add(2) // 2 (flushLoop, eventListener) + s.wg.Add(2) go s.flushLoop() go s.eventListener() s.logger.Info("AnalyticsService (Command Side) started.") @@ -53,13 +53,13 @@ func (s *AnalyticsService) Stop() { close(s.stopChan) s.wg.Wait() s.logger.Info("AnalyticsService stopped. Performing final data flush...") - s.flushToDB() // 停止前刷盘 + s.flushToDB() s.logger.Info("AnalyticsService final data flush completed.") } func (s *AnalyticsService) eventListener() { defer s.wg.Done() - sub, err := s.store.Subscribe(models.TopicRequestFinished) + sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) return @@ -87,9 +87,10 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve if event.RequestLog.GroupID == nil { return } + ctx := context.Background() key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15")) fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName) - pipe := s.store.Pipeline() + pipe := s.store.Pipeline(ctx) pipe.HIncrBy(key, fieldPrefix+":requests", 1) if event.RequestLog.IsSuccess { pipe.HIncrBy(key, fieldPrefix+":success", 1) @@ -120,6 +121,7 @@ func (s *AnalyticsService) flushLoop() { } func (s *AnalyticsService) flushToDB() { + ctx := context.Background() now := time.Now().UTC() keysToFlush := []string{ fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")), @@ -127,7 +129,7 @@ func (s *AnalyticsService) flushToDB() { } for _, key := range keysToFlush { - data, err := s.store.HGetAll(key) + data, err := s.store.HGetAll(ctx, key) if err != nil || len(data) == 0 { continue } @@ -136,15 +138,15 @@ func (s *AnalyticsService) flushToDB() { if len(statsToFlush) > 0 { upsertClause := s.dialect.OnConflictUpdateAll( - []string{"time", "group_id", "model_name"}, // conflict columns - []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns + []string{"time", "group_id", "model_name"}, + []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, ) - err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error + err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error if err != nil { s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err) } else { s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key) - _ = s.store.HDel(key, parsedFields...) + _ = s.store.HDel(ctx, key, parsedFields...) } } } diff --git a/internal/service/apikey_service.go b/internal/service/apikey_service.go index 63a6385..c75f20f 100644 --- a/internal/service/apikey_service.go +++ b/internal/service/apikey_service.go @@ -1,8 +1,8 @@ // Filename: internal/service/apikey_service.go - package service import ( + "context" "encoding/json" "errors" "fmt" @@ -29,7 +29,6 @@ const ( TaskTypeUpdateStatusByFilter = "update_status_by_filter" ) -// DTOs & Constants const ( TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" ) @@ -83,7 +82,6 @@ func NewAPIKeyService( gm *GroupManager, logger *logrus.Logger, ) *APIKeyService { - logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo) return &APIKeyService{ db: db, keyRepo: repo, @@ -99,22 +97,22 @@ func NewAPIKeyService( } func (s *APIKeyService) Start() { - requestSub, err := s.store.Subscribe(models.TopicRequestFinished) + requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) return } - masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged) + masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err) return } - keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged) + keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) return } - importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted) + importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err) return @@ -177,6 +175,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil { return } + ctx := context.Background() if event.RequestLog.IsSuccess { mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID) if err != nil { @@ -194,17 +193,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) now := time.Now() mapping.LastUsedAt = &now - if err := s.keyRepo.UpdateMapping(mapping); err != nil { + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err) return } if statusChanged { - go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use") + go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use") } return } if event.Error != nil { s.judgeKeyErrors( + ctx, event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, @@ -215,6 +215,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) } func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) { + ctx := context.Background() log := s.logger.WithFields(logrus.Fields{ "group_id": event.GroupID, "key_id": event.KeyID, @@ -222,11 +223,11 @@ func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChange "reason": event.ChangeReason, }) log.Info("Received KeyStatusChangedEvent, will update polling caches.") - s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus) + s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus) log.Info("Polling caches updated based on health check event.") } -func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { +func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { changeEvent := models.KeyStatusChangedEvent{ KeyID: keyID, GroupID: groupID, @@ -236,13 +237,12 @@ func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, ChangedAt: time.Now(), } eventData, _ := json.Marshal(changeEvent) - if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil { + if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err) } } -func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) { - // --- Path 1: High-performance DB pagination (no keyword) --- +func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) { if params.Keyword == "" { items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params) if err != nil { @@ -260,14 +260,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate TotalPages: totalPages, }, nil } - // --- Path 2: In-memory search (keyword present) --- s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword) - // To get all keys, we fetch all IDs first, then get their full details. var statusesToFilter []string if params.Status != "" { statusesToFilter = append(statusesToFilter, params.Status) } else { - statusesToFilter = append(statusesToFilter, "all") // "all" gets every status + statusesToFilter = append(statusesToFilter, "all") } allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter) if err != nil { @@ -277,14 +275,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil } - // This is the heavy operation: getting all keys and decrypting them. allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs) if err != nil { return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err) } - // We also need mappings to build the final `APIKeyDetails`. var allMappings []models.GroupAPIKeyMapping - err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error + err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error if err != nil { return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err) } @@ -292,7 +288,6 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate for i := range allMappings { mappingMap[allMappings[i].APIKeyID] = &allMappings[i] } - // Filter the results in memory. var filteredItems []*models.APIKeyDetails for _, key := range allKeys { if strings.Contains(key.APIKey, params.Keyword) { @@ -312,11 +307,9 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate } } } - // Sort the filtered results to ensure consistent pagination (by ID descending). sort.Slice(filteredItems, func(i, j int) bool { return filteredItems[i].ID > filteredItems[j].ID }) - // Manually paginate the filtered results. total := int64(len(filteredItems)) start := (params.Page - 1) * params.PageSize end := start + params.PageSize @@ -345,14 +338,15 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate }, nil } -func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) { +func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) { return s.keyRepo.GetKeysByIDs(ids) } -func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error { +func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error { go func() { + bgCtx := context.Background() var oldKey models.APIKey - if err := s.db.First(&oldKey, key.ID).Error; err != nil { + if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil { s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err) return } @@ -364,16 +358,14 @@ func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error { return nil } -func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error { - // Get all associated groups before deletion to publish correct events - groups, err := s.keyRepo.GetGroupsForKey(id) +func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error { + groups, err := s.keyRepo.GetGroupsForKey(ctx, id) if err != nil { s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err) } err = s.keyRepo.HardDeleteByID(id) if err == nil { - // Publish events for each group the key was a part of for _, groupID := range groups { event := models.KeyStatusChangedEvent{ KeyID: id, @@ -381,13 +373,13 @@ func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error { ChangeReason: "key_hard_deleted", } eventData, _ := json.Marshal(event) - go s.store.Publish(models.TopicKeyStatusChanged, eventData) + go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData) } } return err } -func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) { +func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) { key, err := s.keyRepo.GetKeyByID(keyID) if err != nil { return nil, CustomErrors.ParseDBError(err) @@ -409,19 +401,20 @@ func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus model mapping.ConsecutiveErrorCount = 0 mapping.LastError = "" } - if err := s.keyRepo.UpdateMapping(mapping); err != nil { + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { return nil, err } - go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update") + go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update") return mapping, nil } func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) { + ctx := context.Background() s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus) if event.NewMasterStatus != models.MasterStatusRevoked { return } - affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID) + affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID) if err != nil { s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID) return @@ -432,7 +425,7 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey } s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs)) for _, groupID := range affectedGroupIDs { - _, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned) + _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID) @@ -441,32 +434,32 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey } } -func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) { +func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) { if len(keyIDs) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.") } resourceID := fmt.Sprintf("group-%d", groupID) - taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour) + taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour) if err != nil { return nil, err } - go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs) + go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs) return taskStatus, nil } -func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) { +func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) { defer func() { if r := recover(); r != nil { s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r) - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r)) } }() var mappingsToProcess []models.GroupAPIKeyMapping - err := s.db.Preload("APIKey"). + err := s.db.WithContext(ctx).Preload("APIKey"). Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs). Find(&mappingsToProcess).Error if err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, err) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) return } result := &BatchRestoreResult{ @@ -476,7 +469,7 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro processedCount := 0 for _, mapping := range mappingsToProcess { processedCount++ - _ = s.taskService.UpdateProgressByID(taskID, processedCount) + _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount) if mapping.APIKey == nil { result.SkippedCount++ result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."}) @@ -492,33 +485,29 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro mapping.Status = models.StatusActive mapping.ConsecutiveErrorCount = 0 mapping.LastError = "" - // Use the version that doesn't trigger individual cache updates. if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil { s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.") result.SkippedCount++ result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."}) } else { result.RestoredCount++ - successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update. - go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore") + successfulMappings = append(successfulMappings, &mapping) + go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore") } } else { - result.RestoredCount++ // Already active, count as success. + result.RestoredCount++ } } - // After the loop, perform one single, efficient cache update. - if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil { + if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil { s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.") - // This is not a task-fatal error, so we just log it and continue. } - // Account for keys that were requested but not found in the initial DB query. result.SkippedCount += (len(keyIDs) - len(mappingsToProcess)) - s.taskService.EndTaskByID(taskID, resourceID, result, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } -func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) { +func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) { var bannedKeyIDs []uint - err := s.db.Model(&models.GroupAPIKeyMapping{}). + err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned). Pluck("api_key_id", &bannedKeyIDs).Error if err != nil { @@ -527,10 +516,11 @@ func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, e if len(bannedKeyIDs) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.") } - return s.StartRestoreKeysTask(groupID, bannedKeyIDs) + return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs) } func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) { + ctx := context.Background() group, ok := s.groupManager.GetGroupByID(event.GroupID) if !ok { s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID) @@ -552,7 +542,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp concurrency = *opConfig.KeyCheckConcurrency } if concurrency <= 0 { - concurrency = 10 // Safety fallback + concurrency = 10 } timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs) @@ -571,7 +561,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint) if validationErr == nil { s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID) - if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil { + if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil { s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err) } } else { @@ -579,7 +569,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp if !CustomErrors.As(validationErr, &apiErr) { apiErr = &CustomErrors.APIError{Message: validationErr.Error()} } - s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false) + s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false) } } }() @@ -592,12 +582,9 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp s.logger.Infof("Finished post-import validation for group %d.", event.GroupID) } -// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys -// that match a specific set of source statuses within a group. -func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) { +func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) { s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses) - // 1. Find key IDs using the new repository method. Using IDs is more efficient for updates. keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses) if err != nil { return nil, CustomErrors.ParseDBError(err) @@ -605,35 +592,32 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatus if len(keyIDs) == 0 { now := time.Now() return &task.Status{ - IsRunning: false, // The "task" is not running. + IsRunning: false, Processed: 0, Total: 0, - Result: map[string]string{ // We use the flexible Result field to pass the message. + Result: map[string]string{ "message": "没有找到任何符合当前过滤条件的Key可供操作。", }, - Error: "", // There is no error. + Error: "", StartedAt: now, - FinishedAt: &now, // It started and finished at the same time. - }, nil // Return nil for the error, signaling a 200 OK. + FinishedAt: &now, + }, nil } - // 2. Start a new task using the TaskService, following existing patterns. resourceID := fmt.Sprintf("group-%d-status-update", groupID) - taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute) + taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute) if err != nil { - return nil, err // Pass up errors like "task already in progress". + return nil, err } - // 3. Run the core logic in a separate goroutine. - go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus) + go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus) return taskStatus, nil } -// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task. -func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) { +func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) { defer func() { if r := recover(); r != nil { s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r) - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r)) } }() type BatchUpdateResult struct { @@ -642,31 +626,27 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g } result := &BatchUpdateResult{} var successfulMappings []*models.GroupAPIKeyMapping - // 1. Fetch all key master statuses in one go. This is efficient. + keys, err := s.keyRepo.GetKeysByIDs(keyIDs) if err != nil { s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.") - s.taskService.EndTaskByID(taskID, resourceID, nil, err) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) return } masterStatusMap := make(map[uint]models.MasterAPIKeyStatus) for _, key := range keys { masterStatusMap[key.ID] = key.MasterStatus } - // 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db, - // avoiding the need for a new repository method. This pattern is - // already used in other parts of this service. var mappings []*models.GroupAPIKeyMapping - if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil { + if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil { s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.") - s.taskService.EndTaskByID(taskID, resourceID, nil, err) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) return } processedCount := 0 for _, mapping := range mappings { processedCount++ - // The progress update should reflect the number of items *being processed*, not the final count. - _ = s.taskService.UpdateProgressByID(taskID, processedCount) + _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount) masterStatus, ok := masterStatusMap[mapping.APIKeyID] if !ok { result.SkippedCount++ @@ -688,24 +668,25 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g } else { result.UpdatedCount++ successfulMappings = append(successfulMappings, mapping) - go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update") + go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update") } } else { - result.UpdatedCount++ // Already in desired state, count as success. + result.UpdatedCount++ } } result.SkippedCount += (len(keyIDs) - len(mappings)) - if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil { + if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil { s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.") } s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount) - s.taskService.EndTaskByID(taskID, resourceID, result, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) { + ctx := context.Background() if success { if group.PollingStrategy == models.StrategyWeighted { - go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID) + go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID) } return } @@ -716,26 +697,20 @@ func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models. errMsg := apiErr.Message if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) { s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg) - go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown) + go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown) } else { s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg) } } -// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs. func sanitizeForLog(errMsg string) string { - // Find the start of any potential JSON blob or detailed structure. jsonStartIndex := strings.Index(errMsg, "{") var cleanMsg string if jsonStartIndex != -1 { - // If a '{' is found, take everything before it as the summary - // and append a simple placeholder. cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}" } else { - // If no JSON-like structure is found, use the original message. cleanMsg = errMsg } - // Always apply a final length truncation as a safeguard against extremely long non-JSON errors. const maxLen = 250 if len(cleanMsg) > maxLen { return cleanMsg[:maxLen] + "..." @@ -744,6 +719,7 @@ func sanitizeForLog(errMsg string) string { } func (s *APIKeyService) judgeKeyErrors( + ctx context.Context, correlationID string, groupID, keyID uint, apiErr *CustomErrors.APIError, @@ -765,11 +741,11 @@ func (s *APIKeyService) judgeKeyErrors( oldStatus := mapping.Status mapping.Status = models.StatusBanned mapping.LastError = errorMessage - if err := s.keyRepo.UpdateMapping(mapping); err != nil { + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { logger.WithError(err).Error("Failed to update mapping status to BANNED.") } else { - go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned") - go s.revokeMasterKey(keyID, "permanent_upstream_error") + go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned") + go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error") } } return @@ -801,23 +777,23 @@ func (s *APIKeyService) judgeKeyErrors( if oldStatus != newStatus { mapping.Status = newStatus } - if err := s.keyRepo.UpdateMapping(mapping); err != nil { + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { logger.WithError(err).Error("Failed to update mapping after temporary error.") return } if oldStatus != newStatus { - go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached") + go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached") } return } logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage)) logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.") - if err := s.keyRepo.UpdateMapping(mapping); err != nil { + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.") } } -func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) { +func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) { key, err := s.keyRepo.GetKeyByID(keyID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -832,7 +808,7 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) { } oldMasterStatus := key.MasterStatus newMasterStatus := models.MasterStatusRevoked - if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil { + if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil { s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err) return } @@ -844,9 +820,9 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) { ChangedAt: time.Now(), } eventData, _ := json.Marshal(masterKeyEvent) - _ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData) + _ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData) } -func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) { +func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) { return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses) } diff --git a/internal/service/dashboard_query_service.go b/internal/service/dashboard_query_service.go index 65c044e..d1a6764 100644 --- a/internal/service/dashboard_query_service.go +++ b/internal/service/dashboard_query_service.go @@ -1,8 +1,8 @@ // Filename: internal/service/dashboard_query_service.go - package service import ( + "context" "fmt" "gemini-balancer/internal/models" "gemini-balancer/internal/store" @@ -17,8 +17,6 @@ import ( const overviewCacheChannel = "syncer:cache:dashboard_overview" -// DashboardQueryService 负责所有面向前端的仪表盘数据查询。 - type DashboardQueryService struct { db *gorm.DB store store.Store @@ -54,9 +52,9 @@ func (s *DashboardQueryService) Stop() { s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.") } -func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) { +func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) { statsKey := fmt.Sprintf("stats:group:%d", groupID) - keyStatsMap, err := s.store.HGetAll(statsKey) + keyStatsMap, err := s.store.HGetAll(ctx, statsKey) if err != nil { s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID) return nil, fmt.Errorf("failed to get key stats from cache: %w", err) @@ -74,11 +72,11 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err SuccessRequests int64 } var last1Hour, last24Hours requestStatsResult - s.db.Model(&models.StatsHourly{}). + s.db.WithContext(ctx).Model(&models.StatsHourly{}). Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). Where("group_id = ? AND time >= ?", groupID, oneHourAgo). Scan(&last1Hour) - s.db.Model(&models.StatsHourly{}). + s.db.WithContext(ctx).Model(&models.StatsHourly{}). Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo). Scan(&last24Hours) @@ -109,8 +107,9 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err } func (s *DashboardQueryService) eventListener() { - keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged) - upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged) + ctx := context.Background() + keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged) + upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged) defer keyStatusSub.Close() defer upstreamStatusSub.Close() for { @@ -128,7 +127,6 @@ func (s *DashboardQueryService) eventListener() { } } -// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。 func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) { cachedDataPtr := s.overviewSyncer.Get() if cachedDataPtr == nil { @@ -141,8 +139,7 @@ func (s *DashboardQueryService) InvalidateOverviewCache() error { return s.overviewSyncer.Invalidate() } -// QueryHistoricalChart 查询历史图表数据。 -func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) { +func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) { type ChartPoint struct { TimeLabel string `gorm:"column:time_label"` ModelName string `gorm:"column:model_name"` @@ -151,7 +148,7 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour) sqlFormat, goFormat := s.buildTimeFormatSelectClause() selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat) - query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC") + query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC") if groupID != nil && *groupID > 0 { query = query.Where("group_id = ?", *groupID) } @@ -189,38 +186,38 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha } func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) { + ctx := context.Background() s.logger.Info("[CacheSyncer] Starting to load overview data from database...") startTime := time.Now() resp := &models.DashboardStatsResponse{ KeyStatusCount: make(map[models.APIKeyStatus]int64), MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64), - KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体,而不是nil - RequestCount24h: models.StatCard{}, // 同上 + KeyCount: models.StatCard{}, + RequestCount24h: models.StatCard{}, TokenCount: make(map[string]any), UpstreamHealthStatus: make(map[string]string), RPM: models.StatCard{}, RequestCounts: make(map[string]int64), } - // --- 1. Aggregate Operational Status from Mappings --- + type MappingStatusResult struct { Status models.APIKeyStatus Count int64 } var mappingStatusResults []MappingStatusResult - if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil { + if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil { return nil, fmt.Errorf("failed to query mapping status stats: %w", err) } for _, res := range mappingStatusResults { resp.KeyStatusCount[res.Status] = res.Count } - // --- 2. Aggregate Master Status from APIKeys --- type MasterStatusResult struct { Status models.MasterAPIKeyStatus Count int64 } var masterStatusResults []MasterStatusResult - if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil { + if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil { return nil, fmt.Errorf("failed to query master status stats: %w", err) } var totalKeys, invalidKeys int64 @@ -235,20 +232,15 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon now := time.Now() - // 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询 var count1m, count1h, count1d int64 - // RPM: 从此刻倒推1分钟 - s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m) - // RPH: 从此刻倒推1小时 - s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h) - - // RPD: 从今天零点 (UTC) 到此刻 + s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m) + s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h) year, month, day := now.UTC().Date() startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC) - s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d) - // 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能 + s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d) + var count30d int64 - s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d) + s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d) resp.RequestCounts["1m"] = count1m resp.RequestCounts["1h"] = count1h @@ -256,7 +248,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon resp.RequestCounts["30d"] = count30d var upstreams []*models.UpstreamEndpoint - if err := s.db.Find(&upstreams).Error; err != nil { + if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil { s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.") } else { for _, u := range upstreams { @@ -269,7 +261,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon return resp, nil } -func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) { +func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) { var startTime time.Time now := time.Now() switch period { @@ -288,7 +280,7 @@ func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, Success int64 } - err := s.db.Model(&models.RequestLog{}). + err := s.db.WithContext(ctx).Model(&models.RequestLog{}). Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success"). Where("request_time >= ?", startTime). Scan(&result).Error diff --git a/internal/service/db_log_writer_service.go b/internal/service/db_log_writer_service.go index 892c4ee..b80904a 100644 --- a/internal/service/db_log_writer_service.go +++ b/internal/service/db_log_writer_service.go @@ -1,8 +1,8 @@ // Filename: internal/service/db_log_writer_service.go - package service import ( + "context" "encoding/json" "gemini-balancer/internal/models" "gemini-balancer/internal/settings" @@ -35,35 +35,30 @@ func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.Settin store: s, SettingsManager: settings, logger: logger.WithField("component", "DBLogWriter📝"), - // 使用配置值来创建缓冲区 - logBuffer: make(chan *models.RequestLog, bufferCapacity), - stopChan: make(chan struct{}), + logBuffer: make(chan *models.RequestLog, bufferCapacity), + stopChan: make(chan struct{}), } } func (s *DBLogWriterService) Start() { - s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入 - - // 启动事件监听器 + s.wg.Add(2) go s.eventListenerLoop() - // 启动数据库写入器 go s.dbWriterLoop() - s.logger.Info("DBLogWriterService started.") } func (s *DBLogWriterService) Stop() { s.logger.Info("DBLogWriterService stopping...") - close(s.stopChan) // 通知所有goroutine停止 - s.wg.Wait() // 等待所有goroutine完成 + close(s.stopChan) + s.wg.Wait() s.logger.Info("DBLogWriterService stopped.") } -// eventListenerLoop 负责从store接收事件并放入内存缓冲区 func (s *DBLogWriterService) eventListenerLoop() { defer s.wg.Done() - sub, err := s.store.Subscribe(models.TopicRequestFinished) + ctx := context.Background() + sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) return @@ -80,34 +75,27 @@ func (s *DBLogWriterService) eventListenerLoop() { s.logger.Errorf("Failed to unmarshal event for logging: %v", err) continue } - - // 将事件中的日志部分放入缓冲区 select { case s.logBuffer <- &event.RequestLog: default: s.logger.Warn("Log buffer is full. A log message might be dropped.") } - case <-s.stopChan: s.logger.Info("Event listener loop stopping.") - // 关闭缓冲区,以通知dbWriterLoop处理完剩余日志后退出 close(s.logBuffer) return } } } -// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库 func (s *DBLogWriterService) dbWriterLoop() { defer s.wg.Done() - // 在启动时获取一次配置 cfg := s.SettingsManager.GetSettings() batchSize := cfg.LogFlushBatchSize if batchSize <= 0 { batchSize = 100 } - flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second if flushTimeout <= 0 { flushTimeout = 5 * time.Second @@ -126,7 +114,7 @@ func (s *DBLogWriterService) dbWriterLoop() { return } batch = append(batch, logEntry) - if len(batch) >= batchSize { // 使用配置的批次大小 + if len(batch) >= batchSize { s.flushBatch(batch) batch = make([]*models.RequestLog, 0, batchSize) } @@ -139,7 +127,6 @@ func (s *DBLogWriterService) dbWriterLoop() { } } -// flushBatch 将一个批次的日志写入数据库 func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) { if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil { s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.") diff --git a/internal/service/healthcheck_service.go b/internal/service/healthcheck_service.go index c09ff98..36030a2 100644 --- a/internal/service/healthcheck_service.go +++ b/internal/service/healthcheck_service.go @@ -75,7 +75,7 @@ func NewHealthCheckService( func (s *HealthCheckService) Start() { s.logger.Info("Starting HealthCheckService with independent check loops...") - s.wg.Add(4) // Now four loops + s.wg.Add(4) go s.runKeyCheckLoop() go s.runUpstreamCheckLoop() go s.runProxyCheckLoop() @@ -102,8 +102,6 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string { func (s *HealthCheckService) runKeyCheckLoop() { defer s.wg.Done() s.logger.Info("Key check dynamic scheduler loop started.") - - // 主调度循环,每分钟检查一次任务 ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() @@ -126,26 +124,22 @@ func (s *HealthCheckService) scheduleKeyChecks() { defer s.groupCheckTimeMutex.Unlock() for _, group := range groups { - // 获取特定于组的运营配置 opConfig, err := s.groupManager.BuildOperationalConfig(group) if err != nil { s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.") continue } - if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck { - continue // 跳过禁用了健康检查的组 + continue } - var intervalMinutes int if opConfig.KeyCheckIntervalMinutes != nil { intervalMinutes = *opConfig.KeyCheckIntervalMinutes } interval := time.Duration(intervalMinutes) * time.Minute if interval <= 0 { - continue // 跳过无效的检查周期 + continue } - if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) { s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID) go s.performKeyChecksForGroup(group, opConfig) @@ -160,7 +154,6 @@ func (s *HealthCheckService) runUpstreamCheckLoop() { if s.SettingsManager.GetSettings().EnableUpstreamCheck { s.performUpstreamChecks() } - ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second) defer ticker.Stop() @@ -184,7 +177,6 @@ func (s *HealthCheckService) runProxyCheckLoop() { if s.SettingsManager.GetSettings().EnableProxyCheck { s.performProxyChecks() } - ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second) defer ticker.Stop() @@ -203,6 +195,7 @@ func (s *HealthCheckService) runProxyCheckLoop() { } func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) { + ctx := context.Background() settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second @@ -213,11 +206,9 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op } log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name}) - log.Infof("Starting key health check cycle.") - var mappingsToCheck []models.GroupAPIKeyMapping - err = s.db.Model(&models.GroupAPIKeyMapping{}). + err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id"). Where("group_api_key_mappings.key_group_id = ?", group.ID). Where("api_keys.master_status = ?", models.MasterStatusActive). @@ -233,7 +224,6 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op log.Info("No key mappings to check for this group.") return } - log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck)) jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck)) var wg sync.WaitGroup @@ -242,14 +232,14 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op concurrency = *opConfig.KeyCheckConcurrency } if concurrency <= 0 { - concurrency = 1 // 保证至少有一个 worker + concurrency = 1 } for w := 1; w <= concurrency; w++ { wg.Add(1) go func(workerID int) { defer wg.Done() for mapping := range jobs { - s.checkAndProcessMapping(&mapping, timeout, endpoint) + s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint) } }(w) } @@ -261,52 +251,46 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op log.Info("Finished key health check cycle.") } -func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) { +func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) { if mapping.APIKey == nil { s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID) return } validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint) - // --- 诊断一:验证成功 (健康) --- if validationErr == nil { if mapping.Status != models.StatusActive { - s.activateMapping(mapping) + s.activateMapping(ctx, mapping) } return } errorString := validationErr.Error() - // --- 诊断二:永久性错误 --- if CustomErrors.IsPermanentUpstreamError(errorString) { - s.revokeMapping(mapping, validationErr) + s.revokeMapping(ctx, mapping, validationErr) return } - // --- 诊断三:暂时性错误 --- if CustomErrors.IsTemporaryUpstreamError(errorString) { - // Log with a higher level (WARN) since this is an actionable, proactive finding. s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr) - s.penalizeMapping(mapping, validationErr) + s.penalizeMapping(ctx, mapping, validationErr) return } - // --- 诊断四:其他未知或上游服务错误 --- - s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr) } -func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) { +func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) { oldStatus := mapping.Status mapping.Status = models.StatusActive mapping.ConsecutiveErrorCount = 0 mapping.LastError = "" - if err := s.keyRepo.UpdateMapping(mapping); err != nil { + + if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil { s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) return } s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus) - s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) + s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) } -func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) { - // Re-fetch group-specific operational config to get the correct thresholds +func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) { group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID) if !ok { s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID) @@ -320,7 +304,6 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, oldStatus := mapping.Status mapping.LastError = err.Error() mapping.ConsecutiveErrorCount++ - // Use the group-specific threshold threshold := *opConfig.KeyBlacklistThreshold if mapping.ConsecutiveErrorCount >= threshold { mapping.Status = models.StatusCooldown @@ -329,44 +312,41 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, mapping.CooldownUntil = &cooldownTime s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration) } - if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil { + if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil { s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) return } if oldStatus != mapping.Status { - s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) + s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) } } -func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) { +func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) { oldStatus := mapping.Status if oldStatus == models.StatusBanned { - return // Already banned, do nothing. + return } - mapping.Status = models.StatusBanned mapping.LastError = "Definitive error: " + err.Error() - mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group - - if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil { + mapping.ConsecutiveErrorCount = 0 + if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil { s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) return } - s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err) - s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) - + s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID) - if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil { + if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil { s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID) } } func (s *HealthCheckService) performUpstreamChecks() { + ctx := context.Background() settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second var upstreams []*models.UpstreamEndpoint - if err := s.db.Find(&upstreams).Error; err != nil { + if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil { s.logger.WithError(err).Error("Failed to retrieve upstreams.") return } @@ -390,10 +370,10 @@ func (s *HealthCheckService) performUpstreamChecks() { s.lastResultsMutex.Unlock() if oldStatus != newStatus { s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus) - if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil { + if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil { s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.") } else { - s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus) + s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus) } } }(u) @@ -412,10 +392,11 @@ func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) } func (s *HealthCheckService) performProxyChecks() { + ctx := context.Background() settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second var proxies []*models.ProxyConfig - if err := s.db.Find(&proxies).Error; err != nil { + if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil { s.logger.WithError(err).Error("Failed to retrieve proxies.") return } @@ -438,7 +419,7 @@ func (s *HealthCheckService) performProxyChecks() { s.lastResultsMutex.Unlock() if proxyCfg.Status != newStatus { s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus) - if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil { + if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil { s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.") } } @@ -482,7 +463,7 @@ func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout ti return true } -func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) { +func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) { event := models.KeyStatusChangedEvent{ KeyID: keyID, GroupID: groupID, @@ -496,12 +477,12 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, o s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID) return } - if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil { + if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil { s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID) } } -func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) { +func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) { event := models.UpstreamHealthChangedEvent{ UpstreamID: upstream.ID, UpstreamURL: upstream.URL, @@ -516,28 +497,20 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models. s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.") return } - if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil { + if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil { s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.") } } -// ========================================================================= -// Global Base Key Check (New Logic) -// ========================================================================= - func (s *HealthCheckService) runBaseKeyCheckLoop() { defer s.wg.Done() s.logger.Info("Global base key check loop started.") settings := s.SettingsManager.GetSettings() - if !settings.EnableBaseKeyCheck { s.logger.Info("Global base key check is disabled.") return } - - // Perform an initial check on startup s.performBaseKeyChecks() - interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute if interval <= 0 { s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes) @@ -558,6 +531,7 @@ func (s *HealthCheckService) runBaseKeyCheckLoop() { } func (s *HealthCheckService) performBaseKeyChecks() { + ctx := context.Background() s.logger.Info("Starting global base key check cycle.") settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second @@ -576,7 +550,7 @@ func (s *HealthCheckService) performBaseKeyChecks() { jobs := make(chan *models.APIKey, len(keys)) var wg sync.WaitGroup if concurrency <= 0 { - concurrency = 5 // Safe default + concurrency = 5 } for w := 0; w < concurrency; w++ { wg.Add(1) @@ -587,10 +561,10 @@ func (s *HealthCheckService) performBaseKeyChecks() { if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) { oldStatus := key.MasterStatus s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err) - if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil { + if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil { s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID) } else { - s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked) + s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked) } } } @@ -604,8 +578,7 @@ func (s *HealthCheckService) performBaseKeyChecks() { s.logger.Info("Global base key check cycle finished.") } -// 事件发布辅助函数 -func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) { +func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) { event := models.MasterKeyStatusChangedEvent{ KeyID: keyID, OldMasterStatus: oldStatus, @@ -618,7 +591,7 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldS s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID) return } - if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil { + if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil { s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID) } } diff --git a/internal/service/key_import_service.go b/internal/service/key_import_service.go index 3b7f567..fca6c7c 100644 --- a/internal/service/key_import_service.go +++ b/internal/service/key_import_service.go @@ -2,6 +2,7 @@ package service import ( + "context" "encoding/json" "fmt" "gemini-balancer/internal/models" @@ -42,88 +43,84 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store. } } -// --- 通用的 Panic-Safe 任務執行器 --- -func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) { +func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic recovered in task %s: %v", taskID, r) s.logger.Error(err) - s.taskService.EndTaskByID(taskID, resourceID, nil, err) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err) } }() taskFunc() } -// --- Public Task Starters --- - -func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) { +func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found in input text") } resourceID := fmt.Sprintf("group-%d", groupID) - taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute) + + taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute) if err != nil { return nil, err } - go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { - s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport) + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { + s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport) }) return taskStatus, nil } -func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) { +func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } resourceID := fmt.Sprintf("group-%d", groupID) - taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour) + taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour) if err != nil { return nil, err } - go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { - s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys) + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { + s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys) }) return taskStatus, nil } -func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) { +func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } - resourceID := "global_hard_delete" // Global lock - taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour) + resourceID := "global_hard_delete" + taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour) if err != nil { return nil, err } - go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { - s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys) + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { + s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) return taskStatus, nil } -func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) { +func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) { keys := utils.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found") } - resourceID := "global_restore_keys" // Global lock - taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour) + resourceID := "global_restore_keys" + + taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour) if err != nil { return nil, err } - go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { - s.runRestoreKeysTask(taskStatus.ID, resourceID, keys) + go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() { + s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys) }) return taskStatus, nil } -// --- Private Task Runners --- - -func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { - // 步骤 1: 对输入的原始 key 列表进行去重。 +func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) { uniqueKeysMap := make(map[string]struct{}) var uniqueKeyStrings []string for _, kStr := range keys { @@ -133,41 +130,37 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou } } if len(uniqueKeyStrings) == 0 { - s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil) return } - // 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。 keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings)) for i, keyStr := range uniqueKeyStrings { keysToEnsure[i] = models.APIKey{APIKey: keyStr} } allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure) if err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) return } - // 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。 alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID) if err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err)) return } alreadyLinkedIDSet := make(map[uint]struct{}) for _, key := range alreadyLinkedModels { alreadyLinkedIDSet[key.ID] = struct{}{} } - // 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。 var keysToLink []models.APIKey for _, key := range allKeyModels { if _, exists := alreadyLinkedIDSet[key.ID]; !exists { keysToLink = append(keysToLink, key) } } - // 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。 - if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil { + + if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } - // 步骤 6: 分块处理【链接Key到组】的操作,并实时更新进度。 if len(keysToLink) > 0 { idsToLink := make([]uint, len(keysToLink)) for i, key := range keysToLink { @@ -179,44 +172,41 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou end = len(idsToLink) } chunk := idsToLink[i:end] - if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err)) + if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil { + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err)) return } - _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) + _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } } - // 步骤 7: 准备最终结果并结束任务。 result := gin.H{ "newly_linked_count": len(keysToLink), "already_linked_count": len(alreadyLinkedIDSet), "total_linked_count": len(allKeyModels), } - // 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。 if len(keysToLink) > 0 { idsToLink := make([]uint, len(keysToLink)) for i, key := range keysToLink { idsToLink[i] = key.ID } if validateOnImport { - s.publishImportGroupCompletedEvent(groupID, idsToLink) + s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink) for _, keyID := range idsToLink { - s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked") + s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked") } } else { for _, keyID := range idsToLink { - if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil { + if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil { s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err) } } } } - s.taskService.EndTaskByID(taskID, resourceID, result, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } -// runUnlinkKeysTask -func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) { +func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) { uniqueKeysMap := make(map[string]struct{}) var uniqueKeys []string for _, kStr := range keys { @@ -225,46 +215,42 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uniqueKeys = append(uniqueKeys, kStr) } } - // 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。 keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) if err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) return } if len(keysToUnlink) == 0 { result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)} - s.taskService.EndTaskByID(taskID, resourceID, result, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) return } idsToUnlink := make([]uint, len(keysToUnlink)) for i, key := range keysToUnlink { idsToUnlink[i] = key.ID } - // 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。 - if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil { + + if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil { s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) } var totalUnlinked int64 - // 步骤 3: 分块处理【解绑Key】的操作,并上报进度。 for i := 0; i < len(idsToUnlink); i += chunkSize { end := i + chunkSize if end > len(idsToUnlink) { end = len(idsToUnlink) } chunk := idsToUnlink[i:end] - unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk) + unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk) if err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err)) return } totalUnlinked += unlinked - for _, keyID := range chunk { - s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked") + s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked") } - - _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) + _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } totalDeleted, err := s.keyRepo.DeleteOrphanKeys() @@ -276,10 +262,10 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID "hard_deleted_count": totalDeleted, "not_found_count": len(uniqueKeys) - int(totalUnlinked), } - s.taskService.EndTaskByID(taskID, resourceID, result, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) } -func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) { +func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { var totalDeleted int64 for i := 0; i < len(keys); i += chunkSize { end := i + chunkSize @@ -290,22 +276,21 @@ func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys deleted, err := s.keyRepo.HardDeleteByValues(chunk) if err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err)) return } totalDeleted += deleted - _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) + _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } - result := gin.H{ "hard_deleted_count": totalDeleted, "not_found_count": int64(len(keys)) - totalDeleted, } - s.taskService.EndTaskByID(taskID, resourceID, result, nil) - s.publishChangeEvent(0, "keys_hard_deleted") // Global event + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) + s.publishChangeEvent(ctx, 0, "keys_hard_deleted") } -func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) { +func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) { var restoredCount int64 for i := 0; i < len(keys); i += chunkSize { end := i + chunkSize @@ -316,21 +301,21 @@ func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys [] count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive) if err != nil { - s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err)) + s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err)) return } restoredCount += count - _ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) + _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk)) } result := gin.H{ "restored_count": restoredCount, "not_found_count": int64(len(keys)) - restoredCount, } - s.taskService.EndTaskByID(taskID, resourceID, result, nil) - s.publishChangeEvent(0, "keys_bulk_restored") // Global event + s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil) + s.publishChangeEvent(ctx, 0, "keys_bulk_restored") } -func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { +func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, KeyID: keyID, @@ -340,7 +325,7 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS ChangedAt: time.Now(), } eventData, _ := json.Marshal(event) - if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil { + if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "group_id": groupID, "key_id": keyID, @@ -349,16 +334,16 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS } } -func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) { +func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) { event := models.KeyStatusChangedEvent{ GroupID: groupID, ChangeReason: reason, } eventData, _ := json.Marshal(event) - _ = s.store.Publish(models.TopicKeyStatusChanged, eventData) + _ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData) } -func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) { +func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) { if len(keyIDs) == 0 { return } @@ -372,17 +357,15 @@ func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.") return } - if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil { + if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil { s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.") } else { s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs)) } } -// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter. -func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) { +func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) - // 1. [New] Find the keys to operate on. keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { return nil, fmt.Errorf("failed to find keys by filter: %w", err) @@ -390,8 +373,7 @@ func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses [] if len(keyValues) == 0 { return nil, fmt.Errorf("no keys found matching the provided filter") } - // 2. [REUSE] Convert to text and call the existing, robust unlink task logic. keysAsText := strings.Join(keyValues, "\n") s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID) - return s.StartUnlinkKeysTask(groupID, keysAsText) -} \ No newline at end of file + return s.StartUnlinkKeysTask(ctx, groupID, keysAsText) +} diff --git a/internal/service/key_validation_service.go b/internal/service/key_validation_service.go index 36901dc..74e6678 100644 --- a/internal/service/key_validation_service.go +++ b/internal/service/key_validation_service.go @@ -2,6 +2,7 @@ package service import ( + "context" "encoding/json" "fmt" "gemini-balancer/internal/channel" @@ -62,20 +63,18 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim return fmt.Errorf("failed to create request: %w", err) } - s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request + s.channel.ModifyRequest(req, key) resp, err := client.Do(req) if err != nil { - // This is a network-level error (e.g., timeout, DNS issue) return fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { - return nil // Success + return nil } - // Read the body for more error details bodyBytes, readErr := io.ReadAll(resp.Body) var errorMsg string if readErr != nil { @@ -84,7 +83,6 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim errorMsg = string(bodyBytes) } - // This is a validation failure with a specific HTTP status code return &CustomErrors.APIError{ HTTPStatus: resp.StatusCode, Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg), @@ -92,8 +90,7 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim } } -// --- 异步任务方法 (全面适配新task包) --- -func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) { +func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) { keyStrings := utils.ParseKeysFromText(keysText) if len(keyStrings) == 0 { return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text") @@ -111,7 +108,6 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) } group, ok := s.groupManager.GetGroupByID(groupID) if !ok { - // [FIX] Correctly use the NewAPIError constructor for a missing group. return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID)) } opConfig, err := s.groupManager.BuildOperationalConfig(group) @@ -119,15 +115,15 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err)) } resourceID := fmt.Sprintf("group-%d", groupID) - taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour) + taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour) if err != nil { - return nil, err // Pass up the error from task service (e.g., "task already running") + return nil, err } settings := s.SettingsManager.GetSettings() timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID) if err != nil { - s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails + s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err) return nil, err } var concurrency int @@ -136,11 +132,11 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) } else { concurrency = settings.BaseKeyCheckConcurrency } - go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency) + go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency) return taskStatus, nil } -func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) { +func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) { var wg sync.WaitGroup var mu sync.Mutex finalResults := make([]models.KeyTestResult, len(keys)) @@ -165,7 +161,6 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, var currentResult models.KeyTestResult event := models.RequestFinishedEvent{ RequestLog: models.RequestLog{ - // GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址 GroupID: &groupID, KeyID: &apiKeyModel.ID, }, @@ -185,14 +180,15 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, event.RequestLog.IsSuccess = false } eventData, _ := json.Marshal(event) - if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil { + + if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil { s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID) } mu.Lock() finalResults[j.Index] = currentResult processedCount++ - _ = s.taskService.UpdateProgressByID(taskID, processedCount) + _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount) mu.Unlock() } }() @@ -202,10 +198,10 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, } close(jobs) wg.Wait() - s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil) + s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil) } -func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) { +func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) { s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses) keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) if err != nil { @@ -216,5 +212,5 @@ func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses } keysAsText := strings.Join(keyValues, "\n") s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID) - return s.StartTestKeysTask(groupID, keysAsText) + return s.StartTestKeysTask(ctx, groupID, keysAsText) } diff --git a/internal/service/resource_service.go b/internal/service/resource_service.go index 793e7c8..87ecf36 100644 --- a/internal/service/resource_service.go +++ b/internal/service/resource_service.go @@ -3,6 +3,7 @@ package service import ( + "context" "errors" apperrors "gemini-balancer/internal/errors" "gemini-balancer/internal/models" @@ -43,7 +44,6 @@ func NewResourceService( aks *APIKeyService, logger *logrus.Logger, ) *ResourceService { - logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr) rs := &ResourceService{ settingsManager: sm, groupManager: gm, @@ -56,43 +56,40 @@ func NewResourceService( go rs.preWarmCache(logger) }) return rs - } -// --- [模式一:智能聚合模式] --- -func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) { +func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) { log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"}) log.Debug("Entering BasePool resource acquisition.") - // 1.筛选出所有符合条件的候选组,并按优先级排序 + candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups) if len(candidateGroups) == 0 { log.Warn("No candidate groups found for BasePool construction.") return nil, apperrors.ErrNoKeysAvailable } - // 2.从 BasePool中,根据系统全局策略选择一个Key + basePool := &repository.BasePool{ CandidateGroups: candidateGroups, PollingStrategy: s.settingsManager.GetSettings().PollingStrategy, } - apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool) + + apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool) if err != nil { log.WithError(err).Warn("Failed to select a key from the BasePool.") return nil, apperrors.ErrNoKeysAvailable } - // 3. 组装最终资源 - // [关键] 在此模式下,RequestConfig 永远是空的,以保证透明性。 + resources, err := s.assembleRequestResources(selectedGroup, apiKey) if err != nil { log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.") return nil, err } - resources.RequestConfig = &models.RequestConfig{} // 强制为空 + resources.RequestConfig = &models.RequestConfig{} log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID) return resources, nil } -// --- [模式二:精确路由模式] --- -func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) { +func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) { log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"}) log.Debug("Entering PreciseRoute resource acquisition.") @@ -101,12 +98,11 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou if !ok { return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.") } - if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) { return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.") } - apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup) + apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup) if err != nil { log.WithError(err).Warn("Failed to select a key from the precisely targeted group.") return nil, apperrors.ErrNoKeysAvailable @@ -132,7 +128,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) if authToken.IsAdmin { for _, group := range allGroups { for _, modelMapping := range group.AllowedModels { - allowedModelsSet[modelMapping.ModelName] = struct{}{} } } @@ -144,7 +139,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) for _, group := range allGroups { if _, ok := allowedGroupIDs[group.ID]; ok { for _, modelMapping := range group.AllowedModels { - allowedModelsSet[modelMapping.ModelName] = struct{}{} } } @@ -164,14 +158,6 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.") } var proxyConfig *models.ProxyConfig - // [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。 - // if group.EnableProxy && s.proxyModule != nil { - // var err error - // proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey) - // if err != nil { - // s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID) - // } - // } return &RequestResources{ KeyGroup: group, APIKey: apiKey, @@ -194,7 +180,7 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models func (s *ResourceService) preWarmCache(logger *logrus.Logger) error { time.Sleep(2 * time.Second) s.logger.Info("Performing initial key cache pre-warming...") - if err := s.keyRepo.LoadAllKeysToStore(); err != nil { + if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil { logger.WithError(err).Error("Failed to perform initial key cache pre-warming.") return err } @@ -209,7 +195,6 @@ func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup { allGroupsFromCache := s.groupManager.GetAllGroups() var candidateGroups []*models.KeyGroup - // 1. 确定权限范围 allowedGroupIDs := make(map[uint]bool) isTokenRestricted := len(allowedGroupsFromToken) > 0 if isTokenRestricted { @@ -217,15 +202,12 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed allowedGroupIDs[ag.ID] = true } } - // 2. 筛选 for _, group := range allGroupsFromCache { - // 检查Token权限 if isTokenRestricted && !allowedGroupIDs[group.ID] { continue } - // 检查模型是否被允许 isModelAllowed := false - if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许 + if len(group.AllowedModels) == 0 { isModelAllowed = true } else { for _, m := range group.AllowedModels { @@ -239,8 +221,6 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed candidateGroups = append(candidateGroups, group) } } - - // 3.按 Order 字段升序排序 sort.SliceStable(candidateGroups, func(i, j int) bool { return candidateGroups[i].Order < candidateGroups[j].Order }) diff --git a/internal/service/security_service.go b/internal/service/security_service.go index cfa979f..eccb033 100644 --- a/internal/service/security_service.go +++ b/internal/service/security_service.go @@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke // IsIPBanned func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) { banKey := fmt.Sprintf("banned_ip:%s", ip) - return s.store.Exists(banKey) + return s.store.Exists(ctx, banKey) } // RecordFailedLoginAttempt @@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin return nil } - count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1) + count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1) if err != nil { return err } @@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin banDuration := s.SettingsManager.GetIPBanDuration() banKey := fmt.Sprintf("banned_ip:%s", ip) - if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil { + if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil { return err } s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration) - s.store.HDel(loginAttemptsKey, ip) + s.store.HDel(ctx, loginAttemptsKey, ip) } return nil diff --git a/internal/service/stats_service.go b/internal/service/stats_service.go index 48496e7..50f7b9c 100644 --- a/internal/service/stats_service.go +++ b/internal/service/stats_service.go @@ -2,6 +2,7 @@ package service import ( + "context" "encoding/json" "fmt" "gemini-balancer/internal/models" @@ -34,7 +35,7 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, func (s *StatsService) Start() { s.logger.Info("Starting event listener for stats maintenance.") - sub, err := s.store.Subscribe(models.TopicKeyStatusChanged) + sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged) if err != nil { s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) return @@ -67,42 +68,43 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID) return } + ctx := context.Background() statsKey := fmt.Sprintf("stats:group:%d", event.GroupID) s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason) switch event.ChangeReason { case "key_unlinked", "key_hard_deleted": if event.OldStatus != "" { - s.store.HIncrBy(statsKey, "total_keys", -1) - s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) + s.store.HIncrBy(ctx, statsKey, "total_keys", -1) + s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) } else { s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID) - s.RecalculateGroupKeyStats(event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) } case "key_linked": if event.NewStatus != "" { - s.store.HIncrBy(statsKey, "total_keys", 1) - s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) + s.store.HIncrBy(ctx, statsKey, "total_keys", 1) + s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) } else { s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID) - s.RecalculateGroupKeyStats(event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) } case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key": - s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) - s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) + s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) + s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) default: s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID) - s.RecalculateGroupKeyStats(event.GroupID) + s.RecalculateGroupKeyStats(ctx, event.GroupID) } } -func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error { +func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error { s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID) var results []struct { Status models.APIKeyStatus Count int64 } - if err := s.db.Model(&models.GroupAPIKeyMapping{}). + if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}). Where("key_group_id = ?", groupID). Select("status, COUNT(*) as count"). Group("status"). @@ -119,37 +121,25 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error { } updates["total_keys"] = totalKeys - if err := s.store.Del(statsKey); err != nil { + if err := s.store.Del(ctx, statsKey); err != nil { s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID) } - if err := s.store.HSet(statsKey, updates); err != nil { + if err := s.store.HSet(ctx, statsKey, updates); err != nil { return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err) } s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID) return nil } -func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) { - // TODO 逻辑: - // 1. 从Redis中获取所有分组的Key统计 (HGetAll) - // 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率 - // 3. 组合成 DashboardStatsResponse - // ... 这个方法的具体实现,我们可以在DashboardQueryService中完成, - // 这里我们先确保StatsService的核心职责(维护缓存)已经完成。 - // 为了编译通过,我们先返回一个空对象。 - - // 伪代码: - // keyCounts, _ := s.store.HGetAll("stats:global:keys") - // ... - +func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) { return &models.DashboardStatsResponse{}, nil } -func (s *StatsService) AggregateHourlyStats() error { +func (s *StatsService) AggregateHourlyStats(ctx context.Context) error { s.logger.Info("Starting aggregation of the last hour's request data...") now := time.Now() - endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00 - startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00 + endTime := now.Truncate(time.Hour) + startTime := endTime.Add(-1 * time.Hour) s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339)) type aggregationResult struct { @@ -161,7 +151,8 @@ func (s *StatsService) AggregateHourlyStats() error { CompletionTokens int64 } var results []aggregationResult - err := s.db.Model(&models.RequestLog{}). + + err := s.db.WithContext(ctx).Model(&models.RequestLog{}). Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens"). Where("request_time >= ? AND request_time < ?", startTime, endTime). Group("group_id, model_name"). @@ -179,7 +170,7 @@ func (s *StatsService) AggregateHourlyStats() error { var hourlyStats []models.StatsHourly for _, res := range results { hourlyStats = append(hourlyStats, models.StatsHourly{ - Time: startTime, // 所有记录的时间戳都是该小时的起点 + Time: startTime, GroupID: res.GroupID, ModelName: res.ModelName, RequestCount: res.RequestCount, @@ -189,7 +180,7 @@ func (s *StatsService) AggregateHourlyStats() error { }) } - return s.db.Clauses(clause.OnConflict{ + return s.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}}, DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}), }).Create(&hourlyStats).Error diff --git a/internal/store/factory.go b/internal/store/factory.go index 6b5e454..7fe3950 100644 --- a/internal/store/factory.go +++ b/internal/store/factory.go @@ -1,3 +1,4 @@ +// Filename: internal/store/factory.go package store import ( @@ -11,7 +12,6 @@ import ( // NewStore creates a new store based on the application configuration. func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) { - // 检查是否有Redis配置 if cfg.Redis.DSN != "" { opts, err := redis.ParseURL(cfg.Redis.DSN) if err != nil { @@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) { client := redis.NewClient(opts) if err := client.Ping(context.Background()).Err(); err != nil { logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err) - return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误 + return NewMemoryStore(logger), nil } logger.Info("Successfully connected to Redis. Using Redis as store.") - return NewRedisStore(client), nil + return NewRedisStore(client, logger), nil } logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.") return NewMemoryStore(logger), nil diff --git a/internal/store/memory_store.go b/internal/store/memory_store.go index d7ee2cd..aaeed21 100644 --- a/internal/store/memory_store.go +++ b/internal/store/memory_store.go @@ -1,8 +1,9 @@ -// Filename: internal/store/memory_store.go (经同行审查后最终修复版) +// Filename: internal/store/memory_store.go package store import ( + "context" "fmt" "math/rand" "sort" @@ -12,6 +13,7 @@ import ( "github.com/sirupsen/logrus" ) +// ensure memoryStore implements Store interface var _ Store = (*memoryStore)(nil) type memoryStoreItem struct { @@ -32,7 +34,6 @@ type memoryStore struct { items map[string]*memoryStoreItem pubsub map[string][]chan *Message mu sync.RWMutex - // [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全 rng *rand.Rand rngMu sync.Mutex logger *logrus.Entry @@ -42,7 +43,6 @@ func NewMemoryStore(logger *logrus.Logger) Store { store := &memoryStore{ items: make(map[string]*memoryStoreItem), pubsub: make(map[string][]chan *Message), - // 使用当前时间作为种子,创建一个新的随机数源 rng: rand.New(rand.NewSource(time.Now().UnixNano())), logger: logger.WithField("component", "store.memory 🗱"), } @@ -50,13 +50,12 @@ func NewMemoryStore(logger *logrus.Logger) Store { return store } -// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查 func (s *memoryStore) startGCollector() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { s.mu.Lock() - now := time.Now() // 避免在循环中重复调用 + now := time.Now() for key, item := range s.items { if !item.expireAt.IsZero() && now.After(item.expireAt) { delete(s.items, key) @@ -66,92 +65,10 @@ func (s *memoryStore) startGCollector() { } } -// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题 -func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) { - s.mu.Lock() - defer s.mu.Unlock() +// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 --- +// --- 内存实现可以忽略该参数,用 _ 接收 --- - mainItem, mainOk := s.items[mainKey] - var mainSet map[string]struct{} - - if mainOk && !mainItem.isExpired() { - // 安全地进行类型断言 - mainSet, mainOk = mainItem.value.(map[string]struct{}) - // 确保断言成功且集合不为空 - mainOk = mainOk && len(mainSet) > 0 - } else { - mainOk = false - } - - if !mainOk { - cooldownItem, cooldownOk := s.items[cooldownKey] - if !cooldownOk || cooldownItem.isExpired() { - return "", ErrNotFound - } - // 安全地进行类型断言 - cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{}) - if !cooldownSetOk || len(cooldownSet) == 0 { - return "", ErrNotFound - } - - s.items[mainKey] = cooldownItem - delete(s.items, cooldownKey) - mainSet = cooldownSet - } - - var popped string - for k := range mainSet { - popped = k - break - } - delete(mainSet, popped) - - cooldownItem, cooldownOk := s.items[cooldownKey] - if !cooldownOk || cooldownItem.isExpired() { - cooldownItem = &memoryStoreItem{value: make(map[string]struct{})} - s.items[cooldownKey] = cooldownItem - } - // 安全地处理冷却池 - cooldownSet, ok := cooldownItem.value.(map[string]struct{}) - if !ok { - cooldownSet = make(map[string]struct{}) - cooldownItem.value = cooldownSet - } - cooldownSet[popped] = struct{}{} - - return popped, nil -} - -// SRandMember [并发修复版] 使用带锁的rng -func (s *memoryStore) SRandMember(key string) (string, error) { - s.mu.RLock() - defer s.mu.RUnlock() - item, ok := s.items[key] - if !ok || item.isExpired() { - return "", ErrNotFound - } - set, ok := item.value.(map[string]struct{}) - if !ok || len(set) == 0 { - return "", ErrNotFound - } - members := make([]string, 0, len(set)) - for member := range set { - members = append(members, member) - } - if len(members) == 0 { - return "", ErrNotFound - } - - s.rngMu.Lock() - n := s.rng.Intn(len(members)) - s.rngMu.Unlock() - - return members[n], nil -} - -// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 --- - -func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error { +func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error { s.mu.Lock() defer s.mu.Unlock() var expireAt time.Time @@ -162,7 +79,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error { return nil } -func (s *memoryStore) Get(key string) ([]byte, error) { +func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] @@ -175,7 +92,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) { return nil, ErrNotFound } -func (s *memoryStore) Del(keys ...string) error { +func (s *memoryStore) Del(_ context.Context, keys ...string) error { s.mu.Lock() defer s.mu.Unlock() for _, key := range keys { @@ -184,14 +101,14 @@ func (s *memoryStore) Del(keys ...string) error { return nil } -func (s *memoryStore) Exists(key string) (bool, error) { +func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] return ok && !item.isExpired(), nil } -func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { +func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -208,7 +125,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, func (s *memoryStore) Close() error { return nil } -func (s *memoryStore) HDel(key string, fields ...string) error { +func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -223,7 +140,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error { return nil } -func (s *memoryStore) HSet(key string, values map[string]any) error { +func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -242,7 +159,7 @@ func (s *memoryStore) HSet(key string, values map[string]any) error { return nil } -func (s *memoryStore) HGetAll(key string) (map[string]string, error) { +func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] @@ -259,7 +176,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) { return make(map[string]string), nil } -func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) { +func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -281,7 +198,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) { return newVal, nil } -func (s *memoryStore) LPush(key string, values ...any) error { +func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -301,7 +218,7 @@ func (s *memoryStore) LPush(key string, values ...any) error { return nil } -func (s *memoryStore) LRem(key string, count int64, value any) error { +func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -326,7 +243,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error { return nil } -func (s *memoryStore) SAdd(key string, members ...any) error { +func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -345,7 +262,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error { return nil } -func (s *memoryStore) SPopN(key string, count int64) ([]string, error) { +func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -375,7 +292,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) { return popped, nil } -func (s *memoryStore) SMembers(key string) ([]string, error) { +func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] @@ -393,7 +310,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) { return members, nil } -func (s *memoryStore) SRem(key string, members ...any) error { +func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -410,7 +327,31 @@ func (s *memoryStore) SRem(key string, members ...any) error { return nil } -func (s *memoryStore) Rotate(key string) (string, error) { +func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + item, ok := s.items[key] + if !ok || item.isExpired() { + return "", ErrNotFound + } + set, ok := item.value.(map[string]struct{}) + if !ok || len(set) == 0 { + return "", ErrNotFound + } + members := make([]string, 0, len(set)) + for member := range set { + members = append(members, member) + } + if len(members) == 0 { + return "", ErrNotFound + } + s.rngMu.Lock() + n := s.rng.Intn(len(members)) + s.rngMu.Unlock() + return members[n], nil +} + +func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -426,7 +367,7 @@ func (s *memoryStore) Rotate(key string) (string, error) { return val, nil } -func (s *memoryStore) LIndex(key string, index int64) (string, error) { +func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] @@ -447,8 +388,7 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) { return list[index], nil } -// Zset methods... (ZAdd, ZRange, ZRem) -func (s *memoryStore) ZAdd(key string, members map[string]float64) error { +func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -471,8 +411,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error { for val, score := range membersMap { newZSet = append(newZSet, zsetMember{Value: val, Score: score}) } - // NOTE: This ZSet implementation is simple but not performant for large sets. - // A production implementation would use a skip list or a balanced tree. sort.Slice(newZSet, func(i, j int) bool { if newZSet[i].Score == newZSet[j].Score { return newZSet[i].Value < newZSet[j].Value @@ -482,7 +420,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error { item.value = newZSet return nil } -func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) { +func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] @@ -515,7 +453,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) { } return result, nil } -func (s *memoryStore) ZRem(key string, members ...any) error { +func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] @@ -540,13 +478,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error { return nil } -// Pipeline implementation +func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + mainItem, mainOk := s.items[mainKey] + var mainSet map[string]struct{} + if mainOk && !mainItem.isExpired() { + mainSet, mainOk = mainItem.value.(map[string]struct{}) + mainOk = mainOk && len(mainSet) > 0 + } else { + mainOk = false + } + if !mainOk { + cooldownItem, cooldownOk := s.items[cooldownKey] + if !cooldownOk || cooldownItem.isExpired() { + return "", ErrNotFound + } + cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{}) + if !cooldownSetOk || len(cooldownSet) == 0 { + return "", ErrNotFound + } + s.items[mainKey] = cooldownItem + delete(s.items, cooldownKey) + mainSet = cooldownSet + } + var popped string + for k := range mainSet { + popped = k + break + } + delete(mainSet, popped) + cooldownItem, cooldownOk := s.items[cooldownKey] + if !cooldownOk || cooldownItem.isExpired() { + cooldownItem = &memoryStoreItem{value: make(map[string]struct{})} + s.items[cooldownKey] = cooldownItem + } + cooldownSet, ok := cooldownItem.value.(map[string]struct{}) + if !ok { + cooldownSet = make(map[string]struct{}) + cooldownItem.value = cooldownSet + } + cooldownSet[popped] = struct{}{} + return popped, nil +} + type memoryPipeliner struct { store *memoryStore ops []func() } -func (s *memoryStore) Pipeline() Pipeliner { +func (s *memoryStore) Pipeline(_ context.Context) Pipeliner { return &memoryPipeliner{store: s} } func (p *memoryPipeliner) Exec() error { @@ -559,7 +540,6 @@ func (p *memoryPipeliner) Exec() error { } func (p *memoryPipeliner) Expire(key string, expiration time.Duration) { - // [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference capturedKey := key p.ops = append(p.ops, func() { if item, ok := p.store.items[capturedKey]; ok { @@ -596,7 +576,6 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) { } }) } - func (p *memoryPipeliner) SRem(key string, members ...any) { capturedKey := key capturedMembers := make([]any, len(members)) @@ -615,7 +594,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) { } }) } - func (p *memoryPipeliner) LPush(key string, values ...any) { capturedKey := key capturedValues := make([]any, len(values)) @@ -637,11 +615,12 @@ func (p *memoryPipeliner) LPush(key string, values ...any) { item.value = append(stringValues, list...) }) } -func (p *memoryPipeliner) LRem(key string, count int64, value any) {} -func (p *memoryPipeliner) HSet(key string, values map[string]any) {} -func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {} +func (p *memoryPipeliner) LRem(key string, count int64, value any) {} +func (p *memoryPipeliner) HSet(key string, values map[string]any) {} +func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {} +func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {} +func (p *memoryPipeliner) ZRem(key string, members ...any) {} -// --- Pub/Sub implementation (remains unchanged) --- type memorySubscription struct { store *memoryStore channelName string @@ -649,10 +628,11 @@ type memorySubscription struct { } func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan } +func (ms *memorySubscription) ChannelName() string { return ms.channelName } func (ms *memorySubscription) Close() error { return ms.store.removeSubscriber(ms.channelName, ms.msgChan) } -func (s *memoryStore) Publish(channel string, message []byte) error { +func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error { s.mu.RLock() defer s.mu.RUnlock() subscribers, ok := s.pubsub[channel] @@ -669,7 +649,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error { } return nil } -func (s *memoryStore) Subscribe(channel string) (Subscription, error) { +func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) { s.mu.Lock() defer s.mu.Unlock() msgChan := make(chan *Message, 10) diff --git a/internal/store/redis_store.go b/internal/store/redis_store.go index 0ef5601..e16f849 100644 --- a/internal/store/redis_store.go +++ b/internal/store/redis_store.go @@ -1,3 +1,5 @@ +// Filename: internal/store/redis_store.go + package store import ( @@ -8,22 +10,20 @@ import ( "time" "github.com/redis/go-redis/v9" + "github.com/sirupsen/logrus" ) // ensure RedisStore implements Store interface var _ Store = (*RedisStore)(nil) -// RedisStore is a Redis-backed key-value store. type RedisStore struct { client *redis.Client popAndCycleScript *redis.Script + logger *logrus.Entry } // NewRedisStore creates a new RedisStore instance. -func NewRedisStore(client *redis.Client) Store { - // Lua script for atomic pop-and-cycle operation. - // KEYS[1]: main set key - // KEYS[2]: cooldown set key +func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store { const script = ` if redis.call('SCARD', KEYS[1]) == 0 then if redis.call('SCARD', KEYS[2]) == 0 then @@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store { return &RedisStore{ client: client, popAndCycleScript: redis.NewScript(script), + logger: logger.WithField("component", "store.redis 🗄️"), } } -func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error { - return s.client.Set(context.Background(), key, value, ttl).Err() +func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return s.client.Set(ctx, key, value, ttl).Err() } -func (s *RedisStore) Get(key string) ([]byte, error) { - val, err := s.client.Get(context.Background(), key).Bytes() +func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) { + val, err := s.client.Get(ctx, key).Bytes() if err != nil { if errors.Is(err, redis.Nil) { return nil, ErrNotFound @@ -54,53 +55,53 @@ func (s *RedisStore) Get(key string) ([]byte, error) { return val, nil } -func (s *RedisStore) Del(keys ...string) error { +func (s *RedisStore) Del(ctx context.Context, keys ...string) error { if len(keys) == 0 { return nil } - return s.client.Del(context.Background(), keys...).Err() + return s.client.Del(ctx, keys...).Err() } -func (s *RedisStore) Exists(key string) (bool, error) { - val, err := s.client.Exists(context.Background(), key).Result() +func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) { + val, err := s.client.Exists(ctx, key).Result() return val > 0, err } -func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { - return s.client.SetNX(context.Background(), key, value, ttl).Result() +func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) { + return s.client.SetNX(ctx, key, value, ttl).Result() } func (s *RedisStore) Close() error { return s.client.Close() } -func (s *RedisStore) HSet(key string, values map[string]any) error { - return s.client.HSet(context.Background(), key, values).Err() +func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error { + return s.client.HSet(ctx, key, values).Err() } -func (s *RedisStore) HGetAll(key string) (map[string]string, error) { - return s.client.HGetAll(context.Background(), key).Result() +func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) { + return s.client.HGetAll(ctx, key).Result() } -func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) { - return s.client.HIncrBy(context.Background(), key, field, incr).Result() +func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) { + return s.client.HIncrBy(ctx, key, field, incr).Result() } -func (s *RedisStore) HDel(key string, fields ...string) error { +func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error { if len(fields) == 0 { return nil } - return s.client.HDel(context.Background(), key, fields...).Err() + return s.client.HDel(ctx, key, fields...).Err() } -func (s *RedisStore) LPush(key string, values ...any) error { - return s.client.LPush(context.Background(), key, values...).Err() +func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error { + return s.client.LPush(ctx, key, values...).Err() } -func (s *RedisStore) LRem(key string, count int64, value any) error { - return s.client.LRem(context.Background(), key, count, value).Err() +func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error { + return s.client.LRem(ctx, key, count, value).Err() } -func (s *RedisStore) Rotate(key string) (string, error) { - val, err := s.client.RPopLPush(context.Background(), key, key).Result() +func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) { + val, err := s.client.RPopLPush(ctx, key, key).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound @@ -110,29 +111,28 @@ func (s *RedisStore) Rotate(key string) (string, error) { return val, nil } -func (s *RedisStore) SAdd(key string, members ...any) error { - return s.client.SAdd(context.Background(), key, members...).Err() +func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error { + return s.client.SAdd(ctx, key, members...).Err() } -func (s *RedisStore) SPopN(key string, count int64) ([]string, error) { - return s.client.SPopN(context.Background(), key, count).Result() +func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) { + return s.client.SPopN(ctx, key, count).Result() } -func (s *RedisStore) SMembers(key string) ([]string, error) { - return s.client.SMembers(context.Background(), key).Result() +func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) { + return s.client.SMembers(ctx, key).Result() } -func (s *RedisStore) SRem(key string, members ...any) error { +func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error { if len(members) == 0 { return nil } - return s.client.SRem(context.Background(), key, members...).Err() + return s.client.SRem(ctx, key, members...).Err() } -func (s *RedisStore) SRandMember(key string) (string, error) { - member, err := s.client.SRandMember(context.Background(), key).Result() +func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) { + member, err := s.client.SRandMember(ctx, key).Result() if err != nil { - if errors.Is(err, redis.Nil) { return "", ErrNotFound } @@ -141,81 +141,43 @@ func (s *RedisStore) SRandMember(key string) (string, error) { return member, nil } -// === 新增方法实现 === - -func (s *RedisStore) ZAdd(key string, members map[string]float64) error { +func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error { if len(members) == 0 { return nil } - redisMembers := make([]redis.Z, 0, len(members)) + redisMembers := make([]redis.Z, len(members)) + i := 0 for member, score := range members { - redisMembers = append(redisMembers, redis.Z{Score: score, Member: member}) + redisMembers[i] = redis.Z{Score: score, Member: member} + i++ } - return s.client.ZAdd(context.Background(), key, redisMembers...).Err() + return s.client.ZAdd(ctx, key, redisMembers...).Err() } -func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) { - return s.client.ZRange(context.Background(), key, start, stop).Result() +func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + return s.client.ZRange(ctx, key, start, stop).Result() } -func (s *RedisStore) ZRem(key string, members ...any) error { +func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error { if len(members) == 0 { return nil } - return s.client.ZRem(context.Background(), key, members...).Err() + return s.client.ZRem(ctx, key, members...).Err() } -func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) { - val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result() +func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) { + val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } - // Lua script returns a string, so we need to type assert if str, ok := val.(string); ok { return str, nil } - return "", ErrNotFound // This happens if both sets were empty and the script returned nil + return "", ErrNotFound } -type redisPipeliner struct{ pipe redis.Pipeliner } - -func (p *redisPipeliner) HSet(key string, values map[string]any) { - p.pipe.HSet(context.Background(), key, values) -} -func (p *redisPipeliner) HIncrBy(key, field string, incr int64) { - p.pipe.HIncrBy(context.Background(), key, field, incr) -} -func (p *redisPipeliner) Exec() error { - _, err := p.pipe.Exec(context.Background()) - return err -} - -func (p *redisPipeliner) Del(keys ...string) { - if len(keys) > 0 { - p.pipe.Del(context.Background(), keys...) - } -} - -func (p *redisPipeliner) SAdd(key string, members ...any) { - p.pipe.SAdd(context.Background(), key, members...) -} - -func (p *redisPipeliner) SRem(key string, members ...any) { - if len(members) > 0 { - p.pipe.SRem(context.Background(), key, members...) - } -} - -func (p *redisPipeliner) LPush(key string, values ...any) { - p.pipe.LPush(context.Background(), key, values...) -} - -func (p *redisPipeliner) LRem(key string, count int64, value any) { - p.pipe.LRem(context.Background(), key, count, value) -} - -func (s *RedisStore) LIndex(key string, index int64) (string, error) { - val, err := s.client.LIndex(context.Background(), key, index).Result() +func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) { + val, err := s.client.LIndex(ctx, key, index).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound @@ -225,47 +187,120 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) { return val, nil } -func (p *redisPipeliner) Expire(key string, expiration time.Duration) { - p.pipe.Expire(context.Background(), key, expiration) +type redisPipeliner struct { + pipe redis.Pipeliner + ctx context.Context } -func (s *RedisStore) Pipeline() Pipeliner { - return &redisPipeliner{pipe: s.client.Pipeline()} +func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner { + return &redisPipeliner{ + pipe: s.client.Pipeline(), + ctx: ctx, + } } +func (p *redisPipeliner) Exec() error { + _, err := p.pipe.Exec(p.ctx) + return err +} + +func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) } +func (p *redisPipeliner) Expire(key string, expiration time.Duration) { + p.pipe.Expire(p.ctx, key, expiration) +} +func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) } +func (p *redisPipeliner) HIncrBy(key, field string, incr int64) { + p.pipe.HIncrBy(p.ctx, key, field, incr) +} +func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) } +func (p *redisPipeliner) LRem(key string, count int64, value any) { + p.pipe.LRem(p.ctx, key, count, value) +} +func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) } +func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) } +func (p *redisPipeliner) ZAdd(key string, members map[string]float64) { + if len(members) == 0 { + return + } + redisMembers := make([]redis.Z, len(members)) + i := 0 + for member, score := range members { + redisMembers[i] = redis.Z{Score: score, Member: member} + i++ + } + p.pipe.ZAdd(p.ctx, key, redisMembers...) +} +func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) } + type redisSubscription struct { - pubsub *redis.PubSub - msgChan chan *Message - once sync.Once + pubsub *redis.PubSub + msgChan chan *Message + logger *logrus.Entry + wg sync.WaitGroup + close context.CancelFunc + channelName string +} + +func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) { + pubsub := s.client.Subscribe(ctx, channel) + _, err := pubsub.Receive(ctx) + if err != nil { + _ = pubsub.Close() + return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err) + } + subCtx, cancel := context.WithCancel(context.Background()) + sub := &redisSubscription{ + pubsub: pubsub, + msgChan: make(chan *Message, 10), + logger: s.logger, + close: cancel, + channelName: channel, + } + sub.wg.Add(1) + go sub.bridge(subCtx) + return sub, nil +} + +func (rs *redisSubscription) bridge(ctx context.Context) { + defer rs.wg.Done() + defer close(rs.msgChan) + redisCh := rs.pubsub.Channel() + for { + select { + case <-ctx.Done(): + return + case redisMsg, ok := <-redisCh: + if !ok { + return + } + msg := &Message{ + Channel: redisMsg.Channel, + Payload: []byte(redisMsg.Payload), + } + select { + case rs.msgChan <- msg: + default: + rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName) + } + } + } } func (rs *redisSubscription) Channel() <-chan *Message { - rs.once.Do(func() { - rs.msgChan = make(chan *Message) - go func() { - defer close(rs.msgChan) - for redisMsg := range rs.pubsub.Channel() { - rs.msgChan <- &Message{ - Channel: redisMsg.Channel, - Payload: []byte(redisMsg.Payload), - } - } - }() - }) return rs.msgChan } -func (rs *redisSubscription) Close() error { return rs.pubsub.Close() } - -func (s *RedisStore) Publish(channel string, message []byte) error { - return s.client.Publish(context.Background(), channel, message).Err() +func (rs *redisSubscription) ChannelName() string { + return rs.channelName } -func (s *RedisStore) Subscribe(channel string) (Subscription, error) { - pubsub := s.client.Subscribe(context.Background(), channel) - _, err := pubsub.Receive(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err) - } - return &redisSubscription{pubsub: pubsub}, nil +func (rs *redisSubscription) Close() error { + rs.close() + err := rs.pubsub.Close() + rs.wg.Wait() + return err +} + +func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error { + return s.client.Publish(ctx, channel, message).Err() } diff --git a/internal/store/store.go b/internal/store/store.go index 93b9eb2..fc2edcf 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,6 +1,9 @@ +// Filename: internal/store/store.go + package store import ( + "context" "errors" "time" ) @@ -17,6 +20,7 @@ type Message struct { // Subscription represents an active subscription to a pub/sub channel. type Subscription interface { Channel() <-chan *Message + ChannelName() string Close() error } @@ -38,6 +42,10 @@ type Pipeliner interface { LPush(key string, values ...any) LRem(key string, count int64, value any) + // ZSET + ZAdd(key string, members map[string]float64) + ZRem(key string, members ...any) + // Execution Exec() error } @@ -45,44 +53,44 @@ type Pipeliner interface { // Store is the master interface for our cache service. type Store interface { // Basic K/V operations - Set(key string, value []byte, ttl time.Duration) error - Get(key string) ([]byte, error) - Del(keys ...string) error - Exists(key string) (bool, error) - SetNX(key string, value []byte, ttl time.Duration) (bool, error) + Set(ctx context.Context, key string, value []byte, ttl time.Duration) error + Get(ctx context.Context, key string) ([]byte, error) + Del(ctx context.Context, keys ...string) error + Exists(ctx context.Context, key string) (bool, error) + SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) // HASH operations - HSet(key string, values map[string]any) error - HGetAll(key string) (map[string]string, error) - HIncrBy(key, field string, incr int64) (int64, error) - HDel(key string, fields ...string) error // [新增] + HSet(ctx context.Context, key string, values map[string]any) error + HGetAll(ctx context.Context, key string) (map[string]string, error) + HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) + HDel(ctx context.Context, key string, fields ...string) error // LIST operations - LPush(key string, values ...any) error - LRem(key string, count int64, value any) error - Rotate(key string) (string, error) - LIndex(key string, index int64) (string, error) + LPush(ctx context.Context, key string, values ...any) error + LRem(ctx context.Context, key string, count int64, value any) error + Rotate(ctx context.Context, key string) (string, error) + LIndex(ctx context.Context, key string, index int64) (string, error) // SET operations - SAdd(key string, members ...any) error - SPopN(key string, count int64) ([]string, error) - SMembers(key string) ([]string, error) - SRem(key string, members ...any) error - SRandMember(key string) (string, error) + SAdd(ctx context.Context, key string, members ...any) error + SPopN(ctx context.Context, key string, count int64) ([]string, error) + SMembers(ctx context.Context, key string) ([]string, error) + SRem(ctx context.Context, key string, members ...any) error + SRandMember(ctx context.Context, key string) (string, error) // Pub/Sub operations - Publish(channel string, message []byte) error - Subscribe(channel string) (Subscription, error) + Publish(ctx context.Context, channel string, message []byte) error + Subscribe(ctx context.Context, channel string) (Subscription, error) - // Pipeline (optional) - 我们在redis实现它,内存版暂时不实现 - Pipeline() Pipeliner + // Pipeline + Pipeline(ctx context.Context) Pipeliner // Close closes the store and releases any underlying resources. Close() error - // === 新增方法,支持轮询策略 === - ZAdd(key string, members map[string]float64) error - ZRange(key string, start, stop int64) ([]string, error) - ZRem(key string, members ...any) error - PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) + // ZSET operations + ZAdd(ctx context.Context, key string, members map[string]float64) error + ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) + ZRem(ctx context.Context, key string, members ...any) error + PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) } diff --git a/internal/syncer/syncer.go b/internal/syncer/syncer.go index 1ea7da5..66bdaf5 100644 --- a/internal/syncer/syncer.go +++ b/internal/syncer/syncer.go @@ -1,6 +1,7 @@ package syncer import ( + "context" "fmt" "gemini-balancer/internal/store" "log" @@ -51,7 +52,7 @@ func (s *CacheSyncer[T]) Get() T { func (s *CacheSyncer[T]) Invalidate() error { log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName) - return s.store.Publish(s.channelName, []byte("reload")) + return s.store.Publish(context.Background(), s.channelName, []byte("reload")) } func (s *CacheSyncer[T]) Stop() { @@ -84,7 +85,7 @@ func (s *CacheSyncer[T]) listenForUpdates() { default: } - subscription, err := s.store.Subscribe(s.channelName) + subscription, err := s.store.Subscribe(context.Background(), s.channelName) if err != nil { log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err) time.Sleep(5 * time.Second) diff --git a/internal/task/task.go b/internal/task/task.go index bbc5662..aa204c8 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -1,7 +1,8 @@ -// Filename: internal/task/task.go (最终校准版) +// Filename: internal/task/task.go package task import ( + "context" "encoding/json" "errors" "fmt" @@ -15,15 +16,13 @@ const ( ResultTTL = 60 * time.Minute ) -// Reporter 接口,定义了领域如何与任务服务交互。 type Reporter interface { - StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) - EndTaskByID(taskID, resourceID string, result any, taskErr error) - UpdateProgressByID(taskID string, processed int) error - UpdateTotalByID(taskID string, total int) error + StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) + EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error) + UpdateProgressByID(ctx context.Context, taskID string, processed int) error + UpdateTotalByID(ctx context.Context, taskID string, total int) error } -// Status 代表一个后台任务的完整状态 type Status struct { ID string `json:"id"` TaskType string `json:"task_type"` @@ -38,13 +37,11 @@ type Status struct { DurationSeconds float64 `json:"duration_seconds,omitempty"` } -// Task 是任务管理的核心服务 type Task struct { store store.Store logger *logrus.Entry } -// NewTask 是 Task 的构造函数 func NewTask(store store.Store, logger *logrus.Logger) *Task { return &Task{ store: store, @@ -62,15 +59,14 @@ func (s *Task) getTaskDataKey(taskID string) string { return fmt.Sprintf("task:data:%s", taskID) } -// --- 新增的輔助函數,用於獲取原子標記的鍵 --- func (s *Task) getIsRunningFlagKey(taskID string) string { return fmt.Sprintf("task:running:%s", taskID) } -func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { +func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { lockKey := s.getResourceLockKey(resourceID) - if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 { + if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 { return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID)) } @@ -94,35 +90,34 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int timeout = ResultTTL * 24 } - if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil { + if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil { return nil, fmt.Errorf("failed to acquire task resource lock: %w", err) } - if err := s.store.Set(taskKey, statusBytes, timeout); err != nil { - _ = s.store.Del(lockKey) + if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil { + _ = s.store.Del(ctx, lockKey) return nil, fmt.Errorf("failed to set new task data in store: %w", err) } - // 創建一個獨立的“運行中”標記,它的存在與否是原子性的 - if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil { - _ = s.store.Del(lockKey) - _ = s.store.Del(taskKey) + if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil { + _ = s.store.Del(ctx, lockKey) + _ = s.store.Del(ctx, taskKey) return nil, fmt.Errorf("failed to set task running flag: %w", err) } return status, nil } -func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) { +func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) { lockKey := s.getResourceLockKey(resourceID) defer func() { - if err := s.store.Del(lockKey); err != nil { + if err := s.store.Del(ctx, lockKey); err != nil { s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID) } }() runningFlagKey := s.getIsRunningFlagKey(taskID) - _ = s.store.Del(runningFlagKey) - status, err := s.GetStatus(taskID) - if err != nil { + _ = s.store.Del(ctx, runningFlagKey) + status, err := s.GetStatus(ctx, taskID) + if err != nil { s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID) return } @@ -141,15 +136,14 @@ func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr er } updatedTaskBytes, _ := json.Marshal(status) taskKey := s.getTaskDataKey(taskID) - if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil { + if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil { s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID) } } -// GetStatus 通过ID获取任务状态,供外部(如API Handler)调用 -func (s *Task) GetStatus(taskID string) (*Status, error) { +func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) { taskKey := s.getTaskDataKey(taskID) - statusBytes, err := s.store.Get(taskKey) + statusBytes, err := s.store.Get(ctx, taskKey) if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, errors.New("task not found") @@ -161,22 +155,18 @@ func (s *Task) GetStatus(taskID string) (*Status, error) { if err := json.Unmarshal(statusBytes, &status); err != nil { return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID) } - if !status.IsRunning && status.FinishedAt != nil { status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() } - return &status, nil } -// UpdateProgressByID 通过ID更新任务进度 -func (s *Task) updateTask(taskID string, updater func(status *Status)) error { +func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error { runningFlagKey := s.getIsRunningFlagKey(taskID) - if _, err := s.store.Get(runningFlagKey); err != nil { - // 任务已结束,静默返回是预期行为。 + if _, err := s.store.Get(ctx, runningFlagKey); err != nil { return nil } - status, err := s.GetStatus(taskID) + status, err := s.GetStatus(ctx, taskID) if err != nil { s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID) return nil @@ -184,7 +174,6 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error { if !status.IsRunning { return nil } - // 调用传入的 updater 函数来修改 status updater(status) statusBytes, marshalErr := json.Marshal(status) if marshalErr != nil { @@ -192,23 +181,20 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error { return nil } taskKey := s.getTaskDataKey(taskID) - // 使用更长的TTL,确保运行中的任务不会过早过期 - if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil { + if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil { s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID) } return nil } -// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。 -func (s *Task) UpdateProgressByID(taskID string, processed int) error { - return s.updateTask(taskID, func(status *Status) { +func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error { + return s.updateTask(ctx, taskID, func(status *Status) { status.Processed = processed }) } -// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。 -func (s *Task) UpdateTotalByID(taskID string, total int) error { - return s.updateTask(taskID, func(status *Status) { +func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error { + return s.updateTask(ctx, taskID, func(status *Status) { status.Total = total }) }