Files
gemini-banlancer/internal/handlers/keygroup_handler.go
2025-11-20 12:24:05 +08:00

370 lines
12 KiB
Go

// Filename: internal/handlers/keygroup_handler.go
package handlers
import (
"encoding/json"
"fmt"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/response"
"gemini-balancer/internal/service"
"gemini-balancer/internal/store"
"regexp"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/microcosm-cc/bluemonday"
)
type KeyGroupHandler struct {
groupManager *service.GroupManager
store store.Store
queryService *service.DashboardQueryService
}
func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.DashboardQueryService) *KeyGroupHandler {
return &KeyGroupHandler{
groupManager: gm,
queryService: qs,
store: s,
}
}
// DTOs & 辅助函数
func isValidGroupName(name string) bool {
if name == "" {
return false
}
match, _ := regexp.MatchString("^[a-z0-9_-]{3,30}$", name)
return match
}
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
type KeyGroupOperationalSettings struct {
EnableKeyCheck *bool `json:"enable_key_check"`
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
KeyBlacklistThreshold *int `json:"key_blacklist_threshold"`
KeyCooldownMinutes *int `json:"key_cooldown_minutes"`
KeyCheckConcurrency *int `json:"key_check_concurrency"`
KeyCheckEndpoint *string `json:"key_check_endpoint"`
KeyCheckModel *string `json:"key_check_model"`
MaxRetries *int `json:"max_retries"`
EnableSmartGateway *bool `json:"enable_smart_gateway"`
}
type CreateKeyGroupRequest struct {
Name string `json:"name" binding:"required"`
DisplayName string `json:"display_name"`
Description string `json:"description"`
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy bool `json:"enable_proxy"`
ChannelType string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
}
type UpdateKeyGroupRequest struct {
Name *string `json:"name"`
DisplayName *string `json:"display_name"`
Description *string `json:"description"`
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy *bool `json:"enable_proxy"`
ChannelType *string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
// M:N associations
AllowedUpstreams []string `json:"allowed_upstreams"`
AllowedModels []string `json:"allowed_models"`
}
type KeyGroupResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
DisplayName string `json:"display_name"`
Description string `json:"description"`
PollingStrategy models.PollingStrategy `json:"polling_strategy"`
ChannelType string `json:"channel_type"`
EnableProxy bool `json:"enable_proxy"`
APIKeysCount int64 `json:"api_keys_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Order int `json:"order"`
AllowedModels []string `json:"allowed_models"`
AllowedUpstreams []string `json:"allowed_upstreams"`
}
// [NEW] Define the detailed response structure for a single group.
type KeyGroupDetailsResponse struct {
KeyGroupResponse
Settings *models.GroupSettings `json:"settings,omitempty"`
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
}
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
modelNames := make([]string, 0, len(mappings))
for _, mapping := range mappings {
if mapping != nil { // Safety check
modelNames = append(modelNames, mapping.ModelName)
}
}
return modelNames
}
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
urls := make([]string, 0, len(upstreams))
for _, upstream := range upstreams {
if upstream != nil { // Safety check
urls = append(urls, upstream.URL)
}
}
return urls
}
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
return KeyGroupResponse{
ID: group.ID,
Name: group.Name,
DisplayName: group.DisplayName,
Description: group.Description,
PollingStrategy: group.PollingStrategy,
ChannelType: group.ChannelType,
EnableProxy: group.EnableProxy,
APIKeysCount: keyCount,
CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt,
Order: group.Order,
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
}
}
// packGroupSettings is a helper to convert request-level operational settings
// into the model-level settings struct.
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
return &models.KeyGroupSettings{
EnableKeyCheck: settings.EnableKeyCheck,
KeyCheckIntervalMinutes: settings.KeyCheckIntervalMinutes,
KeyBlacklistThreshold: settings.KeyBlacklistThreshold,
KeyCooldownMinutes: settings.KeyCooldownMinutes,
KeyCheckConcurrency: settings.KeyCheckConcurrency,
KeyCheckEndpoint: settings.KeyCheckEndpoint,
KeyCheckModel: settings.KeyCheckModel,
MaxRetries: settings.MaxRetries,
EnableSmartGateway: settings.EnableSmartGateway,
}
}
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
return nil, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format")
}
group, ok := h.groupManager.GetGroupByID(uint(id))
if !ok {
return nil, errors.NewAPIError(errors.ErrResourceNotFound, "Group not found")
}
return group, nil
}
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
if req.Name != nil {
group.Name = *req.Name
}
p := bluemonday.StripTagsPolicy()
if req.DisplayName != nil {
group.DisplayName = p.Sanitize(*req.DisplayName)
}
if req.Description != nil {
group.Description = p.Sanitize(*req.Description)
}
if req.PollingStrategy != nil {
group.PollingStrategy = models.PollingStrategy(*req.PollingStrategy)
}
if req.EnableProxy != nil {
group.EnableProxy = *req.EnableProxy
}
if req.ChannelType != nil {
group.ChannelType = *req.ChannelType
}
}
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
go func() {
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
eventData, _ := json.Marshal(event)
h.store.Publish(models.TopicKeyStatusChanged, eventData)
}()
}
// --- Handler 方法 ---
func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
var req CreateKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
if !isValidGroupName(req.Name) {
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name. Must be 3-30 characters, lowercase letters, numbers, hyphens, or underscores."))
return
}
// The core logic remains, as it's specific to creation.
p := bluemonday.StripTagsPolicy()
sanitizedDisplayName := p.Sanitize(req.DisplayName)
sanitizedDescription := p.Sanitize(req.Description)
keyGroup := &models.KeyGroup{
Name: req.Name,
DisplayName: sanitizedDisplayName,
Description: sanitizedDescription,
PollingStrategy: models.PollingStrategy(req.PollingStrategy),
EnableProxy: req.EnableProxy,
ChannelType: req.ChannelType,
}
if keyGroup.PollingStrategy == "" {
keyGroup.PollingStrategy = models.StrategySequential
}
if keyGroup.ChannelType == "" {
keyGroup.ChannelType = "gemini"
}
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
if err := h.groupManager.CreateKeyGroup(keyGroup, groupSettings); err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
h.publishGroupChangeEvent(keyGroup.ID, "group_created")
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
}
// 统一的处理器可以处理两种情况:
// 1. GET /keygroups - 返回所有组的列表
// 2. GET /keygroups/:id - 返回指定ID的单个组
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
// Case 1: Get a single group
if idStr := c.Param("id"); idStr != "" {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
keyCount := h.groupManager.GetKeyCount(group.ID)
baseResponse := h.newKeyGroupResponse(group, keyCount)
detailedResponse := KeyGroupDetailsResponse{
KeyGroupResponse: baseResponse,
Settings: group.Settings,
RequestConfig: group.RequestConfig,
}
response.Success(c, detailedResponse)
return
}
// Case 2: Get all groups
allGroups := h.groupManager.GetAllGroups()
responses := make([]KeyGroupResponse, 0, len(allGroups))
for _, group := range allGroups {
keyCount := h.groupManager.GetKeyCount(group.ID)
responses = append(responses, h.newKeyGroupResponse(group, keyCount))
}
response.Success(c, responses)
}
// UpdateKeyGroup
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
var req UpdateKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
if req.Name != nil && !isValidGroupName(*req.Name) {
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name format."))
return
}
applyUpdateRequestToGroup(&req, group)
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
err := h.groupManager.UpdateKeyGroup(group, groupSettings, req.AllowedUpstreams, req.AllowedModels)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
h.publishGroupChangeEvent(group.ID, "group_updated")
freshGroup, _ := h.groupManager.GetGroupByID(group.ID)
keyCount := h.groupManager.GetKeyCount(freshGroup.ID)
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
}
// DeleteKeyGroup
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
groupName := group.Name
if err := h.groupManager.DeleteKeyGroup(group.ID); err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
h.publishGroupChangeEvent(group.ID, "group_deleted")
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
}
// GetKeyGroupStats
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
stats, err := h.queryService.GetGroupStats(group.ID)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
return
}
response.Success(c, stats)
}
func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
clonedGroup, err := h.groupManager.CloneKeyGroup(group.ID)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
keyCount := int64(len(clonedGroup.Mappings))
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
}
// 更新分组排序
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
var payload []service.UpdateOrderPayload
if err := c.ShouldBindJSON(&payload); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
if len(payload) == 0 {
response.Success(c, gin.H{"message": "No order data to update."})
return
}
if err := h.groupManager.UpdateOrder(payload); err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
response.Success(c, gin.H{"message": "Group order updated successfully."})
}