diff --git a/drivers/virt/tegra/tegra_hv_pm_ctl.c b/drivers/virt/tegra/tegra_hv_pm_ctl.c index f6d30e8e..f82e92d6 100644 --- a/drivers/virt/tegra/tegra_hv_pm_ctl.c +++ b/drivers/virt/tegra/tegra_hv_pm_ctl.c @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0-only */ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. */ #include @@ -17,6 +17,7 @@ #include #include #include +#include #ifdef CONFIG_PM_SLEEP #include @@ -48,7 +49,7 @@ static struct { uint32_t client_pid; bool suspend_response; } user_client[MAX_USER_CLIENT]; -static uint32_t user_client_count; +static atomic_t user_client_count; #endif struct tegra_hv_pm_ctl { @@ -109,7 +110,7 @@ static int tegra_hv_pm_ctl_trigger_guest_suspend(u32 vmid) static int do_wait_for_guests_inactive(void) { bool sent_guest_suspend = false; - int i = 0; + uint32_t i = 0; int ret = 0; while (i < tegra_hv_pm_ctl_data->wait_for_guests_size) { @@ -609,14 +610,44 @@ static ssize_t wait_for_guests_show(struct device *dev, struct device_attribute *attr, char *buf) { struct tegra_hv_pm_ctl *data = dev_get_drvdata(dev); - ssize_t count = 0; - int i; + ssize_t count = 0, result; + ssize_t page_size = PAGE_SIZE; + int i, ret; for (i = 0; i < data->wait_for_guests_size; i++) { - count += snprintf(buf + count, PAGE_SIZE - count, "%u ", - data->wait_for_guests[i]); + if (check_sub_overflow(page_size, count, &result)) { + pr_err("%s: operation got overflown.\n", __func__); + return -EINVAL; + } + + ret = snprintf(buf + count, result, "%u ", + data->wait_for_guests[i]); + if (ret < 0) { + pr_err("%s: snprintf API got failed.\n", __func__); + return -EINVAL; + } + + if (check_add_overflow((ssize_t)ret, count, &count)) { + pr_err("%s: operation got overflown.\n", __func__); + return -EINVAL; + } + } + + if (check_sub_overflow(page_size, count, &result)) { + pr_err("%s: operation got overflown.\n", __func__); + return -EINVAL; + } + + ret = snprintf(buf + count, result, "\n"); + if (ret < 0) { + pr_err("%s: snprintf API got failed.\n", __func__); + return -EINVAL; + } + + if (check_add_overflow((ssize_t)ret, count, &count)) { + pr_err("%s: operation got overflown.\n", __func__); + return -EINVAL; } - count += snprintf(buf + count, PAGE_SIZE - count, "\n"); return count; } @@ -762,8 +793,9 @@ static void pm_recv_msg(struct sk_buff *skb) struct tegra_hv_pm_ctl *data = tegra_hv_pm_ctl_data; struct nlmsghdr *nlh; static uint32_t i; + void *ptr = skb->data; - nlh = (struct nlmsghdr *)skb->data; + nlh = (struct nlmsghdr *)ptr; /*process messages coming from User Space only*/ if (nlh->nlmsg_pid != 0) { @@ -774,7 +806,7 @@ static void pm_recv_msg(struct sk_buff *skb) /*regiter userspace Client with kernel*/ spin_lock(&netlink_lock); - for (i = 0; i < user_client_count; i++) { + for (i = 0; i < atomic_read(&user_client_count); i++) { if (user_client[i].client_pid == nlh->nlmsg_pid) { dev_warn(data->dev, "Client already registered \ with pid:%d\n", nlh->nlmsg_pid); @@ -793,9 +825,11 @@ static void pm_recv_msg(struct sk_buff *skb) if (hole == true) { user_client[loc].client_pid = nlh->nlmsg_pid; } else { - if (user_client_count < MAX_USER_CLIENT - 1) - user_client[user_client_count++].client_pid = nlh->nlmsg_pid; - else + uint32_t index = atomic_read(&user_client_count); + if (index < MAX_USER_CLIENT - 1) { + user_client[index].client_pid = nlh->nlmsg_pid; + atomic_inc(&user_client_count); + } else dev_err(data->dev, "Client Registration failed for pid:%d \ due to resource exhaustion\n", nlh->nlmsg_pid); } @@ -806,7 +840,7 @@ static void pm_recv_msg(struct sk_buff *skb) bool deregistered = false; spin_lock(&netlink_lock); - for (i = 0; i < user_client_count; i++) { + for (i = 0; i < atomic_read(&user_client_count); i++) { if (user_client[i].client_pid == nlh->nlmsg_pid) { dev_dbg(data->dev, "Deregistering UserSpace Client \ with pid:%d\n", nlh->nlmsg_pid); @@ -827,7 +861,7 @@ static void pm_recv_msg(struct sk_buff *skb) bool active = false; spin_lock(&netlink_lock); - for (i = 0; i < user_client_count; i++) { + for (i = 0; i < atomic_read(&user_client_count); i++) { if (user_client[i].client_pid == nlh->nlmsg_pid && user_client[i].suspend_response == false) { dev_dbg(data->dev, "Received Suspend Response \ @@ -855,7 +889,7 @@ static void pm_recv_msg(struct sk_buff *skb) bool active = false; spin_lock(&netlink_lock); - for (i = 0; i < user_client_count; i++) { + for (i = 0; i < atomic_read(&user_client_count); i++) { if (user_client[i].client_pid == nlh->nlmsg_pid && user_client[i].suspend_response == false) { dev_warn(data->dev, "Already Received Resume Response \ @@ -888,7 +922,7 @@ static void pm_recv_msg(struct sk_buff *skb) /*invoke blocked task if got suspend response from all clients*/ spin_lock(&netlink_lock); if (strcmp((char *)nlmsg_data(nlh), "Suspend Response") == 0) { - for (i = 0; i < user_client_count && user_client_count > 0; i++) { + for (i = 0; i < atomic_read(&user_client_count) && atomic_read(&user_client_count) > 0; i++) { if (user_client[i].client_pid > 0 && user_client[i].suspend_response == false) { spin_unlock(&netlink_lock); @@ -899,7 +933,7 @@ static void pm_recv_msg(struct sk_buff *skb) /*invoke blocked task if got resume response from all clients*/ if (strcmp((char *)nlmsg_data(nlh), "Resume Response") == 0) { - for (i = 0; i < user_client_count && user_client_count > 0; i++) { + for (i = 0; i < atomic_read(&user_client_count) && atomic_read(&user_client_count) > 0; i++) { if (user_client[i].client_pid > 0 && user_client[i].suspend_response == true) { spin_unlock(&netlink_lock); @@ -927,7 +961,7 @@ static int notify_client(const char *msg, size_t msg_size) int ret = 0; spin_lock(&netlink_lock); - for (i = 0; i < user_client_count; i++) { + for (i = 0; i < atomic_read(&user_client_count); i++) { if (user_client[i].client_pid == 0) continue; @@ -957,7 +991,7 @@ static int notify_client(const char *msg, size_t msg_size) nlh = nlmsg_put(skb_out, 0, 0, NLMSG_DONE, msg_size, 0); if (nlh != NULL) - strncpy(nlmsg_data(nlh), msg, msg_size); + strscpy(nlmsg_data(nlh), msg, msg_size); else { dev_err(data->dev, "Failed to allocate netlink msg header\n"); ret = -ENOMEM; @@ -1001,7 +1035,7 @@ static int netlink_pm_notify(struct notifier_block *nb, dev_dbg(data->dev, "all client notified successful\n"); /*Receive the message from userspace*/ - if (user_client_count) + if (atomic_read(&user_client_count)) if (wait_for_completion_timeout(&netlink_complete, USERSPACE_RESPONSE_TIMEOUT) == 0) { dev_err(data->dev, "%s target suspend failed\n", __func__); @@ -1023,7 +1057,7 @@ static int netlink_pm_notify(struct notifier_block *nb, dev_dbg(data->dev, "all client notified successful\n"); /*Receive the message from userspace*/ - if (user_client_count) + if (atomic_read(&user_client_count)) if (wait_for_completion_timeout(&netlink_complete, USERSPACE_RESPONSE_TIMEOUT) == 0) { dev_err(data->dev, "%s target resume failed\n", __func__);