Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -2,6 +2,7 @@
package proxy package proxy
import ( import (
"context"
"encoding/json" "encoding/json"
"gemini-balancer/internal/errors" "gemini-balancer/internal/errors"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -49,7 +50,6 @@ func (h *handler) registerRoutes(rg *gin.RouterGroup) {
} }
} }
// --- 请求 DTO ---
type CreateProxyConfigRequest struct { type CreateProxyConfigRequest struct {
Address string `json:"address" binding:"required"` Address string `json:"address" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
@@ -64,12 +64,10 @@ type UpdateProxyConfigRequest struct {
Description *string `json:"description"` Description *string `json:"description"`
} }
// 单个检测的请求体 (与前端JS对齐)
type CheckSingleProxyRequest struct { type CheckSingleProxyRequest struct {
Proxy string `json:"proxy" binding:"required"` Proxy string `json:"proxy" binding:"required"`
} }
// 批量检测的请求体
type CheckAllProxiesRequest struct { type CheckAllProxiesRequest struct {
Proxies []string `json:"proxies" binding:"required"` Proxies []string `json:"proxies" binding:"required"`
} }
@@ -84,7 +82,7 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
} }
if req.Status == "" { if req.Status == "" {
req.Status = "active" // 默认状态 req.Status = "active"
} }
proxyConfig := models.ProxyConfig{ proxyConfig := models.ProxyConfig{
@@ -98,7 +96,6 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
response.Error(c, errors.ParseDBError(err)) response.Error(c, errors.ParseDBError(err))
return return
} }
// 写操作后,发布事件并使缓存失效
h.publishAndInvalidate(proxyConfig.ID, "created") h.publishAndInvalidate(proxyConfig.ID, "created")
response.Created(c, proxyConfig) response.Created(c, proxyConfig)
} }
@@ -199,17 +196,16 @@ func (h *handler) DeleteProxyConfig(c *gin.Context) {
response.NoContent(c) response.NoContent(c)
} }
// publishAndInvalidate 统一事件发布和缓存失效逻辑
func (h *handler) publishAndInvalidate(proxyID uint, action string) { func (h *handler) publishAndInvalidate(proxyID uint, action string) {
go h.manager.invalidate() go h.manager.invalidate()
go func() { go func() {
ctx := context.Background()
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action} event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
eventData, _ := json.Marshal(event) eventData, _ := json.Marshal(event)
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData) _ = h.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
}() }()
} }
// 新的 Handler 方法和 DTO
type SyncProxiesRequest struct { type SyncProxiesRequest struct {
Proxies []string `json:"proxies"` Proxies []string `json:"proxies"`
} }
@@ -220,14 +216,12 @@ func (h *handler) SyncProxies(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
taskStatus, err := h.manager.SyncProxiesInBackground(c.Request.Context(), req.Proxies)
if err != nil { if err != nil {
if errors.Is(err, ErrTaskConflict) { if errors.Is(err, ErrTaskConflict) {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
} else { } else {
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
} }
return return
@@ -262,7 +256,7 @@ func (h *handler) CheckAllProxies(c *gin.Context) {
concurrency := cfg.ProxyCheckConcurrency concurrency := cfg.ProxyCheckConcurrency
if concurrency <= 0 { if concurrency <= 0 {
concurrency = 5 // 如果配置不合法,提供一个安全的默认值 concurrency = 5
} }
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency) results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
response.Success(c, results) response.Success(c, results)

View File

@@ -2,14 +2,13 @@
package proxy package proxy
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"gemini-balancer/internal/store" "gemini-balancer/internal/store"
"gemini-balancer/internal/syncer" "gemini-balancer/internal/syncer"
"gemini-balancer/internal/task" "gemini-balancer/internal/task"
"context"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@@ -25,7 +24,7 @@ import (
const ( const (
TaskTypeProxySync = "proxy_sync" TaskTypeProxySync = "proxy_sync"
proxyChunkSize = 200 // 代理同步的批量大小 proxyChunkSize = 200
) )
type ProxyCheckResult struct { type ProxyCheckResult struct {
@@ -35,13 +34,11 @@ type ProxyCheckResult struct {
ErrorMessage string `json:"error_message"` ErrorMessage string `json:"error_message"`
} }
// managerCacheData
type managerCacheData struct { type managerCacheData struct {
ActiveProxies []*models.ProxyConfig ActiveProxies []*models.ProxyConfig
ProxiesByID map[uint]*models.ProxyConfig ProxiesByID map[uint]*models.ProxyConfig
} }
// manager结构体
type manager struct { type manager struct {
db *gorm.DB db *gorm.DB
syncer *syncer.CacheSyncer[managerCacheData] syncer *syncer.CacheSyncer[managerCacheData]
@@ -80,21 +77,21 @@ func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskR
} }
} }
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) { func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) {
resourceID := "global_proxy_sync" resourceID := "global_proxy_sync"
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0) taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
if err != nil { if err != nil {
return nil, ErrTaskConflict return nil, ErrTaskConflict
} }
go m.runProxySyncTask(taskStatus.ID, proxyStrings) go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings)
return taskStatus, nil return taskStatus, nil
} }
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) { func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) {
resourceID := "global_proxy_sync" resourceID := "global_proxy_sync"
var allProxies []models.ProxyConfig var allProxies []models.ProxyConfig
if err := m.db.Find(&allProxies).Error; err != nil { if err := m.db.Find(&allProxies).Error; err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err)) m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
return return
} }
currentProxyMap := make(map[string]uint) currentProxyMap := make(map[string]uint)
@@ -125,19 +122,19 @@ func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
} }
if len(idsToDelete) > 0 { if len(idsToDelete) > 0 {
if err := m.bulkDeleteByIDs(idsToDelete); err != nil { if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err)) m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
return return
} }
} }
if len(proxiesToAdd) > 0 { if len(proxiesToAdd) > 0 {
if err := m.bulkAdd(proxiesToAdd); err != nil { if err := m.bulkAdd(proxiesToAdd); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err)) m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
return return
} }
} }
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)} result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
m.task.EndTaskByID(taskID, resourceID, result, nil) m.task.EndTaskByID(ctx, taskID, resourceID, result, nil)
m.publishChangeEvent("proxies_synced") m.publishChangeEvent(ctx, "proxies_synced")
go m.invalidate() go m.invalidate()
} }
@@ -184,14 +181,15 @@ func (m *manager) bulkDeleteByIDs(ids []uint) error {
} }
return nil return nil
} }
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error { func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
return m.db.CreateInBatches(proxies, proxyChunkSize).Error return m.db.CreateInBatches(proxies, proxyChunkSize).Error
} }
func (m *manager) publishChangeEvent(reason string) { func (m *manager) publishChangeEvent(ctx context.Context, reason string) {
event := models.ProxyStatusChangedEvent{Action: reason} event := models.ProxyStatusChangedEvent{Action: reason}
eventData, _ := json.Marshal(event) eventData, _ := json.Marshal(event)
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData) _ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
} }
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) { func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {

View File

@@ -1,4 +1,4 @@
// Filename: internal/handlers/apikey_handler.go // Filename: internal/handlers/apikey_handler.go (最终决战版)
package handlers package handlers
import ( import (
@@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo
} }
} }
// DTOs for API requests
type BulkAddKeysToGroupRequest struct { type BulkAddKeysToGroupRequest struct {
KeyGroupID uint `json:"key_group_id" binding:"required"` KeyGroupID uint `json:"key_group_id" binding:"required"`
Keys string `json:"keys" binding:"required"` Keys string `json:"keys" binding:"required"`
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false ValidateOnImport bool `json:"validate_on_import"`
} }
type BulkUnlinkKeysFromGroupRequest struct { type BulkUnlinkKeysFromGroupRequest struct {
@@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct {
} }
type BulkActionFilter struct { type BulkActionFilter struct {
Status []string `json:"status"` // Changed to slice to accept multiple statuses Status []string `json:"status"`
} }
type BulkActionRequest struct { type BulkActionRequest struct {
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"` Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
Filter BulkActionFilter `json:"filter" binding:"required"` Filter BulkActionFilter `json:"filter" binding:"required"`
} }
@@ -89,7 +88,8 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport) // [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return return
@@ -104,7 +104,8 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys) // [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return return
@@ -119,7 +120,8 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys) // [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return return
@@ -134,7 +136,8 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys) // [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return return
@@ -148,7 +151,8 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys) // [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return return
@@ -172,7 +176,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
} }
} }
if len(ids) > 0 { if len(ids) > 0 {
keys, err := h.apiKeyService.GetKeysByIds(ids) keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
if err != nil { if err != nil {
response.Error(c, &errors.APIError{ response.Error(c, &errors.APIError{
HTTPStatus: http.StatusInternalServerError, HTTPStatus: http.StatusInternalServerError,
@@ -191,7 +195,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
if params.PageSize <= 0 { if params.PageSize <= 0 {
params.PageSize = 20 params.PageSize = 20
} }
result, err := h.apiKeyService.ListAPIKeys(&params) result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil { if err != nil {
response.Error(c, errors.ParseDBError(err)) response.Error(c, errors.ParseDBError(err))
return return
@@ -201,19 +205,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
// ListKeysForGroup handles the GET /keygroups/:id/keys request. // ListKeysForGroup handles the GET /keygroups/:id/keys request.
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) { func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
// 1. Manually handle the path parameter.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format")) response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return return
} }
// 2. Bind query parameters using the correctly tagged struct.
var params models.APIKeyQueryParams var params models.APIKeyQueryParams
if err := c.ShouldBindQuery(&params); err != nil { if err := c.ShouldBindQuery(&params); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
return return
} }
// 3. Set server-side defaults and the path parameter.
if params.Page <= 0 { if params.Page <= 0 {
params.Page = 1 params.Page = 1
} }
@@ -221,15 +222,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
params.PageSize = 20 params.PageSize = 20
} }
params.KeyGroupID = uint(groupID) params.KeyGroupID = uint(groupID)
// 4. Call the service layer. paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
paginatedResult, err := h.apiKeyService.ListAPIKeys(&params)
if err != nil { if err != nil {
response.Error(c, errors.ParseDBError(err)) response.Error(c, errors.ParseDBError(err))
return return
} }
// 5. [THE FIX] Return a successful response using the standard `response.Success`
// and a gin.H map, as confirmed to exist in your project.
response.Success(c, gin.H{ response.Success(c, gin.H{
"items": paginatedResult.Items, "items": paginatedResult.Items,
"total": paginatedResult.Total, "total": paginatedResult.Total,
@@ -239,20 +236,18 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
} }
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) { func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
// Group ID is now correctly sourced from the URL path.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format")) response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return return
} }
// The request body is now simpler, only needing the keys.
var req BulkTestKeysForGroupRequest var req BulkTestKeysForGroupRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
// Call the same underlying service, but with unambiguous context. // [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys) taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return return
@@ -267,7 +262,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
} }
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group. // UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
// Route: PUT /keygroups/:id/apikeys/:keyId
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) { func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil { if err != nil {
@@ -284,8 +278,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
// Directly use the service to handle the logic updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
if err != nil { if err != nil {
var apiErr *errors.APIError var apiErr *errors.APIError
if errors.As(err, &apiErr) { if errors.As(err, &apiErr) {
@@ -305,7 +298,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format")) response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
return return
} }
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil { if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
response.Error(c, errors.ParseDBError(err)) response.Error(c, errors.ParseDBError(err))
return return
} }
@@ -313,7 +306,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
} }
// RestoreKeysInGroup 恢复指定Key的接口 // RestoreKeysInGroup 恢复指定Key的接口
// POST /keygroups/:id/apikeys/restore
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) { func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil { if err != nil {
@@ -325,7 +317,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs) taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
if err != nil { if err != nil {
var apiErr *errors.APIError var apiErr *errors.APIError
if errors.As(err, &apiErr) { if errors.As(err, &apiErr) {
@@ -339,14 +331,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
} }
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口 // RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
// POST /keygroups/:id/apikeys/restore-all-banned
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) { func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32) groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return return
} }
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID)) taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
if err != nil { if err != nil {
var apiErr *errors.APIError var apiErr *errors.APIError
if errors.As(err, &apiErr) { if errors.As(err, &apiErr) {
@@ -360,48 +351,41 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
} }
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters. // HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
// Route: POST /keygroups/:id/bulk-actions
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) { func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
// 1. Parse GroupID from URL
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return return
} }
// 2. Bind the JSON payload to our new DTO
var req BulkActionRequest var req BulkActionRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return return
} }
// 3. Central logic: based on the action, call the appropriate service method.
var task *task.Status var task *task.Status
var apiErr *errors.APIError var apiErr *errors.APIError
switch req.Action { switch req.Action {
case "revalidate": case "revalidate":
// Assume keyValidationService has a method that accepts a filter // [修正] 将请求的 context 传递给 service 层
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status) task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
case "set_status": case "set_status":
if req.NewStatus == "" { if req.NewStatus == "" {
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action") apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
break break
} }
// Assume apiKeyService has a method to update status by filter targetStatus := models.APIKeyStatus(req.NewStatus)
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
case "delete": case "delete":
// Assume keyImportService has a method to unlink by filter // [修正] 将请求的 context 传递给 service 层
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status) task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
default: default:
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action) apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
} }
// 4. Handle errors from the switch block
if apiErr != nil { if apiErr != nil {
response.Error(c, apiErr) response.Error(c, apiErr)
return return
} }
if err != nil { if err != nil {
// Attempt to parse it as a known APIError, otherwise, wrap it.
var parsedErr *errors.APIError var parsedErr *errors.APIError
if errors.As(err, &parsedErr) { if errors.As(err, &parsedErr) {
response.Error(c, parsedErr) response.Error(c, parsedErr)
@@ -410,21 +394,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
} }
return return
} }
// 5. Return the task status on success
response.Success(c, task) response.Success(c, task)
} }
// ExportKeysForGroup handles requests to export all keys for a group based on status filters. // ExportKeysForGroup handles requests to export all keys for a group based on status filters.
// Route: GET /keygroups/:id/apikeys/export
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) { func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return return
} }
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
statuses := c.QueryArray("status") statuses := c.QueryArray("status")
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses) keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
if err != nil { if err != nil {
response.Error(c, errors.ParseDBError(err)) response.Error(c, errors.ParseDBError(err))
return return

View File

@@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) {
c.JSON(http.StatusOK, stats) c.JSON(http.StatusOK, stats)
} }
// GetChart 获取仪表盘的图表数据 // GetChart
func (h *DashboardHandler) GetChart(c *gin.Context) { func (h *DashboardHandler) GetChart(c *gin.Context) {
var groupID *uint var groupID *uint
if groupIDStr := c.Query("groupId"); groupIDStr != "" { if groupIDStr := c.Query("groupId"); groupIDStr != "" {
@@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
} }
} }
chartData, err := h.queryService.QueryHistoricalChart(groupID) chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID)
if err != nil { if err != nil {
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error()) apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr) c.JSON(apiErr.HTTPStatus, apiErr)
@@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
c.JSON(http.StatusOK, chartData) c.JSON(http.StatusOK, chartData)
} }
// GetRequestStats 处理对“期间调用概览”的请求 // GetRequestStats
func (h *DashboardHandler) GetRequestStats(c *gin.Context) { func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
period := c.Param("period") // 从 URL 路径中获取 period period := c.Param("period")
stats, err := h.queryService.GetRequestStatsForPeriod(period) stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period)
if err != nil { if err != nil {
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error()) apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr) c.JSON(apiErr.HTTPStatus, apiErr)

View File

