Compare commits

...

2 Commits

Author SHA1 Message Date
6c7283d51b Fix basepool & 优化 repo 2025-11-23 22:42:58 +08:00
2b0b9b67dc Update Context for store 2025-11-22 14:20:05 +08:00
37 changed files with 2077 additions and 1687 deletions

View File

@@ -28,12 +28,6 @@ type GeminiChannel struct {
httpClient *http.Client
}
// 用于安全提取信息的本地结构体
type requestMetadata struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -47,38 +41,50 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini
logger: logger,
httpClient: &http.Client{
Transport: transport,
Timeout: 0,
Timeout: 0, // Timeout is handled by the request context
},
}
}
// TransformRequest
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
modelName = ch.ExtractModel(c, requestBody)
return requestBody, modelName, nil
}
// ExtractModel
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
return ch.extractModelFromRequest(c, bodyBytes)
}
// 统一的模型提取逻辑优先从请求体解析失败则回退到从URL路径解析。
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
var p struct {
Model string `json:"model"`
}
_ = json.Unmarshal(requestBody, &p)
modelName = strings.TrimPrefix(p.Model, "models/")
if modelName == "" {
modelName = ch.extractModelFromPath(c.Request.URL.Path)
if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
return strings.TrimPrefix(p.Model, "models/")
}
return requestBody, modelName, nil
return ch.extractModelFromPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) extractModelFromPath(path string) string {
parts := strings.Split(path, "/")
for _, part := range parts {
// 覆盖更多模型名称格式
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
modelPart := strings.Split(part, ":")[0]
return modelPart
return strings.Split(part, ":")[0]
}
}
return ""
}
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
path := c.Request.URL.Path
return ch.isOpenAIPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
}
@@ -88,25 +94,28 @@ func (ch *GeminiChannel) ValidateKey(
targetURL string,
timeout time.Duration,
) *CustomErrors.APIError {
client := &http.Client{
Timeout: timeout,
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
}
ch.ModifyRequest(req, apiKey)
resp, err := client.Do(req)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
errorBody, _ := io.ReadAll(resp.Body)
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
@@ -115,10 +124,6 @@ func (ch *GeminiChannel) ValidateKey(
}
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
// TODO: [Future Refactoring] Decouple auth logic from URL path.
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
// This would make the channel more generic and adaptable to new upstream provider types.
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
} else {
@@ -133,24 +138,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
return true
}
var meta requestMetadata
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
var meta struct {
Stream bool `json:"stream"`
}
if json.Unmarshal(bodyBytes, &meta) == nil {
return meta.Stream
}
return false
}
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
return modelName
}
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
var rewrittenSegment string
if ch.IsOpenAICompatibleRequest(tempCtx) {
var apiEndpoint string
if ch.isOpenAIPath(originalPath) {
v1Index := strings.LastIndex(originalPath, "/v1/")
var apiEndpoint string
if v1Index != -1 {
apiEndpoint = originalPath[v1Index+len("/v1/"):]
} else {
@@ -158,69 +161,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
}
rewrittenSegment = "v1beta/openai/" + apiEndpoint
} else {
tempPath := originalPath
if strings.HasPrefix(tempPath, "/v1/") {
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
if strings.HasPrefix(originalPath, "/v1/") {
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
} else {
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
}
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
}
trimmedBasePath := strings.TrimSuffix(basePath, "/")
pathToJoin := rewrittenSegment
trimmedBasePath := strings.TrimSuffix(basePath, "/")
// 防止版本号重复拼接,例如 basePath 是 /v1beta而重写段也是 v1beta/..
versionPrefixes := []string{"v1beta", "v1"}
for _, prefix := range versionPrefixes {
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
break
}
}
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
if err != nil {
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
// 回退到简单的字符串拼接
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
}
return finalPath
}
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
// 这是一个桩实现,暂时不需要任何逻辑。
return nil
return nil // 桩实现
}
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
// 这是一个桩实现,暂时不需要任何逻辑。
// 桩实现
}
// ==========================================================
// ================== “智能路由”的核心引擎 ===================
// ==========================================================
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
log := ch.logger.WithField("correlation_id", params.CorrelationID)
targetURL, err := url.Parse(params.UpstreamURL)
if err != nil {
log.WithError(err).Error("Failed to parse upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
log.WithError(err).Error("Invalid upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
return
}
targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
if err != nil {
log.WithError(err).Error("Failed to create initial smart request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
log.WithError(err).Error("Failed to create initial smart stream request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
return
}
ch.ModifyRequest(initialReq, params.APIKey)
initialReq.Header.Del("Authorization")
resp, err := ch.httpClient.Do(initialReq)
if err != nil {
log.WithError(err).Error("Initial smart request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
log.WithError(err).Error("Initial smart stream request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
defer standardizedResp.Body.Close()
c.Writer.WriteHeader(standardizedResp.StatusCode)
for key, values := range standardizedResp.Header {
for _, value := range values {
@@ -228,45 +238,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR
}
}
io.Copy(c.Writer, standardizedResp.Body)
params.EventLogger.IsSuccess = false
params.EventLogger.StatusCode = resp.StatusCode
return
}
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
}
func (ch *GeminiChannel) processStreamAndRetry(
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
c *gin.Context,
initialRequestHeaders http.Header,
initialReader io.ReadCloser,
params SmartRequestParams,
log *logrus.Entry,
) {
defer initialReader.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
flusher, _ := c.Writer.(http.Flusher)
var accumulatedText strings.Builder
consecutiveRetryCount := 0
currentReader := initialReader
maxRetries := params.MaxRetries
retryDelay := params.RetryDelay
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
for {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected, stopping stream processing.")
return
}
var interruptionReason string
scanner := bufio.NewScanner(currentReader)
for scanner.Scan() {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected during scan.")
return
}
line := scanner.Text()
if line == "" {
continue
}
fmt.Fprintf(c.Writer, "%s\n\n", line)
flusher.Flush()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var payload models.GeminiSSEPayload
if err := json.Unmarshal([]byte(data), &payload); err != nil {
continue
}
if len(payload.Candidates) > 0 {
candidate := payload.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
@@ -285,52 +321,71 @@ func (ch *GeminiChannel) processStreamAndRetry(
}
}
currentReader.Close()
if interruptionReason == "" {
if err := scanner.Err(); err != nil {
log.WithError(err).Warn("Stream scanner encountered an error.")
interruptionReason = "SCANNER_ERROR"
} else {
log.Warn("Stream dropped unexpectedly without a finish reason.")
log.Warn("Stream connection dropped without a finish reason.")
interruptionReason = "CONNECTION_DROP"
}
}
if consecutiveRetryCount >= maxRetries {
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": http.StatusGatewayTimeout,
"status": "DEADLINE_EXCEEDED",
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
},
})
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
flusher.Flush()
return
}
consecutiveRetryCount++
params.EventLogger.Retries = consecutiveRetryCount
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
time.Sleep(retryDelay)
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBodyBytes, _ := json.Marshal(retryBody)
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
if err != nil {
log.WithError(err).Error("Failed to create retry request")
continue
}
retryReq.Header = initialRequestHeaders
ch.ModifyRequest(retryReq, params.APIKey)
retryReq.Header.Del("Authorization")
retryResp, err := ch.httpClient.Do(retryReq)
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
if err != nil {
log.WithError(err).Errorf("Retry request failed.")
} else {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
if retryResp.Body != nil {
retryResp.Body.Close()
}
}
if err != nil {
log.WithError(err).Error("Retry request failed")
continue
}
if retryResp.StatusCode != http.StatusOK {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
retryResp.Body.Close()
continue
}
currentReader = retryResp.Body
}
}
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
retryBody := originalBody
// 找到最后一个 'user' 角色的消息索引
lastUserIndex := -1
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
if retryBody.Contents[i].Role == "user" {
@@ -338,25 +393,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st
break
}
}
history := []models.GeminiContent{
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
}
if lastUserIndex != -1 {
// 如果找到了 'user' 消息,将历史记录插入到其后
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
newContents = append(newContents, history...)
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
retryBody.Contents = newContents
} else {
// 如果没有 'user' 消息(理论上不应发生),则直接追加
retryBody.Contents = append(retryBody.Contents, history...)
}
return retryBody, nil
}
// ===============================================
// ========= 辅助函数区 (继承并强化) =========
// ===============================================
return retryBody
}
type googleAPIError struct {
Error struct {
@@ -397,25 +453,28 @@ func truncate(s string, n int) string {
return s
}
// standardizeError
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("Failed to read upstream error body")
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
bodyBytes = []byte("Failed to read upstream error body")
}
resp.Body.Close()
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
var standardizedPayload googleAPIError
// 即使解析失败,也要构建一个标准的错误结构体
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
standardizedPayload.Error.Code = resp.StatusCode
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
standardizedPayload.Error.Details = []interface{}{map[string]string{
"@type": "proxy.upstream.error",
"@type": "proxy.upstream.unparsed.error",
"body": truncate(string(bodyBytes), truncateLimit),
}}
}
newBodyBytes, _ := json.Marshal(standardizedPayload)
newResp := &http.Response{
StatusCode: resp.StatusCode,
@@ -425,10 +484,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int
}
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
newResp.Header.Set("Access-Control-Allow-Origin", "*")
return newResp
}
// errToJSON
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
}

View File

