215 lines
6.8 KiB
Go
215 lines
6.8 KiB
Go
// Filename: internal/task/task.go (最终校准版)
|
||
package task
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"gemini-balancer/internal/store"
|
||
"time"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
const (
|
||
ResultTTL = 60 * time.Minute
|
||
)
|
||
|
||
// Reporter 接口,定义了领域如何与任务服务交互。
|
||
type Reporter interface {
|
||
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||
EndTaskByID(taskID, resourceID string, result any, taskErr error)
|
||
UpdateProgressByID(taskID string, processed int) error
|
||
UpdateTotalByID(taskID string, total int) error
|
||
}
|
||
|
||
// Status 代表一个后台任务的完整状态
|
||
type Status struct {
|
||
ID string `json:"id"`
|
||
TaskType string `json:"task_type"`
|
||
IsRunning bool `json:"is_running"`
|
||
ResourceID string `json:"resource_id,omitempty"`
|
||
Processed int `json:"processed"`
|
||
Total int `json:"total"`
|
||
Result any `json:"result,omitempty"`
|
||
Error string `json:"error,omitempty"`
|
||
StartedAt time.Time `json:"started_at"`
|
||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
||
}
|
||
|
||
// Task 是任务管理的核心服务
|
||
type Task struct {
|
||
store store.Store
|
||
logger *logrus.Entry
|
||
}
|
||
|
||
// NewTask 是 Task 的构造函数
|
||
func NewTask(store store.Store, logger *logrus.Logger) *Task {
|
||
return &Task{
|
||
store: store,
|
||
logger: logger.WithField("component", "TaskService📋"),
|
||
}
|
||
}
|
||
|
||
var _ Reporter = (*Task)(nil)
|
||
|
||
func (s *Task) getResourceLockKey(resourceID string) string {
|
||
return fmt.Sprintf("task:lock:%s", resourceID)
|
||
}
|
||
|
||
func (s *Task) getTaskDataKey(taskID string) string {
|
||
return fmt.Sprintf("task:data:%s", taskID)
|
||
}
|
||
|
||
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
|
||
func (s *Task) getIsRunningFlagKey(taskID string) string {
|
||
return fmt.Sprintf("task:running:%s", taskID)
|
||
}
|
||
|
||
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||
lockKey := s.getResourceLockKey(resourceID)
|
||
|
||
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
|
||
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
||
}
|
||
|
||
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
||
taskKey := s.getTaskDataKey(taskID)
|
||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||
status := &Status{
|
||
ID: taskID,
|
||
TaskType: taskType,
|
||
IsRunning: true,
|
||
ResourceID: resourceID,
|
||
Total: total,
|
||
StartedAt: time.Now(),
|
||
}
|
||
statusBytes, err := json.Marshal(status)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
|
||
}
|
||
|
||
if timeout == 0 {
|
||
timeout = ResultTTL * 24
|
||
}
|
||
|
||
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
|
||
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
|
||
}
|
||
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
|
||
_ = s.store.Del(lockKey)
|
||
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
|
||
}
|
||
|
||
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
|
||
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
|
||
_ = s.store.Del(lockKey)
|
||
_ = s.store.Del(taskKey)
|
||
return nil, fmt.Errorf("failed to set task running flag: %w", err)
|
||
}
|
||
return status, nil
|
||
}
|
||
|
||
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
|
||
lockKey := s.getResourceLockKey(resourceID)
|
||
defer func() {
|
||
if err := s.store.Del(lockKey); err != nil {
|
||
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
|
||
}
|
||
}()
|
||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||
_ = s.store.Del(runningFlagKey)
|
||
status, err := s.GetStatus(taskID)
|
||
if err != nil {
|
||
|
||
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
|
||
return
|
||
}
|
||
if !status.IsRunning {
|
||
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
|
||
return
|
||
}
|
||
now := time.Now()
|
||
status.IsRunning = false
|
||
status.FinishedAt = &now
|
||
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
|
||
if taskErr != nil {
|
||
status.Error = taskErr.Error()
|
||
} else {
|
||
status.Result = resultData
|
||
}
|
||
updatedTaskBytes, _ := json.Marshal(status)
|
||
taskKey := s.getTaskDataKey(taskID)
|
||
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
|
||
}
|
||
}
|
||
|
||
// GetStatus 通过ID获取任务状态,供外部(如API Handler)调用
|
||
func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||
taskKey := s.getTaskDataKey(taskID)
|
||
statusBytes, err := s.store.Get(taskKey)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNotFound) {
|
||
return nil, errors.New("task not found")
|
||
}
|
||
return nil, fmt.Errorf("failed to get task status from store: %w", err)
|
||
}
|
||
|
||
var status Status
|
||
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
||
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
|
||
}
|
||
|
||
if !status.IsRunning && status.FinishedAt != nil {
|
||
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
||
}
|
||
|
||
return &status, nil
|
||
}
|
||
|
||
// UpdateProgressByID 通过ID更新任务进度
|
||
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||
if _, err := s.store.Get(runningFlagKey); err != nil {
|
||
// 任务已结束,静默返回是预期行为。
|
||
return nil
|
||
}
|
||
status, err := s.GetStatus(taskID)
|
||
if err != nil {
|
||
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
|
||
return nil
|
||
}
|
||
if !status.IsRunning {
|
||
return nil
|
||
}
|
||
// 调用传入的 updater 函数来修改 status
|
||
updater(status)
|
||
statusBytes, marshalErr := json.Marshal(status)
|
||
if marshalErr != nil {
|
||
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
|
||
return nil
|
||
}
|
||
taskKey := s.getTaskDataKey(taskID)
|
||
// 使用更长的TTL,确保运行中的任务不会过早过期
|
||
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
|
||
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
|
||
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
|
||
return s.updateTask(taskID, func(status *Status) {
|
||
status.Processed = processed
|
||
})
|
||
}
|
||
|
||
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
|
||
func (s *Task) UpdateTotalByID(taskID string, total int) error {
|
||
return s.updateTask(taskID, func(status *Status) {
|
||
status.Total = total
|
||
})
|
||
}
|