From a952ece7d205fb6cb95384c03f3d08ec683eb56c Mon Sep 17 00:00:00 2001 From: Nagaraj P N Date: Sun, 23 Mar 2025 20:35:51 +0530 Subject: [PATCH] nvtzvault: support up to 72 sessions Also - add suspend/resume callbacks - close sessions in case of app crash - perform dynamic memory allocation Jira ESSS-1713 Jira ESSS-1830 Bug 5225204 Change-Id: Ic3bd6d74a530bd10e2e5758b9c59b4a552c7d4b1 Signed-off-by: Nagaraj P N Reviewed-on: https://git-master.nvidia.com/r/c/linux-nv-oot/+/3344180 (cherry picked from commit 69d8511e80e2e45e9f8aea4b2b736c632d7894a3) Reviewed-on: https://git-master.nvidia.com/r/c/linux-nv-oot/+/3324497 Reviewed-by: Leo Chiu Reviewed-by: Sandeep Trasi GVS: buildbot_gerritrpt --- drivers/nvtzvault/nvtzvault-helper.c | 12 +- drivers/nvtzvault/nvtzvault-main.c | 158 +++++++++++++++++++-------- 2 files changed, 115 insertions(+), 55 deletions(-) diff --git a/drivers/nvtzvault/nvtzvault-helper.c b/drivers/nvtzvault/nvtzvault-helper.c index 72200a31..0518ec13 100644 --- a/drivers/nvtzvault/nvtzvault-helper.c +++ b/drivers/nvtzvault/nvtzvault-helper.c @@ -187,21 +187,13 @@ int nvtzvault_tee_check_overflow_and_write(struct nvtzvault_tee_buf_context *ctx goto end; } - local_buf = kzalloc(size, GFP_KERNEL); - if (!local_buf) { - NVTZVAULT_ERR("Failed to allocate memory\n"); - result = -ENOMEM; - goto end; - } - if (is_user_space) { - result = copy_from_user(local_buf, (void __user *)data, size); + result = copy_from_user(&ctx->buf_ptr[ctx->current_offset], + (void __user *)data, size); if (result != 0) { NVTZVAULT_ERR("%s(): Failed to copy_from_user %d\n", __func__, result); goto end; } - for (i = 0U; i < size; i++) - ctx->buf_ptr[ctx->current_offset + i] = ((uint8_t *)local_buf)[i]; } else { for (i = 0U; i < size; i++) ctx->buf_ptr[ctx->current_offset + i] = ((uint8_t *)data)[i]; diff --git a/drivers/nvtzvault/nvtzvault-main.c b/drivers/nvtzvault/nvtzvault-main.c index a0b0ca99..4046b6dd 100644 --- a/drivers/nvtzvault/nvtzvault-main.c +++ b/drivers/nvtzvault/nvtzvault-main.c @@ -18,7 +18,7 @@ #define NVTZVAULT_TA_UUID_LEN (16U) #define NVTZVAULT_TA_DEVICE_NAME_LEN (17U) #define NVTZVAULT_BUFFER_SIZE (8192U) -#define NVTZVAULT_MAX_SESSIONS (32U) +#define NVTZVAULT_MAX_SESSIONS (72U) enum nvtzvault_session_op_type { NVTZVAULT_SESSION_OP_OPEN, @@ -51,6 +51,8 @@ struct nvtzvault_dev { struct device *dev; struct mutex lock; void *data_buf; + atomic_t in_suspend_state; + atomic_t total_active_session_count; } g_nvtzvault_dev; struct nvtzvault_ctx { @@ -62,48 +64,6 @@ struct nvtzvault_ctx { uint32_t driver_id; }; -static int nvtzvault_ta_dev_open(struct inode *inode, struct file *filp) -{ - struct miscdevice *misc; - struct nvtzvault_ctx *ctx = NULL; - int32_t ret; - - misc = filp->private_data; - - ctx = kzalloc(sizeof(struct nvtzvault_ctx), GFP_KERNEL); - if (!ctx) { - NVTZVAULT_ERR("%s: Failed to allocate context memory\n", __func__); - return -ENOMEM; - } - - ctx->node_id = misc->this_device->id; - ctx->task_opcode = g_nvtzvault_dev.ta[ctx->node_id].task_opcode; - ctx->driver_id = g_nvtzvault_dev.ta[ctx->node_id].driver_id; - ctx->is_session_open = false; - - ret = nvtzvault_tee_buf_context_init(&ctx->buf_ctx, g_nvtzvault_dev.data_buf, - NVTZVAULT_BUFFER_SIZE); - if (ret != 0) { - NVTZVAULT_ERR("%s: Failed to initialize buffer context\n", __func__); - kfree(ctx); - return ret; - } - memset(ctx->session_bitmap, 0, sizeof(ctx->session_bitmap)); - - filp->private_data = ctx; - - return 0; -} - -static int nvtzvault_ta_dev_release(struct inode *inode, struct file *filp) -{ - struct nvtzvault_ctx *ctx = filp->private_data; - - kfree(ctx); - - return 0; -} - static bool is_session_open(struct nvtzvault_ctx *ctx, uint32_t session_id) { uint32_t byte_idx = session_id / 8U; @@ -128,6 +88,7 @@ static void set_session_open(struct nvtzvault_ctx *ctx, uint32_t session_id) } ctx->session_bitmap[byte_idx] |= (1U << bit_idx); + atomic_inc(&g_nvtzvault_dev.total_active_session_count); } static void set_session_closed(struct nvtzvault_ctx *ctx, uint32_t session_id) @@ -141,6 +102,8 @@ static void set_session_closed(struct nvtzvault_ctx *ctx, uint32_t session_id) } ctx->session_bitmap[byte_idx] &= ~(1U << bit_idx); + + atomic_dec(&g_nvtzvault_dev.total_active_session_count); } static int nvtzvault_write_process_name(struct nvtzvault_tee_buf_context *ctx) @@ -381,6 +344,59 @@ static int nvtzvault_close_session(struct nvtzvault_ctx *ctx, return ret; } +static int nvtzvault_ta_dev_open(struct inode *inode, struct file *filp) +{ + struct miscdevice *misc; + struct nvtzvault_ctx *ctx = NULL; + int32_t ret; + + misc = filp->private_data; + + ctx = kzalloc(sizeof(struct nvtzvault_ctx), GFP_KERNEL); + if (!ctx) { + NVTZVAULT_ERR("%s: Failed to allocate context memory\n", __func__); + return -ENOMEM; + } + + ctx->node_id = misc->this_device->id; + ctx->task_opcode = g_nvtzvault_dev.ta[ctx->node_id].task_opcode; + ctx->driver_id = g_nvtzvault_dev.ta[ctx->node_id].driver_id; + ctx->is_session_open = false; + + ret = nvtzvault_tee_buf_context_init(&ctx->buf_ctx, g_nvtzvault_dev.data_buf, + NVTZVAULT_BUFFER_SIZE); + if (ret != 0) { + NVTZVAULT_ERR("%s: Failed to initialize buffer context\n", __func__); + kfree(ctx); + return ret; + } + memset(ctx->session_bitmap, 0, sizeof(ctx->session_bitmap)); + + filp->private_data = ctx; + + return 0; +} + +static int nvtzvault_ta_dev_release(struct inode *inode, struct file *filp) +{ + struct nvtzvault_ctx *ctx = filp->private_data; + struct nvtzvault_close_session_ctl close_session_ctl; + + if (atomic_read(&g_nvtzvault_dev.total_active_session_count) > 0) { + for (uint32_t i = 0; i < NVTZVAULT_MAX_SESSIONS; i++) { + if (is_session_open(ctx, i)) { + NVTZVAULT_ERR("%s: closing session %u\n", __func__, i); + close_session_ctl.session_id = i; + nvtzvault_close_session(ctx, &close_session_ctl); + } + } + } + + kfree(ctx); + + return 0; +} + static long nvtzvault_ta_dev_ioctl(struct file *filp, unsigned int ioctl_num, unsigned long arg) { struct nvtzvault_ctx *ctx = filp->private_data; @@ -394,6 +410,11 @@ static long nvtzvault_ta_dev_ioctl(struct file *filp, unsigned int ioctl_num, un return -EPERM; } + if (atomic_read(&g_nvtzvault_dev.in_suspend_state)) { + NVTZVAULT_ERR("%s(): device is in suspend state\n", __func__); + return -EBUSY; + } + mutex_lock(&g_nvtzvault_dev.lock); switch (ioctl_num) { @@ -606,6 +627,8 @@ static int nvtzvault_probe(struct platform_device *pdev) int len, i; int ret; + dev_info(dev, "probe start\n"); + if (!nvtzvault_node) return -ENODEV; @@ -695,8 +718,6 @@ static int nvtzvault_probe(struct platform_device *pdev) ta->task_opcode = task_opcode; ta->driver_id = driver_id; - dev_info(dev, "TA ID: %u, UUID: %16ph\n", ta_id, ta_uuid); - ret = nvtzvault_ta_create_dev_node(ta->dev, ta->id); if (ret != 0) goto fail; @@ -711,8 +732,12 @@ static int nvtzvault_probe(struct platform_device *pdev) goto fail; } + atomic_set(&g_nvtzvault_dev.total_active_session_count, 0); + mutex_init(&g_nvtzvault_dev.lock); + dev_info(dev, "probe success\n"); + return 0; fail: @@ -758,6 +783,47 @@ static int nvtzvault_remove_wrapper(struct platform_device *pdev) } #endif +static void nvtzvault_shutdown(struct platform_device *pdev) +{ + atomic_set(&g_nvtzvault_dev.in_suspend_state, 1); +} + +#if defined(CONFIG_PM) +static int nvtzvault_suspend(struct device *dev) +{ + struct platform_device *pdev = to_platform_device(dev); + + /* Add print to log in nvlog buffer */ + dev_err(dev, "%s start\n", __func__); + + if (atomic_read(&g_nvtzvault_dev.total_active_session_count) > 0) + return -EBUSY; + + nvtzvault_shutdown(pdev); + + /* Add print to log in nvlog buffer */ + dev_err(dev, "%s done\n", __func__); + + return 0; +} + +static int nvtzvault_resume(struct device *dev) +{ + /* Add print to log in nvlog buffer */ + dev_err(dev, "%s start\n", __func__); + atomic_set(&g_nvtzvault_dev.in_suspend_state, 0); + /* Add print to log in nvlog buffer */ + dev_err(dev, "%s done\n", __func__); + return 0; +} +static const struct dev_pm_ops nvtzvault_pm_ops = { + .suspend = nvtzvault_suspend, + .resume = nvtzvault_resume, +}; + +#endif /* CONFIG_PM */ + + static const struct of_device_id nvtzvault_match[] = { {.compatible = "nvidia,nvtzvault"}, {} @@ -767,10 +833,12 @@ MODULE_DEVICE_TABLE(of, nvtzvault_match); static struct platform_driver nvtzvault_driver = { .probe = nvtzvault_probe, .remove = nvtzvault_remove_wrapper, + .shutdown = nvtzvault_shutdown, .driver = { .owner = THIS_MODULE, .name = "nvtzvault", .of_match_table = of_match_ptr(nvtzvault_match), + .pm = &nvtzvault_pm_ops, } };