@@ -13,9 +13,10 @@ type Config struct {
Database DatabaseConfig
Server ServerConfig
Log LogConfig
Redis RedisConfig `mapstructure:"redis"`
SessionSecret string `mapstructure:"session_secret"`
EncryptionKey string `mapstructure:"encryption_key"`
Redis RedisConfig `mapstructure:"redis"`
SessionSecret string `mapstructure:"session_secret"`
EncryptionKey string `mapstructure:"encryption_key"`
Repository RepositoryConfig `mapstructure:"repository"`
}
// DatabaseConfig 存储数据库连接信息
@@ -43,19 +44,24 @@ type RedisConfig struct {
DSN string `mapstructure:"dsn"`
}
type RepositoryConfig struct {
BasePoolTTLMinutes int `mapstructure:"base_pool_ttl_minutes"`
BasePoolTTIMinutes int `mapstructure:"base_pool_tti_minutes"`
}
// LoadConfig 从文件和环境变量加载配置
func LoadConfig() (*Config, error) {
// 设置配置文件名和路径
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/gemini-balancer/") // for production
// 允许从环境变量读取
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
// 设置默认值
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.port", "9000")
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "text")
viper.SetDefault("log.enable_file", false)
@@ -67,6 +73,9 @@ func LoadConfig() (*Config, error) {
viper.SetDefault("database.conn_max_lifetime", "1h")
viper.SetDefault("encryption_key", "")
viper.SetDefault("repository.base_pool_ttl_minutes", 60)
viper.SetDefault("repository.base_pool_tti_minutes", 10)
// 读取配置文件
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {

View File

@@ -2,6 +2,7 @@
package proxy
import (
"context"
"encoding/json"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -49,7 +50,6 @@ func (h *handler) registerRoutes(rg *gin.RouterGroup) {
}
}
// --- 请求 DTO ---
type CreateProxyConfigRequest struct {
Address string `json:"address" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
@@ -64,12 +64,10 @@ type UpdateProxyConfigRequest struct {
Description *string `json:"description"`
}
// 单个检测的请求体 (与前端JS对齐)
type CheckSingleProxyRequest struct {
Proxy string `json:"proxy" binding:"required"`
}
// 批量检测的请求体
type CheckAllProxiesRequest struct {
Proxies []string `json:"proxies" binding:"required"`
}
@@ -84,7 +82,7 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
}
if req.Status == "" {
req.Status = "active" // 默认状态
req.Status = "active"
}
proxyConfig := models.ProxyConfig{
@@ -98,7 +96,6 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
response.Error(c, errors.ParseDBError(err))
return
}
// 写操作后,发布事件并使缓存失效
h.publishAndInvalidate(proxyConfig.ID, "created")
response.Created(c, proxyConfig)
}
@@ -199,17 +196,16 @@ func (h *handler) DeleteProxyConfig(c *gin.Context) {
response.NoContent(c)
}
// publishAndInvalidate 统一事件发布和缓存失效逻辑
func (h *handler) publishAndInvalidate(proxyID uint, action string) {
go h.manager.invalidate()
go func() {
ctx := context.Background()
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
eventData, _ := json.Marshal(event)
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData)
_ = h.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
}()
}
// 新的 Handler 方法和 DTO
type SyncProxiesRequest struct {
Proxies []string `json:"proxies"`
}
@@ -220,14 +216,12 @@ func (h *handler) SyncProxies(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
taskStatus, err := h.manager.SyncProxiesInBackground(c.Request.Context(), req.Proxies)
if err != nil {
if errors.Is(err, ErrTaskConflict) {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
} else {
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
}
return
@@ -262,7 +256,7 @@ func (h *handler) CheckAllProxies(c *gin.Context) {
concurrency := cfg.ProxyCheckConcurrency
if concurrency <= 0 {
concurrency = 5 // 如果配置不合法,提供一个安全的默认值
concurrency = 5
}
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
response.Success(c, results)

View File

@@ -2,14 +2,13 @@
package proxy
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"gemini-balancer/internal/task"
"context"
"net"
"net/http"
"net/url"
@@ -25,7 +24,7 @@ import (
const (
TaskTypeProxySync = "proxy_sync"
proxyChunkSize = 200 // 代理同步的批量大小
proxyChunkSize = 200
)
type ProxyCheckResult struct {
@@ -35,13 +34,11 @@ type ProxyCheckResult struct {
ErrorMessage string `json:"error_message"`
}
// managerCacheData
type managerCacheData struct {
ActiveProxies []*models.ProxyConfig
ProxiesByID map[uint]*models.ProxyConfig
}
// manager结构体
type manager struct {
db *gorm.DB
syncer *syncer.CacheSyncer[managerCacheData]
@@ -80,21 +77,21 @@ func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskR
}
}
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) {
func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) {
resourceID := "global_proxy_sync"
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
if err != nil {
return nil, ErrTaskConflict
}
go m.runProxySyncTask(taskStatus.ID, proxyStrings)
go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings)
return taskStatus, nil
}
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) {
resourceID := "global_proxy_sync"
var allProxies []models.ProxyConfig
if err := m.db.Find(&allProxies).Error; err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
return
}
currentProxyMap := make(map[string]uint)
@@ -125,19 +122,19 @@ func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
}
if len(idsToDelete) > 0 {
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
return
}
}
if len(proxiesToAdd) > 0 {
if err := m.bulkAdd(proxiesToAdd); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
return
}
}
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
m.task.EndTaskByID(taskID, resourceID, result, nil)
m.publishChangeEvent("proxies_synced")
m.task.EndTaskByID(ctx, taskID, resourceID, result, nil)
m.publishChangeEvent(ctx, "proxies_synced")
go m.invalidate()
}
@@ -184,14 +181,15 @@ func (m *manager) bulkDeleteByIDs(ids []uint) error {
}
return nil
}
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
}
func (m *manager) publishChangeEvent(reason string) {
func (m *manager) publishChangeEvent(ctx context.Context, reason string) {
event := models.ProxyStatusChangedEvent{Action: reason}
eventData, _ := json.Marshal(event)
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData)
_ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
}
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
@@ -313,3 +311,8 @@ func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout t
defer resp.Body.Close()
return true
}
type Manager interface {
AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error)
// ... 其他需要暴露给外部服务的方法
}

View File

@@ -44,6 +44,7 @@ var (
ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."}
ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."}
ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."}
ErrProxyNotAvailable = &APIError{HTTPStatus: http.StatusNotFound, Code: "PROXY_ERROR", Message: "Required proxy is not available for this request."}
ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."}
ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}

View File

@@ -1,4 +1,4 @@
// Filename: internal/handlers/apikey_handler.go
// Filename: internal/handlers/apikey_handler.go (最终决战版)
package handlers
import (
@@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo
}
}
// DTOs for API requests
type BulkAddKeysToGroupRequest struct {
KeyGroupID uint `json:"key_group_id" binding:"required"`
Keys string `json:"keys" binding:"required"`
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
ValidateOnImport bool `json:"validate_on_import"`
}
type BulkUnlinkKeysFromGroupRequest struct {
@@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct {
}
type BulkActionFilter struct {
Status []string `json:"status"` // Changed to slice to accept multiple statuses
Status []string `json:"status"`
}
type BulkActionRequest struct {
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
Filter BulkActionFilter `json:"filter" binding:"required"`
}
@@ -89,7 +88,8 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -104,7 +104,8 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -119,7 +120,8 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -134,7 +136,8 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -148,7 +151,8 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -172,7 +176,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
}
}
if len(ids) > 0 {
keys, err := h.apiKeyService.GetKeysByIds(ids)
keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
if err != nil {
response.Error(c, &errors.APIError{
HTTPStatus: http.StatusInternalServerError,
@@ -191,7 +195,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
if params.PageSize <= 0 {
params.PageSize = 20
}
result, err := h.apiKeyService.ListAPIKeys(&params)
result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
@@ -201,19 +205,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
// 1. Manually handle the path parameter.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// 2. Bind query parameters using the correctly tagged struct.
var params models.APIKeyQueryParams
if err := c.ShouldBindQuery(&params); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
return
}
// 3. Set server-side defaults and the path parameter.
if params.Page <= 0 {
params.Page = 1
}
@@ -221,15 +222,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
params.PageSize = 20
}
params.KeyGroupID = uint(groupID)
// 4. Call the service layer.
paginatedResult, err := h.apiKeyService.ListAPIKeys(&params)
paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
// 5. [THE FIX] Return a successful response using the standard `response.Success`
// and a gin.H map, as confirmed to exist in your project.
response.Success(c, gin.H{
"items": paginatedResult.Items,
"total": paginatedResult.Total,
@@ -239,20 +236,18 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
}
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
// Group ID is now correctly sourced from the URL path.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// The request body is now simpler, only needing the keys.
var req BulkTestKeysForGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Call the same underlying service, but with unambiguous context.
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
// [修正] 将请求的 context 传递给 service 层
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -267,7 +262,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
}
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
// Route: PUT /keygroups/:id/apikeys/:keyId
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -284,8 +278,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Directly use the service to handle the logic
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -305,7 +298,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
return
}
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
@@ -313,7 +306,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
}
// RestoreKeysInGroup 恢复指定Key的接口
// POST /keygroups/:id/apikeys/restore
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -325,7 +317,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -339,14 +331,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
}
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
// POST /keygroups/:id/apikeys/restore-all-banned
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -360,48 +351,41 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
}
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
// Route: POST /keygroups/:id/bulk-actions
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
// 1. Parse GroupID from URL
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// 2. Bind the JSON payload to our new DTO
var req BulkActionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// 3. Central logic: based on the action, call the appropriate service method.
var task *task.Status
var apiErr *errors.APIError
switch req.Action {
case "revalidate":
// Assume keyValidationService has a method that accepts a filter
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
// [修正] 将请求的 context 传递给 service 层
task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
case "set_status":
if req.NewStatus == "" {
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
break
}
// Assume apiKeyService has a method to update status by filter
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
targetStatus := models.APIKeyStatus(req.NewStatus)
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
case "delete":
// Assume keyImportService has a method to unlink by filter
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
// [修正] 将请求的 context 传递给 service 层
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
default:
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
}
// 4. Handle errors from the switch block
if apiErr != nil {
response.Error(c, apiErr)
return
}
if err != nil {
// Attempt to parse it as a known APIError, otherwise, wrap it.
var parsedErr *errors.APIError
if errors.As(err, &parsedErr) {
response.Error(c, parsedErr)
@@ -410,21 +394,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
}
return
}
// 5. Return the task status on success
response.Success(c, task)
}
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
// Route: GET /keygroups/:id/apikeys/export
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
statuses := c.QueryArray("status")
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return

View File

@@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}
// GetChart 获取仪表盘的图表数据
// GetChart
func (h *DashboardHandler) GetChart(c *gin.Context) {
var groupID *uint
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
@@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
}
}
chartData, err := h.queryService.QueryHistoricalChart(groupID)
chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)
@@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
c.JSON(http.StatusOK, chartData)
}
// GetRequestStats 处理对“期间调用概览”的请求
// GetRequestStats
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
period := c.Param("period") // 从 URL 路径中获取 period
stats, err := h.queryService.GetRequestStatsForPeriod(period)
period := c.Param("period")
stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)

View File

@@ -2,6 +2,7 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/errors"
@@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das
}
}
// DTOs & 辅助函数
func isValidGroupName(name string) bool {
if name == "" {
return false
@@ -40,7 +40,6 @@ func isValidGroupName(name string) bool {
return match
}
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
type KeyGroupOperationalSettings struct {
EnableKeyCheck *bool `json:"enable_key_check"`
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
@@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct {
MaxRetries *int `json:"max_retries"`
EnableSmartGateway *bool `json:"enable_smart_gateway"`
}
type CreateKeyGroupRequest struct {
Name string `json:"name" binding:"required"`
DisplayName string `json:"display_name"`
@@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct {
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy bool `json:"enable_proxy"`
ChannelType string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
}
type UpdateKeyGroupRequest struct {
Name *string `json:"name"`
DisplayName *string `json:"display_name"`
@@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct {
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy *bool `json:"enable_proxy"`
ChannelType *string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
// M:N associations
AllowedUpstreams []string `json:"allowed_upstreams"`
AllowedModels []string `json:"allowed_models"`
}
type KeyGroupResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
@@ -96,36 +86,30 @@ type KeyGroupResponse struct {
AllowedModels []string `json:"allowed_models"`
AllowedUpstreams []string `json:"allowed_upstreams"`
}
// [NEW] Define the detailed response structure for a single group.
type KeyGroupDetailsResponse struct {
KeyGroupResponse
Settings *models.GroupSettings `json:"settings,omitempty"`
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
}
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
modelNames := make([]string, 0, len(mappings))
for _, mapping := range mappings {
if mapping != nil { // Safety check
if mapping != nil {
modelNames = append(modelNames, mapping.ModelName)
}
}
return modelNames
}
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
urls := make([]string, 0, len(upstreams))
for _, upstream := range upstreams {
if upstream != nil { // Safety check
if upstream != nil {
urls = append(urls, upstream.URL)
}
}
return urls
}
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
return KeyGroupResponse{
ID: group.ID,
@@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i
CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt,
Order: group.Order,
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
AllowedModels: transformModelsToStrings(group.AllowedModels),
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams),
}
}
// packGroupSettings is a helper to convert request-level operational settings
// into the model-level settings struct.
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
return &models.KeyGroupSettings{
EnableKeyCheck: settings.EnableKeyCheck,
@@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet
EnableSmartGateway: settings.EnableSmartGateway,
}
}
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup,
}
return group, nil
}
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
if req.Name != nil {
group.Name = *req.Name
@@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
go func() {
ctx := context.Background()
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
eventData, _ := json.Marshal(event)
h.store.Publish(models.TopicKeyStatusChanged, eventData)
_ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}()
}
@@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
return
}
// The core logic remains, as it's specific to creation.
p := bluemonday.StripTagsPolicy()
sanitizedDisplayName := p.Sanitize(req.DisplayName)
sanitizedDescription := p.Sanitize(req.Description)
@@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
}
// 统一的处理器可以处理两种情况:
// 1. GET /keygroups - 返回所有组的列表
// 2. GET /keygroups/:id - 返回指定ID的单个组
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
// Case 1: Get a single group
if idStr := c.Param("id"); idStr != "" {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, detailedResponse)
return
}
// Case 2: Get all groups
allGroups := h.groupManager.GetAllGroups()
responses := make([]KeyGroupResponse, 0, len(allGroups))
for _, group := range allGroups {
@@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, responses)
}
// UpdateKeyGroup
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
}
// DeleteKeyGroup
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
}
// GetKeyGroupStats
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
stats, err := h.queryService.GetGroupStats(group.ID)
stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
return
@@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
}
// 更新分组排序
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
var payload []service.UpdateOrderPayload
if err := c.ShouldBindJSON(&payload); err != nil {

View File

@@ -29,9 +29,7 @@ import (
"gorm.io/datatypes"
)
type proxyErrorKey int
const proxyErrKey proxyErrorKey = 0
type proxyErrorContextKey struct{}
type ProxyHandler struct {
resourceService *service.ResourceService
@@ -81,45 +79,51 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
h.handleListModelsRequest(c)
return
}
requestBody, err := io.ReadAll(c.Request.Body)
maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024)
requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
if err != nil {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body"))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read"))
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
c.Request.ContentLength = int64(len(requestBody))
modelName := h.channel.ExtractModel(c, requestBody)
groupName := c.Param("group_name")
isPreciseRouting := groupName != ""
if !isPreciseRouting && modelName == "" {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL"))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request"))
return
}
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
if apiErr, ok := err.(*errors.APIError); ok {
errToJSON(c, uuid.New().String(), apiErr)
} else {
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error()))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources"))
}
return
}
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
if err != nil {
h.logger.WithError(err).Error("Failed to build operational config.")
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration"))
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config"))
return
}
initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig)
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
if isOpenAICompatible {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
return
}
isStream := h.channel.IsStreamRequest(c, requestBody)
systemSettings := h.settingsManager.GetSettings()
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry {
if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry {
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
} else {
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
@@ -129,226 +133,307 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
startTime := time.Now()
correlationID := uuid.New().String()
var finalRecorder *httptest.ResponseRecorder
var lastUsedResources *service.RequestResources
var finalProxyErr *errors.APIError
var isSuccess bool
var finalPromptTokens, finalCompletionTokens int
var actualRetries int = 0
defer func() {
// 如果一次尝试都未成功(例如,在第一次获取资源时就失败),则不记录日志
if lastUsedResources == nil {
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
return
}
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
var finalPromptTokens, finalCompletionTokens, actualRetries int
finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
finalEvent.RequestLog.IsSuccess = isSuccess
finalEvent.RequestLog.Retries = actualRetries
if isSuccess {
finalEvent.RequestLog.PromptTokens = finalPromptTokens
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
}
defer h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
actualRetries, isPreciseRouting)
// 确保即使在成功的情况下如果recorder存在也记录最终的状态码
if finalRecorder != nil {
finalEvent.RequestLog.StatusCode = finalRecorder.Code
}
if !isSuccess {
// 将 finalProxyErr 的信息填充到 RequestLog 中
if finalProxyErr != nil {
finalEvent.Error = finalProxyErr // Error 字段用于事件传递,不会被序列化到数据库
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
} else if finalRecorder != nil {
// 降级处理:如果 finalProxyErr 为空但 recorder 存在且失败
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
finalEvent.Error = apiErr
finalEvent.RequestLog.ErrorCode = apiErr.Code
finalEvent.RequestLog.ErrorMessage = apiErr.Message
}
}
// 将完整的事件发布
eventData, err := json.Marshal(finalEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
return
}
if err := h.store.Publish(models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
}
}()
var maxRetries int
if isPreciseRouting {
// For precise routing, use the group's setting. If not set, fall back to the global setting.
if finalOpConfig.MaxRetries != nil {
maxRetries = *finalOpConfig.MaxRetries
} else {
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
} else {
// For BasePool (intelligent aggregation), *always* use the global setting.
maxRetries = h.settingsManager.GetSettings().MaxRetries
}
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
totalAttempts := maxRetries + 1
for attempt := 1; attempt <= totalAttempts; attempt++ {
if c.Request.Context().Err() != nil {
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
if finalProxyErr == nil {
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
}
break
}
var currentResources *service.RequestResources
var err error
if attempt == 1 {
currentResources = initialResources
} else {
actualRetries = attempt - 1
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
break
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
if err != nil {
if apiErr, ok := err.(*errors.APIError); ok {
finalProxyErr = apiErr
} else {
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
}
break
}
lastUsedResources = resources
if attempt > 1 {
actualRetries = attempt - 1
}
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig)
currentResources.RequestConfig = finalRequestConfig
lastUsedResources = currentResources
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID)
var attemptErr *errors.APIError
var attemptIsSuccess bool
recorder := httptest.NewRecorder()
attemptStartTime := time.Now()
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
defer cancel()
attemptReq := c.Request.Clone(ctx)
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" {
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
isSuccess = false
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource")
continue
}
h.transparentProxy.Director = func(req *http.Request) {
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL)
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.Host = targetURL.Host
var pureClientPath string
if isPreciseRouting {
proxyPrefix := "/proxy/" + groupName
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
} else {
pureClientPath = req.URL.Path
}
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath)
req.URL.Path = finalPath
h.logger.WithFields(logrus.Fields{
"correlation_id": correlationID,
"attempt": attempt,
"key_id": currentResources.APIKey.ID,
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
"final_request_url": req.URL.String(),
}).Infof("Director constructed final upstream request URL.")
req.Header.Del("Authorization")
h.channel.ModifyRequest(req, currentResources.APIKey)
req.Header.Set("X-Correlation-ID", correlationID)
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
}
transport := h.transparentProxy.Transport.(*http.Transport)
if currentResources.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address)
proxyURL, err := url.Parse(proxyURLStr)
if err == nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
} else {
transport.Proxy = http.ProxyFromEnvironment
}
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
defer resp.Body.Close()
var reader io.ReadCloser
var err error
isGzipped := resp.Header.Get("Content-Encoding") == "gzip"
if isGzipped {
reader, err = gzip.NewReader(resp.Body)
if err != nil {
h.logger.WithError(err).Error("Failed to create gzip reader")
reader = resp.Body
} else {
resp.Header.Del("Content-Encoding")
}
defer reader.Close()
} else {
reader = resp.Body
}
bodyBytes, err := io.ReadAll(reader)
if err != nil {
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error())
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message)))
return nil
}
if resp.StatusCode < 400 {
attemptIsSuccess = true
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes)
} else {
parsedMsg := errors.ParseUpstreamError(bodyBytes)
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
}
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
h.transparentProxy.ServeHTTP(recorder, attemptReq)
finalRecorder = recorder
finalProxyErr = attemptErr
isSuccess = attemptIsSuccess
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID)
recorder, attemptErr, attemptSuccess := h.executeProxyAttempt(
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
&finalPromptTokens, &finalCompletionTokens,
)
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
if isSuccess {
break
}
isUnretryableError := false
if finalProxyErr != nil {
if errors.IsUnretryableRequestError(finalProxyErr.Message) {
isUnretryableError = true
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message)
}
}
if attempt >= totalAttempts || isUnretryableError {
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
break
}
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds())
retryEvent.IsSuccess = false
retryEvent.StatusCode = recorder.Code
retryEvent.Retries = actualRetries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.ErrorCode = attemptErr.Code
retryEvent.ErrorMessage = attemptErr.Message
}
eventData, _ := json.Marshal(retryEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData)
h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting)
}
if finalRecorder != nil {
bodyBytes := finalRecorder.Body.Bytes()
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
for k, v := range finalRecorder.Header() {
if strings.ToLower(k) != "content-length" {
c.Writer.Header()[k] = v
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
}
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
recorder := httptest.NewRecorder()
var attemptErr *errors.APIError
var isSuccess bool
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
defer cancel()
attemptReq := c.Request.Clone(ctx)
attemptReq.Body = io.NopCloser(bytes.NewReader(body))
attemptReq.ContentLength = int64(len(body))
h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
*attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
h.transparentProxy.ServeHTTP(recorder, attemptReq)
return recorder, attemptErr, isSuccess
}
func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) {
h.transparentProxy.Director = func(r *http.Request) {
targetURL, _ := url.Parse(res.UpstreamEndpoint.URL)
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
var pureClientPath string
if isPrecise {
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
} else {
pureClientPath = r.URL.Path
}
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
r.Header.Del("Authorization")
h.channel.ModifyRequest(r, res.APIKey)
r.Header.Set("X-Correlation-ID", corrID)
}
transport := h.transparentProxy.Transport.(*http.Transport)
if res.ProxyConfig != nil {
proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address)
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
transport.Proxy = http.ProxyURL(proxyURL)
} else {
transport.Proxy = http.ProxyFromEnvironment
}
} else {
transport.Proxy = http.ProxyFromEnvironment
}
h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens)
}
func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
return func(resp *http.Response) error {
var reader io.ReadCloser = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
h.logger.WithError(err).Error("Failed to create gzip reader")
} else {
reader = gzReader
resp.Header.Del("Content-Encoding")
}
}
c.Writer.WriteHeader(finalRecorder.Code)
c.Writer.Write(finalRecorder.Body.Bytes())
} else {
errToJSON(c, correlationID, finalProxyErr)
defer reader.Close()
bodyBytes, err := io.ReadAll(reader)
if err != nil {
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
return nil
}
if resp.StatusCode < 400 {
*isSuccess = true
*pTokens, *cTokens = extractUsage(bodyBytes)
} else {
parsedMsg := errors.ParseUpstreamError(bodyBytes)
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
}
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
}
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
corrID := r.Header.Get("X-Correlation-ID")
log := h.logger.WithField("id", corrID)
log.Errorf("Transparent proxy encountered an error: %v", err)
errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError)
if !ok || errPtr == nil {
log.Error("FATAL: proxyErrorContextKey not found in context for error handler.")
defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred")
writeErrorToResponse(rw, defaultErr)
return
}
if *errPtr == nil {
if errors.IsClientNetworkError(err) {
*errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
*errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
}
writeErrorToResponse(rw, *errPtr)
}
func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) {
if attempt == 1 {
return initialResources, nil
}
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
if err != nil {
return nil, err
}
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig)
resources.RequestConfig = finalRequestConfig
return resources, nil
}
func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool {
if attempt >= totalAttempts {
return true
}
if err != nil && errors.IsUnretryableRequestError(err.Message) {
h.logger.WithField("id", correlationID).Warnf("Attempt failed with unretryable request error. Aborting retries. Message: %s", err.Message)
return true
}
return false
}
func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) {
if rec != nil {
for k, v := range rec.Header() {
c.Writer.Header()[k] = v
}
c.Writer.WriteHeader(rec.Code)
c.Writer.Write(rec.Body.Bytes())
} else if apiErr != nil {
errToJSON(c, corrID, apiErr)
} else {
errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred"))
}
}
func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) {
if res == nil {
h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.")
return
}
event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise)
event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
event.RequestLog.IsSuccess = isSuccess
event.RequestLog.Retries = retries
if isSuccess {
event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens
}
if rec != nil {
event.RequestLog.StatusCode = rec.Code
}
if !isSuccess {
errToLog := finalErr
if errToLog == nil && rec != nil {
errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), "Request failed after all retries.")
}
if errToLog != nil {
event.Error = errToLog
event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message
}
}
eventData, err := json.Marshal(event)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event")
}
}
func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) {
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
retryEvent.RequestLog.IsSuccess = false
retryEvent.RequestLog.StatusCode = rec.Code
retryEvent.RequestLog.Retries = retries
if attemptErr != nil {
retryEvent.Error = attemptErr
retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message
}
eventData, err := json.Marshal(retryEvent)
if err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event")
}
}
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
finalConfig := &models.RequestConfig{
CustomHeaders: make(datatypes.JSONMap),
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
StreamMinDelay: globalSettings.StreamMinDelay,
StreamMaxDelay: globalSettings.StreamMaxDelay,
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
StreamChunkSize: globalSettings.StreamChunkSize,
EnableFakeStream: globalSettings.EnableFakeStream,
FakeStreamInterval: globalSettings.FakeStreamInterval,
}
for k, v := range globalSettings.CustomHeaders {
finalConfig.CustomHeaders[k] = v
}
if groupConfig == nil {
return finalConfig
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
return finalConfig
}
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
}
return finalConfig
}
func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) {
if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() {
return
}
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(apiErr.HTTPStatus)
json.NewEncoder(rw).Encode(gin.H{"error": apiErr})
}
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
startTime := time.Now()
correlationID := uuid.New().String()
@@ -356,7 +441,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
log.Info("Smart Gateway activated for streaming request.")
var originalRequest models.GeminiRequest
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error()))
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway"))
return
}
systemSettings := h.settingsManager.GetSettings()
@@ -367,8 +452,14 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
if c.Writer.Status() > 0 {
requestFinishedEvent.StatusCode = c.Writer.Status()
}
eventData, _ := json.Marshal(requestFinishedEvent)
_ = h.store.Publish(models.TopicRequestFinished, eventData)
eventData, err := json.Marshal(requestFinishedEvent)
if err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream")
return
}
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream")
}
}()
params := channel.SmartRequestParams{
CorrelationID: correlationID,
@@ -385,30 +476,6 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
h.channel.ProcessSmartStreamRequest(c, params)
}
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
correlationID := r.Header.Get("X-Correlation-ID")
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
if !exists || proxyErrPtr == nil {
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
return
}
if errors.IsClientNetworkError(err) {
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
} else {
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
}
if _, ok := rw.(*httptest.ResponseRecorder); ok {
return
}
if writer, ok := rw.(interface{ Written() bool }); ok {
if writer.Written() {
return
}
}
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
}
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
event := &models.RequestFinishedEvent{
RequestLog: models.RequestLog{
@@ -435,7 +502,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
}
}
if res != nil {
// [核心修正] 填充到内嵌的 RequestLog 结构体中
if res.APIKey != nil {
event.RequestLog.KeyID = &res.APIKey.ID
}
@@ -444,7 +510,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
}
if res.UpstreamEndpoint != nil {
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
// UpstreamURL 是事件传递字段,不是数据库字段,所以在这里赋值是正确的
event.UpstreamURL = &res.UpstreamEndpoint.URL
}
if res.ProxyConfig != nil {
@@ -464,13 +529,15 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
}
if isPreciseRouting {
return h.resourceService.GetResourceFromGroup(authToken, groupName)
} else {
return h.resourceService.GetResourceFromBasePool(authToken, modelName)
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
}
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
}
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{
"error": apiErr,
"correlation_id": corrID,
@@ -479,8 +546,8 @@ func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
type bufferPool struct{}
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
func (b *bufferPool) Put(bytes []byte) {}
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
func (b *bufferPool) Put(_ []byte) {}
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
var data struct {
@@ -495,34 +562,11 @@ func extractUsage(body []byte) (promptTokens int, completionTokens int) {
return 0, 0
}
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders)
var customHeadersMap datatypes.JSONMap
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
finalConfig := &models.RequestConfig{
CustomHeaders: customHeadersMap,
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
StreamMinDelay: globalSettings.StreamMinDelay,
StreamMaxDelay: globalSettings.StreamMaxDelay,
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
StreamChunkSize: globalSettings.StreamChunkSize,
EnableFakeStream: globalSettings.EnableFakeStream,
FakeStreamInterval: globalSettings.FakeStreamInterval,
func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int {
if isPreciseRouting && finalOpConfig.MaxRetries != nil {
return *finalOpConfig.MaxRetries
}
if groupConfig == nil {
return finalConfig
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
return finalConfig
}
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
return finalConfig
}
return finalConfig
return h.settingsManager.GetSettings().MaxRetries
}
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {

View File

@@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
return
}
taskStatus, err := h.taskService.GetStatus(taskID)
taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID)
if err != nil {
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))

View File

@@ -77,3 +77,9 @@ type APIKeyDetails struct {
CooldownUntil *time.Time `json:"cooldown_until"`
EncryptedKey string
}
// SettingsManager 定义了系统设置管理器的抽象接口。
type SettingsManager interface {
GetSettings() *SystemSettings
}

View File

@@ -11,6 +11,7 @@ type SystemSettings struct {
BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"`
KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间单位为分钟。"`
LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"`
MaxRequestBodySizeMB int `json:"max_request_body_size_mb" default:"10" name:"最大请求体大小 (MB)" category:"请求设置" desc:"允许代理接收的最大请求体大小单位为MB。超过此大小的请求将被拒绝。"`
PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"`
@@ -41,6 +42,10 @@ type SystemSettings struct {
MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前允许的连续登录失败次数。"`
IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长单位为分钟。"`
// BasePool 相关配置
// BasePoolTTLMinutes int `json:"base_pool_ttl_minutes" default:"30" name:"基础资源池最大生存时间(分钟)" category:"基础资源池" desc:"一个动态构建的基础资源池BasePool在Redis中的最大生存时间。到期后即使仍在活跃使用也会被强制重建。"`
// BasePoolTTIMinutes int `json:"base_pool_tti_minutes" default:"10" name:"基础资源池空闲超时(分钟)" category:"基础资源池" desc:"一个基础资源池BasePool在连续无请求后自动销毁的空闲等待时间。"`
//智能网关
LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时保留的最大字符数。0表示不截断。"`
EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"`

View File

@@ -1,13 +1,15 @@
// Filename: internal/repository/key_cache.go
// Filename: internal/repository/key_cache.go (最终定稿)
package repository
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"strconv"
)
// --- Redis Key 常量定义 ---
const (
KeyGroup = "group:%d:keys:active"
KeyDetails = "key:%d:details"
@@ -22,13 +24,16 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
)
func (r *gormKeyRepository) LoadAllKeysToStore() error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系并完整重建Redis缓存。
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
}
// 1. 批量解密所有涉及的密钥
keyMap := make(map[uint]*models.APIKey)
for _, m := range allMappings {
if m.APIKey != nil {
@@ -40,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
keysToDecrypt = append(keysToDecrypt, *k)
}
if err := r.decryptKeys(keysToDecrypt); err != nil {
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
// 即使解密失败,也继续尝试加载未加密或已解密的部分
}
decryptedKeyMap := make(map[uint]models.APIKey)
for _, k := range keysToDecrypt {
decryptedKeyMap[k.ID] = k
}
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline()
detailsToSet := make(map[string][]byte)
// 2. 清理所有分组的旧轮询结构
pipe := r.store.Pipeline(ctx)
var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil {
for _, group := range allGroups {
@@ -62,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
)
}
} else {
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
}
// 3. 准备批量更新数据
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
detailsToSet := make(map[string]any)
for _, mapping := range allMappings {
if mapping.APIKey == nil {
continue
}
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
if !ok {
continue
continue // 跳过解密失败的密钥
}
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
keyJSON, _ := json.Marshal(decryptedKey)
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
mappingJSON, _ := json.Marshal(mapping)
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
if mapping.Status == models.StatusActive {
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
}
}
// 4. 使用 MSet 批量写入详情和映射缓存
if len(detailsToSet) > 0 {
if err := r.store.MSet(ctx, detailsToSet); err != nil {
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
}
}
// 5. 在Pipeline中重建所有分组的轮询结构
for groupID, activeMappings := range activeKeysByGroup {
if len(activeMappings) == 0 {
continue
@@ -100,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
}
// 6. 执行Pipeline
if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
}
r.logger.Info("Cache rebuild complete, including all polling structures.")
r.logger.Info("Full cache rebuild completed successfully.")
return nil
}
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err := r.decryptKey(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
@@ -124,81 +141,104 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err != nil {
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
}
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
}
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(key.ID)
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
}
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
for _, groupID := range groupIDs {
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
}
return pipe.Exec()
}
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline()
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
groupID := mapping.KeyGroupID
ctx := context.Background()
pipe := r.store.Pipeline(ctx)
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
// 如果新状态是 Active则重新添加到所有轮询结构中
if mapping.Status == models.StatusActive {
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
// 无论状态如何,都更新映射详情的 K-V 缓存
mappingJSON, err := json.Marshal(mapping)
if err != nil {
return fmt.Errorf("failed to marshal mapping: %w", err)
}
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
return pipe.Exec()
}
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 {
return nil
}
groupUpdates := make(map[uint]struct {
ToAdd []interface{}
ToRemove []interface{}
})
pipe := r.store.Pipeline(ctx)
for _, mapping := range mappings {
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
update, ok := groupUpdates[mapping.KeyGroupID]
if !ok {
update = struct {
ToAdd []interface{}
ToRemove []interface{}
}{}
}
groupID := mapping.KeyGroupID
// 对于批处理中的每一个mapping都执行完整的、正确的“先删后增”逻辑
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
if mapping.Status == models.StatusActive {
update.ToRemove = append(update.ToRemove, keyIDStr)
update.ToAdd = append(update.ToAdd, keyIDStr)
} else {
update.ToRemove = append(update.ToRemove, keyIDStr)
}
groupUpdates[mapping.KeyGroupID] = update
}
pipe := r.store.Pipeline()
var pipelineError error
for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
if len(updates.ToRemove) > 0 {
for _, keyID := range updates.ToRemove {
pipe.LRem(activeKeyListKey, 0, keyID)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
if len(updates.ToAdd) > 0 {
pipe.LPush(activeKeyListKey, updates.ToAdd...)
}
mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误以保证大部分更新成功
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
}
if err := pipe.Exec(); err != nil {
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
}
return pipelineError
return pipe.Exec()
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"gemini-balancer/internal/models"
"context"
"math/rand"
"strings"
"time"
@@ -22,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
keyHashes := make([]string, len(keys))
keyValueToHashMap := make(map[string]string)
for i, k := range keys {
// All incoming keys must have plaintext APIKey
if k.APIKey == "" {
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
}
@@ -34,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
var finalKeys []models.APIKey
err := r.db.Transaction(func(tx *gorm.DB) error {
var existingKeys []models.APIKey
// [MODIFIED] Query by hash to find existing keys.
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
return err
}
@@ -68,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
}
}
if len(keysToCreate) > 0 {
// [MODIFIED] Create now only provides encrypted data and hash.
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
return err
}
}
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
return err
}
// [CRITICAL] Decrypt all keys before returning them to the service layer.
return r.decryptKeys(finalKeys)
})
return finalKeys, err
}
func (r *gormKeyRepository) Update(key *models.APIKey) error {
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
// This indicates a potential change that needs to be re-encrypted.
if key.APIKey != "" {
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
if err != nil {
@@ -97,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
key.APIKeyHash = hex.EncodeToString(hash[:])
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
return tx.Save(key).Error
})
if err != nil {
return err
}
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
return nil // Continue without cache update if decryption fails.
return nil
}
if err := r.updateStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
@@ -115,7 +110,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
}
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
key, err := r.GetKeyByID(id) // This now returns a decrypted key
key, err := r.GetKeyByID(id)
if err != nil {
return err
}
@@ -125,7 +120,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
if err != nil {
return err
}
if err := r.removeStoreCacheForKey(key); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
}
return nil
@@ -140,16 +135,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
// Find the full key objects first to update the cache later.
var keysToDelete []models.APIKey
// [MODIFIED] Find by hash.
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
return 0, err
}
if len(keysToDelete) == 0 {
return 0, nil
}
// Decrypt them to ensure cache has plaintext if needed.
if err := r.decryptKeys(keysToDelete); err != nil {
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
}
@@ -167,7 +159,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
return 0, err
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
@@ -194,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt before returning.
return keys, r.decryptKeys(keys)
}

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/models"
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
}
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
if result.Error != nil {
return 0, result.Error
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
return keys, nil
}
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
if err == nil {
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
go func() {
if err := r.LoadAllKeysToStore(); err != nil {
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
}
}()

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -14,7 +15,7 @@ import (
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 {
return nil
}
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
}
for _, keyID := range keyIDs {
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
}
return nil
}
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
if len(keyIDs) == 0 {
return 0, nil
}
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
for _, keyID := range keyIDs {
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
}
return unlinkedCount, nil
}
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
strGroupIDs, err := r.store.SMembers(cacheKey)
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
if err != nil || len(strGroupIDs) == 0 {
var groupIDs []uint
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
for _, id := range groupIDs {
interfaceSlice = append(interfaceSlice, id)
}
r.store.SAdd(cacheKey, interfaceSlice...)
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
}
return groupIDs, nil
}
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
return &mapping, err
}
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha1"
"encoding/json"
"errors"
@@ -17,40 +18,40 @@ import (
)
const (
CacheTTL = 5 * time.Minute
EmptyPoolPlaceholder = "EMPTY_POOL"
EmptyCacheTTL = 1 * time.Minute
CacheTTL = 5 * time.Minute
EmptyCacheTTL = 1 * time.Minute
)
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
if group == nil {
return nil, nil, fmt.Errorf("group cannot be nil")
}
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
@@ -58,65 +59,70 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
go func() {
updateCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
}()
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 智能聚合模式设计的全新轮询器
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
poolID := generatePoolID(pool.CandidateGroups)
// SelectOneActiveKeyFromBasePool 智能聚合池中选取一个可用Key
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
if pool == nil || len(pool.CandidateGroups) == 0 {
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
}
poolID := r.generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
log.WithError(err).Error("Failed to ensure BasePool cache exists")
}
return nil, nil, err
}
var keyIDStr string
var err error
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
@@ -125,153 +131,266 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
}
return apiKey, group, nil
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.refreshBasePoolHeartbeat(bgCtx, poolID)
}()
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
if err != nil {
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
}
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
}
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
if err != nil {
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
return nil, nil, err
}
var originGroup *models.KeyGroup
for _, g := range pool.CandidateGroups {
if g.ID == uint(groupID) {
originGroup = g
break
}
}
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
if originGroup == nil {
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
}
if pool.PollingStrategy == models.StrategyWeighted {
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
}()
}
return apiKey, originGroup, nil
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
exists, err := r.store.Exists(listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err // 直接返回读取错误
// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
// 预检查,快速失败
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
if exists {
val, err := r.store.LIndex(listKey, 0)
if err != nil {
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound // 已知为空,直接返回
}
return nil // 缓存有效,直接返回
}
}
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
}
if !acquired {
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID)
}
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
if exists, _ := r.store.Exists(listKey); exists {
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
var allActiveKeyIDs []string
lruMembers := make(map[string]float64)
// 获取分布式锁
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
if err := r.acquireLock(ctx, lockKey); err != nil {
return err // acquireLock 内部已记录日志并返回明确错误
}
defer r.releaseLock(context.Background(), lockKey)
// 双重检查锁定
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
// 在执行重度操作前,最后检查一次上下文是否已取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
keyToGroupMap := make(map[string]any)
allKeyIDsSet := make(map[string]struct{})
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
// --- [核心修正] ---
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
// 从而给了下一次请求一个全新的、成功的机会。
return err
r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
continue
}
// 只有在 SMembers 成功时,才继续处理
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
for _, keyID := range groupKeyIDs {
if _, exists := allKeyIDsSet[keyID]; !exists {
allKeyIDsSet[keyID] = struct{}{}
keyToGroupMap[keyID] = groupIDStr
}
}
}
// --- [逻辑修正] ---
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
// 才允许写入“毒丸”。
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline()
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
// 处理空池情况
if len(allKeyIDsSet) == 0 {
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
if emptyCacheTTL < time.Minute {
emptyCacheTTL = time.Minute
}
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
for keyID := range allKeyIDsSet {
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
}
// 使用 Pipeline 原子化构建所有缓存结构
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
pipe := r.store.Pipeline(ctx)
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
if len(keyToGroupMap) > 0 {
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
pipe.Expire(keyToGroupMapKey, basePoolTTL)
}
pipe.Expire(mainPoolKey, basePoolTTL)
pipe.Expire(sequentialKey, basePoolTTL)
pipe.Expire(cooldownKey, basePoolTTL)
pipe.Expire(lruKey, basePoolTTL)
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
return err
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
}
// 异步填充 LRU 缓存,并传入已构建好的映射
go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
return nil
}
// --- 辅助方法 ---
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
const (
lockTTL = 30 * time.Second
lockMaxRetries = 5
lockBaseBackoff = 50 * time.Millisecond
)
for i := 0; i < lockMaxRetries; i++ {
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
return err
}
if acquired {
return nil
}
time.Sleep(lockBaseBackoff * (1 << i))
}
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
}
// releaseLock 封装了分布式锁的释放逻辑。
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
if err := r.store.Del(ctx, lockKey); err != nil {
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
}
}
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), duration)
}
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
// 使用 EXPIRE 命令来刷新如果Key不存在它什么也不做是安全的
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
}
}
}
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
func (r *gormKeyRepository) populateBasePoolLRUCache(
parentCtx context.Context,
currentPoolID string,
keys []string,
keyToGroupMap map[string]any,
) {
lruMembers := make(map[string]float64, len(keys))
for _, keyIDStr := range keys {
select {
case <-parentCtx.Done():
return
default:
}
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
if !ok {
continue
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
data, err := r.store.Get(parentCtx, mappingKey)
if err != nil {
continue
}
var mapping models.GroupAPIKeyMapping
if json.Unmarshal(data, &mapping) == nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
}
}
if len(lruMembers) > 0 {
lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
if parentCtx.Err() == nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
}
}
}
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
})
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
}
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup) string {
func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
ids := make([]int, len(groups))
for i, g := range groups {
ids[i] = int(g.ID)
}
sort.Ints(ids)
h := sha1.New()
io.WriteString(h, fmt.Sprintf("%v", ids))
return fmt.Sprintf("%x", h.Sum(nil))
}
// toInterfaceSlice 类型转换辅助函数
func toInterfaceSlice(slice []string) []interface{} {
func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
result := make([]interface{}, len(slice))
for i, v := range slice {
result[i] = v
@@ -280,13 +399,13 @@ func toInterfaceSlice(slice []string) []interface{} {
}
// nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略
func nowMilli() float64 {
func (r *gormKeyRepository) nowMilli() float64 {
return float64(time.Now().UnixMilli())
}
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
}
@@ -295,7 +414,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
}
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
}

