// Filename: internal/task/task.go (最终校准版) package task import ( "encoding/json" "errors" "fmt" "gemini-balancer/internal/store" "time" "github.com/sirupsen/logrus" ) const ( ResultTTL = 60 * time.Minute ) // Reporter 接口,定义了领域如何与任务服务交互。 type Reporter interface { StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) EndTaskByID(taskID, resourceID string, result any, taskErr error) UpdateProgressByID(taskID string, processed int) error UpdateTotalByID(taskID string, total int) error } // Status 代表一个后台任务的完整状态 type Status struct { ID string `json:"id"` TaskType string `json:"task_type"` IsRunning bool `json:"is_running"` ResourceID string `json:"resource_id,omitempty"` Processed int `json:"processed"` Total int `json:"total"` Result any `json:"result,omitempty"` Error string `json:"error,omitempty"` StartedAt time.Time `json:"started_at"` FinishedAt *time.Time `json:"finished_at,omitempty"` DurationSeconds float64 `json:"duration_seconds,omitempty"` } // Task 是任务管理的核心服务 type Task struct { store store.Store logger *logrus.Entry } // NewTask 是 Task 的构造函数 func NewTask(store store.Store, logger *logrus.Logger) *Task { return &Task{ store: store, logger: logger.WithField("component", "TaskService📋"), } } var _ Reporter = (*Task)(nil) func (s *Task) getResourceLockKey(resourceID string) string { return fmt.Sprintf("task:lock:%s", resourceID) } func (s *Task) getTaskDataKey(taskID string) string { return fmt.Sprintf("task:data:%s", taskID) } // --- 新增的輔助函數,用於獲取原子標記的鍵 --- func (s *Task) getIsRunningFlagKey(taskID string) string { return fmt.Sprintf("task:running:%s", taskID) } func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { lockKey := s.getResourceLockKey(resourceID) if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 { return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID)) } taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID) taskKey := s.getTaskDataKey(taskID) runningFlagKey := s.getIsRunningFlagKey(taskID) status := &Status{ ID: taskID, TaskType: taskType, IsRunning: true, ResourceID: resourceID, Total: total, StartedAt: time.Now(), } statusBytes, err := json.Marshal(status) if err != nil { return nil, fmt.Errorf("failed to serialize new task status: %w", err) } if timeout == 0 { timeout = ResultTTL * 24 } if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil { return nil, fmt.Errorf("failed to acquire task resource lock: %w", err) } if err := s.store.Set(taskKey, statusBytes, timeout); err != nil { _ = s.store.Del(lockKey) return nil, fmt.Errorf("failed to set new task data in store: %w", err) } // 創建一個獨立的“運行中”標記,它的存在與否是原子性的 if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil { _ = s.store.Del(lockKey) _ = s.store.Del(taskKey) return nil, fmt.Errorf("failed to set task running flag: %w", err) } return status, nil } func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) { lockKey := s.getResourceLockKey(resourceID) defer func() { if err := s.store.Del(lockKey); err != nil { s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID) } }() runningFlagKey := s.getIsRunningFlagKey(taskID) _ = s.store.Del(runningFlagKey) status, err := s.GetStatus(taskID) if err != nil { s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID) return } if !status.IsRunning { s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID) return } now := time.Now() status.IsRunning = false status.FinishedAt = &now status.DurationSeconds = now.Sub(status.StartedAt).Seconds() if taskErr != nil { status.Error = taskErr.Error() } else { status.Result = resultData } updatedTaskBytes, _ := json.Marshal(status) taskKey := s.getTaskDataKey(taskID) if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil { s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID) } } // GetStatus 通过ID获取任务状态,供外部(如API Handler)调用 func (s *Task) GetStatus(taskID string) (*Status, error) { taskKey := s.getTaskDataKey(taskID) statusBytes, err := s.store.Get(taskKey) if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, errors.New("task not found") } return nil, fmt.Errorf("failed to get task status from store: %w", err) } var status Status if err := json.Unmarshal(statusBytes, &status); err != nil { return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID) } if !status.IsRunning && status.FinishedAt != nil { status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() } return &status, nil } // UpdateProgressByID 通过ID更新任务进度 func (s *Task) updateTask(taskID string, updater func(status *Status)) error { runningFlagKey := s.getIsRunningFlagKey(taskID) if _, err := s.store.Get(runningFlagKey); err != nil { // 任务已结束,静默返回是预期行为。 return nil } status, err := s.GetStatus(taskID) if err != nil { s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID) return nil } if !status.IsRunning { return nil } // 调用传入的 updater 函数来修改 status updater(status) statusBytes, marshalErr := json.Marshal(status) if marshalErr != nil { s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID) return nil } taskKey := s.getTaskDataKey(taskID) // 使用更长的TTL,确保运行中的任务不会过早过期 if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil { s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID) } return nil } // [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。 func (s *Task) UpdateProgressByID(taskID string, processed int) error { return s.updateTask(taskID, func(status *Status) { status.Processed = processed }) } // [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。 func (s *Task) UpdateTotalByID(taskID string, total int) error { return s.updateTask(taskID, func(status *Status) { status.Total = total }) }