464 lines
13 KiB
Go
464 lines
13 KiB
Go
// Filename: internal/service/key_validation_service.go
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/channel"
|
|
CustomErrors "gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/repository"
|
|
"gemini-balancer/internal/settings"
|
|
"gemini-balancer/internal/store"
|
|
"gemini-balancer/internal/task"
|
|
"gemini-balancer/internal/utils"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
TaskTypeTestKeys = "test_keys"
|
|
defaultConcurrency = 10
|
|
maxValidationConcurrency = 100
|
|
validationTaskTimeout = time.Hour
|
|
)
|
|
|
|
type KeyValidationService struct {
|
|
taskService task.Reporter
|
|
channel channel.ChannelProxy
|
|
db *gorm.DB
|
|
settingsManager *settings.SettingsManager
|
|
groupManager *GroupManager
|
|
store store.Store
|
|
keyRepo repository.KeyRepository
|
|
logger *logrus.Entry
|
|
}
|
|
|
|
func NewKeyValidationService(
|
|
ts task.Reporter,
|
|
ch channel.ChannelProxy,
|
|
db *gorm.DB,
|
|
ss *settings.SettingsManager,
|
|
gm *GroupManager,
|
|
st store.Store,
|
|
kr repository.KeyRepository,
|
|
logger *logrus.Logger,
|
|
) *KeyValidationService {
|
|
return &KeyValidationService{
|
|
taskService: ts,
|
|
channel: ch,
|
|
db: db,
|
|
settingsManager: ss,
|
|
groupManager: gm,
|
|
store: st,
|
|
keyRepo: kr,
|
|
logger: logger.WithField("component", "KeyValidationService🧐"),
|
|
}
|
|
}
|
|
|
|
// ==================== 公开接口 ====================
|
|
|
|
// ValidateSingleKey 验证单个密钥
|
|
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
|
|
// 1. 解密密钥
|
|
if err := s.keyRepo.Decrypt(key); err != nil {
|
|
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
|
|
}
|
|
|
|
// 2. 创建 HTTP 客户端和请求
|
|
client := &http.Client{Timeout: timeout}
|
|
req, err := http.NewRequest("GET", endpoint, nil)
|
|
if err != nil {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"key_id": key.ID,
|
|
"endpoint": endpoint,
|
|
}).Error("Failed to create validation request")
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
// 3. 修改请求(添加密钥认证头)
|
|
s.channel.ModifyRequest(req, key)
|
|
|
|
// 4. 执行请求
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 5. 检查响应状态
|
|
if resp.StatusCode == http.StatusOK {
|
|
return nil
|
|
}
|
|
|
|
// 6. 处理错误响应
|
|
return s.buildValidationError(resp)
|
|
}
|
|
|
|
// StartTestKeysTask 启动批量密钥测试任务
|
|
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
|
// 1. 解析和验证输入
|
|
keyStrings := utils.ParseKeysFromText(keysText)
|
|
if len(keyStrings) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
|
}
|
|
|
|
// 2. 查询密钥模型
|
|
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
|
|
if err != nil {
|
|
return nil, CustomErrors.ParseDBError(err)
|
|
}
|
|
if len(apiKeyModels) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
|
|
}
|
|
|
|
// 3. 批量解密密钥
|
|
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
|
|
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
|
|
}
|
|
|
|
// 4. 获取组配置
|
|
group, ok := s.groupManager.GetGroupByID(groupID)
|
|
if !ok {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
|
}
|
|
|
|
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
|
if err != nil {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
|
}
|
|
|
|
// 5. 构建验证端点
|
|
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
|
if err != nil {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
|
|
}
|
|
|
|
// 6. 创建任务
|
|
resourceID := fmt.Sprintf("group-%d", groupID)
|
|
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 7. 准备任务参数
|
|
params := s.buildValidationParams(opConfig)
|
|
|
|
// 8. 启动异步验证任务
|
|
go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
|
|
|
|
return taskStatus, nil
|
|
}
|
|
|
|
// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
|
|
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"group_id": groupID,
|
|
"statuses": statuses,
|
|
}).Info("Starting test task with status filter")
|
|
|
|
// 1. 根据过滤条件查询密钥
|
|
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
|
if err != nil {
|
|
return nil, CustomErrors.ParseDBError(err)
|
|
}
|
|
|
|
if len(keyValues) == 0 {
|
|
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
|
|
}
|
|
|
|
// 2. 转换为文本格式并启动任务
|
|
keysAsText := strings.Join(keyValues, "\n")
|
|
s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
|
|
|
|
return s.StartTestKeysTask(ctx, groupID, keysAsText)
|
|
}
|
|
|
|
// ==================== 核心任务执行逻辑 ====================
|
|
|
|
// validationParams 验证参数封装
|
|
type validationParams struct {
|
|
timeout time.Duration
|
|
concurrency int
|
|
}
|
|
|
|
// buildValidationParams 构建验证参数
|
|
func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
|
|
settings := s.settingsManager.GetSettings()
|
|
// 从配置读取超时时间(而非硬编码)
|
|
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
|
if timeout <= 0 {
|
|
timeout = 30 * time.Second // 仅在配置无效时使用默认值
|
|
}
|
|
// 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
|
|
var concurrency int
|
|
if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
|
|
concurrency = *opConfig.KeyCheckConcurrency
|
|
} else if settings.BaseKeyCheckConcurrency > 0 {
|
|
concurrency = settings.BaseKeyCheckConcurrency
|
|
} else {
|
|
concurrency = defaultConcurrency // 兜底默认值
|
|
}
|
|
// 限制最大并发数(防护措施)
|
|
if concurrency > maxValidationConcurrency {
|
|
concurrency = maxValidationConcurrency
|
|
}
|
|
return validationParams{
|
|
timeout: timeout,
|
|
concurrency: concurrency,
|
|
}
|
|
}
|
|
|
|
// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
|
|
func (s *KeyValidationService) runTestKeysTaskWithRecovery(
|
|
ctx context.Context,
|
|
taskID string,
|
|
resourceID string,
|
|
groupID uint,
|
|
keys []models.APIKey,
|
|
params validationParams,
|
|
endpoint string,
|
|
) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
|
|
s.logger.WithField("task_id", taskID).Error(err)
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
|
}
|
|
}()
|
|
|
|
s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
|
|
}
|
|
|
|
// runTestKeysTask 执行批量密钥验证任务
|
|
func (s *KeyValidationService) runTestKeysTask(
|
|
ctx context.Context,
|
|
taskID string,
|
|
resourceID string,
|
|
groupID uint,
|
|
keys []models.APIKey,
|
|
params validationParams,
|
|
endpoint string,
|
|
) {
|
|
s.logger.WithFields(logrus.Fields{
|
|
"task_id": taskID,
|
|
"group_id": groupID,
|
|
"key_count": len(keys),
|
|
"concurrency": params.concurrency,
|
|
"timeout": params.timeout,
|
|
}).Info("Starting validation task")
|
|
|
|
// 1. 初始化结果收集
|
|
results := make([]models.KeyTestResult, len(keys))
|
|
|
|
// 2. 创建任务分发器
|
|
dispatcher := newValidationDispatcher(
|
|
keys,
|
|
params.concurrency,
|
|
s,
|
|
ctx,
|
|
taskID,
|
|
groupID,
|
|
endpoint,
|
|
params.timeout,
|
|
)
|
|
|
|
// 3. 执行并发验证
|
|
dispatcher.run(results)
|
|
|
|
// 4. 完成任务
|
|
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
|
|
|
|
s.logger.WithFields(logrus.Fields{
|
|
"task_id": taskID,
|
|
"group_id": groupID,
|
|
"processed": len(results),
|
|
}).Info("Validation task completed")
|
|
}
|
|
|
|
// ==================== 验证调度器 ====================
|
|
|
|
// validationJob 验证作业
|
|
type validationJob struct {
|
|
index int
|
|
key models.APIKey
|
|
}
|
|
|
|
// validationDispatcher 验证任务分发器
|
|
type validationDispatcher struct {
|
|
keys []models.APIKey
|
|
concurrency int
|
|
service *KeyValidationService
|
|
ctx context.Context
|
|
taskID string
|
|
groupID uint
|
|
endpoint string
|
|
timeout time.Duration
|
|
|
|
mu sync.Mutex
|
|
processedCount int
|
|
}
|
|
|
|
// newValidationDispatcher 创建验证分发器
|
|
func newValidationDispatcher(
|
|
keys []models.APIKey,
|
|
concurrency int,
|
|
service *KeyValidationService,
|
|
ctx context.Context,
|
|
taskID string,
|
|
groupID uint,
|
|
endpoint string,
|
|
timeout time.Duration,
|
|
) *validationDispatcher {
|
|
return &validationDispatcher{
|
|
keys: keys,
|
|
concurrency: concurrency,
|
|
service: service,
|
|
ctx: ctx,
|
|
taskID: taskID,
|
|
groupID: groupID,
|
|
endpoint: endpoint,
|
|
timeout: timeout,
|
|
}
|
|
}
|
|
|
|
// run 执行并发验证
|
|
func (d *validationDispatcher) run(results []models.KeyTestResult) {
|
|
var wg sync.WaitGroup
|
|
jobs := make(chan validationJob, len(d.keys))
|
|
|
|
// 启动 worker pool
|
|
for i := 0; i < d.concurrency; i++ {
|
|
wg.Add(1)
|
|
go d.worker(&wg, jobs, results)
|
|
}
|
|
|
|
// 分发任务
|
|
for i, key := range d.keys {
|
|
jobs <- validationJob{index: i, key: key}
|
|
}
|
|
close(jobs)
|
|
|
|
// 等待所有 worker 完成
|
|
wg.Wait()
|
|
}
|
|
|
|
// worker 验证工作协程
|
|
func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
|
|
defer wg.Done()
|
|
|
|
for job := range jobs {
|
|
result := d.validateKey(job.key)
|
|
|
|
d.mu.Lock()
|
|
results[job.index] = result
|
|
d.processedCount++
|
|
_ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
|
|
d.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// validateKey 验证单个密钥并返回结果
|
|
func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
|
|
// 1. 执行验证
|
|
validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
|
|
|
|
// 2. 构建结果和事件
|
|
result, event := d.buildResultAndEvent(key, validationErr)
|
|
|
|
// 3. 发布验证事件
|
|
d.publishValidationEvent(key.ID, event)
|
|
|
|
return result
|
|
}
|
|
|
|
// buildResultAndEvent 构建验证结果和事件
|
|
func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
|
|
event := models.RequestFinishedEvent{
|
|
RequestLog: models.RequestLog{
|
|
GroupID: &d.groupID,
|
|
KeyID: &key.ID,
|
|
},
|
|
}
|
|
|
|
if validationErr == nil {
|
|
// 验证成功
|
|
event.RequestLog.IsSuccess = true
|
|
return models.KeyTestResult{
|
|
Key: key.APIKey,
|
|
Status: "valid",
|
|
Message: "Validation successful",
|
|
}, event
|
|
}
|
|
|
|
// 验证失败
|
|
event.RequestLog.IsSuccess = false
|
|
|
|
var apiErr *CustomErrors.APIError
|
|
if CustomErrors.As(validationErr, &apiErr) {
|
|
event.Error = apiErr
|
|
return models.KeyTestResult{
|
|
Key: key.APIKey,
|
|
Status: "invalid",
|
|
Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
|
|
}, event
|
|
}
|
|
|
|
// 其他错误
|
|
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
|
return models.KeyTestResult{
|
|
Key: key.APIKey,
|
|
Status: "error",
|
|
Message: "Validation check failed: " + validationErr.Error(),
|
|
}, event
|
|
}
|
|
|
|
// publishValidationEvent 发布验证事件
|
|
func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
|
|
eventData, err := json.Marshal(event)
|
|
if err != nil {
|
|
d.service.logger.WithFields(logrus.Fields{
|
|
"key_id": keyID,
|
|
"group_id": d.groupID,
|
|
}).WithError(err).Error("Failed to marshal validation event")
|
|
return
|
|
}
|
|
|
|
if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
|
|
d.service.logger.WithFields(logrus.Fields{
|
|
"key_id": keyID,
|
|
"group_id": d.groupID,
|
|
}).WithError(err).Error("Failed to publish validation event")
|
|
}
|
|
}
|
|
|
|
// ==================== 辅助方法 ====================
|
|
|
|
// buildValidationError 构建验证错误
|
|
func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
|
|
bodyBytes, readErr := io.ReadAll(resp.Body)
|
|
|
|
var errorMsg string
|
|
if readErr != nil {
|
|
errorMsg = "Failed to read error response body"
|
|
s.logger.WithError(readErr).Warn("Failed to read validation error response")
|
|
} else {
|
|
errorMsg = string(bodyBytes)
|
|
}
|
|
|
|
return &CustomErrors.APIError{
|
|
HTTPStatus: resp.StatusCode,
|
|
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
|
Code: "VALIDATION_FAILED",
|
|
}
|
|
}
|