View File

@@ -1,7 +1,9 @@
// Filename: internal/repository/key_writer.go
package repository
import (
"context"
"fmt"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -9,7 +11,7 @@ import (
"time"
)
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
timestamp := float64(time.Now().UnixMilli())
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
strconv.FormatUint(uint64(keyID), 10): timestamp,
}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
}
}
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(lruKey, keyIDStr)
_ = r.store.SRem(mainPoolKey, keyIDStr)
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
_ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(ctx, lruKey, keyIDStr)
_ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
_ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
if newStatus == models.StatusActive {
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
}
members := map[string]float64{keyIDStr: 0}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
}
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
}
}
}
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
if success {
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
}
return
}
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
}
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
// This call is correct. It uses the synchronous, direct method.
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
}

View File

@@ -2,6 +2,8 @@
package repository
import (
"context"
"gemini-balancer/internal/config"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -22,8 +24,8 @@ type BasePool struct {
type KeyRepository interface {
// --- 核心选取与调度 --- key_selector
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
// --- 加密与解密 --- key_crud
Decrypt(key *models.APIKey) error
@@ -37,16 +39,16 @@ type KeyRepository interface {
GetKeyByID(id uint) (*models.APIKey, error)
GetKeyByValue(keyValue string) (*models.APIKey, error)
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
GetKeysByIDs(ids []uint) ([]models.APIKey, error)
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
CountByGroup(groupID uint) (int64, error)
// --- 多对多关系管理 --- key_mapping
LinkKeysToGroup(groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(keyID uint) ([]uint, error)
LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
@@ -55,8 +57,8 @@ type KeyRepository interface {
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
// --- 缓存管理 --- key_cache
LoadAllKeysToStore() error
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
LoadAllKeysToStore(ctx context.Context) error
HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
// --- 维护与后台任务 --- key_maintenance
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
@@ -65,16 +67,14 @@ type KeyRepository interface {
DeleteOrphanKeys() (int64, error)
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
GetActiveMasterKeys() ([]*models.APIKey, error)
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
// --- 轮询策略的"写"操作 --- key_writer
UpdateKeyUsageTimestamp(groupID, keyID uint)
// 同步更新缓存,供核心业务使用
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
// 异步更新缓存,供事件订阅者使用
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
}
type GroupRepository interface {
@@ -88,18 +88,20 @@ type gormKeyRepository struct {
store store.Store
logger *logrus.Entry
crypto *crypto.Service
config *config.Config
}
type gormGroupRepository struct {
db *gorm.DB
}
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository {
return &gormKeyRepository{
db: db,
store: s,
logger: logger.WithField("component", "repository.key🔗"),
crypto: crypto,
config: cfg,
}
}

View File

@@ -2,6 +2,7 @@
package scheduler
import (
"context"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/service"
"time"
@@ -15,7 +16,6 @@ type Scheduler struct {
logger *logrus.Entry
statsService *service.StatsService
keyRepo repository.KeyRepository
// healthCheckService *service.HealthCheckService // 健康检查任务预留
}
func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler {
@@ -32,11 +32,13 @@ func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyReposito
func (s *Scheduler) Start() {
s.logger.Info("Starting scheduler and registering jobs...")
// --- 任务注册 ---
// 任务一:每小时执行一次的统计聚合
// 使用CRON表达式精确定义“每小时的第5分钟”执行
_, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() {
s.logger.Info("Executing hourly request stats aggregation...")
if err := s.statsService.AggregateHourlyStats(); err != nil {
// 为后台定时任务创建一个新的、空的 context
ctx := context.Background()
if err := s.statsService.AggregateHourlyStats(ctx); err != nil {
s.logger.WithError(err).Error("Hourly stats aggregation failed.")
} else {
s.logger.Info("Hourly stats aggregation completed successfully.")
@@ -46,23 +48,14 @@ func (s *Scheduler) Start() {
s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err)
}
// 任务二:(预留) 自动健康检查 (例如每10分钟一次)
/*
_, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() {
s.logger.Info("Executing periodic health check for all groups...")
// s.healthCheckService.StartGlobalCheckTask() // 伪代码
})
if err != nil {
s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err)
}
*/
// [NEW] --- 任务三: 清理软删除的API Keys ---
// 任务二:(预留) 自动健康检查
// 任务三每日执行一次的软删除Key清理
// Executes once daily at 3:15 AM UTC.
_, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() {
s.logger.Info("Executing daily cleanup of soft-deleted API keys...")
// Let's assume a retention period of 7 days for now.
// In a real scenario, this should come from settings.
// [假设保留7天实际应来自配置
const retentionDays = 7
count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays))
@@ -77,9 +70,8 @@ func (s *Scheduler) Start() {
if err != nil {
s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err)
}
// --- 任务注册结束 ---
s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程
s.gocronScheduler.StartAsync()
s.logger.Info("Scheduler started.")
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/analytics_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/db/dialect"
@@ -43,7 +43,7 @@ func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d di
}
func (s *AnalyticsService) Start() {
s.wg.Add(2) // 2 (flushLoop, eventListener)
s.wg.Add(2)
go s.flushLoop()
go s.eventListener()
s.logger.Info("AnalyticsService (Command Side) started.")
@@ -53,13 +53,13 @@ func (s *AnalyticsService) Stop() {
close(s.stopChan)
s.wg.Wait()
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
s.flushToDB() // 停止前刷盘
s.flushToDB()
s.logger.Info("AnalyticsService final data flush completed.")
}
func (s *AnalyticsService) eventListener() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
@@ -87,9 +87,10 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline()
pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1)
@@ -120,6 +121,7 @@ func (s *AnalyticsService) flushLoop() {
}
func (s *AnalyticsService) flushToDB() {
ctx := context.Background()
now := time.Now().UTC()
keysToFlush := []string{
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
@@ -127,7 +129,7 @@ func (s *AnalyticsService) flushToDB() {
}
for _, key := range keysToFlush {
data, err := s.store.HGetAll(key)
data, err := s.store.HGetAll(ctx, key)
if err != nil || len(data) == 0 {
continue
}
@@ -136,15 +138,15 @@ func (s *AnalyticsService) flushToDB() {
if len(statsToFlush) > 0 {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"}, // conflict columns
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
if err != nil {
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
} else {
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
_ = s.store.HDel(key, parsedFields...)
_ = s.store.HDel(ctx, key, parsedFields...)
}
}
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/apikey_service.go
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -29,7 +29,6 @@ const (
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
)
// DTOs & Constants
const (
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
)
@@ -83,7 +82,6 @@ func NewAPIKeyService(
gm *GroupManager,
logger *logrus.Logger,
) *APIKeyService {
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
return &APIKeyService{
db: db,
keyRepo: repo,
@@ -99,22 +97,22 @@ func NewAPIKeyService(
}
func (s *APIKeyService) Start() {
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
}
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
return
}
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
return
@@ -177,6 +175,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
return
}
ctx := context.Background()
if event.RequestLog.IsSuccess {
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
if err != nil {
@@ -194,17 +193,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
now := time.Now()
mapping.LastUsedAt = &now
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
return
}
if statusChanged {
go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
}
return
}
if event.Error != nil {
s.judgeKeyErrors(
ctx,
event.CorrelationID,
*event.RequestLog.GroupID,
*event.RequestLog.KeyID,
@@ -215,6 +215,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
}
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
ctx := context.Background()
log := s.logger.WithFields(logrus.Fields{
"group_id": event.GroupID,
"key_id": event.KeyID,
@@ -222,11 +223,11 @@ func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChange
"reason": event.ChangeReason,
})
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
log.Info("Polling caches updated based on health check event.")
}
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
changeEvent := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -236,13 +237,12 @@ func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(changeEvent)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
}
}
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
// --- Path 1: High-performance DB pagination (no keyword) ---
func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
if params.Keyword == "" {
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
if err != nil {
@@ -260,14 +260,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
TotalPages: totalPages,
}, nil
}
// --- Path 2: In-memory search (keyword present) ---
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
// To get all keys, we fetch all IDs first, then get their full details.
var statusesToFilter []string
if params.Status != "" {
statusesToFilter = append(statusesToFilter, params.Status)
} else {
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
statusesToFilter = append(statusesToFilter, "all")
}
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
if err != nil {
@@ -277,14 +275,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
}
// This is the heavy operation: getting all keys and decrypting them.
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
if err != nil {
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
}
// We also need mappings to build the final `APIKeyDetails`.
var allMappings []models.GroupAPIKeyMapping
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
if err != nil {
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
}
@@ -292,7 +288,6 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
for i := range allMappings {
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
}
// Filter the results in memory.
var filteredItems []*models.APIKeyDetails
for _, key := range allKeys {
if strings.Contains(key.APIKey, params.Keyword) {
@@ -312,11 +307,9 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}
}
}
// Sort the filtered results to ensure consistent pagination (by ID descending).
sort.Slice(filteredItems, func(i, j int) bool {
return filteredItems[i].ID > filteredItems[j].ID
})
// Manually paginate the filtered results.
total := int64(len(filteredItems))
start := (params.Page - 1) * params.PageSize
end := start + params.PageSize
@@ -345,14 +338,15 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
}, nil
}
func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) {
func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) {
return s.keyRepo.GetKeysByIDs(ids)
}
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
go func() {
bgCtx := context.Background()
var oldKey models.APIKey
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
return
}
@@ -364,16 +358,14 @@ func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
return nil
}
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
// Get all associated groups before deletion to publish correct events
groups, err := s.keyRepo.GetGroupsForKey(id)
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
if err != nil {
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
}
err = s.keyRepo.HardDeleteByID(id)
if err == nil {
// Publish events for each group the key was a part of
for _, groupID := range groups {
event := models.KeyStatusChangedEvent{
KeyID: id,
@@ -381,13 +373,13 @@ func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
ChangeReason: "key_hard_deleted",
}
eventData, _ := json.Marshal(event)
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
}
}
return err
}
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -409,19 +401,20 @@ func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus model
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
return nil, err
}
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
return mapping, nil
}
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
ctx := context.Background()
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
if event.NewMasterStatus != models.MasterStatusRevoked {
return
}
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
return
@@ -432,7 +425,7 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
for _, groupID := range affectedGroupIDs {
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
@@ -441,32 +434,32 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
}
}
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
if len(keyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
if err != nil {
return nil, err
}
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
return taskStatus, nil
}
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
}
}()
var mappingsToProcess []models.GroupAPIKeyMapping
err := s.db.Preload("APIKey").
err := s.db.WithContext(ctx).Preload("APIKey").
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
Find(&mappingsToProcess).Error
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
result := &BatchRestoreResult{
@@ -476,7 +469,7 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
processedCount := 0
for _, mapping := range mappingsToProcess {
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
if mapping.APIKey == nil {
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
@@ -492,33 +485,29 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
// Use the version that doesn't trigger individual cache updates.
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
result.SkippedCount++
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
} else {
result.RestoredCount++
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
successfulMappings = append(successfulMappings, &mapping)
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
}
} else {
result.RestoredCount++ // Already active, count as success.
result.RestoredCount++
}
}
// After the loop, perform one single, efficient cache update.
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
// This is not a task-fatal error, so we just log it and continue.
}
// Account for keys that were requested but not found in the initial DB query.
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
var bannedKeyIDs []uint
err := s.db.Model(&models.GroupAPIKeyMapping{}).
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
Pluck("api_key_id", &bannedKeyIDs).Error
if err != nil {
@@ -527,10 +516,11 @@ func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, e
if len(bannedKeyIDs) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
}
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
}
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
ctx := context.Background()
group, ok := s.groupManager.GetGroupByID(event.GroupID)
if !ok {
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
@@ -552,7 +542,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 10 // Safety fallback
concurrency = 10
}
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
@@ -571,7 +561,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
if validationErr == nil {
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
}
} else {
@@ -579,7 +569,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
if !CustomErrors.As(validationErr, &apiErr) {
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
}
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false)
}
}
}()
@@ -592,12 +582,9 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
}
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
// that match a specific set of source statuses within a group.
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -605,35 +592,32 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatus
if len(keyIDs) == 0 {
now := time.Now()
return &task.Status{
IsRunning: false, // The "task" is not running.
IsRunning: false,
Processed: 0,
Total: 0,
Result: map[string]string{ // We use the flexible Result field to pass the message.
Result: map[string]string{
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
},
Error: "", // There is no error.
Error: "",
StartedAt: now,
FinishedAt: &now, // It started and finished at the same time.
}, nil // Return nil for the error, signaling a 200 OK.
FinishedAt: &now,
}, nil
}
// 2. Start a new task using the TaskService, following existing patterns.
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
if err != nil {
return nil, err // Pass up errors like "task already in progress".
return nil, err
}
// 3. Run the core logic in a separate goroutine.
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
return taskStatus, nil
}
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
}
}()
type BatchUpdateResult struct {
@@ -642,31 +626,27 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
}
result := &BatchUpdateResult{}
var successfulMappings []*models.GroupAPIKeyMapping
// 1. Fetch all key master statuses in one go. This is efficient.
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
if err != nil {
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
for _, key := range keys {
masterStatusMap[key.ID] = key.MasterStatus
}
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
// avoiding the need for a new repository method. This pattern is
// already used in other parts of this service.
var mappings []*models.GroupAPIKeyMapping
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
return
}
processedCount := 0
for _, mapping := range mappings {
processedCount++
// The progress update should reflect the number of items *being processed*, not the final count.
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
if !ok {
result.SkippedCount++
@@ -688,24 +668,25 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
} else {
result.UpdatedCount++
successfulMappings = append(successfulMappings, mapping)
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
}
} else {
result.UpdatedCount++ // Already in desired state, count as success.
result.UpdatedCount++
}
}
result.SkippedCount += (len(keyIDs) - len(mappings))
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
}
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
ctx := context.Background()
if success {
if group.PollingStrategy == models.StrategyWeighted {
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
}
return
}
@@ -716,26 +697,20 @@ func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.
errMsg := apiErr.Message
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
} else {
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
}
}
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
func sanitizeForLog(errMsg string) string {
// Find the start of any potential JSON blob or detailed structure.
jsonStartIndex := strings.Index(errMsg, "{")
var cleanMsg string
if jsonStartIndex != -1 {
// If a '{' is found, take everything before it as the summary
// and append a simple placeholder.
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
} else {
// If no JSON-like structure is found, use the original message.
cleanMsg = errMsg
}
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
const maxLen = 250
if len(cleanMsg) > maxLen {
return cleanMsg[:maxLen] + "..."
@@ -744,6 +719,7 @@ func sanitizeForLog(errMsg string) string {
}
func (s *APIKeyService) judgeKeyErrors(
ctx context.Context,
correlationID string,
groupID, keyID uint,
apiErr *CustomErrors.APIError,
@@ -765,11 +741,11 @@ func (s *APIKeyService) judgeKeyErrors(
oldStatus := mapping.Status
mapping.Status = models.StatusBanned
mapping.LastError = errorMessage
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
} else {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(keyID, "permanent_upstream_error")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
}
}
return
@@ -801,23 +777,23 @@ func (s *APIKeyService) judgeKeyErrors(
if oldStatus != newStatus {
mapping.Status = newStatus
}
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update mapping after temporary error.")
return
}
if oldStatus != newStatus {
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
}
return
}
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
}
}
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) {
key, err := s.keyRepo.GetKeyByID(keyID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -832,7 +808,7 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
}
oldMasterStatus := key.MasterStatus
newMasterStatus := models.MasterStatusRevoked
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
return
}
@@ -844,9 +820,9 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(masterKeyEvent)
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
_ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData)
}
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
}

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/dashboard_query_service.go
package service
import (
"context"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
@@ -17,8 +17,6 @@ import (
const overviewCacheChannel = "syncer:cache:dashboard_overview"
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
type DashboardQueryService struct {
db *gorm.DB
store store.Store
@@ -54,9 +52,9 @@ func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
}
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(statsKey)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
@@ -74,11 +72,11 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
SuccessRequests int64
}
var last1Hour, last24Hours requestStatsResult
s.db.Model(&models.StatsHourly{}).
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour)
s.db.Model(&models.StatsHourly{}).
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours)
@@ -109,8 +107,9 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
}
func (s *DashboardQueryService) eventListener() {
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
ctx := context.Background()
keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close()
defer upstreamStatusSub.Close()
for {
@@ -128,7 +127,6 @@ func (s *DashboardQueryService) eventListener() {
}
}
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
@@ -141,8 +139,7 @@ func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate()
}
// QueryHistoricalChart 查询历史图表数据。
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
@@ -151,7 +148,7 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
@@ -189,38 +186,38 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
}
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx := context.Background()
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
startTime := time.Now()
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体而不是nil
RequestCount24h: models.StatCard{}, // 同上
KeyCount: models.StatCard{},
RequestCount24h: models.StatCard{},
TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{},
RequestCounts: make(map[string]int64),
}
// --- 1. Aggregate Operational Status from Mappings ---
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var mappingStatusResults []MappingStatusResult
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
}
for _, res := range mappingStatusResults {
resp.KeyStatusCount[res.Status] = res.Count
}
// --- 2. Aggregate Master Status from APIKeys ---
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var masterStatusResults []MasterStatusResult
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query master status stats: %w", err)
}
var totalKeys, invalidKeys int64
@@ -235,20 +232,15 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
now := time.Now()
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
var count1m, count1h, count1d int64
// RPM: 从此刻倒推1分钟
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
// RPH: 从此刻倒推1小时
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
// RPD: 从今天零点 (UTC) 到此刻
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
year, month, day := now.UTC().Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
var count30d int64
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
@@ -256,7 +248,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
resp.RequestCounts["30d"] = count30d
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
} else {
for _, u := range upstreams {
@@ -269,7 +261,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
return resp, nil
}
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
var startTime time.Time
now := time.Now()
switch period {
@@ -288,7 +280,7 @@ func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H,
Success int64
}
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error

View File

@@ -1,8 +1,8 @@
// Filename: internal/service/db_log_writer_service.go
package service
import (
"context"
"encoding/json"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
@@ -35,35 +35,30 @@ func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.Settin
store: s,
SettingsManager: settings,
logger: logger.WithField("component", "DBLogWriter📝"),
// 使用配置值来创建缓冲区
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
}
}
func (s *DBLogWriterService) Start() {
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
// 启动事件监听器
s.wg.Add(2)
go s.eventListenerLoop()
// 启动数据库写入器
go s.dbWriterLoop()
s.logger.Info("DBLogWriterService started.")
}
func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan) // 通知所有goroutine停止
s.wg.Wait() // 等待所有goroutine完成
close(s.stopChan)
s.wg.Wait()
s.logger.Info("DBLogWriterService stopped.")
}
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
ctx := context.Background()
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
return
@@ -80,34 +75,27 @@ func (s *DBLogWriterService) eventListenerLoop() {
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
continue
}
// 将事件中的日志部分放入缓冲区
select {
case s.logBuffer <- &event.RequestLog:
default:
s.logger.Warn("Log buffer is full. A log message might be dropped.")
}
case <-s.stopChan:
s.logger.Info("Event listener loop stopping.")
// 关闭缓冲区以通知dbWriterLoop处理完剩余日志后退出
close(s.logBuffer)
return
}
}
}
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done()
// 在启动时获取一次配置
cfg := s.SettingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 {
batchSize = 100
}
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushTimeout <= 0 {
flushTimeout = 5 * time.Second
@@ -126,7 +114,7 @@ func (s *DBLogWriterService) dbWriterLoop() {
return
}
batch = append(batch, logEntry)
if len(batch) >= batchSize { // 使用配置的批次大小
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
@@ -139,7 +127,6 @@ func (s *DBLogWriterService) dbWriterLoop() {
}
}
// flushBatch 将一个批次的日志写入数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")

View File

@@ -75,7 +75,7 @@ func NewHealthCheckService(
func (s *HealthCheckService) Start() {
s.logger.Info("Starting HealthCheckService with independent check loops...")
s.wg.Add(4) // Now four loops
s.wg.Add(4)
go s.runKeyCheckLoop()
go s.runUpstreamCheckLoop()
go s.runProxyCheckLoop()
@@ -102,8 +102,6 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
func (s *HealthCheckService) runKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Key check dynamic scheduler loop started.")
// 主调度循环,每分钟检查一次任务
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -126,26 +124,22 @@ func (s *HealthCheckService) scheduleKeyChecks() {
defer s.groupCheckTimeMutex.Unlock()
for _, group := range groups {
// 获取特定于组的运营配置
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
continue
}
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
continue // 跳过禁用了健康检查的组
continue
}
var intervalMinutes int
if opConfig.KeyCheckIntervalMinutes != nil {
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
}
interval := time.Duration(intervalMinutes) * time.Minute
if interval <= 0 {
continue // 跳过无效的检查周期
continue
}
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
go s.performKeyChecksForGroup(group, opConfig)
@@ -160,7 +154,6 @@ func (s *HealthCheckService) runUpstreamCheckLoop() {
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
s.performUpstreamChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -184,7 +177,6 @@ func (s *HealthCheckService) runProxyCheckLoop() {
if s.SettingsManager.GetSettings().EnableProxyCheck {
s.performProxyChecks()
}
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
defer ticker.Stop()
@@ -203,6 +195,7 @@ func (s *HealthCheckService) runProxyCheckLoop() {
}
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -213,11 +206,9 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
}
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
log.Infof("Starting key health check cycle.")
var mappingsToCheck []models.GroupAPIKeyMapping
err = s.db.Model(&models.GroupAPIKeyMapping{}).
err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
Where("group_api_key_mappings.key_group_id = ?", group.ID).
Where("api_keys.master_status = ?", models.MasterStatusActive).
@@ -233,7 +224,6 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("No key mappings to check for this group.")
return
}
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
var wg sync.WaitGroup
@@ -242,14 +232,14 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
concurrency = *opConfig.KeyCheckConcurrency
}
if concurrency <= 0 {
concurrency = 1 // 保证至少有一个 worker
concurrency = 1
}
for w := 1; w <= concurrency; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for mapping := range jobs {
s.checkAndProcessMapping(&mapping, timeout, endpoint)
s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
}
}(w)
}
@@ -261,52 +251,46 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
log.Info("Finished key health check cycle.")
}
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
if mapping.APIKey == nil {
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
return
}
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
// --- 诊断一:验证成功 (健康) ---
if validationErr == nil {
if mapping.Status != models.StatusActive {
s.activateMapping(mapping)
s.activateMapping(ctx, mapping)
}
return
}
errorString := validationErr.Error()
// --- 诊断二:永久性错误 ---
if CustomErrors.IsPermanentUpstreamError(errorString) {
s.revokeMapping(mapping, validationErr)
s.revokeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断三:暂时性错误 ---
if CustomErrors.IsTemporaryUpstreamError(errorString) {
// Log with a higher level (WARN) since this is an actionable, proactive finding.
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
s.penalizeMapping(mapping, validationErr)
s.penalizeMapping(ctx, mapping, validationErr)
return
}
// --- 诊断四:其他未知或上游服务错误 ---
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
}
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
oldStatus := mapping.Status
mapping.Status = models.StatusActive
mapping.ConsecutiveErrorCount = 0
mapping.LastError = ""
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
// Re-fetch group-specific operational config to get the correct thresholds
func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
if !ok {
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
@@ -320,7 +304,6 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
oldStatus := mapping.Status
mapping.LastError = err.Error()
mapping.ConsecutiveErrorCount++
// Use the group-specific threshold
threshold := *opConfig.KeyBlacklistThreshold
if mapping.ConsecutiveErrorCount >= threshold {
mapping.Status = models.StatusCooldown
@@ -329,44 +312,41 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
mapping.CooldownUntil = &cooldownTime
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
}
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
if oldStatus != mapping.Status {
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
}
}
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
oldStatus := mapping.Status
if oldStatus == models.StatusBanned {
return // Already banned, do nothing.
return
}
mapping.Status = models.StatusBanned
mapping.LastError = "Definitive error: " + err.Error()
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
mapping.ConsecutiveErrorCount = 0
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
return
}
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
}
}
func (s *HealthCheckService) performUpstreamChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
return
}
@@ -390,10 +370,10 @@ func (s *HealthCheckService) performUpstreamChecks() {
s.lastResultsMutex.Unlock()
if oldStatus != newStatus {
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
} else {
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
}
}
}(u)
@@ -412,10 +392,11 @@ func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration)
}
func (s *HealthCheckService) performProxyChecks() {
ctx := context.Background()
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
var proxies []*models.ProxyConfig
if err := s.db.Find(&proxies).Error; err != nil {
if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
s.logger.WithError(err).Error("Failed to retrieve proxies.")
return
}
@@ -438,7 +419,7 @@ func (s *HealthCheckService) performProxyChecks() {
s.lastResultsMutex.Unlock()
if proxyCfg.Status != newStatus {
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
}
}
@@ -482,7 +463,7 @@ func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout ti
return true
}
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
event := models.KeyStatusChangedEvent{
KeyID: keyID,
GroupID: groupID,
@@ -496,12 +477,12 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, o
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
return
}
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
}
}
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
event := models.UpstreamHealthChangedEvent{
UpstreamID: upstream.ID,
UpstreamURL: upstream.URL,
@@ -516,28 +497,20 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
return
}
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
}
}
// =========================================================================
// Global Base Key Check (New Logic)
// =========================================================================
func (s *HealthCheckService) runBaseKeyCheckLoop() {
defer s.wg.Done()
s.logger.Info("Global base key check loop started.")
settings := s.SettingsManager.GetSettings()
if !settings.EnableBaseKeyCheck {
s.logger.Info("Global base key check is disabled.")
return
}
// Perform an initial check on startup
s.performBaseKeyChecks()
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
if interval <= 0 {
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
@@ -558,6 +531,7 @@ func (s *HealthCheckService) runBaseKeyCheckLoop() {
}
func (s *HealthCheckService) performBaseKeyChecks() {
ctx := context.Background()
s.logger.Info("Starting global base key check cycle.")
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
@@ -576,7 +550,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
jobs := make(chan *models.APIKey, len(keys))
var wg sync.WaitGroup
if concurrency <= 0 {
concurrency = 5 // Safe default
concurrency = 5
}
for w := 0; w < concurrency; w++ {
wg.Add(1)
@@ -587,10 +561,10 @@ func (s *HealthCheckService) performBaseKeyChecks() {
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
oldStatus := key.MasterStatus
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
} else {
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
}
}
}
@@ -604,8 +578,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
s.logger.Info("Global base key check cycle finished.")
}
// 事件发布辅助函数
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
event := models.MasterKeyStatusChangedEvent{
KeyID: keyID,
OldMasterStatus: oldStatus,
@@ -618,7 +591,7 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldS
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
return
}
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
}
}

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -42,88 +43,84 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
}
}
// --- 通用的 Panic-Safe 任務執行器 ---
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err)
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
taskFunc()
}
// --- Public Task Starters ---
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in input text")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
})
return taskStatus, nil
}
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_hard_delete" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_restore_keys" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
// --- Private Task Runners ---
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 步骤 1: 对输入的原始 key 列表进行去重。
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeyStrings []string
for _, kStr := range keys {
@@ -133,41 +130,37 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
}
}
if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return
}
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
}
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
return
}
alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
var keysToLink []models.APIKey
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
}
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
// 步骤 6: 分块处理【链接Key到组】的操作并实时更新进度。
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
@@ -179,44 +172,41 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
end = len(idsToLink)
}
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
return
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
}
// 步骤 7: 准备最终结果并结束任务。
result := gin.H{
"newly_linked_count": len(keysToLink),
"already_linked_count": len(alreadyLinkedIDSet),
"total_linked_count": len(allKeyModels),
}
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
if validateOnImport {
s.publishImportGroupCompletedEvent(groupID, idsToLink)
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
}
}
}
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
// runUnlinkKeysTask
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeys []string
for _, kStr := range keys {
@@ -225,46 +215,42 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
uniqueKeys = append(uniqueKeys, kStr)
}
}
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
}
if len(keysToUnlink) == 0 {
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
return
}
idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID
}
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
var totalUnlinked int64
// 步骤 3: 分块处理【解绑Key】的操作并上报进度。
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize
if end > len(idsToUnlink) {
end = len(idsToUnlink)
}
chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
return
}
totalUnlinked += unlinked
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
@@ -276,10 +262,10 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
}
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
@@ -290,22 +276,21 @@ func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return
}
totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
var restoredCount int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
@@ -316,21 +301,21 @@ func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return
}
restoredCount += count
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
KeyID: keyID,
@@ -340,7 +325,7 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
@@ -349,16 +334,16 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
}
}
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
ChangeReason: reason,
}
eventData, _ := json.Marshal(event)
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 {
return
}
@@ -372,17 +357,15 @@ func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
return
}
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
} else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
}
}
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
// 1. [New] Find the keys to operate on.
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
@@ -390,8 +373,7 @@ func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(groupID, keysAsText)
}
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/channel"
@@ -62,20 +63,18 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
return fmt.Errorf("failed to create request: %w", err)
}
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
s.channel.ModifyRequest(req, key)
resp, err := client.Do(req)
if err != nil {
// This is a network-level error (e.g., timeout, DNS issue)
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil // Success
return nil
}
// Read the body for more error details
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
@@ -84,7 +83,6 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
errorMsg = string(bodyBytes)
}
// This is a validation failure with a specific HTTP status code
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
@@ -92,8 +90,7 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
}
}
// --- 异步任务方法 (全面适配新task包) ---
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
@@ -111,7 +108,6 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
}
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
// [FIX] Correctly use the NewAPIError constructor for a missing group.
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
@@ -119,15 +115,15 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil {
return nil, err // Pass up the error from task service (e.g., "task already running")
return nil, err
}
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
return nil, err
}
var concurrency int
@@ -136,11 +132,11 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
} else {
concurrency = settings.BaseKeyCheckConcurrency
}
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
return taskStatus, nil
}
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup
var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys))
@@ -165,7 +161,6 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
GroupID: &groupID,
KeyID: &apiKeyModel.ID,
},
@@ -185,14 +180,15 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
event.RequestLog.IsSuccess = false
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
}
mu.Lock()
finalResults[j.Index] = currentResult
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
mu.Unlock()
}
}()
@@ -202,10 +198,10 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
}
close(jobs)
wg.Wait()
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
}
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
@@ -216,5 +212,5 @@ func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(groupID, keysAsText)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}

