// Filename: internal/handlers/apikey_handler.go (最终决战版) package handlers import ( "gemini-balancer/internal/errors" "gemini-balancer/internal/models" "gemini-balancer/internal/response" "gemini-balancer/internal/service" "gemini-balancer/internal/task" "net/http" "strconv" "strings" "github.com/gin-gonic/gin" "gorm.io/gorm" ) type APIKeyHandler struct { apiKeyService *service.APIKeyService db *gorm.DB keyImportService *service.KeyImportService keyValidationService *service.KeyValidationService } func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImportService *service.KeyImportService, keyValidationService *service.KeyValidationService) *APIKeyHandler { return &APIKeyHandler{ apiKeyService: apiKeyService, db: db, keyImportService: keyImportService, keyValidationService: keyValidationService, } } type BulkAddKeysToGroupRequest struct { KeyGroupID uint `json:"key_group_id" binding:"required"` Keys string `json:"keys" binding:"required"` ValidateOnImport bool `json:"validate_on_import"` } type BulkUnlinkKeysFromGroupRequest struct { KeyGroupID uint `json:"key_group_id" binding:"required"` Keys string `json:"keys" binding:"required"` } type BulkHardDeleteKeysRequest struct { Keys string `json:"keys" binding:"required"` } type BulkRestoreKeysRequest struct { Keys string `json:"keys" binding:"required"` } type UpdateAPIKeyRequest struct { Status *string `json:"status" binding:"omitempty,oneof=ACTIVE,PENDING_VALIDATION,COOLDOWN,DISABLED,BANNED"` } type UpdateMappingRequest struct { Status models.APIKeyStatus `json:"status" binding:"required,oneof=ACTIVE PENDING_VALIDATION COOLDOWN DISABLED BANNED"` } type BulkTestKeysRequest struct { KeyGroupID uint `json:"key_group_id" binding:"required"` Keys string `json:"keys" binding:"required"` } type RestoreKeysRequest struct { KeyIDs []uint `json:"key_ids" binding:"required,gt=0"` } type BulkTestKeysForGroupRequest struct { Keys string `json:"keys" binding:"required"` } type BulkActionFilter struct { Status []string `json:"status"` } type BulkActionRequest struct { Action string `json:"action" binding:"required,oneof=revalidate set_status delete"` NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` Filter BulkActionFilter `json:"filter" binding:"required"` } // --- Handler Methods --- // AddMultipleKeysToGroup handles adding/linking multiple keys to a specific group. func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) { var req BulkAddKeysToGroupRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } // [修正] 将请求的 context 传递给 service 层 taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return } response.Success(c, taskStatus) } // UnlinkMultipleKeysFromGroup handles unlinking multiple keys from a specific group. func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) { var req BulkUnlinkKeysFromGroupRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } // [修正] 将请求的 context 传递给 service 层 taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return } response.Success(c, taskStatus) } // HardDeleteMultipleKeys handles globally deleting multiple key entities. func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) { var req BulkHardDeleteKeysRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } // [修正] 将请求的 context 传递给 service 层 taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return } response.Success(c, taskStatus) } // RestoreMultipleKeys handles restoring multiple keys to ACTIVE status globally. func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) { var req BulkRestoreKeysRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } // [修正] 将请求的 context 传递给 service 层 taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return } response.Success(c, taskStatus) } func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) { var req BulkTestKeysRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } // [修正] 将请求的 context 传递给 service 层 taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return } response.Success(c, taskStatus) } func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) { var params models.APIKeyQueryParams if err := c.ShouldBindQuery(¶ms); err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error())) return } if params.IDs != "" { idStrs := strings.Split(params.IDs, ",") ids := make([]uint, 0, len(idStrs)) for _, s := range idStrs { id, err := strconv.ParseUint(s, 10, 64) if err == nil { ids = append(ids, uint(id)) } } if len(ids) > 0 { keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids) if err != nil { response.Error(c, &errors.APIError{ HTTPStatus: http.StatusInternalServerError, Code: "DATA_FETCH_ERROR", Message: err.Error(), }) return } response.Success(c, keys) return } } if params.Page <= 0 { params.Page = 1 } if params.PageSize <= 0 { params.PageSize = 20 } result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms) if err != nil { response.Error(c, errors.ParseDBError(err)) return } response.Success(c, result) } // ListKeysForGroup handles the GET /keygroups/:id/keys request. func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format")) return } var params models.APIKeyQueryParams if err := c.ShouldBindQuery(¶ms); err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error())) return } if params.Page <= 0 { params.Page = 1 } if params.PageSize <= 0 { params.PageSize = 20 } params.KeyGroupID = uint(groupID) paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms) if err != nil { response.Error(c, errors.ParseDBError(err)) return } response.Success(c, gin.H{ "items": paginatedResult.Items, "total": paginatedResult.Total, "page": paginatedResult.Page, "pages": paginatedResult.TotalPages, }) } func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format")) return } var req BulkTestKeysForGroupRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } // [修正] 将请求的 context 传递给 service 层 taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error())) return } response.Success(c, taskStatus) } // UpdateAPIKey is DEPRECATED. Status is now contextual to a group. func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) { err := errors.NewAPIError(errors.ErrBadRequest, "This endpoint is deprecated. Use 'PUT /keygroups/:id/apikeys/:keyId' to update key status within a group context.") response.Error(c, err) } // UpdateGroupAPIKeyMapping handles updating a key's status within a specific group. func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } keyID, err := strconv.ParseUint(c.Param("keyId"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Key ID format")) return } var req UpdateMappingRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status) if err != nil { var apiErr *errors.APIError if errors.As(err, &apiErr) { response.Error(c, apiErr) } else { response.Error(c, errors.ParseDBError(err)) } return } response.Success(c, updatedMapping) } // HardDeleteAPIKey handles globally deleting a single key entity. func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format")) return } if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil { response.Error(c, errors.ParseDBError(err)) return } response.Success(c, gin.H{"message": "API key globally deleted successfully"}) } // RestoreKeysInGroup 恢复指定Key的接口 func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } var req RestoreKeysRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs) if err != nil { var apiErr *errors.APIError if errors.As(err, &apiErr) { response.Error(c, apiErr) } else { response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error())) } return } response.Success(c, taskStatus) } // RestoreAllBannedInGroup 一键恢复所有Banned Key的接口 func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID)) if err != nil { var apiErr *errors.APIError if errors.As(err, &apiErr) { response.Error(c, apiErr) } else { response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error())) } return } response.Success(c, taskStatus) } // HandleBulkAction handles generic bulk actions on a key group based on server-side filters. func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } var req BulkActionRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error())) return } var task *task.Status var apiErr *errors.APIError switch req.Action { case "revalidate": // [修正] 将请求的 context 传递给 service 层 task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status) case "set_status": if req.NewStatus == "" { apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action") break } targetStatus := models.APIKeyStatus(req.NewStatus) task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus) case "delete": // [修正] 将请求的 context 传递给 service 层 task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status) default: apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action) } if apiErr != nil { response.Error(c, apiErr) return } if err != nil { var parsedErr *errors.APIError if errors.As(err, &parsedErr) { response.Error(c, parsedErr) } else { response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error())) } return } response.Success(c, task) } // ExportKeysForGroup handles requests to export all keys for a group based on status filters. func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) { groupID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format")) return } statuses := c.QueryArray("status") keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses) if err != nil { response.Error(c, errors.ParseDBError(err)) return } response.Success(c, keyStrings) }