// Filename: internal/task/task.go package task import ( "context" "encoding/json" "errors" "fmt" "gemini-balancer/internal/store" "time" "github.com/sirupsen/logrus" ) const ( ResultTTL = 60 * time.Minute ) type Reporter interface { StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error) UpdateProgressByID(ctx context.Context, taskID string, processed int) error UpdateTotalByID(ctx context.Context, taskID string, total int) error } type Status struct { ID string `json:"id"` TaskType string `json:"task_type"` IsRunning bool `json:"is_running"` ResourceID string `json:"resource_id,omitempty"` Processed int `json:"processed"` Total int `json:"total"` Result any `json:"result,omitempty"` Error string `json:"error,omitempty"` StartedAt time.Time `json:"started_at"` FinishedAt *time.Time `json:"finished_at,omitempty"` DurationSeconds float64 `json:"duration_seconds,omitempty"` } type Task struct { store store.Store logger *logrus.Entry } func NewTask(store store.Store, logger *logrus.Logger) *Task { return &Task{ store: store, logger: logger.WithField("component", "TaskServiceđź“‹"), } } var _ Reporter = (*Task)(nil) func (s *Task) getResourceLockKey(resourceID string) string { return fmt.Sprintf("task:lock:%s", resourceID) } func (s *Task) getTaskDataKey(taskID string) string { return fmt.Sprintf("task:data:%s", taskID) } func (s *Task) getIsRunningFlagKey(taskID string) string { return fmt.Sprintf("task:running:%s", taskID) } func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { lockKey := s.getResourceLockKey(resourceID) if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 { return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID)) } taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID) taskKey := s.getTaskDataKey(taskID) runningFlagKey := s.getIsRunningFlagKey(taskID) status := &Status{ ID: taskID, TaskType: taskType, IsRunning: true, ResourceID: resourceID, Total: total, StartedAt: time.Now(), } statusBytes, err := json.Marshal(status) if err != nil { return nil, fmt.Errorf("failed to serialize new task status: %w", err) } if timeout == 0 { timeout = ResultTTL * 24 } if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil { return nil, fmt.Errorf("failed to acquire task resource lock: %w", err) } if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil { _ = s.store.Del(ctx, lockKey) return nil, fmt.Errorf("failed to set new task data in store: %w", err) } if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil { _ = s.store.Del(ctx, lockKey) _ = s.store.Del(ctx, taskKey) return nil, fmt.Errorf("failed to set task running flag: %w", err) } return status, nil } func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) { lockKey := s.getResourceLockKey(resourceID) defer func() { if err := s.store.Del(ctx, lockKey); err != nil { s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID) } }() runningFlagKey := s.getIsRunningFlagKey(taskID) _ = s.store.Del(ctx, runningFlagKey) status, err := s.GetStatus(ctx, taskID) if err != nil { s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID) return } if !status.IsRunning { s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID) return } now := time.Now() status.IsRunning = false status.FinishedAt = &now status.DurationSeconds = now.Sub(status.StartedAt).Seconds() if taskErr != nil { status.Error = taskErr.Error() } else { status.Result = resultData } updatedTaskBytes, _ := json.Marshal(status) taskKey := s.getTaskDataKey(taskID) if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil { s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID) } } func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) { taskKey := s.getTaskDataKey(taskID) statusBytes, err := s.store.Get(ctx, taskKey) if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, errors.New("task not found") } return nil, fmt.Errorf("failed to get task status from store: %w", err) } var status Status if err := json.Unmarshal(statusBytes, &status); err != nil { return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID) } if !status.IsRunning && status.FinishedAt != nil { status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() } return &status, nil } func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error { runningFlagKey := s.getIsRunningFlagKey(taskID) if _, err := s.store.Get(ctx, runningFlagKey); err != nil { return nil } status, err := s.GetStatus(ctx, taskID) if err != nil { s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID) return nil } if !status.IsRunning { return nil } updater(status) statusBytes, marshalErr := json.Marshal(status) if marshalErr != nil { s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID) return nil } taskKey := s.getTaskDataKey(taskID) if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil { s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID) } return nil } func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error { return s.updateTask(ctx, taskID, func(status *Status) { status.Processed = processed }) } func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error { return s.updateTask(ctx, taskID, func(status *Status) { status.Total = total }) }