View File

@@ -1,9 +1,9 @@
// Filename: internal/service/resource_service.go
package service
import (
"errors"
"context"
"gemini-balancer/internal/domain/proxy"
apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
@@ -15,10 +15,7 @@ import (
"github.com/sirupsen/logrus"
)
var (
ErrNoResourceAvailable = errors.New("no available resource found for the request")
)
// RequestResources 封装了一次成功请求所需的所有资源。
type RequestResources struct {
KeyGroup *models.KeyGroup
APIKey *models.APIKey
@@ -27,86 +24,92 @@ type RequestResources struct {
RequestConfig *models.RequestConfig
}
// ResourceService 负责根据请求参数和业务规则动态地选择和分配API密钥及相关资源。
type ResourceService struct {
settingsManager *settings.SettingsManager
groupManager *GroupManager
keyRepo repository.KeyRepository
authTokenRepo repository.AuthTokenRepository
apiKeyService *APIKeyService
proxyManager *proxy.Module
logger *logrus.Entry
initOnce sync.Once
}
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
func NewResourceService(
sm *settings.SettingsManager,
gm *GroupManager,
kr repository.KeyRepository,
atr repository.AuthTokenRepository,
aks *APIKeyService,
pm *proxy.Module,
logger *logrus.Logger,
) *ResourceService {
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
rs := &ResourceService{
settingsManager: sm,
groupManager: gm,
keyRepo: kr,
authTokenRepo: atr,
apiKeyService: aks,
proxyManager: pm,
logger: logger.WithField("component", "ResourceService📦"),
}
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
rs.initOnce.Do(func() {
go rs.preWarmCache(logger)
go rs.preWarmCache()
})
return rs
}
// --- [模式一:智能聚合模式] ---
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
// GetResourceFromBasePool 使用智能聚合模式获取资源。
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.")
// 1.筛选出所有符合条件的候选组,并按优先级排序
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable
}
// 2.从 BasePool中根据系统全局策略选择一个Key
basePool := &repository.BasePool{
CandidateGroups: candidateGroups,
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
}
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the BasePool.")
return nil, apperrors.ErrNoKeysAvailable
}
// 3. 组装最终资源
// [关键] 在此模式下RequestConfig 永远是空的,以保证透明性。
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
if err != nil {
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err
}
resources.RequestConfig = &models.RequestConfig{} // 强制为空
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil
}
// --- [模式二:精确路由模式] ---
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.")
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
}
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
}
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
return nil, apperrors.ErrNoKeysAvailable
@@ -117,39 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou
log.WithError(err).Error("Failed to assemble resources for precise route.")
return nil, err
}
resources.RequestConfig = targetGroup.RequestConfig
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
return resources, nil
}
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
allGroups := s.groupManager.GetAllGroups()
if len(allGroups) == 0 {
return []string{}
}
allowedModelsSet := make(map[string]struct{})
allowedGroupIDs := make(map[uint]bool)
if authToken.IsAdmin {
for _, group := range allGroups {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
allowedGroupIDs[group.ID] = true
}
} else {
allowedGroupIDs := make(map[uint]bool)
for _, ag := range authToken.AllowedGroups {
allowedGroupIDs[ag.ID] = true
}
for _, group := range allGroups {
if _, ok := allowedGroupIDs[group.ID]; ok {
for _, modelMapping := range group.AllowedModels {
}
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
allowedModelsSet := make(map[string]struct{})
for _, group := range allGroups {
if allowedGroupIDs[group.ID] {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
}
result := make([]string, 0, len(allowedModelsSet))
for modelName := range allowedModelsSet {
result = append(result, modelName)
@@ -158,20 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
return result
}
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}
// --- 私有辅助方法 ---
// preWarmCache 在后台执行一次性的缓存预热任务。
func (s *ResourceService) preWarmCache() {
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
s.logger.Info("Performing initial key cache pre-warming...")
// 强制加载 GroupManager 缓存
s.logger.Info("Pre-warming GroupManager cache...")
_ = s.groupManager.GetAllGroups()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
defer cancel()
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
} else {
s.logger.Info("Initial key cache pre-warming completed successfully.")
}
}
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
selectedUpstream := s.selectUpstreamForGroup(group)
if selectedUpstream == nil {
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
}
var proxyConfig *models.ProxyConfig
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
// if group.EnableProxy && s.proxyModule != nil {
// var err error
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
// if err != nil {
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
// }
// }
var err error
// 只有在组明确启用代理时,才为其分配代理
if group.EnableProxy {
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
if err != nil {
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
}
}
return &RequestResources{
KeyGroup: group,
APIKey: apiKey,
@@ -180,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
}, nil
}
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
if len(group.AllowedUpstreams) > 0 {
// (未来可扩展负载均衡逻辑)
return group.AllowedUpstreams[0]
}
globalSettings := s.settingsManager.GetSettings()
@@ -191,62 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
return nil
}
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
time.Sleep(2 * time.Second)
s.logger.Info("Performing initial key cache pre-warming...")
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
return err
}
s.logger.Info("Initial key cache pre-warming completed successfully.")
return nil
}
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
}
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup
// 1. 确定权限范围
allowedGroupIDs := make(map[uint]bool)
isTokenRestricted := len(allowedGroupsFromToken) > 0
if isTokenRestricted {
for _, ag := range allowedGroupsFromToken {
allowedGroupIDs[ag.ID] = true
}
}
// 2. 筛选
for _, group := range allGroupsFromCache {
// 检查Token权限
if isTokenRestricted && !allowedGroupIDs[group.ID] {
// 检查令牌权限
if !s.isTokenAllowedForGroup(authToken, group.ID) {
continue
}
// 检查模型是否被允许
isModelAllowed := false
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
isModelAllowed = true
} else {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
isModelAllowed = true
break
}
}
}
if isModelAllowed {
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
candidateGroups = append(candidateGroups, group)
}
}
// 3.按 Order 字段升序排序
sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order
})
return candidateGroups
}
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
return true
}
}
return false
}
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
if authToken.IsAdmin {
return true
@@ -258,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr
}
return false
}
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}