@@ -2,6 +2,7 @@
package handlers package handlers
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/errors" "gemini-balancer/internal/errors"
@@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das
} }
} }
// DTOs & 辅助函数
func isValidGroupName(name string) bool { func isValidGroupName(name string) bool {
if name == "" { if name == "" {
return false return false
@@ -40,7 +40,6 @@ func isValidGroupName(name string) bool {
return match return match
} }
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
type KeyGroupOperationalSettings struct { type KeyGroupOperationalSettings struct {
EnableKeyCheck *bool `json:"enable_key_check"` EnableKeyCheck *bool `json:"enable_key_check"`
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"` KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
@@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct {
MaxRetries *int `json:"max_retries"` MaxRetries *int `json:"max_retries"`
EnableSmartGateway *bool `json:"enable_smart_gateway"` EnableSmartGateway *bool `json:"enable_smart_gateway"`
} }
type CreateKeyGroupRequest struct { type CreateKeyGroupRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
DisplayName string `json:"display_name"` DisplayName string `json:"display_name"`
@@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct {
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"` PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy bool `json:"enable_proxy"` EnableProxy bool `json:"enable_proxy"`
ChannelType string `json:"channel_type"` ChannelType string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings KeyGroupOperationalSettings
} }
type UpdateKeyGroupRequest struct { type UpdateKeyGroupRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
DisplayName *string `json:"display_name"` DisplayName *string `json:"display_name"`
@@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct {
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"` PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy *bool `json:"enable_proxy"` EnableProxy *bool `json:"enable_proxy"`
ChannelType *string `json:"channel_type"` ChannelType *string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings KeyGroupOperationalSettings
// M:N associations
AllowedUpstreams []string `json:"allowed_upstreams"` AllowedUpstreams []string `json:"allowed_upstreams"`
AllowedModels []string `json:"allowed_models"` AllowedModels []string `json:"allowed_models"`
} }
type KeyGroupResponse struct { type KeyGroupResponse struct {
ID uint `json:"id"` ID uint `json:"id"`
Name string `json:"name"` Name string `json:"name"`
@@ -96,36 +86,30 @@ type KeyGroupResponse struct {
AllowedModels []string `json:"allowed_models"` AllowedModels []string `json:"allowed_models"`
AllowedUpstreams []string `json:"allowed_upstreams"` AllowedUpstreams []string `json:"allowed_upstreams"`
} }
// [NEW] Define the detailed response structure for a single group.
type KeyGroupDetailsResponse struct { type KeyGroupDetailsResponse struct {
KeyGroupResponse KeyGroupResponse
Settings *models.GroupSettings `json:"settings,omitempty"` Settings *models.GroupSettings `json:"settings,omitempty"`
RequestConfig *models.RequestConfig `json:"request_config,omitempty"` RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
} }
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string { func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
modelNames := make([]string, 0, len(mappings)) modelNames := make([]string, 0, len(mappings))
for _, mapping := range mappings { for _, mapping := range mappings {
if mapping != nil { // Safety check if mapping != nil {
modelNames = append(modelNames, mapping.ModelName) modelNames = append(modelNames, mapping.ModelName)
} }
} }
return modelNames return modelNames
} }
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string { func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
urls := make([]string, 0, len(upstreams)) urls := make([]string, 0, len(upstreams))
for _, upstream := range upstreams { for _, upstream := range upstreams {
if upstream != nil { // Safety check if upstream != nil {
urls = append(urls, upstream.URL) urls = append(urls, upstream.URL)
} }
} }
return urls return urls
} }
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse { func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
return KeyGroupResponse{ return KeyGroupResponse{
ID: group.ID, ID: group.ID,
@@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i
CreatedAt: group.CreatedAt, CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt, UpdatedAt: group.UpdatedAt,
Order: group.Order, Order: group.Order,
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper AllowedModels: transformModelsToStrings(group.AllowedModels),
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams),
} }
} }
// packGroupSettings is a helper to convert request-level operational settings
// into the model-level settings struct.
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings { func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
return &models.KeyGroupSettings{ return &models.KeyGroupSettings{
EnableKeyCheck: settings.EnableKeyCheck, EnableKeyCheck: settings.EnableKeyCheck,
@@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet
EnableSmartGateway: settings.EnableSmartGateway, EnableSmartGateway: settings.EnableSmartGateway,
} }
} }
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) { func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
@@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup,
} }
return group, nil return group, nil
} }
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) { func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
if req.Name != nil { if req.Name != nil {
group.Name = *req.Name group.Name = *req.Name
@@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event. // publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) { func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
go func() { go func() {
ctx := context.Background()
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason} event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
eventData, _ := json.Marshal(event) eventData, _ := json.Marshal(event)
h.store.Publish(models.TopicKeyStatusChanged, eventData) _ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}() }()
} }
@@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
return return
} }
// The core logic remains, as it's specific to creation.
p := bluemonday.StripTagsPolicy() p := bluemonday.StripTagsPolicy()
sanitizedDisplayName := p.Sanitize(req.DisplayName) sanitizedDisplayName := p.Sanitize(req.DisplayName)
sanitizedDescription := p.Sanitize(req.Description) sanitizedDescription := p.Sanitize(req.Description)
@@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(keyGroup, 0)) response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
} }
// 统一的处理器可以处理两种情况:
// 1. GET /keygroups - 返回所有组的列表 // 1. GET /keygroups - 返回所有组的列表
// 2. GET /keygroups/:id - 返回指定ID的单个组 // 2. GET /keygroups/:id - 返回指定ID的单个组
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) { func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
// Case 1: Get a single group
if idStr := c.Param("id"); idStr != "" { if idStr := c.Param("id"); idStr != "" {
group, apiErr := h.getGroupFromContext(c) group, apiErr := h.getGroupFromContext(c)
if apiErr != nil { if apiErr != nil {
@@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, detailedResponse) response.Success(c, detailedResponse)
return return
} }
// Case 2: Get all groups
allGroups := h.groupManager.GetAllGroups() allGroups := h.groupManager.GetAllGroups()
responses := make([]KeyGroupResponse, 0, len(allGroups)) responses := make([]KeyGroupResponse, 0, len(allGroups))
for _, group := range allGroups { for _, group := range allGroups {
@@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, responses) response.Success(c, responses)
} }
// UpdateKeyGroup
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) { func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c) group, apiErr := h.getGroupFromContext(c)
if apiErr != nil { if apiErr != nil {
@@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount)) response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
} }
// DeleteKeyGroup
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) { func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c) group, apiErr := h.getGroupFromContext(c)
if apiErr != nil { if apiErr != nil {
@@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)}) response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
} }
// GetKeyGroupStats
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) { func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c) group, apiErr := h.getGroupFromContext(c)
if apiErr != nil { if apiErr != nil {
response.Error(c, apiErr) response.Error(c, apiErr)
return return
} }
stats, err := h.queryService.GetGroupStats(group.ID)
stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID)
if err != nil { if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
return return
@@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount)) response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
} }
// 更新分组排序
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) { func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
var payload []service.UpdateOrderPayload var payload []service.UpdateOrderPayload
if err := c.ShouldBindJSON(&payload); err != nil { if err := c.ShouldBindJSON(&payload); err != nil {

View File

@@ -136,7 +136,6 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
var finalPromptTokens, finalCompletionTokens int var finalPromptTokens, finalCompletionTokens int
var actualRetries int = 0 var actualRetries int = 0
defer func() { defer func() {
// 如果一次尝试都未成功(例如,在第一次获取资源时就失败),则不记录日志
if lastUsedResources == nil { if lastUsedResources == nil {
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.") h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
return return
@@ -151,44 +150,38 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
} }
// 确保即使在成功的情况下如果recorder存在也记录最终的状态码
if finalRecorder != nil { if finalRecorder != nil {
finalEvent.RequestLog.StatusCode = finalRecorder.Code finalEvent.RequestLog.StatusCode = finalRecorder.Code
} }
if !isSuccess { if !isSuccess {
// 将 finalProxyErr 的信息填充到 RequestLog 中
if finalProxyErr != nil { if finalProxyErr != nil {
finalEvent.Error = finalProxyErr // Error 字段用于事件传递,不会被序列化到数据库 finalEvent.Error = finalProxyErr
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
} else if finalRecorder != nil { } else if finalRecorder != nil {
// 降级处理:如果 finalProxyErr 为空但 recorder 存在且失败
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.") apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
finalEvent.Error = apiErr finalEvent.Error = apiErr
finalEvent.RequestLog.ErrorCode = apiErr.Code finalEvent.RequestLog.ErrorCode = apiErr.Code
finalEvent.RequestLog.ErrorMessage = apiErr.Message finalEvent.RequestLog.ErrorMessage = apiErr.Message
} }
} }
// 将完整的事件发布
eventData, err := json.Marshal(finalEvent) eventData, err := json.Marshal(finalEvent)
if err != nil { if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.") h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
return return
} }
if err := h.store.Publish(models.TopicRequestFinished, eventData); err != nil { if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.") h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
} }
}() }()
var maxRetries int var maxRetries int
if isPreciseRouting { if isPreciseRouting {
// For precise routing, use the group's setting. If not set, fall back to the global setting.
if finalOpConfig.MaxRetries != nil { if finalOpConfig.MaxRetries != nil {
maxRetries = *finalOpConfig.MaxRetries maxRetries = *finalOpConfig.MaxRetries
} else { } else {
maxRetries = h.settingsManager.GetSettings().MaxRetries maxRetries = h.settingsManager.GetSettings().MaxRetries
} }
} else { } else {
// For BasePool (intelligent aggregation), *always* use the global setting.
maxRetries = h.settingsManager.GetSettings().MaxRetries maxRetries = h.settingsManager.GetSettings().MaxRetries
} }
totalAttempts := maxRetries + 1 totalAttempts := maxRetries + 1
@@ -332,7 +325,7 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
retryEvent.ErrorMessage = attemptErr.Message retryEvent.ErrorMessage = attemptErr.Message
} }
eventData, _ := json.Marshal(retryEvent) eventData, _ := json.Marshal(retryEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData) _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
} }
if finalRecorder != nil { if finalRecorder != nil {
bodyBytes := finalRecorder.Body.Bytes() bodyBytes := finalRecorder.Body.Bytes()
@@ -368,7 +361,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
requestFinishedEvent.StatusCode = c.Writer.Status() requestFinishedEvent.StatusCode = c.Writer.Status()
} }
eventData, _ := json.Marshal(requestFinishedEvent) eventData, _ := json.Marshal(requestFinishedEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData) _ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
}() }()
params := channel.SmartRequestParams{ params := channel.SmartRequestParams{
CorrelationID: correlationID, CorrelationID: correlationID,
@@ -435,7 +428,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
} }
} }
if res != nil { if res != nil {
// [核心修正] 填充到内嵌的 RequestLog 结构体中
if res.APIKey != nil { if res.APIKey != nil {
event.RequestLog.KeyID = &res.APIKey.ID event.RequestLog.KeyID = &res.APIKey.ID
} }
@@ -444,7 +436,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
} }
if res.UpstreamEndpoint != nil { if res.UpstreamEndpoint != nil {
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
// UpstreamURL 是事件传递字段,不是数据库字段,所以在这里赋值是正确的
event.UpstreamURL = &res.UpstreamEndpoint.URL event.UpstreamURL = &res.UpstreamEndpoint.URL
} }
if res.ProxyConfig != nil { if res.ProxyConfig != nil {
@@ -464,9 +455,9 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context") return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
} }
if isPreciseRouting { if isPreciseRouting {
return h.resourceService.GetResourceFromGroup(authToken, groupName) return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
} else { } else {
return h.resourceService.GetResourceFromBasePool(authToken, modelName) return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
} }
} }

View File

@@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
return return
} }
taskStatus, err := h.taskService.GetStatus(taskID) taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID)
if err != nil { if err != nil {
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理 // TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error())) response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))

View File

@@ -2,6 +2,7 @@
package repository package repository
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -22,7 +23,7 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown" BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
) )
func (r *gormKeyRepository) LoadAllKeysToStore() error { func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...") r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
var allMappings []*models.GroupAPIKeyMapping var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil { if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
@@ -48,7 +49,7 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
} }
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping) activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline() pipe := r.store.Pipeline(context.Background())
detailsToSet := make(map[string][]byte) detailsToSet := make(map[string][]byte)
var allGroups []*models.KeyGroup var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil { if err := r.db.Find(&allGroups).Error; err == nil {
@@ -100,14 +101,14 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...) pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...) pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers) go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
} }
if err := pipe.Exec(); err != nil { if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err) return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
} }
for key, value := range detailsToSet { for key, value := range detailsToSet {
if err := r.store.Set(key, value, 0); err != nil { if err := r.store.Set(context.Background(), key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key) r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
} }
} }
@@ -124,16 +125,16 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err) return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
} }
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0) return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
} }
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error { func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(key.ID) groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil { if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err) r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
} }
pipe := r.store.Pipeline() pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID)) pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
@@ -144,13 +145,13 @@ func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr) pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr) pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr) go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
} }
return pipe.Exec() return pipe.Exec()
} }
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error { func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline() pipe := r.store.Pipeline(context.Background())
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID) activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID) pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
if mapping.Status == models.StatusActive { if mapping.Status == models.StatusActive {
@@ -159,7 +160,7 @@ func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIK
return pipe.Exec() return pipe.Exec()
} }
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error { func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 { if len(mappings) == 0 {
return nil return nil
} }
@@ -184,7 +185,7 @@ func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.Group
} }
groupUpdates[mapping.KeyGroupID] = update groupUpdates[mapping.KeyGroupID] = update
} }
pipe := r.store.Pipeline() pipe := r.store.Pipeline(context.Background())
var pipelineError error var pipelineError error
for groupID, updates := range groupUpdates { for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID) activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"context"
"math/rand" "math/rand"
"strings" "strings"
"time" "time"
@@ -115,7 +116,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
} }
func (r *gormKeyRepository) HardDeleteByID(id uint) error { func (r *gormKeyRepository) HardDeleteByID(id uint) error {
key, err := r.GetKeyByID(id) // This now returns a decrypted key key, err := r.GetKeyByID(id)
if err != nil { if err != nil {
return err return err
} }
@@ -125,7 +126,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
if err != nil { if err != nil {
return err return err
} }
if err := r.removeStoreCacheForKey(key); err != nil { if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err) r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
} }
return nil return nil
@@ -140,16 +141,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
hash := sha256.Sum256([]byte(v)) hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:]) hashes[i] = hex.EncodeToString(hash[:])
} }
// Find the full key objects first to update the cache later.
var keysToDelete []models.APIKey var keysToDelete []models.APIKey
// [MODIFIED] Find by hash.
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil { if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
return 0, err return 0, err
} }
if len(keysToDelete) == 0 { if len(keysToDelete) == 0 {
return 0, nil return 0, nil
} }
// Decrypt them to ensure cache has plaintext if needed.
if err := r.decryptKeys(keysToDelete); err != nil { if err := r.decryptKeys(keysToDelete); err != nil {
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err) r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
} }
@@ -167,7 +165,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
return 0, err return 0, err
} }
for i := range keysToDelete { for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil { if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err) r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
} }
} }

View File

@@ -2,6 +2,7 @@
package repository package repository
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
} }
result := db.Delete(&models.APIKey{}, orphanKeyIDs) result := db.Delete(&models.APIKey{}, orphanKeyIDs)
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
if result.Error != nil { if result.Error != nil {
return 0, result.Error return 0, result.Error
} }
for i := range keysToDelete { for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil { // [修正] 使用 context.Background() 调用已更新的缓存清理函数
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err) r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
} }
} }
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
return keys, nil return keys, nil
} }
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error { func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}). result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID). Where("id = ?", keyID).
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
if err == nil { if err == nil {
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID) r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
go func() { go func() {
if err := r.LoadAllKeysToStore(); err != nil { if err := r.LoadAllKeysToStore(context.Background()); err != nil {
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err) r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
} }
}() }()

View File

@@ -2,6 +2,7 @@
package repository package repository
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@@ -14,7 +15,7 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error { func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 { if len(keyIDs) == 0 {
return nil return nil
} }
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
} }
for _, keyID := range keyIDs { for _, keyID := range keyIDs {
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID) r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
} }
return nil return nil
} }
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) { func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
if len(keyIDs) == 0 { if len(keyIDs) == 0 {
return 0, nil return 0, nil
} }
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID) activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
for _, keyID := range keyIDs { for _, keyID := range keyIDs {
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID) r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID))) r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
} }
return unlinkedCount, nil return unlinkedCount, nil
} }
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) { func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
cacheKey := fmt.Sprintf("key:%d:groups", keyID) cacheKey := fmt.Sprintf("key:%d:groups", keyID)
strGroupIDs, err := r.store.SMembers(cacheKey) strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
if err != nil || len(strGroupIDs) == 0 { if err != nil || len(strGroupIDs) == 0 {
var groupIDs []uint var groupIDs []uint
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
for _, id := range groupIDs { for _, id := range groupIDs {
interfaceSlice = append(interfaceSlice, id) interfaceSlice = append(interfaceSlice, id)
} }
r.store.SAdd(cacheKey, interfaceSlice...) r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
} }
return groupIDs, nil return groupIDs, nil
} }
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
return &mapping, err return &mapping, err
} }
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error { func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error { err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error return tx.Save(mapping).Error
}) })

View File

