Files
gemini-banlancer/internal/task/task.go
2025-11-22 14:20:05 +08:00

201 lines
6.2 KiB
Go

// Filename: internal/task/task.go
package task
import (
"context"
"encoding/json"
"errors"
"fmt"
"gemini-balancer/internal/store"
"time"
"github.com/sirupsen/logrus"
)
const (
ResultTTL = 60 * time.Minute
)
type Reporter interface {
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
UpdateTotalByID(ctx context.Context, taskID string, total int) error
}
type Status struct {
ID string `json:"id"`
TaskType string `json:"task_type"`
IsRunning bool `json:"is_running"`
ResourceID string `json:"resource_id,omitempty"`
Processed int `json:"processed"`
Total int `json:"total"`
Result any `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartedAt time.Time `json:"started_at"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
DurationSeconds float64 `json:"duration_seconds,omitempty"`
}
type Task struct {
store store.Store
logger *logrus.Entry
}
func NewTask(store store.Store, logger *logrus.Logger) *Task {
return &Task{
store: store,
logger: logger.WithField("component", "TaskService📋"),
}
}
var _ Reporter = (*Task)(nil)
func (s *Task) getResourceLockKey(resourceID string) string {
return fmt.Sprintf("task:lock:%s", resourceID)
}
func (s *Task) getTaskDataKey(taskID string) string {
return fmt.Sprintf("task:data:%s", taskID)
}
func (s *Task) getIsRunningFlagKey(taskID string) string {
return fmt.Sprintf("task:running:%s", taskID)
}
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID)
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
}
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
taskKey := s.getTaskDataKey(taskID)
runningFlagKey := s.getIsRunningFlagKey(taskID)
status := &Status{
ID: taskID,
TaskType: taskType,
IsRunning: true,
ResourceID: resourceID,
Total: total,
StartedAt: time.Now(),
}
statusBytes, err := json.Marshal(status)
if err != nil {
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
}
if timeout == 0 {
timeout = ResultTTL * 24
}
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
}
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
}
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
_ = s.store.Del(ctx, taskKey)
return nil, fmt.Errorf("failed to set task running flag: %w", err)
}
return status, nil
}
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID)
defer func() {
if err := s.store.Del(ctx, lockKey); err != nil {
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
}
}()
runningFlagKey := s.getIsRunningFlagKey(taskID)
_ = s.store.Del(ctx, runningFlagKey)
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
return
}
if !status.IsRunning {
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
return
}
now := time.Now()
status.IsRunning = false
status.FinishedAt = &now
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
if taskErr != nil {
status.Error = taskErr.Error()
} else {
status.Result = resultData
}
updatedTaskBytes, _ := json.Marshal(status)
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
}
}
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
taskKey := s.getTaskDataKey(taskID)
statusBytes, err := s.store.Get(ctx, taskKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found")
}
return nil, fmt.Errorf("failed to get task status from store: %w", err)
}
var status Status
if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
}
if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
}
return &status, nil
}
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
return nil
}
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
return nil
}
if !status.IsRunning {
return nil
}
updater(status)
statusBytes, marshalErr := json.Marshal(status)
if marshalErr != nil {
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
return nil
}
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
}
return nil
}
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Processed = processed
})
}
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Total = total
})
}