diff --git a/drivers/dax/pmem.c b/drivers/dax/pmem.c
index 033f49b31fdcfeeedede6a6349bbdab5115c82f6..cb0d742fa23f212102aca744e94095ab835b6023 100644
--- a/drivers/dax/pmem.c
+++ b/drivers/dax/pmem.c
@@ -43,6 +43,7 @@ static void dax_pmem_percpu_exit(void *data)
 	struct dax_pmem *dax_pmem = to_dax_pmem(ref);
 
 	dev_dbg(dax_pmem->dev, "%s\n", __func__);
+	wait_for_completion(&dax_pmem->cmp);
 	percpu_ref_exit(ref);
 }
 
@@ -53,7 +54,6 @@ static void dax_pmem_percpu_kill(void *data)
 
 	dev_dbg(dax_pmem->dev, "%s\n", __func__);
 	percpu_ref_kill(ref);
-	wait_for_completion(&dax_pmem->cmp);
 }
 
 static int dax_pmem_probe(struct device *dev)
diff --git a/drivers/nvdimm/pmem.c b/drivers/nvdimm/pmem.c
index 5b536be5a12eb97023745a59f65283280b7b3675..fb7bbc79ac264421eba88de2120dbc2295b3b74d 100644
--- a/drivers/nvdimm/pmem.c
+++ b/drivers/nvdimm/pmem.c
@@ -25,6 +25,7 @@
 #include <linux/badblocks.h>
 #include <linux/memremap.h>
 #include <linux/vmalloc.h>
+#include <linux/blk-mq.h>
 #include <linux/pfn_t.h>
 #include <linux/slab.h>
 #include <linux/pmem.h>
@@ -231,6 +232,11 @@ static void pmem_release_queue(void *q)
 	blk_cleanup_queue(q);
 }
 
+static void pmem_freeze_queue(void *q)
+{
+	blk_mq_freeze_queue_start(q);
+}
+
 static void pmem_release_disk(void *disk)
 {
 	del_gendisk(disk);
@@ -284,6 +290,9 @@ static int pmem_attach_disk(struct device *dev,
 	if (!q)
 		return -ENOMEM;
 
+	if (devm_add_action_or_reset(dev, pmem_release_queue, q))
+		return -ENOMEM;
+
 	pmem->pfn_flags = PFN_DEV;
 	if (is_nd_pfn(dev)) {
 		addr = devm_memremap_pages(dev, &pfn_res, &q->q_usage_counter,
@@ -303,10 +312,10 @@ static int pmem_attach_disk(struct device *dev,
 				pmem->size, ARCH_MEMREMAP_PMEM);
 
 	/*
-	 * At release time the queue must be dead before
+	 * At release time the queue must be frozen before
 	 * devm_memremap_pages is unwound
 	 */
-	if (devm_add_action_or_reset(dev, pmem_release_queue, q))
+	if (devm_add_action_or_reset(dev, pmem_freeze_queue, q))
 		return -ENOMEM;
 
 	if (IS_ERR(addr))
diff --git a/include/linux/mm.h b/include/linux/mm.h
index a835edd2db3486b1b2e28f9aac19841d1c860461..695da2a19b4cbb355810b88df2095b6c3e242801 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -762,19 +762,11 @@ static inline enum zone_type page_zonenum(const struct page *page)
 }
 
 #ifdef CONFIG_ZONE_DEVICE
-void get_zone_device_page(struct page *page);
-void put_zone_device_page(struct page *page);
 static inline bool is_zone_device_page(const struct page *page)
 {
 	return page_zonenum(page) == ZONE_DEVICE;
 }
 #else
-static inline void get_zone_device_page(struct page *page)
-{
-}
-static inline void put_zone_device_page(struct page *page)
-{
-}
 static inline bool is_zone_device_page(const struct page *page)
 {
 	return false;
@@ -790,9 +782,6 @@ static inline void get_page(struct page *page)
 	 */
 	VM_BUG_ON_PAGE(page_ref_count(page) <= 0, page);
 	page_ref_inc(page);
-
-	if (unlikely(is_zone_device_page(page)))
-		get_zone_device_page(page);
 }
 
 static inline void put_page(struct page *page)
@@ -801,9 +790,6 @@ static inline void put_page(struct page *page)
 
 	if (put_page_testzero(page))
 		__put_page(page);
-
-	if (unlikely(is_zone_device_page(page)))
-		put_zone_device_page(page);
 }
 
 #if defined(CONFIG_SPARSEMEM) && !defined(CONFIG_SPARSEMEM_VMEMMAP)
diff --git a/kernel/memremap.c b/kernel/memremap.c
index 07e85e5229da849d33391f97234c1e1fff2c5ce1..23a6483c3666e35fec7f927d2f6cd4a2e38dbe81 100644
--- a/kernel/memremap.c
+++ b/kernel/memremap.c
@@ -182,18 +182,6 @@ struct page_map {
 	struct vmem_altmap altmap;
 };
 
-void get_zone_device_page(struct page *page)
-{
-	percpu_ref_get(page->pgmap->ref);
-}
-EXPORT_SYMBOL(get_zone_device_page);
-
-void put_zone_device_page(struct page *page)
-{
-	put_dev_pagemap(page->pgmap);
-}
-EXPORT_SYMBOL(put_zone_device_page);
-
 static void pgmap_radix_release(struct resource *res)
 {
 	resource_size_t key, align_start, align_size, align_end;
@@ -237,6 +225,10 @@ static void devm_memremap_pages_release(struct device *dev, void *data)
 	struct resource *res = &page_map->res;
 	resource_size_t align_start, align_size;
 	struct dev_pagemap *pgmap = &page_map->pgmap;
+	unsigned long pfn;
+
+	for_each_device_pfn(pfn, page_map)
+		put_page(pfn_to_page(pfn));
 
 	if (percpu_ref_tryget_live(pgmap->ref)) {
 		dev_WARN(dev, "%s: page mapping is still live!\n", __func__);
@@ -277,7 +269,10 @@ struct dev_pagemap *find_dev_pagemap(resource_size_t phys)
  *
  * Notes:
  * 1/ @ref must be 'live' on entry and 'dead' before devm_memunmap_pages() time
- *    (or devm release event).
+ *    (or devm release event). The expected order of events is that @ref has
+ *    been through percpu_ref_kill() before devm_memremap_pages_release(). The
+ *    wait for the completion of all references being dropped and
+ *    percpu_ref_exit() must occur after devm_memremap_pages_release().
  *
  * 2/ @res is expected to be a host memory range that could feasibly be
  *    treated as a "System RAM" range, i.e. not a device mmio range, but
@@ -379,6 +374,7 @@ void *devm_memremap_pages(struct device *dev, struct resource *res,
 		 */
 		list_del(&page->lru);
 		page->pgmap = pgmap;
+		percpu_ref_get(ref);
 	}
 	devres_add(dev, page_map);
 	return __va(res->start);
diff --git a/mm/swap.c b/mm/swap.c
index c4910f14f9579ef1d8b165355f9294715968bf2d..a4e6113276b5554dafd8d70ba9712f3bd28813b9 100644
--- a/mm/swap.c
+++ b/mm/swap.c
@@ -97,6 +97,16 @@ static void __put_compound_page(struct page *page)
 
 void __put_page(struct page *page)
 {
+	if (is_zone_device_page(page)) {
+		put_dev_pagemap(page->pgmap);
+
+		/*
+		 * The page belongs to the device that created pgmap. Do
+		 * not return it to page allocator.
+		 */
+		return;
+	}
+
 	if (unlikely(PageCompound(page)))
 		__put_compound_page(page);
 	else