Compare commits
2 Commits
ac0e0a8275
...
6c7283d51b
| Author | SHA1 | Date | |
|---|---|---|---|
| 6c7283d51b | |||
| 2b0b9b67dc |
@@ -28,12 +28,6 @@ type GeminiChannel struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// 用于安全提取信息的本地结构体
|
||||
type requestMetadata struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -47,38 +41,50 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini
|
||||
logger: logger,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 0,
|
||||
Timeout: 0, // Timeout is handled by the request context
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TransformRequest
|
||||
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
|
||||
modelName = ch.ExtractModel(c, requestBody)
|
||||
return requestBody, modelName, nil
|
||||
}
|
||||
|
||||
// ExtractModel
|
||||
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
||||
return ch.extractModelFromRequest(c, bodyBytes)
|
||||
}
|
||||
|
||||
// 统一的模型提取逻辑:优先从请求体解析,失败则回退到从URL路径解析。
|
||||
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
|
||||
var p struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
_ = json.Unmarshal(requestBody, &p)
|
||||
modelName = strings.TrimPrefix(p.Model, "models/")
|
||||
|
||||
if modelName == "" {
|
||||
modelName = ch.extractModelFromPath(c.Request.URL.Path)
|
||||
if json.Unmarshal(bodyBytes, &p) == nil && p.Model != "" {
|
||||
return strings.TrimPrefix(p.Model, "models/")
|
||||
}
|
||||
return requestBody, modelName, nil
|
||||
return ch.extractModelFromPath(c.Request.URL.Path)
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) extractModelFromPath(path string) string {
|
||||
parts := strings.Split(path, "/")
|
||||
for _, part := range parts {
|
||||
// 覆盖更多模型名称格式
|
||||
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
|
||||
modelPart := strings.Split(part, ":")[0]
|
||||
return modelPart
|
||||
return strings.Split(part, ":")[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
|
||||
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
|
||||
path := c.Request.URL.Path
|
||||
return ch.isOpenAIPath(c.Request.URL.Path)
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
|
||||
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
|
||||
}
|
||||
|
||||
@@ -88,25 +94,28 @@ func (ch *GeminiChannel) ValidateKey(
|
||||
targetURL string,
|
||||
timeout time.Duration,
|
||||
) *CustomErrors.APIError {
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
|
||||
}
|
||||
|
||||
ch.ModifyRequest(req, apiKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
errorBody, _ := io.ReadAll(resp.Body)
|
||||
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
|
||||
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
|
||||
@@ -115,10 +124,6 @@ func (ch *GeminiChannel) ValidateKey(
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
|
||||
// TODO: [Future Refactoring] Decouple auth logic from URL path.
|
||||
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
|
||||
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
|
||||
// This would make the channel more generic and adaptable to new upstream provider types.
|
||||
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
|
||||
} else {
|
||||
@@ -133,24 +138,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
|
||||
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
|
||||
return true
|
||||
}
|
||||
var meta requestMetadata
|
||||
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
|
||||
var meta struct {
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &meta) == nil {
|
||||
return meta.Stream
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
||||
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
|
||||
return modelName
|
||||
}
|
||||
|
||||
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
|
||||
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
||||
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
|
||||
var rewrittenSegment string
|
||||
if ch.IsOpenAICompatibleRequest(tempCtx) {
|
||||
var apiEndpoint string
|
||||
|
||||
if ch.isOpenAIPath(originalPath) {
|
||||
v1Index := strings.LastIndex(originalPath, "/v1/")
|
||||
var apiEndpoint string
|
||||
if v1Index != -1 {
|
||||
apiEndpoint = originalPath[v1Index+len("/v1/"):]
|
||||
} else {
|
||||
@@ -158,69 +161,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
||||
}
|
||||
rewrittenSegment = "v1beta/openai/" + apiEndpoint
|
||||
} else {
|
||||
tempPath := originalPath
|
||||
if strings.HasPrefix(tempPath, "/v1/") {
|
||||
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
|
||||
if strings.HasPrefix(originalPath, "/v1/") {
|
||||
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
|
||||
} else {
|
||||
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
|
||||
}
|
||||
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
|
||||
}
|
||||
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
||||
pathToJoin := rewrittenSegment
|
||||
|
||||
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
||||
|
||||
// 防止版本号重复拼接,例如 basePath 是 /v1beta,而重写段也是 v1beta/..
|
||||
versionPrefixes := []string{"v1beta", "v1"}
|
||||
for _, prefix := range versionPrefixes {
|
||||
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
|
||||
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
|
||||
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
|
||||
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
|
||||
break
|
||||
}
|
||||
}
|
||||
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
|
||||
|
||||
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
|
||||
if err != nil {
|
||||
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
|
||||
// 回退到简单的字符串拼接
|
||||
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
|
||||
}
|
||||
return finalPath
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
return nil
|
||||
return nil // 桩实现
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
// 桩实现
|
||||
}
|
||||
|
||||
// ==========================================================
|
||||
// ================== “智能路由”的核心引擎 ===================
|
||||
// ==========================================================
|
||||
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
|
||||
log := ch.logger.WithField("correlation_id", params.CorrelationID)
|
||||
|
||||
targetURL, err := url.Parse(params.UpstreamURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse upstream URL")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
|
||||
log.WithError(err).Error("Invalid upstream URL")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
|
||||
return
|
||||
}
|
||||
|
||||
targetURL.Path = c.Request.URL.Path
|
||||
targetURL.RawQuery = c.Request.URL.RawQuery
|
||||
|
||||
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create initial smart request")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
|
||||
log.WithError(err).Error("Failed to create initial smart stream request")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
|
||||
return
|
||||
}
|
||||
|
||||
ch.ModifyRequest(initialReq, params.APIKey)
|
||||
initialReq.Header.Del("Authorization")
|
||||
|
||||
resp, err := ch.httpClient.Do(initialReq)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Initial smart request failed")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
|
||||
log.WithError(err).Error("Initial smart stream request failed")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
|
||||
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
|
||||
defer standardizedResp.Body.Close()
|
||||
|
||||
c.Writer.WriteHeader(standardizedResp.StatusCode)
|
||||
for key, values := range standardizedResp.Header {
|
||||
for _, value := range values {
|
||||
@@ -228,45 +238,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR
|
||||
}
|
||||
}
|
||||
io.Copy(c.Writer, standardizedResp.Body)
|
||||
|
||||
params.EventLogger.IsSuccess = false
|
||||
params.EventLogger.StatusCode = resp.StatusCode
|
||||
return
|
||||
}
|
||||
|
||||
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) processStreamAndRetry(
|
||||
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
|
||||
c *gin.Context,
|
||||
initialRequestHeaders http.Header,
|
||||
initialReader io.ReadCloser,
|
||||
params SmartRequestParams,
|
||||
log *logrus.Entry,
|
||||
) {
|
||||
defer initialReader.Close()
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
|
||||
var accumulatedText strings.Builder
|
||||
consecutiveRetryCount := 0
|
||||
currentReader := initialReader
|
||||
maxRetries := params.MaxRetries
|
||||
retryDelay := params.RetryDelay
|
||||
|
||||
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
|
||||
|
||||
for {
|
||||
if c.Request.Context().Err() != nil {
|
||||
log.Info("Client disconnected, stopping stream processing.")
|
||||
return
|
||||
}
|
||||
|
||||
var interruptionReason string
|
||||
scanner := bufio.NewScanner(currentReader)
|
||||
|
||||
for scanner.Scan() {
|
||||
if c.Request.Context().Err() != nil {
|
||||
log.Info("Client disconnected during scan.")
|
||||
return
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(c.Writer, "%s\n\n", line)
|
||||
flusher.Flush()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
var payload models.GeminiSSEPayload
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(payload.Candidates) > 0 {
|
||||
candidate := payload.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
@@ -285,52 +321,71 @@ func (ch *GeminiChannel) processStreamAndRetry(
|
||||
}
|
||||
}
|
||||
currentReader.Close()
|
||||
|
||||
if interruptionReason == "" {
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.WithError(err).Warn("Stream scanner encountered an error.")
|
||||
interruptionReason = "SCANNER_ERROR"
|
||||
} else {
|
||||
log.Warn("Stream dropped unexpectedly without a finish reason.")
|
||||
log.Warn("Stream connection dropped without a finish reason.")
|
||||
interruptionReason = "CONNECTION_DROP"
|
||||
}
|
||||
}
|
||||
|
||||
if consecutiveRetryCount >= maxRetries {
|
||||
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
|
||||
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
|
||||
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
|
||||
errData, _ := json.Marshal(map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"code": http.StatusGatewayTimeout,
|
||||
"status": "DEADLINE_EXCEEDED",
|
||||
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
consecutiveRetryCount++
|
||||
params.EventLogger.Retries = consecutiveRetryCount
|
||||
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
|
||||
|
||||
time.Sleep(retryDelay)
|
||||
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
||||
|
||||
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
||||
retryBodyBytes, _ := json.Marshal(retryBody)
|
||||
|
||||
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
||||
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create retry request")
|
||||
continue
|
||||
}
|
||||
|
||||
retryReq.Header = initialRequestHeaders
|
||||
ch.ModifyRequest(retryReq, params.APIKey)
|
||||
retryReq.Header.Del("Authorization")
|
||||
|
||||
retryResp, err := ch.httpClient.Do(retryReq)
|
||||
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Retry request failed.")
|
||||
} else {
|
||||
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
||||
if retryResp.Body != nil {
|
||||
retryResp.Body.Close()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Retry request failed")
|
||||
continue
|
||||
}
|
||||
|
||||
if retryResp.StatusCode != http.StatusOK {
|
||||
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
||||
retryResp.Body.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
currentReader = retryResp.Body
|
||||
}
|
||||
}
|
||||
|
||||
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
|
||||
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
|
||||
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
|
||||
retryBody := originalBody
|
||||
|
||||
// 找到最后一个 'user' 角色的消息索引
|
||||
lastUserIndex := -1
|
||||
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
||||
if retryBody.Contents[i].Role == "user" {
|
||||
@@ -338,25 +393,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
history := []models.GeminiContent{
|
||||
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
|
||||
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
|
||||
}
|
||||
|
||||
if lastUserIndex != -1 {
|
||||
// 如果找到了 'user' 消息,将历史记录插入到其后
|
||||
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
|
||||
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
||||
newContents = append(newContents, history...)
|
||||
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
||||
retryBody.Contents = newContents
|
||||
} else {
|
||||
// 如果没有 'user' 消息(理论上不应发生),则直接追加
|
||||
retryBody.Contents = append(retryBody.Contents, history...)
|
||||
}
|
||||
return retryBody, nil
|
||||
}
|
||||
|
||||
// ===============================================
|
||||
// ========= 辅助函数区 (继承并强化) =========
|
||||
// ===============================================
|
||||
return retryBody
|
||||
}
|
||||
|
||||
type googleAPIError struct {
|
||||
Error struct {
|
||||
@@ -397,25 +453,28 @@ func truncate(s string, n int) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// standardizeError
|
||||
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to read upstream error body")
|
||||
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
|
||||
bodyBytes = []byte("Failed to read upstream error body")
|
||||
}
|
||||
resp.Body.Close()
|
||||
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
|
||||
|
||||
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
|
||||
|
||||
var standardizedPayload googleAPIError
|
||||
// 即使解析失败,也要构建一个标准的错误结构体
|
||||
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
|
||||
standardizedPayload.Error.Code = resp.StatusCode
|
||||
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
|
||||
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
|
||||
standardizedPayload.Error.Details = []interface{}{map[string]string{
|
||||
"@type": "proxy.upstream.error",
|
||||
"@type": "proxy.upstream.unparsed.error",
|
||||
"body": truncate(string(bodyBytes), truncateLimit),
|
||||
}}
|
||||
}
|
||||
|
||||
newBodyBytes, _ := json.Marshal(standardizedPayload)
|
||||
newResp := &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
@@ -425,10 +484,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int
|
||||
}
|
||||
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
newResp.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
return newResp
|
||||
}
|
||||
|
||||
// errToJSON
|
||||
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
|
||||
}
|
||||
|
||||
@@ -13,9 +13,10 @@ type Config struct {
|
||||
Database DatabaseConfig
|
||||
Server ServerConfig
|
||||
Log LogConfig
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
SessionSecret string `mapstructure:"session_secret"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
SessionSecret string `mapstructure:"session_secret"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
Repository RepositoryConfig `mapstructure:"repository"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 存储数据库连接信息
|
||||
@@ -43,19 +44,24 @@ type RedisConfig struct {
|
||||
DSN string `mapstructure:"dsn"`
|
||||
}
|
||||
|
||||
type RepositoryConfig struct {
|
||||
BasePoolTTLMinutes int `mapstructure:"base_pool_ttl_minutes"`
|
||||
BasePoolTTIMinutes int `mapstructure:"base_pool_tti_minutes"`
|
||||
}
|
||||
|
||||
// LoadConfig 从文件和环境变量加载配置
|
||||
func LoadConfig() (*Config, error) {
|
||||
// 设置配置文件名和路径
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
viper.AddConfigPath("/etc/gemini-balancer/") // for production
|
||||
// 允许从环境变量读取
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// 设置默认值
|
||||
viper.SetDefault("server.port", "8080")
|
||||
viper.SetDefault("server.port", "9000")
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("log.format", "text")
|
||||
viper.SetDefault("log.enable_file", false)
|
||||
@@ -67,6 +73,9 @@ func LoadConfig() (*Config, error) {
|
||||
viper.SetDefault("database.conn_max_lifetime", "1h")
|
||||
viper.SetDefault("encryption_key", "")
|
||||
|
||||
viper.SetDefault("repository.base_pool_ttl_minutes", 60)
|
||||
viper.SetDefault("repository.base_pool_tti_minutes", 10)
|
||||
|
||||
// 读取配置文件
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -49,7 +50,6 @@ func (h *handler) registerRoutes(rg *gin.RouterGroup) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- 请求 DTO ---
|
||||
type CreateProxyConfigRequest struct {
|
||||
Address string `json:"address" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
@@ -64,12 +64,10 @@ type UpdateProxyConfigRequest struct {
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// 单个检测的请求体 (与前端JS对齐)
|
||||
type CheckSingleProxyRequest struct {
|
||||
Proxy string `json:"proxy" binding:"required"`
|
||||
}
|
||||
|
||||
// 批量检测的请求体
|
||||
type CheckAllProxiesRequest struct {
|
||||
Proxies []string `json:"proxies" binding:"required"`
|
||||
}
|
||||
@@ -84,7 +82,7 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
if req.Status == "" {
|
||||
req.Status = "active" // 默认状态
|
||||
req.Status = "active"
|
||||
}
|
||||
|
||||
proxyConfig := models.ProxyConfig{
|
||||
@@ -98,7 +96,6 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
// 写操作后,发布事件并使缓存失效
|
||||
h.publishAndInvalidate(proxyConfig.ID, "created")
|
||||
response.Created(c, proxyConfig)
|
||||
}
|
||||
@@ -199,17 +196,16 @@ func (h *handler) DeleteProxyConfig(c *gin.Context) {
|
||||
response.NoContent(c)
|
||||
}
|
||||
|
||||
// publishAndInvalidate 统一事件发布和缓存失效逻辑
|
||||
func (h *handler) publishAndInvalidate(proxyID uint, action string) {
|
||||
go h.manager.invalidate()
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
_ = h.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
// 新的 Handler 方法和 DTO
|
||||
type SyncProxiesRequest struct {
|
||||
Proxies []string `json:"proxies"`
|
||||
}
|
||||
@@ -220,14 +216,12 @@ func (h *handler) SyncProxies(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
|
||||
|
||||
taskStatus, err := h.manager.SyncProxiesInBackground(c.Request.Context(), req.Proxies)
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, ErrTaskConflict) {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
} else {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
@@ -262,7 +256,7 @@ func (h *handler) CheckAllProxies(c *gin.Context) {
|
||||
|
||||
concurrency := cfg.ProxyCheckConcurrency
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // 如果配置不合法,提供一个安全的默认值
|
||||
concurrency = 5
|
||||
}
|
||||
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
|
||||
response.Success(c, results)
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -25,7 +24,7 @@ import (
|
||||
|
||||
const (
|
||||
TaskTypeProxySync = "proxy_sync"
|
||||
proxyChunkSize = 200 // 代理同步的批量大小
|
||||
proxyChunkSize = 200
|
||||
)
|
||||
|
||||
type ProxyCheckResult struct {
|
||||
@@ -35,13 +34,11 @@ type ProxyCheckResult struct {
|
||||
ErrorMessage string `json:"error_message"`
|
||||
}
|
||||
|
||||
// managerCacheData
|
||||
type managerCacheData struct {
|
||||
ActiveProxies []*models.ProxyConfig
|
||||
ProxiesByID map[uint]*models.ProxyConfig
|
||||
}
|
||||
|
||||
// manager结构体
|
||||
type manager struct {
|
||||
db *gorm.DB
|
||||
syncer *syncer.CacheSyncer[managerCacheData]
|
||||
@@ -80,21 +77,21 @@ func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskR
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) {
|
||||
func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) {
|
||||
resourceID := "global_proxy_sync"
|
||||
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
|
||||
taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
|
||||
if err != nil {
|
||||
return nil, ErrTaskConflict
|
||||
}
|
||||
go m.runProxySyncTask(taskStatus.ID, proxyStrings)
|
||||
go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
|
||||
func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) {
|
||||
resourceID := "global_proxy_sync"
|
||||
var allProxies []models.ProxyConfig
|
||||
if err := m.db.Find(&allProxies).Error; err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
|
||||
return
|
||||
}
|
||||
currentProxyMap := make(map[string]uint)
|
||||
@@ -125,19 +122,19 @@ func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
|
||||
}
|
||||
if len(idsToDelete) > 0 {
|
||||
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(proxiesToAdd) > 0 {
|
||||
if err := m.bulkAdd(proxiesToAdd); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
|
||||
m.task.EndTaskByID(taskID, resourceID, result, nil)
|
||||
m.publishChangeEvent("proxies_synced")
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
m.publishChangeEvent(ctx, "proxies_synced")
|
||||
go m.invalidate()
|
||||
}
|
||||
|
||||
@@ -184,14 +181,15 @@ func (m *manager) bulkDeleteByIDs(ids []uint) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
|
||||
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
|
||||
}
|
||||
|
||||
func (m *manager) publishChangeEvent(reason string) {
|
||||
func (m *manager) publishChangeEvent(ctx context.Context, reason string) {
|
||||
event := models.ProxyStatusChangedEvent{Action: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
_ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
|
||||
@@ -313,3 +311,8 @@ func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout t
|
||||
defer resp.Body.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
type Manager interface {
|
||||
AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error)
|
||||
// ... 其他需要暴露给外部服务的方法
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ var (
|
||||
ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."}
|
||||
ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."}
|
||||
ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."}
|
||||
ErrProxyNotAvailable = &APIError{HTTPStatus: http.StatusNotFound, Code: "PROXY_ERROR", Message: "Required proxy is not available for this request."}
|
||||
|
||||
ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."}
|
||||
ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Filename: internal/handlers/apikey_handler.go
|
||||
// Filename: internal/handlers/apikey_handler.go (最终决战版)
|
||||
package handlers
|
||||
|
||||
import (
|
||||
@@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs for API requests
|
||||
type BulkAddKeysToGroupRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
|
||||
ValidateOnImport bool `json:"validate_on_import"`
|
||||
}
|
||||
|
||||
type BulkUnlinkKeysFromGroupRequest struct {
|
||||
@@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct {
|
||||
}
|
||||
|
||||
type BulkActionFilter struct {
|
||||
Status []string `json:"status"` // Changed to slice to accept multiple statuses
|
||||
Status []string `json:"status"`
|
||||
}
|
||||
type BulkActionRequest struct {
|
||||
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
|
||||
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
|
||||
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
|
||||
Filter BulkActionFilter `json:"filter" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -89,7 +88,8 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -104,7 +104,8 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -119,7 +120,8 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -134,7 +136,8 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -148,7 +151,8 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -172,7 +176,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
keys, err := h.apiKeyService.GetKeysByIds(ids)
|
||||
keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
|
||||
if err != nil {
|
||||
response.Error(c, &errors.APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
@@ -191,7 +195,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
result, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
@@ -201,19 +205,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
|
||||
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
|
||||
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
// 1. Manually handle the path parameter.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind query parameters using the correctly tagged struct.
|
||||
var params models.APIKeyQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Set server-side defaults and the path parameter.
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
@@ -221,15 +222,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
params.PageSize = 20
|
||||
}
|
||||
params.KeyGroupID = uint(groupID)
|
||||
// 4. Call the service layer.
|
||||
paginatedResult, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 5. [THE FIX] Return a successful response using the standard `response.Success`
|
||||
// and a gin.H map, as confirmed to exist in your project.
|
||||
response.Success(c, gin.H{
|
||||
"items": paginatedResult.Items,
|
||||
"total": paginatedResult.Total,
|
||||
@@ -239,20 +236,18 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
|
||||
// Group ID is now correctly sourced from the URL path.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// The request body is now simpler, only needing the keys.
|
||||
var req BulkTestKeysForGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Call the same underlying service, but with unambiguous context.
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -267,7 +262,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
|
||||
// Route: PUT /keygroups/:id/apikeys/:keyId
|
||||
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -284,8 +278,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Directly use the service to handle the logic
|
||||
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
|
||||
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
@@ -305,7 +298,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
|
||||
if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
@@ -313,7 +306,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// RestoreKeysInGroup 恢复指定Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore
|
||||
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -325,7 +317,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
|
||||
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
@@ -339,14 +331,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
}
|
||||
|
||||
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore-all-banned
|
||||
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
|
||||
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
@@ -360,48 +351,41 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
|
||||
}
|
||||
|
||||
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
|
||||
// Route: POST /keygroups/:id/bulk-actions
|
||||
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
|
||||
// 1. Parse GroupID from URL
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind the JSON payload to our new DTO
|
||||
var req BulkActionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Central logic: based on the action, call the appropriate service method.
|
||||
var task *task.Status
|
||||
var apiErr *errors.APIError
|
||||
switch req.Action {
|
||||
case "revalidate":
|
||||
// Assume keyValidationService has a method that accepts a filter
|
||||
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
|
||||
case "set_status":
|
||||
if req.NewStatus == "" {
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
|
||||
break
|
||||
}
|
||||
// Assume apiKeyService has a method to update status by filter
|
||||
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
|
||||
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
|
||||
targetStatus := models.APIKeyStatus(req.NewStatus)
|
||||
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
|
||||
case "delete":
|
||||
// Assume keyImportService has a method to unlink by filter
|
||||
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
// [修正] 将请求的 context 传递给 service 层
|
||||
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
|
||||
default:
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
|
||||
}
|
||||
// 4. Handle errors from the switch block
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
// Attempt to parse it as a known APIError, otherwise, wrap it.
|
||||
var parsedErr *errors.APIError
|
||||
if errors.As(err, &parsedErr) {
|
||||
response.Error(c, parsedErr)
|
||||
@@ -410,21 +394,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
|
||||
}
|
||||
return
|
||||
}
|
||||
// 5. Return the task status on success
|
||||
response.Success(c, task)
|
||||
}
|
||||
|
||||
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
|
||||
// Route: GET /keygroups/:id/apikeys/export
|
||||
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
|
||||
statuses := c.QueryArray("status")
|
||||
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
|
||||
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
|
||||
@@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetChart 获取仪表盘的图表数据
|
||||
// GetChart
|
||||
func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
var groupID *uint
|
||||
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
|
||||
@@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
chartData, err := h.queryService.QueryHistoricalChart(groupID)
|
||||
chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
@@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, chartData)
|
||||
}
|
||||
|
||||
// GetRequestStats 处理对“期间调用概览”的请求
|
||||
// GetRequestStats
|
||||
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
|
||||
period := c.Param("period") // 从 URL 路径中获取 period
|
||||
stats, err := h.queryService.GetRequestStatsForPeriod(period)
|
||||
period := c.Param("period")
|
||||
stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
@@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs & 辅助函数
|
||||
func isValidGroupName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
@@ -40,7 +40,6 @@ func isValidGroupName(name string) bool {
|
||||
return match
|
||||
}
|
||||
|
||||
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
|
||||
type KeyGroupOperationalSettings struct {
|
||||
EnableKeyCheck *bool `json:"enable_key_check"`
|
||||
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
|
||||
@@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct {
|
||||
MaxRetries *int `json:"max_retries"`
|
||||
EnableSmartGateway *bool `json:"enable_smart_gateway"`
|
||||
}
|
||||
|
||||
type CreateKeyGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
DisplayName string `json:"display_name"`
|
||||
@@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct {
|
||||
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy bool `json:"enable_proxy"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
}
|
||||
|
||||
type UpdateKeyGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
DisplayName *string `json:"display_name"`
|
||||
@@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct {
|
||||
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy *bool `json:"enable_proxy"`
|
||||
ChannelType *string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
|
||||
// M:N associations
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
}
|
||||
|
||||
type KeyGroupResponse struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -96,36 +86,30 @@ type KeyGroupResponse struct {
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
}
|
||||
|
||||
// [NEW] Define the detailed response structure for a single group.
|
||||
type KeyGroupDetailsResponse struct {
|
||||
KeyGroupResponse
|
||||
Settings *models.GroupSettings `json:"settings,omitempty"`
|
||||
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
|
||||
}
|
||||
|
||||
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
|
||||
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
|
||||
modelNames := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping != nil { // Safety check
|
||||
if mapping != nil {
|
||||
modelNames = append(modelNames, mapping.ModelName)
|
||||
}
|
||||
}
|
||||
return modelNames
|
||||
}
|
||||
|
||||
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
|
||||
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
|
||||
urls := make([]string, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
if upstream != nil { // Safety check
|
||||
if upstream != nil {
|
||||
urls = append(urls, upstream.URL)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
|
||||
return KeyGroupResponse{
|
||||
ID: group.ID,
|
||||
@@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i
|
||||
CreatedAt: group.CreatedAt,
|
||||
UpdatedAt: group.UpdatedAt,
|
||||
Order: group.Order,
|
||||
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
|
||||
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
|
||||
AllowedModels: transformModelsToStrings(group.AllowedModels),
|
||||
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams),
|
||||
}
|
||||
}
|
||||
|
||||
// packGroupSettings is a helper to convert request-level operational settings
|
||||
// into the model-level settings struct.
|
||||
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
|
||||
return &models.KeyGroupSettings{
|
||||
EnableKeyCheck: settings.EnableKeyCheck,
|
||||
@@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet
|
||||
EnableSmartGateway: settings.EnableSmartGateway,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup,
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
|
||||
if req.Name != nil {
|
||||
group.Name = *req.Name
|
||||
@@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou
|
||||
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
|
||||
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
h.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
_ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// The core logic remains, as it's specific to creation.
|
||||
p := bluemonday.StripTagsPolicy()
|
||||
sanitizedDisplayName := p.Sanitize(req.DisplayName)
|
||||
sanitizedDescription := p.Sanitize(req.Description)
|
||||
@@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
|
||||
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
|
||||
}
|
||||
|
||||
// 统一的处理器可以处理两种情况:
|
||||
// 1. GET /keygroups - 返回所有组的列表
|
||||
// 2. GET /keygroups/:id - 返回指定ID的单个组
|
||||
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
// Case 1: Get a single group
|
||||
if idStr := c.Param("id"); idStr != "" {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
@@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
response.Success(c, detailedResponse)
|
||||
return
|
||||
}
|
||||
// Case 2: Get all groups
|
||||
allGroups := h.groupManager.GetAllGroups()
|
||||
responses := make([]KeyGroupResponse, 0, len(allGroups))
|
||||
for _, group := range allGroups {
|
||||
@@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
response.Success(c, responses)
|
||||
}
|
||||
|
||||
// UpdateKeyGroup
|
||||
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
@@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
|
||||
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
|
||||
}
|
||||
|
||||
// DeleteKeyGroup
|
||||
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
@@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
|
||||
}
|
||||
|
||||
// GetKeyGroupStats
|
||||
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
stats, err := h.queryService.GetGroupStats(group.ID)
|
||||
|
||||
stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
|
||||
return
|
||||
@@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
|
||||
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
|
||||
}
|
||||
|
||||
// 更新分组排序
|
||||
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
|
||||
var payload []service.UpdateOrderPayload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
|
||||
@@ -29,9 +29,7 @@ import (
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type proxyErrorKey int
|
||||
|
||||
const proxyErrKey proxyErrorKey = 0
|
||||
type proxyErrorContextKey struct{}
|
||||
|
||||
type ProxyHandler struct {
|
||||
resourceService *service.ResourceService
|
||||
@@ -81,45 +79,51 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
h.handleListModelsRequest(c)
|
||||
return
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
|
||||
maxBodySize := int64(h.settingsManager.GetSettings().MaxRequestBodySizeMB * 1024 * 1024)
|
||||
requestBody, err := io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
|
||||
if err != nil {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Failed to read request body"))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Request body too large or failed to read"))
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
c.Request.ContentLength = int64(len(requestBody))
|
||||
|
||||
modelName := h.channel.ExtractModel(c, requestBody)
|
||||
groupName := c.Param("group_name")
|
||||
isPreciseRouting := groupName != ""
|
||||
|
||||
if !isPreciseRouting && modelName == "" {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in the request body or URL"))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrBadRequest, "Model not specified in request"))
|
||||
return
|
||||
}
|
||||
|
||||
initialResources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*errors.APIError); ok {
|
||||
errToJSON(c, uuid.New().String(), apiErr)
|
||||
} else {
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrNoKeysAvailable, err.Error()))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to get initial resources"))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
finalOpConfig, err := h.groupManager.BuildOperationalConfig(initialResources.KeyGroup)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to build operational config.")
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational configuration"))
|
||||
errToJSON(c, uuid.New().String(), errors.NewAPIError(errors.ErrInternalServer, "Failed to build operational config"))
|
||||
return
|
||||
}
|
||||
|
||||
initialResources.RequestConfig = h.buildFinalRequestConfig(h.settingsManager.GetSettings(), initialResources.RequestConfig)
|
||||
|
||||
isOpenAICompatible := h.channel.IsOpenAICompatibleRequest(c)
|
||||
if isOpenAICompatible {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
return
|
||||
}
|
||||
|
||||
isStream := h.channel.IsStreamRequest(c, requestBody)
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
useSmartGateway := finalOpConfig.EnableSmartGateway != nil && *finalOpConfig.EnableSmartGateway
|
||||
if useSmartGateway && isStream && systemSettings.EnableStreamingRetry {
|
||||
if useSmartGateway && isStream && h.settingsManager.GetSettings().EnableStreamingRetry {
|
||||
h.serveSmartStream(c, requestBody, initialResources, isPreciseRouting)
|
||||
} else {
|
||||
h.serveTransparentProxy(c, requestBody, initialResources, finalOpConfig, modelName, groupName, isPreciseRouting)
|
||||
@@ -129,226 +133,307 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
func (h *ProxyHandler) serveTransparentProxy(c *gin.Context, requestBody []byte, initialResources *service.RequestResources, finalOpConfig *models.KeyGroupSettings, modelName, groupName string, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
|
||||
var finalRecorder *httptest.ResponseRecorder
|
||||
var lastUsedResources *service.RequestResources
|
||||
var finalProxyErr *errors.APIError
|
||||
var isSuccess bool
|
||||
var finalPromptTokens, finalCompletionTokens int
|
||||
var actualRetries int = 0
|
||||
defer func() {
|
||||
// 如果一次尝试都未成功(例如,在第一次获取资源时就失败),则不记录日志
|
||||
if lastUsedResources == nil {
|
||||
h.logger.WithField("id", correlationID).Warn("No resources were used, skipping final log event.")
|
||||
return
|
||||
}
|
||||
finalEvent := h.createLogEvent(c, startTime, correlationID, modelName, lastUsedResources, models.LogTypeFinal, isPreciseRouting)
|
||||
var finalPromptTokens, finalCompletionTokens, actualRetries int
|
||||
|
||||
finalEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
finalEvent.RequestLog.IsSuccess = isSuccess
|
||||
finalEvent.RequestLog.Retries = actualRetries
|
||||
if isSuccess {
|
||||
finalEvent.RequestLog.PromptTokens = finalPromptTokens
|
||||
finalEvent.RequestLog.CompletionTokens = finalCompletionTokens
|
||||
}
|
||||
defer h.publishFinalLogEvent(c, startTime, correlationID, modelName, lastUsedResources,
|
||||
finalRecorder, finalProxyErr, isSuccess, finalPromptTokens, finalCompletionTokens,
|
||||
actualRetries, isPreciseRouting)
|
||||
|
||||
// 确保即使在成功的情况下,如果recorder存在,也记录最终的状态码
|
||||
if finalRecorder != nil {
|
||||
finalEvent.RequestLog.StatusCode = finalRecorder.Code
|
||||
}
|
||||
if !isSuccess {
|
||||
// 将 finalProxyErr 的信息填充到 RequestLog 中
|
||||
if finalProxyErr != nil {
|
||||
finalEvent.Error = finalProxyErr // Error 字段用于事件传递,不会被序列化到数据库
|
||||
finalEvent.RequestLog.ErrorCode = finalProxyErr.Code
|
||||
finalEvent.RequestLog.ErrorMessage = finalProxyErr.Message
|
||||
} else if finalRecorder != nil {
|
||||
// 降级处理:如果 finalProxyErr 为空但 recorder 存在且失败
|
||||
apiErr := errors.NewAPIErrorWithUpstream(finalRecorder.Code, fmt.Sprintf("UPSTREAM_%d", finalRecorder.Code), "Request failed after all retries.")
|
||||
finalEvent.Error = apiErr
|
||||
finalEvent.RequestLog.ErrorCode = apiErr.Code
|
||||
finalEvent.RequestLog.ErrorMessage = apiErr.Message
|
||||
}
|
||||
}
|
||||
// 将完整的事件发布
|
||||
eventData, err := json.Marshal(finalEvent)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event.")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event.")
|
||||
}
|
||||
}()
|
||||
var maxRetries int
|
||||
if isPreciseRouting {
|
||||
// For precise routing, use the group's setting. If not set, fall back to the global setting.
|
||||
if finalOpConfig.MaxRetries != nil {
|
||||
maxRetries = *finalOpConfig.MaxRetries
|
||||
} else {
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
} else {
|
||||
// For BasePool (intelligent aggregation), *always* use the global setting.
|
||||
maxRetries = h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
maxRetries := h.getMaxRetries(isPreciseRouting, finalOpConfig)
|
||||
totalAttempts := maxRetries + 1
|
||||
|
||||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||||
if c.Request.Context().Err() != nil {
|
||||
h.logger.WithField("id", correlationID).Info("Client disconnected, aborting retry loop.")
|
||||
if finalProxyErr == nil {
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrBadRequest, "Client disconnected")
|
||||
}
|
||||
break
|
||||
}
|
||||
var currentResources *service.RequestResources
|
||||
var err error
|
||||
if attempt == 1 {
|
||||
currentResources = initialResources
|
||||
} else {
|
||||
actualRetries = attempt - 1
|
||||
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
||||
currentResources, err = h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).Errorf("Failed to get new resources for retry, aborting: %v", err)
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrNoKeysAvailable, "Failed to get new resources for retry")
|
||||
break
|
||||
|
||||
resources, err := h.getResourcesForAttempt(c, attempt, initialResources, modelName, groupName, isPreciseRouting, correlationID)
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*errors.APIError); ok {
|
||||
finalProxyErr = apiErr
|
||||
} else {
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "Failed to get resources for retry")
|
||||
}
|
||||
break
|
||||
}
|
||||
lastUsedResources = resources
|
||||
if attempt > 1 {
|
||||
actualRetries = attempt - 1
|
||||
}
|
||||
|
||||
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), currentResources.RequestConfig)
|
||||
currentResources.RequestConfig = finalRequestConfig
|
||||
lastUsedResources = currentResources
|
||||
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d...", attempt, totalAttempts, currentResources.APIKey.ID)
|
||||
var attemptErr *errors.APIError
|
||||
var attemptIsSuccess bool
|
||||
recorder := httptest.NewRecorder()
|
||||
attemptStartTime := time.Now()
|
||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
||||
defer cancel()
|
||||
attemptReq := c.Request.Clone(ctx)
|
||||
attemptReq.Body = io.NopCloser(bytes.NewReader(requestBody))
|
||||
if currentResources.UpstreamEndpoint == nil || currentResources.UpstreamEndpoint.URL == "" {
|
||||
h.logger.WithField("id", correlationID).Errorf("Attempt %d failed: no upstream URL in resources.", attempt)
|
||||
isSuccess = false
|
||||
finalProxyErr = errors.NewAPIError(errors.ErrInternalServer, "No upstream URL configured for the selected resource")
|
||||
continue
|
||||
}
|
||||
h.transparentProxy.Director = func(req *http.Request) {
|
||||
targetURL, _ := url.Parse(currentResources.UpstreamEndpoint.URL)
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.Host = targetURL.Host
|
||||
var pureClientPath string
|
||||
if isPreciseRouting {
|
||||
proxyPrefix := "/proxy/" + groupName
|
||||
pureClientPath = strings.TrimPrefix(req.URL.Path, proxyPrefix)
|
||||
} else {
|
||||
pureClientPath = req.URL.Path
|
||||
}
|
||||
finalPath := h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||||
req.URL.Path = finalPath
|
||||
h.logger.WithFields(logrus.Fields{
|
||||
"correlation_id": correlationID,
|
||||
"attempt": attempt,
|
||||
"key_id": currentResources.APIKey.ID,
|
||||
"base_upstream_url": currentResources.UpstreamEndpoint.URL,
|
||||
"final_request_url": req.URL.String(),
|
||||
}).Infof("Director constructed final upstream request URL.")
|
||||
req.Header.Del("Authorization")
|
||||
h.channel.ModifyRequest(req, currentResources.APIKey)
|
||||
req.Header.Set("X-Correlation-ID", correlationID)
|
||||
*req = *req.WithContext(context.WithValue(req.Context(), proxyErrKey, &attemptErr))
|
||||
}
|
||||
transport := h.transparentProxy.Transport.(*http.Transport)
|
||||
if currentResources.ProxyConfig != nil {
|
||||
proxyURLStr := fmt.Sprintf("%s://%s", currentResources.ProxyConfig.Protocol, currentResources.ProxyConfig.Address)
|
||||
proxyURL, err := url.Parse(proxyURLStr)
|
||||
if err == nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
h.transparentProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
defer resp.Body.Close()
|
||||
var reader io.ReadCloser
|
||||
var err error
|
||||
isGzipped := resp.Header.Get("Content-Encoding") == "gzip"
|
||||
if isGzipped {
|
||||
reader, err = gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to create gzip reader")
|
||||
reader = resp.Body
|
||||
} else {
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
defer reader.Close()
|
||||
} else {
|
||||
reader = resp.Body
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response: "+err.Error())
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte(attemptErr.Message)))
|
||||
return nil
|
||||
}
|
||||
if resp.StatusCode < 400 {
|
||||
attemptIsSuccess = true
|
||||
finalPromptTokens, finalCompletionTokens = extractUsage(bodyBytes)
|
||||
} else {
|
||||
parsedMsg := errors.ParseUpstreamError(bodyBytes)
|
||||
attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
||||
finalRecorder = recorder
|
||||
finalProxyErr = attemptErr
|
||||
isSuccess = attemptIsSuccess
|
||||
h.resourceService.ReportRequestResult(currentResources, isSuccess, finalProxyErr)
|
||||
h.logger.WithField("id", correlationID).Infof("Attempt %d/%d with KeyID %d", attempt, totalAttempts, resources.APIKey.ID)
|
||||
|
||||
recorder, attemptErr, attemptSuccess := h.executeProxyAttempt(
|
||||
c, correlationID, requestBody, resources, isPreciseRouting, groupName,
|
||||
&finalPromptTokens, &finalCompletionTokens,
|
||||
)
|
||||
|
||||
finalRecorder, finalProxyErr, isSuccess = recorder, attemptErr, attemptSuccess
|
||||
h.resourceService.ReportRequestResult(resources, isSuccess, finalProxyErr)
|
||||
|
||||
if isSuccess {
|
||||
break
|
||||
}
|
||||
isUnretryableError := false
|
||||
if finalProxyErr != nil {
|
||||
if errors.IsUnretryableRequestError(finalProxyErr.Message) {
|
||||
isUnretryableError = true
|
||||
h.logger.WithField("id", correlationID).Warnf("Attempt %d failed with unretryable request error. Aborting retries. Message: %s", attempt, finalProxyErr.Message)
|
||||
}
|
||||
}
|
||||
if attempt >= totalAttempts || isUnretryableError {
|
||||
if h.shouldStopRetrying(attempt, totalAttempts, finalProxyErr, correlationID) {
|
||||
break
|
||||
}
|
||||
retryEvent := h.createLogEvent(c, startTime, correlationID, modelName, currentResources, models.LogTypeRetry, isPreciseRouting)
|
||||
retryEvent.LatencyMs = int(time.Since(attemptStartTime).Milliseconds())
|
||||
retryEvent.IsSuccess = false
|
||||
retryEvent.StatusCode = recorder.Code
|
||||
retryEvent.Retries = actualRetries
|
||||
if attemptErr != nil {
|
||||
retryEvent.Error = attemptErr
|
||||
retryEvent.ErrorCode = attemptErr.Code
|
||||
retryEvent.ErrorMessage = attemptErr.Message
|
||||
}
|
||||
eventData, _ := json.Marshal(retryEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
h.publishRetryLogEvent(c, startTime, correlationID, modelName, resources, recorder, attemptErr, actualRetries, isPreciseRouting)
|
||||
}
|
||||
if finalRecorder != nil {
|
||||
bodyBytes := finalRecorder.Body.Bytes()
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
|
||||
for k, v := range finalRecorder.Header() {
|
||||
if strings.ToLower(k) != "content-length" {
|
||||
c.Writer.Header()[k] = v
|
||||
|
||||
h.writeFinalResponse(c, correlationID, finalRecorder, finalProxyErr)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) executeProxyAttempt(c *gin.Context, corrID string, body []byte, res *service.RequestResources, isPrecise bool, groupName string, pTokens, cTokens *int) (*httptest.ResponseRecorder, *errors.APIError, bool) {
|
||||
recorder := httptest.NewRecorder()
|
||||
var attemptErr *errors.APIError
|
||||
var isSuccess bool
|
||||
|
||||
connectTimeout := time.Duration(h.settingsManager.GetSettings().ConnectTimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
attemptReq := c.Request.Clone(ctx)
|
||||
attemptReq.Body = io.NopCloser(bytes.NewReader(body))
|
||||
attemptReq.ContentLength = int64(len(body))
|
||||
|
||||
h.configureProxy(corrID, res, isPrecise, groupName, &attemptErr, &isSuccess, pTokens, cTokens)
|
||||
*attemptReq = *attemptReq.WithContext(context.WithValue(attemptReq.Context(), proxyErrorContextKey{}, &attemptErr))
|
||||
|
||||
h.transparentProxy.ServeHTTP(recorder, attemptReq)
|
||||
|
||||
return recorder, attemptErr, isSuccess
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) configureProxy(corrID string, res *service.RequestResources, isPrecise bool, groupName string, attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) {
|
||||
h.transparentProxy.Director = func(r *http.Request) {
|
||||
targetURL, _ := url.Parse(res.UpstreamEndpoint.URL)
|
||||
r.URL.Scheme, r.URL.Host, r.Host = targetURL.Scheme, targetURL.Host, targetURL.Host
|
||||
|
||||
var pureClientPath string
|
||||
if isPrecise {
|
||||
pureClientPath = strings.TrimPrefix(r.URL.Path, "/proxy/"+groupName)
|
||||
} else {
|
||||
pureClientPath = r.URL.Path
|
||||
}
|
||||
r.URL.Path = h.channel.RewritePath(targetURL.Path, pureClientPath)
|
||||
|
||||
r.Header.Del("Authorization")
|
||||
h.channel.ModifyRequest(r, res.APIKey)
|
||||
r.Header.Set("X-Correlation-ID", corrID)
|
||||
}
|
||||
|
||||
transport := h.transparentProxy.Transport.(*http.Transport)
|
||||
if res.ProxyConfig != nil {
|
||||
proxyURLStr := fmt.Sprintf("%s://%s", res.ProxyConfig.Protocol, res.ProxyConfig.Address)
|
||||
if proxyURL, err := url.Parse(proxyURLStr); err == nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
h.transparentProxy.ModifyResponse = h.createModifyResponseFunc(attemptErr, isSuccess, pTokens, cTokens)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) createModifyResponseFunc(attemptErr **errors.APIError, isSuccess *bool, pTokens, cTokens *int) func(*http.Response) error {
|
||||
return func(resp *http.Response) error {
|
||||
var reader io.ReadCloser = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to create gzip reader")
|
||||
} else {
|
||||
reader = gzReader
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
}
|
||||
c.Writer.WriteHeader(finalRecorder.Code)
|
||||
c.Writer.Write(finalRecorder.Body.Bytes())
|
||||
} else {
|
||||
errToJSON(c, correlationID, finalProxyErr)
|
||||
defer reader.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
*attemptErr = errors.NewAPIError(errors.ErrBadGateway, "Failed to read upstream response")
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode < 400 {
|
||||
*isSuccess = true
|
||||
*pTokens, *cTokens = extractUsage(bodyBytes)
|
||||
} else {
|
||||
parsedMsg := errors.ParseUpstreamError(bodyBytes)
|
||||
*attemptErr = errors.NewAPIErrorWithUpstream(resp.StatusCode, fmt.Sprintf("UPSTREAM_%d", resp.StatusCode), parsedMsg)
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
corrID := r.Header.Get("X-Correlation-ID")
|
||||
log := h.logger.WithField("id", corrID)
|
||||
log.Errorf("Transparent proxy encountered an error: %v", err)
|
||||
|
||||
errPtr, ok := r.Context().Value(proxyErrorContextKey{}).(**errors.APIError)
|
||||
if !ok || errPtr == nil {
|
||||
log.Error("FATAL: proxyErrorContextKey not found in context for error handler.")
|
||||
defaultErr := errors.NewAPIError(errors.ErrBadGateway, "An unexpected proxy error occurred")
|
||||
writeErrorToResponse(rw, defaultErr)
|
||||
return
|
||||
}
|
||||
|
||||
if *errPtr == nil {
|
||||
if errors.IsClientNetworkError(err) {
|
||||
*errPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
} else {
|
||||
*errPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||||
}
|
||||
}
|
||||
writeErrorToResponse(rw, *errPtr)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) getResourcesForAttempt(c *gin.Context, attempt int, initialResources *service.RequestResources, modelName, groupName string, isPreciseRouting bool, correlationID string) (*service.RequestResources, error) {
|
||||
if attempt == 1 {
|
||||
return initialResources, nil
|
||||
}
|
||||
h.logger.WithField("id", correlationID).Infof("Retrying... getting new resources for attempt %d.", attempt)
|
||||
resources, err := h.getResourcesForRequest(c, modelName, groupName, isPreciseRouting)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
finalRequestConfig := h.buildFinalRequestConfig(h.settingsManager.GetSettings(), resources.RequestConfig)
|
||||
resources.RequestConfig = finalRequestConfig
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) shouldStopRetrying(attempt, totalAttempts int, err *errors.APIError, correlationID string) bool {
|
||||
if attempt >= totalAttempts {
|
||||
return true
|
||||
}
|
||||
if err != nil && errors.IsUnretryableRequestError(err.Message) {
|
||||
h.logger.WithField("id", correlationID).Warnf("Attempt failed with unretryable request error. Aborting retries. Message: %s", err.Message)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeFinalResponse(c *gin.Context, corrID string, rec *httptest.ResponseRecorder, apiErr *errors.APIError) {
|
||||
if rec != nil {
|
||||
for k, v := range rec.Header() {
|
||||
c.Writer.Header()[k] = v
|
||||
}
|
||||
c.Writer.WriteHeader(rec.Code)
|
||||
c.Writer.Write(rec.Body.Bytes())
|
||||
} else if apiErr != nil {
|
||||
errToJSON(c, corrID, apiErr)
|
||||
} else {
|
||||
errToJSON(c, corrID, errors.NewAPIError(errors.ErrInternalServer, "An unknown error occurred"))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) publishFinalLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, finalErr *errors.APIError, isSuccess bool, pTokens, cTokens, retries int, isPrecise bool) {
|
||||
if res == nil {
|
||||
h.logger.WithField("id", corrID).Warn("No resources were used, skipping final log event.")
|
||||
return
|
||||
}
|
||||
event := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeFinal, isPrecise)
|
||||
event.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
event.RequestLog.IsSuccess = isSuccess
|
||||
event.RequestLog.Retries = retries
|
||||
if isSuccess {
|
||||
event.RequestLog.PromptTokens, event.RequestLog.CompletionTokens = pTokens, cTokens
|
||||
}
|
||||
if rec != nil {
|
||||
event.RequestLog.StatusCode = rec.Code
|
||||
}
|
||||
if !isSuccess {
|
||||
errToLog := finalErr
|
||||
if errToLog == nil && rec != nil {
|
||||
errToLog = errors.NewAPIErrorWithUpstream(rec.Code, fmt.Sprintf("UPSTREAM_%d", rec.Code), "Request failed after all retries.")
|
||||
}
|
||||
if errToLog != nil {
|
||||
event.Error = errToLog
|
||||
event.RequestLog.ErrorCode, event.RequestLog.ErrorMessage = errToLog.Code, errToLog.Message
|
||||
}
|
||||
}
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal log event")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish log event")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) publishRetryLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, rec *httptest.ResponseRecorder, attemptErr *errors.APIError, retries int, isPrecise bool) {
|
||||
retryEvent := h.createLogEvent(c, startTime, corrID, modelName, res, models.LogTypeRetry, isPrecise)
|
||||
retryEvent.RequestLog.LatencyMs = int(time.Since(startTime).Milliseconds())
|
||||
retryEvent.RequestLog.IsSuccess = false
|
||||
retryEvent.RequestLog.StatusCode = rec.Code
|
||||
retryEvent.RequestLog.Retries = retries
|
||||
if attemptErr != nil {
|
||||
retryEvent.Error = attemptErr
|
||||
retryEvent.RequestLog.ErrorCode, retryEvent.RequestLog.ErrorMessage = attemptErr.Code, attemptErr.Message
|
||||
}
|
||||
eventData, err := json.Marshal(retryEvent)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to marshal retry log event")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", corrID).WithError(err).Error("Failed to publish retry log event")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
||||
finalConfig := &models.RequestConfig{
|
||||
CustomHeaders: make(datatypes.JSONMap),
|
||||
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
||||
StreamMinDelay: globalSettings.StreamMinDelay,
|
||||
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
||||
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
||||
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
||||
StreamChunkSize: globalSettings.StreamChunkSize,
|
||||
EnableFakeStream: globalSettings.EnableFakeStream,
|
||||
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
||||
}
|
||||
for k, v := range globalSettings.CustomHeaders {
|
||||
finalConfig.CustomHeaders[k] = v
|
||||
}
|
||||
if groupConfig == nil {
|
||||
return finalConfig
|
||||
}
|
||||
groupConfigJSON, err := json.Marshal(groupConfig)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
||||
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
||||
}
|
||||
return finalConfig
|
||||
}
|
||||
|
||||
func writeErrorToResponse(rw http.ResponseWriter, apiErr *errors.APIError) {
|
||||
if writer, ok := rw.(interface{ Written() bool }); ok && writer.Written() {
|
||||
return
|
||||
}
|
||||
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
rw.WriteHeader(apiErr.HTTPStatus)
|
||||
json.NewEncoder(rw).Encode(gin.H{"error": apiErr})
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, resources *service.RequestResources, isPreciseRouting bool) {
|
||||
startTime := time.Now()
|
||||
correlationID := uuid.New().String()
|
||||
@@ -356,7 +441,7 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
|
||||
log.Info("Smart Gateway activated for streaming request.")
|
||||
var originalRequest models.GeminiRequest
|
||||
if err := json.Unmarshal(requestBody, &originalRequest); err != nil {
|
||||
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Smart Gateway failed: Request body is not a valid Gemini native format. Error: "+err.Error()))
|
||||
errToJSON(c, correlationID, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid request format for Smart Gateway"))
|
||||
return
|
||||
}
|
||||
systemSettings := h.settingsManager.GetSettings()
|
||||
@@ -367,8 +452,14 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
|
||||
if c.Writer.Status() > 0 {
|
||||
requestFinishedEvent.StatusCode = c.Writer.Status()
|
||||
}
|
||||
eventData, _ := json.Marshal(requestFinishedEvent)
|
||||
_ = h.store.Publish(models.TopicRequestFinished, eventData)
|
||||
eventData, err := json.Marshal(requestFinishedEvent)
|
||||
if err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to marshal final log event for smart stream")
|
||||
return
|
||||
}
|
||||
if err := h.store.Publish(context.Background(), models.TopicRequestFinished, eventData); err != nil {
|
||||
h.logger.WithField("id", correlationID).WithError(err).Error("Failed to publish final log event for smart stream")
|
||||
}
|
||||
}()
|
||||
params := channel.SmartRequestParams{
|
||||
CorrelationID: correlationID,
|
||||
@@ -385,30 +476,6 @@ func (h *ProxyHandler) serveSmartStream(c *gin.Context, requestBody []byte, reso
|
||||
h.channel.ProcessSmartStreamRequest(c, params)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) transparentProxyErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
correlationID := r.Header.Get("X-Correlation-ID")
|
||||
h.logger.WithField("id", correlationID).Errorf("Transparent proxy error: %v", err)
|
||||
proxyErrPtr, exists := r.Context().Value(proxyErrKey).(**errors.APIError)
|
||||
if !exists || proxyErrPtr == nil {
|
||||
h.logger.WithField("id", correlationID).Error("FATAL: proxyErrorKey not found in context for error handler.")
|
||||
return
|
||||
}
|
||||
if errors.IsClientNetworkError(err) {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadRequest, "Client connection closed")
|
||||
} else {
|
||||
*proxyErrPtr = errors.NewAPIError(errors.ErrBadGateway, err.Error())
|
||||
}
|
||||
if _, ok := rw.(*httptest.ResponseRecorder); ok {
|
||||
return
|
||||
}
|
||||
if writer, ok := rw.(interface{ Written() bool }); ok {
|
||||
if writer.Written() {
|
||||
return
|
||||
}
|
||||
}
|
||||
rw.WriteHeader((*proxyErrPtr).HTTPStatus)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrID, modelName string, res *service.RequestResources, logType models.LogType, isPreciseRouting bool) *models.RequestFinishedEvent {
|
||||
event := &models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
@@ -435,7 +502,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
|
||||
}
|
||||
}
|
||||
if res != nil {
|
||||
// [核心修正] 填充到内嵌的 RequestLog 结构体中
|
||||
if res.APIKey != nil {
|
||||
event.RequestLog.KeyID = &res.APIKey.ID
|
||||
}
|
||||
@@ -444,7 +510,6 @@ func (h *ProxyHandler) createLogEvent(c *gin.Context, startTime time.Time, corrI
|
||||
}
|
||||
if res.UpstreamEndpoint != nil {
|
||||
event.RequestLog.UpstreamID = &res.UpstreamEndpoint.ID
|
||||
// UpstreamURL 是事件传递字段,不是数据库字段,所以在这里赋值是正确的
|
||||
event.UpstreamURL = &res.UpstreamEndpoint.URL
|
||||
}
|
||||
if res.ProxyConfig != nil {
|
||||
@@ -464,13 +529,15 @@ func (h *ProxyHandler) getResourcesForRequest(c *gin.Context, modelName string,
|
||||
return nil, errors.NewAPIError(errors.ErrInternalServer, "Invalid auth token type in context")
|
||||
}
|
||||
if isPreciseRouting {
|
||||
return h.resourceService.GetResourceFromGroup(authToken, groupName)
|
||||
} else {
|
||||
return h.resourceService.GetResourceFromBasePool(authToken, modelName)
|
||||
return h.resourceService.GetResourceFromGroup(c.Request.Context(), authToken, groupName)
|
||||
}
|
||||
return h.resourceService.GetResourceFromBasePool(c.Request.Context(), authToken, modelName)
|
||||
}
|
||||
|
||||
func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{
|
||||
"error": apiErr,
|
||||
"correlation_id": corrID,
|
||||
@@ -479,8 +546,8 @@ func errToJSON(c *gin.Context, corrID string, apiErr *errors.APIError) {
|
||||
|
||||
type bufferPool struct{}
|
||||
|
||||
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
||||
func (b *bufferPool) Put(bytes []byte) {}
|
||||
func (b *bufferPool) Get() []byte { return make([]byte, 32*1024) }
|
||||
func (b *bufferPool) Put(_ []byte) {}
|
||||
|
||||
func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
||||
var data struct {
|
||||
@@ -495,34 +562,11 @@ func extractUsage(body []byte) (promptTokens int, completionTokens int) {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) buildFinalRequestConfig(globalSettings *models.SystemSettings, groupConfig *models.RequestConfig) *models.RequestConfig {
|
||||
customHeadersJSON, _ := json.Marshal(globalSettings.CustomHeaders)
|
||||
var customHeadersMap datatypes.JSONMap
|
||||
_ = json.Unmarshal(customHeadersJSON, &customHeadersMap)
|
||||
finalConfig := &models.RequestConfig{
|
||||
CustomHeaders: customHeadersMap,
|
||||
EnableStreamOptimizer: globalSettings.EnableStreamOptimizer,
|
||||
StreamMinDelay: globalSettings.StreamMinDelay,
|
||||
StreamMaxDelay: globalSettings.StreamMaxDelay,
|
||||
StreamShortTextThresh: globalSettings.StreamShortTextThresh,
|
||||
StreamLongTextThresh: globalSettings.StreamLongTextThresh,
|
||||
StreamChunkSize: globalSettings.StreamChunkSize,
|
||||
EnableFakeStream: globalSettings.EnableFakeStream,
|
||||
FakeStreamInterval: globalSettings.FakeStreamInterval,
|
||||
func (h *ProxyHandler) getMaxRetries(isPreciseRouting bool, finalOpConfig *models.KeyGroupSettings) int {
|
||||
if isPreciseRouting && finalOpConfig.MaxRetries != nil {
|
||||
return *finalOpConfig.MaxRetries
|
||||
}
|
||||
if groupConfig == nil {
|
||||
return finalConfig
|
||||
}
|
||||
groupConfigJSON, err := json.Marshal(groupConfig)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to marshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
if err := json.Unmarshal(groupConfigJSON, finalConfig); err != nil {
|
||||
h.logger.WithError(err).Error("Failed to unmarshal group request config for merging.")
|
||||
return finalConfig
|
||||
}
|
||||
return finalConfig
|
||||
return h.settingsManager.GetSettings().MaxRetries
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) handleListModelsRequest(c *gin.Context) {
|
||||
|
||||
@@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
taskStatus, err := h.taskService.GetStatus(taskID)
|
||||
taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID)
|
||||
if err != nil {
|
||||
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
|
||||
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))
|
||||
|
||||
@@ -77,3 +77,9 @@ type APIKeyDetails struct {
|
||||
CooldownUntil *time.Time `json:"cooldown_until"`
|
||||
EncryptedKey string
|
||||
}
|
||||
|
||||
// SettingsManager 定义了系统设置管理器的抽象接口。
|
||||
|
||||
type SettingsManager interface {
|
||||
GetSettings() *SystemSettings
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type SystemSettings struct {
|
||||
BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"`
|
||||
KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间,单位为分钟。"`
|
||||
LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"`
|
||||
MaxRequestBodySizeMB int `json:"max_request_body_size_mb" default:"10" name:"最大请求体大小 (MB)" category:"请求设置" desc:"允许代理接收的最大请求体大小,单位为MB。超过此大小的请求将被拒绝。"`
|
||||
|
||||
PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"`
|
||||
|
||||
@@ -41,6 +42,10 @@ type SystemSettings struct {
|
||||
MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前,允许的连续登录失败次数。"`
|
||||
IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长,单位为分钟。"`
|
||||
|
||||
// BasePool 相关配置
|
||||
// BasePoolTTLMinutes int `json:"base_pool_ttl_minutes" default:"30" name:"基础资源池最大生存时间(分钟)" category:"基础资源池" desc:"一个动态构建的基础资源池(BasePool)在Redis中的最大生存时间。到期后即使仍在活跃使用也会被强制重建。"`
|
||||
// BasePoolTTIMinutes int `json:"base_pool_tti_minutes" default:"10" name:"基础资源池空闲超时(分钟)" category:"基础资源池" desc:"一个基础资源池(BasePool)在连续无请求后,自动销毁的空闲等待时间。"`
|
||||
|
||||
//智能网关
|
||||
LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时,保留的最大字符数。0表示不截断。"`
|
||||
EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"`
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
// Filename: internal/repository/key_cache.go
|
||||
// Filename: internal/repository/key_cache.go (最终定稿)
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// --- Redis Key 常量定义 ---
|
||||
const (
|
||||
KeyGroup = "group:%d:keys:active"
|
||||
KeyDetails = "key:%d:details"
|
||||
@@ -22,13 +24,16 @@ const (
|
||||
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
|
||||
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系,并完整重建Redis缓存。
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
|
||||
r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
|
||||
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
|
||||
return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
|
||||
}
|
||||
|
||||
// 1. 批量解密所有涉及的密钥
|
||||
keyMap := make(map[uint]*models.APIKey)
|
||||
for _, m := range allMappings {
|
||||
if m.APIKey != nil {
|
||||
@@ -40,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
keysToDecrypt = append(keysToDecrypt, *k)
|
||||
}
|
||||
if err := r.decryptKeys(keysToDecrypt); err != nil {
|
||||
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
|
||||
r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
|
||||
// 即使解密失败,也继续尝试加载未加密或已解密的部分
|
||||
}
|
||||
decryptedKeyMap := make(map[uint]models.APIKey)
|
||||
for _, k := range keysToDecrypt {
|
||||
decryptedKeyMap[k.ID] = k
|
||||
}
|
||||
|
||||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
pipe := r.store.Pipeline()
|
||||
detailsToSet := make(map[string][]byte)
|
||||
// 2. 清理所有分组的旧轮询结构
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
var allGroups []*models.KeyGroup
|
||||
if err := r.db.Find(&allGroups).Error; err == nil {
|
||||
for _, group := range allGroups {
|
||||
@@ -62,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
)
|
||||
}
|
||||
} else {
|
||||
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
|
||||
r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
|
||||
}
|
||||
|
||||
// 3. 准备批量更新数据
|
||||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
detailsToSet := make(map[string]any)
|
||||
|
||||
for _, mapping := range allMappings {
|
||||
if mapping.APIKey == nil {
|
||||
continue
|
||||
}
|
||||
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
continue
|
||||
continue // 跳过解密失败的密钥
|
||||
}
|
||||
|
||||
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
|
||||
keyJSON, _ := json.Marshal(decryptedKey)
|
||||
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
|
||||
mappingJSON, _ := json.Marshal(mapping)
|
||||
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
|
||||
|
||||
if mapping.Status == models.StatusActive {
|
||||
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 使用 MSet 批量写入详情和映射缓存
|
||||
if len(detailsToSet) > 0 {
|
||||
if err := r.store.MSet(ctx, detailsToSet); err != nil {
|
||||
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 在Pipeline中重建所有分组的轮询结构
|
||||
for groupID, activeMappings := range activeKeysByGroup {
|
||||
if len(activeMappings) == 0 {
|
||||
continue
|
||||
@@ -100,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
|
||||
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
}
|
||||
|
||||
// 6. 执行Pipeline
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
|
||||
}
|
||||
for key, value := range detailsToSet {
|
||||
if err := r.store.Set(key, value, 0); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
|
||||
}
|
||||
return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cache rebuild complete, including all polling structures.")
|
||||
r.logger.Info("Full cache rebuild completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
|
||||
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
|
||||
@@ -124,81 +141,104 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
|
||||
}
|
||||
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(key.ID)
|
||||
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
|
||||
if err != nil {
|
||||
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
|
||||
r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
|
||||
}
|
||||
|
||||
pipe := r.store.Pipeline()
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
|
||||
|
||||
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
||||
for _, groupID := range groupIDs {
|
||||
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
|
||||
|
||||
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
}
|
||||
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
|
||||
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
pipe := r.store.Pipeline()
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
|
||||
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
groupID := mapping.KeyGroupID
|
||||
ctx := context.Background()
|
||||
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
|
||||
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
|
||||
// 如果新状态是 Active,则重新添加到所有轮询结构中
|
||||
if mapping.Status == models.StatusActive {
|
||||
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
|
||||
}
|
||||
|
||||
// 无论状态如何,都更新映射详情的 K-V 缓存
|
||||
mappingJSON, err := json.Marshal(mapping)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal mapping: %w", err)
|
||||
}
|
||||
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
|
||||
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
|
||||
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
|
||||
if len(mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
groupUpdates := make(map[uint]struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
})
|
||||
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
|
||||
for _, mapping := range mappings {
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
update, ok := groupUpdates[mapping.KeyGroupID]
|
||||
if !ok {
|
||||
update = struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
}{}
|
||||
}
|
||||
groupID := mapping.KeyGroupID
|
||||
|
||||
// 对于批处理中的每一个mapping,都执行完整的、正确的“先删后增”逻辑
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
|
||||
if mapping.Status == models.StatusActive {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
update.ToAdd = append(update.ToAdd, keyIDStr)
|
||||
} else {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
}
|
||||
groupUpdates[mapping.KeyGroupID] = update
|
||||
}
|
||||
pipe := r.store.Pipeline()
|
||||
var pipelineError error
|
||||
for groupID, updates := range groupUpdates {
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
if len(updates.ToRemove) > 0 {
|
||||
for _, keyID := range updates.ToRemove {
|
||||
pipe.LRem(activeKeyListKey, 0, keyID)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
|
||||
}
|
||||
if len(updates.ToAdd) > 0 {
|
||||
pipe.LPush(activeKeyListKey, updates.ToAdd...)
|
||||
}
|
||||
|
||||
mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误,以保证大部分更新成功
|
||||
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
|
||||
}
|
||||
if err := pipe.Exec(); err != nil {
|
||||
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
|
||||
}
|
||||
return pipelineError
|
||||
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"context"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -22,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
|
||||
keyHashes := make([]string, len(keys))
|
||||
keyValueToHashMap := make(map[string]string)
|
||||
for i, k := range keys {
|
||||
// All incoming keys must have plaintext APIKey
|
||||
if k.APIKey == "" {
|
||||
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
|
||||
}
|
||||
@@ -34,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
|
||||
var finalKeys []models.APIKey
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var existingKeys []models.APIKey
|
||||
// [MODIFIED] Query by hash to find existing keys.
|
||||
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -68,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
|
||||
}
|
||||
}
|
||||
if len(keysToCreate) > 0 {
|
||||
// [MODIFIED] Create now only provides encrypted data and hash.
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
|
||||
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// [CRITICAL] Decrypt all keys before returning them to the service layer.
|
||||
|
||||
return r.decryptKeys(finalKeys)
|
||||
})
|
||||
return finalKeys, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
|
||||
// This indicates a potential change that needs to be re-encrypted.
|
||||
if key.APIKey != "" {
|
||||
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
|
||||
if err != nil {
|
||||
@@ -97,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
key.APIKeyHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
|
||||
|
||||
return tx.Save(key).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
|
||||
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
|
||||
return nil // Continue without cache update if decryption fails.
|
||||
return nil
|
||||
}
|
||||
if err := r.updateStoreCacheForKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
|
||||
@@ -115,7 +110,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
key, err := r.GetKeyByID(id) // This now returns a decrypted key
|
||||
key, err := r.GetKeyByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +120,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.removeStoreCacheForKey(key); err != nil {
|
||||
if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
|
||||
}
|
||||
return nil
|
||||
@@ -140,16 +135,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
// Find the full key objects first to update the cache later.
|
||||
var keysToDelete []models.APIKey
|
||||
// [MODIFIED] Find by hash.
|
||||
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(keysToDelete) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Decrypt them to ensure cache has plaintext if needed.
|
||||
if err := r.decryptKeys(keysToDelete); err != nil {
|
||||
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
|
||||
}
|
||||
@@ -167,7 +159,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
|
||||
return 0, err
|
||||
}
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
@@ -194,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt before returning.
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
|
||||
}
|
||||
|
||||
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
|
||||
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&models.APIKey{}).
|
||||
Where("id = ?", keyID).
|
||||
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
|
||||
if err == nil {
|
||||
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
|
||||
go func() {
|
||||
if err := r.LoadAllKeysToStore(); err != nil {
|
||||
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
|
||||
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
}
|
||||
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
|
||||
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
}
|
||||
|
||||
return unlinkedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
|
||||
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
|
||||
strGroupIDs, err := r.store.SMembers(cacheKey)
|
||||
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
|
||||
if err != nil || len(strGroupIDs) == 0 {
|
||||
var groupIDs []uint
|
||||
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
|
||||
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
for _, id := range groupIDs {
|
||||
interfaceSlice = append(interfaceSlice, id)
|
||||
}
|
||||
r.store.SAdd(cacheKey, interfaceSlice...)
|
||||
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
|
||||
return &mapping, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -17,40 +18,40 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
CacheTTL = 5 * time.Minute
|
||||
EmptyPoolPlaceholder = "EMPTY_POOL"
|
||||
EmptyCacheTTL = 1 * time.Minute
|
||||
CacheTTL = 5 * time.Minute
|
||||
EmptyCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
|
||||
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
if group == nil {
|
||||
return nil, nil, fmt.Errorf("group cannot be nil")
|
||||
}
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch group.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||||
if zerr == nil {
|
||||
if len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
} else {
|
||||
zerr = gorm.ErrRecordNotFound
|
||||
}
|
||||
}
|
||||
err = zerr
|
||||
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
|
||||
default: // 默认或未指定策略时,使用基础的随机策略
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||||
default:
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
|
||||
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
@@ -58,65 +59,70 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
|
||||
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
|
||||
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
|
||||
}
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
|
||||
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
|
||||
go func() {
|
||||
updateCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
|
||||
}()
|
||||
}
|
||||
|
||||
return apiKey, mapping, nil
|
||||
}
|
||||
|
||||
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
// 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离
|
||||
poolID := generatePoolID(pool.CandidateGroups)
|
||||
// SelectOneActiveKeyFromBasePool 从智能聚合池中选取一个可用Key。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
if pool == nil || len(pool.CandidateGroups) == 0 {
|
||||
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
|
||||
}
|
||||
poolID := r.generatePoolID(pool.CandidateGroups)
|
||||
log := r.logger.WithField("pool_id", poolID)
|
||||
|
||||
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
|
||||
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists")
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch pool.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||||
if zerr == nil {
|
||||
if len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
} else {
|
||||
zerr = gorm.ErrRecordNotFound
|
||||
}
|
||||
}
|
||||
err = zerr
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||||
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
|
||||
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
|
||||
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||||
default:
|
||||
log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
|
||||
@@ -125,153 +131,266 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
for _, group := range pool.CandidateGroups {
|
||||
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if cacheErr == nil && apiKey != nil && mapping != nil {
|
||||
|
||||
if pool.PollingStrategy == models.StrategyWeighted {
|
||||
|
||||
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
|
||||
}
|
||||
return apiKey, group, nil
|
||||
go func() {
|
||||
bgCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.refreshBasePoolHeartbeat(bgCtx, poolID)
|
||||
}()
|
||||
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
|
||||
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
|
||||
}
|
||||
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
|
||||
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
|
||||
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
|
||||
}
|
||||
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
|
||||
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
|
||||
}
|
||||
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
|
||||
return nil, nil, err
|
||||
}
|
||||
var originGroup *models.KeyGroup
|
||||
for _, g := range pool.CandidateGroups {
|
||||
if g.ID == uint(groupID) {
|
||||
originGroup = g
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
|
||||
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
|
||||
if originGroup == nil {
|
||||
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
|
||||
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
|
||||
}
|
||||
if pool.PollingStrategy == models.StrategyWeighted {
|
||||
go func() {
|
||||
bgCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
|
||||
}()
|
||||
}
|
||||
return apiKey, originGroup, nil
|
||||
}
|
||||
|
||||
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
|
||||
listKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
|
||||
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
|
||||
exists, err := r.store.Exists(listKey)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
|
||||
return err // 直接返回读取错误
|
||||
// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构。
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
|
||||
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
|
||||
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
|
||||
// 预检查,快速失败
|
||||
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
if exists {
|
||||
val, err := r.store.LIndex(listKey, 0)
|
||||
if err != nil {
|
||||
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
|
||||
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
|
||||
} else {
|
||||
if val == EmptyPoolPlaceholder {
|
||||
return gorm.ErrRecordNotFound // 已知为空,直接返回
|
||||
}
|
||||
return nil // 缓存有效,直接返回
|
||||
}
|
||||
}
|
||||
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
|
||||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||||
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
|
||||
return err
|
||||
}
|
||||
if !acquired {
|
||||
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return r.ensureBasePoolCacheExists(pool, poolID)
|
||||
}
|
||||
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
|
||||
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
|
||||
if exists, _ := r.store.Exists(listKey); exists {
|
||||
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
|
||||
return nil
|
||||
}
|
||||
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
|
||||
var allActiveKeyIDs []string
|
||||
lruMembers := make(map[string]float64)
|
||||
// 获取分布式锁
|
||||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||||
if err := r.acquireLock(ctx, lockKey); err != nil {
|
||||
return err // acquireLock 内部已记录日志并返回明确错误
|
||||
}
|
||||
defer r.releaseLock(context.Background(), lockKey)
|
||||
// 双重检查锁定
|
||||
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
|
||||
return nil
|
||||
}
|
||||
// 在执行重度操作前,最后检查一次上下文是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
|
||||
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
|
||||
keyToGroupMap := make(map[string]any)
|
||||
allKeyIDsSet := make(map[string]struct{})
|
||||
for _, group := range pool.CandidateGroups {
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
|
||||
|
||||
// --- [核心修正] ---
|
||||
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
|
||||
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
|
||||
groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
|
||||
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
|
||||
// 从而给了下一次请求一个全新的、成功的机会。
|
||||
return err
|
||||
r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
|
||||
continue
|
||||
}
|
||||
// 只有在 SMembers 成功时,才继续处理
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
|
||||
for _, keyIDStr := range groupKeyIDs {
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if err == nil && mapping != nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
|
||||
for _, keyID := range groupKeyIDs {
|
||||
if _, exists := allKeyIDsSet[keyID]; !exists {
|
||||
allKeyIDsSet[keyID] = struct{}{}
|
||||
keyToGroupMap[keyID] = groupIDStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- [逻辑修正] ---
|
||||
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
|
||||
// 才允许写入“毒丸”。
|
||||
if len(allActiveKeyIDs) == 0 {
|
||||
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
|
||||
pipe := r.store.Pipeline()
|
||||
pipe.LPush(listKey, EmptyPoolPlaceholder)
|
||||
pipe.Expire(listKey, EmptyCacheTTL)
|
||||
if err := pipe.Exec(); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
|
||||
// 处理空池情况
|
||||
if len(allKeyIDsSet) == 0 {
|
||||
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
|
||||
if emptyCacheTTL < time.Minute {
|
||||
emptyCacheTTL = time.Minute
|
||||
}
|
||||
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
|
||||
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
// 使用管道填充所有轮询结构
|
||||
pipe := r.store.Pipeline()
|
||||
// 1. 顺序
|
||||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
// 2. 随机
|
||||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
|
||||
// 设置合理的过期时间,例如5分钟,以防止孤儿数据
|
||||
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
|
||||
|
||||
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
|
||||
for keyID := range allKeyIDsSet {
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
|
||||
}
|
||||
// 使用 Pipeline 原子化构建所有缓存结构
|
||||
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
|
||||
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
|
||||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
|
||||
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
|
||||
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
|
||||
if len(keyToGroupMap) > 0 {
|
||||
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
|
||||
pipe.Expire(keyToGroupMapKey, basePoolTTL)
|
||||
}
|
||||
pipe.Expire(mainPoolKey, basePoolTTL)
|
||||
pipe.Expire(sequentialKey, basePoolTTL)
|
||||
pipe.Expire(cooldownKey, basePoolTTL)
|
||||
pipe.Expire(lruKey, basePoolTTL)
|
||||
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
|
||||
if err := pipe.Exec(); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
|
||||
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(lruMembers) > 0 {
|
||||
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
|
||||
}
|
||||
// 异步填充 LRU 缓存,并传入已构建好的映射
|
||||
go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 辅助方法 ---
|
||||
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
|
||||
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
|
||||
const (
|
||||
lockTTL = 30 * time.Second
|
||||
lockMaxRetries = 5
|
||||
lockBaseBackoff = 50 * time.Millisecond
|
||||
)
|
||||
for i := 0; i < lockMaxRetries; i++ {
|
||||
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
|
||||
return err
|
||||
}
|
||||
if acquired {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(lockBaseBackoff * (1 << i))
|
||||
}
|
||||
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
|
||||
}
|
||||
|
||||
// releaseLock 封装了分布式锁的释放逻辑。
|
||||
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
|
||||
if err := r.store.Del(ctx, lockKey); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
|
||||
}
|
||||
}
|
||||
|
||||
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
|
||||
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), duration)
|
||||
}
|
||||
|
||||
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
|
||||
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
|
||||
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
|
||||
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
|
||||
// 使用 EXPIRE 命令来刷新,如果Key不存在,它什么也不做,是安全的
|
||||
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
|
||||
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
|
||||
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
|
||||
func (r *gormKeyRepository) populateBasePoolLRUCache(
|
||||
parentCtx context.Context,
|
||||
currentPoolID string,
|
||||
keys []string,
|
||||
keyToGroupMap map[string]any,
|
||||
) {
|
||||
lruMembers := make(map[string]float64, len(keys))
|
||||
for _, keyIDStr := range keys {
|
||||
select {
|
||||
case <-parentCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
|
||||
mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
|
||||
data, err := r.store.Get(parentCtx, mappingKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var mapping models.GroupAPIKeyMapping
|
||||
if json.Unmarshal(data, &mapping) == nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
}
|
||||
}
|
||||
if len(lruMembers) > 0 {
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
|
||||
if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
|
||||
if parentCtx.Err() == nil {
|
||||
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
r.store.ZAdd(lruKey, map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
|
||||
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
|
||||
}
|
||||
}
|
||||
|
||||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||||
func generatePoolID(groups []*models.KeyGroup) string {
|
||||
func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
|
||||
ids := make([]int, len(groups))
|
||||
for i, g := range groups {
|
||||
ids[i] = int(g.ID)
|
||||
}
|
||||
sort.Ints(ids)
|
||||
|
||||
h := sha1.New()
|
||||
io.WriteString(h, fmt.Sprintf("%v", ids))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// toInterfaceSlice 类型转换辅助函数
|
||||
func toInterfaceSlice(slice []string) []interface{} {
|
||||
func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
|
||||
result := make([]interface{}, len(slice))
|
||||
for i, v := range slice {
|
||||
result[i] = v
|
||||
@@ -280,13 +399,13 @@ func toInterfaceSlice(slice []string) []interface{} {
|
||||
}
|
||||
|
||||
// nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略
|
||||
func nowMilli() float64 {
|
||||
func (r *gormKeyRepository) nowMilli() float64 {
|
||||
return float64(time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
|
||||
}
|
||||
@@ -295,7 +414,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
|
||||
}
|
||||
|
||||
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
// Filename: internal/repository/key_writer.go
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -9,7 +11,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
timestamp := float64(time.Now().UnixMilli())
|
||||
|
||||
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
strconv.FormatUint(uint64(keyID), 10): timestamp,
|
||||
}
|
||||
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
|
||||
|
||||
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(lruKey, keyIDStr)
|
||||
_ = r.store.SRem(mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
|
||||
_ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(ctx, lruKey, keyIDStr)
|
||||
_ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
|
||||
|
||||
if newStatus == models.StatusActive {
|
||||
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
|
||||
if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
|
||||
}
|
||||
members := map[string]float64{keyIDStr: 0}
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
|
||||
}
|
||||
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
|
||||
if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
|
||||
}
|
||||
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
|
||||
|
||||
// This call is correct. It uses the synchronous, direct method.
|
||||
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -22,8 +24,8 @@ type BasePool struct {
|
||||
|
||||
type KeyRepository interface {
|
||||
// --- 核心选取与调度 --- key_selector
|
||||
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
|
||||
// --- 加密与解密 --- key_crud
|
||||
Decrypt(key *models.APIKey) error
|
||||
@@ -37,16 +39,16 @@ type KeyRepository interface {
|
||||
GetKeyByID(id uint) (*models.APIKey, error)
|
||||
GetKeyByValue(keyValue string) (*models.APIKey, error)
|
||||
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error)
|
||||
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
|
||||
CountByGroup(groupID uint) (int64, error)
|
||||
|
||||
// --- 多对多关系管理 --- key_mapping
|
||||
LinkKeysToGroup(groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(keyID uint) ([]uint, error)
|
||||
LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
|
||||
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
|
||||
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
|
||||
UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
|
||||
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
|
||||
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
|
||||
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
|
||||
@@ -55,8 +57,8 @@ type KeyRepository interface {
|
||||
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 缓存管理 --- key_cache
|
||||
LoadAllKeysToStore() error
|
||||
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
|
||||
LoadAllKeysToStore(ctx context.Context) error
|
||||
HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 维护与后台任务 --- key_maintenance
|
||||
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
|
||||
@@ -65,16 +67,14 @@ type KeyRepository interface {
|
||||
DeleteOrphanKeys() (int64, error)
|
||||
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
|
||||
GetActiveMasterKeys() ([]*models.APIKey, error)
|
||||
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
|
||||
UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
|
||||
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
|
||||
|
||||
// --- 轮询策略的"写"操作 --- key_writer
|
||||
UpdateKeyUsageTimestamp(groupID, keyID uint)
|
||||
// 同步更新缓存,供核心业务使用
|
||||
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
// 异步更新缓存,供事件订阅者使用
|
||||
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
|
||||
SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
}
|
||||
|
||||
type GroupRepository interface {
|
||||
@@ -88,18 +88,20 @@ type gormKeyRepository struct {
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
crypto *crypto.Service
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
type gormGroupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
|
||||
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository {
|
||||
return &gormKeyRepository{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "repository.key🔗"),
|
||||
crypto: crypto,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/service"
|
||||
"time"
|
||||
@@ -15,7 +16,6 @@ type Scheduler struct {
|
||||
logger *logrus.Entry
|
||||
statsService *service.StatsService
|
||||
keyRepo repository.KeyRepository
|
||||
// healthCheckService *service.HealthCheckService // 健康检查任务预留
|
||||
}
|
||||
|
||||
func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler {
|
||||
@@ -32,11 +32,13 @@ func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyReposito
|
||||
func (s *Scheduler) Start() {
|
||||
s.logger.Info("Starting scheduler and registering jobs...")
|
||||
|
||||
// --- 任务注册 ---
|
||||
// 任务一:每小时执行一次的统计聚合
|
||||
// 使用CRON表达式,精确定义“每小时的第5分钟”执行
|
||||
_, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() {
|
||||
s.logger.Info("Executing hourly request stats aggregation...")
|
||||
if err := s.statsService.AggregateHourlyStats(); err != nil {
|
||||
// 为后台定时任务创建一个新的、空的 context
|
||||
ctx := context.Background()
|
||||
if err := s.statsService.AggregateHourlyStats(ctx); err != nil {
|
||||
s.logger.WithError(err).Error("Hourly stats aggregation failed.")
|
||||
} else {
|
||||
s.logger.Info("Hourly stats aggregation completed successfully.")
|
||||
@@ -46,23 +48,14 @@ func (s *Scheduler) Start() {
|
||||
s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err)
|
||||
}
|
||||
|
||||
// 任务二:(预留) 自动健康检查 (例如:每10分钟一次)
|
||||
/*
|
||||
_, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() {
|
||||
s.logger.Info("Executing periodic health check for all groups...")
|
||||
// s.healthCheckService.StartGlobalCheckTask() // 伪代码
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err)
|
||||
}
|
||||
*/
|
||||
// [NEW] --- 任务三: 清理软删除的API Keys ---
|
||||
// 任务二:(预留) 自动健康检查
|
||||
|
||||
// 任务三:每日执行一次的软删除Key清理
|
||||
// Executes once daily at 3:15 AM UTC.
|
||||
_, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() {
|
||||
s.logger.Info("Executing daily cleanup of soft-deleted API keys...")
|
||||
|
||||
// Let's assume a retention period of 7 days for now.
|
||||
// In a real scenario, this should come from settings.
|
||||
// [假设保留7天,实际应来自配置
|
||||
const retentionDays = 7
|
||||
|
||||
count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays))
|
||||
@@ -77,9 +70,8 @@ func (s *Scheduler) Start() {
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err)
|
||||
}
|
||||
// --- 任务注册结束 ---
|
||||
|
||||
s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程
|
||||
s.gocronScheduler.StartAsync()
|
||||
s.logger.Info("Scheduler started.")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Filename: internal/service/analytics_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
@@ -43,7 +43,7 @@ func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d di
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Start() {
|
||||
s.wg.Add(2) // 2 (flushLoop, eventListener)
|
||||
s.wg.Add(2)
|
||||
go s.flushLoop()
|
||||
go s.eventListener()
|
||||
s.logger.Info("AnalyticsService (Command Side) started.")
|
||||
@@ -53,13 +53,13 @@ func (s *AnalyticsService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
|
||||
s.flushToDB() // 停止前刷盘
|
||||
s.flushToDB()
|
||||
s.logger.Info("AnalyticsService final data flush completed.")
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) eventListener() {
|
||||
defer s.wg.Done()
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
sub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
@@ -87,9 +87,10 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
|
||||
if event.RequestLog.GroupID == nil {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
|
||||
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
|
||||
pipe := s.store.Pipeline()
|
||||
pipe := s.store.Pipeline(ctx)
|
||||
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
|
||||
if event.RequestLog.IsSuccess {
|
||||
pipe.HIncrBy(key, fieldPrefix+":success", 1)
|
||||
@@ -120,6 +121,7 @@ func (s *AnalyticsService) flushLoop() {
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) flushToDB() {
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
keysToFlush := []string{
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
|
||||
@@ -127,7 +129,7 @@ func (s *AnalyticsService) flushToDB() {
|
||||
}
|
||||
|
||||
for _, key := range keysToFlush {
|
||||
data, err := s.store.HGetAll(key)
|
||||
data, err := s.store.HGetAll(ctx, key)
|
||||
if err != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -136,15 +138,15 @@ func (s *AnalyticsService) flushToDB() {
|
||||
|
||||
if len(statsToFlush) > 0 {
|
||||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||||
[]string{"time", "group_id", "model_name"}, // conflict columns
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
|
||||
[]string{"time", "group_id", "model_name"},
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
|
||||
)
|
||||
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
|
||||
err := s.db.WithContext(ctx).Clauses(upsertClause).Create(&statsToFlush).Error
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
|
||||
_ = s.store.HDel(key, parsedFields...)
|
||||
_ = s.store.HDel(ctx, key, parsedFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Filename: internal/service/apikey_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -29,7 +29,6 @@ const (
|
||||
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
|
||||
)
|
||||
|
||||
// DTOs & Constants
|
||||
const (
|
||||
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
)
|
||||
@@ -83,7 +82,6 @@ func NewAPIKeyService(
|
||||
gm *GroupManager,
|
||||
logger *logrus.Logger,
|
||||
) *APIKeyService {
|
||||
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
|
||||
return &APIKeyService{
|
||||
db: db,
|
||||
keyRepo: repo,
|
||||
@@ -99,22 +97,22 @@ func NewAPIKeyService(
|
||||
}
|
||||
|
||||
func (s *APIKeyService) Start() {
|
||||
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
requestSub, err := s.store.Subscribe(context.Background(), models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
|
||||
masterKeySub, err := s.store.Subscribe(context.Background(), models.TopicMasterKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
keyStatusSub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
|
||||
importSub, err := s.store.Subscribe(context.Background(), models.TopicImportGroupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
|
||||
return
|
||||
@@ -177,6 +175,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
|
||||
if event.RequestLog.KeyID == nil || event.RequestLog.GroupID == nil {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
if event.RequestLog.IsSuccess {
|
||||
mapping, err := s.keyRepo.GetMapping(*event.RequestLog.GroupID, *event.RequestLog.KeyID)
|
||||
if err != nil {
|
||||
@@ -194,17 +193,18 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
|
||||
|
||||
now := time.Now()
|
||||
mapping.LastUsedAt = &now
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, *event.RequestLog.GroupID, *event.RequestLog.KeyID, err)
|
||||
return
|
||||
}
|
||||
if statusChanged {
|
||||
go s.publishStatusChangeEvent(*event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
||||
go s.publishStatusChangeEvent(ctx, *event.RequestLog.GroupID, *event.RequestLog.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
||||
}
|
||||
return
|
||||
}
|
||||
if event.Error != nil {
|
||||
s.judgeKeyErrors(
|
||||
ctx,
|
||||
event.CorrelationID,
|
||||
*event.RequestLog.GroupID,
|
||||
*event.RequestLog.KeyID,
|
||||
@@ -215,6 +215,7 @@ func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
|
||||
ctx := context.Background()
|
||||
log := s.logger.WithFields(logrus.Fields{
|
||||
"group_id": event.GroupID,
|
||||
"key_id": event.KeyID,
|
||||
@@ -222,11 +223,11 @@ func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChange
|
||||
"reason": event.ChangeReason,
|
||||
})
|
||||
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
|
||||
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
|
||||
s.keyRepo.HandleCacheUpdateEvent(ctx, event.GroupID, event.KeyID, event.NewStatus)
|
||||
log.Info("Polling caches updated based on health check event.")
|
||||
}
|
||||
|
||||
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
func (s *APIKeyService) publishStatusChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
changeEvent := models.KeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
GroupID: groupID,
|
||||
@@ -236,13 +237,12 @@ func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(changeEvent)
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
|
||||
// --- Path 1: High-performance DB pagination (no keyword) ---
|
||||
func (s *APIKeyService) ListAPIKeys(ctx context.Context, params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
|
||||
if params.Keyword == "" {
|
||||
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
|
||||
if err != nil {
|
||||
@@ -260,14 +260,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
// --- Path 2: In-memory search (keyword present) ---
|
||||
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
||||
// To get all keys, we fetch all IDs first, then get their full details.
|
||||
var statusesToFilter []string
|
||||
if params.Status != "" {
|
||||
statusesToFilter = append(statusesToFilter, params.Status)
|
||||
} else {
|
||||
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
|
||||
statusesToFilter = append(statusesToFilter, "all")
|
||||
}
|
||||
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
|
||||
if err != nil {
|
||||
@@ -277,14 +275,12 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
|
||||
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
|
||||
}
|
||||
|
||||
// This is the heavy operation: getting all keys and decrypting them.
|
||||
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
|
||||
}
|
||||
// We also need mappings to build the final `APIKeyDetails`.
|
||||
var allMappings []models.GroupAPIKeyMapping
|
||||
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
|
||||
err = s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
|
||||
}
|
||||
@@ -292,7 +288,6 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
|
||||
for i := range allMappings {
|
||||
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
|
||||
}
|
||||
// Filter the results in memory.
|
||||
var filteredItems []*models.APIKeyDetails
|
||||
for _, key := range allKeys {
|
||||
if strings.Contains(key.APIKey, params.Keyword) {
|
||||
@@ -312,11 +307,9 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort the filtered results to ensure consistent pagination (by ID descending).
|
||||
sort.Slice(filteredItems, func(i, j int) bool {
|
||||
return filteredItems[i].ID > filteredItems[j].ID
|
||||
})
|
||||
// Manually paginate the filtered results.
|
||||
total := int64(len(filteredItems))
|
||||
start := (params.Page - 1) * params.PageSize
|
||||
end := start + params.PageSize
|
||||
@@ -345,14 +338,15 @@ func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*Paginate
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) GetKeysByIds(ids []uint) ([]models.APIKey, error) {
|
||||
func (s *APIKeyService) GetKeysByIds(ctx context.Context, ids []uint) ([]models.APIKey, error) {
|
||||
return s.keyRepo.GetKeysByIDs(ids)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
|
||||
func (s *APIKeyService) UpdateAPIKey(ctx context.Context, key *models.APIKey) error {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
var oldKey models.APIKey
|
||||
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
|
||||
if err := s.db.WithContext(bgCtx).First(&oldKey, key.ID).Error; err != nil {
|
||||
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -364,16 +358,14 @@ func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
|
||||
// Get all associated groups before deletion to publish correct events
|
||||
groups, err := s.keyRepo.GetGroupsForKey(id)
|
||||
func (s *APIKeyService) HardDeleteAPIKeyByID(ctx context.Context, id uint) error {
|
||||
groups, err := s.keyRepo.GetGroupsForKey(ctx, id)
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
|
||||
}
|
||||
|
||||
err = s.keyRepo.HardDeleteByID(id)
|
||||
if err == nil {
|
||||
// Publish events for each group the key was a part of
|
||||
for _, groupID := range groups {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: id,
|
||||
@@ -381,13 +373,13 @@ func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
|
||||
ChangeReason: "key_hard_deleted",
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
go s.store.Publish(context.Background(), models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
||||
func (s *APIKeyService) UpdateMappingStatus(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
@@ -409,19 +401,20 @@ func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus model
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
}
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "manual_update")
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
|
||||
ctx := context.Background()
|
||||
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
|
||||
if event.NewMasterStatus != models.MasterStatusRevoked {
|
||||
return
|
||||
}
|
||||
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
|
||||
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(ctx, event.KeyID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
|
||||
return
|
||||
@@ -432,7 +425,7 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
|
||||
}
|
||||
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
|
||||
for _, groupID := range affectedGroupIDs {
|
||||
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
|
||||
_, err := s.UpdateMappingStatus(ctx, groupID, event.KeyID, models.StatusBanned)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
|
||||
@@ -441,32 +434,32 @@ func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKey
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
|
||||
func (s *APIKeyService) StartRestoreKeysTask(ctx context.Context, groupID uint, keyIDs []uint) (*task.Status, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
|
||||
go s.runRestoreKeysTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
|
||||
func (s *APIKeyService) runRestoreKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keyIDs []uint) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
|
||||
}
|
||||
}()
|
||||
var mappingsToProcess []models.GroupAPIKeyMapping
|
||||
err := s.db.Preload("APIKey").
|
||||
err := s.db.WithContext(ctx).Preload("APIKey").
|
||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||
Find(&mappingsToProcess).Error
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
result := &BatchRestoreResult{
|
||||
@@ -476,7 +469,7 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
|
||||
processedCount := 0
|
||||
for _, mapping := range mappingsToProcess {
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
||||
if mapping.APIKey == nil {
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
|
||||
@@ -492,33 +485,29 @@ func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, gro
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
// Use the version that doesn't trigger individual cache updates.
|
||||
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
|
||||
} else {
|
||||
result.RestoredCount++
|
||||
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
|
||||
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
|
||||
successfulMappings = append(successfulMappings, &mapping)
|
||||
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
|
||||
}
|
||||
} else {
|
||||
result.RestoredCount++ // Already active, count as success.
|
||||
result.RestoredCount++
|
||||
}
|
||||
}
|
||||
// After the loop, perform one single, efficient cache update.
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
|
||||
// This is not a task-fatal error, so we just log it and continue.
|
||||
}
|
||||
// Account for keys that were requested but not found in the initial DB query.
|
||||
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
|
||||
func (s *APIKeyService) StartRestoreAllBannedTask(ctx context.Context, groupID uint) (*task.Status, error) {
|
||||
var bannedKeyIDs []uint
|
||||
err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
|
||||
Pluck("api_key_id", &bannedKeyIDs).Error
|
||||
if err != nil {
|
||||
@@ -527,10 +516,11 @@ func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, e
|
||||
if len(bannedKeyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
|
||||
}
|
||||
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
|
||||
return s.StartRestoreKeysTask(ctx, groupID, bannedKeyIDs)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
|
||||
ctx := context.Background()
|
||||
group, ok := s.groupManager.GetGroupByID(event.GroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
|
||||
@@ -552,7 +542,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10 // Safety fallback
|
||||
concurrency = 10
|
||||
}
|
||||
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
|
||||
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
|
||||
@@ -571,7 +561,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
||||
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
|
||||
if validationErr == nil {
|
||||
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
|
||||
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
|
||||
if _, err := s.UpdateMappingStatus(ctx, event.GroupID, key.ID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
|
||||
}
|
||||
} else {
|
||||
@@ -579,7 +569,7 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
||||
if !CustomErrors.As(validationErr, &apiErr) {
|
||||
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
}
|
||||
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
|
||||
s.judgeKeyErrors(ctx, "", event.GroupID, key.ID, apiErr, false)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -592,12 +582,9 @@ func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupComp
|
||||
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
|
||||
}
|
||||
|
||||
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
|
||||
// that match a specific set of source statuses within a group.
|
||||
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
||||
func (s *APIKeyService) StartUpdateStatusByFilterTask(ctx context.Context, groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
||||
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
|
||||
|
||||
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
|
||||
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
@@ -605,35 +592,32 @@ func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatus
|
||||
if len(keyIDs) == 0 {
|
||||
now := time.Now()
|
||||
return &task.Status{
|
||||
IsRunning: false, // The "task" is not running.
|
||||
IsRunning: false,
|
||||
Processed: 0,
|
||||
Total: 0,
|
||||
Result: map[string]string{ // We use the flexible Result field to pass the message.
|
||||
Result: map[string]string{
|
||||
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
|
||||
},
|
||||
Error: "", // There is no error.
|
||||
Error: "",
|
||||
StartedAt: now,
|
||||
FinishedAt: &now, // It started and finished at the same time.
|
||||
}, nil // Return nil for the error, signaling a 200 OK.
|
||||
FinishedAt: &now,
|
||||
}, nil
|
||||
}
|
||||
// 2. Start a new task using the TaskService, following existing patterns.
|
||||
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err // Pass up errors like "task already in progress".
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Run the core logic in a separate goroutine.
|
||||
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
|
||||
go s.runUpdateStatusByFilterTask(ctx, taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
|
||||
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
||||
func (s *APIKeyService) runUpdateStatusByFilterTask(ctx context.Context, taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
|
||||
}
|
||||
}()
|
||||
type BatchUpdateResult struct {
|
||||
@@ -642,31 +626,27 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
|
||||
}
|
||||
result := &BatchUpdateResult{}
|
||||
var successfulMappings []*models.GroupAPIKeyMapping
|
||||
// 1. Fetch all key master statuses in one go. This is efficient.
|
||||
|
||||
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
|
||||
for _, key := range keys {
|
||||
masterStatusMap[key.ID] = key.MasterStatus
|
||||
}
|
||||
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
|
||||
// avoiding the need for a new repository method. This pattern is
|
||||
// already used in other parts of this service.
|
||||
var mappings []*models.GroupAPIKeyMapping
|
||||
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
processedCount := 0
|
||||
for _, mapping := range mappings {
|
||||
processedCount++
|
||||
// The progress update should reflect the number of items *being processed*, not the final count.
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
||||
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
result.SkippedCount++
|
||||
@@ -688,24 +668,25 @@ func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, g
|
||||
} else {
|
||||
result.UpdatedCount++
|
||||
successfulMappings = append(successfulMappings, mapping)
|
||||
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
|
||||
go s.publishStatusChangeEvent(ctx, groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
|
||||
}
|
||||
} else {
|
||||
result.UpdatedCount++ // Already in desired state, count as success.
|
||||
result.UpdatedCount++
|
||||
}
|
||||
}
|
||||
result.SkippedCount += (len(keyIDs) - len(mappings))
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(ctx, successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
|
||||
}
|
||||
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
|
||||
ctx := context.Background()
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
go s.keyRepo.UpdateKeyUsageTimestamp(ctx, group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -716,26 +697,20 @@ func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.
|
||||
errMsg := apiErr.Message
|
||||
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
|
||||
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
|
||||
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
go s.keyRepo.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
|
||||
} else {
|
||||
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
|
||||
func sanitizeForLog(errMsg string) string {
|
||||
// Find the start of any potential JSON blob or detailed structure.
|
||||
jsonStartIndex := strings.Index(errMsg, "{")
|
||||
var cleanMsg string
|
||||
if jsonStartIndex != -1 {
|
||||
// If a '{' is found, take everything before it as the summary
|
||||
// and append a simple placeholder.
|
||||
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
|
||||
} else {
|
||||
// If no JSON-like structure is found, use the original message.
|
||||
cleanMsg = errMsg
|
||||
}
|
||||
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
|
||||
const maxLen = 250
|
||||
if len(cleanMsg) > maxLen {
|
||||
return cleanMsg[:maxLen] + "..."
|
||||
@@ -744,6 +719,7 @@ func sanitizeForLog(errMsg string) string {
|
||||
}
|
||||
|
||||
func (s *APIKeyService) judgeKeyErrors(
|
||||
ctx context.Context,
|
||||
correlationID string,
|
||||
groupID, keyID uint,
|
||||
apiErr *CustomErrors.APIError,
|
||||
@@ -765,11 +741,11 @@ func (s *APIKeyService) judgeKeyErrors(
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = errorMessage
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
|
||||
} else {
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
|
||||
go s.revokeMasterKey(keyID, "permanent_upstream_error")
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
|
||||
go s.revokeMasterKey(ctx, keyID, "permanent_upstream_error")
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -801,23 +777,23 @@ func (s *APIKeyService) judgeKeyErrors(
|
||||
if oldStatus != newStatus {
|
||||
mapping.Status = newStatus
|
||||
}
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping after temporary error.")
|
||||
return
|
||||
}
|
||||
if oldStatus != newStatus {
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
||||
go s.publishStatusChangeEvent(ctx, groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
|
||||
func (s *APIKeyService) revokeMasterKey(ctx context.Context, keyID uint, reason string) {
|
||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -832,7 +808,7 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
|
||||
}
|
||||
oldMasterStatus := key.MasterStatus
|
||||
newMasterStatus := models.MasterStatusRevoked
|
||||
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, keyID, newMasterStatus); err != nil {
|
||||
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
|
||||
return
|
||||
}
|
||||
@@ -844,9 +820,9 @@ func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(masterKeyEvent)
|
||||
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
|
||||
_ = s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
|
||||
func (s *APIKeyService) GetAPIKeyStringsForExport(ctx context.Context, groupID uint, statuses []string) ([]string, error) {
|
||||
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Filename: internal/service/dashboard_query_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
@@ -17,8 +17,6 @@ import (
|
||||
|
||||
const overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||||
|
||||
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
|
||||
|
||||
type DashboardQueryService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
@@ -54,9 +52,9 @@ func (s *DashboardQueryService) Stop() {
|
||||
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
|
||||
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
keyStatsMap, err := s.store.HGetAll(statsKey)
|
||||
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
|
||||
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
|
||||
@@ -74,11 +72,11 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
|
||||
SuccessRequests int64
|
||||
}
|
||||
var last1Hour, last24Hours requestStatsResult
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||||
Scan(&last1Hour)
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||||
Scan(&last24Hours)
|
||||
@@ -109,8 +107,9 @@ func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, err
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) eventListener() {
|
||||
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
|
||||
ctx := context.Background()
|
||||
keyStatusSub, _ := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, _ := s.store.Subscribe(ctx, models.TopicUpstreamHealthChanged)
|
||||
defer keyStatusSub.Close()
|
||||
defer upstreamStatusSub.Close()
|
||||
for {
|
||||
@@ -128,7 +127,6 @@ func (s *DashboardQueryService) eventListener() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
|
||||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
cachedDataPtr := s.overviewSyncer.Get()
|
||||
if cachedDataPtr == nil {
|
||||
@@ -141,8 +139,7 @@ func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||||
return s.overviewSyncer.Invalidate()
|
||||
}
|
||||
|
||||
// QueryHistoricalChart 查询历史图表数据。
|
||||
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
|
||||
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
|
||||
type ChartPoint struct {
|
||||
TimeLabel string `gorm:"column:time_label"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
@@ -151,7 +148,7 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
|
||||
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
|
||||
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
|
||||
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
|
||||
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
|
||||
query := s.db.WithContext(ctx).Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
|
||||
if groupID != nil && *groupID > 0 {
|
||||
query = query.Where("group_id = ?", *groupID)
|
||||
}
|
||||
@@ -189,38 +186,38 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
ctx := context.Background()
|
||||
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
|
||||
startTime := time.Now()
|
||||
resp := &models.DashboardStatsResponse{
|
||||
KeyStatusCount: make(map[models.APIKeyStatus]int64),
|
||||
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
|
||||
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体,而不是nil
|
||||
RequestCount24h: models.StatCard{}, // 同上
|
||||
KeyCount: models.StatCard{},
|
||||
RequestCount24h: models.StatCard{},
|
||||
TokenCount: make(map[string]any),
|
||||
UpstreamHealthStatus: make(map[string]string),
|
||||
RPM: models.StatCard{},
|
||||
RequestCounts: make(map[string]int64),
|
||||
}
|
||||
// --- 1. Aggregate Operational Status from Mappings ---
|
||||
|
||||
type MappingStatusResult struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var mappingStatusResults []MappingStatusResult
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
|
||||
}
|
||||
for _, res := range mappingStatusResults {
|
||||
resp.KeyStatusCount[res.Status] = res.Count
|
||||
}
|
||||
|
||||
// --- 2. Aggregate Master Status from APIKeys ---
|
||||
type MasterStatusResult struct {
|
||||
Status models.MasterAPIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var masterStatusResults []MasterStatusResult
|
||||
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query master status stats: %w", err)
|
||||
}
|
||||
var totalKeys, invalidKeys int64
|
||||
@@ -235,20 +232,15 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
|
||||
var count1m, count1h, count1d int64
|
||||
// RPM: 从此刻倒推1分钟
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
|
||||
// RPH: 从此刻倒推1小时
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
|
||||
|
||||
// RPD: 从今天零点 (UTC) 到此刻
|
||||
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
|
||||
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
|
||||
year, month, day := now.UTC().Date()
|
||||
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
|
||||
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
|
||||
s.db.WithContext(ctx).Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
|
||||
|
||||
var count30d int64
|
||||
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
|
||||
s.db.WithContext(ctx).Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
|
||||
|
||||
resp.RequestCounts["1m"] = count1m
|
||||
resp.RequestCounts["1h"] = count1h
|
||||
@@ -256,7 +248,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
|
||||
resp.RequestCounts["30d"] = count30d
|
||||
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
|
||||
} else {
|
||||
for _, u := range upstreams {
|
||||
@@ -269,7 +261,7 @@ func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsRespon
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
|
||||
var startTime time.Time
|
||||
now := time.Now()
|
||||
switch period {
|
||||
@@ -288,7 +280,7 @@ func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H,
|
||||
Success int64
|
||||
}
|
||||
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
||||
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
|
||||
Where("request_time >= ?", startTime).
|
||||
Scan(&result).Error
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Filename: internal/service/db_log_writer_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
@@ -35,35 +35,30 @@ func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.Settin
|
||||
store: s,
|
||||
SettingsManager: settings,
|
||||
logger: logger.WithField("component", "DBLogWriter📝"),
|
||||
// 使用配置值来创建缓冲区
|
||||
logBuffer: make(chan *models.RequestLog, bufferCapacity),
|
||||
stopChan: make(chan struct{}),
|
||||
logBuffer: make(chan *models.RequestLog, bufferCapacity),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Start() {
|
||||
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
|
||||
|
||||
// 启动事件监听器
|
||||
s.wg.Add(2)
|
||||
go s.eventListenerLoop()
|
||||
// 启动数据库写入器
|
||||
go s.dbWriterLoop()
|
||||
|
||||
s.logger.Info("DBLogWriterService started.")
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Stop() {
|
||||
s.logger.Info("DBLogWriterService stopping...")
|
||||
close(s.stopChan) // 通知所有goroutine停止
|
||||
s.wg.Wait() // 等待所有goroutine完成
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
s.logger.Info("DBLogWriterService stopped.")
|
||||
}
|
||||
|
||||
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
|
||||
func (s *DBLogWriterService) eventListenerLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
ctx := context.Background()
|
||||
sub, err := s.store.Subscribe(ctx, models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
@@ -80,34 +75,27 @@ func (s *DBLogWriterService) eventListenerLoop() {
|
||||
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 将事件中的日志部分放入缓冲区
|
||||
select {
|
||||
case s.logBuffer <- &event.RequestLog:
|
||||
default:
|
||||
s.logger.Warn("Log buffer is full. A log message might be dropped.")
|
||||
}
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Event listener loop stopping.")
|
||||
// 关闭缓冲区,以通知dbWriterLoop处理完剩余日志后退出
|
||||
close(s.logBuffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
|
||||
func (s *DBLogWriterService) dbWriterLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 在启动时获取一次配置
|
||||
cfg := s.SettingsManager.GetSettings()
|
||||
batchSize := cfg.LogFlushBatchSize
|
||||
if batchSize <= 0 {
|
||||
batchSize = 100
|
||||
}
|
||||
|
||||
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if flushTimeout <= 0 {
|
||||
flushTimeout = 5 * time.Second
|
||||
@@ -126,7 +114,7 @@ func (s *DBLogWriterService) dbWriterLoop() {
|
||||
return
|
||||
}
|
||||
batch = append(batch, logEntry)
|
||||
if len(batch) >= batchSize { // 使用配置的批次大小
|
||||
if len(batch) >= batchSize {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
@@ -139,7 +127,6 @@ func (s *DBLogWriterService) dbWriterLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// flushBatch 将一个批次的日志写入数据库
|
||||
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
|
||||
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
|
||||
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
|
||||
|
||||
@@ -75,7 +75,7 @@ func NewHealthCheckService(
|
||||
|
||||
func (s *HealthCheckService) Start() {
|
||||
s.logger.Info("Starting HealthCheckService with independent check loops...")
|
||||
s.wg.Add(4) // Now four loops
|
||||
s.wg.Add(4)
|
||||
go s.runKeyCheckLoop()
|
||||
go s.runUpstreamCheckLoop()
|
||||
go s.runProxyCheckLoop()
|
||||
@@ -102,8 +102,6 @@ func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
|
||||
func (s *HealthCheckService) runKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Key check dynamic scheduler loop started.")
|
||||
|
||||
// 主调度循环,每分钟检查一次任务
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -126,26 +124,22 @@ func (s *HealthCheckService) scheduleKeyChecks() {
|
||||
defer s.groupCheckTimeMutex.Unlock()
|
||||
|
||||
for _, group := range groups {
|
||||
// 获取特定于组的运营配置
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
|
||||
continue
|
||||
}
|
||||
|
||||
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
|
||||
continue // 跳过禁用了健康检查的组
|
||||
continue
|
||||
}
|
||||
|
||||
var intervalMinutes int
|
||||
if opConfig.KeyCheckIntervalMinutes != nil {
|
||||
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
|
||||
}
|
||||
interval := time.Duration(intervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
continue // 跳过无效的检查周期
|
||||
continue
|
||||
}
|
||||
|
||||
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
|
||||
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
|
||||
go s.performKeyChecksForGroup(group, opConfig)
|
||||
@@ -160,7 +154,6 @@ func (s *HealthCheckService) runUpstreamCheckLoop() {
|
||||
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
|
||||
s.performUpstreamChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -184,7 +177,6 @@ func (s *HealthCheckService) runProxyCheckLoop() {
|
||||
if s.SettingsManager.GetSettings().EnableProxyCheck {
|
||||
s.performProxyChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -203,6 +195,7 @@ func (s *HealthCheckService) runProxyCheckLoop() {
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
|
||||
ctx := context.Background()
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
@@ -213,11 +206,9 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
|
||||
}
|
||||
|
||||
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
|
||||
|
||||
log.Infof("Starting key health check cycle.")
|
||||
|
||||
var mappingsToCheck []models.GroupAPIKeyMapping
|
||||
err = s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
err = s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", group.ID).
|
||||
Where("api_keys.master_status = ?", models.MasterStatusActive).
|
||||
@@ -233,7 +224,6 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
|
||||
log.Info("No key mappings to check for this group.")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
|
||||
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
|
||||
var wg sync.WaitGroup
|
||||
@@ -242,14 +232,14 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 1 // 保证至少有一个 worker
|
||||
concurrency = 1
|
||||
}
|
||||
for w := 1; w <= concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
for mapping := range jobs {
|
||||
s.checkAndProcessMapping(&mapping, timeout, endpoint)
|
||||
s.checkAndProcessMapping(ctx, &mapping, timeout, endpoint)
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
@@ -261,52 +251,46 @@ func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, op
|
||||
log.Info("Finished key health check cycle.")
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
|
||||
func (s *HealthCheckService) checkAndProcessMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
|
||||
if mapping.APIKey == nil {
|
||||
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
|
||||
// --- 诊断一:验证成功 (健康) ---
|
||||
if validationErr == nil {
|
||||
if mapping.Status != models.StatusActive {
|
||||
s.activateMapping(mapping)
|
||||
s.activateMapping(ctx, mapping)
|
||||
}
|
||||
return
|
||||
}
|
||||
errorString := validationErr.Error()
|
||||
// --- 诊断二:永久性错误 ---
|
||||
if CustomErrors.IsPermanentUpstreamError(errorString) {
|
||||
s.revokeMapping(mapping, validationErr)
|
||||
s.revokeMapping(ctx, mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断三:暂时性错误 ---
|
||||
if CustomErrors.IsTemporaryUpstreamError(errorString) {
|
||||
// Log with a higher level (WARN) since this is an actionable, proactive finding.
|
||||
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
|
||||
s.penalizeMapping(mapping, validationErr)
|
||||
s.penalizeMapping(ctx, mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断四:其他未知或上游服务错误 ---
|
||||
|
||||
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
|
||||
func (s *HealthCheckService) activateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) {
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
|
||||
if err := s.keyRepo.UpdateMapping(ctx, mapping); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
// Re-fetch group-specific operational config to get the correct thresholds
|
||||
func (s *HealthCheckService) penalizeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
|
||||
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
|
||||
@@ -320,7 +304,6 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
|
||||
oldStatus := mapping.Status
|
||||
mapping.LastError = err.Error()
|
||||
mapping.ConsecutiveErrorCount++
|
||||
// Use the group-specific threshold
|
||||
threshold := *opConfig.KeyBlacklistThreshold
|
||||
if mapping.ConsecutiveErrorCount >= threshold {
|
||||
mapping.Status = models.StatusCooldown
|
||||
@@ -329,44 +312,41 @@ func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping,
|
||||
mapping.CooldownUntil = &cooldownTime
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
|
||||
}
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
if oldStatus != mapping.Status {
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
func (s *HealthCheckService) revokeMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping, err error) {
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus == models.StatusBanned {
|
||||
return // Already banned, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = "Definitive error: " + err.Error()
|
||||
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
|
||||
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
if errDb := s.keyRepo.UpdateMapping(ctx, mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
|
||||
s.publishKeyStatusChangedEvent(ctx, mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(ctx, mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performUpstreamChecks() {
|
||||
ctx := context.Background()
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
|
||||
return
|
||||
}
|
||||
@@ -390,10 +370,10 @@ func (s *HealthCheckService) performUpstreamChecks() {
|
||||
s.lastResultsMutex.Unlock()
|
||||
if oldStatus != newStatus {
|
||||
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
|
||||
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Model(upstream).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
|
||||
} else {
|
||||
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
|
||||
s.publishUpstreamHealthChangedEvent(ctx, upstream, oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}(u)
|
||||
@@ -412,10 +392,11 @@ func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performProxyChecks() {
|
||||
ctx := context.Background()
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
|
||||
var proxies []*models.ProxyConfig
|
||||
if err := s.db.Find(&proxies).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Find(&proxies).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve proxies.")
|
||||
return
|
||||
}
|
||||
@@ -438,7 +419,7 @@ func (s *HealthCheckService) performProxyChecks() {
|
||||
s.lastResultsMutex.Unlock()
|
||||
if proxyCfg.Status != newStatus {
|
||||
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
|
||||
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Model(proxyCfg).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
|
||||
}
|
||||
}
|
||||
@@ -482,7 +463,7 @@ func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout ti
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
|
||||
func (s *HealthCheckService) publishKeyStatusChangedEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
GroupID: groupID,
|
||||
@@ -496,12 +477,12 @@ func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, o
|
||||
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
|
||||
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(ctx context.Context, upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
|
||||
event := models.UpstreamHealthChangedEvent{
|
||||
UpstreamID: upstream.ID,
|
||||
UpstreamURL: upstream.URL,
|
||||
@@ -516,28 +497,20 @@ func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.
|
||||
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
|
||||
if err := s.store.Publish(ctx, models.TopicUpstreamHealthChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Global Base Key Check (New Logic)
|
||||
// =========================================================================
|
||||
|
||||
func (s *HealthCheckService) runBaseKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Global base key check loop started.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
|
||||
if !settings.EnableBaseKeyCheck {
|
||||
s.logger.Info("Global base key check is disabled.")
|
||||
return
|
||||
}
|
||||
|
||||
// Perform an initial check on startup
|
||||
s.performBaseKeyChecks()
|
||||
|
||||
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
|
||||
@@ -558,6 +531,7 @@ func (s *HealthCheckService) runBaseKeyCheckLoop() {
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performBaseKeyChecks() {
|
||||
ctx := context.Background()
|
||||
s.logger.Info("Starting global base key check cycle.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
@@ -576,7 +550,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
|
||||
jobs := make(chan *models.APIKey, len(keys))
|
||||
var wg sync.WaitGroup
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // Safe default
|
||||
concurrency = 5
|
||||
}
|
||||
for w := 0; w < concurrency; w++ {
|
||||
wg.Add(1)
|
||||
@@ -587,10 +561,10 @@ func (s *HealthCheckService) performBaseKeyChecks() {
|
||||
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
|
||||
oldStatus := key.MasterStatus
|
||||
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
|
||||
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
|
||||
if updateErr := s.keyRepo.UpdateAPIKeyStatus(ctx, key.ID, models.MasterStatusRevoked); updateErr != nil {
|
||||
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
|
||||
} else {
|
||||
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
|
||||
s.publishMasterKeyStatusChangedEvent(ctx, key.ID, oldStatus, models.MasterStatusRevoked)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -604,8 +578,7 @@ func (s *HealthCheckService) performBaseKeyChecks() {
|
||||
s.logger.Info("Global base key check cycle finished.")
|
||||
}
|
||||
|
||||
// 事件发布辅助函数
|
||||
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
|
||||
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(ctx context.Context, keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
|
||||
event := models.MasterKeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
OldMasterStatus: oldStatus,
|
||||
@@ -618,7 +591,7 @@ func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldS
|
||||
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
|
||||
if err := s.store.Publish(ctx, models.TopicMasterKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -42,88 +43,84 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
|
||||
}
|
||||
}
|
||||
|
||||
// --- 通用的 Panic-Safe 任務執行器 ---
|
||||
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
|
||||
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
||||
s.logger.Error(err)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
taskFunc()
|
||||
}
|
||||
|
||||
// --- Public Task Starters ---
|
||||
|
||||
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in input text")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||||
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_hard_delete" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||||
resourceID := "global_hard_delete"
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
|
||||
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_restore_keys" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||||
resourceID := "global_restore_keys"
|
||||
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// --- Private Task Runners ---
|
||||
|
||||
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
// 步骤 1: 对输入的原始 key 列表进行去重。
|
||||
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeyStrings []string
|
||||
for _, kStr := range keys {
|
||||
@@ -133,41 +130,37 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
|
||||
}
|
||||
}
|
||||
if len(uniqueKeyStrings) == 0 {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||||
return
|
||||
}
|
||||
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
|
||||
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
||||
for i, keyStr := range uniqueKeyStrings {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
}
|
||||
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
alreadyLinkedIDSet := make(map[uint]struct{})
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
|
||||
var keysToLink []models.APIKey
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
}
|
||||
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
|
||||
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
// 步骤 6: 分块处理【链接Key到组】的操作,并实时更新进度。
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
@@ -179,44 +172,41 @@ func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, grou
|
||||
end = len(idsToLink)
|
||||
}
|
||||
chunk := idsToLink[i:end]
|
||||
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||||
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||||
return
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤 7: 准备最终结果并结束任务。
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": len(alreadyLinkedIDSet),
|
||||
"total_linked_count": len(allKeyModels),
|
||||
}
|
||||
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
if validateOnImport {
|
||||
s.publishImportGroupCompletedEvent(groupID, idsToLink)
|
||||
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runUnlinkKeysTask
|
||||
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
|
||||
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeys []string
|
||||
for _, kStr := range keys {
|
||||
@@ -225,46 +215,42 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
|
||||
uniqueKeys = append(uniqueKeys, kStr)
|
||||
}
|
||||
}
|
||||
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
idsToUnlink := make([]uint, len(keysToUnlink))
|
||||
for i, key := range keysToUnlink {
|
||||
idsToUnlink[i] = key.ID
|
||||
}
|
||||
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
|
||||
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
var totalUnlinked int64
|
||||
// 步骤 3: 分块处理【解绑Key】的操作,并上报进度。
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToUnlink) {
|
||||
end = len(idsToUnlink)
|
||||
}
|
||||
chunk := idsToUnlink[i:end]
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||||
return
|
||||
}
|
||||
totalUnlinked += unlinked
|
||||
|
||||
for _, keyID := range chunk {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
@@ -276,10 +262,10 @@ func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
var totalDeleted int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
@@ -290,22 +276,21 @@ func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys
|
||||
|
||||
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||||
return
|
||||
}
|
||||
totalDeleted += deleted
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
|
||||
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
var restoredCount int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
@@ -316,21 +301,21 @@ func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []
|
||||
|
||||
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||||
return
|
||||
}
|
||||
restoredCount += count
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, i+len(chunk))
|
||||
}
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
KeyID: keyID,
|
||||
@@ -340,7 +325,7 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithError(err).WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
@@ -349,16 +334,16 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
|
||||
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
ChangeReason: reason,
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
_ = s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
|
||||
if len(keyIDs) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -372,17 +357,15 @@ func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs
|
||||
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
||||
} else {
|
||||
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
||||
}
|
||||
}
|
||||
|
||||
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
// 1. [New] Find the keys to operate on.
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
@@ -390,8 +373,7 @@ func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
return s.StartUnlinkKeysTask(groupID, keysAsText)
|
||||
}
|
||||
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
@@ -62,20 +63,18 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
|
||||
s.channel.ModifyRequest(req, key)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// This is a network-level error (e.g., timeout, DNS issue)
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil // Success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the body for more error details
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
var errorMsg string
|
||||
if readErr != nil {
|
||||
@@ -84,7 +83,6 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
|
||||
errorMsg = string(bodyBytes)
|
||||
}
|
||||
|
||||
// This is a validation failure with a specific HTTP status code
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
||||
@@ -92,8 +90,7 @@ func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout tim
|
||||
}
|
||||
}
|
||||
|
||||
// --- 异步任务方法 (全面适配新task包) ---
|
||||
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
keyStrings := utils.ParseKeysFromText(keysText)
|
||||
if len(keyStrings) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
||||
@@ -111,7 +108,6 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
|
||||
}
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
// [FIX] Correctly use the NewAPIError constructor for a missing group.
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
||||
}
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
@@ -119,15 +115,15 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err // Pass up the error from task service (e.g., "task already running")
|
||||
return nil, err
|
||||
}
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
|
||||
s.taskService.EndTaskByID(ctx, taskStatus.ID, resourceID, nil, err)
|
||||
return nil, err
|
||||
}
|
||||
var concurrency int
|
||||
@@ -136,11 +132,11 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
|
||||
} else {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
}
|
||||
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
go s.runTestKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
func (s *KeyValidationService) runTestKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
finalResults := make([]models.KeyTestResult, len(keys))
|
||||
@@ -165,7 +161,6 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
|
||||
var currentResult models.KeyTestResult
|
||||
event := models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
|
||||
GroupID: &groupID,
|
||||
KeyID: &apiKeyModel.ID,
|
||||
},
|
||||
@@ -185,14 +180,15 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
|
||||
event.RequestLog.IsSuccess = false
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicRequestFinished, eventData); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
finalResults[j.Index] = currentResult
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, processedCount)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
@@ -202,10 +198,10 @@ func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string,
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
@@ -216,5 +212,5 @@ func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses
|
||||
}
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
|
||||
return s.StartTestKeysTask(groupID, keysAsText)
|
||||
return s.StartTestKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// Filename: internal/service/resource_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"gemini-balancer/internal/domain/proxy"
|
||||
apperrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
@@ -15,10 +15,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoResourceAvailable = errors.New("no available resource found for the request")
|
||||
)
|
||||
|
||||
// RequestResources 封装了一次成功请求所需的所有资源。
|
||||
type RequestResources struct {
|
||||
KeyGroup *models.KeyGroup
|
||||
APIKey *models.APIKey
|
||||
@@ -27,86 +24,92 @@ type RequestResources struct {
|
||||
RequestConfig *models.RequestConfig
|
||||
}
|
||||
|
||||
// ResourceService 负责根据请求参数和业务规则,动态地选择和分配API密钥及相关资源。
|
||||
type ResourceService struct {
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
keyRepo repository.KeyRepository
|
||||
authTokenRepo repository.AuthTokenRepository
|
||||
apiKeyService *APIKeyService
|
||||
proxyManager *proxy.Module
|
||||
logger *logrus.Entry
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
|
||||
func NewResourceService(
|
||||
sm *settings.SettingsManager,
|
||||
gm *GroupManager,
|
||||
kr repository.KeyRepository,
|
||||
atr repository.AuthTokenRepository,
|
||||
aks *APIKeyService,
|
||||
pm *proxy.Module,
|
||||
logger *logrus.Logger,
|
||||
) *ResourceService {
|
||||
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
|
||||
rs := &ResourceService{
|
||||
settingsManager: sm,
|
||||
groupManager: gm,
|
||||
keyRepo: kr,
|
||||
authTokenRepo: atr,
|
||||
apiKeyService: aks,
|
||||
proxyManager: pm,
|
||||
logger: logger.WithField("component", "ResourceService📦️"),
|
||||
}
|
||||
|
||||
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
|
||||
rs.initOnce.Do(func() {
|
||||
go rs.preWarmCache(logger)
|
||||
go rs.preWarmCache()
|
||||
})
|
||||
return rs
|
||||
|
||||
}
|
||||
|
||||
// --- [模式一:智能聚合模式] ---
|
||||
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||||
// GetResourceFromBasePool 使用智能聚合池模式获取资源。
|
||||
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
|
||||
log.Debug("Entering BasePool resource acquisition.")
|
||||
// 1.筛选出所有符合条件的候选组,并按优先级排序
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
|
||||
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
|
||||
if len(candidateGroups) == 0 {
|
||||
log.Warn("No candidate groups found for BasePool construction.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 2.从 BasePool中,根据系统全局策略选择一个Key
|
||||
|
||||
basePool := &repository.BasePool{
|
||||
CandidateGroups: candidateGroups,
|
||||
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
|
||||
}
|
||||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
|
||||
|
||||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the BasePool.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 3. 组装最终资源
|
||||
// [关键] 在此模式下,RequestConfig 永远是空的,以保证透明性。
|
||||
|
||||
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = &models.RequestConfig{} // 强制为空
|
||||
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
|
||||
|
||||
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// --- [模式二:精确路由模式] ---
|
||||
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||||
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
|
||||
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
|
||||
log.Debug("Entering PreciseRoute resource acquisition.")
|
||||
|
||||
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
|
||||
|
||||
if !ok {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
|
||||
}
|
||||
|
||||
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
|
||||
}
|
||||
|
||||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
|
||||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
@@ -117,39 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou
|
||||
log.WithError(err).Error("Failed to assemble resources for precise route.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = targetGroup.RequestConfig
|
||||
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
|
||||
|
||||
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
|
||||
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
|
||||
allGroups := s.groupManager.GetAllGroups()
|
||||
if len(allGroups) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
if authToken.IsAdmin {
|
||||
for _, group := range allGroups {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
allowedGroupIDs[group.ID] = true
|
||||
}
|
||||
} else {
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
for _, ag := range authToken.AllowedGroups {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
for _, group := range allGroups {
|
||||
if _, ok := allowedGroupIDs[group.ID]; ok {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
}
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
for _, group := range allGroups {
|
||||
if allowedGroupIDs[group.ID] {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(allowedModelsSet))
|
||||
for modelName := range allowedModelsSet {
|
||||
result = append(result, modelName)
|
||||
@@ -158,20 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
|
||||
return result
|
||||
}
|
||||
|
||||
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
|
||||
// --- 私有辅助方法 ---
|
||||
|
||||
// preWarmCache 在后台执行一次性的缓存预热任务。
|
||||
func (s *ResourceService) preWarmCache() {
|
||||
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
|
||||
// 强制加载 GroupManager 缓存
|
||||
s.logger.Info("Pre-warming GroupManager cache...")
|
||||
_ = s.groupManager.GetAllGroups()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
|
||||
defer cancel()
|
||||
|
||||
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
} else {
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
}
|
||||
}
|
||||
|
||||
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
|
||||
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
|
||||
selectedUpstream := s.selectUpstreamForGroup(group)
|
||||
if selectedUpstream == nil {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
|
||||
}
|
||||
var proxyConfig *models.ProxyConfig
|
||||
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
|
||||
// if group.EnableProxy && s.proxyModule != nil {
|
||||
// var err error
|
||||
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
|
||||
// if err != nil {
|
||||
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
|
||||
// }
|
||||
// }
|
||||
var err error
|
||||
// 只有在组明确启用代理时,才为其分配代理
|
||||
if group.EnableProxy {
|
||||
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
|
||||
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
|
||||
}
|
||||
}
|
||||
return &RequestResources{
|
||||
KeyGroup: group,
|
||||
APIKey: apiKey,
|
||||
@@ -180,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
|
||||
}, nil
|
||||
}
|
||||
|
||||
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
|
||||
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
|
||||
if len(group.AllowedUpstreams) > 0 {
|
||||
// (未来可扩展负载均衡逻辑)
|
||||
return group.AllowedUpstreams[0]
|
||||
}
|
||||
globalSettings := s.settingsManager.GetSettings()
|
||||
@@ -191,62 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
|
||||
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
return err
|
||||
}
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
|
||||
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
|
||||
}
|
||||
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
|
||||
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
|
||||
allGroupsFromCache := s.groupManager.GetAllGroups()
|
||||
var candidateGroups []*models.KeyGroup
|
||||
// 1. 确定权限范围
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
isTokenRestricted := len(allowedGroupsFromToken) > 0
|
||||
if isTokenRestricted {
|
||||
for _, ag := range allowedGroupsFromToken {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
}
|
||||
// 2. 筛选
|
||||
|
||||
for _, group := range allGroupsFromCache {
|
||||
// 检查Token权限
|
||||
if isTokenRestricted && !allowedGroupIDs[group.ID] {
|
||||
// 检查令牌权限
|
||||
if !s.isTokenAllowedForGroup(authToken, group.ID) {
|
||||
continue
|
||||
}
|
||||
// 检查模型是否被允许
|
||||
isModelAllowed := false
|
||||
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
|
||||
isModelAllowed = true
|
||||
} else {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
isModelAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if isModelAllowed {
|
||||
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
|
||||
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
|
||||
candidateGroups = append(candidateGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// 3.按 Order 字段升序排序
|
||||
sort.SliceStable(candidateGroups, func(i, j int) bool {
|
||||
return candidateGroups[i].Order < candidateGroups[j].Order
|
||||
})
|
||||
return candidateGroups
|
||||
}
|
||||
|
||||
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
|
||||
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
|
||||
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
|
||||
if authToken.IsAdmin {
|
||||
return true
|
||||
@@ -258,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke
|
||||
// IsIPBanned
|
||||
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
return s.store.Exists(banKey)
|
||||
return s.store.Exists(ctx, banKey)
|
||||
}
|
||||
|
||||
// RecordFailedLoginAttempt
|
||||
@@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
|
||||
count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
|
||||
banDuration := s.SettingsManager.GetIPBanDuration()
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
|
||||
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
|
||||
if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
|
||||
|
||||
s.store.HDel(loginAttemptsKey, ip)
|
||||
s.store.HDel(ctx, loginAttemptsKey, ip)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -34,7 +35,7 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
|
||||
|
||||
func (s *StatsService) Start() {
|
||||
s.logger.Info("Starting event listener for stats maintenance.")
|
||||
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
sub, err := s.store.Subscribe(context.Background(), models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
@@ -67,42 +68,43 @@ func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent
|
||||
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
|
||||
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
|
||||
|
||||
switch event.ChangeReason {
|
||||
case "key_unlinked", "key_hard_deleted":
|
||||
if event.OldStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(ctx, statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
} else {
|
||||
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "key_linked":
|
||||
if event.NewStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
s.store.HIncrBy(ctx, statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
} else {
|
||||
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
default:
|
||||
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
|
||||
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
|
||||
var results []struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ?", groupID).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
@@ -119,37 +121,25 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
}
|
||||
updates["total_keys"] = totalKeys
|
||||
|
||||
if err := s.store.Del(statsKey); err != nil {
|
||||
if err := s.store.Del(ctx, statsKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
||||
}
|
||||
if err := s.store.HSet(statsKey, updates); err != nil {
|
||||
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
|
||||
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
|
||||
}
|
||||
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
|
||||
// TODO 逻辑:
|
||||
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
|
||||
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
|
||||
// 3. 组合成 DashboardStatsResponse
|
||||
// ... 这个方法的具体实现,我们可以在DashboardQueryService中完成,
|
||||
// 这里我们先确保StatsService的核心职责(维护缓存)已经完成。
|
||||
// 为了编译通过,我们先返回一个空对象。
|
||||
|
||||
// 伪代码:
|
||||
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
|
||||
// ...
|
||||
|
||||
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
|
||||
return &models.DashboardStatsResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *StatsService) AggregateHourlyStats() error {
|
||||
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
|
||||
s.logger.Info("Starting aggregation of the last hour's request data...")
|
||||
now := time.Now()
|
||||
endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00
|
||||
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
|
||||
endTime := now.Truncate(time.Hour)
|
||||
startTime := endTime.Add(-1 * time.Hour)
|
||||
|
||||
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
|
||||
type aggregationResult struct {
|
||||
@@ -161,7 +151,8 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
CompletionTokens int64
|
||||
}
|
||||
var results []aggregationResult
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
|
||||
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
||||
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
|
||||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||||
Group("group_id, model_name").
|
||||
@@ -179,7 +170,7 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
var hourlyStats []models.StatsHourly
|
||||
for _, res := range results {
|
||||
hourlyStats = append(hourlyStats, models.StatsHourly{
|
||||
Time: startTime, // 所有记录的时间戳都是该小时的起点
|
||||
Time: startTime,
|
||||
GroupID: res.GroupID,
|
||||
ModelName: res.ModelName,
|
||||
RequestCount: res.RequestCount,
|
||||
@@ -189,7 +180,7 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
})
|
||||
}
|
||||
|
||||
return s.db.Clauses(clause.OnConflict{
|
||||
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
||||
}).Create(&hourlyStats).Error
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// file: gemini-balancer\internal\settings\settings.go
|
||||
// Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版)
|
||||
package settings
|
||||
|
||||
import (
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
const SettingsUpdateChannel = "system_settings:updated"
|
||||
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型
|
||||
var _ models.SettingsManager = (*SettingsManager)(nil)
|
||||
|
||||
// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。
|
||||
type SettingsManager struct {
|
||||
db *gorm.DB
|
||||
syncer *syncer.CacheSyncer[*models.SystemSettings]
|
||||
@@ -27,13 +29,14 @@ type SettingsManager struct {
|
||||
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
|
||||
}
|
||||
|
||||
// NewSettingsManager 创建一个新的 SettingsManager 实例。
|
||||
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
|
||||
sm := &SettingsManager{
|
||||
db: db,
|
||||
logger: logger.WithField("component", "SettingsManager⚙️"),
|
||||
jsonToFieldType: make(map[string]reflect.Type),
|
||||
}
|
||||
// settingsLoader 的职责:读取“砖块”,组装并返回“蓝图”
|
||||
|
||||
settingsType := reflect.TypeOf(models.SystemSettings{})
|
||||
for i := 0; i < settingsType.NumField(); i++ {
|
||||
field := settingsType.Field(i)
|
||||
@@ -42,102 +45,89 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
|
||||
sm.jsonToFieldType[jsonTag] = field.Type
|
||||
}
|
||||
}
|
||||
// settingsLoader 的职责:读取“砖块”,智能组装成“蓝图”
|
||||
|
||||
settingsLoader := func() (*models.SystemSettings, error) {
|
||||
sm.logger.Info("Loading system settings from database...")
|
||||
var dbRecords []models.Setting
|
||||
if err := sm.db.Find(&dbRecords).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to load system settings from db: %w", err)
|
||||
}
|
||||
|
||||
settingsMap := make(map[string]string)
|
||||
for _, record := range dbRecords {
|
||||
settingsMap[record.Key] = record.Value
|
||||
}
|
||||
// 从一个包含了所有“出厂设置”的“蓝图”开始
|
||||
|
||||
settings := defaultSystemSettings()
|
||||
v := reflect.ValueOf(settings).Elem()
|
||||
t := v.Type()
|
||||
// [智能卸货]
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Type().Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if dbValue, ok := settingsMap[jsonTag]; ok {
|
||||
|
||||
if dbValue, ok := settingsMap[jsonTag]; ok {
|
||||
if err := parseAndSetField(fieldValue, dbValue); err != nil {
|
||||
sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
|
||||
sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" {
|
||||
if settings.DefaultUpstreamURL != "" {
|
||||
// 如果全局上游URL已设置,则基于它构建新的检查端点。
|
||||
originalEndpoint := settings.BaseKeyCheckEndpoint
|
||||
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
|
||||
settings.BaseKeyCheckEndpoint = derivedEndpoint
|
||||
sm.logger.Infof(
|
||||
"BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'",
|
||||
originalEndpoint, derivedEndpoint,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。
|
||||
if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" {
|
||||
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
|
||||
sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint)
|
||||
settings.BaseKeyCheckEndpoint = derivedEndpoint
|
||||
} else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" {
|
||||
// 恢复 else 日志,以明确告知用户正在使用自定义覆盖。
|
||||
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
|
||||
}
|
||||
|
||||
sm.logger.Info("System settings loaded and cached.")
|
||||
sm.DisplaySettings(settings)
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
|
||||
}
|
||||
sm.syncer = s
|
||||
go sm.ensureSettingsInitialized()
|
||||
|
||||
if err := sm.ensureSettingsInitialized(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err)
|
||||
}
|
||||
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图”
|
||||
// GetSettings 返回当前缓存的系统设置。
|
||||
func (sm *SettingsManager) GetSettings() *models.SystemSettings {
|
||||
return sm.syncer.Get()
|
||||
}
|
||||
|
||||
// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库
|
||||
// UpdateSettings 更新一个或多个系统设置。
|
||||
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
|
||||
var settingsToUpdate []models.Setting
|
||||
|
||||
for key, value := range settingsMap {
|
||||
fieldType, ok := sm.jsonToFieldType[key]
|
||||
if !ok {
|
||||
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
|
||||
continue
|
||||
}
|
||||
var dbValue string
|
||||
// [智能打包]
|
||||
// 如果字段是 slice 或 map,我们就将传入的 interface{} “打包”成 JSON string
|
||||
kind := fieldType.Kind()
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, marshalErr := json.Marshal(value)
|
||||
if marshalErr != nil {
|
||||
// [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。
|
||||
sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr)
|
||||
continue // 跳过,继续处理下一个key
|
||||
}
|
||||
dbValue = string(jsonBytes)
|
||||
} else if kind == reflect.Bool {
|
||||
if b, ok := value.(bool); ok {
|
||||
dbValue = strconv.FormatBool(b)
|
||||
} else {
|
||||
dbValue = "false"
|
||||
}
|
||||
} else {
|
||||
dbValue = fmt.Sprintf("%v", value)
|
||||
|
||||
dbValue, err := sm.convertToDBValue(key, value, fieldType)
|
||||
if err != nil {
|
||||
sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
settingsToUpdate = append(settingsToUpdate, models.Setting{
|
||||
Key: key,
|
||||
Value: dbValue,
|
||||
})
|
||||
}
|
||||
|
||||
if len(settingsToUpdate) > 0 {
|
||||
err := sm.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
@@ -147,83 +137,20 @@ func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) er
|
||||
return fmt.Errorf("failed to update settings in db: %w", err)
|
||||
}
|
||||
}
|
||||
return sm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义
|
||||
func (sm *SettingsManager) ensureSettingsInitialized() {
|
||||
defaults := defaultSystemSettings()
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
var existing models.Setting
|
||||
if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound {
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
// [智能初始化]
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
// 为复杂类型,生成一个“空的”JSON字符串,例如 "[]" 或 "{}"
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
setting := models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"), // 元数据中的default,永远来自tag
|
||||
}
|
||||
if err := sm.db.Create(&setting).Error; err != nil {
|
||||
sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err)
|
||||
}
|
||||
}
|
||||
if err := sm.syncer.Invalidate(); err != nil {
|
||||
sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err)
|
||||
return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetAndSaveSettings [核心新增] 將所有配置重置為其在 'default' 標籤中定義的值。
|
||||
|
||||
// ResetAndSaveSettings 将所有设置重置为其默认值。
|
||||
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
|
||||
defaults := defaultSystemSettings()
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
var settingsToSave []models.Setting
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
settingsToSave := sm.buildSettingsFromDefaults(defaults)
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
// [智能重置]
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
setting := models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"),
|
||||
}
|
||||
settingsToSave = append(settingsToSave, setting)
|
||||
}
|
||||
if len(settingsToSave) > 0 {
|
||||
err := sm.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
@@ -233,8 +160,93 @@ func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error
|
||||
return nil, fmt.Errorf("failed to reset settings in db: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := sm.syncer.Invalidate(); err != nil {
|
||||
sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err)
|
||||
sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err)
|
||||
return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err)
|
||||
}
|
||||
|
||||
return defaults, nil
|
||||
}
|
||||
|
||||
// --- 私有辅助函数 ---
|
||||
|
||||
func (sm *SettingsManager) ensureSettingsInitialized() error {
|
||||
defaults := defaultSystemSettings()
|
||||
settingsToCreate := sm.buildSettingsFromDefaults(defaults)
|
||||
|
||||
for _, setting := range settingsToCreate {
|
||||
var existing models.Setting
|
||||
err := sm.db.Where("key = ?", setting.Key).First(&existing).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
sm.logger.Infof("Initializing new setting '%s'", setting.Key)
|
||||
if createErr := sm.db.Create(&setting).Error; createErr != nil {
|
||||
return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting {
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
var settings []models.Setting
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
|
||||
settings = append(settings, models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"),
|
||||
})
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。
|
||||
func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) {
|
||||
kind := fieldType.Kind()
|
||||
|
||||
switch kind {
|
||||
case reflect.Slice, reflect.Map:
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
|
||||
case reflect.Bool:
|
||||
b, ok := value.(bool)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expected bool, but got %T", value)
|
||||
}
|
||||
return strconv.FormatBool(b), nil
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("%v", value), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Filename: internal/store/factory.go
|
||||
package store
|
||||
|
||||
import (
|
||||
@@ -11,7 +12,6 @@ import (
|
||||
|
||||
// NewStore creates a new store based on the application configuration.
|
||||
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
|
||||
// 检查是否有Redis配置
|
||||
if cfg.Redis.DSN != "" {
|
||||
opts, err := redis.ParseURL(cfg.Redis.DSN)
|
||||
if err != nil {
|
||||
@@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
|
||||
client := redis.NewClient(opts)
|
||||
if err := client.Ping(context.Background()).Err(); err != nil {
|
||||
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
|
||||
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
|
||||
return NewMemoryStore(logger), nil
|
||||
}
|
||||
logger.Info("Successfully connected to Redis. Using Redis as store.")
|
||||
return NewRedisStore(client), nil
|
||||
return NewRedisStore(client, logger), nil
|
||||
}
|
||||
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
|
||||
return NewMemoryStore(logger), nil
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
|
||||
// Filename: internal/store/memory_store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure memoryStore implements Store interface
|
||||
var _ Store = (*memoryStore)(nil)
|
||||
|
||||
type memoryStoreItem struct {
|
||||
@@ -32,7 +35,6 @@ type memoryStore struct {
|
||||
items map[string]*memoryStoreItem
|
||||
pubsub map[string][]chan *Message
|
||||
mu sync.RWMutex
|
||||
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
|
||||
rng *rand.Rand
|
||||
rngMu sync.Mutex
|
||||
logger *logrus.Entry
|
||||
@@ -42,7 +44,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
store := &memoryStore{
|
||||
items: make(map[string]*memoryStoreItem),
|
||||
pubsub: make(map[string][]chan *Message),
|
||||
// 使用当前时间作为种子,创建一个新的随机数源
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
logger: logger.WithField("component", "store.memory 🗱"),
|
||||
}
|
||||
@@ -50,13 +51,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
return store
|
||||
}
|
||||
|
||||
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
|
||||
func (s *memoryStore) startGCollector() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
now := time.Now() // 避免在循环中重复调用
|
||||
now := time.Now()
|
||||
for key, item := range s.items {
|
||||
if !item.expireAt.IsZero() && now.After(item.expireAt) {
|
||||
delete(s.items, key)
|
||||
@@ -66,92 +66,10 @@ func (s *memoryStore) startGCollector() {
|
||||
}
|
||||
}
|
||||
|
||||
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
|
||||
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
|
||||
// --- 内存实现可以忽略该参数,用 _ 接收 ---
|
||||
|
||||
mainItem, mainOk := s.items[mainKey]
|
||||
var mainSet map[string]struct{}
|
||||
|
||||
if mainOk && !mainItem.isExpired() {
|
||||
// 安全地进行类型断言
|
||||
mainSet, mainOk = mainItem.value.(map[string]struct{})
|
||||
// 确保断言成功且集合不为空
|
||||
mainOk = mainOk && len(mainSet) > 0
|
||||
} else {
|
||||
mainOk = false
|
||||
}
|
||||
|
||||
if !mainOk {
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
// 安全地进行类型断言
|
||||
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
|
||||
if !cooldownSetOk || len(cooldownSet) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainSet = cooldownSet
|
||||
}
|
||||
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[cooldownKey] = cooldownItem
|
||||
}
|
||||
// 安全地处理冷却池
|
||||
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
|
||||
if !ok {
|
||||
cooldownSet = make(map[string]struct{})
|
||||
cooldownItem.value = cooldownSet
|
||||
}
|
||||
cooldownSet[popped] = struct{}{}
|
||||
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
// SRandMember [并发修复版] 使用带锁的rng
|
||||
func (s *memoryStore) SRandMember(key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.rngMu.Lock()
|
||||
n := s.rng.Intn(len(members))
|
||||
s.rngMu.Unlock()
|
||||
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
|
||||
|
||||
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var expireAt time.Time
|
||||
@@ -162,7 +80,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -175,7 +93,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
func (s *memoryStore) Del(keys ...string) error {
|
||||
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, key := range keys {
|
||||
@@ -184,14 +102,25 @@ func (s *memoryStore) Del(keys ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Exists(key string) (bool, error) {
|
||||
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
return ok && !item.isExpired(), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
item.expireAt = time.Now().Add(expiration)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -208,7 +137,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
|
||||
|
||||
func (s *memoryStore) Close() error { return nil }
|
||||
|
||||
func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -223,7 +152,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -242,7 +171,22 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
if value, exists := hash[field]; exists {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -259,7 +203,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -281,7 +225,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return newVal, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -301,7 +245,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -326,7 +270,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -345,7 +289,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -375,7 +319,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -393,7 +337,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -410,7 +354,51 @@ func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.rngMu.Lock()
|
||||
n := s.rng.Intn(len(members))
|
||||
s.rngMu.Unlock()
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
unionSet := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
continue
|
||||
}
|
||||
if set, ok := item.value.(map[string]struct{}); ok {
|
||||
for member := range set {
|
||||
unionSet[member] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
destItem := &memoryStoreItem{value: unionSet}
|
||||
s.items[destination] = destItem
|
||||
return int64(len(unionSet)), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -426,7 +414,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -447,8 +435,17 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
return list[index], nil
|
||||
}
|
||||
|
||||
// Zset methods... (ZAdd, ZRange, ZRem)
|
||||
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for key, value := range values {
|
||||
// 内存存储不支持独立的 TTL,因此我们假设永不过期
|
||||
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -471,8 +468,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
for val, score := range membersMap {
|
||||
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
|
||||
}
|
||||
// NOTE: This ZSet implementation is simple but not performant for large sets.
|
||||
// A production implementation would use a skip list or a balanced tree.
|
||||
sort.Slice(newZSet, func(i, j int) bool {
|
||||
if newZSet[i].Score == newZSet[j].Score {
|
||||
return newZSet[i].Value < newZSet[j].Value
|
||||
@@ -482,7 +477,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
item.value = newZSet
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -515,7 +510,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -540,13 +535,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pipeline implementation
|
||||
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
mainItem, mainOk := s.items[mainKey]
|
||||
var mainSet map[string]struct{}
|
||||
if mainOk && !mainItem.isExpired() {
|
||||
mainSet, mainOk = mainItem.value.(map[string]struct{})
|
||||
mainOk = mainOk && len(mainSet) > 0
|
||||
} else {
|
||||
mainOk = false
|
||||
}
|
||||
if !mainOk {
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
|
||||
if !cooldownSetOk || len(cooldownSet) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainSet = cooldownSet
|
||||
}
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[cooldownKey] = cooldownItem
|
||||
}
|
||||
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
|
||||
if !ok {
|
||||
cooldownSet = make(map[string]struct{})
|
||||
cooldownItem.value = cooldownSet
|
||||
}
|
||||
cooldownSet[popped] = struct{}{}
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
type memoryPipeliner struct {
|
||||
store *memoryStore
|
||||
ops []func()
|
||||
}
|
||||
|
||||
func (s *memoryStore) Pipeline() Pipeliner {
|
||||
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
|
||||
return &memoryPipeliner{store: s}
|
||||
}
|
||||
func (p *memoryPipeliner) Exec() error {
|
||||
@@ -559,7 +597,6 @@ func (p *memoryPipeliner) Exec() error {
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
|
||||
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
|
||||
capturedKey := key
|
||||
p.ops = append(p.ops, func() {
|
||||
if item, ok := p.store.items[capturedKey]; ok {
|
||||
@@ -576,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
|
||||
capturedKey := key
|
||||
capturedValue := value
|
||||
p.ops = append(p.ops, func() {
|
||||
var expireAt time.Time
|
||||
if expiration > 0 {
|
||||
expireAt = time.Now().Add(expiration)
|
||||
}
|
||||
p.store.items[capturedKey] = &memoryStoreItem{
|
||||
value: capturedValue,
|
||||
expireAt: expireAt,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
@@ -615,7 +668,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
capturedKey := key
|
||||
capturedValues := make([]any, len(values))
|
||||
@@ -637,11 +689,126 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
item.value = append(stringValues, list...)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
||||
capturedKey := key
|
||||
capturedValue := fmt.Sprintf("%v", value)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
newList := make([]string, 0, len(list))
|
||||
removed := int64(0)
|
||||
for _, v := range list {
|
||||
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
|
||||
removed++
|
||||
continue
|
||||
}
|
||||
newList = append(newList, v)
|
||||
}
|
||||
item.value = newList
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||
capturedKey := key
|
||||
capturedValues := make(map[string]any, len(values))
|
||||
for k, v := range values {
|
||||
capturedValues[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
for field, value := range capturedValues {
|
||||
hash[field] = fmt.Sprintf("%v", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
capturedKey := key
|
||||
capturedField := field
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
|
||||
hash[capturedField] = strconv.FormatInt(current+incr, 10)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
capturedKey := key
|
||||
capturedMembers := make(map[string]float64, len(members))
|
||||
for k, v := range members {
|
||||
capturedMembers[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]float64)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
zset, ok := item.value.(map[string]float64)
|
||||
if !ok {
|
||||
zset = make(map[string]float64)
|
||||
item.value = zset
|
||||
}
|
||||
for member, score := range capturedMembers {
|
||||
zset[member] = score
|
||||
}
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
copy(capturedMembers, members)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
zset, ok := item.value.(map[string]float64)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, member := range capturedMembers {
|
||||
delete(zset, fmt.Sprintf("%v", member))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) MSet(values map[string]any) {
|
||||
capturedValues := make(map[string]any, len(values))
|
||||
for k, v := range values {
|
||||
capturedValues[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
for key, value := range capturedValues {
|
||||
p.store.items[key] = &memoryStoreItem{
|
||||
value: value,
|
||||
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// --- Pub/Sub implementation (remains unchanged) ---
|
||||
type memorySubscription struct {
|
||||
store *memoryStore
|
||||
channelName string
|
||||
@@ -649,10 +816,11 @@ type memorySubscription struct {
|
||||
}
|
||||
|
||||
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
|
||||
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
|
||||
func (ms *memorySubscription) Close() error {
|
||||
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
|
||||
}
|
||||
func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
subscribers, ok := s.pubsub[channel]
|
||||
@@ -669,7 +837,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
|
||||
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
msgChan := make(chan *Message, 10)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Filename: internal/store/redis_store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
@@ -8,22 +10,20 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure RedisStore implements Store interface
|
||||
var _ Store = (*RedisStore)(nil)
|
||||
|
||||
// RedisStore is a Redis-backed key-value store.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
popAndCycleScript *redis.Script
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new RedisStore instance.
|
||||
func NewRedisStore(client *redis.Client) Store {
|
||||
// Lua script for atomic pop-and-cycle operation.
|
||||
// KEYS[1]: main set key
|
||||
// KEYS[2]: cooldown set key
|
||||
func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
|
||||
const script = `
|
||||
if redis.call('SCARD', KEYS[1]) == 0 then
|
||||
if redis.call('SCARD', KEYS[2]) == 0 then
|
||||
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
popAndCycleScript: redis.NewScript(script),
|
||||
logger: logger.WithField("component", "store.redis 🗄️"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(context.Background(), key, value, ttl).Err()
|
||||
func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(ctx, key, value, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
val, err := s.client.Get(context.Background(), key).Bytes()
|
||||
func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := s.client.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrNotFound
|
||||
@@ -54,53 +55,67 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) Del(keys ...string) error {
|
||||
func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.Del(context.Background(), keys...).Err()
|
||||
return s.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Exists(key string) (bool, error) {
|
||||
val, err := s.client.Exists(context.Background(), key).Result()
|
||||
func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
|
||||
val, err := s.client.Exists(ctx, key).Result()
|
||||
return val > 0, err
|
||||
}
|
||||
|
||||
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(context.Background(), key, value, ttl).Result()
|
||||
func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(ctx, key, value, ttl).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HSet(key string, values map[string]any) error {
|
||||
return s.client.HSet(context.Background(), key, values).Err()
|
||||
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
|
||||
return s.client.Expire(ctx, key, expiration).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(context.Background(), key).Result()
|
||||
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
|
||||
return s.client.HSet(ctx, key, values).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
|
||||
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
|
||||
val, err := s.client.HGet(ctx, key, field).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func (s *RedisStore) HDel(key string, fields ...string) error {
|
||||
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(ctx, key, field, incr).Result()
|
||||
}
|
||||
func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.HDel(context.Background(), key, fields...).Err()
|
||||
return s.client.HDel(ctx, key, fields...).Err()
|
||||
}
|
||||
func (s *RedisStore) LPush(key string, values ...any) error {
|
||||
return s.client.LPush(context.Background(), key, values...).Err()
|
||||
func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
|
||||
return s.client.LPush(ctx, key, values...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) LRem(key string, count int64, value any) error {
|
||||
return s.client.LRem(context.Background(), key, count, value).Err()
|
||||
func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
|
||||
return s.client.LRem(ctx, key, count, value).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
|
||||
func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(ctx, key, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
@@ -110,29 +125,40 @@ func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) SAdd(key string, members ...any) error {
|
||||
return s.client.SAdd(context.Background(), key, members...).Err()
|
||||
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
|
||||
pairs := make([]interface{}, 0, len(values)*2)
|
||||
for k, v := range values {
|
||||
pairs = append(pairs, k, v)
|
||||
}
|
||||
return s.client.MSet(ctx, pairs...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(context.Background(), key, count).Result()
|
||||
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
|
||||
return s.client.SAdd(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SMembers(key string) ([]string, error) {
|
||||
return s.client.SMembers(context.Background(), key).Result()
|
||||
func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(ctx, key, count).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRem(key string, members ...any) error {
|
||||
func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
|
||||
return s.client.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.SRem(context.Background(), key, members...).Err()
|
||||
return s.client.SRem(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
member, err := s.client.SRandMember(context.Background(), key).Result()
|
||||
func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
|
||||
member, err := s.client.SRandMember(ctx, key).Result()
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
@@ -141,81 +167,50 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
return member, nil
|
||||
}
|
||||
|
||||
// === 新增方法实现 ===
|
||||
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
|
||||
if len(keys) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return s.client.SUnionStore(ctx, destination, keys...).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
|
||||
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
redisMembers := make([]redis.Z, 0, len(members))
|
||||
redisMembers := make([]redis.Z, len(members))
|
||||
i := 0
|
||||
for member, score := range members {
|
||||
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
|
||||
redisMembers[i] = redis.Z{Score: score, Member: member}
|
||||
i++
|
||||
}
|
||||
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
|
||||
return s.client.ZAdd(ctx, key, redisMembers...).Err()
|
||||
}
|
||||
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(context.Background(), key, start, stop).Result()
|
||||
func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(ctx, key, start, stop).Result()
|
||||
}
|
||||
func (s *RedisStore) ZRem(key string, members ...any) error {
|
||||
func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.ZRem(context.Background(), key, members...).Err()
|
||||
return s.client.ZRem(ctx, key, members...).Err()
|
||||
}
|
||||
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
|
||||
func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
// Lua script returns a string, so we need to type assert
|
||||
if str, ok := val.(string); ok {
|
||||
return str, nil
|
||||
}
|
||||
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
type redisPipeliner struct{ pipe redis.Pipeliner }
|
||||
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) {
|
||||
p.pipe.HSet(context.Background(), key, values)
|
||||
}
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(context.Background(), key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) {
|
||||
if len(keys) > 0 {
|
||||
p.pipe.Del(context.Background(), keys...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) {
|
||||
p.pipe.SAdd(context.Background(), key, members...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) {
|
||||
if len(members) > 0 {
|
||||
p.pipe.SRem(context.Background(), key, members...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) {
|
||||
p.pipe.LPush(context.Background(), key, values...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(context.Background(), key, count, value)
|
||||
}
|
||||
|
||||
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(context.Background(), key, index).Result()
|
||||
func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(ctx, key, index).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
@@ -225,47 +220,131 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(context.Background(), key, expiration)
|
||||
type redisPipeliner struct {
|
||||
pipe redis.Pipeliner
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *RedisStore) Pipeline() Pipeliner {
|
||||
return &redisPipeliner{pipe: s.client.Pipeline()}
|
||||
func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
|
||||
return &redisPipeliner{
|
||||
pipe: s.client.Pipeline(),
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(p.ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(p.ctx, key, expiration)
|
||||
}
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(p.ctx, key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(p.ctx, key, count, value)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
|
||||
p.pipe.Set(p.ctx, key, value, expiration)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) MSet(values map[string]any) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
p.pipe.MSet(p.ctx, values)
|
||||
}
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
if len(members) == 0 {
|
||||
return
|
||||
}
|
||||
redisMembers := make([]redis.Z, len(members))
|
||||
i := 0
|
||||
for member, score := range members {
|
||||
redisMembers[i] = redis.Z{Score: score, Member: member}
|
||||
i++
|
||||
}
|
||||
p.pipe.ZAdd(p.ctx, key, redisMembers...)
|
||||
}
|
||||
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
|
||||
|
||||
type redisSubscription struct {
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
once sync.Once
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
logger *logrus.Entry
|
||||
wg sync.WaitGroup
|
||||
close context.CancelFunc
|
||||
channelName string
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(ctx, channel)
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
subCtx, cancel := context.WithCancel(context.Background())
|
||||
sub := &redisSubscription{
|
||||
pubsub: pubsub,
|
||||
msgChan: make(chan *Message, 10),
|
||||
logger: s.logger,
|
||||
close: cancel,
|
||||
channelName: channel,
|
||||
}
|
||||
sub.wg.Add(1)
|
||||
go sub.bridge(subCtx)
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) bridge(ctx context.Context) {
|
||||
defer rs.wg.Done()
|
||||
defer close(rs.msgChan)
|
||||
redisCh := rs.pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case redisMsg, ok := <-redisCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
msg := &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
select {
|
||||
case rs.msgChan <- msg:
|
||||
default:
|
||||
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Channel() <-chan *Message {
|
||||
rs.once.Do(func() {
|
||||
rs.msgChan = make(chan *Message)
|
||||
go func() {
|
||||
defer close(rs.msgChan)
|
||||
for redisMsg := range rs.pubsub.Channel() {
|
||||
rs.msgChan <- &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
return rs.msgChan
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
|
||||
|
||||
func (s *RedisStore) Publish(channel string, message []byte) error {
|
||||
return s.client.Publish(context.Background(), channel, message).Err()
|
||||
func (rs *redisSubscription) ChannelName() string {
|
||||
return rs.channelName
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(context.Background(), channel)
|
||||
_, err := pubsub.Receive(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
return &redisSubscription{pubsub: pubsub}, nil
|
||||
func (rs *redisSubscription) Close() error {
|
||||
rs.close()
|
||||
err := rs.pubsub.Close()
|
||||
rs.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
|
||||
return s.client.Publish(ctx, channel, message).Err()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
// Filename: internal/store/store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
@@ -17,6 +20,7 @@ type Message struct {
|
||||
// Subscription represents an active subscription to a pub/sub channel.
|
||||
type Subscription interface {
|
||||
Channel() <-chan *Message
|
||||
ChannelName() string
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -31,6 +35,8 @@ type Pipeliner interface {
|
||||
HIncrBy(key, field string, incr int64)
|
||||
|
||||
// SET
|
||||
MSet(values map[string]any)
|
||||
Set(key string, value []byte, expiration time.Duration)
|
||||
SAdd(key string, members ...any)
|
||||
SRem(key string, members ...any)
|
||||
|
||||
@@ -38,6 +44,10 @@ type Pipeliner interface {
|
||||
LPush(key string, values ...any)
|
||||
LRem(key string, count int64, value any)
|
||||
|
||||
// ZSET
|
||||
ZAdd(key string, members map[string]float64)
|
||||
ZRem(key string, members ...any)
|
||||
|
||||
// Execution
|
||||
Exec() error
|
||||
}
|
||||
@@ -45,44 +55,48 @@ type Pipeliner interface {
|
||||
// Store is the master interface for our cache service.
|
||||
type Store interface {
|
||||
// Basic K/V operations
|
||||
Set(key string, value []byte, ttl time.Duration) error
|
||||
Get(key string) ([]byte, error)
|
||||
Del(keys ...string) error
|
||||
Exists(key string) (bool, error)
|
||||
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Del(ctx context.Context, keys ...string) error
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
|
||||
MSet(ctx context.Context, values map[string]any) error
|
||||
|
||||
// HASH operations
|
||||
HSet(key string, values map[string]any) error
|
||||
HGetAll(key string) (map[string]string, error)
|
||||
HIncrBy(key, field string, incr int64) (int64, error)
|
||||
HDel(key string, fields ...string) error // [新增]
|
||||
HSet(ctx context.Context, key string, values map[string]any) error
|
||||
HGet(ctx context.Context, key, field string) (string, error)
|
||||
HGetAll(ctx context.Context, key string) (map[string]string, error)
|
||||
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
|
||||
HDel(ctx context.Context, key string, fields ...string) error
|
||||
|
||||
// LIST operations
|
||||
LPush(key string, values ...any) error
|
||||
LRem(key string, count int64, value any) error
|
||||
Rotate(key string) (string, error)
|
||||
LIndex(key string, index int64) (string, error)
|
||||
LPush(ctx context.Context, key string, values ...any) error
|
||||
LRem(ctx context.Context, key string, count int64, value any) error
|
||||
Rotate(ctx context.Context, key string) (string, error)
|
||||
LIndex(ctx context.Context, key string, index int64) (string, error)
|
||||
Expire(ctx context.Context, key string, expiration time.Duration) error
|
||||
|
||||
// SET operations
|
||||
SAdd(key string, members ...any) error
|
||||
SPopN(key string, count int64) ([]string, error)
|
||||
SMembers(key string) ([]string, error)
|
||||
SRem(key string, members ...any) error
|
||||
SRandMember(key string) (string, error)
|
||||
SAdd(ctx context.Context, key string, members ...any) error
|
||||
SPopN(ctx context.Context, key string, count int64) ([]string, error)
|
||||
SMembers(ctx context.Context, key string) ([]string, error)
|
||||
SRem(ctx context.Context, key string, members ...any) error
|
||||
SRandMember(ctx context.Context, key string) (string, error)
|
||||
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
|
||||
|
||||
// Pub/Sub operations
|
||||
Publish(channel string, message []byte) error
|
||||
Subscribe(channel string) (Subscription, error)
|
||||
Publish(ctx context.Context, channel string, message []byte) error
|
||||
Subscribe(ctx context.Context, channel string) (Subscription, error)
|
||||
|
||||
// Pipeline (optional) - 我们在redis实现它,内存版暂时不实现
|
||||
Pipeline() Pipeliner
|
||||
// Pipeline
|
||||
Pipeline(ctx context.Context) Pipeliner
|
||||
|
||||
// Close closes the store and releases any underlying resources.
|
||||
Close() error
|
||||
|
||||
// === 新增方法,支持轮询策略 ===
|
||||
ZAdd(key string, members map[string]float64) error
|
||||
ZRange(key string, start, stop int64) ([]string, error)
|
||||
ZRem(key string, members ...any) error
|
||||
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
|
||||
// ZSET operations
|
||||
ZAdd(ctx context.Context, key string, members map[string]float64) error
|
||||
ZRange(ctx context.Context, key string, start, stop int64) ([]string, error)
|
||||
ZRem(ctx context.Context, key string, members ...any) error
|
||||
PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package syncer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/store"
|
||||
"log"
|
||||
@@ -51,7 +52,7 @@ func (s *CacheSyncer[T]) Get() T {
|
||||
|
||||
func (s *CacheSyncer[T]) Invalidate() error {
|
||||
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
|
||||
return s.store.Publish(s.channelName, []byte("reload"))
|
||||
return s.store.Publish(context.Background(), s.channelName, []byte("reload"))
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) Stop() {
|
||||
@@ -84,7 +85,7 @@ func (s *CacheSyncer[T]) listenForUpdates() {
|
||||
default:
|
||||
}
|
||||
|
||||
subscription, err := s.store.Subscribe(s.channelName)
|
||||
subscription, err := s.store.Subscribe(context.Background(), s.channelName)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// Filename: internal/task/task.go (最终校准版)
|
||||
// Filename: internal/task/task.go
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,15 +16,13 @@ const (
|
||||
ResultTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
// Reporter 接口,定义了领域如何与任务服务交互。
|
||||
type Reporter interface {
|
||||
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||||
EndTaskByID(taskID, resourceID string, result any, taskErr error)
|
||||
UpdateProgressByID(taskID string, processed int) error
|
||||
UpdateTotalByID(taskID string, total int) error
|
||||
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||||
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
|
||||
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
|
||||
UpdateTotalByID(ctx context.Context, taskID string, total int) error
|
||||
}
|
||||
|
||||
// Status 代表一个后台任务的完整状态
|
||||
type Status struct {
|
||||
ID string `json:"id"`
|
||||
TaskType string `json:"task_type"`
|
||||
@@ -38,13 +37,11 @@ type Status struct {
|
||||
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// Task 是任务管理的核心服务
|
||||
type Task struct {
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewTask 是 Task 的构造函数
|
||||
func NewTask(store store.Store, logger *logrus.Logger) *Task {
|
||||
return &Task{
|
||||
store: store,
|
||||
@@ -62,15 +59,14 @@ func (s *Task) getTaskDataKey(taskID string) string {
|
||||
return fmt.Sprintf("task:data:%s", taskID)
|
||||
}
|
||||
|
||||
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
|
||||
func (s *Task) getIsRunningFlagKey(taskID string) string {
|
||||
return fmt.Sprintf("task:running:%s", taskID)
|
||||
}
|
||||
|
||||
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
|
||||
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
|
||||
if existingTaskID, err := s.store.Get(ctx, lockKey); err == nil && len(existingTaskID) > 0 {
|
||||
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
||||
}
|
||||
|
||||
@@ -94,35 +90,34 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int
|
||||
timeout = ResultTTL * 24
|
||||
}
|
||||
|
||||
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
|
||||
if err := s.store.Set(ctx, lockKey, []byte(taskID), timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
|
||||
}
|
||||
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
|
||||
}
|
||||
|
||||
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
|
||||
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
_ = s.store.Del(taskKey)
|
||||
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
_ = s.store.Del(ctx, taskKey)
|
||||
return nil, fmt.Errorf("failed to set task running flag: %w", err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
|
||||
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
defer func() {
|
||||
if err := s.store.Del(lockKey); err != nil {
|
||||
if err := s.store.Del(ctx, lockKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
|
||||
}
|
||||
}()
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
_ = s.store.Del(runningFlagKey)
|
||||
status, err := s.GetStatus(taskID)
|
||||
if err != nil {
|
||||
_ = s.store.Del(ctx, runningFlagKey)
|
||||
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
|
||||
return
|
||||
}
|
||||
@@ -141,15 +136,14 @@ func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr er
|
||||
}
|
||||
updatedTaskBytes, _ := json.Marshal(status)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||||
if err := s.store.Set(ctx, taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus 通过ID获取任务状态,供外部(如API Handler)调用
|
||||
func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||||
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
statusBytes, err := s.store.Get(taskKey)
|
||||
statusBytes, err := s.store.Get(ctx, taskKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
@@ -161,22 +155,18 @@ func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||||
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
||||
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
|
||||
}
|
||||
|
||||
if !status.IsRunning && status.FinishedAt != nil {
|
||||
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// UpdateProgressByID 通过ID更新任务进度
|
||||
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if _, err := s.store.Get(runningFlagKey); err != nil {
|
||||
// 任务已结束,静默返回是预期行为。
|
||||
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
|
||||
return nil
|
||||
}
|
||||
status, err := s.GetStatus(taskID)
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
|
||||
return nil
|
||||
@@ -184,7 +174,6 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
if !status.IsRunning {
|
||||
return nil
|
||||
}
|
||||
// 调用传入的 updater 函数来修改 status
|
||||
updater(status)
|
||||
statusBytes, marshalErr := json.Marshal(status)
|
||||
if marshalErr != nil {
|
||||
@@ -192,23 +181,20 @@ func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
return nil
|
||||
}
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
// 使用更长的TTL,确保运行中的任务不会过早过期
|
||||
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, ResultTTL*24); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
|
||||
return s.updateTask(ctx, taskID, func(status *Status) {
|
||||
status.Processed = processed
|
||||
})
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateTotalByID(taskID string, total int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
|
||||
return s.updateTask(ctx, taskID, func(status *Status) {
|
||||
status.Total = total
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user