View File

@@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke
// IsIPBanned
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
banKey := fmt.Sprintf("banned_ip:%s", ip)
return s.store.Exists(banKey)
return s.store.Exists(ctx, banKey)
}
// RecordFailedLoginAttempt
@@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
return nil
}
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1)
if err != nil {
return err
}
@@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
banDuration := s.SettingsManager.GetIPBanDuration()
banKey := fmt.Sprintf("banned_ip:%s", ip)
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil {
return err
}
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
s.store.HDel(loginAttemptsKey, ip)
s.store.HDel(ctx, loginAttemptsKey, ip)
}
return nil

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -34,7 +35,7 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
@@ -67,42 +68,43 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
return
}
ctx := context.Background()
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", 1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
}
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
var results []struct {
Status models.APIKeyStatus
Count int64
}
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ?", groupID).
Select("status, COUNT(*) as count").
Group("status").
@@ -119,37 +121,25 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
}
updates["total_keys"] = totalKeys
if err := s.store.Del(statsKey); err != nil {
if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
}
if err := s.store.HSet(statsKey, updates); err != nil {
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
}
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
return nil
}
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
// TODO 逻辑:
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
// 3. 组合成 DashboardStatsResponse
// ... 这个方法的具体实现我们可以在DashboardQueryService中完成
// 这里我们先确保StatsService的核心职责维护缓存已经完成。
// 为了编译通过,我们先返回一个空对象。
// 伪代码:
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
// ...
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
return &models.DashboardStatsResponse{}, nil
}
func (s *StatsService) AggregateHourlyStats() error {
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
s.logger.Info("Starting aggregation of the last hour's request data...")
now := time.Now()
endTime := now.Truncate(time.Hour) // 例如15:23 -> 15:00
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
endTime := now.Truncate(time.Hour)
startTime := endTime.Add(-1 * time.Hour)
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
type aggregationResult struct {
@@ -161,7 +151,8 @@ func (s *StatsService) AggregateHourlyStats() error {
CompletionTokens int64
}
var results []aggregationResult
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Group("group_id, model_name").
@@ -179,7 +170,7 @@ func (s *StatsService) AggregateHourlyStats() error {
var hourlyStats []models.StatsHourly
for _, res := range results {
hourlyStats = append(hourlyStats, models.StatsHourly{
Time: startTime, // 所有记录的时间戳都是该小时的起点
Time: startTime,
GroupID: res.GroupID,
ModelName: res.ModelName,
RequestCount: res.RequestCount,
@@ -189,7 +180,7 @@ func (s *StatsService) AggregateHourlyStats() error {
})
}
return s.db.Clauses(clause.OnConflict{
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error

View File

@@ -1,4 +1,4 @@
// file: gemini-balancer\internal\settings\settings.go
// Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版)
package settings
import (
@@ -19,7 +19,9 @@ import (
const SettingsUpdateChannel = "system_settings:updated"
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型
var _ models.SettingsManager = (*SettingsManager)(nil)
// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。
type SettingsManager struct {
db *gorm.DB
syncer *syncer.CacheSyncer[*models.SystemSettings]
@@ -27,13 +29,14 @@ type SettingsManager struct {
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
}
// NewSettingsManager 创建一个新的 SettingsManager 实例。
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
sm := &SettingsManager{
db: db,
logger: logger.WithField("component", "SettingsManager⚙"),
jsonToFieldType: make(map[string]reflect.Type),
}
// settingsLoader 的职责:读取“砖块”,组装并返回“蓝图”
settingsType := reflect.TypeOf(models.SystemSettings{})
for i := 0; i < settingsType.NumField(); i++ {
field := settingsType.Field(i)
@@ -42,102 +45,89 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
sm.jsonToFieldType[jsonTag] = field.Type
}
}
// settingsLoader 的职责:读取“砖块”,智能组装成“蓝图”
settingsLoader := func() (*models.SystemSettings, error) {
sm.logger.Info("Loading system settings from database...")
var dbRecords []models.Setting
if err := sm.db.Find(&dbRecords).Error; err != nil {
return nil, fmt.Errorf("failed to load system settings from db: %w", err)
}
settingsMap := make(map[string]string)
for _, record := range dbRecords {
settingsMap[record.Key] = record.Value
}
// 从一个包含了所有“出厂设置”的“蓝图”开始
settings := defaultSystemSettings()
v := reflect.ValueOf(settings).Elem()
t := v.Type()
// [智能卸货]
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
fieldValue := v.Field(i)
jsonTag := field.Tag.Get("json")
if dbValue, ok := settingsMap[jsonTag]; ok {
if dbValue, ok := settingsMap[jsonTag]; ok {
if err := parseAndSetField(fieldValue, dbValue); err != nil {
sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
}
}
}
if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" {
if settings.DefaultUpstreamURL != "" {
// 如果全局上游URL已设置则基于它构建新的检查端点。
originalEndpoint := settings.BaseKeyCheckEndpoint
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
settings.BaseKeyCheckEndpoint = derivedEndpoint
sm.logger.Infof(
"BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'",
originalEndpoint, derivedEndpoint,
)
}
} else {
// [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。
if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" {
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint)
settings.BaseKeyCheckEndpoint = derivedEndpoint
} else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" {
// 恢复 else 日志,以明确告知用户正在使用自定义覆盖。
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
}
sm.logger.Info("System settings loaded and cached.")
sm.DisplaySettings(settings)
return settings, nil
}
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
if err != nil {
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
}
sm.syncer = s
go sm.ensureSettingsInitialized()
if err := sm.ensureSettingsInitialized(); err != nil {
return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err)
}
return sm, nil
}
// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图”
// GetSettings 返回当前缓存的系统设置。
func (sm *SettingsManager) GetSettings() *models.SystemSettings {
return sm.syncer.Get()
}
// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库
// UpdateSettings 更新一个或多个系统设置。
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
var settingsToUpdate []models.Setting
for key, value := range settingsMap {
fieldType, ok := sm.jsonToFieldType[key]
if !ok {
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
continue
}
var dbValue string
// [智能打包]
// 如果字段是 slice 或 map我们就将传入的 interface{} “打包”成 JSON string
kind := fieldType.Kind()
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, marshalErr := json.Marshal(value)
if marshalErr != nil {
// [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。
sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr)
continue // 跳过继续处理下一个key
}
dbValue = string(jsonBytes)
} else if kind == reflect.Bool {
if b, ok := value.(bool); ok {
dbValue = strconv.FormatBool(b)
} else {
dbValue = "false"
}
} else {
dbValue = fmt.Sprintf("%v", value)
dbValue, err := sm.convertToDBValue(key, value, fieldType)
if err != nil {
sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err)
continue
}
settingsToUpdate = append(settingsToUpdate, models.Setting{
Key: key,
Value: dbValue,
})
}
if len(settingsToUpdate) > 0 {
err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
@@ -147,83 +137,20 @@ func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) er
return fmt.Errorf("failed to update settings in db: %w", err)
}
}
return sm.syncer.Invalidate()
}
// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义
func (sm *SettingsManager) ensureSettingsInitialized() {
defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var existing models.Setting
if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound {
var defaultValue string
kind := fieldValue.Kind()
// [智能初始化]
if kind == reflect.Slice || kind == reflect.Map {
// 为复杂类型生成一个“空的”JSON字符串例如 "[]" 或 "{}"
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"), // 元数据中的default永远来自tag
}
if err := sm.db.Create(&setting).Error; err != nil {
sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err)
}
}
if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err)
return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err)
}
return nil
}
// ResetAndSaveSettings [核心新增] 將所有置重置為其在 'default' 標籤中定義的值。
// ResetAndSaveSettings 所有置重置为其默认值。
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
var settingsToSave []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
settingsToSave := sm.buildSettingsFromDefaults(defaults)
var defaultValue string
kind := fieldValue.Kind()
// [智能重置]
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
}
settingsToSave = append(settingsToSave, setting)
}
if len(settingsToSave) > 0 {
err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
@@ -233,8 +160,93 @@ func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error
return nil, fmt.Errorf("failed to reset settings in db: %w", err)
}
}
if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err)
sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err)
return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err)
}
return defaults, nil
}
// --- 私有辅助函数 ---
func (sm *SettingsManager) ensureSettingsInitialized() error {
defaults := defaultSystemSettings()
settingsToCreate := sm.buildSettingsFromDefaults(defaults)
for _, setting := range settingsToCreate {
var existing models.Setting
err := sm.db.Where("key = ?", setting.Key).First(&existing).Error
if err == gorm.ErrRecordNotFound {
sm.logger.Infof("Initializing new setting '%s'", setting.Key)
if createErr := sm.db.Create(&setting).Error; createErr != nil {
return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr)
}
} else if err != nil {
return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err)
}
}
return nil
}
func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting {
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
var settings []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var defaultValue string
kind := fieldValue.Kind()
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
settings = append(settings, models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
})
}
return settings
}
// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。
func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) {
kind := fieldType.Kind()
switch kind {
case reflect.Slice, reflect.Map:
jsonBytes, err := json.Marshal(value)
if err != nil {
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
}
return string(jsonBytes), nil
case reflect.Bool:
b, ok := value.(bool)
if !ok {
return "", fmt.Errorf("expected bool, but got %T", value)
}
return strconv.FormatBool(b), nil
default:
return fmt.Sprintf("%v", value), nil
}
}

