415 lines
14 KiB
Go
415 lines
14 KiB
Go
// Filename: internal/handlers/apikey_handler.go (最终决战版)
|
|
package handlers
|
|
|
|
import (
|
|
"gemini-balancer/internal/errors"
|
|
"gemini-balancer/internal/models"
|
|
"gemini-balancer/internal/response"
|
|
"gemini-balancer/internal/service"
|
|
"gemini-balancer/internal/task"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type APIKeyHandler struct {
|
|
apiKeyService *service.APIKeyService
|
|
db *gorm.DB
|
|
keyImportService *service.KeyImportService
|
|
keyValidationService *service.KeyValidationService
|
|
}
|
|
|
|
func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImportService *service.KeyImportService, keyValidationService *service.KeyValidationService) *APIKeyHandler {
|
|
return &APIKeyHandler{
|
|
apiKeyService: apiKeyService,
|
|
db: db,
|
|
keyImportService: keyImportService,
|
|
keyValidationService: keyValidationService,
|
|
}
|
|
}
|
|
|
|
type BulkAddKeysToGroupRequest struct {
|
|
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
|
Keys string `json:"keys" binding:"required"`
|
|
ValidateOnImport bool `json:"validate_on_import"`
|
|
}
|
|
|
|
type BulkUnlinkKeysFromGroupRequest struct {
|
|
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
|
Keys string `json:"keys" binding:"required"`
|
|
}
|
|
|
|
type BulkHardDeleteKeysRequest struct {
|
|
Keys string `json:"keys" binding:"required"`
|
|
}
|
|
|
|
type BulkRestoreKeysRequest struct {
|
|
Keys string `json:"keys" binding:"required"`
|
|
}
|
|
|
|
type UpdateAPIKeyRequest struct {
|
|
Status *string `json:"status" binding:"omitempty,oneof=ACTIVE,PENDING_VALIDATION,COOLDOWN,DISABLED,BANNED"`
|
|
}
|
|
|
|
type UpdateMappingRequest struct {
|
|
Status models.APIKeyStatus `json:"status" binding:"required,oneof=ACTIVE PENDING_VALIDATION COOLDOWN DISABLED BANNED"`
|
|
}
|
|
|
|
type BulkTestKeysRequest struct {
|
|
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
|
Keys string `json:"keys" binding:"required"`
|
|
}
|
|
|
|
type RestoreKeysRequest struct {
|
|
KeyIDs []uint `json:"key_ids" binding:"required,gt=0"`
|
|
}
|
|
type BulkTestKeysForGroupRequest struct {
|
|
Keys string `json:"keys" binding:"required"`
|
|
}
|
|
|
|
type BulkActionFilter struct {
|
|
Status []string `json:"status"`
|
|
}
|
|
type BulkActionRequest struct {
|
|
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
|
|
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
|
|
Filter BulkActionFilter `json:"filter" binding:"required"`
|
|
}
|
|
|
|
// --- Handler Methods ---
|
|
|
|
// AddMultipleKeysToGroup handles adding/linking multiple keys to a specific group.
|
|
func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
|
|
var req BulkAddKeysToGroupRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
// UnlinkMultipleKeysFromGroup handles unlinking multiple keys from a specific group.
|
|
func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
|
|
var req BulkUnlinkKeysFromGroupRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
// HardDeleteMultipleKeys handles globally deleting multiple key entities.
|
|
func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
|
|
var req BulkHardDeleteKeysRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
// RestoreMultipleKeys handles restoring multiple keys to ACTIVE status globally.
|
|
func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
|
|
var req BulkRestoreKeysRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
|
|
var req BulkTestKeysRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
|
var params models.APIKeyQueryParams
|
|
if err := c.ShouldBindQuery(¶ms); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
|
return
|
|
}
|
|
if params.IDs != "" {
|
|
idStrs := strings.Split(params.IDs, ",")
|
|
ids := make([]uint, 0, len(idStrs))
|
|
for _, s := range idStrs {
|
|
id, err := strconv.ParseUint(s, 10, 64)
|
|
if err == nil {
|
|
ids = append(ids, uint(id))
|
|
}
|
|
}
|
|
if len(ids) > 0 {
|
|
keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
|
|
if err != nil {
|
|
response.Error(c, &errors.APIError{
|
|
HTTPStatus: http.StatusInternalServerError,
|
|
Code: "DATA_FETCH_ERROR",
|
|
Message: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
response.Success(c, keys)
|
|
return
|
|
}
|
|
}
|
|
if params.Page <= 0 {
|
|
params.Page = 1
|
|
}
|
|
if params.PageSize <= 0 {
|
|
params.PageSize = 20
|
|
}
|
|
result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms)
|
|
if err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
response.Success(c, result)
|
|
}
|
|
|
|
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
|
|
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
|
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
|
return
|
|
}
|
|
var params models.APIKeyQueryParams
|
|
if err := c.ShouldBindQuery(¶ms); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
|
return
|
|
}
|
|
if params.Page <= 0 {
|
|
params.Page = 1
|
|
}
|
|
if params.PageSize <= 0 {
|
|
params.PageSize = 20
|
|
}
|
|
params.KeyGroupID = uint(groupID)
|
|
paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms)
|
|
if err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
response.Success(c, gin.H{
|
|
"items": paginatedResult.Items,
|
|
"total": paginatedResult.Total,
|
|
"page": paginatedResult.Page,
|
|
"pages": paginatedResult.TotalPages,
|
|
})
|
|
}
|
|
|
|
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
|
|
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
|
return
|
|
}
|
|
var req BulkTestKeysForGroupRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
// UpdateAPIKey is DEPRECATED. Status is now contextual to a group.
|
|
func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
|
|
err := errors.NewAPIError(errors.ErrBadRequest, "This endpoint is deprecated. Use 'PUT /keygroups/:id/apikeys/:keyId' to update key status within a group context.")
|
|
response.Error(c, err)
|
|
}
|
|
|
|
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
|
|
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
|
|
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
|
return
|
|
}
|
|
keyID, err := strconv.ParseUint(c.Param("keyId"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Key ID format"))
|
|
return
|
|
}
|
|
var req UpdateMappingRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
|
|
if err != nil {
|
|
var apiErr *errors.APIError
|
|
if errors.As(err, &apiErr) {
|
|
response.Error(c, apiErr)
|
|
} else {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
}
|
|
return
|
|
}
|
|
response.Success(c, updatedMapping)
|
|
}
|
|
|
|
// HardDeleteAPIKey handles globally deleting a single key entity.
|
|
func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
|
|
id, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
|
return
|
|
}
|
|
if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
response.Success(c, gin.H{"message": "API key globally deleted successfully"})
|
|
}
|
|
|
|
// RestoreKeysInGroup 恢复指定Key的接口
|
|
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
|
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
|
return
|
|
}
|
|
var req RestoreKeysRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
|
|
if err != nil {
|
|
var apiErr *errors.APIError
|
|
if errors.As(err, &apiErr) {
|
|
response.Error(c, apiErr)
|
|
} else {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
|
}
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
|
|
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
|
|
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
|
return
|
|
}
|
|
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
|
|
if err != nil {
|
|
var apiErr *errors.APIError
|
|
if errors.As(err, &apiErr) {
|
|
response.Error(c, apiErr)
|
|
} else {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
|
}
|
|
return
|
|
}
|
|
response.Success(c, taskStatus)
|
|
}
|
|
|
|
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
|
|
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
|
|
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
|
return
|
|
}
|
|
var req BulkActionRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
|
return
|
|
}
|
|
var task *task.Status
|
|
var apiErr *errors.APIError
|
|
switch req.Action {
|
|
case "revalidate":
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
|
|
case "set_status":
|
|
if req.NewStatus == "" {
|
|
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
|
|
break
|
|
}
|
|
targetStatus := models.APIKeyStatus(req.NewStatus)
|
|
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
|
|
case "delete":
|
|
// [修正] 将请求的 context 传递给 service 层
|
|
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
|
|
default:
|
|
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
|
|
}
|
|
if apiErr != nil {
|
|
response.Error(c, apiErr)
|
|
return
|
|
}
|
|
if err != nil {
|
|
var parsedErr *errors.APIError
|
|
if errors.As(err, &parsedErr) {
|
|
response.Error(c, parsedErr)
|
|
} else {
|
|
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
|
}
|
|
return
|
|
}
|
|
response.Success(c, task)
|
|
}
|
|
|
|
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
|
|
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
|
|
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
|
return
|
|
}
|
|
statuses := c.QueryArray("status")
|
|
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
|
|
if err != nil {
|
|
response.Error(c, errors.ParseDBError(err))
|
|
return
|
|
}
|
|
response.Success(c, keyStrings)
|
|
}
|