@@ -1,7 +1,8 @@
// Filename: internal/repository/key_selector.go // Filename: internal/repository/key_selector.go (经审查后最终修复版)
package repository package repository
import ( import (
"context"
"crypto/sha1" "crypto/sha1"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -23,19 +24,18 @@ const (
) )
// SelectOneActiveKey 根据指定的轮询策略从缓存中高效地选取一个可用的API密钥。 // SelectOneActiveKey 根据指定的轮询策略从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
var keyIDStr string var keyIDStr string
var err error var err error
switch group.PollingStrategy { switch group.PollingStrategy {
case models.StrategySequential: case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID) sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey) keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted: case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID) lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 { if zerr == nil && len(results) > 0 {
keyIDStr = results[0] keyIDStr = results[0]
} }
@@ -44,11 +44,11 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
case models.StrategyRandom: case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID) mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略 default: // 默认或未指定策略时,使用基础的随机策略
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey) keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
} }
if err != nil { if err != nil {
@@ -65,27 +65,25 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID) apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil { if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID) r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
return nil, nil, err return nil, nil, err
} }
if group.PollingStrategy == models.StrategyWeighted { if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID)) go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID))
} }
return apiKey, mapping, nil return apiKey, mapping, nil
} }
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。 // SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) { func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
poolID := generatePoolID(pool.CandidateGroups) poolID := generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID) log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil { if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.") log.WithError(err).Error("Failed to ensure BasePool cache exists.")
return nil, nil, err return nil, nil, err
} }
@@ -96,10 +94,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
switch pool.PollingStrategy { switch pool.PollingStrategy {
case models.StrategySequential: case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey) keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted: case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID) lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0) results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 { if zerr == nil && len(results) > 0 {
keyIDStr = results[0] keyIDStr = results[0]
} }
@@ -107,12 +105,11 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
case models.StrategyRandom: case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID) mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID) cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey) keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案 default:
log.Warnf("Default polling strategy triggered inside selection. This should be rare.") log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID) sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0) keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0)
} }
if err != nil { if err != nil {
@@ -128,12 +125,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups { for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID) apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil { if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted { if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID))
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
} }
return apiKey, group, nil return apiKey, group, nil
} }
@@ -144,42 +139,39 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
} }
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构 // ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error { func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID) listKey := fmt.Sprintf(BasePoolSequential, poolID)
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 --- exists, err := r.store.Exists(ctx, listKey)
exists, err := r.store.Exists(listKey)
if err != nil { if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID) r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err // 直接返回读取错误 return err
} }
if exists { if exists {
val, err := r.store.LIndex(listKey, 0) val, err := r.store.LIndex(ctx, listKey, 0)
if err != nil { if err != nil {
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID) r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else { } else {
if val == EmptyPoolPlaceholder { if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound // 已知为空,直接返回 return gorm.ErrRecordNotFound
} }
return nil // 缓存有效,直接返回 return nil
} }
} }
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
lockKey := fmt.Sprintf("lock:basepool:%s", poolID) lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时 acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second)
if err != nil { if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.") r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err return err
} }
if !acquired { if !acquired {
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID) return r.ensureBasePoolCacheExists(ctx, pool, poolID)
} }
defer r.store.Del(lockKey) // 确保在函数退出时释放锁 defer r.store.Del(context.Background(), lockKey)
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
if exists, _ := r.store.Exists(listKey); exists { if exists, _ := r.store.Exists(ctx, listKey); exists {
return nil return nil
} }
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID) r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
@@ -187,22 +179,15 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
lruMembers := make(map[string]float64) lruMembers := make(map[string]float64)
for _, group := range pool.CandidateGroups { for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID) activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey) groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey)
// --- [核心修正] ---
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
if err != nil { if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID) r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
// 从而给了下一次请求一个全新的、成功的机会。
return err return err
} }
// 只有在 SMembers 成功时,才继续处理
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...) allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs { for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64) keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID) _, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err == nil && mapping != nil { if err == nil && mapping != nil {
var score float64 var score float64
if mapping.LastUsedAt != nil { if mapping.LastUsedAt != nil {
@@ -213,12 +198,9 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
} }
} }
// --- [逻辑修正] ---
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
// 才允许写入“毒丸”。
if len(allActiveKeyIDs) == 0 { if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID) r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline() pipe := r.store.Pipeline(ctx)
pipe.LPush(listKey, EmptyPoolPlaceholder) pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL) pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil { if err := pipe.Exec(); err != nil {
@@ -226,14 +208,10 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
} }
return gorm.ErrRecordNotFound return gorm.ErrRecordNotFound
} }
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据 pipe := r.store.Pipeline(ctx)
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL) pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
@@ -244,17 +222,22 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
} }
if len(lruMembers) > 0 { if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers) if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID)
}
} }
return nil return nil
} }
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET // updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) { func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID) lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{ err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(), strconv.FormatUint(uint64(keyID), 10): nowMilli(),
}) })
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
}
} }
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID // generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
@@ -285,8 +268,8 @@ func nowMilli() float64 {
} }
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。 // getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) { func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID)) apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err) return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
} }
@@ -295,7 +278,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err) return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
} }
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID)) mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err) return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
} }

View File

@@ -1,7 +1,9 @@
// Filename: internal/repository/key_writer.go // Filename: internal/repository/key_writer.go
package repository package repository
import ( import (
"context"
"fmt" "fmt"
"gemini-balancer/internal/errors" "gemini-balancer/internal/errors"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -9,7 +11,7 @@ import (
"time" "time"
) )
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) { func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
lruKey := fmt.Sprintf(KeyGroupLRU, groupID) lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
timestamp := float64(time.Now().UnixMilli()) timestamp := float64(time.Now().UnixMilli())
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
strconv.FormatUint(uint64(keyID), 10): timestamp, strconv.FormatUint(uint64(keyID), 10): timestamp,
} }
if err := r.store.ZAdd(lruKey, members); err != nil { if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID) r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
} }
} }
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) { func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus) r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus) r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
} }
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) { func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus) r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus) r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
} }
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) { func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
keyIDStr := strconv.FormatUint(uint64(keyID), 10) keyIDStr := strconv.FormatUint(uint64(keyID), 10)
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID) sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
lruKey := fmt.Sprintf(KeyGroupLRU, groupID) lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID) mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID) cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
_ = r.store.LRem(sequentialKey, 0, keyIDStr) _ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(lruKey, keyIDStr) _ = r.store.ZRem(ctx, lruKey, keyIDStr)
_ = r.store.SRem(mainPoolKey, keyIDStr) _ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
_ = r.store.SRem(cooldownPoolKey, keyIDStr) _ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
if newStatus == models.StatusActive { if newStatus == models.StatusActive {
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil { if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID) r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
} }
members := map[string]float64{keyIDStr: 0} members := map[string]float64{keyIDStr: 0}
if err := r.store.ZAdd(lruKey, members); err != nil { if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID) r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
} }
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil { if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID) r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
} }
} }
} }
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback. func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
if success { if success {
if group.PollingStrategy == models.StrategyWeighted { if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, key.ID) go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
} }
return return
} }
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
} }
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message) r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
// This call is correct. It uses the synchronous, direct method. r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
} }

View File

