Update Context for store
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
// Filename: internal/task/task.go (最终校准版)
|
||||
// Filename: internal/task/task.go
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,15 +16,13 @@ const (
|
||||
ResultTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
// Reporter 接口,定义了领域如何与任务服务交互。
|
||||
type Reporter interface {
|
||||
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||||
EndTaskByID(taskID, resourceID string, result any, taskErr error)
|
||||
UpdateProgressByID(taskID string, processed int) error
|
||||
UpdateTotalByID(taskID string, total int) error
|
||||
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||||
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
|
||||
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
|
||||
UpdateTotalByID(ctx context.Context, taskID string, total int) error
|
||||
}
|
||||
|
||||
// Status 代表一个后台任务的完整状态
|
||||
type Status struct {
|
||||
ID string `json:"id"`
|
||||
TaskType string `json:"task_type"`
|
||||
@@ -38,13 +37,11 @@ type Status struct {
|
||||
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// Task 是任务管理的核心服务
|
||||
type Task struct {
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewTask 是 Task 的构造函数
|
||||
func NewTask(store store.Store, logger *logrus.Logger) *Task {
|
||||
return &Task{
|
||||
store: store,
|
||||
@@ -62,15 +59,14 @@ func (s *Task) getTaskDataKey(taskID string) string {
|
||||
return fmt.Sprintf("task:data:%s", taskID)
|
||||
}
|
||||
|
||||
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
|
||||
func (s *Task) getIsRunningFlagKey(taskID string) string {
|
||||
return fmt.Sprintf("task:running:%s", taskID)
|
||||
}
|
||||
|
||||
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
|
||||
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
|
||||
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
|
||||
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
||||
}
|
||||
|
||||
@@ -94,35 +90,34 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int
|
||||
timeout = ResultTTL * 24
|
||||
}
|
||||
|
||||
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
|
||||
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
|
||||
}
|
||||
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
|
||||
}
|
||||
|
||||
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
|
||||
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
_ = s.store.Del(taskKey)
|
||||
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
_ = s.store.Del(ctx, taskKey)
|
||||
return nil, fmt.Errorf("failed to set task running flag: %w", err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
|
||||
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
defer func() {
|
||||
if err := s.store.Del(lockKey); err != nil {
|
||||
if err := s.store.Del(ctx, lockKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
|
||||
}
|
||||
}()
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
_ = s.store.Del(runningFlagKey)
|
||||
status, err := s.GetStatus(taskID)
|
||||
if err != nil {
|
||||
_ = s.store.Del(ctx, runningFlagKey)
|
||||
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
|
||||
return
|
||||
}
|
||||
@@ -141,15 +136,14 @@ func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr er
|
||||
}
|
||||
updatedTaskBytes, _ := json.Marshal(status)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||||
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus 通过ID获取任务状态,供外部(如API Handler)调用
|
||||
func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||||
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
statusBytes, err := s.store.Get(taskKey)
|
||||
statusBytes, err := s.store.Get(ctx, taskKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
@@ -161,22 +155,18 @@ func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||||
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
||||
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
|
||||
}
|
||||
|
||||
if !status.IsRunning && status.FinishedAt != nil {
|
||||
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// UpdateProgressByID 通过ID更新任务进度
|
||||
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if _, err := s.store.Get(runningFlagKey); err != nil {
|
||||
// 任务已结束,静默返回是预期行为。
|
||||
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
|
||||
return nil
|
||||
}
|
||||
status, err := s.GetStatus(taskID)
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
|
||||
return nil
|
||||
@@ -184,7 +174,6 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
if !status.IsRunning {
|
||||
return nil
|
||||
}
|
||||
// 调用传入的 updater 函数来修改 status
|
||||
updater(status)
|
||||
statusBytes, marshalErr := json.Marshal(status)
|
||||
if marshalErr != nil {
|
||||
@@ -192,23 +181,20 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
return nil
|
||||
}
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
// 使用更长的TTL,确保运行中的任务不会过早过期
|
||||
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
|
||||
return s.updateTask(ctx, taskID, func(status *Status) {
|
||||
status.Processed = processed
|
||||
})
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateTotalByID(taskID string, total int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
|
||||
return s.updateTask(ctx, taskID, func(status *Status) {
|
||||
status.Total = total
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user