370 lines
12 KiB
Go
370 lines
12 KiB
Go
// Filename: internal/handlers/keygroup_handler.go
|
|
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/response"
|
|
"gemini-balancer/internal/service"
|
|
"gemini-balancer/internal/store"
|
|
"regexp"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/microcosm-cc/bluemonday"
|
|
)
|
|
|
|
type KeyGroupHandler struct {
|
|
groupManager *service.GroupManager
|
|
store store.Store
|
|
queryService *service.DashboardQueryService
|
|
}
|
|
|
|
func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.DashboardQueryService) *KeyGroupHandler {
|
|
return &KeyGroupHandler{
|
|
groupManager: gm,
|
|
queryService: qs,
|
|
store: s,
|
|
}
|
|
}
|
|
|
|
// DTOs & 辅助函数
|
|
func isValidGroupName(name string) bool {
|
|
if name == "" {
|
|
return false
|
|
}
|
|
match, _ := regexp.MatchString("^[a-z0-9_-]{3,30}$", name)
|
|
return match
|
|
}
|
|
|
|
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
|
|
type KeyGroupOperationalSettings struct {
|
|
EnableKeyCheck *bool `json:"enable_key_check"`
|
|
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
|
|
KeyBlacklistThreshold *int `json:"key_blacklist_threshold"`
|
|
KeyCooldownMinutes *int `json:"key_cooldown_minutes"`
|
|
KeyCheckConcurrency *int `json:"key_check_concurrency"`
|
|
KeyCheckEndpoint *string `json:"key_check_endpoint"`
|
|
KeyCheckModel *string `json:"key_check_model"`
|
|
MaxRetries *int `json:"max_retries"`
|
|
EnableSmartGateway *bool `json:"enable_smart_gateway"`
|
|
}
|
|
|
|
type CreateKeyGroupRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
DisplayName string `json:"display_name"`
|
|
Description string `json:"description"`
|
|
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
|
EnableProxy bool `json:"enable_proxy"`
|
|
ChannelType string `json:"channel_type"`
|
|
|
|
// Embed shared operational settings
|
|
KeyGroupOperationalSettings
|
|
}
|
|
|
|
type UpdateKeyGroupRequest struct {
|
|
Name *string `json:"name"`
|
|
DisplayName *string `json:"display_name"`
|
|
Description *string `json:"description"`
|
|
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
|
EnableProxy *bool `json:"enable_proxy"`
|
|
ChannelType *string `json:"channel_type"`
|
|
|
|
// Embed shared operational settings
|
|
KeyGroupOperationalSettings
|
|
|
|
// M:N associations
|
|
AllowedUpstreams []string `json:"allowed_upstreams"`
|
|
AllowedModels []string `json:"allowed_models"`
|
|
}
|
|
|
|
type KeyGroupResponse struct {
|
|
ID uint `json:"id"`
|
|
Name string `json:"name"`
|
|
DisplayName string `json:"display_name"`
|
|
Description string `json:"description"`
|
|
PollingStrategy models.PollingStrategy `json:"polling_strategy"`
|
|
ChannelType string `json:"channel_type"`
|
|
EnableProxy bool `json:"enable_proxy"`
|
|
APIKeysCount int64 `json:"api_keys_count"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
Order int `json:"order"`
|
|
AllowedModels []string `json:"allowed_models"`
|
|
AllowedUpstreams []string `json:"allowed_upstreams"`
|
|
}
|
|
|
|
// [NEW] Define the detailed response structure for a single group.
|
|
type KeyGroupDetailsResponse struct {
|
|
KeyGroupResponse
|
|
Settings *models.GroupSettings `json:"settings,omitempty"`
|
|
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
|
|
}
|
|
|
|
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
|
|
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
|
|
modelNames := make([]string, 0, len(mappings))
|
|
for _, mapping := range mappings {
|
|
if mapping != nil { // Safety check
|
|
modelNames = append(modelNames, mapping.ModelName)
|
|
}
|
|
}
|
|
return modelNames
|
|
}
|
|
|
|
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
|
|
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
|
|
urls := make([]string, 0, len(upstreams))
|
|
for _, upstream := range upstreams {
|
|
if upstream != nil { // Safety check
|
|
urls = append(urls, upstream.URL)
|
|
}
|
|
}
|
|
return urls
|
|
}
|
|
|
|
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
|
|
return KeyGroupResponse{
|
|
ID: group.ID,
|
|
Name: group.Name,
|
|
DisplayName: group.DisplayName,
|
|
Description: group.Description,
|
|
PollingStrategy: group.PollingStrategy,
|
|
ChannelType: group.ChannelType,
|
|
EnableProxy: group.EnableProxy,
|
|
APIKeysCount: keyCount,
|
|
CreatedAt: group.CreatedAt,
|
|
UpdatedAt: group.UpdatedAt,
|
|
Order: group.Order,
|
|
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
|
|
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
|
|
}
|
|
}
|
|
|
|
// packGroupSettings is a helper to convert request-level operational settings
|
|
// into the model-level settings struct.
|
|
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
|
|
return &models.KeyGroupSettings{
|
|
EnableKeyCheck: settings.EnableKeyCheck,
|
|
KeyCheckIntervalMinutes: settings.KeyCheckIntervalMinutes,
|
|
KeyBlacklistThreshold: settings.KeyBlacklistThreshold,
|
|
KeyCooldownMinutes: settings.KeyCooldownMinutes,
|
|
KeyCheckConcurrency: settings.KeyCheckConcurrency,
|
|
KeyCheckEndpoint: settings.KeyCheckEndpoint,
|
|
KeyCheckModel: settings.KeyCheckModel,
|
|
MaxRetries: settings.MaxRetries,
|
|
EnableSmartGateway: settings.EnableSmartGateway,
|
|
}
|
|
}
|
|
|
|
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
|
|
id, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
return nil, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format")
|
|
}
|
|
group, ok := h.groupManager.GetGroupByID(uint(id))
|
|
if !ok {
|
|
return nil, errors.NewAPIError(errors.ErrResourceNotFound, "Group not found")
|
|
}
|
|
return group, nil
|
|
}
|
|
|
|
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
|
|
if req.Name != nil {
|
|
group.Name = *req.Name
|
|
}
|
|
p := bluemonday.StripTagsPolicy()
|
|
if req.DisplayName != nil {
|
|
group.DisplayName = p.Sanitize(*req.DisplayName)
|
|
}
|
|
if req.Description != nil {
|
|
group.Description = p.Sanitize(*req.Description)
|
|
}
|
|
if req.PollingStrategy != nil {
|
|
group.PollingStrategy = models.PollingStrategy(*req.PollingStrategy)
|
|
}
|
|
if req.EnableProxy != nil {
|
|
group.EnableProxy = *req.EnableProxy
|
|
}
|
|
if req.ChannelType != nil {
|
|
group.ChannelType = *req.ChannelType
|
|
}
|
|
}
|
|
|
|
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
|
|
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
|
|
go func() {
|
|
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
|
|
eventData, _ := json.Marshal(event)
|
|
h.store.Publish(models.TopicKeyStatusChanged, eventData)
|
|
}()
|
|
}
|
|
|
|
// --- Handler 方法 ---
|
|
|
|
func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
|
|
var req CreateKeyGroupRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
if !isValidGroupName(req.Name) {
|
|
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name. Must be 3-30 characters, lowercase letters, numbers, hyphens, or underscores."))
|
|
return
|
|
}
|
|
|
|
// The core logic remains, as it's specific to creation.
|
|
p := bluemonday.StripTagsPolicy()
|
|
sanitizedDisplayName := p.Sanitize(req.DisplayName)
|
|
sanitizedDescription := p.Sanitize(req.Description)
|
|
keyGroup := &models.KeyGroup{
|
|
Name: req.Name,
|
|
DisplayName: sanitizedDisplayName,
|
|
Description: sanitizedDescription,
|
|
PollingStrategy: models.PollingStrategy(req.PollingStrategy),
|
|
EnableProxy: req.EnableProxy,
|
|
ChannelType: req.ChannelType,
|
|
}
|
|
if keyGroup.PollingStrategy == "" {
|
|
keyGroup.PollingStrategy = models.StrategySequential
|
|
}
|
|
if keyGroup.ChannelType == "" {
|
|
keyGroup.ChannelType = "gemini"
|
|
}
|
|
|
|
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
|
|
if err := h.groupManager.CreateKeyGroup(keyGroup, groupSettings); err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
h.publishGroupChangeEvent(keyGroup.ID, "group_created")
|
|
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
|
|
}
|
|
|
|
// 统一的处理器可以处理两种情况:
|
|
// 1. GET /keygroups - 返回所有组的列表
|
|
// 2. GET /keygroups/:id - 返回指定ID的单个组
|
|
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
|
// Case 1: Get a single group
|
|
if idStr := c.Param("id"); idStr != "" {
|
|
group, apiErr := h.getGroupFromContext(c)
|
|
if apiErr != nil {
|
|
response.Error(c, apiErr)
|
|
return
|
|
}
|
|
keyCount := h.groupManager.GetKeyCount(group.ID)
|
|
baseResponse := h.newKeyGroupResponse(group, keyCount)
|
|
detailedResponse := KeyGroupDetailsResponse{
|
|
KeyGroupResponse: baseResponse,
|
|
Settings: group.Settings,
|
|
RequestConfig: group.RequestConfig,
|
|
}
|
|
response.Success(c, detailedResponse)
|
|
return
|
|
}
|
|
// Case 2: Get all groups
|
|
allGroups := h.groupManager.GetAllGroups()
|
|
responses := make([]KeyGroupResponse, 0, len(allGroups))
|
|
for _, group := range allGroups {
|
|
keyCount := h.groupManager.GetKeyCount(group.ID)
|
|
responses = append(responses, h.newKeyGroupResponse(group, keyCount))
|
|
}
|
|
response.Success(c, responses)
|
|
}
|
|
|
|
// UpdateKeyGroup
|
|
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
|
|
group, apiErr := h.getGroupFromContext(c)
|
|
if apiErr != nil {
|
|
response.Error(c, apiErr)
|
|
return
|
|
}
|
|
var req UpdateKeyGroupRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
if req.Name != nil && !isValidGroupName(*req.Name) {
|
|
response.Error(c, errors.NewAPIError(errors.ErrValidation, "Invalid group name format."))
|
|
return
|
|
}
|
|
applyUpdateRequestToGroup(&req, group)
|
|
groupSettings := packGroupSettings(req.KeyGroupOperationalSettings)
|
|
err := h.groupManager.UpdateKeyGroup(group, groupSettings, req.AllowedUpstreams, req.AllowedModels)
|
|
if err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
h.publishGroupChangeEvent(group.ID, "group_updated")
|
|
freshGroup, _ := h.groupManager.GetGroupByID(group.ID)
|
|
keyCount := h.groupManager.GetKeyCount(freshGroup.ID)
|
|
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
|
|
}
|
|
|
|
// DeleteKeyGroup
|
|
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
|
|
group, apiErr := h.getGroupFromContext(c)
|
|
if apiErr != nil {
|
|
response.Error(c, apiErr)
|
|
return
|
|
}
|
|
groupName := group.Name
|
|
if err := h.groupManager.DeleteKeyGroup(group.ID); err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
h.publishGroupChangeEvent(group.ID, "group_deleted")
|
|
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
|
|
}
|
|
|
|
// GetKeyGroupStats
|
|
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
|
|
group, apiErr := h.getGroupFromContext(c)
|
|
if apiErr != nil {
|
|
response.Error(c, apiErr)
|
|
return
|
|
}
|
|
stats, err := h.queryService.GetGroupStats(group.ID)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
|
|
return
|
|
}
|
|
response.Success(c, stats)
|
|
}
|
|
|
|
func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
|
|
group, apiErr := h.getGroupFromContext(c)
|
|
if apiErr != nil {
|
|
response.Error(c, apiErr)
|
|
return
|
|
}
|
|
clonedGroup, err := h.groupManager.CloneKeyGroup(group.ID)
|
|
if err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
keyCount := int64(len(clonedGroup.Mappings))
|
|
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
|
|
}
|
|
|
|
// 更新分组排序
|
|
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
|
|
var payload []service.UpdateOrderPayload
|
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
if len(payload) == 0 {
|
|
response.Success(c, gin.H{"message": "No order data to update."})
|
|
return
|
|
}
|
|
if err := h.groupManager.UpdateOrder(payload); err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
response.Success(c, gin.H{"message": "Group order updated successfully."})
|
|
}
|