View File

@@ -1,3 +1,4 @@
// Filename: internal/store/factory.go
package store
import (
@@ -11,7 +12,6 @@ import (
// NewStore creates a new store based on the application configuration.
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
// 检查是否有Redis配置
if cfg.Redis.DSN != "" {
opts, err := redis.ParseURL(cfg.Redis.DSN)
if err != nil {
@@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
client := redis.NewClient(opts)
if err := client.Ping(context.Background()).Err(); err != nil {
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
return NewMemoryStore(logger), nil
}
logger.Info("Successfully connected to Redis. Using Redis as store.")
return NewRedisStore(client), nil
return NewRedisStore(client, logger), nil
}
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
return NewMemoryStore(logger), nil

View File

@@ -1,17 +1,20 @@
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
// Filename: internal/store/memory_store.go
package store
import (
"context"
"fmt"
"math/rand"
"sort"
"strconv"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct {
@@ -32,7 +35,6 @@ type memoryStore struct {
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
@@ -42,7 +44,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
// 使用当前时间作为种子,创建一个新的随机数源
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
@@ -50,13 +51,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
return store
}
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now() // 避免在循环中重复调用
now := time.Now()
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
@@ -66,92 +66,10 @@ func (s *memoryStore) startGCollector() {
}
}
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
// 安全地进行类型断言
mainSet, mainOk = mainItem.value.(map[string]struct{})
// 确保断言成功且集合不为空
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
// 安全地进行类型断言
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
// 安全地处理冷却池
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
// SRandMember [并发修复版] 使用带锁的rng
func (s *memoryStore) SRandMember(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
@@ -162,7 +80,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
return nil
}
func (s *memoryStore) Get(key string) ([]byte, error) {
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -175,7 +93,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
return nil, ErrNotFound
}
func (s *memoryStore) Del(keys ...string) error {
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
@@ -184,14 +102,25 @@ func (s *memoryStore) Del(keys ...string) error {
return nil
}
func (s *memoryStore) Exists(key string) (bool, error) {
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok {
return ErrNotFound
}
item.expireAt = time.Now().Add(expiration)
return nil
}
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -208,7 +137,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error {
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -223,7 +152,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
return nil
}
func (s *memoryStore) HSet(key string, values map[string]any) error {
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -242,7 +171,22 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
return nil
}
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
if hash, ok := item.value.(map[string]string); ok {
if value, exists := hash[field]; exists {
return value, nil
}
}
return "", ErrNotFound
}
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -259,7 +203,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -281,7 +225,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
return newVal, nil
}
func (s *memoryStore) LPush(key string, values ...any) error {
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -301,7 +245,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
return nil
}
func (s *memoryStore) LRem(key string, count int64, value any) error {
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -326,7 +270,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
return nil
}
func (s *memoryStore) SAdd(key string, members ...any) error {
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -345,7 +289,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
return nil
}
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -375,7 +319,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
return popped, nil
}
func (s *memoryStore) SMembers(key string) ([]string, error) {
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -393,7 +337,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
return members, nil
}
func (s *memoryStore) SRem(key string, members ...any) error {
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -410,7 +354,51 @@ func (s *memoryStore) SRem(key string, members ...any) error {
return nil
}
func (s *memoryStore) Rotate(key string) (string, error) {
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
unionSet := make(map[string]struct{})
for _, key := range keys {
item, ok := s.items[key]
if !ok || item.isExpired() {
continue
}
if set, ok := item.value.(map[string]struct{}); ok {
for member := range set {
unionSet[member] = struct{}{}
}
}
}
destItem := &memoryStoreItem{value: unionSet}
s.items[destination] = destItem
return int64(len(unionSet)), nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -426,7 +414,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -447,8 +435,17 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
return list[index], nil
}
// Zset methods... (ZAdd, ZRange, ZRem)
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
for key, value := range values {
// 内存存储不支持独立的 TTL因此我们假设永不过期
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
}
return nil
}
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -471,8 +468,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
// NOTE: This ZSet implementation is simple but not performant for large sets.
// A production implementation would use a skip list or a balanced tree.
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
@@ -482,7 +477,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -515,7 +510,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
}
return result, nil
}
func (s *memoryStore) ZRem(key string, members ...any) error {
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -540,13 +535,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
return nil
}
// Pipeline implementation
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline() Pipeliner {
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
@@ -559,7 +597,6 @@ func (p *memoryPipeliner) Exec() error {
}
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
capturedKey := key
p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok {
@@ -576,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
}
})
}
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
capturedKey := key
capturedValue := value
p.ops = append(p.ops, func() {
var expireAt time.Time
if expiration > 0 {
expireAt = time.Now().Add(expiration)
}
p.store.items[capturedKey] = &memoryStoreItem{
value: capturedValue,
expireAt: expireAt,
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -615,7 +668,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
@@ -637,11 +689,126 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
capturedKey := key
capturedValue := fmt.Sprintf("%v", value)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
list, ok := item.value.([]string)
if !ok {
return
}
newList := make([]string, 0, len(list))
removed := int64(0)
for _, v := range list {
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
removed++
continue
}
newList = append(newList, v)
}
item.value = newList
})
}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
capturedKey := key
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range capturedValues {
hash[field] = fmt.Sprintf("%v", value)
}
})
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
capturedKey := key
capturedField := field
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
hash[capturedField] = strconv.FormatInt(current+incr, 10)
})
}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
capturedKey := key
capturedMembers := make(map[string]float64, len(members))
for k, v := range members {
capturedMembers[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]float64)}
p.store.items[capturedKey] = item
}
zset, ok := item.value.(map[string]float64)
if !ok {
zset = make(map[string]float64)
item.value = zset
}
for member, score := range capturedMembers {
zset[member] = score
}
})
}
func (p *memoryPipeliner) ZRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
zset, ok := item.value.(map[string]float64)
if !ok {
return
}
for _, member := range capturedMembers {
delete(zset, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) MSet(values map[string]any) {
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
for key, value := range capturedValues {
p.store.items[key] = &memoryStoreItem{
value: value,
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
}
}
})
}
// --- Pub/Sub implementation (remains unchanged) ---
type memorySubscription struct {
store *memoryStore
channelName string
@@ -649,10 +816,11 @@ type memorySubscription struct {
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(channel string, message []byte) error {
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
@@ -669,7 +837,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
}
return nil
}
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)

