package task import ( "context" "encoding/json" "errors" "fmt" "gemini-balancer/internal/store" "time" "github.com/sirupsen/logrus" ) const ( ResultTTL = 60 * time.Minute DefaultTimeout = 24 * time.Hour LockTTL = 30 * time.Minute ) type Reporter interface { StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error) UpdateProgressByID(ctx context.Context, taskID string, processed int) error UpdateTotalByID(ctx context.Context, taskID string, total int) error } type Status struct { ID string `json:"id"` TaskType string `json:"task_type"` IsRunning bool `json:"is_running"` ResourceID string `json:"resource_id,omitempty"` Processed int `json:"processed"` Total int `json:"total"` Result any `json:"result,omitempty"` Error string `json:"error,omitempty"` StartedAt time.Time `json:"started_at"` FinishedAt *time.Time `json:"finished_at,omitempty"` DurationSeconds float64 `json:"duration_seconds,omitempty"` } type Task struct { store store.Store logger *logrus.Entry } func NewTask(store store.Store, logger *logrus.Logger) *Task { return &Task{ store: store, logger: logger.WithField("component", "TaskServiceđź“‹"), } } var _ Reporter = (*Task)(nil) func (s *Task) getResourceLockKey(resourceID string) string { return fmt.Sprintf("task:lock:%s", resourceID) } func (s *Task) getTaskDataKey(taskID string) string { return fmt.Sprintf("task:data:%s", taskID) } func (s *Task) getIsRunningFlagKey(taskID string) string { return fmt.Sprintf("task:running:%s", taskID) } func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { lockKey := s.getResourceLockKey(resourceID) taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID) locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL) if err != nil { return nil, fmt.Errorf("failed to acquire task lock: %w", err) } if !locked { existingTaskID, _ := s.store.Get(ctx, lockKey) return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID)) } if timeout == 0 { timeout = DefaultTimeout } status := &Status{ ID: taskID, TaskType: taskType, IsRunning: true, ResourceID: resourceID, Total: total, StartedAt: time.Now(), } if err := s.saveStatus(ctx, taskID, status, timeout); err != nil { _ = s.store.Del(ctx, lockKey) return nil, fmt.Errorf("failed to save task status: %w", err) } runningFlagKey := s.getIsRunningFlagKey(taskID) if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil { _ = s.store.Del(ctx, lockKey) _ = s.store.Del(ctx, s.getTaskDataKey(taskID)) return nil, fmt.Errorf("failed to set running flag: %w", err) } return status, nil } func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) { lockKey := s.getResourceLockKey(resourceID) runningFlagKey := s.getIsRunningFlagKey(taskID) defer func() { _ = s.store.Del(ctx, lockKey) _ = s.store.Del(ctx, runningFlagKey) }() status, err := s.GetStatus(ctx, taskID) if err != nil { s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID) return } if !status.IsRunning { s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID) return } now := time.Now() status.IsRunning = false status.FinishedAt = &now status.DurationSeconds = now.Sub(status.StartedAt).Seconds() if taskErr != nil { status.Error = taskErr.Error() } else { status.Result = resultData } if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil { s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID) } } func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) { taskKey := s.getTaskDataKey(taskID) statusBytes, err := s.store.Get(ctx, taskKey) if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, errors.New("task not found") } return nil, fmt.Errorf("failed to get task status: %w", err) } var status Status if err := json.Unmarshal(statusBytes, &status); err != nil { return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err) } if !status.IsRunning && status.FinishedAt != nil { status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() } return &status, nil } func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error { runningFlagKey := s.getIsRunningFlagKey(taskID) if _, err := s.store.Get(ctx, runningFlagKey); err != nil { if errors.Is(err, store.ErrNotFound) { return errors.New("task is not running") } return fmt.Errorf("failed to check running flag: %w", err) } status, err := s.GetStatus(ctx, taskID) if err != nil { return fmt.Errorf("failed to get task status: %w", err) } if !status.IsRunning { return errors.New("task is not running") } updater(status) return s.saveStatus(ctx, taskID, status, DefaultTimeout) } func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error { return s.updateTask(ctx, taskID, func(status *Status) { status.Processed = processed }) } func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error { return s.updateTask(ctx, taskID, func(status *Status) { status.Total = total }) } func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error { statusBytes, err := json.Marshal(status) if err != nil { return fmt.Errorf("failed to serialize status: %w", err) } taskKey := s.getTaskDataKey(taskID) if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil { return fmt.Errorf("failed to save status: %w", err) } return nil }