// Filename: internal/handlers/keygroup_handler.go package handlers import ( "context" "encoding/json" "fmt" "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "gemini-balancer/internal/response" "gemini-balancer/internal/service" "gemini-balancer/internal/store" "regexp" "strconv" "time" "github.com/gin-gonic/gin" "github.com/microcosm-cc/bluemonday" ) type KeyGroupHandler struct { groupManager *service.GroupManager store store.Store queryService *service.DashboardQueryService } func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.DashboardQueryService) *KeyGroupHandler { return &KeyGroupHandler{ groupManager: gm, queryService: qs, store: s, } } func isValidGroupName(name string) bool { if name == "" { return false } match, _ := regexp.MatchString("^[a-z0-9_-]{3,30}$", name) return match } type KeyGroupOperationalSettings struct { EnableKeyCheck *bool `json:"enable_key_check"` KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"` KeyBlacklistThreshold *int `json:"key_blacklist_threshold"` KeyCooldownMinutes *int `json:"key_cooldown_minutes"` KeyCheckConcurrency *int `json:"key_check_concurrency"` KeyCheckEndpoint *string `json:"key_check_endpoint"` KeyCheckModel *string `json:"key_check_model"` MaxRetries *int `json:"max_retries"` EnableSmartGateway *bool `json:"enable_smart_gateway"` } type CreateKeyGroupRequest struct { Name string `json:"name" binding:"required"` DisplayName string `json:"display_name"` Description string `json:"description"` PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"` EnableProxy bool `json:"enable_proxy"` ChannelType string `json:"channel_type"` KeyGroupOperationalSettings } type UpdateKeyGroupRequest struct { Name *string `json:"name"` DisplayName *string `json:"display_name"` Description *string `json:"description"` PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"` EnableProxy *bool `json:"enable_proxy"` ChannelType *string `json:"channel_type"` KeyGroupOperationalSettings AllowedUpstreams []string `json:"allowed_upstreams"` AllowedModels []string `json:"allowed_models"` } type KeyGroupResponse struct { ID uint `json:"id"` Name string `json:"name"` DisplayName string `json:"display_name"` Description string `json:"description"` PollingStrategy models.PollingStrategy `json:"polling_strategy"` ChannelType string `json:"channel_type"` EnableProxy bool `json:"enable_proxy"` APIKeysCount int64 `json:"api_keys_count"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` Order int `json:"order"` AllowedModels []string `json:"allowed_models"` AllowedUpstreams []string `json:"allowed_upstreams"` } type KeyGroupDetailsResponse struct { KeyGroupResponse Settings *models.GroupSettings `json:"settings,omitempty"` RequestConfig *models.RequestConfig `json:"request_config,omitempty"` } func transformModelsToStrings(mappings []*models.GroupModelMapping) []string { modelNames := make([]string, 0, len(mappings)) for _, mapping := range mappings { if mapping != nil { modelNames = append(modelNames, mapping.ModelName) } } return modelNames } func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string { urls := make([]string, 0, len(upstreams)) for _, upstream := range upstreams { if upstream != nil { urls = append(urls, upstream.URL) } } return urls } func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse { return KeyGroupResponse{ ID: group.ID, Name: group.Name, DisplayName: group.DisplayName, Description: group.Description, PollingStrategy: group.PollingStrategy, ChannelType: group.ChannelType, EnableProxy: group.EnableProxy, APIKeysCount: keyCount, CreatedAt: group.CreatedAt, UpdatedAt: group.UpdatedAt, Order: group.Order, AllowedModels: transformModelsToStrings(group.AllowedModels), AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), } } func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings { return &models.KeyGroupSettings{ EnableKeyCheck: settings.EnableKeyCheck, KeyCheckIntervalMinutes: settings.KeyCheckIntervalMinutes, KeyBlacklistThreshold: settings.KeyBlacklistThreshold, KeyCooldownMinutes: settings.KeyCooldownMinutes, KeyCheckConcurrency: settings.KeyCheckConcurrency, KeyCheckEndpoint: settings.KeyCheckEndpoint, KeyCheckModel: settings.KeyCheckModel, MaxRetries: settings.MaxRetries, EnableSmartGateway: settings.EnableSmartGateway, } } func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) { id, err := strconv.Atoi(c.Param("id")) if err != nil { return nil, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format") } group, ok := h.groupManager.GetGroupByID(uint(id)) if !ok { return nil, errors.NewAPIError(errors.ErrResourceNotFound, "Group not found") } return group, nil } func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) { if req.Name != nil { group.Name = *req.Name } p := bluemonday.StripTagsPolicy() if req.DisplayName != nil { group.DisplayName = p.Sanitize(*req.DisplayName) } if req.Description != nil { group.Description = p.Sanitize(*req.Description) } if req.PollingStrategy != nil { group.PollingStrategy = models.PollingStrategy(*req.PollingStrategy) } if req.EnableProxy != nil { group.EnableProxy = *req.EnableProxy } if req.ChannelType != nil { group.ChannelType = *req.ChannelType } } // publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event. func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) { go func() { ctx := context.Background() event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason} eventData, _ := json.Marshal(event) _ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData) }() } // --- Handler 方法 --- func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) { var req CreateKeyGroupRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } if !isValidGroupName(req.Name) { response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name. Must be 3-30 characters, lowercase letters, numbers, hyphens, or underscores.")) return } p := bluemonday.StripTagsPolicy() sanitizedDisplayName := p.Sanitize(req.DisplayName) sanitizedDescription := p.Sanitize(req.Description) keyGroup := &models.KeyGroup{ Name: req.Name, DisplayName: sanitizedDisplayName, Description: sanitizedDescription, PollingStrategy: models.PollingStrategy(req.PollingStrategy), EnableProxy: req.EnableProxy, ChannelType: req.ChannelType, } if keyGroup.PollingStrategy == "" { keyGroup.PollingStrategy = models.StrategySequential } if keyGroup.ChannelType == "" { keyGroup.ChannelType = "gemini" } groupSettings := packGroupSettings(req.KeyGroupOperationalSettings) if err := h.groupManager.CreateKeyGroup(keyGroup, groupSettings); err != nil { response.Error(c, errors.ParseDBError(err)) return } h.publishGroupChangeEvent(keyGroup.ID, "group_created") response.Created(c, h.newKeyGroupResponse(keyGroup, 0)) } // 1. GET /keygroups - 返回所有组的列表 // 2. GET /keygroups/:id - 返回指定ID的单个组 func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) { if idStr := c.Param("id"); idStr != "" { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { response.Error(c, apiErr) return } keyCount := h.groupManager.GetKeyCount(group.ID) baseResponse := h.newKeyGroupResponse(group, keyCount) detailedResponse := KeyGroupDetailsResponse{ KeyGroupResponse: baseResponse, Settings: group.Settings, RequestConfig: group.RequestConfig, } response.Success(c, detailedResponse) return } allGroups := h.groupManager.GetAllGroups() responses := make([]KeyGroupResponse, 0, len(allGroups)) for _, group := range allGroups { keyCount := h.groupManager.GetKeyCount(group.ID) responses = append(responses, h.newKeyGroupResponse(group, keyCount)) } response.Success(c, responses) } func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { response.Error(c, apiErr) return } var req UpdateKeyGroupRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } if req.Name != nil && !isValidGroupName(*req.Name) { response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name format.")) return } applyUpdateRequestToGroup(&req, group) groupSettings := packGroupSettings(req.KeyGroupOperationalSettings) err := h.groupManager.UpdateKeyGroup(group, groupSettings, req.AllowedUpstreams, req.AllowedModels) if err != nil { response.Error(c, errors.ParseDBError(err)) return } h.publishGroupChangeEvent(group.ID, "group_updated") freshGroup, _ := h.groupManager.GetGroupByID(group.ID) keyCount := h.groupManager.GetKeyCount(freshGroup.ID) response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount)) } func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { response.Error(c, apiErr) return } groupName := group.Name if err := h.groupManager.DeleteKeyGroup(group.ID); err != nil { response.Error(c, errors.ParseDBError(err)) return } h.publishGroupChangeEvent(group.ID, "group_deleted") response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)}) } func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { response.Error(c, apiErr) return } stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error())) return } response.Success(c, stats) } func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) { group, apiErr := h.getGroupFromContext(c) if apiErr != nil { response.Error(c, apiErr) return } clonedGroup, err := h.groupManager.CloneKeyGroup(group.ID) if err != nil { response.Error(c, errors.ParseDBError(err)) return } keyCount := int64(len(clonedGroup.Mappings)) response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount)) } func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) { var payload []service.UpdateOrderPayload if err := c.ShouldBindJSON(&payload); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } if len(payload) == 0 { response.Success(c, gin.H{"message": "No order data to update."}) return } if err := h.groupManager.UpdateOrder(payload); err != nil { response.Error(c, errors.ParseDBError(err)) return } response.Success(c, gin.H{"message": "Group order updated successfully."}) }