View File

@@ -1,3 +1,5 @@
// Filename: internal/store/redis_store.go
package store
import (
@@ -8,22 +10,20 @@ import (
"time"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
)
// ensure RedisStore implements Store interface
var _ Store = (*RedisStore)(nil)
// RedisStore is a Redis-backed key-value store.
type RedisStore struct {
client *redis.Client
popAndCycleScript *redis.Script
logger *logrus.Entry
}
// NewRedisStore creates a new RedisStore instance.
func NewRedisStore(client *redis.Client) Store {
// Lua script for atomic pop-and-cycle operation.
// KEYS[1]: main set key
// KEYS[2]: cooldown set key
func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
const script = `
if redis.call('SCARD', KEYS[1]) == 0 then
if redis.call('SCARD', KEYS[2]) == 0 then
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
return &RedisStore{
client: client,
popAndCycleScript: redis.NewScript(script),
logger: logger.WithField("component", "store.redis 🗄️"),
}
}
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
return s.client.Set(context.Background(), key, value, ttl).Err()
func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.client.Set(ctx, key, value, ttl).Err()
}
func (s *RedisStore) Get(key string) ([]byte, error) {
val, err := s.client.Get(context.Background(), key).Bytes()
func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
val, err := s.client.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrNotFound
@@ -54,53 +55,67 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
return val, nil
}
func (s *RedisStore) Del(keys ...string) error {
func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
return s.client.Del(context.Background(), keys...).Err()
return s.client.Del(ctx, keys...).Err()
}
func (s *RedisStore) Exists(key string) (bool, error) {
val, err := s.client.Exists(context.Background(), key).Result()
func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
val, err := s.client.Exists(ctx, key).Result()
return val > 0, err
}
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(context.Background(), key, value, ttl).Result()
func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(ctx, key, value, ttl).Result()
}
func (s *RedisStore) Close() error {
return s.client.Close()
}
func (s *RedisStore) HSet(key string, values map[string]any) error {
return s.client.HSet(context.Background(), key, values).Err()
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
return s.client.Expire(ctx, key, expiration).Err()
}
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
return s.client.HGetAll(context.Background(), key).Result()
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(ctx, key, values).Err()
}
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
val, err := s.client.HGet(ctx, key, field).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
return val, nil
}
func (s *RedisStore) HDel(key string, fields ...string) error {
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(ctx, key).Result()
}
func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(ctx, key, field, incr).Result()
}
func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
if len(fields) == 0 {
return nil
}
return s.client.HDel(context.Background(), key, fields...).Err()
return s.client.HDel(ctx, key, fields...).Err()
}
func (s *RedisStore) LPush(key string, values ...any) error {
return s.client.LPush(context.Background(), key, values...).Err()
func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
return s.client.LPush(ctx, key, values...).Err()
}
func (s *RedisStore) LRem(key string, count int64, value any) error {
return s.client.LRem(context.Background(), key, count, value).Err()
func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
return s.client.LRem(ctx, key, count, value).Err()
}
func (s *RedisStore) Rotate(key string) (string, error) {
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
val, err := s.client.RPopLPush(ctx, key, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -110,29 +125,40 @@ func (s *RedisStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *RedisStore) SAdd(key string, members ...any) error {
return s.client.SAdd(context.Background(), key, members...).Err()
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
if len(values) == 0 {
return nil
}
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
pairs := make([]interface{}, 0, len(values)*2)
for k, v := range values {
pairs = append(pairs, k, v)
}
return s.client.MSet(ctx, pairs...).Err()
}
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
return s.client.SPopN(context.Background(), key, count).Result()
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(ctx, key, members...).Err()
}
func (s *RedisStore) SMembers(key string) ([]string, error) {
return s.client.SMembers(context.Background(), key).Result()
func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
return s.client.SPopN(ctx, key, count).Result()
}
func (s *RedisStore) SRem(key string, members ...any) error {
func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
return s.client.SMembers(ctx, key).Result()
}
func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.SRem(context.Background(), key, members...).Err()
return s.client.SRem(ctx, key, members...).Err()
}
func (s *RedisStore) SRandMember(key string) (string, error) {
member, err := s.client.SRandMember(context.Background(), key).Result()
func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
member, err := s.client.SRandMember(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
@@ -141,81 +167,50 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
return member, nil
}
// === 新增方法实现 ===
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
if len(keys) == 0 {
return 0, nil
}
return s.client.SUnionStore(ctx, destination, keys...).Result()
}
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
if len(members) == 0 {
return nil
}
redisMembers := make([]redis.Z, 0, len(members))
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
return s.client.ZAdd(ctx, key, redisMembers...).Err()
}
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
return s.client.ZRange(context.Background(), key, start, stop).Result()
func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
return s.client.ZRange(ctx, key, start, stop).Result()
}
func (s *RedisStore) ZRem(key string, members ...any) error {
func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.ZRem(context.Background(), key, members...).Err()
return s.client.ZRem(ctx, key, members...).Err()
}
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
// Lua script returns a string, so we need to type assert
if str, ok := val.(string); ok {
return str, nil
}
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
return "", ErrNotFound
}
type redisPipeliner struct{ pipe redis.Pipeliner }
func (p *redisPipeliner) HSet(key string, values map[string]any) {
p.pipe.HSet(context.Background(), key, values)
}
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(context.Background(), key, field, incr)
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(context.Background())
return err
}
func (p *redisPipeliner) Del(keys ...string) {
if len(keys) > 0 {
p.pipe.Del(context.Background(), keys...)
}
}
func (p *redisPipeliner) SAdd(key string, members ...any) {
p.pipe.SAdd(context.Background(), key, members...)
}
func (p *redisPipeliner) SRem(key string, members ...any) {
if len(members) > 0 {
p.pipe.SRem(context.Background(), key, members...)
}
}
func (p *redisPipeliner) LPush(key string, values ...any) {
p.pipe.LPush(context.Background(), key, values...)
}
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(context.Background(), key, count, value)
}
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
val, err := s.client.LIndex(context.Background(), key, index).Result()
func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
val, err := s.client.LIndex(ctx, key, index).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -225,47 +220,131 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
return val, nil
}
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(context.Background(), key, expiration)
type redisPipeliner struct {
pipe redis.Pipeliner
ctx context.Context
}
func (s *RedisStore) Pipeline() Pipeliner {
return &redisPipeliner{pipe: s.client.Pipeline()}
func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
return &redisPipeliner{
pipe: s.client.Pipeline(),
ctx: ctx,
}
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(p.ctx)
return err
}
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(p.ctx, key, expiration)
}
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(p.ctx, key, field, incr)
}
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value)
}
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
p.pipe.Set(p.ctx, key, value, expiration)
}
func (p *redisPipeliner) MSet(values map[string]any) {
if len(values) == 0 {
return
}
p.pipe.MSet(p.ctx, values)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
if len(members) == 0 {
return
}
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
p.pipe.ZAdd(p.ctx, key, redisMembers...)
}
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
type redisSubscription struct {
pubsub *redis.PubSub
msgChan chan *Message
once sync.Once
pubsub *redis.PubSub
msgChan chan *Message
logger *logrus.Entry
wg sync.WaitGroup
close context.CancelFunc
channelName string
}
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
pubsub := s.client.Subscribe(ctx, channel)
_, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
subCtx, cancel := context.WithCancel(context.Background())
sub := &redisSubscription{
pubsub: pubsub,
msgChan: make(chan *Message, 10),
logger: s.logger,
close: cancel,
channelName: channel,
}
sub.wg.Add(1)
go sub.bridge(subCtx)
return sub, nil
}
func (rs *redisSubscription) bridge(ctx context.Context) {
defer rs.wg.Done()
defer close(rs.msgChan)
redisCh := rs.pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case redisMsg, ok := <-redisCh:
if !ok {
return
}
msg := &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
select {
case rs.msgChan <- msg:
default:
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
}
}
}
}
func (rs *redisSubscription) Channel() <-chan *Message {
rs.once.Do(func() {
rs.msgChan = make(chan *Message)
go func() {
defer close(rs.msgChan)
for redisMsg := range rs.pubsub.Channel() {
rs.msgChan <- &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
}
}()
})
return rs.msgChan
}
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
func (s *RedisStore) Publish(channel string, message []byte) error {
return s.client.Publish(context.Background(), channel, message).Err()
func (rs *redisSubscription) ChannelName() string {
return rs.channelName
}
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
pubsub := s.client.Subscribe(context.Background(), channel)
_, err := pubsub.Receive(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
return &redisSubscription{pubsub: pubsub}, nil
func (rs *redisSubscription) Close() error {
rs.close()
err := rs.pubsub.Close()
rs.wg.Wait()
return err
}
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
return s.client.Publish(ctx, channel, message).Err()
}

View File

@@ -1,6 +1,9 @@
// Filename: internal/store/store.go
package store
import (
"context"
"errors"
"time"
)
@@ -17,6 +20,7 @@ type Message struct {
// Subscription represents an active subscription to a pub/sub channel.
type Subscription interface {
Channel() <-chan *Message
ChannelName() string
Close() error
}
@@ -31,6 +35,8 @@ type Pipeliner interface {
HIncrBy(key, field string, incr int64)
// SET
MSet(values map[string]any)
Set(key string, value []byte, expiration time.Duration)
SAdd(key string, members ...any)
SRem(key string, members ...any)
@@ -38,6 +44,10 @@ type Pipeliner interface {
LPush(key string, values ...any)
LRem(key string, count int64, value any)
// ZSET
ZAdd(key string, members map[string]float64)
ZRem(key string, members ...any)
// Execution
Exec() error
}
@@ -45,44 +55,48 @@ type Pipeliner interface {
// Store is the master interface for our cache service.
type Store interface {
// Basic K/V operations
Set(key string, value []byte, ttl time.Duration) error
Get(key string) ([]byte, error)
Del(keys ...string) error
Exists(key string) (bool, error)
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
Get(ctx context.Context, key string) ([]byte, error)
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
MSet(ctx context.Context, values map[string]any) error
// HASH operations
HSet(key string, values map[string]any) error
HGetAll(key string) (map[string]string, error)
HIncrBy(key, field string, incr int64) (int64, error)
HDel(key string, fields ...string) error // [新增]
HSet(ctx context.Context, key string, values map[string]any) error
HGet(ctx context.Context, key, field string) (string, error)
HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(ctx context.Context, key string, fields ...string) error
// LIST operations
LPush(key string, values ...any) error
LRem(key string, count int64, value any) error
Rotate(key string) (string, error)
LIndex(key string, index int64) (string, error)
LPush(ctx context.Context, key string, values ...any) error
LRem(ctx context.Context, key string, count int64, value any) error
Rotate(ctx context.Context, key string) (string, error)
LIndex(ctx context.Context, key string, index int64) (string, error)
Expire(ctx context.Context, key string, expiration time.Duration) error
// SET operations
SAdd(key string, members ...any) error
SPopN(key string, count int64) ([]string, error)
SMembers(key string) ([]string, error)
SRem(key string, members ...any) error
SRandMember(key string) (string, error)
SAdd(ctx context.Context, key string, members ...any) error
SPopN(ctx context.Context, key string, count int64) ([]string, error)
SMembers(ctx context.Context, key string) ([]string, error)
SRem(ctx context.Context, key string, members ...any) error
SRandMember(ctx context.Context, key string) (string, error)
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
// Pub/Sub operations
Publish(channel string, message []byte) error
Subscribe(channel string) (Subscription, error)
Publish(ctx context.Context, channel string, message []byte) error
Subscribe(ctx context.Context, channel string) (Subscription, error)
// Pipeline (optional) - 我们在redis实现它内存版暂时不实现
Pipeline() Pipeliner
// Pipeline
Pipeline(ctx context.Context) Pipeliner
// Close closes the store and releases any underlying resources.
Close() error
// === 新增方法,支持轮询策略 ===
ZAdd(key string, members map[string]float64) error
ZRange(key string, start, stop int64) ([]string, error)
ZRem(key string, members ...any) error
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
// ZSET operations
ZAdd(ctx context.Context, key string, members map[string]float64) error
ZRange(ctx context.Context, key string, start, stop int64) ([]string, error)
ZRem(ctx context.Context, key string, members ...any) error
PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error)
}

View File

@@ -1,6 +1,7 @@
package syncer
import (
"context"
"fmt"
"gemini-balancer/internal/store"
"log"
@@ -51,7 +52,7 @@ func (s *CacheSyncer[T]) Get() T {
func (s *CacheSyncer[T]) Invalidate() error {
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
return s.store.Publish(s.channelName, []byte("reload"))
return s.store.Publish(context.Background(), s.channelName, []byte("reload"))
}
func (s *CacheSyncer[T]) Stop() {
@@ -84,7 +85,7 @@ func (s *CacheSyncer[T]) listenForUpdates() {
default:
}
subscription, err := s.store.Subscribe(s.channelName)
subscription, err := s.store.Subscribe(context.Background(), s.channelName)
if err != nil {
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
time.Sleep(5 * time.Second)

View File

@@ -1,7 +1,8 @@
// Filename: internal/task/task.go (最终校准版)
// Filename: internal/task/task.go
package task
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -15,15 +16,13 @@ const (
ResultTTL = 60 * time.Minute
)
// Reporter 接口,定义了领域如何与任务服务交互。
type Reporter interface {
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(taskID string, processed int) error
UpdateTotalByID(taskID string, total int) error
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
UpdateTotalByID(ctx context.Context, taskID string, total int) error
}
// Status 代表一个后台任务的完整状态
type Status struct {
ID string `json:"id"`
TaskType string `json:"task_type"`
@@ -38,13 +37,11 @@ type Status struct {
DurationSeconds float64 `json:"duration_seconds,omitempty"`
}
// Task 是任务管理的核心服务
type Task struct {
store store.Store
logger *logrus.Entry
}
// NewTask 是 Task 的构造函数
func NewTask(store store.Store, logger *logrus.Logger) *Task {
return &Task{
store: store,
@@ -62,15 +59,14 @@ func (s *Task) getTaskDataKey(taskID string) string {
return fmt.Sprintf("task:data:%s", taskID)
}
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
func (s *Task) getIsRunningFlagKey(taskID string) string {
return fmt.Sprintf("task:running:%s", taskID)
}
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID)
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
}
@@ -94,35 +90,34 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int
timeout = ResultTTL * 24
}
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
}
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(lockKey)
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
}
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(lockKey)
_ = s.store.Del(taskKey)
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
_ = s.store.Del(ctx, taskKey)
return nil, fmt.Errorf("failed to set task running flag: %w", err)
}
return status, nil
}
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID)
defer func() {
if err := s.store.Del(lockKey); err != nil {
if err := s.store.Del(ctx, lockKey); err != nil {
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
}
}()
runningFlagKey := s.getIsRunningFlagKey(taskID)
_ = s.store.Del(runningFlagKey)
status, err := s.GetStatus(taskID)
if err != nil {
_ = s.store.Del(ctx, runningFlagKey)
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
return
}
@@ -141,15 +136,14 @@ func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr er
}
updatedTaskBytes, _ := json.Marshal(status)
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
}
}
// GetStatus 通过ID获取任务状态供外部如API Handler调用
func (s *Task) GetStatus(taskID string) (*Status, error) {
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
taskKey := s.getTaskDataKey(taskID)
statusBytes, err := s.store.Get(taskKey)
statusBytes, err := s.store.Get(ctx, taskKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found")
@@ -161,22 +155,18 @@ func (s *Task) GetStatus(taskID string) (*Status, error) {
if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
}
if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
}
return &status, nil
}
// UpdateProgressByID 通过ID更新任务进度
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(runningFlagKey); err != nil {
// 任务已结束,静默返回是预期行为。
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
return nil
}
status, err := s.GetStatus(taskID)
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
return nil
@@ -184,7 +174,6 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
if !status.IsRunning {
return nil
}
// 调用传入的 updater 函数来修改 status
updater(status)
statusBytes, marshalErr := json.Marshal(status)
if marshalErr != nil {
@@ -192,23 +181,20 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
return nil
}
taskKey := s.getTaskDataKey(taskID)
// 使用更长的TTL确保运行中的任务不会过早过期
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
}
return nil
}
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
return s.updateTask(taskID, func(status *Status) {
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Processed = processed
})
}
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateTotalByID(taskID string, total int) error {
return s.updateTask(taskID, func(status *Status) {
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Total = total
})
}