201 lines
6.2 KiB
Go
201 lines
6.2 KiB
Go
// Filename: internal/task/task.go
|
|
package task
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"gemini-balancer/internal/store"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
ResultTTL = 60 * time.Minute
|
|
)
|
|
|
|
type Reporter interface {
|
|
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
|
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
|
|
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
|
|
UpdateTotalByID(ctx context.Context, taskID string, total int) error
|
|
}
|
|
|
|
type Status struct {
|
|
ID string `json:"id"`
|
|
TaskType string `json:"task_type"`
|
|
IsRunning bool `json:"is_running"`
|
|
ResourceID string `json:"resource_id,omitempty"`
|
|
Processed int `json:"processed"`
|
|
Total int `json:"total"`
|
|
Result any `json:"result,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
StartedAt time.Time `json:"started_at"`
|
|
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
|
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
|
}
|
|
|
|
type Task struct {
|
|
store store.Store
|
|
logger *logrus.Entry
|
|
}
|
|
|
|
func NewTask(store store.Store, logger *logrus.Logger) *Task {
|
|
return &Task{
|
|
store: store,
|
|
logger: logger.WithField("component", "TaskService📋"),
|
|
}
|
|
}
|
|
|
|
var _ Reporter = (*Task)(nil)
|
|
|
|
func (s *Task) getResourceLockKey(resourceID string) string {
|
|
return fmt.Sprintf("task:lock:%s", resourceID)
|
|
}
|
|
|
|
func (s *Task) getTaskDataKey(taskID string) string {
|
|
return fmt.Sprintf("task:data:%s", taskID)
|
|
}
|
|
|
|
func (s *Task) getIsRunningFlagKey(taskID string) string {
|
|
return fmt.Sprintf("task:running:%s", taskID)
|
|
}
|
|
|
|
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
|
lockKey := s.getResourceLockKey(resourceID)
|
|
|
|
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
|
|
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
|
}
|
|
|
|
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
|
taskKey := s.getTaskDataKey(taskID)
|
|
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
|
status := &Status{
|
|
ID: taskID,
|
|
TaskType: taskType,
|
|
IsRunning: true,
|
|
ResourceID: resourceID,
|
|
Total: total,
|
|
StartedAt: time.Now(),
|
|
}
|
|
statusBytes, err := json.Marshal(status)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
|
|
}
|
|
|
|
if timeout == 0 {
|
|
timeout = ResultTTL * 24
|
|
}
|
|
|
|
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
|
|
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
|
|
}
|
|
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
|
|
_ = s.store.Del(ctx, lockKey)
|
|
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
|
|
}
|
|
|
|
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
|
|
_ = s.store.Del(ctx, lockKey)
|
|
_ = s.store.Del(ctx, taskKey)
|
|
return nil, fmt.Errorf("failed to set task running flag: %w", err)
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
|
|
lockKey := s.getResourceLockKey(resourceID)
|
|
defer func() {
|
|
if err := s.store.Del(ctx, lockKey); err != nil {
|
|
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
|
|
}
|
|
}()
|
|
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
|
_ = s.store.Del(ctx, runningFlagKey)
|
|
|
|
status, err := s.GetStatus(ctx, taskID)
|
|
if err != nil {
|
|
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
|
|
return
|
|
}
|
|
if !status.IsRunning {
|
|
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
|
|
return
|
|
}
|
|
now := time.Now()
|
|
status.IsRunning = false
|
|
status.FinishedAt = &now
|
|
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
|
|
if taskErr != nil {
|
|
status.Error = taskErr.Error()
|
|
} else {
|
|
status.Result = resultData
|
|
}
|
|
updatedTaskBytes, _ := json.Marshal(status)
|
|
taskKey := s.getTaskDataKey(taskID)
|
|
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
|
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
|
|
}
|
|
}
|
|
|
|
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
|
|
taskKey := s.getTaskDataKey(taskID)
|
|
statusBytes, err := s.store.Get(ctx, taskKey)
|
|
if err != nil {
|
|
if errors.Is(err, store.ErrNotFound) {
|
|
return nil, errors.New("task not found")
|
|
}
|
|
return nil, fmt.Errorf("failed to get task status from store: %w", err)
|
|
}
|
|
|
|
var status Status
|
|
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
|
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
|
|
}
|
|
if !status.IsRunning && status.FinishedAt != nil {
|
|
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
|
}
|
|
return &status, nil
|
|
}
|
|
|
|
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
|
|
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
|
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
|
|
return nil
|
|
}
|
|
status, err := s.GetStatus(ctx, taskID)
|
|
if err != nil {
|
|
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
|
|
return nil
|
|
}
|
|
if !status.IsRunning {
|
|
return nil
|
|
}
|
|
updater(status)
|
|
statusBytes, marshalErr := json.Marshal(status)
|
|
if marshalErr != nil {
|
|
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
|
|
return nil
|
|
}
|
|
taskKey := s.getTaskDataKey(taskID)
|
|
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
|
|
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
|
|
return s.updateTask(ctx, taskID, func(status *Status) {
|
|
status.Processed = processed
|
|
})
|
|
}
|
|
|
|
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
|
|
return s.updateTask(ctx, taskID, func(status *Status) {
|
|
status.Total = total
|
|
})
|
|
}
|