diff --git a/drivers/misc/mods/mods_arm_ffa.c b/drivers/misc/mods/mods_arm_ffa.c index 566ceae2..34ddea78 100644 --- a/drivers/misc/mods/mods_arm_ffa.c +++ b/drivers/misc/mods/mods_arm_ffa.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0-only -/* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. */ +/* SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. */ #include "mods_internal.h" #include @@ -19,7 +19,7 @@ static const struct ffa_device_id mods_ffa_device_id[] = { struct mods_ffa_ctx { struct ffa_device *ffa_dev; -#if KERNEL_VERSION(6, 1, 0) <= MODS_KERNEL_VERSION +#if KERNEL_VERSION(6, 1, 0) <= MODS_KERNEL_VERSION || defined(FFA_PARTITION_AARCH64_EXEC) const struct ffa_msg_ops *ffa_ops; #else const struct ffa_dev_ops *ffa_ops; @@ -34,7 +34,7 @@ static int ffa_probe(struct ffa_device *ffa_dev) { int ret = 0; -#if KERNEL_VERSION(6, 1, 0) <= MODS_KERNEL_VERSION +#if KERNEL_VERSION(6, 1, 0) <= MODS_KERNEL_VERSION || defined(FFA_PARTITION_AARCH64_EXEC) const struct ffa_msg_ops *ffa_ops = NULL; if (ffa_dev->ops) diff --git a/drivers/misc/mods/mods_config.h b/drivers/misc/mods/mods_config.h index af5b7d91..ed0743ce 100644 --- a/drivers/misc/mods/mods_config.h +++ b/drivers/misc/mods/mods_config.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0-only */ -/* SPDX-FileCopyrightText: Copyright (c) 2008-2023, NVIDIA CORPORATION. All rights reserved. */ +/* SPDX-FileCopyrightText: Copyright (c) 2008-2024, NVIDIA CORPORATION. All rights reserved. */ #ifndef _MODS_CONFIG_H_ #define _MODS_CONFIG_H_ @@ -45,6 +45,10 @@ # define MODS_HAS_SRIOV 1 #endif +#if KERNEL_VERSION(3, 13, 0) <= MODS_KERNEL_VERSION +# define MODS_HAS_REINIT_COMPLETION 1 +#endif + #if KERNEL_VERSION(3, 14, 0) <= MODS_KERNEL_VERSION # define MODS_HAS_MSIX_RANGE 1 #else @@ -78,6 +82,11 @@ # define MODS_HAS_KERNEL_WRITE #endif +#if KERNEL_VERSION(4, 14, 0) <= MODS_KERNEL_VERSION && \ + defined(CONFIG_X86) +# define MODS_HAS_PGPROT_DECRYPTED +#endif + #if KERNEL_VERSION(4, 16, 0) <= MODS_KERNEL_VERSION # define MODS_HAS_POLL_T 1 #endif @@ -86,6 +95,10 @@ # define MODS_HAS_PXM_TO_NODE 1 #endif +#if KERNEL_VERSION(5, 10, 0) <= MODS_KERNEL_VERSION +# define MODS_HAS_DMA_ALLOC_PAGES 1 +#endif + #if KERNEL_VERSION(5, 17, 0) <= MODS_KERNEL_VERSION # define MODS_HAS_ACPI_FETCH 1 #endif @@ -94,9 +107,13 @@ # define MODS_ENABLE_BPMP_MRQ_API 1 #endif -#if KERNEL_VERSION(5, 15, 0) > MODS_KERNEL_VERSION +#if (KERNEL_VERSION(5, 14, 0) > MODS_KERNEL_VERSION) # define MODS_HAS_FB_SET_SUSPEND 1 -#elif (KERNEL_VERSION(6, 1, 0) > MODS_KERNEL_VERSION) && !defined(CONFIG_CHROME_PLATFORMS) +#elif (KERNEL_VERSION(5, 15, 0) > MODS_KERNEL_VERSION) && !defined(CONFIG_RHEL_DIFFERENCES) +# define MODS_HAS_FB_SET_SUSPEND 1 +#elif (KERNEL_VERSION(6, 1, 0) > MODS_KERNEL_VERSION) && \ + !defined(CONFIG_CHROME_PLATFORMS) && \ + !defined(CONFIG_RHEL_DIFFERENCES) # define MODS_HAS_FB_SET_SUSPEND 1 #endif diff --git a/drivers/misc/mods/mods_internal.h b/drivers/misc/mods/mods_internal.h index 6539afb9..2ffa2fc6 100644 --- a/drivers/misc/mods/mods_internal.h +++ b/drivers/misc/mods/mods_internal.h @@ -4,6 +4,7 @@ #ifndef _MODS_INTERNAL_H_ #define _MODS_INTERNAL_H_ +#include #include #include #include @@ -44,9 +45,10 @@ #define MSI_DEV_NOT_FOUND 0 struct en_dev_entry { + struct list_head list; struct pci_dev *dev; - struct en_dev_entry *next; struct msix_entry *msix_entries; + struct completion client_completion; u32 irq_flags; u32 nvecs; #ifdef MODS_HAS_SRIOV @@ -86,10 +88,10 @@ struct mods_client { struct list_head ppc_tce_bypass_list; struct list_head nvlink_sysmem_trained_list; #endif + struct list_head enabled_devices; wait_queue_head_t interrupt_event; struct irq_q_info irq_queue; spinlock_t irq_lock; - struct en_dev_entry *enabled_devices; struct workqueue_struct *work_queue; struct mem_type mem_type; #if defined(CONFIG_PCI) @@ -107,12 +109,6 @@ struct mods_client { u8 client_id; }; -/* VM private data */ -struct mods_vm_private_data { - struct mods_client *client; - atomic_t usage_count; -}; - /* Free WC or UC chunk, which can be reused */ struct MODS_FREE_PHYS_CHUNK { struct list_head list; @@ -131,7 +127,7 @@ struct MODS_DMA_MAP { * was mapped to is not a PCI device. */ struct device *dev; /* device these mappings are for */ - struct scatterlist sg[1]; /* each entry corresponds to phys chunk + struct scatterlist sg[]; /* each entry corresponds to phys chunk * in sg array in MODS_MEM_INFO at the * same index */ @@ -146,23 +142,24 @@ struct MODS_MEM_INFO { */ struct list_head dma_map_list; - u32 num_pages; /* total number of allocated pages */ - u32 num_chunks; /* number of allocated contig chunks */ - int numa_node; /* numa node for the allocation */ - u8 cache_type : 2; /* MODS_ALLOC_* */ - u8 dma32 : 1; /* true/false */ - u8 force_numa : 1; /* true/false */ - u8 reservation_tag; /* zero if not reserved */ + u32 num_pages; /* total number of allocated pages */ + u32 num_chunks; /* number of allocated contig chunks */ + int numa_node; /* numa node for the allocation */ + u8 cache_type : 2; /* MODS_ALLOC_* */ + u8 dma32 : 1; /* true/false */ + u8 force_numa : 1; /* true/false */ + u8 no_free_opt : 1; /* true/false */ + u8 dma_pages : 1; /* true/false */ + u8 decrypted_mmap : 1; /* true/false */ + u8 reservation_tag; /* zero if not reserved */ - struct pci_dev *dev; /* (optional) pci_dev this allocation - * is for. - */ - unsigned long *wc_bitmap; /* marks which chunks use WC/UC */ - struct scatterlist *sg; /* current list of chunks */ - struct scatterlist contig_sg; /* contiguous merged chunk */ - struct scatterlist alloc_sg[1]; /* allocated memory chunks, each chunk - * consists of 2^n contiguous pages - */ + struct pci_dev *dev; /* (optional) pci_dev this allocation is for. */ + unsigned long *wc_bitmap; /* marks which chunks use WC/UC */ + struct scatterlist *sg; /* current list of chunks */ + struct scatterlist contig_sg; /* contiguous merged chunk */ + struct scatterlist alloc_sg[]; /* allocated memory chunks, each chunk + * consists of 2^n contiguous pages + */ }; static inline u32 get_num_chunks(const struct MODS_MEM_INFO *p_mem_info) @@ -180,10 +177,12 @@ struct SYS_MAP_MEMORY { /* used for offset lookup, NULL for device memory */ struct MODS_MEM_INFO *p_mem_info; - phys_addr_t phys_addr; - unsigned long virtual_addr; - unsigned long mapping_offs; /* mapped offset from the beginning of the allocation */ - unsigned long mapping_length; /* how many bytes were mapped */ + struct mods_client *client; + atomic_t usage_count; + phys_addr_t phys_addr; + unsigned long virtual_addr; + unsigned long mapping_offs; /* mapped offset from the beginning of the allocation */ + unsigned long mapping_length; /* how many bytes were mapped */ }; struct mods_smmu_dev { @@ -233,6 +232,8 @@ struct NVL_TRAINED { #define IRQ_VAL_POISON 0xfafbfcfdU +#define INVALID_CLIENT_ID 0 + /* debug print masks */ #define DEBUG_IOCTL 0x2 #define DEBUG_PCI 0x4 @@ -242,13 +243,12 @@ struct NVL_TRAINED { #define DEBUG_FUNC 0x40 #define DEBUG_CLOCK 0x80 #define DEBUG_DETAILED 0x100 -#define DEBUG_TEGRADC 0x200 #define DEBUG_TEGRADMA 0x400 #define DEBUG_ISR_DETAILED (DEBUG_ISR | DEBUG_DETAILED) #define DEBUG_MEM_DETAILED (DEBUG_MEM | DEBUG_DETAILED) #define DEBUG_ALL (DEBUG_IOCTL | DEBUG_PCI | DEBUG_ACPI | \ DEBUG_ISR | DEBUG_MEM | DEBUG_FUNC | DEBUG_CLOCK | DEBUG_DETAILED | \ - DEBUG_TEGRADC | DEBUG_TEGRADMA) + DEBUG_TEGRADMA) #define LOG_ENT() mods_debug_printk(DEBUG_FUNC, "> %s\n", __func__) #define LOG_EXT() mods_debug_printk(DEBUG_FUNC, "< %s\n", __func__) @@ -284,6 +284,9 @@ struct NVL_TRAINED { #define cl_warn(fmt, args...)\ pr_notice("mods [%u] warning: " fmt, client->client_id, ##args) +#define is_valid_client_id(client_id)\ + ((client_id) != INVALID_CLIENT_ID) + struct irq_mask_info { void __iomem *dev_irq_mask_reg; /*IRQ mask register, read-only reg*/ void __iomem *dev_irq_state; /* IRQ status register*/ diff --git a/drivers/misc/mods/mods_irq.c b/drivers/misc/mods/mods_irq.c index 4838b94c..6d59bb53 100644 --- a/drivers/misc/mods/mods_irq.c +++ b/drivers/misc/mods/mods_irq.c @@ -47,18 +47,31 @@ int mods_enable_device(struct mods_client *client, struct en_dev_entry **dev_entry) { int err = OK; - struct en_dev_entry *dpriv = client->enabled_devices; + struct en_dev_entry *dpriv = NULL; WARN_ON(!mutex_is_locked(&irq_mtx)); dpriv = pci_get_drvdata(dev); - if (dpriv) { - if (dpriv->client_id == client->client_id) { - if (dev_entry) - *dev_entry = dpriv; - return OK; - } + if (!dpriv) { + cl_error( + "driver data is not set for %04x:%02x:%02x.%x\n", + pci_domain_nr(dev->bus), + dev->bus->number, + PCI_SLOT(dev->devfn), + PCI_FUNC(dev->devfn)); + return -EINVAL; + } + + /* Client already owns the device */ + if (dpriv->client_id == client->client_id) { + if (dev_entry) + *dev_entry = dpriv; + return OK; + } + + /* Another client owns the device */ + if (is_valid_client_id(dpriv->client_id)) { cl_error( "invalid client for dev %04x:%02x:%02x.%x, expected %u\n", pci_domain_nr(dev->bus), @@ -69,11 +82,6 @@ int mods_enable_device(struct mods_client *client, return -EBUSY; } - dpriv = kzalloc(sizeof(*dpriv), GFP_KERNEL | __GFP_NORETRY); - if (unlikely(!dpriv)) - return -ENOMEM; - atomic_inc(&client->num_allocs); - err = pci_enable_device(dev); if (unlikely(err)) { cl_error("failed to enable dev %04x:%02x:%02x.%x\n", @@ -81,8 +89,6 @@ int mods_enable_device(struct mods_client *client, dev->bus->number, PCI_SLOT(dev->devfn), PCI_FUNC(dev->devfn)); - kfree(dpriv); - atomic_dec(&client->num_allocs); return err; } @@ -93,10 +99,12 @@ int mods_enable_device(struct mods_client *client, PCI_FUNC(dev->devfn)); dpriv->client_id = client->client_id; - dpriv->dev = pci_dev_get(dev); - dpriv->next = client->enabled_devices; - client->enabled_devices = dpriv; - pci_set_drvdata(dev, dpriv); + list_add(&dpriv->list, &client->enabled_devices); +#ifdef MODS_HAS_REINIT_COMPLETION + reinit_completion(&dpriv->client_completion); +#else + INIT_COMPLETION(dpriv->client_completion); +#endif if (dev_entry) *dev_entry = dpriv; @@ -110,18 +118,13 @@ void mods_disable_device(struct mods_client *client, WARN_ON(!mutex_is_locked(&irq_mtx)); -#ifdef MODS_HAS_SRIOV - if (dpriv && dpriv->num_vfs) - pci_disable_sriov(dev); -#endif + pci_disable_device(dev); if (dpriv) { - pci_set_drvdata(dev, NULL); - pci_dev_put(dev); + dpriv->client_id = INVALID_CLIENT_ID; + complete(&dpriv->client_completion); } - pci_disable_device(dev); - cl_info("disabled dev %04x:%02x:%02x.%x\n", pci_domain_nr(dev->bus), dev->bus->number, @@ -675,9 +678,27 @@ static int mods_free_irqs(struct mods_client *client, dpriv = pci_get_drvdata(dev); if (!dpriv) { + cl_error( + "driver data is not set for %04x:%02x:%02x.%x\n", + pci_domain_nr(dev->bus), + dev->bus->number, + PCI_SLOT(dev->devfn), + PCI_FUNC(dev->devfn)); mutex_unlock(&irq_mtx); LOG_EXT(); - return OK; + return -EINVAL; + } + + if (!is_valid_client_id(dpriv->client_id)) { + cl_error( + "no client owns dev %04x:%02x:%02x.%x\n", + pci_domain_nr(dev->bus), + dev->bus->number, + PCI_SLOT(dev->devfn), + PCI_FUNC(dev->devfn)); + mutex_unlock(&irq_mtx); + LOG_EXT(); + return -EINVAL; } if (dpriv->client_id != client->client_id) { @@ -751,14 +772,16 @@ static int mods_free_irqs(struct mods_client *client, void mods_free_client_interrupts(struct mods_client *client) { - struct en_dev_entry *dpriv = client->enabled_devices; + struct list_head *head = &client->enabled_devices; + struct list_head *iter; + struct en_dev_entry *dpriv; LOG_ENT(); /* Release all interrupts */ - while (dpriv) { + list_for_each(iter, head) { + dpriv = list_entry(iter, struct en_dev_entry, list); mods_free_irqs(client, dpriv->dev); - dpriv = dpriv->next; } LOG_EXT(); @@ -1014,7 +1037,20 @@ static int mods_register_pci_irq(struct mods_client *client, } dpriv = pci_get_drvdata(dev); - if (dpriv) { + if (!dpriv) { + cl_error( + "driver data is not set for %04x:%02x:%02x.%x\n", + pci_domain_nr(dev->bus), + dev->bus->number, + PCI_SLOT(dev->devfn), + PCI_FUNC(dev->devfn)); + mutex_unlock(&irq_mtx); + pci_dev_put(dev); + LOG_EXT(); + return -EINVAL; + } + + if (is_valid_client_id(dpriv->client_id)) { if (dpriv->client_id != client->client_id) { cl_error( "dev %04x:%02x:%02x.%x already owned by client %u\n", diff --git a/drivers/misc/mods/mods_krnl.c b/drivers/misc/mods/mods_krnl.c index 75510eaf..601e8668 100644 --- a/drivers/misc/mods/mods_krnl.c +++ b/drivers/misc/mods/mods_krnl.c @@ -96,6 +96,16 @@ static const struct pci_device_id mods_pci_table[] = { static int mods_pci_probe(struct pci_dev *dev, const struct pci_device_id *id) { + struct en_dev_entry *dpriv; + + dpriv = kzalloc(sizeof(*dpriv), GFP_KERNEL | __GFP_NORETRY); + if (unlikely(!dpriv)) + return -ENOMEM; + + dpriv->dev = pci_dev_get(dev); + init_completion(&dpriv->client_completion); + pci_set_drvdata(dev, dpriv); + mods_debug_printk(DEBUG_PCI, "probed dev %04x:%02x:%02x.%x vendor %04x device %04x\n", pci_domain_nr(dev->bus), @@ -107,6 +117,44 @@ static int mods_pci_probe(struct pci_dev *dev, const struct pci_device_id *id) return 0; } +static void mods_pci_remove(struct pci_dev *dev) +{ + struct en_dev_entry *dpriv = pci_get_drvdata(dev); + + WARN_ON(!dpriv); + + while (true) { + mutex_lock(mods_get_irq_mutex()); + + if (!is_valid_client_id(dpriv->client_id)) + break; + + mods_info_printk("removing dev %04x:%02x:%02x.%x, waiting for client %u\n", + pci_domain_nr(dev->bus), + dev->bus->number, + PCI_SLOT(dev->devfn), + PCI_FUNC(dev->devfn), + dpriv->client_id); + + mutex_unlock(mods_get_irq_mutex()); + wait_for_completion(&dpriv->client_completion); + } + + pci_dev_put(dpriv->dev); + pci_set_drvdata(dev, NULL); + kfree(dpriv); + + mutex_unlock(mods_get_irq_mutex()); + + mods_debug_printk(DEBUG_PCI, + "removed dev %04x:%02x:%02x.%x vendor %04x device %04x\n", + pci_domain_nr(dev->bus), + dev->bus->number, + PCI_SLOT(dev->devfn), + PCI_FUNC(dev->devfn), + dev->vendor, dev->device); +} + #if defined(CONFIG_PCI) && defined(MODS_HAS_SRIOV) static int mods_pci_sriov_configure(struct pci_dev *dev, int numvfs); #endif @@ -115,6 +163,7 @@ static struct pci_driver mods_pci_driver = { .name = DEVICE_NAME, .id_table = mods_pci_table, .probe = mods_pci_probe, + .remove = mods_pci_remove, .err_handler = &mods_pci_error_handlers, #ifdef MODS_HAS_SRIOV .sriov_configure = mods_pci_sriov_configure, @@ -229,14 +278,14 @@ static int esc_mods_set_num_vf(struct mods_client *client, } dpriv = pci_get_drvdata(dev); - if (!dpriv) { + if (!dpriv || !is_valid_client_id(dpriv->client_id)) { cl_error( "failed to enable sriov, dev %04x:%02x:%02x.%x was not enabled\n", pci_domain_nr(dev->bus), dev->bus->number, PCI_SLOT(dev->devfn), PCI_FUNC(dev->devfn)); - err = -EBUSY; + err = -EINVAL; goto error; } if (dpriv->client_id != client->client_id) { @@ -287,7 +336,7 @@ static int esc_mods_set_total_vf(struct mods_client *client, } dpriv = pci_get_drvdata(dev); - if (!dpriv) { + if (!dpriv || !is_valid_client_id(dpriv->client_id)) { cl_error( "failed to enable sriov, dev %04x:%02x:%02x.%x was not enabled\n", pci_domain_nr(dev->bus), @@ -408,6 +457,7 @@ static struct mods_client *alloc_client(void) init_waitqueue_head(&client->interrupt_event); INIT_LIST_HEAD(&client->irq_list); INIT_LIST_HEAD(&client->mem_alloc_list); + INIT_LIST_HEAD(&client->enabled_devices); INIT_LIST_HEAD(&client->mem_map_list); INIT_LIST_HEAD(&client->free_mem_list); #if defined(CONFIG_PPC64) @@ -634,18 +684,32 @@ MODULE_PARM_DESC(ppc_tce_bypass, static void mods_disable_all_devices(struct mods_client *client) { #ifdef CONFIG_PCI - if (unlikely(mutex_lock_interruptible(mods_get_irq_mutex()))) - return; + struct list_head *head = &client->enabled_devices; + struct en_dev_entry *entry; + struct en_dev_entry *tmp; - while (client->enabled_devices != NULL) { - struct en_dev_entry *old = client->enabled_devices; +#ifdef MODS_HAS_SRIOV + mutex_lock(mods_get_irq_mutex()); + list_for_each_entry_safe(entry, tmp, head, list) { + struct en_dev_entry *dpriv = pci_get_drvdata(entry->dev); - mods_disable_device(client, old->dev); - client->enabled_devices = old->next; - kfree(old); - atomic_dec(&client->num_allocs); + if (dpriv->num_vfs == 0) { + mods_disable_device(client, entry->dev); + list_del(&entry->list); + } } + mutex_unlock(mods_get_irq_mutex()); + list_for_each_entry(entry, head, list) { + pci_disable_sriov(entry->dev); + } +#endif + + mutex_lock(mods_get_irq_mutex()); + list_for_each_entry_safe(entry, tmp, head, list) { + mods_disable_device(client, entry->dev); + list_del(&entry->list); + } mutex_unlock(mods_get_irq_mutex()); if (client->cached_dev) { @@ -653,7 +717,7 @@ static void mods_disable_all_devices(struct mods_client *client) client->cached_dev = NULL; } #else - WARN_ON(client->enabled_devices != NULL); + WARN_ON(!list_empty(&client->enabled_devices)); #endif } @@ -666,24 +730,16 @@ static inline int mods_resume_console(struct mods_client *client) { return 0; } /********************* * MAPPING FUNCTIONS * *********************/ -static int register_mapping(struct mods_client *client, - struct MODS_MEM_INFO *p_mem_info, - phys_addr_t phys_addr, - unsigned long virtual_address, - unsigned long mapping_offs, - unsigned long mapping_length) +static int register_mapping(struct mods_client *client, + struct MODS_MEM_INFO *p_mem_info, + phys_addr_t phys_addr, + struct SYS_MAP_MEMORY *p_map_mem, + unsigned long virtual_address, + unsigned long mapping_offs, + unsigned long mapping_length) { - struct SYS_MAP_MEMORY *p_map_mem; - LOG_ENT(); - p_map_mem = kzalloc(sizeof(*p_map_mem), GFP_KERNEL | __GFP_NORETRY); - if (unlikely(!p_map_mem)) { - LOG_EXT(); - return -ENOMEM; - } - atomic_inc(&client->num_allocs); - p_map_mem->phys_addr = phys_addr; p_map_mem->virtual_addr = virtual_address; p_map_mem->mapping_offs = mapping_offs; @@ -704,56 +760,6 @@ static int register_mapping(struct mods_client *client, return OK; } -static void unregister_mapping(struct mods_client *client, - struct SYS_MAP_MEMORY *p_map_mem) -{ - list_del(&p_map_mem->list); - - kfree(p_map_mem); - atomic_dec(&client->num_allocs); -} - -static struct SYS_MAP_MEMORY *find_mapping(struct mods_client *client, - unsigned long virtual_address) -{ - struct SYS_MAP_MEMORY *p_map_mem = NULL; - struct list_head *head = &client->mem_map_list; - struct list_head *iter; - - LOG_ENT(); - - list_for_each(iter, head) { - p_map_mem = list_entry(iter, struct SYS_MAP_MEMORY, list); - - if (p_map_mem->virtual_addr == virtual_address) - break; - - p_map_mem = NULL; - } - - LOG_EXT(); - - return p_map_mem; -} - -static void unregister_all_mappings(struct mods_client *client) -{ - struct SYS_MAP_MEMORY *p_map_mem; - struct list_head *head = &client->mem_map_list; - struct list_head *iter; - struct list_head *tmp; - - LOG_ENT(); - - list_for_each_safe(iter, tmp, head) { - p_map_mem = list_entry(iter, struct SYS_MAP_MEMORY, list); - - unregister_mapping(client, p_map_mem); - } - - LOG_EXT(); -} - static pgprot_t get_prot(struct mods_client *client, u8 mem_type, pgprot_t prot) @@ -880,41 +886,39 @@ static void mods_pci_resume(struct pci_dev *dev) ********************/ static void mods_krnl_vma_open(struct vm_area_struct *vma) { - struct mods_vm_private_data *priv; + struct SYS_MAP_MEMORY *p_map_mem; LOG_ENT(); + mods_debug_printk(DEBUG_MEM_DETAILED, "open vma, virt 0x%lx, size 0x%lx, phys 0x%llx\n", vma->vm_start, vma->vm_end - vma->vm_start, (unsigned long long)vma->vm_pgoff << PAGE_SHIFT); - priv = vma->vm_private_data; - if (priv) - atomic_inc(&priv->usage_count); + p_map_mem = vma->vm_private_data; + if (p_map_mem) + atomic_inc(&p_map_mem->usage_count); LOG_EXT(); } static void mods_krnl_vma_close(struct vm_area_struct *vma) { - struct mods_vm_private_data *priv; + struct SYS_MAP_MEMORY *p_map_mem; LOG_ENT(); - priv = vma->vm_private_data; - if (priv && atomic_dec_and_test(&priv->usage_count)) { - struct mods_client *client = priv->client; - struct SYS_MAP_MEMORY *p_map_mem; + p_map_mem = vma->vm_private_data; - mutex_lock(&client->mtx); + if (p_map_mem && atomic_dec_and_test(&p_map_mem->usage_count)) { + struct mods_client *client = p_map_mem->client; - /* we need to unregister the mapping */ - p_map_mem = find_mapping(client, vma->vm_start); - if (p_map_mem) - unregister_mapping(client, p_map_mem); - - mutex_unlock(&client->mtx); + if (p_map_mem->mapping_length) { + mutex_lock(&client->mtx); + list_del(&p_map_mem->list); + mutex_unlock(&client->mtx); + } mods_debug_printk(DEBUG_MEM_DETAILED, "closed vma, virt 0x%lx\n", @@ -922,7 +926,7 @@ static void mods_krnl_vma_close(struct vm_area_struct *vma) vma->vm_private_data = NULL; - kfree(priv); + kfree(p_map_mem); atomic_dec(&client->num_allocs); } @@ -936,20 +940,19 @@ static int mods_krnl_vma_access(struct vm_area_struct *vma, int len, int write) { - struct mods_vm_private_data *priv = vma->vm_private_data; - struct mods_client *client; - struct SYS_MAP_MEMORY *p_map_mem; - unsigned long map_offs; - int err = OK; + struct SYS_MAP_MEMORY *p_map_mem = vma->vm_private_data; + struct mods_client *client; + unsigned long map_offs; + int err = OK; LOG_ENT(); - if (!priv) { + if (!p_map_mem) { LOG_EXT(); return -EINVAL; } - client = priv->client; + client = p_map_mem->client; cl_debug(DEBUG_MEM_DETAILED, "access vma [virt 0x%lx, size 0x%lx, phys 0x%llx] at virt 0x%lx, len 0x%x\n", @@ -964,8 +967,6 @@ static int mods_krnl_vma_access(struct vm_area_struct *vma, return -EINTR; } - p_map_mem = find_mapping(client, vma->vm_start); - if (unlikely(!p_map_mem || addr < p_map_mem->virtual_addr || addr + len > p_map_mem->virtual_addr + p_map_mem->mapping_length)) { @@ -1099,7 +1100,10 @@ static int mods_krnl_close(struct inode *ip, struct file *fp) mods_resume_console(client); - unregister_all_mappings(client); + /* All memory mappings should be gone before close */ + if (unlikely(!list_empty(&client->mem_map_list))) + cl_error("not all memory mappings have been freed\n"); + err = mods_unregister_all_alloc(client); if (err) cl_error("failed to free all memory\n"); @@ -1177,13 +1181,13 @@ static POLL_TYPE mods_krnl_poll(struct file *fp, poll_table *wait) return mask; } -static int mods_krnl_map_inner(struct mods_client *client, - struct vm_area_struct *vma); +static int map_internal(struct mods_client *client, + struct vm_area_struct *vma); static int mods_krnl_mmap(struct file *fp, struct vm_area_struct *vma) { - struct mods_vm_private_data *vma_private_data; - struct mods_client *client = fp->private_data; + struct SYS_MAP_MEMORY *p_map_mem; + struct mods_client *client = fp->private_data; int err; LOG_ENT(); @@ -1201,24 +1205,21 @@ static int mods_krnl_mmap(struct file *fp, struct vm_area_struct *vma) vma->vm_ops = &mods_krnl_vm_ops; - vma_private_data = kzalloc(sizeof(*vma_private_data), - GFP_KERNEL | __GFP_NORETRY); - if (unlikely(!vma_private_data)) { + p_map_mem = kzalloc(sizeof(*p_map_mem), GFP_KERNEL | __GFP_NORETRY); + if (unlikely(!p_map_mem)) { LOG_EXT(); return -ENOMEM; } atomic_inc(&client->num_allocs); - /* set private data for vm_area_struct */ - atomic_set(&vma_private_data->usage_count, 0); - vma_private_data->client = client; - vma->vm_private_data = vma_private_data; + p_map_mem->client = client; + vma->vm_private_data = p_map_mem; mods_krnl_vma_open(vma); err = mutex_lock_interruptible(&client->mtx); if (likely(!err)) { - err = mods_krnl_map_inner(client, vma); + err = map_internal(client, vma); mutex_unlock(&client->mtx); } @@ -1244,10 +1245,15 @@ static int map_system_mem(struct mods_client *client, const u32 num_chunks = get_num_chunks(p_mem_info); u32 map_chunks; u32 i = 0; - const pgprot_t prot = get_prot(client, + pgprot_t prot = get_prot(client, p_mem_info->cache_type, vma->vm_page_prot); +#ifdef MODS_HAS_PGPROT_DECRYPT + if (p_mem_info->decrypted_mmap) + prot = pgprot_decrypted(prot); +#endif + /* Find the beginning of the requested range */ for_each_sg(p_mem_info->sg, sg, num_chunks, i) { const phys_addr_t phys_addr = sg_phys(sg); @@ -1318,6 +1324,7 @@ static int map_system_mem(struct mods_client *client, register_mapping(client, p_mem_info, reg_pa, + vma->vm_private_data, vma->vm_start, skip_size, vma_size); @@ -1357,6 +1364,7 @@ static int map_device_mem(struct mods_client *client, register_mapping(client, NULL, req_pa, + vma->vm_private_data, vma->vm_start, 0, vma_size); @@ -1364,10 +1372,10 @@ static int map_device_mem(struct mods_client *client, return OK; } -static int mods_krnl_map_inner(struct mods_client *client, - struct vm_area_struct *vma) +static int map_internal(struct mods_client *client, + struct vm_area_struct *vma) { - const phys_addr_t req_pa = (phys_addr_t)vma->vm_pgoff << PAGE_SHIFT; + const phys_addr_t req_pa = (phys_addr_t)vma->vm_pgoff << PAGE_SHIFT; struct MODS_MEM_INFO *p_mem_info = mods_find_alloc(client, req_pa); const unsigned long vma_size = vma->vm_end - vma->vm_start; diff --git a/drivers/misc/mods/mods_mem.c b/drivers/misc/mods/mods_mem.c index 5f9d64a3..e8312e72 100644 --- a/drivers/misc/mods/mods_mem.c +++ b/drivers/misc/mods/mods_mem.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0-only -/* SPDX-FileCopyrightText: Copyright (c) 2008-2023, NVIDIA CORPORATION. All rights reserved. */ +/* SPDX-FileCopyrightText: Copyright (c) 2008-2024, NVIDIA CORPORATION. All rights reserved. */ #include "mods_internal.h" @@ -248,6 +248,8 @@ static void unmap_sg(struct device *dev, dma_unmap_sg(dev, sg, (int)chunks_to_unmap, DMA_BIDIRECTIONAL); + sg_dma_address(sg) = 0; + sg += chunks_to_unmap; num_chunks -= chunks_to_unmap; @@ -296,14 +298,12 @@ static int dma_unmap_all(struct mods_client *client, struct list_head *tmp; #ifdef CONFIG_PCI - if (sg_dma_address(p_mem_info->sg) && + if (sg_dma_address(p_mem_info->sg) && !p_mem_info->dma_pages && (dev == &p_mem_info->dev->dev || !dev)) { unmap_sg(&p_mem_info->dev->dev, p_mem_info->sg, get_num_chunks(p_mem_info)); - - sg_dma_address(p_mem_info->sg) = 0; } #endif @@ -494,6 +494,9 @@ static void save_non_wb_chunks(struct mods_client *client, if (p_mem_info->cache_type == MODS_ALLOC_CACHED) return; + if (p_mem_info->no_free_opt) + return; + if (unlikely(mutex_lock_interruptible(&client->mtx))) return; @@ -661,8 +664,8 @@ static void release_chunks(struct mods_client *client, { u32 i; - WARN_ON(sg_dma_address(p_mem_info->sg)); WARN_ON(!list_empty(&p_mem_info->dma_map_list)); + WARN_ON(sg_dma_address(p_mem_info->sg) && !p_mem_info->dma_pages); restore_cache(client, p_mem_info); @@ -679,7 +682,19 @@ static void release_chunks(struct mods_client *client, order = get_order(sg->length); WARN_ON((PAGE_SIZE << order) != sg->length); - __free_pages(sg_page(sg), order); + if (p_mem_info->dma_pages) { +#ifdef MODS_HAS_DMA_ALLOC_PAGES + WARN_ON(!sg_dma_address(sg)); + WARN_ON(!p_mem_info->dev); + dma_free_pages(&p_mem_info->dev->dev, + PAGE_SIZE << order, + sg_page(sg), + sg_dma_address(sg), + DMA_BIDIRECTIONAL); +#endif + } else + __free_pages(sg_page(sg), order); + atomic_sub(1u << order, &client->num_pages); sg_set_page(sg, NULL, 0, 0); @@ -711,7 +726,8 @@ static gfp_t get_alloc_flags(struct MODS_MEM_INFO *p_mem_info, u32 order) static struct page *alloc_chunk(struct mods_client *client, struct MODS_MEM_INFO *p_mem_info, u32 order, - int *need_cup) + int *need_cup, + dma_addr_t *dma_handle) { struct page *p_page = NULL; u8 cache_type = p_mem_info->cache_type; @@ -759,9 +775,18 @@ static struct page *alloc_chunk(struct mods_client *client, } } - p_page = alloc_pages_node(p_mem_info->numa_node, - get_alloc_flags(p_mem_info, order), - order); + if (p_mem_info->dma_pages) { +#ifdef MODS_HAS_DMA_ALLOC_PAGES + p_page = dma_alloc_pages(&p_mem_info->dev->dev, + PAGE_SIZE << order, + dma_handle, + DMA_BIDIRECTIONAL, + GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN); +#endif + } else + p_page = alloc_pages_node(p_mem_info->numa_node, + get_alloc_flags(p_mem_info, order), + order); *need_cup = 1; @@ -777,6 +802,7 @@ static int alloc_contig_sys_pages(struct mods_client *client, const unsigned long req_bytes = (unsigned long)p_mem_info->num_pages << PAGE_SHIFT; struct page *p_page; + dma_addr_t dma_handle = 0; u64 phys_addr; u64 end_addr = 0; u32 order = 0; @@ -788,7 +814,7 @@ static int alloc_contig_sys_pages(struct mods_client *client, while ((1U << order) < p_mem_info->num_pages) order++; - p_page = alloc_chunk(client, p_mem_info, order, &is_wb); + p_page = alloc_chunk(client, p_mem_info, order, &is_wb, &dma_handle); if (unlikely(!p_page)) goto failed; @@ -796,6 +822,9 @@ static int alloc_contig_sys_pages(struct mods_client *client, p_mem_info->num_pages = 1U << order; sg_set_page(p_mem_info->alloc_sg, p_page, PAGE_SIZE << order, 0); + sg_dma_address(p_mem_info->alloc_sg) = dma_handle; + if (!sg_dma_len(p_mem_info->alloc_sg)) + sg_dma_len(p_mem_info->alloc_sg) = PAGE_SIZE << order; if (!is_wb) mark_chunk_wc(p_mem_info, 0); @@ -864,12 +893,18 @@ static int alloc_noncontig_sys_pages(struct mods_client *client, } for (;;) { + dma_addr_t dma_handle = 0; + struct page *p_page = alloc_chunk(client, p_mem_info, order, - &is_wb); + &is_wb, + &dma_handle); if (p_page) { sg_set_page(sg, p_page, PAGE_SIZE << order, 0); + sg_dma_address(sg) = dma_handle; + if (!sg_dma_len(sg)) + sg_dma_len(sg) = PAGE_SIZE << order; allocated_pages = 1u << order; break; } @@ -1377,6 +1412,21 @@ int esc_mods_alloc_pages_2(struct mods_client *client, goto failed; } + if (unlikely((p->flags & MODS_ALLOC_DMA_PAGES) && + ((p->flags & MODS_ALLOC_DMA32) || + (p->flags & MODS_ALLOC_USE_NUMA) || + !(p->flags & MODS_ALLOC_MAP_DEV)))) { + cl_error("invalid combination of alloc flags 0x%x for dma pages\n", p->flags); + goto failed; + } + +#ifndef MODS_HAS_DMA_ALLOC_PAGES + if (unlikely(p->flags & MODS_ALLOC_DMA_PAGES)) { + cl_error("dma pages are not supported in this kernel\n"); + goto failed; + } +#endif + #ifdef CONFIG_PPC64 if (unlikely((p->flags & MODS_ALLOC_CACHE_MASK) != MODS_ALLOC_CACHED)) { cl_error("unsupported cache attr %u (%s)\n", @@ -1405,6 +1455,14 @@ int esc_mods_alloc_pages_2(struct mods_client *client, p_mem_info->dma32 = (p->flags & MODS_ALLOC_DMA32) ? true : false; p_mem_info->force_numa = (p->flags & MODS_ALLOC_FORCE_NUMA) ? true : false; + p_mem_info->no_free_opt = (p->flags & MODS_ALLOC_NO_FREE_OPTIMIZATION) || + (p->flags & MODS_ALLOC_DMA_PAGES) || + (p->flags & MODS_ALLOC_DECRYPTED_MMAP) + ? true : false; + p_mem_info->dma_pages = (p->flags & MODS_ALLOC_DMA_PAGES) ? true : false; +#ifdef MODS_HAS_PGPROT_DECRYPTED + p_mem_info->decrypted_mmap = (p->flags & MODS_ALLOC_DECRYPTED_MMAP) ? true : false; +#endif p_mem_info->reservation_tag = 0; #ifdef MODS_HASNT_NUMA_NO_NODE p_mem_info->numa_node = numa_node_id(); diff --git a/include/uapi/misc/mods.h b/include/uapi/misc/mods.h index 5f3d1d69..09f4a4e2 100644 --- a/include/uapi/misc/mods.h +++ b/include/uapi/misc/mods.h @@ -8,7 +8,7 @@ /* Driver version */ #define MODS_DRIVER_VERSION_MAJOR 4 -#define MODS_DRIVER_VERSION_MINOR 24 +#define MODS_DRIVER_VERSION_MINOR 28 #define MODS_DRIVER_VERSION ((MODS_DRIVER_VERSION_MAJOR << 8) | \ ((MODS_DRIVER_VERSION_MINOR / 10) << 4) | \ (MODS_DRIVER_VERSION_MINOR % 10)) @@ -114,22 +114,25 @@ struct MODS_ALLOC_PAGES_2 { #define MODS_ANY_NUMA_NODE (-1) /* Bit flags for the flags member above */ -#define MODS_ALLOC_CACHED 0 /* Default WB cache attr */ -#define MODS_ALLOC_UNCACHED 1 /* UC cache attr */ -#define MODS_ALLOC_WRITECOMBINE 2 /* WC cache attr */ -#define MODS_ALLOC_CACHE_MASK 7U /* The first three bits are cache attr */ +#define MODS_ALLOC_CACHED 0U /* Default WB cache attr */ +#define MODS_ALLOC_UNCACHED (1U << 0) /* UC cache attr */ +#define MODS_ALLOC_WRITECOMBINE (1U << 1) /* WC cache attr */ +#define MODS_ALLOC_CACHE_MASK 7U /* The first three bits are cache attr */ -#define MODS_ALLOC_DMA32 8 /* Force 32-bit PA, else any PA */ -#define MODS_ALLOC_CONTIGUOUS 16 /* Force contiguous, else non-contig */ -#define MODS_ALLOC_USE_NUMA 32 /* Use numa_node member instead of PCI dev - * for NUMA node hint - */ -#define MODS_ALLOC_FORCE_NUMA 64 /* Force memory to be from a given NUMA - * node (specified by PCI dev or - * numa_node). Otherwise use PCI dev or - * numa_node as a hint only. - */ -#define MODS_ALLOC_MAP_DEV 128 /* DMA map to PCI device */ +#define MODS_ALLOC_DMA32 (1U << 3) /* Force 32-bit PA, else any PA */ +#define MODS_ALLOC_CONTIGUOUS (1U << 4) /* Force contiguous, else non-contig */ +#define MODS_ALLOC_USE_NUMA (1U << 5) /* Use numa_node member instead of PCI dev + * for NUMA node hint + */ +#define MODS_ALLOC_FORCE_NUMA (1U << 6) /* Force memory to be from a given NUMA + * node (specified by PCI dev or + * numa_node). Otherwise use PCI dev or + * numa_node as a hint only. + */ +#define MODS_ALLOC_MAP_DEV (1U << 7) /* DMA map to PCI device */ +#define MODS_ALLOC_NO_FREE_OPTIMIZATION (1U << 8) /* Don't cache surfaces on free */ +#define MODS_ALLOC_DMA_PAGES (1U << 9) /* Allocate memory from DMA region */ +#define MODS_ALLOC_DECRYPTED_MMAP (1U << 10) /* Allocate decrypted memory */ /* Used by MODS_ESC_ALLOC_PAGES ioctl */ struct MODS_ALLOC_PAGES {