214 lines
6.0 KiB
Go
214 lines
6.0 KiB
Go
package task
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"gemini-balancer/internal/store"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
ResultTTL = 60 * time.Minute
|
|
DefaultTimeout = 24 * time.Hour
|
|
LockTTL = 30 * time.Minute
|
|
)
|
|
|
|
type Reporter interface {
|
|
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
|
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
|
|
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
|
|
UpdateTotalByID(ctx context.Context, taskID string, total int) error
|
|
}
|
|
|
|
type Status struct {
|
|
ID string `json:"id"`
|
|
TaskType string `json:"task_type"`
|
|
IsRunning bool `json:"is_running"`
|
|
ResourceID string `json:"resource_id,omitempty"`
|
|
Processed int `json:"processed"`
|
|
Total int `json:"total"`
|
|
Result any `json:"result,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
StartedAt time.Time `json:"started_at"`
|
|
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
|
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
|
}
|
|
|
|
type Task struct {
|
|
store store.Store
|
|
logger *logrus.Entry
|
|
}
|
|
|
|
func NewTask(store store.Store, logger *logrus.Logger) *Task {
|
|
return &Task{
|
|
store: store,
|
|
logger: logger.WithField("component", "TaskService📋"),
|
|
}
|
|
}
|
|
|
|
var _ Reporter = (*Task)(nil)
|
|
|
|
func (s *Task) getResourceLockKey(resourceID string) string {
|
|
return fmt.Sprintf("task:lock:%s", resourceID)
|
|
}
|
|
|
|
func (s *Task) getTaskDataKey(taskID string) string {
|
|
return fmt.Sprintf("task:data:%s", taskID)
|
|
}
|
|
|
|
func (s *Task) getIsRunningFlagKey(taskID string) string {
|
|
return fmt.Sprintf("task:running:%s", taskID)
|
|
}
|
|
|
|
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
|
lockKey := s.getResourceLockKey(resourceID)
|
|
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
|
|
|
locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to acquire task lock: %w", err)
|
|
}
|
|
if !locked {
|
|
existingTaskID, _ := s.store.Get(ctx, lockKey)
|
|
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
|
}
|
|
|
|
if timeout == 0 {
|
|
timeout = DefaultTimeout
|
|
}
|
|
|
|
status := &Status{
|
|
ID: taskID,
|
|
TaskType: taskType,
|
|
IsRunning: true,
|
|
ResourceID: resourceID,
|
|
Total: total,
|
|
StartedAt: time.Now(),
|
|
}
|
|
|
|
if err := s.saveStatus(ctx, taskID, status, timeout); err != nil {
|
|
_ = s.store.Del(ctx, lockKey)
|
|
return nil, fmt.Errorf("failed to save task status: %w", err)
|
|
}
|
|
|
|
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
|
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
|
|
_ = s.store.Del(ctx, lockKey)
|
|
_ = s.store.Del(ctx, s.getTaskDataKey(taskID))
|
|
return nil, fmt.Errorf("failed to set running flag: %w", err)
|
|
}
|
|
|
|
return status, nil
|
|
}
|
|
|
|
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
|
|
lockKey := s.getResourceLockKey(resourceID)
|
|
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
|
|
|
defer func() {
|
|
_ = s.store.Del(ctx, lockKey)
|
|
_ = s.store.Del(ctx, runningFlagKey)
|
|
}()
|
|
|
|
status, err := s.GetStatus(ctx, taskID)
|
|
if err != nil {
|
|
s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID)
|
|
return
|
|
}
|
|
|
|
if !status.IsRunning {
|
|
s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID)
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
status.IsRunning = false
|
|
status.FinishedAt = &now
|
|
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
|
|
|
|
if taskErr != nil {
|
|
status.Error = taskErr.Error()
|
|
} else {
|
|
status.Result = resultData
|
|
}
|
|
|
|
if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil {
|
|
s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID)
|
|
}
|
|
}
|
|
|
|
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
|
|
taskKey := s.getTaskDataKey(taskID)
|
|
statusBytes, err := s.store.Get(ctx, taskKey)
|
|
if err != nil {
|
|
if errors.Is(err, store.ErrNotFound) {
|
|
return nil, errors.New("task not found")
|
|
}
|
|
return nil, fmt.Errorf("failed to get task status: %w", err)
|
|
}
|
|
|
|
var status Status
|
|
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
|
return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err)
|
|
}
|
|
|
|
if !status.IsRunning && status.FinishedAt != nil {
|
|
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
|
}
|
|
|
|
return &status, nil
|
|
}
|
|
|
|
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
|
|
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
|
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
|
|
if errors.Is(err, store.ErrNotFound) {
|
|
return errors.New("task is not running")
|
|
}
|
|
return fmt.Errorf("failed to check running flag: %w", err)
|
|
}
|
|
|
|
status, err := s.GetStatus(ctx, taskID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get task status: %w", err)
|
|
}
|
|
|
|
if !status.IsRunning {
|
|
return errors.New("task is not running")
|
|
}
|
|
|
|
updater(status)
|
|
|
|
return s.saveStatus(ctx, taskID, status, DefaultTimeout)
|
|
}
|
|
|
|
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
|
|
return s.updateTask(ctx, taskID, func(status *Status) {
|
|
status.Processed = processed
|
|
})
|
|
}
|
|
|
|
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
|
|
return s.updateTask(ctx, taskID, func(status *Status) {
|
|
status.Total = total
|
|
})
|
|
}
|
|
|
|
func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error {
|
|
statusBytes, err := json.Marshal(status)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to serialize status: %w", err)
|
|
}
|
|
|
|
taskKey := s.getTaskDataKey(taskID)
|
|
if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil {
|
|
return fmt.Errorf("failed to save status: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|