Files
gemini-banlancer/internal/task/task.go
2025-11-20 12:24:05 +08:00

215 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/task/task.go (最终校准版)
package task
import (
"encoding/json"
"errors"
"fmt"
"gemini-balancer/internal/store"
"time"
"github.com/sirupsen/logrus"
)
const (
ResultTTL = 60 * time.Minute
)
// Reporter 接口,定义了领域如何与任务服务交互。
type Reporter interface {
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(taskID string, processed int) error
UpdateTotalByID(taskID string, total int) error
}
// Status 代表一个后台任务的完整状态
type Status struct {
ID string `json:"id"`
TaskType string `json:"task_type"`
IsRunning bool `json:"is_running"`
ResourceID string `json:"resource_id,omitempty"`
Processed int `json:"processed"`
Total int `json:"total"`
Result any `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartedAt time.Time `json:"started_at"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
DurationSeconds float64 `json:"duration_seconds,omitempty"`
}
// Task 是任务管理的核心服务
type Task struct {
store store.Store
logger *logrus.Entry
}
// NewTask 是 Task 的构造函数
func NewTask(store store.Store, logger *logrus.Logger) *Task {
return &Task{
store: store,
logger: logger.WithField("component", "TaskService📋"),
}
}
var _ Reporter = (*Task)(nil)
func (s *Task) getResourceLockKey(resourceID string) string {
return fmt.Sprintf("task:lock:%s", resourceID)
}
func (s *Task) getTaskDataKey(taskID string) string {
return fmt.Sprintf("task:data:%s", taskID)
}
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
func (s *Task) getIsRunningFlagKey(taskID string) string {
return fmt.Sprintf("task:running:%s", taskID)
}
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID)
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
}
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
taskKey := s.getTaskDataKey(taskID)
runningFlagKey := s.getIsRunningFlagKey(taskID)
status := &Status{
ID: taskID,
TaskType: taskType,
IsRunning: true,
ResourceID: resourceID,
Total: total,
StartedAt: time.Now(),
}
statusBytes, err := json.Marshal(status)
if err != nil {
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
}
if timeout == 0 {
timeout = ResultTTL * 24
}
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
}
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(lockKey)
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
}
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(lockKey)
_ = s.store.Del(taskKey)
return nil, fmt.Errorf("failed to set task running flag: %w", err)
}
return status, nil
}
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID)
defer func() {
if err := s.store.Del(lockKey); err != nil {
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
}
}()
runningFlagKey := s.getIsRunningFlagKey(taskID)
_ = s.store.Del(runningFlagKey)
status, err := s.GetStatus(taskID)
if err != nil {
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
return
}
if !status.IsRunning {
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
return
}
now := time.Now()
status.IsRunning = false
status.FinishedAt = &now
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
if taskErr != nil {
status.Error = taskErr.Error()
} else {
status.Result = resultData
}
updatedTaskBytes, _ := json.Marshal(status)
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
}
}
// GetStatus 通过ID获取任务状态供外部如API Handler调用
func (s *Task) GetStatus(taskID string) (*Status, error) {
taskKey := s.getTaskDataKey(taskID)
statusBytes, err := s.store.Get(taskKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found")
}
return nil, fmt.Errorf("failed to get task status from store: %w", err)
}
var status Status
if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
}
if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
}
return &status, nil
}
// UpdateProgressByID 通过ID更新任务进度
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(runningFlagKey); err != nil {
// 任务已结束,静默返回是预期行为。
return nil
}
status, err := s.GetStatus(taskID)
if err != nil {
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
return nil
}
if !status.IsRunning {
return nil
}
// 调用传入的 updater 函数来修改 status
updater(status)
statusBytes, marshalErr := json.Marshal(status)
if marshalErr != nil {
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
return nil
}
taskKey := s.getTaskDataKey(taskID)
// 使用更长的TTL确保运行中的任务不会过早过期
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
}
return nil
}
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
return s.updateTask(taskID, func(status *Status) {
status.Processed = processed
})
}
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateTotalByID(taskID string, total int) error {
return s.updateTask(taskID, func(status *Status) {
status.Total = total
})
}