@@ -1,7 +1,8 @@
// Filename: internal/repository/repository.go // Filename: internal/repository/repository.go (经审查后最终修复版)
package repository package repository
import ( import (
"context"
"gemini-balancer/internal/crypto" "gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors" "gemini-balancer/internal/errors"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -22,8 +23,8 @@ type BasePool struct {
type KeyRepository interface { type KeyRepository interface {
// --- 核心选取与调度 --- key_selector // --- 核心选取与调度 --- key_selector
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
// --- 加密与解密 --- key_crud // --- 加密与解密 --- key_crud
Decrypt(key *models.APIKey) error Decrypt(key *models.APIKey) error
@@ -37,16 +38,16 @@ type KeyRepository interface {
GetKeyByID(id uint) (*models.APIKey, error) GetKeyByID(id uint) (*models.APIKey, error)
GetKeyByValue(keyValue string) (*models.APIKey, error) GetKeyByValue(keyValue string) (*models.APIKey, error)
GetKeysByValues(keyValues []string) ([]models.APIKey, error) GetKeysByValues(keyValues []string) ([]models.APIKey, error)
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key GetKeysByIDs(ids []uint) ([]models.APIKey, error)
GetKeysByGroup(groupID uint) ([]models.APIKey, error) GetKeysByGroup(groupID uint) ([]models.APIKey, error)
CountByGroup(groupID uint) (int64, error) CountByGroup(groupID uint) (int64, error)
// --- 多对多关系管理 --- key_mapping // --- 多对多关系管理 --- key_mapping
LinkKeysToGroup(groupID uint, keyIDs []uint) error LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(keyID uint) ([]uint, error) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error) GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
UpdateMapping(mapping *models.GroupAPIKeyMapping) error UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error) GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error) GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error) FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
@@ -55,8 +56,8 @@ type KeyRepository interface {
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
// --- 缓存管理 --- key_cache // --- 缓存管理 --- key_cache
LoadAllKeysToStore() error LoadAllKeysToStore(ctx context.Context) error
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
// --- 维护与后台任务 --- key_maintenance // --- 维护与后台任务 --- key_maintenance
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
@@ -65,16 +66,14 @@ type KeyRepository interface {
DeleteOrphanKeys() (int64, error) DeleteOrphanKeys() (int64, error)
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error) DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
GetActiveMasterKeys() ([]*models.APIKey, error) GetActiveMasterKeys() ([]*models.APIKey, error)
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
HardDeleteSoftDeletedBefore(date time.Time) (int64, error) HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
// --- 轮询策略的"写"操作 --- key_writer // --- 轮询策略的"写"操作 --- key_writer
UpdateKeyUsageTimestamp(groupID, keyID uint) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
// 同步更新缓存,供核心业务使用 SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
// 异步更新缓存,供事件订阅者使用 UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
} }
type GroupRepository interface { type GroupRepository interface {

View File

@@ -2,6 +2,7 @@
package scheduler package scheduler
import ( import (
"context"
"gemini-balancer/internal/repository" "gemini-balancer/internal/repository"
"gemini-balancer/internal/service" "gemini-balancer/internal/service"
"time" "time"
@@ -15,7 +16,6 @@ type Scheduler struct {
logger *logrus.Entry logger *logrus.Entry
statsService *service.StatsService statsService *service.StatsService
keyRepo repository.KeyRepository keyRepo repository.KeyRepository
// healthCheckService *service.HealthCheckService // 健康检查任务预留
} }
func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler { func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler {
@@ -32,11 +32,13 @@ func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyReposito
func (s *Scheduler) Start() { func (s *Scheduler) Start() {
s.logger.Info("Starting scheduler and registering jobs...") s.logger.Info("Starting scheduler and registering jobs...")
// --- 任务注册 --- // 任务一:每小时执行一次的统计聚合
// 使用CRON表达式精确定义“每小时的第5分钟”执行 // 使用CRON表达式精确定义“每小时的第5分钟”执行
_, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() { _, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() {
s.logger.Info("Executing hourly request stats aggregation...") s.logger.Info("Executing hourly request stats aggregation...")
if err := s.statsService.AggregateHourlyStats(); err != nil { // 为后台定时任务创建一个新的、空的 context
ctx := context.Background()
if err := s.statsService.AggregateHourlyStats(ctx); err != nil {
s.logger.WithError(err).Error("Hourly stats aggregation failed.") s.logger.WithError(err).Error("Hourly stats aggregation failed.")
} else { } else {
s.logger.Info("Hourly stats aggregation completed successfully.") s.logger.Info("Hourly stats aggregation completed successfully.")
@@ -46,23 +48,14 @@ func (s *Scheduler) Start() {
s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err) s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err)
} }
// 任务二:(预留) 自动健康检查 (例如每10分钟一次) // 任务二:(预留) 自动健康检查
/*
_, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() { // 任务三每日执行一次的软删除Key清理
s.logger.Info("Executing periodic health check for all groups...")
// s.healthCheckService.StartGlobalCheckTask() // 伪代码
})
if err != nil {
s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err)
}
*/
// [NEW] --- 任务三: 清理软删除的API Keys ---
// Executes once daily at 3:15 AM UTC. // Executes once daily at 3:15 AM UTC.
_, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() { _, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() {
s.logger.Info("Executing daily cleanup of soft-deleted API keys...") s.logger.Info("Executing daily cleanup of soft-deleted API keys...")
// Let's assume a retention period of 7 days for now. // [假设保留7天实际应来自配置
// In a real scenario, this should come from settings.
const retentionDays = 7 const retentionDays = 7
count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays)) count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays))
@@ -77,9 +70,8 @@ func (s *Scheduler) Start() {
if err != nil { if err != nil {
s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err) s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err)
} }
// --- 任务注册结束 ---
s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程 s.gocronScheduler.StartAsync()
s.logger.Info("Scheduler started.") s.logger.Info("Scheduler started.")
} }

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/analytics_service.go // Filename: internal/service/analytics_service.go
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/db/dialect" "gemini-balancer/internal/db/dialect"
@@ -43,7 +43,7 @@ func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d di
} }
func (s *AnalyticsService) Start() { func (s *AnalyticsService) Start() {
s.wg.Add(2) // 2 (flushLoop, eventListener) s.wg.Add(2)
go s.flushLoop() go s.flushLoop()
go s.eventListener() go s.eventListener()
s.logger.Info("AnalyticsService (Command Side) started.") s.logger.Info("AnalyticsService (Command Side) started.")
@@ -53,13 +53,13 @@ func (s *AnalyticsService) Stop() {
close(s.stopChan) close(s.stopChan)
s.wg.Wait() s.wg.Wait()
s.logger.Info("AnalyticsService stopped. Performing final data flush...") s.logger.Info("AnalyticsService stopped. Performing final data flush...")
s.flushToDB() // 停止前刷盘 s.flushToDB()
s.logger.Info("AnalyticsService final data flush completed.") s.logger.Info("AnalyticsService final data flush completed.")
} }
func (s *AnalyticsService) eventListener() { func (s *AnalyticsService) eventListener() {
defer s.wg.Done() defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished) sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return return
@@ -87,9 +87,10 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.GroupID == nil { if event.RequestLog.GroupID == nil {
return return
} }
ctx := context.Background()
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15")) key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName) fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline() pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1) pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess { if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1) pipe.HIncrBy(key, fieldPrefix+":success", 1)
@@ -120,6 +121,7 @@ func (s *AnalyticsService) flushLoop() {
} }
func (s *AnalyticsService) flushToDB() { func (s *AnalyticsService) flushToDB() {
ctx := context.Background()
now := time.Now().UTC() now := time.Now().UTC()
keysToFlush := []string{ keysToFlush := []string{
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")), fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
@@ -127,7 +129,7 @@ func (s *AnalyticsService) flushToDB() {
} }
for _, key := range keysToFlush { for _, key := range keysToFlush {
data, err := s.store.HGetAll(key) data, err := s.store.HGetAll(ctx, key)
if err != nil || len(data) == 0 { if err != nil || len(data) == 0 {
continue continue
} }
@@ -136,15 +138,15 @@ func (s *AnalyticsService) flushToDB() {
if len(statsToFlush) > 0 { if len(statsToFlush) > 0 {
upsertClause := s.dialect.OnConflictUpdateAll( upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"}, // conflict columns []string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns []string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
) )
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
if err != nil { if err != nil {
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err) s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
} else { } else {
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key) s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
_ = s.store.HDel(key, parsedFields...) _ = s.store.HDel(ctx, key, parsedFields...)
} }
} }
} }

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/apikey_service.go // Filename: internal/service/apikey_service.go
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -29,7 +29,6 @@ const (
TaskTypeUpdateStatusByFilter = "update_status_by_filter" TaskTypeUpdateStatusByFilter = "update_status_by_filter"
) )
// DTOs & Constants
const ( const (
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
) )
@@ -83,7 +82,6 @@ func NewAPIKeyService(
gm *GroupManager, gm *GroupManager,
logger *logrus.Logger, logger *logrus.Logger,
) *APIKeyService { ) *APIKeyService {
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
return &APIKeyService{ return &APIKeyService{
db: db, db: db,
keyRepo: repo, keyRepo: repo,
@@ -99,22 +97,22 @@ func NewAPIKeyService(
} }
func (s *APIKeyService) Start() { func (s *APIKeyService) Start() {
requestSub, err := s.store.Subscribe(models.TopicRequestFinished) requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return return
} }
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged) masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err) s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
return return
} }
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged) keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return return
} }
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted) importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err) s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
return return
@@ -177,6 +175,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil { if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
return return
} }
ctx := context.Background()
if event.RequestLog.IsSuccess { if event.RequestLog.IsSuccess {
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID) mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
if err != nil { if err != nil {
@@ -194,17 +193,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
now := time.Now() now := time.Now()
mapping.LastUsedAt = &now mapping.LastUsedAt = &now
if err := s.keyRepo.UpdateMapping(mapping); err != nil { if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err) s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
return return
} }
if statusChanged { if statusChanged {
go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use") go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
} }
return return
} }
if event.Error != nil { if event.Error != nil {
s.judgeKeyErrors( s.judgeKeyErrors(
ctx,
event.CorrelationID, event.CorrelationID,
*event.RequestLog.GroupID, *event.RequestLog.GroupID,
*event.RequestLog.KeyID, *event.RequestLog.KeyID,
@@ -215,6 +215,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
} }
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) { func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
ctx := context.Background()
log := s.logger.WithFields(logrus.Fields{ log := s.logger.WithFields(logrus.Fields{
"group_id": event.GroupID, "group_id": event.GroupID,
"key_id": event.KeyID, "key_id": event.KeyID,
@@ -222,11 +223,11 @@ func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChange
"reason": event.ChangeReason, "reason": event.ChangeReason,
}) })
log.Info("Received KeyStatusChangedEvent, will update polling caches.") log.Info("Received KeyStatusChangedEvent, will update polling caches.")
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus) s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
log.Info("Polling caches updated based on health check event.") log.Info("Polling caches updated based on health check event.")
} }
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
changeEvent := models.KeyStatusChangedEvent{ changeEvent := models.KeyStatusChangedEvent{
KeyID: keyID, KeyID: keyID,
GroupID: groupID, GroupID: groupID,
@@ -236,13 +237,12 @@ func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus,
ChangedAt: time.Now(), ChangedAt: time.Now(),
} }
eventData, _ := json.Marshal(changeEvent) eventData, _ := json.Marshal(changeEvent)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil { if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err) s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
} }
} }
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) { func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
// --- Path 1: High-performance DB pagination (no keyword) ---
if params.Keyword == "" { if params.Keyword == "" {
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params) items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
if err != nil { if err != nil {
@@ -260,14 +260,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
TotalPages: totalPages, TotalPages: totalPages,
}, nil }, nil
} }
// --- Path 2: In-memory search (keyword present) ---
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword) s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
// To get all keys, we fetch all IDs first, then get their full details.
var statusesToFilter []string var statusesToFilter []string
if params.Status != "" { if params.Status != "" {
statusesToFilter = append(statusesToFilter, params.Status) statusesToFilter = append(statusesToFilter, params.Status)
} else { } else {
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status statusesToFilter = append(statusesToFilter, "all")
} }
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter) allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
if err != nil { if err != nil {
@@ -277,14 +275,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
} }
// This is the heavy operation: getting all keys and decrypting them.
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs) allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err) return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
} }
// We also need mappings to build the final `APIKeyDetails`.
var allMappings []models.GroupAPIKeyMapping var allMappings []models.GroupAPIKeyMapping
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err) return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
} }
@@ -292,7 +288,6 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
for i := range allMappings { for i := range allMappings {
mappingMap[allMappings[i].APIKeyID] = &allMappings[i] mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
} }
// Filter the results in memory.
var filteredItems []*models.APIKeyDetails var filteredItems []*models.APIKeyDetails
for _, key := range allKeys { for _, key := range allKeys {
if strings.Contains(key.APIKey, params.Keyword) { if strings.Contains(key.APIKey, params.Keyword) {
@@ -312,11 +307,9 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
} }
} }
} }
// Sort the filtered results to ensure consistent pagination (by ID descending).
sort.Slice(filteredItems, func(i, j int) bool { sort.Slice(filteredItems, func(i, j int) bool {
return filteredItems[i].ID > filteredItems[j].ID return filteredItems[i].ID > filteredItems[j].ID
}) })
// Manually paginate the filtered results.
total := int64(len(filteredItems)) total := int64(len(filteredItems))
start := (params.Page - 1) * params.PageSize start := (params.Page - 1) * params.PageSize
end := start + params.PageSize end := start + params.PageSize
@@ -345,14 +338,15 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}, nil }, nil
} }
func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) { func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) {
return s.keyRepo.GetKeysByIDs(ids) return s.keyRepo.GetKeysByIDs(ids)
} }
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error { func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
go func() { go func() {
bgCtx := context.Background()
var oldKey models.APIKey var oldKey models.APIKey
if err := s.db.First(&oldKey, key.ID).Error; err != nil { if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err) s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
return return
} }
@@ -364,16 +358,14 @@ func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
return nil return nil
} }
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error { func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
// Get all associated groups before deletion to publish correct events groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
groups, err := s.keyRepo.GetGroupsForKey(id)
if err != nil { if err != nil {
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err) s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
} }
err = s.keyRepo.HardDeleteByID(id) err = s.keyRepo.HardDeleteByID(id)
if err == nil { if err == nil {
// Publish events for each group the key was a part of
for _, groupID := range groups { for _, groupID := range groups {
event := models.KeyStatusChangedEvent{ event := models.KeyStatusChangedEvent{
KeyID: id, KeyID: id,
@@ -381,13 +373,13 @@ func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
ChangeReason: "key_hard_deleted", ChangeReason: "key_hard_deleted",
} }
eventData, _ := json.Marshal(event) eventData, _ := json.Marshal(event)
go s.store.Publish(models.TopicKeyStatusChanged, eventData) go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
} }
} }
return err return err
} }
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) { func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
key, err := s.keyRepo.GetKeyByID(keyID) key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil { if err != nil {
return nil, CustomErrors.ParseDBError(err) return nil, CustomErrors.ParseDBError(err)
@@ -409,19 +401,20 @@ func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus model
mapping.ConsecutiveErrorCount = 0 mapping.ConsecutiveErrorCount = 0
mapping.LastError = "" mapping.LastError = ""
} }
if err := s.keyRepo.UpdateMapping(mapping); err != nil { if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
return nil, err return nil, err
} }
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update") go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
return mapping, nil return mapping, nil
} }
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) { func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
ctx := context.Background()
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus) s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
if event.NewMasterStatus != models.MasterStatusRevoked { if event.NewMasterStatus != models.MasterStatusRevoked {
return return
} }
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID) affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
if err != nil { if err != nil {
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID) s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
return return
@@ -432,7 +425,7 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
} }
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs)) s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
for _, groupID := range affectedGroupIDs { for _, groupID := range affectedGroupIDs {
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned) _, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
if err != nil { if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) { if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID) s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
@@ -441,32 +434,32 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
} }
} }
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) { func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
if len(keyIDs) == 0 { if len(keyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.") return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
} }
resourceID := fmt.Sprintf("group-%d", groupID) resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour) taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs) go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
return taskStatus, nil return taskStatus, nil
} }
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) { func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r) s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
} }
}() }()
var mappingsToProcess []models.GroupAPIKeyMapping var mappingsToProcess []models.GroupAPIKeyMapping
err := s.db.Preload("APIKey"). err := s.db.WithContext(ctx).Preload("APIKey").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs). Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Find(&mappingsToProcess).Error Find(&mappingsToProcess).Error
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return return
} }
result := &BatchRestoreResult{ result := &BatchRestoreResult{
@@ -476,7 +469,7 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
processedCount := 0 processedCount := 0
for _, mapping := range mappingsToProcess { for _, mapping := range mappingsToProcess {
processedCount++ processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount) _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
if mapping.APIKey == nil { if mapping.APIKey == nil {
result.SkippedCount++ result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."}) result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
@@ -492,33 +485,29 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
mapping.Status = models.StatusActive mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0 mapping.ConsecutiveErrorCount = 0
mapping.LastError = "" mapping.LastError = ""
// Use the version that doesn't trigger individual cache updates.
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil { if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.") s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
result.SkippedCount++ result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."}) result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
} else { } else {
result.RestoredCount++ result.RestoredCount++
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update. successfulMappings = append(successfulMappings, &mapping)
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore") go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
} }
} else { } else {
result.RestoredCount++ // Already active, count as success. result.RestoredCount++
} }
} }
// After the loop, perform one single, efficient cache update. if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.") s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
// This is not a task-fatal error, so we just log it and continue.
} }
// Account for keys that were requested but not found in the initial DB query.
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess)) result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
} }
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) { func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
var bannedKeyIDs []uint var bannedKeyIDs []uint
err := s.db.Model(&models.GroupAPIKeyMapping{}). err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned). Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
Pluck("api_key_id", &bannedKeyIDs).Error Pluck("api_key_id", &bannedKeyIDs).Error
if err != nil { if err != nil {
@@ -527,10 +516,11 @@ func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, e
if len(bannedKeyIDs) == 0 { if len(bannedKeyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.") return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
} }
return s.StartRestoreKeysTask(groupID, bannedKeyIDs) return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
} }
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) { func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
ctx := context.Background()
group, ok := s.groupManager.GetGroupByID(event.GroupID) group, ok := s.groupManager.GetGroupByID(event.GroupID)
if !ok { if !ok {
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID) s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
@@ -552,7 +542,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
concurrency = *opConfig.KeyCheckConcurrency concurrency = *opConfig.KeyCheckConcurrency
} }
if concurrency <= 0 { if concurrency <= 0 {
concurrency = 10 // Safety fallback concurrency = 10
} }
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs) keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
@@ -571,7 +561,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint) validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
if validationErr == nil { if validationErr == nil {
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID) s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil { if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err) s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
} }
} else { } else {
@@ -579,7 +569,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
if !CustomErrors.As(validationErr, &apiErr) { if !CustomErrors.As(validationErr, &apiErr) {
apiErr = &CustomErrors.APIError{Message: validationErr.Error()} apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
} }
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false) s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false)
} }
} }
}() }()
@@ -592,12 +582,9 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID) s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
} }
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
// that match a specific set of source statuses within a group.
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses) s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses) keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
if err != nil { if err != nil {
return nil, CustomErrors.ParseDBError(err) return nil, CustomErrors.ParseDBError(err)
@@ -605,35 +592,32 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatus
if len(keyIDs) == 0 { if len(keyIDs) == 0 {
now := time.Now() now := time.Now()
return &task.Status{ return &task.Status{
IsRunning: false, // The "task" is not running. IsRunning: false,
Processed: 0, Processed: 0,
Total: 0, Total: 0,
Result: map[string]string{ // We use the flexible Result field to pass the message. Result: map[string]string{
"message": "没有找到任何符合当前过滤条件的Key可供操作。", "message": "没有找到任何符合当前过滤条件的Key可供操作。",
}, },
Error: "", // There is no error. Error: "",
StartedAt: now, StartedAt: now,
FinishedAt: &now, // It started and finished at the same time. FinishedAt: &now,
}, nil // Return nil for the error, signaling a 200 OK. }, nil
} }
// 2. Start a new task using the TaskService, following existing patterns.
resourceID := fmt.Sprintf("group-%d-status-update", groupID) resourceID := fmt.Sprintf("group-%d-status-update", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute) taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
if err != nil { if err != nil {
return nil, err // Pass up errors like "task already in progress". return nil, err
} }
// 3. Run the core logic in a separate goroutine. go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
return taskStatus, nil return taskStatus, nil
} }
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task. func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r) s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
} }
}() }()
type BatchUpdateResult struct { type BatchUpdateResult struct {
@@ -642,31 +626,27 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
} }
result := &BatchUpdateResult{} result := &BatchUpdateResult{}
var successfulMappings []*models.GroupAPIKeyMapping var successfulMappings []*models.GroupAPIKeyMapping
// 1. Fetch all key master statuses in one go. This is efficient.
keys, err := s.keyRepo.GetKeysByIDs(keyIDs) keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
if err != nil { if err != nil {
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.") s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return return
} }
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus) masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
for _, key := range keys { for _, key := range keys {
masterStatusMap[key.ID] = key.MasterStatus masterStatusMap[key.ID] = key.MasterStatus
} }
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
// avoiding the need for a new repository method. This pattern is
// already used in other parts of this service.
var mappings []*models.GroupAPIKeyMapping var mappings []*models.GroupAPIKeyMapping
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil { if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.") s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return return
} }
processedCount := 0 processedCount := 0
for _, mapping := range mappings { for _, mapping := range mappings {
processedCount++ processedCount++
// The progress update should reflect the number of items *being processed*, not the final count. _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
masterStatus, ok := masterStatusMap[mapping.APIKeyID] masterStatus, ok := masterStatusMap[mapping.APIKeyID]
if !ok { if !ok {
result.SkippedCount++ result.SkippedCount++
@@ -688,24 +668,25 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
} else { } else {
result.UpdatedCount++ result.UpdatedCount++
successfulMappings = append(successfulMappings, mapping) successfulMappings = append(successfulMappings, mapping)
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update") go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
} }
} else { } else {
result.UpdatedCount++ // Already in desired state, count as success. result.UpdatedCount++
} }
} }
result.SkippedCount += (len(keyIDs) - len(mappings)) result.SkippedCount += (len(keyIDs) - len(mappings))
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil { if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.") s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
} }
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount) s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
} }
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) { func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
ctx := context.Background()
if success { if success {
if group.PollingStrategy == models.StrategyWeighted { if group.PollingStrategy == models.StrategyWeighted {
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID) go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
} }
return return
} }
@@ -716,26 +697,20 @@ func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.
errMsg := apiErr.Message errMsg := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) { if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg) s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown) go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
} else { } else {
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg) s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
} }
} }
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
func sanitizeForLog(errMsg string) string { func sanitizeForLog(errMsg string) string {
// Find the start of any potential JSON blob or detailed structure.
jsonStartIndex := strings.Index(errMsg, "{") jsonStartIndex := strings.Index(errMsg, "{")
var cleanMsg string var cleanMsg string
if jsonStartIndex != -1 { if jsonStartIndex != -1 {
// If a '{' is found, take everything before it as the summary
// and append a simple placeholder.
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}" cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
} else { } else {
// If no JSON-like structure is found, use the original message.
cleanMsg = errMsg cleanMsg = errMsg
} }
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
const maxLen = 250 const maxLen = 250
if len(cleanMsg) > maxLen { if len(cleanMsg) > maxLen {
return cleanMsg[:maxLen] + "..." return cleanMsg[:maxLen] + "..."
@@ -744,6 +719,7 @@ func sanitizeForLog(errMsg string) string {
} }
func (s *APIKeyService) judgeKeyErrors( func (s *APIKeyService) judgeKeyErrors(
ctx context.Context,
correlationID string, correlationID string,
groupID, keyID uint, groupID, keyID uint,
apiErr *CustomErrors.APIError, apiErr *CustomErrors.APIError,
@@ -765,11 +741,11 @@ func (s *APIKeyService) judgeKeyErrors(
oldStatus := mapping.Status oldStatus := mapping.Status
mapping.Status = models.StatusBanned mapping.Status = models.StatusBanned
mapping.LastError = errorMessage mapping.LastError = errorMessage
if err := s.keyRepo.UpdateMapping(mapping); err != nil { if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping status to BANNED.") logger.WithError(err).Error("Failed to update mapping status to BANNED.")
} else { } else {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned") go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(keyID, "permanent_upstream_error") go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
} }
} }
return return
@@ -801,23 +777,23 @@ func (s *APIKeyService) judgeKeyErrors(
if oldStatus != newStatus { if oldStatus != newStatus {
mapping.Status = newStatus mapping.Status = newStatus
} }
if err := s.keyRepo.UpdateMapping(mapping); err != nil { if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping after temporary error.") logger.WithError(err).Error("Failed to update mapping after temporary error.")
return return
} }
if oldStatus != newStatus { if oldStatus != newStatus {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached") go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
} }
return return
} }
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage)) logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.") logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
if err := s.keyRepo.UpdateMapping(mapping); err != nil { if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.") logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
} }
} }
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) { func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) {
key, err := s.keyRepo.GetKeyByID(keyID) key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -832,7 +808,7 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
} }
oldMasterStatus := key.MasterStatus oldMasterStatus := key.MasterStatus
newMasterStatus := models.MasterStatusRevoked newMasterStatus := models.MasterStatusRevoked
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil { if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err) s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
return return
} }
@@ -844,9 +820,9 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
ChangedAt: time.Now(), ChangedAt: time.Now(),
} }
eventData, _ := json.Marshal(masterKeyEvent) eventData, _ := json.Marshal(masterKeyEvent)
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData) _ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData)
} }
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) { func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses) return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
} }

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/dashboard_query_service.go // Filename: internal/service/dashboard_query_service.go
package service package service
import ( import (
"context"
"fmt" "fmt"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"gemini-balancer/internal/store" "gemini-balancer/internal/store"
@@ -17,8 +17,6 @@ import (
const overviewCacheChannel = "syncer:cache:dashboard_overview" const overviewCacheChannel = "syncer:cache:dashboard_overview"
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
type DashboardQueryService struct { type DashboardQueryService struct {
db *gorm.DB db *gorm.DB
store store.Store store store.Store
@@ -54,9 +52,9 @@ func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.") s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
} }
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) { func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
statsKey := fmt.Sprintf("stats:group:%d", groupID) statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(statsKey) keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil { if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID) s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err) return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
@@ -74,11 +72,11 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
SuccessRequests int64 SuccessRequests int64
} }
var last1Hour, last24Hours requestStatsResult var last1Hour, last24Hours requestStatsResult
s.db.Model(&models.StatsHourly{}). s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo). Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour) Scan(&last1Hour)
s.db.Model(&models.StatsHourly{}). s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests"). Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo). Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours) Scan(&last24Hours)
@@ -109,8 +107,9 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
} }
func (s *DashboardQueryService) eventListener() { func (s *DashboardQueryService) eventListener() {
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged) ctx := context.Background()
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged) keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close() defer keyStatusSub.Close()
defer upstreamStatusSub.Close() defer upstreamStatusSub.Close()
for { for {
@@ -128,7 +127,6 @@ func (s *DashboardQueryService) eventListener() {
} }
} }
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) { func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get() cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil { if cachedDataPtr == nil {
@@ -141,8 +139,7 @@ func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate() return s.overviewSyncer.Invalidate()
} }
// QueryHistoricalChart 查询历史图表数据。 func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
type ChartPoint struct { type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"` TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"` ModelName string `gorm:"column:model_name"`
@@ -151,7 +148,7 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour) sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
sqlFormat, goFormat := s.buildTimeFormatSelectClause() sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat) selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC") query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
if groupID != nil && *groupID > 0 { if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID) query = query.Where("group_id = ?", *groupID)
} }
@@ -189,38 +186,38 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
} }
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) { func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx := context.Background()
s.logger.Info("[CacheSyncer] Starting to load overview data from database...") s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
startTime := time.Now() startTime := time.Now()
resp := &models.DashboardStatsResponse{ resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64), KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64), MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体而不是nil KeyCount: models.StatCard{},
RequestCount24h: models.StatCard{}, // 同上 RequestCount24h: models.StatCard{},
TokenCount: make(map[string]any), TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string), UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{}, RPM: models.StatCard{},
RequestCounts: make(map[string]int64), RequestCounts: make(map[string]int64),
} }
// --- 1. Aggregate Operational Status from Mappings ---
type MappingStatusResult struct { type MappingStatusResult struct {
Status models.APIKeyStatus Status models.APIKeyStatus
Count int64 Count int64
} }
var mappingStatusResults []MappingStatusResult var mappingStatusResults []MappingStatusResult
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil { if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query mapping status stats: %w", err) return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
} }
for _, res := range mappingStatusResults { for _, res := range mappingStatusResults {
resp.KeyStatusCount[res.Status] = res.Count resp.KeyStatusCount[res.Status] = res.Count
} }
// --- 2. Aggregate Master Status from APIKeys ---
type MasterStatusResult struct { type MasterStatusResult struct {
Status models.MasterAPIKeyStatus Status models.MasterAPIKeyStatus
Count int64 Count int64
} }
var masterStatusResults []MasterStatusResult var masterStatusResults []MasterStatusResult
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil { if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query master status stats: %w", err) return nil, fmt.Errorf("failed to query master status stats: %w", err)
} }
var totalKeys, invalidKeys int64 var totalKeys, invalidKeys int64
@@ -235,20 +232,15 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
now := time.Now() now := time.Now()
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
var count1m, count1h, count1d int64 var count1m, count1h, count1d int64
// RPM: 从此刻倒推1分钟 s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m) s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
// RPH: 从此刻倒推1小时
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
// RPD: 从今天零点 (UTC) 到此刻
year, month, day := now.UTC().Date() year, month, day := now.UTC().Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC) startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d) s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
var count30d int64 var count30d int64
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d) s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
resp.RequestCounts["1m"] = count1m resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h resp.RequestCounts["1h"] = count1h
@@ -256,7 +248,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
resp.RequestCounts["30d"] = count30d resp.RequestCounts["30d"] = count30d
var upstreams []*models.UpstreamEndpoint var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil { if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.") s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
} else { } else {
for _, u := range upstreams { for _, u := range upstreams {
@@ -269,7 +261,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
return resp, nil return resp, nil
} }
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) { func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
var startTime time.Time var startTime time.Time
now := time.Now() now := time.Now()
switch period { switch period {
@@ -288,7 +280,7 @@ func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H,
Success int64 Success int64
} }
err := s.db.Model(&models.RequestLog{}). err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success"). Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
Where("request_time >= ?", startTime). Where("request_time >= ?", startTime).
Scan(&result).Error Scan(&result).Error

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/db_log_writer_service.go // Filename: internal/service/db_log_writer_service.go
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
"gemini-balancer/internal/settings" "gemini-balancer/internal/settings"
@@ -35,35 +35,30 @@ func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.Settin
store: s, store: s,
SettingsManager: settings, SettingsManager: settings,
logger: logger.WithField("component", "DBLogWriter📝"), logger: logger.WithField("component", "DBLogWriter📝"),
// 使用配置值来创建缓冲区 logBuffer: make(chan *models.RequestLog, bufferCapacity),
logBuffer: make(chan *models.RequestLog, bufferCapacity), stopChan: make(chan struct{}),
stopChan: make(chan struct{}),
} }
} }
func (s *DBLogWriterService) Start() { func (s *DBLogWriterService) Start() {
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入 s.wg.Add(2)
// 启动事件监听器
go s.eventListenerLoop() go s.eventListenerLoop()
// 启动数据库写入器
go s.dbWriterLoop() go s.dbWriterLoop()
s.logger.Info("DBLogWriterService started.") s.logger.Info("DBLogWriterService started.")
} }
func (s *DBLogWriterService) Stop() { func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...") s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan) // 通知所有goroutine停止 close(s.stopChan)
s.wg.Wait() // 等待所有goroutine完成 s.wg.Wait()
s.logger.Info("DBLogWriterService stopped.") s.logger.Info("DBLogWriterService stopped.")
} }
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
func (s *DBLogWriterService) eventListenerLoop() { func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done() defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished) ctx := context.Background()
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err) s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return return
@@ -80,34 +75,27 @@ func (s *DBLogWriterService) eventListenerLoop() {
s.logger.Errorf("Failed to unmarshal event for logging: %v", err) s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
continue continue
} }
// 将事件中的日志部分放入缓冲区
select { select {
case s.logBuffer <- &event.RequestLog: case s.logBuffer <- &event.RequestLog:
default: default:
s.logger.Warn("Log buffer is full. A log message might be dropped.") s.logger.Warn("Log buffer is full. A log message might be dropped.")
} }
case <-s.stopChan: case <-s.stopChan:
s.logger.Info("Event listener loop stopping.") s.logger.Info("Event listener loop stopping.")
// 关闭缓冲区以通知dbWriterLoop处理完剩余日志后退出
close(s.logBuffer) close(s.logBuffer)
return return
} }
} }
} }
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
func (s *DBLogWriterService) dbWriterLoop() { func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done() defer s.wg.Done()
// 在启动时获取一次配置
cfg := s.SettingsManager.GetSettings() cfg := s.SettingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 { if batchSize <= 0 {
batchSize = 100 batchSize = 100
} }
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushTimeout <= 0 { if flushTimeout <= 0 {
flushTimeout = 5 * time.Second flushTimeout = 5 * time.Second
@@ -126,7 +114,7 @@ func (s *DBLogWriterService) dbWriterLoop() {
return return
} }
batch = append(batch, logEntry) batch = append(batch, logEntry)
if len(batch) >= batchSize { // 使用配置的批次大小 if len(batch) >= batchSize {
s.flushBatch(batch) s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize) batch = make([]*models.RequestLog, 0, batchSize)
} }
@@ -139,7 +127,6 @@ func (s *DBLogWriterService) dbWriterLoop() {
} }
} }
// flushBatch 将一个批次的日志写入数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) { func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil { if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.") s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")

View File

@@ -75,7 +75,7 @@ func NewHealthCheckService(
func (s *HealthCheckService) Start() { func (s *HealthCheckService) Start() {
s.logger.Info("Starting HealthCheckService with independent check loops...") s.logger.Info("Starting HealthCheckService with independent check loops...")
s.wg.Add(4) // Now four loops s.wg.Add(4)
go s.runKeyCheckLoop() go s.runKeyCheckLoop()
go s.runUpstreamCheckLoop() go s.runUpstreamCheckLoop()
go s.runProxyCheckLoop() go s.runProxyCheckLoop()
@@ -102,8 +102,6 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
func (s *HealthCheckService) runKeyCheckLoop() { func (s *HealthCheckService) runKeyCheckLoop() {
defer s.wg.Done() defer s.wg.Done()
s.logger.Info("Key check dynamic scheduler loop started.") s.logger.Info("Key check dynamic scheduler loop started.")
// 主调度循环,每分钟检查一次任务
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -126,26 +124,22 @@ func (s *HealthCheckService) scheduleKeyChecks() {
defer s.groupCheckTimeMutex.Unlock() defer s.groupCheckTimeMutex.Unlock()
for _, group := range groups { for _, group := range groups {
// 获取特定于组的运营配置
opConfig, err := s.groupManager.BuildOperationalConfig(group) opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil { if err != nil {
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.") s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
continue continue
} }
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck { if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
continue // 跳过禁用了健康检查的组 continue
} }
var intervalMinutes int var intervalMinutes int
if opConfig.KeyCheckIntervalMinutes != nil { if opConfig.KeyCheckIntervalMinutes != nil {
intervalMinutes = *opConfig.KeyCheckIntervalMinutes intervalMinutes = *opConfig.KeyCheckIntervalMinutes
} }
interval := time.Duration(intervalMinutes) * time.Minute interval := time.Duration(intervalMinutes) * time.Minute
if interval <= 0 { if interval <= 0 {
continue // 跳过无效的检查周期 continue
} }
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) { if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID) s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
go s.performKeyChecksForGroup(group, opConfig) go s.performKeyChecksForGroup(group, opConfig)
@@ -160,7 +154,6 @@ func (s *HealthCheckService) runUpstreamCheckLoop() {
if s.SettingsManager.GetSettings().EnableUpstreamCheck { if s.SettingsManager.GetSettings().EnableUpstreamCheck {
s.performUpstreamChecks() s.performUpstreamChecks()
} }
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second) ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -184,7 +177,6 @@ func (s *HealthCheckService) runProxyCheckLoop() {
if s.SettingsManager.GetSettings().EnableProxyCheck { if s.SettingsManager.GetSettings().EnableProxyCheck {
s.performProxyChecks() s.performProxyChecks()
} }
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second) ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -203,6 +195,7 @@ func (s *HealthCheckService) runProxyCheckLoop() {
} }
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) { func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
ctx := context.Background()
settings := s.SettingsManager.GetSettings() settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -213,11 +206,9 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
} }
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name}) log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
log.Infof("Starting key health check cycle.") log.Infof("Starting key health check cycle.")
var mappingsToCheck []models.GroupAPIKeyMapping var mappingsToCheck []models.GroupAPIKeyMapping
err = s.db.Model(&models.GroupAPIKeyMapping{}). err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id"). Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
Where("group_api_key_mappings.key_group_id = ?", group.ID). Where("group_api_key_mappings.key_group_id = ?", group.ID).
Where("api_keys.master_status = ?", models.MasterStatusActive). Where("api_keys.master_status = ?", models.MasterStatusActive).
@@ -233,7 +224,6 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("No key mappings to check for this group.") log.Info("No key mappings to check for this group.")
return return
} }
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck)) log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck)) jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -242,14 +232,14 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
concurrency = *opConfig.KeyCheckConcurrency concurrency = *opConfig.KeyCheckConcurrency
} }
if concurrency <= 0 { if concurrency <= 0 {
concurrency = 1 // 保证至少有一个 worker concurrency = 1
} }
for w := 1; w <= concurrency; w++ { for w := 1; w <= concurrency; w++ {
wg.Add(1) wg.Add(1)
go func(workerID int) { go func(workerID int) {
defer wg.Done() defer wg.Done()
for mapping := range jobs { for mapping := range jobs {
s.checkAndProcessMapping(&mapping, timeout, endpoint) s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
} }
}(w) }(w)
} }
@@ -261,52 +251,46 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("Finished key health check cycle.") log.Info("Finished key health check cycle.")
} }
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) { func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
if mapping.APIKey == nil { if mapping.APIKey == nil {
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID) s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
return return
} }
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint) validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
// --- 诊断一:验证成功 (健康) ---
if validationErr == nil { if validationErr == nil {
if mapping.Status != models.StatusActive { if mapping.Status != models.StatusActive {
s.activateMapping(mapping) s.activateMapping(ctx, mapping)
} }
return return
} }
errorString := validationErr.Error() errorString := validationErr.Error()
// --- 诊断二:永久性错误 ---
if CustomErrors.IsPermanentUpstreamError(errorString) { if CustomErrors.IsPermanentUpstreamError(errorString) {
s.revokeMapping(mapping, validationErr) s.revokeMapping(ctx, mapping, validationErr)
return return
} }
// --- 诊断三:暂时性错误 ---
if CustomErrors.IsTemporaryUpstreamError(errorString) { if CustomErrors.IsTemporaryUpstreamError(errorString) {
// Log with a higher level (WARN) since this is an actionable, proactive finding.
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr) s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
s.penalizeMapping(mapping, validationErr) s.penalizeMapping(ctx, mapping, validationErr)
return return
} }
// --- 诊断四:其他未知或上游服务错误 ---
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr) s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
} }
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) { func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
oldStatus := mapping.Status oldStatus := mapping.Status
mapping.Status = models.StatusActive mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0 mapping.ConsecutiveErrorCount = 0
mapping.LastError = "" mapping.LastError = ""
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return return
} }
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus) s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
} }
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) { func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
// Re-fetch group-specific operational config to get the correct thresholds
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID) group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
if !ok { if !ok {
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID) s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
@@ -320,7 +304,6 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
oldStatus := mapping.Status oldStatus := mapping.Status
mapping.LastError = err.Error() mapping.LastError = err.Error()
mapping.ConsecutiveErrorCount++ mapping.ConsecutiveErrorCount++
// Use the group-specific threshold
threshold := *opConfig.KeyBlacklistThreshold threshold := *opConfig.KeyBlacklistThreshold
if mapping.ConsecutiveErrorCount >= threshold { if mapping.ConsecutiveErrorCount >= threshold {
mapping.Status = models.StatusCooldown mapping.Status = models.StatusCooldown
@@ -329,44 +312,41 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
mapping.CooldownUntil = &cooldownTime mapping.CooldownUntil = &cooldownTime
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration) s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
} }
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil { if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return return
} }
if oldStatus != mapping.Status { if oldStatus != mapping.Status {
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
} }
} }
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) { func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
oldStatus := mapping.Status oldStatus := mapping.Status
if oldStatus == models.StatusBanned { if oldStatus == models.StatusBanned {
return // Already banned, do nothing. return
} }
mapping.Status = models.StatusBanned mapping.Status = models.StatusBanned
mapping.LastError = "Definitive error: " + err.Error() mapping.LastError = "Definitive error: " + err.Error()
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group mapping.ConsecutiveErrorCount = 0
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID) s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return return
} }
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err) s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status) s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID) s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil { if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID) s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
} }
} }
func (s *HealthCheckService) performUpstreamChecks() { func (s *HealthCheckService) performUpstreamChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings() settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
var upstreams []*models.UpstreamEndpoint var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil { if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve upstreams.") s.logger.WithError(err).Error("Failed to retrieve upstreams.")
return return
} }
@@ -390,10 +370,10 @@ func (s *HealthCheckService) performUpstreamChecks() {
s.lastResultsMutex.Unlock() s.lastResultsMutex.Unlock()
if oldStatus != newStatus { if oldStatus != newStatus {
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus) s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil { if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.") s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
} else { } else {
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus) s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
} }
} }
}(u) }(u)
@@ -412,10 +392,11 @@ func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration)
} }
func (s *HealthCheckService) performProxyChecks() { func (s *HealthCheckService) performProxyChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings() settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
var proxies []*models.ProxyConfig var proxies []*models.ProxyConfig
if err := s.db.Find(&proxies).Error; err != nil { if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve proxies.") s.logger.WithError(err).Error("Failed to retrieve proxies.")
return return
} }
@@ -438,7 +419,7 @@ func (s *HealthCheckService) performProxyChecks() {
s.lastResultsMutex.Unlock() s.lastResultsMutex.Unlock()
if proxyCfg.Status != newStatus { if proxyCfg.Status != newStatus {
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus) s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil { if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.") s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
} }
} }
@@ -482,7 +463,7 @@ func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout ti
return true return true
} }
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) { func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
event := models.KeyStatusChangedEvent{ event := models.KeyStatusChangedEvent{
KeyID: keyID, KeyID: keyID,
GroupID: groupID, GroupID: groupID,
@@ -496,12 +477,12 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, o
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID) s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
return return
} }
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil { if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID) s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
} }
} }
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) { func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
event := models.UpstreamHealthChangedEvent{ event := models.UpstreamHealthChangedEvent{
UpstreamID: upstream.ID, UpstreamID: upstream.ID,
UpstreamURL: upstream.URL, UpstreamURL: upstream.URL,
@@ -516,28 +497,20 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.") s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
return return
} }
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil { if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.") s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
} }
} }
// =========================================================================
// Global Base Key Check (New Logic)
// =========================================================================
func (s *HealthCheckService) runBaseKeyCheckLoop() { func (s *HealthCheckService) runBaseKeyCheckLoop() {
defer s.wg.Done() defer s.wg.Done()
s.logger.Info("Global base key check loop started.") s.logger.Info("Global base key check loop started.")
settings := s.SettingsManager.GetSettings() settings := s.SettingsManager.GetSettings()
if !settings.EnableBaseKeyCheck { if !settings.EnableBaseKeyCheck {
s.logger.Info("Global base key check is disabled.") s.logger.Info("Global base key check is disabled.")
return return
} }
// Perform an initial check on startup
s.performBaseKeyChecks() s.performBaseKeyChecks()
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
if interval <= 0 { if interval <= 0 {
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes) s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
@@ -558,6 +531,7 @@ func (s *HealthCheckService) runBaseKeyCheckLoop() {
} }
func (s *HealthCheckService) performBaseKeyChecks() { func (s *HealthCheckService) performBaseKeyChecks() {
ctx := context.Background()
s.logger.Info("Starting global base key check cycle.") s.logger.Info("Starting global base key check cycle.")
settings := s.SettingsManager.GetSettings() settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -576,7 +550,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
jobs := make(chan *models.APIKey, len(keys)) jobs := make(chan *models.APIKey, len(keys))
var wg sync.WaitGroup var wg sync.WaitGroup
if concurrency <= 0 { if concurrency <= 0 {
concurrency = 5 // Safe default concurrency = 5
} }
for w := 0; w < concurrency; w++ { for w := 0; w < concurrency; w++ {
wg.Add(1) wg.Add(1)
@@ -587,10 +561,10 @@ func (s *HealthCheckService) performBaseKeyChecks() {
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) { if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
oldStatus := key.MasterStatus oldStatus := key.MasterStatus
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err) s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil { if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID) s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
} else { } else {
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked) s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
} }
} }
} }
@@ -604,8 +578,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
s.logger.Info("Global base key check cycle finished.") s.logger.Info("Global base key check cycle finished.")
} }
// 事件发布辅助函数 func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
event := models.MasterKeyStatusChangedEvent{ event := models.MasterKeyStatusChangedEvent{
KeyID: keyID, KeyID: keyID,
OldMasterStatus: oldStatus, OldMasterStatus: oldStatus,
@@ -618,7 +591,7 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldS
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID) s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
return return
} }
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil { if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID) s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
} }
} }

View File

@@ -2,6 +2,7 @@
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -42,88 +43,84 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
} }
} }
// --- 通用的 Panic-Safe 任務執行器 --- func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r) err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err) s.logger.Error(err)
s.taskService.EndTaskByID(taskID, resourceID, nil, err) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
} }
}() }()
taskFunc() taskFunc()
} }
// --- Public Task Starters --- func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in input text") return nil, fmt.Errorf("no valid keys found in input text")
} }
resourceID := fmt.Sprintf("group-%d", groupID) resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport) s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
}) })
return taskStatus, nil return taskStatus, nil
} }
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) { func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found") return nil, fmt.Errorf("no valid keys found")
} }
resourceID := fmt.Sprintf("group-%d", groupID) resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour) taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys) s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
}) })
return taskStatus, nil return taskStatus, nil
} }
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) { func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found") return nil, fmt.Errorf("no valid keys found")
} }
resourceID := "global_hard_delete" // Global lock resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour) taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys) s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
}) })
return taskStatus, nil return taskStatus, nil
} }
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) { func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText) keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 { if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found") return nil, fmt.Errorf("no valid keys found")
} }
resourceID := "global_restore_keys" // Global lock resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() { go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys) s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
}) })
return taskStatus, nil return taskStatus, nil
} }
// --- Private Task Runners --- func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 步骤 1: 对输入的原始 key 列表进行去重。
uniqueKeysMap := make(map[string]struct{}) uniqueKeysMap := make(map[string]struct{})
var uniqueKeyStrings []string var uniqueKeyStrings []string
for _, kStr := range keys { for _, kStr := range keys {
@@ -133,41 +130,37 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
} }
} }
if len(uniqueKeyStrings) == 0 { if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return return
} }
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings)) keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings { for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr} keysToEnsure[i] = models.APIKey{APIKey: keyStr}
} }
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure) allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return return
} }
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID) alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
return return
} }
alreadyLinkedIDSet := make(map[uint]struct{}) alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels { for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{} alreadyLinkedIDSet[key.ID] = struct{}{}
} }
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
var keysToLink []models.APIKey var keysToLink []models.APIKey
for _, key := range allKeyModels { for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists { if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key) keysToLink = append(keysToLink, key)
} }
} }
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil { if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
} }
// 步骤 6: 分块处理【链接Key到组】的操作并实时更新进度。
if len(keysToLink) > 0 { if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink)) idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink { for i, key := range keysToLink {
@@ -179,44 +172,41 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
end = len(idsToLink) end = len(idsToLink)
} }
chunk := idsToLink[i:end] chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil { if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
return return
} }
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
} }
} }
// 步骤 7: 准备最终结果并结束任务。
result := gin.H{ result := gin.H{
"newly_linked_count": len(keysToLink), "newly_linked_count": len(keysToLink),
"already_linked_count": len(alreadyLinkedIDSet), "already_linked_count": len(alreadyLinkedIDSet),
"total_linked_count": len(allKeyModels), "total_linked_count": len(allKeyModels),
} }
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
if len(keysToLink) > 0 { if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink)) idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink { for i, key := range keysToLink {
idsToLink[i] = key.ID idsToLink[i] = key.ID
} }
if validateOnImport { if validateOnImport {
s.publishImportGroupCompletedEvent(groupID, idsToLink) s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
for _, keyID := range idsToLink { for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked") s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
} }
} else { } else {
for _, keyID := range idsToLink { for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil { if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err) s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
} }
} }
} }
} }
s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
} }
// runUnlinkKeysTask func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
uniqueKeysMap := make(map[string]struct{}) uniqueKeysMap := make(map[string]struct{})
var uniqueKeys []string var uniqueKeys []string
for _, kStr := range keys { for _, kStr := range keys {
@@ -225,46 +215,42 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
uniqueKeys = append(uniqueKeys, kStr) uniqueKeys = append(uniqueKeys, kStr)
} }
} }
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID) keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return return
} }
if len(keysToUnlink) == 0 { if len(keysToUnlink) == 0 {
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)} result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
return return
} }
idsToUnlink := make([]uint, len(keysToUnlink)) idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink { for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID idsToUnlink[i] = key.ID
} }
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil { if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID) s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
} }
var totalUnlinked int64 var totalUnlinked int64
// 步骤 3: 分块处理【解绑Key】的操作并上报进度。
for i := 0; i < len(idsToUnlink); i += chunkSize { for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize end := i + chunkSize
if end > len(idsToUnlink) { if end > len(idsToUnlink) {
end = len(idsToUnlink) end = len(idsToUnlink)
} }
chunk := idsToUnlink[i:end] chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk) unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
return return
} }
totalUnlinked += unlinked totalUnlinked += unlinked
for _, keyID := range chunk { for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked") s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
} }
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
} }
totalDeleted, err := s.keyRepo.DeleteOrphanKeys() totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
@@ -276,10 +262,10 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
"hard_deleted_count": totalDeleted, "hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked), "not_found_count": len(uniqueKeys) - int(totalUnlinked),
} }
s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
} }
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) { func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var totalDeleted int64 var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize { for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize end := i + chunkSize
@@ -290,22 +276,21 @@ func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys
deleted, err := s.keyRepo.HardDeleteByValues(chunk) deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return return
} }
totalDeleted += deleted totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
} }
result := gin.H{ result := gin.H{
"hard_deleted_count": totalDeleted, "hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted, "not_found_count": int64(len(keys)) - totalDeleted,
} }
s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_hard_deleted") // Global event s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
} }
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) { func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var restoredCount int64 var restoredCount int64
for i := 0; i < len(keys); i += chunkSize { for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize end := i + chunkSize
@@ -316,21 +301,21 @@ func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive) count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err)) s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return return
} }
restoredCount += count restoredCount += count
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk)) _ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
} }
result := gin.H{ result := gin.H{
"restored_count": restoredCount, "restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount, "not_found_count": int64(len(keys)) - restoredCount,
} }
s.taskService.EndTaskByID(taskID, resourceID, result, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_bulk_restored") // Global event s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
} }
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) { func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{ event := models.KeyStatusChangedEvent{
GroupID: groupID, GroupID: groupID,
KeyID: keyID, KeyID: keyID,
@@ -340,7 +325,7 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
ChangedAt: time.Now(), ChangedAt: time.Now(),
} }
eventData, _ := json.Marshal(event) eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil { if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{ s.logger.WithError(err).WithFields(logrus.Fields{
"group_id": groupID, "group_id": groupID,
"key_id": keyID, "key_id": keyID,
@@ -349,16 +334,16 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
} }
} }
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) { func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{ event := models.KeyStatusChangedEvent{
GroupID: groupID, GroupID: groupID,
ChangeReason: reason, ChangeReason: reason,
} }
eventData, _ := json.Marshal(event) eventData, _ := json.Marshal(event)
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData) _ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
} }
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) { func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 { if len(keyIDs) == 0 {
return return
} }
@@ -372,17 +357,15 @@ func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.") s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
return return
} }
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil { if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.") s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
} else { } else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs)) s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
} }
} }
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter. func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses) s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
// 1. [New] Find the keys to operate on.
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to find keys by filter: %w", err) return nil, fmt.Errorf("failed to find keys by filter: %w", err)
@@ -390,8 +373,7 @@ func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []
if len(keyValues) == 0 { if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter") return nil, fmt.Errorf("no keys found matching the provided filter")
} }
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
keysAsText := strings.Join(keyValues, "\n") keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID) s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(groupID, keysAsText) return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
} }

View File

@@ -2,6 +2,7 @@
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/channel" "gemini-balancer/internal/channel"
@@ -62,20 +63,18 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
return fmt.Errorf("failed to create request: %w", err) return fmt.Errorf("failed to create request: %w", err)
} }
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request s.channel.ModifyRequest(req, key)
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
// This is a network-level error (e.g., timeout, DNS issue)
return fmt.Errorf("request failed: %w", err) return fmt.Errorf("request failed: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {
return nil // Success return nil
} }
// Read the body for more error details
bodyBytes, readErr := io.ReadAll(resp.Body) bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string var errorMsg string
if readErr != nil { if readErr != nil {
@@ -84,7 +83,6 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
errorMsg = string(bodyBytes) errorMsg = string(bodyBytes)
} }
// This is a validation failure with a specific HTTP status code
return &CustomErrors.APIError{ return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode, HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg), Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
@@ -92,8 +90,7 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
} }
} }
// --- 异步任务方法 (全面适配新task包) --- func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
keyStrings := utils.ParseKeysFromText(keysText) keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 { if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text") return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
@@ -111,7 +108,6 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
} }
group, ok := s.groupManager.GetGroupByID(groupID) group, ok := s.groupManager.GetGroupByID(groupID)
if !ok { if !ok {
// [FIX] Correctly use the NewAPIError constructor for a missing group.
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID)) return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
} }
opConfig, err := s.groupManager.BuildOperationalConfig(group) opConfig, err := s.groupManager.BuildOperationalConfig(group)
@@ -119,15 +115,15 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err)) return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
} }
resourceID := fmt.Sprintf("group-%d", groupID) resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour) taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil { if err != nil {
return nil, err // Pass up the error from task service (e.g., "task already running") return nil, err
} }
settings := s.SettingsManager.GetSettings() settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID) endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil { if err != nil {
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
return nil, err return nil, err
} }
var concurrency int var concurrency int
@@ -136,11 +132,11 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
} else { } else {
concurrency = settings.BaseKeyCheckConcurrency concurrency = settings.BaseKeyCheckConcurrency
} }
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency) go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
return taskStatus, nil return taskStatus, nil
} }
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) { func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup var wg sync.WaitGroup
var mu sync.Mutex var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys)) finalResults := make([]models.KeyTestResult, len(keys))
@@ -165,7 +161,6 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
var currentResult models.KeyTestResult var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{ event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{ RequestLog: models.RequestLog{
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
GroupID: &groupID, GroupID: &groupID,
KeyID: &apiKeyModel.ID, KeyID: &apiKeyModel.ID,
}, },
@@ -185,14 +180,15 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
event.RequestLog.IsSuccess = false event.RequestLog.IsSuccess = false
} }
eventData, _ := json.Marshal(event) eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID) s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
} }
mu.Lock() mu.Lock()
finalResults[j.Index] = currentResult finalResults[j.Index] = currentResult
processedCount++ processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount) _ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
mu.Unlock() mu.Unlock()
} }
}() }()
@@ -202,10 +198,10 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
} }
close(jobs) close(jobs)
wg.Wait() wg.Wait()
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil) s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
} }
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) { func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses) s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses) keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil { if err != nil {
@@ -216,5 +212,5 @@ func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses
} }
keysAsText := strings.Join(keyValues, "\n") keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID) s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(groupID, keysAsText) return s.StartTestKeysTask(ctx, groupID, keysAsText)
} }

View File

@@ -3,6 +3,7 @@
package service package service
import ( import (
"context"
"errors" "errors"
apperrors "gemini-balancer/internal/errors" apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -43,7 +44,6 @@ func NewResourceService(
aks *APIKeyService, aks *APIKeyService,
logger *logrus.Logger, logger *logrus.Logger,
) *ResourceService { ) *ResourceService {
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
rs := &ResourceService{ rs := &ResourceService{
settingsManager: sm, settingsManager: sm,
groupManager: gm, groupManager: gm,
@@ -56,43 +56,40 @@ func NewResourceService(
go rs.preWarmCache(logger) go rs.preWarmCache(logger)
}) })
return rs return rs
} }
// --- [模式一:智能聚合模式] --- func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"}) log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.") log.Debug("Entering BasePool resource acquisition.")
// 1.筛选出所有符合条件的候选组,并按优先级排序
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups) candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
if len(candidateGroups) == 0 { if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.") log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable return nil, apperrors.ErrNoKeysAvailable
} }
// 2.从 BasePool中根据系统全局策略选择一个Key
basePool := &repository.BasePool{ basePool := &repository.BasePool{
CandidateGroups: candidateGroups, CandidateGroups: candidateGroups,
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy, PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
} }
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
if err != nil { if err != nil {
log.WithError(err).Warn("Failed to select a key from the BasePool.") log.WithError(err).Warn("Failed to select a key from the BasePool.")
return nil, apperrors.ErrNoKeysAvailable return nil, apperrors.ErrNoKeysAvailable
} }
// 3. 组装最终资源
// [关键] 在此模式下RequestConfig 永远是空的,以保证透明性。
resources, err := s.assembleRequestResources(selectedGroup, apiKey) resources, err := s.assembleRequestResources(selectedGroup, apiKey)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.") log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err return nil, err
} }
resources.RequestConfig = &models.RequestConfig{} // 强制为空 resources.RequestConfig = &models.RequestConfig{}
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID) log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil return resources, nil
} }
// --- [模式二:精确路由模式] --- func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"}) log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.") log.Debug("Entering PreciseRoute resource acquisition.")
@@ -101,12 +98,11 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou
if !ok { if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.") return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
} }
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) { if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.") return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
} }
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup) apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
if err != nil { if err != nil {
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.") log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
return nil, apperrors.ErrNoKeysAvailable return nil, apperrors.ErrNoKeysAvailable
@@ -132,7 +128,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
if authToken.IsAdmin { if authToken.IsAdmin {
for _, group := range allGroups { for _, group := range allGroups {
for _, modelMapping := range group.AllowedModels { for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{} allowedModelsSet[modelMapping.ModelName] = struct{}{}
} }
} }
@@ -144,7 +139,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
for _, group := range allGroups { for _, group := range allGroups {
if _, ok := allowedGroupIDs[group.ID]; ok { if _, ok := allowedGroupIDs[group.ID]; ok {
for _, modelMapping := range group.AllowedModels { for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{} allowedModelsSet[modelMapping.ModelName] = struct{}{}
} }
} }
@@ -164,14 +158,6 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.") return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
} }
var proxyConfig *models.ProxyConfig var proxyConfig *models.ProxyConfig
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
// if group.EnableProxy && s.proxyModule != nil {
// var err error
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
// if err != nil {
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
// }
// }
return &RequestResources{ return &RequestResources{
KeyGroup: group, KeyGroup: group,
APIKey: apiKey, APIKey: apiKey,
@@ -194,7 +180,7 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error { func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
s.logger.Info("Performing initial key cache pre-warming...") s.logger.Info("Performing initial key cache pre-warming...")
if err := s.keyRepo.LoadAllKeysToStore(); err != nil { if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil {
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.") logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
return err return err
} }
@@ -209,7 +195,6 @@ func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup { func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups() allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup var candidateGroups []*models.KeyGroup
// 1. 确定权限范围
allowedGroupIDs := make(map[uint]bool) allowedGroupIDs := make(map[uint]bool)
isTokenRestricted := len(allowedGroupsFromToken) > 0 isTokenRestricted := len(allowedGroupsFromToken) > 0
if isTokenRestricted { if isTokenRestricted {
@@ -217,15 +202,12 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed
allowedGroupIDs[ag.ID] = true allowedGroupIDs[ag.ID] = true
} }
} }
// 2. 筛选
for _, group := range allGroupsFromCache { for _, group := range allGroupsFromCache {
// 检查Token权限
if isTokenRestricted && !allowedGroupIDs[group.ID] { if isTokenRestricted && !allowedGroupIDs[group.ID] {
continue continue
} }
// 检查模型是否被允许
isModelAllowed := false isModelAllowed := false
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许 if len(group.AllowedModels) == 0 {
isModelAllowed = true isModelAllowed = true
} else { } else {
for _, m := range group.AllowedModels { for _, m := range group.AllowedModels {
@@ -239,8 +221,6 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed
candidateGroups = append(candidateGroups, group) candidateGroups = append(candidateGroups, group)
} }
} }
// 3.按 Order 字段升序排序
sort.SliceStable(candidateGroups, func(i, j int) bool { sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order return candidateGroups[i].Order < candidateGroups[j].Order
}) })

View File

@@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke
// IsIPBanned // IsIPBanned
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) { func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
banKey := fmt.Sprintf("banned_ip:%s", ip) banKey := fmt.Sprintf("banned_ip:%s", ip)
return s.store.Exists(banKey) return s.store.Exists(ctx, banKey)
} }
// RecordFailedLoginAttempt // RecordFailedLoginAttempt
@@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
return nil return nil
} }
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1) count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1)
if err != nil { if err != nil {
return err return err
} }
@@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
banDuration := s.SettingsManager.GetIPBanDuration() banDuration := s.SettingsManager.GetIPBanDuration()
banKey := fmt.Sprintf("banned_ip:%s", ip) banKey := fmt.Sprintf("banned_ip:%s", ip)
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil { if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil {
return err return err
} }
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration) s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
s.store.HDel(loginAttemptsKey, ip) s.store.HDel(ctx, loginAttemptsKey, ip)
} }
return nil return nil

View File

@@ -2,6 +2,7 @@
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gemini-balancer/internal/models" "gemini-balancer/internal/models"
@@ -34,7 +35,7 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() { func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.") s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged) sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil { if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err) s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return return
@@ -67,42 +68,43 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID) s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
return return
} }
ctx := context.Background()
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID) statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason) s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
switch event.ChangeReason { switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted": case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" { if event.OldStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", -1) s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
} else { } else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID) s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID)
} }
case "key_linked": case "key_linked":
if event.NewStatus != "" { if event.NewStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", 1) s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
} else { } else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID) s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID)
} }
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key": case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1) s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1) s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
default: default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID) s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID) s.RecalculateGroupKeyStats(ctx, event.GroupID)
} }
} }
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error { func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID) s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
var results []struct { var results []struct {
Status models.APIKeyStatus Status models.APIKeyStatus
Count int64 Count int64
} }
if err := s.db.Model(&models.GroupAPIKeyMapping{}). if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ?", groupID). Where("key_group_id = ?", groupID).
Select("status, COUNT(*) as count"). Select("status, COUNT(*) as count").
Group("status"). Group("status").
@@ -119,37 +121,25 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
} }
updates["total_keys"] = totalKeys updates["total_keys"] = totalKeys
if err := s.store.Del(statsKey); err != nil { if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID) s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
} }
if err := s.store.HSet(statsKey, updates); err != nil { if err := s.store.HSet(ctx, statsKey, updates); err != nil {
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err) return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
} }
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID) s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
return nil return nil
} }
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) { func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
// TODO 逻辑:
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
// 3. 组合成 DashboardStatsResponse
// ... 这个方法的具体实现我们可以在DashboardQueryService中完成
// 这里我们先确保StatsService的核心职责维护缓存已经完成。
// 为了编译通过,我们先返回一个空对象。
// 伪代码:
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
// ...
return &models.DashboardStatsResponse{}, nil return &models.DashboardStatsResponse{}, nil
} }
func (s *StatsService) AggregateHourlyStats() error { func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
s.logger.Info("Starting aggregation of the last hour's request data...") s.logger.Info("Starting aggregation of the last hour's request data...")
now := time.Now() now := time.Now()
endTime := now.Truncate(time.Hour) // 例如15:23 -> 15:00 endTime := now.Truncate(time.Hour)
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00 startTime := endTime.Add(-1 * time.Hour)
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339)) s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
type aggregationResult struct { type aggregationResult struct {
@@ -161,7 +151,8 @@ func (s *StatsService) AggregateHourlyStats() error {
CompletionTokens int64 CompletionTokens int64
} }
var results []aggregationResult var results []aggregationResult
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens"). Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
Where("request_time >= ? AND request_time < ?", startTime, endTime). Where("request_time >= ? AND request_time < ?", startTime, endTime).
Group("group_id, model_name"). Group("group_id, model_name").
@@ -179,7 +170,7 @@ func (s *StatsService) AggregateHourlyStats() error {
var hourlyStats []models.StatsHourly var hourlyStats []models.StatsHourly
for _, res := range results { for _, res := range results {
hourlyStats = append(hourlyStats, models.StatsHourly{ hourlyStats = append(hourlyStats, models.StatsHourly{
Time: startTime, // 所有记录的时间戳都是该小时的起点 Time: startTime,
GroupID: res.GroupID, GroupID: res.GroupID,
ModelName: res.ModelName, ModelName: res.ModelName,
RequestCount: res.RequestCount, RequestCount: res.RequestCount,
@@ -189,7 +180,7 @@ func (s *StatsService) AggregateHourlyStats() error {
}) })
} }
return s.db.Clauses(clause.OnConflict{ return s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}}, Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}), DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error }).Create(&hourlyStats).Error

View File

@@ -1,3 +1,4 @@
// Filename: internal/store/factory.go
package store package store
import ( import (
@@ -11,7 +12,6 @@ import (
// NewStore creates a new store based on the application configuration. // NewStore creates a new store based on the application configuration.
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) { func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
// 检查是否有Redis配置
if cfg.Redis.DSN != "" { if cfg.Redis.DSN != "" {
opts, err := redis.ParseURL(cfg.Redis.DSN) opts, err := redis.ParseURL(cfg.Redis.DSN)
if err != nil { if err != nil {
@@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
client := redis.NewClient(opts) client := redis.NewClient(opts)
if err := client.Ping(context.Background()).Err(); err != nil { if err := client.Ping(context.Background()).Err(); err != nil {
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err) logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误 return NewMemoryStore(logger), nil
} }
logger.Info("Successfully connected to Redis. Using Redis as store.") logger.Info("Successfully connected to Redis. Using Redis as store.")
return NewRedisStore(client), nil return NewRedisStore(client, logger), nil
} }
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.") logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
return NewMemoryStore(logger), nil return NewMemoryStore(logger), nil

View File

@@ -1,8 +1,9 @@
// Filename: internal/store/memory_store.go (经同行审查后最终修复版) // Filename: internal/store/memory_store.go
package store package store
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"sort" "sort"
@@ -12,6 +13,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil) var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct { type memoryStoreItem struct {
@@ -32,7 +34,6 @@ type memoryStore struct {
items map[string]*memoryStoreItem items map[string]*memoryStoreItem
pubsub map[string][]chan *Message pubsub map[string][]chan *Message
mu sync.RWMutex mu sync.RWMutex
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
rng *rand.Rand rng *rand.Rand
rngMu sync.Mutex rngMu sync.Mutex
logger *logrus.Entry logger *logrus.Entry
@@ -42,7 +43,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{ store := &memoryStore{
items: make(map[string]*memoryStoreItem), items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message), pubsub: make(map[string][]chan *Message),
// 使用当前时间作为种子,创建一个新的随机数源
rng: rand.New(rand.NewSource(time.Now().UnixNano())), rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"), logger: logger.WithField("component", "store.memory 🗱"),
} }
@@ -50,13 +50,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
return store return store
} }
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
func (s *memoryStore) startGCollector() { func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
s.mu.Lock() s.mu.Lock()
now := time.Now() // 避免在循环中重复调用 now := time.Now()
for key, item := range s.items { for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) { if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key) delete(s.items, key)
@@ -66,92 +65,10 @@ func (s *memoryStore) startGCollector() {
} }
} }
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题 // --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) { // --- 内存实现可以忽略该参数,用 _ 接收 ---
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey] func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
// 安全地进行类型断言
mainSet, mainOk = mainItem.value.(map[string]struct{})
// 确保断言成功且集合不为空
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
// 安全地进行类型断言
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
// 安全地处理冷却池
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
// SRandMember [并发修复版] 使用带锁的rng
func (s *memoryStore) SRandMember(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
var expireAt time.Time var expireAt time.Time
@@ -162,7 +79,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
return nil return nil
} }
func (s *memoryStore) Get(key string) ([]byte, error) { func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -175,7 +92,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
return nil, ErrNotFound return nil, ErrNotFound
} }
func (s *memoryStore) Del(keys ...string) error { func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for _, key := range keys { for _, key := range keys {
@@ -184,14 +101,14 @@ func (s *memoryStore) Del(keys ...string) error {
return nil return nil
} }
func (s *memoryStore) Exists(key string) (bool, error) { func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
item, ok := s.items[key] item, ok := s.items[key]
return ok && !item.isExpired(), nil return ok && !item.isExpired(), nil
} }
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -208,7 +125,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
func (s *memoryStore) Close() error { return nil } func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error { func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -223,7 +140,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
return nil return nil
} }
func (s *memoryStore) HSet(key string, values map[string]any) error { func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -242,7 +159,7 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
return nil return nil
} }
func (s *memoryStore) HGetAll(key string) (map[string]string, error) { func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -259,7 +176,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
return make(map[string]string), nil return make(map[string]string), nil
} }
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) { func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -281,7 +198,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
return newVal, nil return newVal, nil
} }
func (s *memoryStore) LPush(key string, values ...any) error { func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -301,7 +218,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
return nil return nil
} }
func (s *memoryStore) LRem(key string, count int64, value any) error { func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -326,7 +243,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
return nil return nil
} }
func (s *memoryStore) SAdd(key string, members ...any) error { func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -345,7 +262,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
return nil return nil
} }
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) { func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -375,7 +292,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
return popped, nil return popped, nil
} }
func (s *memoryStore) SMembers(key string) ([]string, error) { func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -393,7 +310,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
return members, nil return members, nil
} }
func (s *memoryStore) SRem(key string, members ...any) error { func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -410,7 +327,31 @@ func (s *memoryStore) SRem(key string, members ...any) error {
return nil return nil
} }
func (s *memoryStore) Rotate(key string) (string, error) { func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -426,7 +367,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
return val, nil return val, nil
} }
func (s *memoryStore) LIndex(key string, index int64) (string, error) { func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -447,8 +388,7 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
return list[index], nil return list[index], nil
} }
// Zset methods... (ZAdd, ZRange, ZRem) func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -471,8 +411,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
for val, score := range membersMap { for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score}) newZSet = append(newZSet, zsetMember{Value: val, Score: score})
} }
// NOTE: This ZSet implementation is simple but not performant for large sets.
// A production implementation would use a skip list or a balanced tree.
sort.Slice(newZSet, func(i, j int) bool { sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score { if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value return newZSet[i].Value < newZSet[j].Value
@@ -482,7 +420,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
item.value = newZSet item.value = newZSet
return nil return nil
} }
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) { func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -515,7 +453,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
} }
return result, nil return result, nil
} }
func (s *memoryStore) ZRem(key string, members ...any) error { func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
item, ok := s.items[key] item, ok := s.items[key]
@@ -540,13 +478,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
return nil return nil
} }
// Pipeline implementation func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct { type memoryPipeliner struct {
store *memoryStore store *memoryStore
ops []func() ops []func()
} }
func (s *memoryStore) Pipeline() Pipeliner { func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s} return &memoryPipeliner{store: s}
} }
func (p *memoryPipeliner) Exec() error { func (p *memoryPipeliner) Exec() error {
@@ -559,7 +540,6 @@ func (p *memoryPipeliner) Exec() error {
} }
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) { func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
capturedKey := key capturedKey := key
p.ops = append(p.ops, func() { p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok { if item, ok := p.store.items[capturedKey]; ok {
@@ -596,7 +576,6 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
} }
}) })
} }
func (p *memoryPipeliner) SRem(key string, members ...any) { func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key capturedKey := key
capturedMembers := make([]any, len(members)) capturedMembers := make([]any, len(members))
@@ -615,7 +594,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
} }
}) })
} }
func (p *memoryPipeliner) LPush(key string, values ...any) { func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key capturedKey := key
capturedValues := make([]any, len(values)) capturedValues := make([]any, len(values))
@@ -637,11 +615,12 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...) item.value = append(stringValues, list...)
}) })
} }
func (p *memoryPipeliner) LRem(key string, count int64, value any) {} func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {} func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {} func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
// --- Pub/Sub implementation (remains unchanged) ---
type memorySubscription struct { type memorySubscription struct {
store *memoryStore store *memoryStore
channelName string channelName string
@@ -649,10 +628,11 @@ type memorySubscription struct {
} }
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan } func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error { func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan) return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
} }
func (s *memoryStore) Publish(channel string, message []byte) error { func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel] subscribers, ok := s.pubsub[channel]
@@ -669,7 +649,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
} }
return nil return nil
} }
func (s *memoryStore) Subscribe(channel string) (Subscription, error) { func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
msgChan := make(chan *Message, 10) msgChan := make(chan *Message, 10)

View File

@@ -1,3 +1,5 @@
// Filename: internal/store/redis_store.go
package store package store
import ( import (
@@ -8,22 +10,20 @@ import (
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
) )
// ensure RedisStore implements Store interface // ensure RedisStore implements Store interface
var _ Store = (*RedisStore)(nil) var _ Store = (*RedisStore)(nil)
// RedisStore is a Redis-backed key-value store.
type RedisStore struct { type RedisStore struct {
client *redis.Client client *redis.Client
popAndCycleScript *redis.Script popAndCycleScript *redis.Script
logger *logrus.Entry
} }
// NewRedisStore creates a new RedisStore instance. // NewRedisStore creates a new RedisStore instance.
func NewRedisStore(client *redis.Client) Store { func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
// Lua script for atomic pop-and-cycle operation.
// KEYS[1]: main set key
// KEYS[2]: cooldown set key
const script = ` const script = `
if redis.call('SCARD', KEYS[1]) == 0 then if redis.call('SCARD', KEYS[1]) == 0 then
if redis.call('SCARD', KEYS[2]) == 0 then if redis.call('SCARD', KEYS[2]) == 0 then
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
return &RedisStore{ return &RedisStore{
client: client, client: client,
popAndCycleScript: redis.NewScript(script), popAndCycleScript: redis.NewScript(script),
logger: logger.WithField("component", "store.redis 🗄️"),
} }
} }
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error { func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.client.Set(context.Background(), key, value, ttl).Err() return s.client.Set(ctx, key, value, ttl).Err()
} }
func (s *RedisStore) Get(key string) ([]byte, error) { func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
val, err := s.client.Get(context.Background(), key).Bytes() val, err := s.client.Get(ctx, key).Bytes()
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return nil, ErrNotFound return nil, ErrNotFound
@@ -54,53 +55,53 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
return val, nil return val, nil
} }
func (s *RedisStore) Del(keys ...string) error { func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
if len(keys) == 0 { if len(keys) == 0 {
return nil return nil
} }
return s.client.Del(context.Background(), keys...).Err() return s.client.Del(ctx, keys...).Err()
} }
func (s *RedisStore) Exists(key string) (bool, error) { func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
val, err := s.client.Exists(context.Background(), key).Result() val, err := s.client.Exists(ctx, key).Result()
return val > 0, err return val > 0, err
} }
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(context.Background(), key, value, ttl).Result() return s.client.SetNX(ctx, key, value, ttl).Result()
} }
func (s *RedisStore) Close() error { func (s *RedisStore) Close() error {
return s.client.Close() return s.client.Close()
} }
func (s *RedisStore) HSet(key string, values map[string]any) error { func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(context.Background(), key, values).Err() return s.client.HSet(ctx, key, values).Err()
} }
func (s *RedisStore) HGetAll(key string) (map[string]string, error) { func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(context.Background(), key).Result() return s.client.HGetAll(ctx, key).Result()
} }
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) { func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(context.Background(), key, field, incr).Result() return s.client.HIncrBy(ctx, key, field, incr).Result()
} }
func (s *RedisStore) HDel(key string, fields ...string) error { func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
if len(fields) == 0 { if len(fields) == 0 {
return nil return nil
} }
return s.client.HDel(context.Background(), key, fields...).Err() return s.client.HDel(ctx, key, fields...).Err()
} }
func (s *RedisStore) LPush(key string, values ...any) error { func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
return s.client.LPush(context.Background(), key, values...).Err() return s.client.LPush(ctx, key, values...).Err()
} }
func (s *RedisStore) LRem(key string, count int64, value any) error { func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
return s.client.LRem(context.Background(), key, count, value).Err() return s.client.LRem(ctx, key, count, value).Err()
} }
func (s *RedisStore) Rotate(key string) (string, error) { func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
val, err := s.client.RPopLPush(context.Background(), key, key).Result() val, err := s.client.RPopLPush(ctx, key, key).Result()
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return "", ErrNotFound return "", ErrNotFound
@@ -110,29 +111,28 @@ func (s *RedisStore) Rotate(key string) (string, error) {
return val, nil return val, nil
} }
func (s *RedisStore) SAdd(key string, members ...any) error { func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(context.Background(), key, members...).Err() return s.client.SAdd(ctx, key, members...).Err()
} }
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) { func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
return s.client.SPopN(context.Background(), key, count).Result() return s.client.SPopN(ctx, key, count).Result()
} }
func (s *RedisStore) SMembers(key string) ([]string, error) { func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
return s.client.SMembers(context.Background(), key).Result() return s.client.SMembers(ctx, key).Result()
} }
func (s *RedisStore) SRem(key string, members ...any) error { func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 { if len(members) == 0 {
return nil return nil
} }
return s.client.SRem(context.Background(), key, members...).Err() return s.client.SRem(ctx, key, members...).Err()
} }
func (s *RedisStore) SRandMember(key string) (string, error) { func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
member, err := s.client.SRandMember(context.Background(), key).Result() member, err := s.client.SRandMember(ctx, key).Result()
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return "", ErrNotFound return "", ErrNotFound
} }
@@ -141,81 +141,43 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
return member, nil return member, nil
} }
// === 新增方法实现 === func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
if len(members) == 0 { if len(members) == 0 {
return nil return nil
} }
redisMembers := make([]redis.Z, 0, len(members)) redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members { for member, score := range members {
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member}) redisMembers[i] = redis.Z{Score: score, Member: member}
i++
} }
return s.client.ZAdd(context.Background(), key, redisMembers...).Err() return s.client.ZAdd(ctx, key, redisMembers...).Err()
} }
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) { func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
return s.client.ZRange(context.Background(), key, start, stop).Result() return s.client.ZRange(ctx, key, start, stop).Result()
} }
func (s *RedisStore) ZRem(key string, members ...any) error { func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 { if len(members) == 0 {
return nil return nil
} }
return s.client.ZRem(context.Background(), key, members...).Err() return s.client.ZRem(ctx, key, members...).Err()
} }
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) { func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result() val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return "", ErrNotFound return "", ErrNotFound
} }
return "", err return "", err
} }
// Lua script returns a string, so we need to type assert
if str, ok := val.(string); ok { if str, ok := val.(string); ok {
return str, nil return str, nil
} }
return "", ErrNotFound // This happens if both sets were empty and the script returned nil return "", ErrNotFound
} }
type redisPipeliner struct{ pipe redis.Pipeliner } func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
val, err := s.client.LIndex(ctx, key, index).Result()
func (p *redisPipeliner) HSet(key string, values map[string]any) {
p.pipe.HSet(context.Background(), key, values)
}
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(context.Background(), key, field, incr)
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(context.Background())
return err
}
func (p *redisPipeliner) Del(keys ...string) {
if len(keys) > 0 {
p.pipe.Del(context.Background(), keys...)
}
}
func (p *redisPipeliner) SAdd(key string, members ...any) {
p.pipe.SAdd(context.Background(), key, members...)
}
func (p *redisPipeliner) SRem(key string, members ...any) {
if len(members) > 0 {
p.pipe.SRem(context.Background(), key, members...)
}
}
func (p *redisPipeliner) LPush(key string, values ...any) {
p.pipe.LPush(context.Background(), key, values...)
}
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(context.Background(), key, count, value)
}
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
val, err := s.client.LIndex(context.Background(), key, index).Result()
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return "", ErrNotFound return "", ErrNotFound
@@ -225,47 +187,120 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
return val, nil return val, nil
} }
func (p *redisPipeliner) Expire(key string, expiration time.Duration) { type redisPipeliner struct {
p.pipe.Expire(context.Background(), key, expiration) pipe redis.Pipeliner
ctx context.Context
} }
func (s *RedisStore) Pipeline() Pipeliner { func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
return &redisPipeliner{pipe: s.client.Pipeline()} return &redisPipeliner{
pipe: s.client.Pipeline(),
ctx: ctx,
}
} }
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(p.ctx)
return err
}
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(p.ctx, key, expiration)
}
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(p.ctx, key, field, incr)
}
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
if len(members) == 0 {
return
}
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
p.pipe.ZAdd(p.ctx, key, redisMembers...)
}
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
type redisSubscription struct { type redisSubscription struct {
pubsub *redis.PubSub pubsub *redis.PubSub
msgChan chan *Message msgChan chan *Message
once sync.Once logger *logrus.Entry
wg sync.WaitGroup
close context.CancelFunc
channelName string
}
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
pubsub := s.client.Subscribe(ctx, channel)
_, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
subCtx, cancel := context.WithCancel(context.Background())
sub := &redisSubscription{
pubsub: pubsub,
msgChan: make(chan *Message, 10),
logger: s.logger,
close: cancel,
channelName: channel,
}
sub.wg.Add(1)
go sub.bridge(subCtx)
return sub, nil
}
func (rs *redisSubscription) bridge(ctx context.Context) {
defer rs.wg.Done()
defer close(rs.msgChan)
redisCh := rs.pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case redisMsg, ok := <-redisCh:
if !ok {
return
}
msg := &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
select {
case rs.msgChan <- msg:
default:
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
}
}
}
} }
func (rs *redisSubscription) Channel() <-chan *Message { func (rs *redisSubscription) Channel() <-chan *Message {
rs.once.Do(func() {
rs.msgChan = make(chan *Message)
go func() {
defer close(rs.msgChan)
for redisMsg := range rs.pubsub.Channel() {
rs.msgChan <- &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
}
}()
})
return rs.msgChan return rs.msgChan
} }
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() } func (rs *redisSubscription) ChannelName() string {
return rs.channelName
func (s *RedisStore) Publish(channel string, message []byte) error {
return s.client.Publish(context.Background(), channel, message).Err()
} }
func (s *RedisStore) Subscribe(channel string) (Subscription, error) { func (rs *redisSubscription) Close() error {
pubsub := s.client.Subscribe(context.Background(), channel) rs.close()
_, err := pubsub.Receive(context.Background()) err := rs.pubsub.Close()
if err != nil { rs.wg.Wait()
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err) return err
} }
return &redisSubscription{pubsub: pubsub}, nil
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
return s.client.Publish(ctx, channel, message).Err()
} }

View File

@@ -1,6 +1,9 @@
// Filename: internal/store/store.go
package store package store
import ( import (
"context"
"errors" "errors"
"time" "time"
) )
@@ -17,6 +20,7 @@ type Message struct {
// Subscription represents an active subscription to a pub/sub channel. // Subscription represents an active subscription to a pub/sub channel.
type Subscription interface { type Subscription interface {
Channel() <-chan *Message Channel() <-chan *Message
ChannelName() string
Close() error Close() error
} }
@@ -38,6 +42,10 @@ type Pipeliner interface {
LPush(key string, values ...any) LPush(key string, values ...any)
LRem(key string, count int64, value any) LRem(key string, count int64, value any)
// ZSET
ZAdd(key string, members map[string]float64)
ZRem(key string, members ...any)
// Execution // Execution
Exec() error Exec() error
} }
@@ -45,44 +53,44 @@ type Pipeliner interface {
// Store is the master interface for our cache service. // Store is the master interface for our cache service.
type Store interface { type Store interface {
// Basic K/V operations // Basic K/V operations
Set(key string, value []byte, ttl time.Duration) error Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
Get(key string) ([]byte, error) Get(ctx context.Context, key string) ([]byte, error)
Del(keys ...string) error Del(ctx context.Context, keys ...string) error
Exists(key string) (bool, error) Exists(ctx context.Context, key string) (bool, error)
SetNX(key string, value []byte, ttl time.Duration) (bool, error) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
// HASH operations // HASH operations
HSet(key string, values map[string]any) error HSet(ctx context.Context, key string, values map[string]any) error
HGetAll(key string) (map[string]string, error) HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(key, field string, incr int64) (int64, error) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(key string, fields ...string) error // [新增] HDel(ctx context.Context, key string, fields ...string) error
// LIST operations // LIST operations
LPush(key string, values ...any) error LPush(ctx context.Context, key string, values ...any) error
LRem(key string, count int64, value any) error LRem(ctx context.Context, key string, count int64, value any) error
Rotate(key string) (string, error) Rotate(ctx context.Context, key string) (string, error)
LIndex(key string, index int64) (string, error) LIndex(ctx context.Context, key string, index int64) (string, error)
// SET operations // SET operations
SAdd(key string, members ...any) error SAdd(ctx context.Context, key string, members ...any) error
SPopN(key string, count int64) ([]string, error) SPopN(ctx context.Context, key string, count int64) ([]string, error)
SMembers(key string) ([]string, error) SMembers(ctx context.Context, key string) ([]string, error)
SRem(key string, members ...any) error SRem(ctx context.Context, key string, members ...any) error
SRandMember(key string) (string, error) SRandMember(ctx context.Context, key string) (string, error)
// Pub/Sub operations // Pub/Sub operations
Publish(channel string, message []byte) error Publish(ctx context.Context, channel string, message []byte) error
Subscribe(channel string) (Subscription, error) Subscribe(ctx context.Context, channel string) (Subscription, error)
// Pipeline (optional) - 我们在redis实现它内存版暂时不实现 // Pipeline
Pipeline() Pipeliner Pipeline(ctx context.Context) Pipeliner
// Close closes the store and releases any underlying resources. // Close closes the store and releases any underlying resources.
Close() error Close() error
// === 新增方法,支持轮询策略 === // ZSET operations
ZAdd(key string, members map[string]float64) error ZAdd(ctx context.Context, key string, members map[string]float64) error
ZRange(key string, start, stop int64) ([]string, error) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error)
ZRem(key string, members ...any) error ZRem(ctx context.Context, key string, members ...any) error
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error)
} }

View File

@@ -1,6 +1,7 @@
package syncer package syncer
import ( import (
"context"
"fmt" "fmt"
"gemini-balancer/internal/store" "gemini-balancer/internal/store"
"log" "log"
@@ -51,7 +52,7 @@ func (s *CacheSyncer[T]) Get() T {
func (s *CacheSyncer[T]) Invalidate() error { func (s *CacheSyncer[T]) Invalidate() error {
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName) log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
return s.store.Publish(s.channelName, []byte("reload")) return s.store.Publish(context.Background(), s.channelName, []byte("reload"))
} }
func (s *CacheSyncer[T]) Stop() { func (s *CacheSyncer[T]) Stop() {
@@ -84,7 +85,7 @@ func (s *CacheSyncer[T]) listenForUpdates() {
default: default:
} }
subscription, err := s.store.Subscribe(s.channelName) subscription, err := s.store.Subscribe(context.Background(), s.channelName)
if err != nil { if err != nil {
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err) log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)

View File

@@ -1,7 +1,8 @@
// Filename: internal/task/task.go (最终校准版) // Filename: internal/task/task.go
package task package task
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -15,15 +16,13 @@ const (
ResultTTL = 60 * time.Minute ResultTTL = 60 * time.Minute
) )
// Reporter 接口,定义了领域如何与任务服务交互。
type Reporter interface { type Reporter interface {
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(taskID, resourceID string, result any, taskErr error) EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(taskID string, processed int) error UpdateProgressByID(ctx context.Context, taskID string, processed int) error
UpdateTotalByID(taskID string, total int) error UpdateTotalByID(ctx context.Context, taskID string, total int) error
} }
// Status 代表一个后台任务的完整状态
type Status struct { type Status struct {
ID string `json:"id"` ID string `json:"id"`
TaskType string `json:"task_type"` TaskType string `json:"task_type"`
@@ -38,13 +37,11 @@ type Status struct {
DurationSeconds float64 `json:"duration_seconds,omitempty"` DurationSeconds float64 `json:"duration_seconds,omitempty"`
} }
// Task 是任务管理的核心服务
type Task struct { type Task struct {
store store.Store store store.Store
logger *logrus.Entry logger *logrus.Entry
} }
// NewTask 是 Task 的构造函数
func NewTask(store store.Store, logger *logrus.Logger) *Task { func NewTask(store store.Store, logger *logrus.Logger) *Task {
return &Task{ return &Task{
store: store, store: store,
@@ -62,15 +59,14 @@ func (s *Task) getTaskDataKey(taskID string) string {
return fmt.Sprintf("task:data:%s", taskID) return fmt.Sprintf("task:data:%s", taskID)
} }
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
func (s *Task) getIsRunningFlagKey(taskID string) string { func (s *Task) getIsRunningFlagKey(taskID string) string {
return fmt.Sprintf("task:running:%s", taskID) return fmt.Sprintf("task:running:%s", taskID)
} }
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) { func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID) lockKey := s.getResourceLockKey(resourceID)
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 { if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID)) return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
} }
@@ -94,35 +90,34 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int
timeout = ResultTTL * 24 timeout = ResultTTL * 24
} }
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil { if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err) return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
} }
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil { if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(lockKey) _ = s.store.Del(ctx, lockKey)
return nil, fmt.Errorf("failed to set new task data in store: %w", err) return nil, fmt.Errorf("failed to set new task data in store: %w", err)
} }
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的 if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil { _ = s.store.Del(ctx, lockKey)
_ = s.store.Del(lockKey) _ = s.store.Del(ctx, taskKey)
_ = s.store.Del(taskKey)
return nil, fmt.Errorf("failed to set task running flag: %w", err) return nil, fmt.Errorf("failed to set task running flag: %w", err)
} }
return status, nil return status, nil
} }
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) { func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID) lockKey := s.getResourceLockKey(resourceID)
defer func() { defer func() {
if err := s.store.Del(lockKey); err != nil { if err := s.store.Del(ctx, lockKey); err != nil {
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID) s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
} }
}() }()
runningFlagKey := s.getIsRunningFlagKey(taskID) runningFlagKey := s.getIsRunningFlagKey(taskID)
_ = s.store.Del(runningFlagKey) _ = s.store.Del(ctx, runningFlagKey)
status, err := s.GetStatus(taskID)
if err != nil {
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID) s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
return return
} }
@@ -141,15 +136,14 @@ func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr er
} }
updatedTaskBytes, _ := json.Marshal(status) updatedTaskBytes, _ := json.Marshal(status)
taskKey := s.getTaskDataKey(taskID) taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil { if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID) s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
} }
} }
// GetStatus 通过ID获取任务状态供外部如API Handler调用 func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
func (s *Task) GetStatus(taskID string) (*Status, error) {
taskKey := s.getTaskDataKey(taskID) taskKey := s.getTaskDataKey(taskID)
statusBytes, err := s.store.Get(taskKey) statusBytes, err := s.store.Get(ctx, taskKey)
if err != nil { if err != nil {
if errors.Is(err, store.ErrNotFound) { if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found") return nil, errors.New("task not found")
@@ -161,22 +155,18 @@ func (s *Task) GetStatus(taskID string) (*Status, error) {
if err := json.Unmarshal(statusBytes, &status); err != nil { if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID) return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
} }
if !status.IsRunning && status.FinishedAt != nil { if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
} }
return &status, nil return &status, nil
} }
// UpdateProgressByID 通过ID更新任务进度 func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID) runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(runningFlagKey); err != nil { if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
// 任务已结束,静默返回是预期行为。
return nil return nil
} }
status, err := s.GetStatus(taskID) status, err := s.GetStatus(ctx, taskID)
if err != nil { if err != nil {
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID) s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
return nil return nil
@@ -184,7 +174,6 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
if !status.IsRunning { if !status.IsRunning {
return nil return nil
} }
// 调用传入的 updater 函数来修改 status
updater(status) updater(status)
statusBytes, marshalErr := json.Marshal(status) statusBytes, marshalErr := json.Marshal(status)
if marshalErr != nil { if marshalErr != nil {
@@ -192,23 +181,20 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
return nil return nil
} }
taskKey := s.getTaskDataKey(taskID) taskKey := s.getTaskDataKey(taskID)
// 使用更长的TTL确保运行中的任务不会过早过期 if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID) s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
} }
return nil return nil
} }
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。 func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
func (s *Task) UpdateProgressByID(taskID string, processed int) error { return s.updateTask(ctx, taskID, func(status *Status) {
return s.updateTask(taskID, func(status *Status) {
status.Processed = processed status.Processed = processed
}) })
} }
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。 func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
func (s *Task) UpdateTotalByID(taskID string, total int) error { return s.updateTask(ctx, taskID, func(status *Status) {
return s.updateTask(taskID, func(status *Status) {
status.Total = total status.Total = total
}) })
} }