Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -2,6 +2,7 @@
package proxy
import (
"context"
"encoding/json"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -49,7 +50,6 @@ func (h *handler) registerRoutes(rg *gin.RouterGroup) {
}
}
// --- 请求 DTO ---
type CreateProxyConfigRequest struct {
Address string `json:"address" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
@@ -64,12 +64,10 @@ type UpdateProxyConfigRequest struct {
Description *string `json:"description"`
}
// 单个检测的请求体 (与前端JS对齐)
type CheckSingleProxyRequest struct {
Proxy string `json:"proxy" binding:"required"`
}
// 批量检测的请求体
type CheckAllProxiesRequest struct {
Proxies []string `json:"proxies" binding:"required"`
}
@@ -84,7 +82,7 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
}
if req.Status == "" {
req.Status = "active" // 默认状态
req.Status = "active"
}
proxyConfig := models.ProxyConfig{
@@ -98,7 +96,6 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
response.Error(c, errors.ParseDBError(err))
return
}
// 写操作后,发布事件并使缓存失效
h.publishAndInvalidate(proxyConfig.ID, "created")
response.Created(c, proxyConfig)
}
@@ -199,17 +196,16 @@ func (h *handler) DeleteProxyConfig(c *gin.Context) {
response.NoContent(c)
}
// publishAndInvalidate 统一事件发布和缓存失效逻辑
func (h *handler) publishAndInvalidate(proxyID uint, action string) {
go h.manager.invalidate()
go func() {
ctx := context.Background()
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
eventData, _ := json.Marshal(event)
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData)
_ = h.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
}()
}
// 新的 Handler 方法和 DTO
type SyncProxiesRequest struct {
Proxies []string `json:"proxies"`
}
@@ -220,14 +216,12 @@ func (h *handler) SyncProxies(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
taskStatus, err := h.manager.SyncProxiesInBackground(c.Request.Context(), req.Proxies)
if err != nil {
if errors.Is(err, ErrTaskConflict) {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
} else {
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
}
return
@@ -262,7 +256,7 @@ func (h *handler) CheckAllProxies(c *gin.Context) {
concurrency := cfg.ProxyCheckConcurrency
if concurrency <= 0 {
concurrency = 5 // 如果配置不合法,提供一个安全的默认值
concurrency = 5
}
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
response.Success(c, results)

View File

@@ -2,14 +2,13 @@
package proxy
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"gemini-balancer/internal/task"
"context"
"net"
"net/http"
"net/url"
@@ -25,7 +24,7 @@ import (
const (
TaskTypeProxySync = "proxy_sync"
proxyChunkSize = 200 // 代理同步的批量大小
proxyChunkSize = 200
)
type ProxyCheckResult struct {
@@ -35,13 +34,11 @@ type ProxyCheckResult struct {
ErrorMessage string `json:"error_message"`
}
// managerCacheData
type managerCacheData struct {
ActiveProxies []*models.ProxyConfig
ProxiesByID map[uint]*models.ProxyConfig
}
// manager结构体
type manager struct {
db *gorm.DB
syncer *syncer.CacheSyncer[managerCacheData]
@@ -80,21 +77,21 @@ func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskR
}
}
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) {
func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) {
resourceID := "global_proxy_sync"
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
if err != nil {
return nil, ErrTaskConflict
}
go m.runProxySyncTask(taskStatus.ID, proxyStrings)
go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings)
return taskStatus, nil
}
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) {
resourceID := "global_proxy_sync"
var allProxies []models.ProxyConfig
if err := m.db.Find(&allProxies).Error; err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
return
}
currentProxyMap := make(map[string]uint)
@@ -125,19 +122,19 @@ func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
}
if len(idsToDelete) > 0 {
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
return
}
}
if len(proxiesToAdd) > 0 {
if err := m.bulkAdd(proxiesToAdd); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
return
}
}
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
m.task.EndTaskByID(taskID, resourceID, result, nil)
m.publishChangeEvent("proxies_synced")
m.task.EndTaskByID(ctx, taskID, resourceID, result, nil)
m.publishChangeEvent(ctx, "proxies_synced")
go m.invalidate()
}
@@ -184,14 +181,15 @@ func (m *manager) bulkDeleteByIDs(ids []uint) error {
}
return nil
}
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
}
func (m *manager) publishChangeEvent(reason string) {
func (m *manager) publishChangeEvent(ctx context.Context, reason string) {
event := models.ProxyStatusChangedEvent{Action: reason}
eventData, _ := json.Marshal(event)
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData)
_ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
}
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {

View File

@@ -1,4 +1,4 @@
// Filename: internal/handlers/apikey_handler.go
// Filename: internal/handlers/apikey_handler.go (最终决战版)
package handlers
import (
@@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo
}
}
// DTOs for API requests
type BulkAddKeysToGroupRequest struct {
KeyGroupID uint `json:"key_group_id" binding:"required"`
Keys string `json:"keys" binding:"required"`
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
ValidateOnImport bool `json:"validate_on_import"`
}
type BulkUnlinkKeysFromGroupRequest struct {
@@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct {
}
type BulkActionFilter struct {
Status []string `json:"status"` // Changed to slice to accept multiple statuses
Status []string `json:"status"`
}
type BulkActionRequest struct {
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
Filter BulkActionFilter `json:"filter" binding:"required"`
}
@@ -89,7 +88,8 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -104,7 +104,8 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -119,7 +120,8 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -134,7 +136,8 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -148,7 +151,8 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -172,7 +176,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
}
}
if len(ids) > 0 {
keys, err := h.apiKeyService.GetKeysByIds(ids)
keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
if err != nil {
response.Error(c, &errors.APIError{
HTTPStatus: http.StatusInternalServerError,
@@ -191,7 +195,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
if params.PageSize <= 0 {
params.PageSize = 20
}
result, err := h.apiKeyService.ListAPIKeys(&params)
result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
@@ -201,19 +205,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
// 1. Manually handle the path parameter.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// 2. Bind query parameters using the correctly tagged struct.
var params models.APIKeyQueryParams
if err := c.ShouldBindQuery(&params); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
return
}
// 3. Set server-side defaults and the path parameter.
if params.Page <= 0 {
params.Page = 1
}
@@ -221,15 +222,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
params.PageSize = 20
}
params.KeyGroupID = uint(groupID)
// 4. Call the service layer.
paginatedResult, err := h.apiKeyService.ListAPIKeys(&params)
paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
// 5. [THE FIX] Return a successful response using the standard `response.Success`
// and a gin.H map, as confirmed to exist in your project.
response.Success(c, gin.H{
"items": paginatedResult.Items,
"total": paginatedResult.Total,
@@ -239,20 +236,18 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
}
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
// Group ID is now correctly sourced from the URL path.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// The request body is now simpler, only needing the keys.
var req BulkTestKeysForGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Call the same underlying service, but with unambiguous context.
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -267,7 +262,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
}
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
// Route: PUT /keygroups/:id/apikeys/:keyId
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -284,8 +278,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Directly use the service to handle the logic
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -305,7 +298,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
return
}
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
@@ -313,7 +306,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
}
// RestoreKeysInGroup 恢复指定Key的接口
// POST /keygroups/:id/apikeys/restore
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -325,7 +317,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -339,14 +331,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
}
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
// POST /keygroups/:id/apikeys/restore-all-banned
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -360,48 +351,41 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
}
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
// Route: POST /keygroups/:id/bulk-actions
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
// 1. Parse GroupID from URL
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// 2. Bind the JSON payload to our new DTO
var req BulkActionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// 3. Central logic: based on the action, call the appropriate service method.
var task *task.Status
var apiErr *errors.APIError
switch req.Action {
case "revalidate":
// Assume keyValidationService has a method that accepts a filter
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
// [修正] 将请求的 context 传递给 service 层
task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
case "set_status":
if req.NewStatus == "" {
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
break
}
// Assume apiKeyService has a method to update status by filter
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
targetStatus := models.APIKeyStatus(req.NewStatus)
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
case "delete":
// Assume keyImportService has a method to unlink by filter
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
// [修正] 将请求的 context 传递给 service 层
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
default:
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
}
// 4. Handle errors from the switch block
if apiErr != nil {
response.Error(c, apiErr)
return
}
if err != nil {
// Attempt to parse it as a known APIError, otherwise, wrap it.
var parsedErr *errors.APIError
if errors.As(err, &parsedErr) {
response.Error(c, parsedErr)
@@ -410,21 +394,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
}
return
}
// 5. Return the task status on success
response.Success(c, task)
}
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
// Route: GET /keygroups/:id/apikeys/export
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
statuses := c.QueryArray("status")
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return

View File

@@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}
// GetChart 获取仪表盘的图表数据
// GetChart
func (h *DashboardHandler) GetChart(c *gin.Context) {
var groupID *uint
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
@@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
}
}
chartData, err := h.queryService.QueryHistoricalChart(groupID)
chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)
@@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
c.JSON(http.StatusOK, chartData)
}
// GetRequestStats 处理对“期间调用概览”的请求
// GetRequestStats
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
period := c.Param("period") // 从 URL 路径中获取 period
stats, err := h.queryService.GetRequestStatsForPeriod(period)
period := c.Param("period")
stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)

View File

@@ -2,6 +2,7 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/errors"
@@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das
}
}
// DTOs & 辅助函数
func isValidGroupName(name string) bool {
if name == "" {
return false
@@ -40,7 +40,6 @@ func isValidGroupName(name string) bool {
return match
}
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
type KeyGroupOperationalSettings struct {
EnableKeyCheck *bool `json:"enable_key_check"`
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
@@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct {
MaxRetries *int `json:"max_retries"`
EnableSmartGateway *bool `json:"enable_smart_gateway"`
}
type CreateKeyGroupRequest struct {
Name string `json:"name" binding:"required"`
DisplayName string `json:"display_name"`
@@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct {
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy bool `json:"enable_proxy"`
ChannelType string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
}
type UpdateKeyGroupRequest struct {
Name *string `json:"name"`
DisplayName *string `json:"display_name"`
@@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct {
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy *bool `json:"enable_proxy"`
ChannelType *string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
// M:N associations
AllowedUpstreams []string `json:"allowed_upstreams"`
AllowedModels []string `json:"allowed_models"`
}
type KeyGroupResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
@@ -96,36 +86,30 @@ type KeyGroupResponse struct {
AllowedModels []string `json:"allowed_models"`
AllowedUpstreams []string `json:"allowed_upstreams"`
}
// [NEW] Define the detailed response structure for a single group.
type KeyGroupDetailsResponse struct {
KeyGroupResponse
Settings *models.GroupSettings `json:"settings,omitempty"`
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
}
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
modelNames := make([]string, 0, len(mappings))
for _, mapping := range mappings {
if mapping != nil { // Safety check
if mapping != nil {
modelNames = append(modelNames, mapping.ModelName)
}
}
return modelNames
}
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
urls := make([]string, 0, len(upstreams))
for _, upstream := range upstreams {
if upstream != nil { // Safety check
if upstream != nil {
urls = append(urls, upstream.URL)
}
}
return urls
}
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
return KeyGroupResponse{
ID: group.ID,
@@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i
CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt,
Order: group.Order,
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
AllowedModels: transformModelsToStrings(group.AllowedModels),
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams),
}
}
// packGroupSettings is a helper to convert request-level operational settings
// into the model-level settings struct.
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
return &models.KeyGroupSettings{
EnableKeyCheck: settings.EnableKeyCheck,
@@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet
EnableSmartGateway: settings.EnableSmartGateway,
}
}
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup,
}
return group, nil
}
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
if req.Name != nil {
group.Name = *req.Name
@@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
go func() {
ctx := context.Background()
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
eventData, _ := json.Marshal(event)
h.store.Publish(models.TopicKeyStatusChanged, eventData)
_ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}()
}
@@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
return
}
// The core logic remains, as it's specific to creation.
p := bluemonday.StripTagsPolicy()
sanitizedDisplayName := p.Sanitize(req.DisplayName)
sanitizedDescription := p.Sanitize(req.Description)
@@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
}
// 统一的处理器可以处理两种情况:
// 1. GET /keygroups - 返回所有组的列表
// 2. GET /keygroups/:id - 返回指定ID的单个组
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
// Case 1: Get a single group
if idStr := c.Param("id"); idStr != "" {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, detailedResponse)
return
}
// Case 2: Get all groups
allGroups := h.groupManager.GetAllGroups()
responses := make([]KeyGroupResponse, 0, len(allGroups))
for _, group := range allGroups {
@@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, responses)
}
// UpdateKeyGroup
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
}
// DeleteKeyGroup
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
}
// GetKeyGroupStats
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
stats, err := h.queryService.GetGroupStats(group.ID)
stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
return
@@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
}
// 更新分组排序
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
var payload []service.UpdateOrderPayload
if err := c.ShouldBindJSON(&payload); err != nil {

View File

@@ -136,7 +136,6 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
var finalPromptTokens, finalCompletionTokens int
var actualRetries int = 0
defer func() {
// 如果一次尝试都未成功(例如,在第一次获取资源时就失败),则不记录日志
if lastUsedResources == nil {
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
return
@@ -151,44 +150,38 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
}
// 确保即使在成功的情况下如果recorder存在也记录最终的状态码
if finalRecorder != nil {
finalEvent.RequestLog.StatusCode = finalRecorder.Code
}
if !isSuccess {
// 将 finalProxyErr 的信息填充到 RequestLog 中
if finalProxyErr != nil {
finalEvent.Error = finalProxyErr // Error 字段用于事件传递,不会被序列化到数据库
finalEvent.Error = finalProxyErr
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
} else if finalRecorder != nil {
// 降级处理:如果 finalProxyErr 为空但 recorder 存在且失败
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
finalEvent.Error = apiErr
finalEvent.RequestLog.ErrorCode = apiErr.Code
finalEvent.RequestLog.ErrorMessage = apiErr.Message
}
}
// 将完整的事件发布
eventData, err := json.Marshal(finalEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
return
}
if err := h.store.Publish(models.TopicRequestFinished, eventData); err != nil {
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
}
}()
var maxRetries int
if isPreciseRouting {
// For precise routing, use the group's setting. If not set, fall back to the global setting.
if finalOpConfig.MaxRetries != nil {
maxRetries = *finalOpConfig.MaxRetries
} else {
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
} else {
// For BasePool (intelligent aggregation), *always* use the global setting.
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
totalAttempts := maxRetries + 1
@@ -332,7 +325,7 @@ func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte,
retryEvent.ErrorMessage = attemptErr.Message
}
eventData, _ := json.Marshal(retryEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
}
if finalRecorder != nil {
bodyBytes := finalRecorder.Body.Bytes()
@@ -368,7 +361,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
requestFinishedEvent.StatusCode = c.Writer.Status()
}
eventData, _ := json.Marshal(requestFinishedEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData)
_ = h.store.Publish(context.Background(), models.TopicRequestFinished, eventData)
}()
params := channel.SmartRequestParams{
CorrelationID: correlationID,
@@ -435,7 +428,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
}
}
if res != nil {
// [核心修正] 填充到内嵌的 RequestLog 结构体中
if res.APIKey != nil {
event.RequestLog.KeyID = &res.APIKey.ID
}
@@ -444,7 +436,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
}
if res.UpstreamEndpoint != nil {
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
// UpstreamURL 是事件传递字段,不是数据库字段,所以在这里赋值是正确的
event.UpstreamURL = &res.UpstreamEndpoint.URL
}
if res.ProxyConfig != nil {
@@ -464,9 +455,9 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
}
if isPreciseRouting {
return h.resourceService.GetResourceFromGroup(authToken, groupName)
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
} else {
return h.resourceService.GetResourceFromBasePool(authToken, modelName)
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
}
}

View File

@@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
return
}
taskStatus, err := h.taskService.GetStatus(taskID)
taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID)
if err != nil {
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -22,7 +23,7 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
)
func (r *gormKeyRepository) LoadAllKeysToStore() error {
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
@@ -48,7 +49,7 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
}
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(context.Background())
detailsToSet := make(map[string][]byte)
var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil {
@@ -100,14 +101,14 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
go r.store.ZAdd(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
}
if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(key, value, 0); err != nil {
if err := r.store.Set(context.Background(), key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
}
@@ -124,16 +125,16 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err != nil {
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
}
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
}
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(key.ID)
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
}
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
for _, groupID := range groupIDs {
@@ -144,13 +145,13 @@ func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
go r.store.ZRem(context.Background(), fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
}
return pipe.Exec()
}
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(context.Background())
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
if mapping.Status == models.StatusActive {
@@ -159,7 +160,7 @@ func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIK
return pipe.Exec()
}
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 {
return nil
}
@@ -184,7 +185,7 @@ func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.Group
}
groupUpdates[mapping.KeyGroupID] = update
}
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(context.Background())
var pipelineError error
for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"gemini-balancer/internal/models"
"context"
"math/rand"
"strings"
"time"
@@ -115,7 +116,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
}
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
key, err := r.GetKeyByID(id) // This now returns a decrypted key
key, err := r.GetKeyByID(id)
if err != nil {
return err
}
@@ -125,7 +126,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
if err != nil {
return err
}
if err := r.removeStoreCacheForKey(key); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
}
return nil
@@ -140,16 +141,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
// Find the full key objects first to update the cache later.
var keysToDelete []models.APIKey
// [MODIFIED] Find by hash.
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
return 0, err
}
if len(keysToDelete) == 0 {
return 0, nil
}
// Decrypt them to ensure cache has plaintext if needed.
if err := r.decryptKeys(keysToDelete); err != nil {
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
}
@@ -167,7 +165,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
return 0, err
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/models"
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
}
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
if result.Error != nil {
return 0, result.Error
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
return keys, nil
}
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
if err == nil {
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
go func() {
if err := r.LoadAllKeysToStore(); err != nil {
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
}
}()

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -14,7 +15,7 @@ import (
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 {
return nil
}
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
}
for _, keyID := range keyIDs {
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
}
return nil
}
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
if len(keyIDs) == 0 {
return 0, nil
}
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
for _, keyID := range keyIDs {
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
}
return unlinkedCount, nil
}
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
strGroupIDs, err := r.store.SMembers(cacheKey)
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
if err != nil || len(strGroupIDs) == 0 {
var groupIDs []uint
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
for _, id := range groupIDs {
interfaceSlice = append(interfaceSlice, id)
}
r.store.SAdd(cacheKey, interfaceSlice...)
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
}
return groupIDs, nil
}
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
return &mapping, err
}
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})

View File

@@ -1,7 +1,8 @@
// Filename: internal/repository/key_selector.go
// Filename: internal/repository/key_selector.go (经审查后最终修复版)
package repository
import (
"context"
"crypto/sha1"
"encoding/json"
"errors"
@@ -23,19 +24,18 @@ const (
)
// SelectOneActiveKey 根据指定的轮询策略从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
@@ -44,11 +44,11 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
}
if err != nil {
@@ -65,27 +65,25 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, uint(keyID))
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
poolID := generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
return nil, nil, err
}
@@ -96,10 +94,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
}
@@ -107,12 +105,11 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
keyIDStr, err = r.store.LIndex(ctx, sequentialKey, 0)
}
if err != nil {
@@ -128,12 +125,10 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
go r.updateKeyUsageTimestampForPool(context.Background(), poolID, uint(keyID))
}
return apiKey, group, nil
}
@@ -144,42 +139,39 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
exists, err := r.store.Exists(listKey)
exists, err := r.store.Exists(ctx, listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err // 直接返回读取错误
return err
}
if exists {
val, err := r.store.LIndex(listKey, 0)
val, err := r.store.LIndex(ctx, listKey, 0)
if err != nil {
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound // 已知为空,直接返回
return gorm.ErrRecordNotFound
}
return nil // 缓存有效,直接返回
return nil
}
}
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), 10*time.Second)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
}
if !acquired {
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID)
return r.ensureBasePoolCacheExists(ctx, pool, poolID)
}
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
if exists, _ := r.store.Exists(listKey); exists {
defer r.store.Del(context.Background(), lockKey)
if exists, _ := r.store.Exists(ctx, listKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
@@ -187,22 +179,15 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
lruMembers := make(map[string]float64)
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
// --- [核心修正] ---
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
groupKeyIDs, err := r.store.SMembers(ctx, activeKeySetKey)
if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
// 从而给了下一次请求一个全新的、成功的机会。
return err
}
// 只有在 SMembers 成功时,才继续处理
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
_, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
@@ -213,12 +198,9 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
}
// --- [逻辑修正] ---
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
// 才允许写入“毒丸”。
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(ctx)
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
@@ -226,14 +208,10 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
return gorm.ErrRecordNotFound
}
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据
pipe := r.store.Pipeline(ctx)
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
@@ -244,17 +222,22 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
if err := r.store.ZAdd(ctx, fmt.Sprintf(BasePoolLRU, poolID), lruMembers); err != nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool_id '%s'", poolID)
}
}
return nil
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
})
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
}
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
@@ -285,8 +268,8 @@ func nowMilli() float64 {
}
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
}
@@ -295,7 +278,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
}
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
}

View File

@@ -1,7 +1,9 @@
// Filename: internal/repository/key_writer.go
package repository
import (
"context"
"fmt"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -9,7 +11,7 @@ import (
"time"
)
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
timestamp := float64(time.Now().UnixMilli())
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
strconv.FormatUint(uint64(keyID), 10): timestamp,
}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
}
}
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(lruKey, keyIDStr)
_ = r.store.SRem(mainPoolKey, keyIDStr)
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
_ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(ctx, lruKey, keyIDStr)
_ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
_ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
if newStatus == models.StatusActive {
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
}
members := map[string]float64{keyIDStr: 0}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
}
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
}
}
}
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
if success {
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
}
return
}
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
}
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
// This call is correct. It uses the synchronous, direct method.
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
}

View File

@@ -1,7 +1,8 @@
// Filename: internal/repository/repository.go
// Filename: internal/repository/repository.go (经审查后最终修复版)
package repository
import (
"context"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -22,8 +23,8 @@ type BasePool struct {
type KeyRepository interface {
// --- 核心选取与调度 --- key_selector
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
// --- 加密与解密 --- key_crud
Decrypt(key *models.APIKey) error
@@ -37,16 +38,16 @@ type KeyRepository interface {
GetKeyByID(id uint) (*models.APIKey, error)
GetKeyByValue(keyValue string) (*models.APIKey, error)
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
GetKeysByIDs(ids []uint) ([]models.APIKey, error)
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
CountByGroup(groupID uint) (int64, error)
// --- 多对多关系管理 --- key_mapping
LinkKeysToGroup(groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(keyID uint) ([]uint, error)
LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
@@ -55,8 +56,8 @@ type KeyRepository interface {
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
// --- 缓存管理 --- key_cache
LoadAllKeysToStore() error
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
LoadAllKeysToStore(ctx context.Context) error
HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
// --- 维护与后台任务 --- key_maintenance
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
@@ -65,16 +66,14 @@ type KeyRepository interface {
DeleteOrphanKeys() (int64, error)
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
GetActiveMasterKeys() ([]*models.APIKey, error)
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
// --- 轮询策略的"写"操作 --- key_writer
UpdateKeyUsageTimestamp(groupID, keyID uint)
// 同步更新缓存,供核心业务使用
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
// 异步更新缓存,供事件订阅者使用
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
}
type GroupRepository interface {

View File

@@ -2,6 +2,7 @@
package scheduler
import (
"context"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/service"
"time"
@@ -15,7 +16,6 @@ type Scheduler struct {
logger *logrus.Entry
statsService *service.StatsService
keyRepo repository.KeyRepository
// healthCheckService *service.HealthCheckService // 健康检查任务预留
}
func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler {
@@ -32,11 +32,13 @@ func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyReposito
func (s *Scheduler) Start() {
s.logger.Info("Starting scheduler and registering jobs...")
// --- 任务注册 ---
// 任务一:每小时执行一次的统计聚合
// 使用CRON表达式精确定义“每小时的第5分钟”执行
_, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() {
s.logger.Info("Executing hourly request stats aggregation...")
if err := s.statsService.AggregateHourlyStats(); err != nil {
// 为后台定时任务创建一个新的、空的 context
ctx := context.Background()
if err := s.statsService.AggregateHourlyStats(ctx); err != nil {
s.logger.WithError(err).Error("Hourly stats aggregation failed.")
} else {
s.logger.Info("Hourly stats aggregation completed successfully.")
@@ -46,23 +48,14 @@ func (s *Scheduler) Start() {
s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err)
}
// 任务二:(预留) 自动健康检查 (例如每10分钟一次)
/*
_, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() {
s.logger.Info("Executing periodic health check for all groups...")
// s.healthCheckService.StartGlobalCheckTask() // 伪代码
})
if err != nil {
s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err)
}
*/
// [NEW] --- 任务三: 清理软删除的API Keys ---
// 任务二:(预留) 自动健康检查
// 任务三每日执行一次的软删除Key清理
// Executes once daily at 3:15 AM UTC.
_, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() {
s.logger.Info("Executing daily cleanup of soft-deleted API keys...")
// Let's assume a retention period of 7 days for now.
// In a real scenario, this should come from settings.
// [假设保留7天实际应来自配置
const retentionDays = 7
count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays))
@@ -77,9 +70,8 @@ func (s *Scheduler) Start() {
if err != nil {
s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err)
}
// --- 任务注册结束 ---
s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程
s.gocronScheduler.StartAsync()
s.logger.Info("Scheduler started.")
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/analytics_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/db/dialect"
@@ -43,7 +43,7 @@ func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d di
}
func (s *AnalyticsService) Start() {
s.wg.Add(2) // 2 (flushLoop, eventListener)
s.wg.Add(2)
go s.flushLoop()
go s.eventListener()
s.logger.Info("AnalyticsService (Command Side) started.")
@@ -53,13 +53,13 @@ func (s *AnalyticsService) Stop() {
close(s.stopChan)
s.wg.Wait()
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
s.flushToDB() // 停止前刷盘
s.flushToDB()
s.logger.Info("AnalyticsService final data flush completed.")
}
func (s *AnalyticsService) eventListener() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
@@ -87,9 +87,10 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline()
pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1)
@@ -120,6 +121,7 @@ func (s *AnalyticsService) flushLoop() {
}
func (s *AnalyticsService) flushToDB() {
ctx := context.Background()
now := time.Now().UTC()
keysToFlush := []string{
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
@@ -127,7 +129,7 @@ func (s *AnalyticsService) flushToDB() {
}
for _, key := range keysToFlush {
data, err := s.store.HGetAll(key)
data, err := s.store.HGetAll(ctx, key)
if err != nil || len(data) == 0 {
continue
}
@@ -136,15 +138,15 @@ func (s *AnalyticsService) flushToDB() {
if len(statsToFlush) > 0 {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"}, // conflict columns
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
if err != nil {
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
} else {
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
_ = s.store.HDel(key, parsedFields...)
_ = s.store.HDel(ctx, key, parsedFields...)
}
}
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/apikey_service.go
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -29,7 +29,6 @@ const (
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
)
// DTOs & Constants
const (
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
)
@@ -83,7 +82,6 @@ func NewAPIKeyService(
gm *GroupManager,
logger *logrus.Logger,
) *APIKeyService {
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
return &APIKeyService{
db: db,
keyRepo: repo,
@@ -99,22 +97,22 @@ func NewAPIKeyService(
}
func (s *APIKeyService) Start() {
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
}
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
return
}
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
return
@@ -177,6 +175,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
if event.RequestLog.IsSuccess {
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
if err != nil {
@@ -194,17 +193,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
now := time.Now()
mapping.LastUsedAt = &now
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
return
}
if statusChanged {
go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
}
return
}
if event.Error != nil {
s.judgeKeyErrors(
ctx,
event.CorrelationID,
*event.RequestLog.GroupID,
*event.RequestLog.KeyID,
@@ -215,6 +215,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
}
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
ctx := context.Background()
log := s.logger.WithFields(logrus.Fields{
"group_id": event.GroupID,
"key_id": event.KeyID,
@@ -222,11 +223,11 @@ func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChange
"reason": event.ChangeReason,
})
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
log.Info("Polling caches updated based on health check event.")
}
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
changeEvent := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -236,13 +237,12 @@ func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(changeEvent)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
}
}
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
// --- Path 1: High-performance DB pagination (no keyword) ---
func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
if params.Keyword == "" {
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
if err != nil {
@@ -260,14 +260,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
TotalPages: totalPages,
}, nil
}
// --- Path 2: In-memory search (keyword present) ---
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
// To get all keys, we fetch all IDs first, then get their full details.
var statusesToFilter []string
if params.Status != "" {
statusesToFilter = append(statusesToFilter, params.Status)
} else {
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
statusesToFilter = append(statusesToFilter, "all")
}
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
if err != nil {
@@ -277,14 +275,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
}
// This is the heavy operation: getting all keys and decrypting them.
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
if err != nil {
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
}
// We also need mappings to build the final `APIKeyDetails`.
var allMappings []models.GroupAPIKeyMapping
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
if err != nil {
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
}
@@ -292,7 +288,6 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
for i := range allMappings {
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
}
// Filter the results in memory.
var filteredItems []*models.APIKeyDetails
for _, key := range allKeys {
if strings.Contains(key.APIKey, params.Keyword) {
@@ -312,11 +307,9 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}
}
}
// Sort the filtered results to ensure consistent pagination (by ID descending).
sort.Slice(filteredItems, func(i, j int) bool {
return filteredItems[i].ID > filteredItems[j].ID
})
// Manually paginate the filtered results.
total := int64(len(filteredItems))
start := (params.Page - 1) * params.PageSize
end := start + params.PageSize
@@ -345,14 +338,15 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}, nil
}
func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) {
func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) {
return s.keyRepo.GetKeysByIDs(ids)
}
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
go func() {
bgCtx := context.Background()
var oldKey models.APIKey
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
return
}
@@ -364,16 +358,14 @@ func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
return nil
}
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
// Get all associated groups before deletion to publish correct events
groups, err := s.keyRepo.GetGroupsForKey(id)
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
if err != nil {
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
}
err = s.keyRepo.HardDeleteByID(id)
if err == nil {
// Publish events for each group the key was a part of
for _, groupID := range groups {
event := models.KeyStatusChangedEvent{
KeyID: id,
@@ -381,13 +373,13 @@ func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
ChangeReason: "key_hard_deleted",
}
eventData, _ := json.Marshal(event)
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
}
}
return err
}
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -409,19 +401,20 @@ func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus model
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
return nil, err
}
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
return mapping, nil
}
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
ctx := context.Background()
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
if event.NewMasterStatus != models.MasterStatusRevoked {
return
}
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
return
@@ -432,7 +425,7 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
for _, groupID := range affectedGroupIDs {
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
@@ -441,32 +434,32 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
}
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
if len(keyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
if err != nil {
return nil, err
}
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
return taskStatus, nil
}
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
}
}()
var mappingsToProcess []models.GroupAPIKeyMapping
err := s.db.Preload("APIKey").
err := s.db.WithContext(ctx).Preload("APIKey").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Find(&mappingsToProcess).Error
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
result := &BatchRestoreResult{
@@ -476,7 +469,7 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
processedCount := 0
for _, mapping := range mappingsToProcess {
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
if mapping.APIKey == nil {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
@@ -492,33 +485,29 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
// Use the version that doesn't trigger individual cache updates.
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
} else {
result.RestoredCount++
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
successfulMappings = append(successfulMappings, &mapping)
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
}
} else {
result.RestoredCount++ // Already active, count as success.
result.RestoredCount++
}
}
// After the loop, perform one single, efficient cache update.
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
// This is not a task-fatal error, so we just log it and continue.
}
// Account for keys that were requested but not found in the initial DB query.
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
var bannedKeyIDs []uint
err := s.db.Model(&models.GroupAPIKeyMapping{}).
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
Pluck("api_key_id", &bannedKeyIDs).Error
if err != nil {
@@ -527,10 +516,11 @@ func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, e
if len(bannedKeyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
}
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
}
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
ctx := context.Background()
group, ok := s.groupManager.GetGroupByID(event.GroupID)
if !ok {
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
@@ -552,7 +542,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 10 // Safety fallback
concurrency = 10
}
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
@@ -571,7 +561,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
if validationErr == nil {
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
}
} else {
@@ -579,7 +569,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
if !CustomErrors.As(validationErr, &apiErr) {
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
}
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false)
}
}
}()
@@ -592,12 +582,9 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
}
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
// that match a specific set of source statuses within a group.
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -605,35 +592,32 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatus
if len(keyIDs) == 0 {
now := time.Now()
return &task.Status{
IsRunning: false, // The "task" is not running.
IsRunning: false,
Processed: 0,
Total: 0,
Result: map[string]string{ // We use the flexible Result field to pass the message.
Result: map[string]string{
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
},
Error: "", // There is no error.
Error: "",
StartedAt: now,
FinishedAt: &now, // It started and finished at the same time.
}, nil // Return nil for the error, signaling a 200 OK.
FinishedAt: &now,
}, nil
}
// 2. Start a new task using the TaskService, following existing patterns.
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
if err != nil {
return nil, err // Pass up errors like "task already in progress".
return nil, err
}
// 3. Run the core logic in a separate goroutine.
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
return taskStatus, nil
}
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
}
}()
type BatchUpdateResult struct {
@@ -642,31 +626,27 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
}
result := &BatchUpdateResult{}
var successfulMappings []*models.GroupAPIKeyMapping
// 1. Fetch all key master statuses in one go. This is efficient.
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
if err != nil {
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
for _, key := range keys {
masterStatusMap[key.ID] = key.MasterStatus
}
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
// avoiding the need for a new repository method. This pattern is
// already used in other parts of this service.
var mappings []*models.GroupAPIKeyMapping
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
processedCount := 0
for _, mapping := range mappings {
processedCount++
// The progress update should reflect the number of items *being processed*, not the final count.
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
if !ok {
result.SkippedCount++
@@ -688,24 +668,25 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
} else {
result.UpdatedCount++
successfulMappings = append(successfulMappings, mapping)
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
}
} else {
result.UpdatedCount++ // Already in desired state, count as success.
result.UpdatedCount++
}
}
result.SkippedCount += (len(keyIDs) - len(mappings))
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
}
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
ctx := context.Background()
if success {
if group.PollingStrategy == models.StrategyWeighted {
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
}
return
}
@@ -716,26 +697,20 @@ func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.
errMsg := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
} else {
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
}
}
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
func sanitizeForLog(errMsg string) string {
// Find the start of any potential JSON blob or detailed structure.
jsonStartIndex := strings.Index(errMsg, "{")
var cleanMsg string
if jsonStartIndex != -1 {
// If a '{' is found, take everything before it as the summary
// and append a simple placeholder.
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
} else {
// If no JSON-like structure is found, use the original message.
cleanMsg = errMsg
}
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
const maxLen = 250
if len(cleanMsg) > maxLen {
return cleanMsg[:maxLen] + "..."
@@ -744,6 +719,7 @@ func sanitizeForLog(errMsg string) string {
}
func (s *APIKeyService) judgeKeyErrors(
ctx context.Context,
correlationID string,
groupID, keyID uint,
apiErr *CustomErrors.APIError,
@@ -765,11 +741,11 @@ func (s *APIKeyService) judgeKeyErrors(
oldStatus := mapping.Status
mapping.Status = models.StatusBanned
mapping.LastError = errorMessage
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
} else {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(keyID, "permanent_upstream_error")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
}
}
return
@@ -801,23 +777,23 @@ func (s *APIKeyService) judgeKeyErrors(
if oldStatus != newStatus {
mapping.Status = newStatus
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping after temporary error.")
return
}
if oldStatus != newStatus {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
}
return
}
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
}
}
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -832,7 +808,7 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
}
oldMasterStatus := key.MasterStatus
newMasterStatus := models.MasterStatusRevoked
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
return
}
@@ -844,9 +820,9 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(masterKeyEvent)
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
_ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData)
}
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/dashboard_query_service.go
package service
import (
"context"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
@@ -17,8 +17,6 @@ import (
const overviewCacheChannel = "syncer:cache:dashboard_overview"
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
type DashboardQueryService struct {
db *gorm.DB
store store.Store
@@ -54,9 +52,9 @@ func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
}
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(statsKey)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
@@ -74,11 +72,11 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
SuccessRequests int64
}
var last1Hour, last24Hours requestStatsResult
s.db.Model(&models.StatsHourly{}).
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour)
s.db.Model(&models.StatsHourly{}).
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours)
@@ -109,8 +107,9 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
}
func (s *DashboardQueryService) eventListener() {
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
ctx := context.Background()
keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close()
defer upstreamStatusSub.Close()
for {
@@ -128,7 +127,6 @@ func (s *DashboardQueryService) eventListener() {
}
}
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
@@ -141,8 +139,7 @@ func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate()
}
// QueryHistoricalChart 查询历史图表数据。
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
@@ -151,7 +148,7 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
@@ -189,38 +186,38 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
}
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx := context.Background()
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
startTime := time.Now()
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体而不是nil
RequestCount24h: models.StatCard{}, // 同上
KeyCount: models.StatCard{},
RequestCount24h: models.StatCard{},
TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{},
RequestCounts: make(map[string]int64),
}
// --- 1. Aggregate Operational Status from Mappings ---
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var mappingStatusResults []MappingStatusResult
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
}
for _, res := range mappingStatusResults {
resp.KeyStatusCount[res.Status] = res.Count
}
// --- 2. Aggregate Master Status from APIKeys ---
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var masterStatusResults []MasterStatusResult
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query master status stats: %w", err)
}
var totalKeys, invalidKeys int64
@@ -235,20 +232,15 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
now := time.Now()
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
var count1m, count1h, count1d int64
// RPM: 从此刻倒推1分钟
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
// RPH: 从此刻倒推1小时
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
// RPD: 从今天零点 (UTC) 到此刻
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
year, month, day := now.UTC().Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
var count30d int64
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
@@ -256,7 +248,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
resp.RequestCounts["30d"] = count30d
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
} else {
for _, u := range upstreams {
@@ -269,7 +261,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
return resp, nil
}
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
var startTime time.Time
now := time.Now()
switch period {
@@ -288,7 +280,7 @@ func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H,
Success int64
}
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/db_log_writer_service.go
package service
import (
"context"
"encoding/json"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
@@ -35,35 +35,30 @@ func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.Settin
store: s,
SettingsManager: settings,
logger: logger.WithField("component", "DBLogWriter📝"),
// 使用配置值来创建缓冲区
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
}
}
func (s *DBLogWriterService) Start() {
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
// 启动事件监听器
s.wg.Add(2)
go s.eventListenerLoop()
// 启动数据库写入器
go s.dbWriterLoop()
s.logger.Info("DBLogWriterService started.")
}
func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan) // 通知所有goroutine停止
s.wg.Wait() // 等待所有goroutine完成
close(s.stopChan)
s.wg.Wait()
s.logger.Info("DBLogWriterService stopped.")
}
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
ctx := context.Background()
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
@@ -80,34 +75,27 @@ func (s *DBLogWriterService) eventListenerLoop() {
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
continue
}
// 将事件中的日志部分放入缓冲区
select {
case s.logBuffer <- &event.RequestLog:
default:
s.logger.Warn("Log buffer is full. A log message might be dropped.")
}
case <-s.stopChan:
s.logger.Info("Event listener loop stopping.")
// 关闭缓冲区以通知dbWriterLoop处理完剩余日志后退出
close(s.logBuffer)
return
}
}
}
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done()
// 在启动时获取一次配置
cfg := s.SettingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 {
batchSize = 100
}
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushTimeout <= 0 {
flushTimeout = 5 * time.Second
@@ -126,7 +114,7 @@ func (s *DBLogWriterService) dbWriterLoop() {
return
}
batch = append(batch, logEntry)
if len(batch) >= batchSize { // 使用配置的批次大小
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
@@ -139,7 +127,6 @@ func (s *DBLogWriterService) dbWriterLoop() {
}
}
// flushBatch 将一个批次的日志写入数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")

View File

@@ -75,7 +75,7 @@ func NewHealthCheckService(
func (s *HealthCheckService) Start() {
s.logger.Info("Starting HealthCheckService with independent check loops...")
s.wg.Add(4) // Now four loops
s.wg.Add(4)
go s.runKeyCheckLoop()
go s.runUpstreamCheckLoop()
go s.runProxyCheckLoop()
@@ -102,8 +102,6 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
func (s *HealthCheckService) runKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Key check dynamic scheduler loop started.")
// 主调度循环,每分钟检查一次任务
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -126,26 +124,22 @@ func (s *HealthCheckService) scheduleKeyChecks() {
defer s.groupCheckTimeMutex.Unlock()
for _, group := range groups {
// 获取特定于组的运营配置
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
continue
}
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
continue // 跳过禁用了健康检查的组
continue
}
var intervalMinutes int
if opConfig.KeyCheckIntervalMinutes != nil {
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
}
interval := time.Duration(intervalMinutes) * time.Minute
if interval <= 0 {
continue // 跳过无效的检查周期
continue
}
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
go s.performKeyChecksForGroup(group, opConfig)
@@ -160,7 +154,6 @@ func (s *HealthCheckService) runUpstreamCheckLoop() {
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
s.performUpstreamChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -184,7 +177,6 @@ func (s *HealthCheckService) runProxyCheckLoop() {
if s.SettingsManager.GetSettings().EnableProxyCheck {
s.performProxyChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -203,6 +195,7 @@ func (s *HealthCheckService) runProxyCheckLoop() {
}
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -213,11 +206,9 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
}
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
log.Infof("Starting key health check cycle.")
var mappingsToCheck []models.GroupAPIKeyMapping
err = s.db.Model(&models.GroupAPIKeyMapping{}).
err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
Where("group_api_key_mappings.key_group_id = ?", group.ID).
Where("api_keys.master_status = ?", models.MasterStatusActive).
@@ -233,7 +224,6 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("No key mappings to check for this group.")
return
}
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
var wg sync.WaitGroup
@@ -242,14 +232,14 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 1 // 保证至少有一个 worker
concurrency = 1
}
for w := 1; w <= concurrency; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for mapping := range jobs {
s.checkAndProcessMapping(&mapping, timeout, endpoint)
s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
}
}(w)
}
@@ -261,52 +251,46 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("Finished key health check cycle.")
}
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
if mapping.APIKey == nil {
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
return
}
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
// --- 诊断一:验证成功 (健康) ---
if validationErr == nil {
if mapping.Status != models.StatusActive {
s.activateMapping(mapping)
s.activateMapping(ctx, mapping)
}
return
}
errorString := validationErr.Error()
// --- 诊断二:永久性错误 ---
if CustomErrors.IsPermanentUpstreamError(errorString) {
s.revokeMapping(mapping, validationErr)
s.revokeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断三:暂时性错误 ---
if CustomErrors.IsTemporaryUpstreamError(errorString) {
// Log with a higher level (WARN) since this is an actionable, proactive finding.
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
s.penalizeMapping(mapping, validationErr)
s.penalizeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断四:其他未知或上游服务错误 ---
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
}
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
oldStatus := mapping.Status
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
// Re-fetch group-specific operational config to get the correct thresholds
func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
if !ok {
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
@@ -320,7 +304,6 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
oldStatus := mapping.Status
mapping.LastError = err.Error()
mapping.ConsecutiveErrorCount++
// Use the group-specific threshold
threshold := *opConfig.KeyBlacklistThreshold
if mapping.ConsecutiveErrorCount >= threshold {
mapping.Status = models.StatusCooldown
@@ -329,44 +312,41 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
mapping.CooldownUntil = &cooldownTime
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
}
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
if oldStatus != mapping.Status {
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
}
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
oldStatus := mapping.Status
if oldStatus == models.StatusBanned {
return // Already banned, do nothing.
return
}
mapping.Status = models.StatusBanned
mapping.LastError = "Definitive error: " + err.Error()
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
mapping.ConsecutiveErrorCount = 0
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
}
}
func (s *HealthCheckService) performUpstreamChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
return
}
@@ -390,10 +370,10 @@ func (s *HealthCheckService) performUpstreamChecks() {
s.lastResultsMutex.Unlock()
if oldStatus != newStatus {
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
} else {
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
}
}
}(u)
@@ -412,10 +392,11 @@ func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration)
}
func (s *HealthCheckService) performProxyChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
var proxies []*models.ProxyConfig
if err := s.db.Find(&proxies).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve proxies.")
return
}
@@ -438,7 +419,7 @@ func (s *HealthCheckService) performProxyChecks() {
s.lastResultsMutex.Unlock()
if proxyCfg.Status != newStatus {
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
}
}
@@ -482,7 +463,7 @@ func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout ti
return true
}
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
event := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -496,12 +477,12 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, o
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
return
}
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
}
}
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
event := models.UpstreamHealthChangedEvent{
UpstreamID: upstream.ID,
UpstreamURL: upstream.URL,
@@ -516,28 +497,20 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
return
}
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
}
}
// =========================================================================
// Global Base Key Check (New Logic)
// =========================================================================
func (s *HealthCheckService) runBaseKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Global base key check loop started.")
settings := s.SettingsManager.GetSettings()
if !settings.EnableBaseKeyCheck {
s.logger.Info("Global base key check is disabled.")
return
}
// Perform an initial check on startup
s.performBaseKeyChecks()
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
if interval <= 0 {
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
@@ -558,6 +531,7 @@ func (s *HealthCheckService) runBaseKeyCheckLoop() {
}
func (s *HealthCheckService) performBaseKeyChecks() {
ctx := context.Background()
s.logger.Info("Starting global base key check cycle.")
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -576,7 +550,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
jobs := make(chan *models.APIKey, len(keys))
var wg sync.WaitGroup
if concurrency <= 0 {
concurrency = 5 // Safe default
concurrency = 5
}
for w := 0; w < concurrency; w++ {
wg.Add(1)
@@ -587,10 +561,10 @@ func (s *HealthCheckService) performBaseKeyChecks() {
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
oldStatus := key.MasterStatus
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
} else {
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
}
}
}
@@ -604,8 +578,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
s.logger.Info("Global base key check cycle finished.")
}
// 事件发布辅助函数
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
event := models.MasterKeyStatusChangedEvent{
KeyID: keyID,
OldMasterStatus: oldStatus,
@@ -618,7 +591,7 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldS
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
return
}
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
}
}

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -42,88 +43,84 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
}
}
// --- 通用的 Panic-Safe 任務執行器 ---
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err)
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
taskFunc()
}
// --- Public Task Starters ---
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in input text")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
})
return taskStatus, nil
}
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_hard_delete" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_restore_keys" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
// --- Private Task Runners ---
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 步骤 1: 对输入的原始 key 列表进行去重。
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeyStrings []string
for _, kStr := range keys {
@@ -133,41 +130,37 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
}
}
if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return
}
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
}
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
return
}
alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
var keysToLink []models.APIKey
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
}
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
// 步骤 6: 分块处理【链接Key到组】的操作并实时更新进度。
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
@@ -179,44 +172,41 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
end = len(idsToLink)
}
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
return
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
}
// 步骤 7: 准备最终结果并结束任务。
result := gin.H{
"newly_linked_count": len(keysToLink),
"already_linked_count": len(alreadyLinkedIDSet),
"total_linked_count": len(allKeyModels),
}
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
if validateOnImport {
s.publishImportGroupCompletedEvent(groupID, idsToLink)
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
}
}
}
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
// runUnlinkKeysTask
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeys []string
for _, kStr := range keys {
@@ -225,46 +215,42 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
uniqueKeys = append(uniqueKeys, kStr)
}
}
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
}
if len(keysToUnlink) == 0 {
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
return
}
idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID
}
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
var totalUnlinked int64
// 步骤 3: 分块处理【解绑Key】的操作并上报进度。
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize
if end > len(idsToUnlink) {
end = len(idsToUnlink)
}
chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
return
}
totalUnlinked += unlinked
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
@@ -276,10 +262,10 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
@@ -290,22 +276,21 @@ func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return
}
totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var restoredCount int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
@@ -316,21 +301,21 @@ func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return
}
restoredCount += count
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
KeyID: keyID,
@@ -340,7 +325,7 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
@@ -349,16 +334,16 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
}
}
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
ChangeReason: reason,
}
eventData, _ := json.Marshal(event)
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 {
return
}
@@ -372,17 +357,15 @@ func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
return
}
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
} else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
}
}
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
// 1. [New] Find the keys to operate on.
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
@@ -390,8 +373,7 @@ func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(groupID, keysAsText)
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/channel"
@@ -62,20 +63,18 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
return fmt.Errorf("failed to create request: %w", err)
}
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
s.channel.ModifyRequest(req, key)
resp, err := client.Do(req)
if err != nil {
// This is a network-level error (e.g., timeout, DNS issue)
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil // Success
return nil
}
// Read the body for more error details
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
@@ -84,7 +83,6 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
errorMsg = string(bodyBytes)
}
// This is a validation failure with a specific HTTP status code
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
@@ -92,8 +90,7 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
}
}
// --- 异步任务方法 (全面适配新task包) ---
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
@@ -111,7 +108,6 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
}
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
// [FIX] Correctly use the NewAPIError constructor for a missing group.
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
@@ -119,15 +115,15 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil {
return nil, err // Pass up the error from task service (e.g., "task already running")
return nil, err
}
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
return nil, err
}
var concurrency int
@@ -136,11 +132,11 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
} else {
concurrency = settings.BaseKeyCheckConcurrency
}
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
return taskStatus, nil
}
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup
var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys))
@@ -165,7 +161,6 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
GroupID: &groupID,
KeyID: &apiKeyModel.ID,
},
@@ -185,14 +180,15 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
event.RequestLog.IsSuccess = false
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
}
mu.Lock()
finalResults[j.Index] = currentResult
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
mu.Unlock()
}
}()
@@ -202,10 +198,10 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
}
close(jobs)
wg.Wait()
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
}
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
@@ -216,5 +212,5 @@ func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(groupID, keysAsText)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -3,6 +3,7 @@
package service
import (
"context"
"errors"
apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -43,7 +44,6 @@ func NewResourceService(
aks *APIKeyService,
logger *logrus.Logger,
) *ResourceService {
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
rs := &ResourceService{
settingsManager: sm,
groupManager: gm,
@@ -56,43 +56,40 @@ func NewResourceService(
go rs.preWarmCache(logger)
})
return rs
}
// --- [模式一:智能聚合模式] ---
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.")
// 1.筛选出所有符合条件的候选组,并按优先级排序
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable
}
// 2.从 BasePool中根据系统全局策略选择一个Key
basePool := &repository.BasePool{
CandidateGroups: candidateGroups,
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
}
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the BasePool.")
return nil, apperrors.ErrNoKeysAvailable
}
// 3. 组装最终资源
// [关键] 在此模式下RequestConfig 永远是空的,以保证透明性。
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
if err != nil {
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err
}
resources.RequestConfig = &models.RequestConfig{} // 强制为空
resources.RequestConfig = &models.RequestConfig{}
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil
}
// --- [模式二:精确路由模式] ---
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.")
@@ -101,12 +98,11 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou
if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
}
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
}
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
return nil, apperrors.ErrNoKeysAvailable
@@ -132,7 +128,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
if authToken.IsAdmin {
for _, group := range allGroups {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
@@ -144,7 +139,6 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
for _, group := range allGroups {
if _, ok := allowedGroupIDs[group.ID]; ok {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
@@ -164,14 +158,6 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
}
var proxyConfig *models.ProxyConfig
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
// if group.EnableProxy && s.proxyModule != nil {
// var err error
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
// if err != nil {
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
// }
// }
return &RequestResources{
KeyGroup: group,
APIKey: apiKey,
@@ -194,7 +180,7 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
time.Sleep(2 * time.Second)
s.logger.Info("Performing initial key cache pre-warming...")
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil {
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
return err
}
@@ -209,7 +195,6 @@ func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup
// 1. 确定权限范围
allowedGroupIDs := make(map[uint]bool)
isTokenRestricted := len(allowedGroupsFromToken) > 0
if isTokenRestricted {
@@ -217,15 +202,12 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed
allowedGroupIDs[ag.ID] = true
}
}
// 2. 筛选
for _, group := range allGroupsFromCache {
// 检查Token权限
if isTokenRestricted && !allowedGroupIDs[group.ID] {
continue
}
// 检查模型是否被允许
isModelAllowed := false
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
if len(group.AllowedModels) == 0 {
isModelAllowed = true
} else {
for _, m := range group.AllowedModels {
@@ -239,8 +221,6 @@ func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowed
candidateGroups = append(candidateGroups, group)
}
}
// 3.按 Order 字段升序排序
sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order
})

View File

@@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke
// IsIPBanned
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
banKey := fmt.Sprintf("banned_ip:%s", ip)
return s.store.Exists(banKey)
return s.store.Exists(ctx, banKey)
}
// RecordFailedLoginAttempt
@@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
return nil
}
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1)
if err != nil {
return err
}
@@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
banDuration := s.SettingsManager.GetIPBanDuration()
banKey := fmt.Sprintf("banned_ip:%s", ip)
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil {
return err
}
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
s.store.HDel(loginAttemptsKey, ip)
s.store.HDel(ctx, loginAttemptsKey, ip)
}
return nil

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -34,7 +35,7 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
@@ -67,42 +68,43 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
return
}
ctx := context.Background()
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", 1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
}
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
var results []struct {
Status models.APIKeyStatus
Count int64
}
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ?", groupID).
Select("status, COUNT(*) as count").
Group("status").
@@ -119,37 +121,25 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
}
updates["total_keys"] = totalKeys
if err := s.store.Del(statsKey); err != nil {
if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
}
if err := s.store.HSet(statsKey, updates); err != nil {
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
}
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
return nil
}
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
// TODO 逻辑:
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
// 3. 组合成 DashboardStatsResponse
// ... 这个方法的具体实现我们可以在DashboardQueryService中完成
// 这里我们先确保StatsService的核心职责维护缓存已经完成。
// 为了编译通过,我们先返回一个空对象。
// 伪代码:
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
// ...
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
return &models.DashboardStatsResponse{}, nil
}
func (s *StatsService) AggregateHourlyStats() error {
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
s.logger.Info("Starting aggregation of the last hour's request data...")
now := time.Now()
endTime := now.Truncate(time.Hour) // 例如15:23 -> 15:00
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
endTime := now.Truncate(time.Hour)
startTime := endTime.Add(-1 * time.Hour)
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
type aggregationResult struct {
@@ -161,7 +151,8 @@ func (s *StatsService) AggregateHourlyStats() error {
CompletionTokens int64
}
var results []aggregationResult
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Group("group_id, model_name").
@@ -179,7 +170,7 @@ func (s *StatsService) AggregateHourlyStats() error {
var hourlyStats []models.StatsHourly
for _, res := range results {
hourlyStats = append(hourlyStats, models.StatsHourly{
Time: startTime, // 所有记录的时间戳都是该小时的起点
Time: startTime,
GroupID: res.GroupID,
ModelName: res.ModelName,
RequestCount: res.RequestCount,
@@ -189,7 +180,7 @@ func (s *StatsService) AggregateHourlyStats() error {
})
}
return s.db.Clauses(clause.OnConflict{
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error

View File

@@ -1,3 +1,4 @@
// Filename: internal/store/factory.go
package store
import (
@@ -11,7 +12,6 @@ import (
// NewStore creates a new store based on the application configuration.
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
// 检查是否有Redis配置
if cfg.Redis.DSN != "" {
opts, err := redis.ParseURL(cfg.Redis.DSN)
if err != nil {
@@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
client := redis.NewClient(opts)
if err := client.Ping(context.Background()).Err(); err != nil {
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
return NewMemoryStore(logger), nil
}
logger.Info("Successfully connected to Redis. Using Redis as store.")
return NewRedisStore(client), nil
return NewRedisStore(client, logger), nil
}
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
return NewMemoryStore(logger), nil

View File

@@ -1,8 +1,9 @@
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
// Filename: internal/store/memory_store.go
package store
import (
"context"
"fmt"
"math/rand"
"sort"
@@ -12,6 +13,7 @@ import (
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct {
@@ -32,7 +34,6 @@ type memoryStore struct {
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
@@ -42,7 +43,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
// 使用当前时间作为种子,创建一个新的随机数源
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
@@ -50,13 +50,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
return store
}
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now() // 避免在循环中重复调用
now := time.Now()
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
@@ -66,92 +65,10 @@ func (s *memoryStore) startGCollector() {
}
}
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
// 安全地进行类型断言
mainSet, mainOk = mainItem.value.(map[string]struct{})
// 确保断言成功且集合不为空
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
// 安全地进行类型断言
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
// 安全地处理冷却池
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
// SRandMember [并发修复版] 使用带锁的rng
func (s *memoryStore) SRandMember(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
@@ -162,7 +79,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
return nil
}
func (s *memoryStore) Get(key string) ([]byte, error) {
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -175,7 +92,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
return nil, ErrNotFound
}
func (s *memoryStore) Del(keys ...string) error {
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
@@ -184,14 +101,14 @@ func (s *memoryStore) Del(keys ...string) error {
return nil
}
func (s *memoryStore) Exists(key string) (bool, error) {
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -208,7 +125,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error {
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -223,7 +140,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
return nil
}
func (s *memoryStore) HSet(key string, values map[string]any) error {
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -242,7 +159,7 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
return nil
}
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -259,7 +176,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -281,7 +198,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
return newVal, nil
}
func (s *memoryStore) LPush(key string, values ...any) error {
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -301,7 +218,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
return nil
}
func (s *memoryStore) LRem(key string, count int64, value any) error {
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -326,7 +243,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
return nil
}
func (s *memoryStore) SAdd(key string, members ...any) error {
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -345,7 +262,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
return nil
}
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -375,7 +292,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
return popped, nil
}
func (s *memoryStore) SMembers(key string) ([]string, error) {
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -393,7 +310,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
return members, nil
}
func (s *memoryStore) SRem(key string, members ...any) error {
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -410,7 +327,31 @@ func (s *memoryStore) SRem(key string, members ...any) error {
return nil
}
func (s *memoryStore) Rotate(key string) (string, error) {
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -426,7 +367,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -447,8 +388,7 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
return list[index], nil
}
// Zset methods... (ZAdd, ZRange, ZRem)
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -471,8 +411,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
// NOTE: This ZSet implementation is simple but not performant for large sets.
// A production implementation would use a skip list or a balanced tree.
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
@@ -482,7 +420,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -515,7 +453,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
}
return result, nil
}
func (s *memoryStore) ZRem(key string, members ...any) error {
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -540,13 +478,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
return nil
}
// Pipeline implementation
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline() Pipeliner {
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
@@ -559,7 +540,6 @@ func (p *memoryPipeliner) Exec() error {
}
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
capturedKey := key
p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok {
@@ -596,7 +576,6 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -615,7 +594,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
@@ -637,11 +615,12 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
// --- Pub/Sub implementation (remains unchanged) ---
type memorySubscription struct {
store *memoryStore
channelName string
@@ -649,10 +628,11 @@ type memorySubscription struct {
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(channel string, message []byte) error {
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
@@ -669,7 +649,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
}
return nil
}
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)

View File

@@ -1,3 +1,5 @@
// Filename: internal/store/redis_store.go
package store
import (
@@ -8,22 +10,20 @@ import (
"time"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
)
// ensure RedisStore implements Store interface
var _ Store = (*RedisStore)(nil)
// RedisStore is a Redis-backed key-value store.
type RedisStore struct {
client *redis.Client
popAndCycleScript *redis.Script
logger *logrus.Entry
}
// NewRedisStore creates a new RedisStore instance.
func NewRedisStore(client *redis.Client) Store {
// Lua script for atomic pop-and-cycle operation.
// KEYS[1]: main set key
// KEYS[2]: cooldown set key
func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
const script = `
if redis.call('SCARD', KEYS[1]) == 0 then
if redis.call('SCARD', KEYS[2]) == 0 then
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
return &RedisStore{
client: client,
popAndCycleScript: redis.NewScript(script),
logger: logger.WithField("component", "store.redis 🗄️"),
}
}
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
return s.client.Set(context.Background(), key, value, ttl).Err()
func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.client.Set(ctx, key, value, ttl).Err()
}
func (s *RedisStore) Get(key string) ([]byte, error) {
val, err := s.client.Get(context.Background(), key).Bytes()
func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
val, err := s.client.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrNotFound
@@ -54,53 +55,53 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
return val, nil
}
func (s *RedisStore) Del(keys ...string) error {
func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
return s.client.Del(context.Background(), keys...).Err()
return s.client.Del(ctx, keys...).Err()
}
func (s *RedisStore) Exists(key string) (bool, error) {
val, err := s.client.Exists(context.Background(), key).Result()
func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
val, err := s.client.Exists(ctx, key).Result()
return val > 0, err
}
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(context.Background(), key, value, ttl).Result()
func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(ctx, key, value, ttl).Result()
}
func (s *RedisStore) Close() error {
return s.client.Close()
}
func (s *RedisStore) HSet(key string, values map[string]any) error {
return s.client.HSet(context.Background(), key, values).Err()
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(ctx, key, values).Err()
}
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
return s.client.HGetAll(context.Background(), key).Result()
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(ctx, key).Result()
}
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(ctx, key, field, incr).Result()
}
func (s *RedisStore) HDel(key string, fields ...string) error {
func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
if len(fields) == 0 {
return nil
}
return s.client.HDel(context.Background(), key, fields...).Err()
return s.client.HDel(ctx, key, fields...).Err()
}
func (s *RedisStore) LPush(key string, values ...any) error {
return s.client.LPush(context.Background(), key, values...).Err()
func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
return s.client.LPush(ctx, key, values...).Err()
}
func (s *RedisStore) LRem(key string, count int64, value any) error {
return s.client.LRem(context.Background(), key, count, value).Err()
func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
return s.client.LRem(ctx, key, count, value).Err()
}
func (s *RedisStore) Rotate(key string) (string, error) {
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
val, err := s.client.RPopLPush(ctx, key, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -110,29 +111,28 @@ func (s *RedisStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *RedisStore) SAdd(key string, members ...any) error {
return s.client.SAdd(context.Background(), key, members...).Err()
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(ctx, key, members...).Err()
}
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
return s.client.SPopN(context.Background(), key, count).Result()
func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
return s.client.SPopN(ctx, key, count).Result()
}
func (s *RedisStore) SMembers(key string) ([]string, error) {
return s.client.SMembers(context.Background(), key).Result()
func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
return s.client.SMembers(ctx, key).Result()
}
func (s *RedisStore) SRem(key string, members ...any) error {
func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.SRem(context.Background(), key, members...).Err()
return s.client.SRem(ctx, key, members...).Err()
}
func (s *RedisStore) SRandMember(key string) (string, error) {
member, err := s.client.SRandMember(context.Background(), key).Result()
func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
member, err := s.client.SRandMember(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
@@ -141,81 +141,43 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
return member, nil
}
// === 新增方法实现 ===
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
if len(members) == 0 {
return nil
}
redisMembers := make([]redis.Z, 0, len(members))
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
return s.client.ZAdd(ctx, key, redisMembers...).Err()
}
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
return s.client.ZRange(context.Background(), key, start, stop).Result()
func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
return s.client.ZRange(ctx, key, start, stop).Result()
}
func (s *RedisStore) ZRem(key string, members ...any) error {
func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.ZRem(context.Background(), key, members...).Err()
return s.client.ZRem(ctx, key, members...).Err()
}
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
// Lua script returns a string, so we need to type assert
if str, ok := val.(string); ok {
return str, nil
}
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
return "", ErrNotFound
}
type redisPipeliner struct{ pipe redis.Pipeliner }
func (p *redisPipeliner) HSet(key string, values map[string]any) {
p.pipe.HSet(context.Background(), key, values)
}
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(context.Background(), key, field, incr)
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(context.Background())
return err
}
func (p *redisPipeliner) Del(keys ...string) {
if len(keys) > 0 {
p.pipe.Del(context.Background(), keys...)
}
}
func (p *redisPipeliner) SAdd(key string, members ...any) {
p.pipe.SAdd(context.Background(), key, members...)
}
func (p *redisPipeliner) SRem(key string, members ...any) {
if len(members) > 0 {
p.pipe.SRem(context.Background(), key, members...)
}
}
func (p *redisPipeliner) LPush(key string, values ...any) {
p.pipe.LPush(context.Background(), key, values...)
}
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(context.Background(), key, count, value)
}
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
val, err := s.client.LIndex(context.Background(), key, index).Result()
func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
val, err := s.client.LIndex(ctx, key, index).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -225,47 +187,120 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
return val, nil
}
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(context.Background(), key, expiration)
type redisPipeliner struct {
pipe redis.Pipeliner
ctx context.Context
}
func (s *RedisStore) Pipeline() Pipeliner {
return &redisPipeliner{pipe: s.client.Pipeline()}
func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
return &redisPipeliner{
pipe: s.client.Pipeline(),
ctx: ctx,
}
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(p.ctx)
return err
}
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(p.ctx, key, expiration)
}
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(p.ctx, key, field, incr)
}
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
if len(members) == 0 {
return
}
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
p.pipe.ZAdd(p.ctx, key, redisMembers...)
}
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
type redisSubscription struct {
pubsub *redis.PubSub
msgChan chan *Message
once sync.Once
pubsub *redis.PubSub
msgChan chan *Message
logger *logrus.Entry
wg sync.WaitGroup
close context.CancelFunc
channelName string
}
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
pubsub := s.client.Subscribe(ctx, channel)
_, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
subCtx, cancel := context.WithCancel(context.Background())
sub := &redisSubscription{
pubsub: pubsub,
msgChan: make(chan *Message, 10),
logger: s.logger,
close: cancel,
channelName: channel,
}
sub.wg.Add(1)
go sub.bridge(subCtx)
return sub, nil
}
func (rs *redisSubscription) bridge(ctx context.Context) {
defer rs.wg.Done()
defer close(rs.msgChan)
redisCh := rs.pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case redisMsg, ok := <-redisCh:
if !ok {
return
}
msg := &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
select {
case rs.msgChan <- msg:
default:
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
}
}
}
}
func (rs *redisSubscription) Channel() <-chan *Message {
rs.once.Do(func() {
rs.msgChan = make(chan *Message)
go func() {
defer close(rs.msgChan)
for redisMsg := range rs.pubsub.Channel() {
rs.msgChan <- &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
}
}()
})
return rs.msgChan
}
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
func (s *RedisStore) Publish(channel string, message []byte) error {
return s.client.Publish(context.Background(), channel, message).Err()
func (rs *redisSubscription) ChannelName() string {
return rs.channelName
}
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
pubsub := s.client.Subscribe(context.Background(), channel)
_, err := pubsub.Receive(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
return &redisSubscription{pubsub: pubsub}, nil
func (rs *redisSubscription) Close() error {
rs.close()
err := rs.pubsub.Close()
rs.wg.Wait()
return err
}
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
return s.client.Publish(ctx, channel, message).Err()
}

View File

@@ -1,6 +1,9 @@
// Filename: internal/store/store.go
package store
import (
"context"
"errors"
"time"
)
@@ -17,6 +20,7 @@ type Message struct {
// Subscription represents an active subscription to a pub/sub channel.
type Subscription interface {
Channel() <-chan *Message
ChannelName() string
Close() error
}
@@ -38,6 +42,10 @@ type Pipeliner interface {
LPush(key string, values ...any)
LRem(key string, count int64, value any)
// ZSET
ZAdd(key string, members map[string]float64)
ZRem(key string, members ...any)
// Execution
Exec() error
}
@@ -45,44 +53,44 @@ type Pipeliner interface {
// Store is the master interface for our cache service.
type Store interface {
// Basic K/V operations
Set(key string, value []byte, ttl time.Duration) error
Get(key string) ([]byte, error)
Del(keys ...string) error
Exists(key string) (bool, error)
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
Get(ctx context.Context, key string) ([]byte, error)
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
// HASH operations
HSet(key string, values map[string]any) error
HGetAll(key string) (map[string]string, error)
HIncrBy(key, field string, incr int64) (int64, error)
HDel(key string, fields ...string) error // [新增]
HSet(ctx context.Context, key string, values map[string]any) error
HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(ctx context.Context, key string, fields ...string) error
// LIST operations
LPush(key string, values ...any) error
LRem(key string, count int64, value any) error
Rotate(key string) (string, error)
LIndex(key string, index int64) (string, error)
LPush(ctx context.Context, key string, values ...any) error
LRem(ctx context.Context, key string, count int64, value any) error
Rotate(ctx context.Context, key string) (string, error)
LIndex(ctx context.Context, key string, index int64) (string, error)
// SET operations
SAdd(key string, members ...any) error
SPopN(key string, count int64) ([]string, error)
SMembers(key string) ([]string, error)
SRem(key string, members ...any) error
SRandMember(key string) (string, error)
SAdd(ctx context.Context, key string, members ...any) error
SPopN(ctx context.Context, key string, count int64) ([]string, error)
SMembers(ctx context.Context, key string) ([]string, error)
SRem(ctx context.Context, key string, members ...any) error
SRandMember(ctx context.Context, key string) (string, error)
// Pub/Sub operations
Publish(channel string, message []byte) error
Subscribe(channel string) (Subscription, error)
Publish(ctx context.Context, channel string, message []byte) error
Subscribe(ctx context.Context, channel string) (Subscription, error)
// Pipeline (optional) - 我们在redis实现它内存版暂时不实现
Pipeline() Pipeliner
// Pipeline
Pipeline(ctx context.Context) Pipeliner
// Close closes the store and releases any underlying resources.
Close() error
// === 新增方法,支持轮询策略 ===
ZAdd(key string, members map[string]float64) error
ZRange(key string, start, stop int64) ([]string, error)
ZRem(key string, members ...any) error
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
// ZSET operations
ZAdd(ctx context.Context, key string, members map[string]float64) error
ZRange(ctx context.Context, key string, start, stop int64) ([]string, error)
ZRem(ctx context.Context, key string, members ...any) error
PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error)
}

View File

@@ -1,6 +1,7 @@
package syncer
import (
"context"
"fmt"
"gemini-balancer/internal/store"
"log"
@@ -51,7 +52,7 @@ func (s *CacheSyncer[T]) Get() T {
func (s *CacheSyncer[T]) Invalidate() error {
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
return s.store.Publish(s.channelName, []byte("reload"))
return s.store.Publish(context.Background(), s.channelName, []byte("reload"))
}
func (s *CacheSyncer[T]) Stop() {
@@ -84,7 +85,7 @@ func (s *CacheSyncer[T]) listenForUpdates() {
default:
}
subscription, err := s.store.Subscribe(s.channelName)
subscription, err := s.store.Subscribe(context.Background(), s.channelName)
if err != nil {
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
time.Sleep(5 * time.Second)

View File

@@ -1,7 +1,8 @@
// Filename: internal/task/task.go (最终校准版)
// Filename: internal/task/task.go
package task
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -15,15 +16,13 @@ const (
ResultTTL = 60 * time.Minute
)
// Reporter 接口,定义了领域如何与任务服务交互。
type Reporter interface {
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(taskID string, processed int) error
UpdateTotalByID(taskID string, total int) error
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
UpdateTotalByID(ctx context.Context, taskID string, total int) error
}
// Status 代表一个后台任务的完整状态
type Status struct {
ID string `json:"id"`
TaskType string `json:"task_type"`
@@ -38,13 +37,11 @@ type Status struct {
DurationSeconds float64 `json:"duration_seconds,omitempty"`
}
// Task 是任务管理的核心服务
type Task struct {
store store.Store
logger *logrus.Entry
}
// NewTask 是 Task 的构造函数
func NewTask(store store.Store, logger *logrus.Logger) *Task {
return &Task{
store: store,
@@ -62,15 +59,14 @@ func (s *Task) getTaskDataKey(taskID string) string {
return fmt.Sprintf("task:data:%s", taskID)
}
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
func (s *Task) getIsRunningFlagKey(taskID string) string {
return fmt.Sprintf("task:running:%s", taskID)
}
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID)
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
}
@@ -94,35 +90,34 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int
timeout = ResultTTL * 24
}
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
}
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(lockKey)
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
}
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(lockKey)
_ = s.store.Del(taskKey)
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
_ = s.store.Del(ctx, taskKey)
return nil, fmt.Errorf("failed to set task running flag: %w", err)
}
return status, nil
}
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID)
defer func() {
if err := s.store.Del(lockKey); err != nil {
if err := s.store.Del(ctx, lockKey); err != nil {
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
}
}()
runningFlagKey := s.getIsRunningFlagKey(taskID)
_ = s.store.Del(runningFlagKey)
status, err := s.GetStatus(taskID)
if err != nil {
_ = s.store.Del(ctx, runningFlagKey)
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
return
}
@@ -141,15 +136,14 @@ func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr er
}
updatedTaskBytes, _ := json.Marshal(status)
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
}
}
// GetStatus 通过ID获取任务状态供外部如API Handler调用
func (s *Task) GetStatus(taskID string) (*Status, error) {
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
taskKey := s.getTaskDataKey(taskID)
statusBytes, err := s.store.Get(taskKey)
statusBytes, err := s.store.Get(ctx, taskKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found")
@@ -161,22 +155,18 @@ func (s *Task) GetStatus(taskID string) (*Status, error) {
if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
}
if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
}
return &status, nil
}
// UpdateProgressByID 通过ID更新任务进度
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(runningFlagKey); err != nil {
// 任务已结束,静默返回是预期行为。
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
return nil
}
status, err := s.GetStatus(taskID)
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
return nil
@@ -184,7 +174,6 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
if !status.IsRunning {
return nil
}
// 调用传入的 updater 函数来修改 status
updater(status)
statusBytes, marshalErr := json.Marshal(status)
if marshalErr != nil {
@@ -192,23 +181,20 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
return nil
}
taskKey := s.getTaskDataKey(taskID)
// 使用更长的TTL确保运行中的任务不会过早过期
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
}
return nil
}
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
return s.updateTask(taskID, func(status *Status) {
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Processed = processed
})
}
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateTotalByID(taskID string, total int) error {
return s.updateTask(taskID, func(status *Status) {
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Total = total
})
}