diff --git a/include/linux/gfp.h b/include/linux/gfp.h
index 570383a4185371bda5713257898979057075d5a3..c29e9d347bc697125e6eb02fe2ca0b033cbf83fa 100644
--- a/include/linux/gfp.h
+++ b/include/linux/gfp.h
@@ -78,8 +78,7 @@ struct vm_area_struct;
  * __GFP_THISNODE forces the allocation to be satisified from the requested
  *   node with no fallbacks or placement policy enforcements.
  *
- * __GFP_ACCOUNT causes the allocation to be accounted to kmemcg (only relevant
- *   to kmem allocations).
+ * __GFP_ACCOUNT causes the allocation to be accounted to kmemcg.
  */
 #define __GFP_RECLAIMABLE ((__force gfp_t)___GFP_RECLAIMABLE)
 #define __GFP_WRITE	((__force gfp_t)___GFP_WRITE)
@@ -486,10 +485,6 @@ extern struct page *alloc_pages_vma(gfp_t gfp_mask, int order,
 #define alloc_page_vma_node(gfp_mask, vma, addr, node)		\
 	alloc_pages_vma(gfp_mask, 0, vma, addr, node, false)
 
-extern struct page *alloc_kmem_pages(gfp_t gfp_mask, unsigned int order);
-extern struct page *alloc_kmem_pages_node(int nid, gfp_t gfp_mask,
-					  unsigned int order);
-
 extern unsigned long __get_free_pages(gfp_t gfp_mask, unsigned int order);
 extern unsigned long get_zeroed_page(gfp_t gfp_mask);
 
@@ -513,9 +508,6 @@ extern void *__alloc_page_frag(struct page_frag_cache *nc,
 			       unsigned int fragsz, gfp_t gfp_mask);
 extern void __free_page_frag(void *addr);
 
-extern void __free_kmem_pages(struct page *page, unsigned int order);
-extern void free_kmem_pages(unsigned long addr, unsigned int order);
-
 #define __free_page(page) __free_pages((page), 0)
 #define free_page(addr) free_pages((addr), 0)
 
diff --git a/include/linux/page-flags.h b/include/linux/page-flags.h
index 96084ee74ee8fc11b6a325f66dfeb5bd5dcac2b4..7c8e82ac2eb7f36f590ad65aa1c3c8de7b6e774f 100644
--- a/include/linux/page-flags.h
+++ b/include/linux/page-flags.h
@@ -641,6 +641,13 @@ PAGE_MAPCOUNT_OPS(Buddy, BUDDY)
 #define PAGE_BALLOON_MAPCOUNT_VALUE		(-256)
 PAGE_MAPCOUNT_OPS(Balloon, BALLOON)
 
+/*
+ * If kmemcg is enabled, the buddy allocator will set PageKmemcg() on
+ * pages allocated with __GFP_ACCOUNT. It gets cleared on page free.
+ */
+#define PAGE_KMEMCG_MAPCOUNT_VALUE		(-512)
+PAGE_MAPCOUNT_OPS(Kmemcg, KMEMCG)
+
 extern bool is_free_buddy_page(struct page *page);
 
 __PAGEFLAG(Isolated, isolated, PF_ANY);
diff --git a/kernel/fork.c b/kernel/fork.c
index 4a7ec0c6c88c9ea85b2c74c6a749e11533220d44..de21f25e0d2ce9755b80ecc653269b8b5f8e5415 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -162,8 +162,8 @@ void __weak arch_release_thread_stack(unsigned long *stack)
 static unsigned long *alloc_thread_stack_node(struct task_struct *tsk,
 						  int node)
 {
-	struct page *page = alloc_kmem_pages_node(node, THREADINFO_GFP,
-						  THREAD_SIZE_ORDER);
+	struct page *page = alloc_pages_node(node, THREADINFO_GFP,
+					     THREAD_SIZE_ORDER);
 
 	if (page)
 		memcg_kmem_update_page_stat(page, MEMCG_KERNEL_STACK,
@@ -178,7 +178,7 @@ static inline void free_thread_stack(unsigned long *stack)
 
 	memcg_kmem_update_page_stat(page, MEMCG_KERNEL_STACK,
 				    -(1 << THREAD_SIZE_ORDER));
-	__free_kmem_pages(page, THREAD_SIZE_ORDER);
+	__free_pages(page, THREAD_SIZE_ORDER);
 }
 # else
 static struct kmem_cache *thread_stack_cache;
diff --git a/mm/page_alloc.c b/mm/page_alloc.c
index de2491c42d4ffd433e84368f6b669a26c2fe2071..7023a31edc5c265a34b963ae6e9f917e408d5ddc 100644
--- a/mm/page_alloc.c
+++ b/mm/page_alloc.c
@@ -63,6 +63,7 @@
 #include <linux/sched/rt.h>
 #include <linux/page_owner.h>
 #include <linux/kthread.h>
+#include <linux/memcontrol.h>
 
 #include <asm/sections.h>
 #include <asm/tlbflush.h>
@@ -1018,6 +1019,10 @@ static __always_inline bool free_pages_prepare(struct page *page,
 	}
 	if (PageMappingFlags(page))
 		page->mapping = NULL;
+	if (memcg_kmem_enabled() && PageKmemcg(page)) {
+		memcg_kmem_uncharge(page, order);
+		__ClearPageKmemcg(page);
+	}
 	if (check_free)
 		bad += free_pages_check(page);
 	if (bad)
@@ -3841,6 +3846,14 @@ __alloc_pages_nodemask(gfp_t gfp_mask, unsigned int order,
 	}
 
 out:
+	if (memcg_kmem_enabled() && (gfp_mask & __GFP_ACCOUNT) && page) {
+		if (unlikely(memcg_kmem_charge(page, gfp_mask, order))) {
+			__free_pages(page, order);
+			page = NULL;
+		} else
+			__SetPageKmemcg(page);
+	}
+
 	if (kmemcheck_enabled && page)
 		kmemcheck_pagealloc_alloc(page, order, gfp_mask);
 
@@ -3996,59 +4009,6 @@ void __free_page_frag(void *addr)
 }
 EXPORT_SYMBOL(__free_page_frag);
 
-/*
- * alloc_kmem_pages charges newly allocated pages to the kmem resource counter
- * of the current memory cgroup if __GFP_ACCOUNT is set, other than that it is
- * equivalent to alloc_pages.
- *
- * It should be used when the caller would like to use kmalloc, but since the
- * allocation is large, it has to fall back to the page allocator.
- */
-struct page *alloc_kmem_pages(gfp_t gfp_mask, unsigned int order)
-{
-	struct page *page;
-
-	page = alloc_pages(gfp_mask, order);
-	if (memcg_kmem_enabled() && (gfp_mask & __GFP_ACCOUNT) &&
-	    page && memcg_kmem_charge(page, gfp_mask, order) != 0) {
-		__free_pages(page, order);
-		page = NULL;
-	}
-	return page;
-}
-
-struct page *alloc_kmem_pages_node(int nid, gfp_t gfp_mask, unsigned int order)
-{
-	struct page *page;
-
-	page = alloc_pages_node(nid, gfp_mask, order);
-	if (memcg_kmem_enabled() && (gfp_mask & __GFP_ACCOUNT) &&
-	    page && memcg_kmem_charge(page, gfp_mask, order) != 0) {
-		__free_pages(page, order);
-		page = NULL;
-	}
-	return page;
-}
-
-/*
- * __free_kmem_pages and free_kmem_pages will free pages allocated with
- * alloc_kmem_pages.
- */
-void __free_kmem_pages(struct page *page, unsigned int order)
-{
-	if (memcg_kmem_enabled())
-		memcg_kmem_uncharge(page, order);
-	__free_pages(page, order);
-}
-
-void free_kmem_pages(unsigned long addr, unsigned int order)
-{
-	if (addr != 0) {
-		VM_BUG_ON(!virt_addr_valid((void *)addr));
-		__free_kmem_pages(virt_to_page((void *)addr), order);
-	}
-}
-
 static void *make_alloc_exact(unsigned long addr, unsigned int order,
 		size_t size)
 {
diff --git a/mm/slab_common.c b/mm/slab_common.c
index da88c1588752a8ff0ff5f91eff7df7777bebc134..71f0b28a1bec8bc58a479f7c53343b647bcf5f24 100644
--- a/mm/slab_common.c
+++ b/mm/slab_common.c
@@ -1012,7 +1012,7 @@ void *kmalloc_order(size_t size, gfp_t flags, unsigned int order)
 	struct page *page;
 
 	flags |= __GFP_COMP;
-	page = alloc_kmem_pages(flags, order);
+	page = alloc_pages(flags, order);
 	ret = page ? page_address(page) : NULL;
 	kmemleak_alloc(ret, size, 1, flags);
 	kasan_kmalloc_large(ret, size, flags);
diff --git a/mm/slub.c b/mm/slub.c
index c0cfa272253972b54f4e62cda46014c9ef6c8112..f9da8716b8b358fb6a0d7184a5c2ccb95d7369a5 100644
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -2977,7 +2977,7 @@ int build_detached_freelist(struct kmem_cache *s, size_t size,
 		if (unlikely(!PageSlab(page))) {
 			BUG_ON(!PageCompound(page));
 			kfree_hook(object);
-			__free_kmem_pages(page, compound_order(page));
+			__free_pages(page, compound_order(page));
 			p[size] = NULL; /* mark object processed */
 			return size;
 		}
@@ -3693,7 +3693,7 @@ static void *kmalloc_large_node(size_t size, gfp_t flags, int node)
 	void *ptr = NULL;
 
 	flags |= __GFP_COMP | __GFP_NOTRACK;
-	page = alloc_kmem_pages_node(node, flags, get_order(size));
+	page = alloc_pages_node(node, flags, get_order(size));
 	if (page)
 		ptr = page_address(page);
 
@@ -3774,7 +3774,7 @@ void kfree(const void *x)
 	if (unlikely(!PageSlab(page))) {
 		BUG_ON(!PageCompound(page));
 		kfree_hook(x);
-		__free_kmem_pages(page, compound_order(page));
+		__free_pages(page, compound_order(page));
 		return;
 	}
 	slab_free(page->slab_cache, page, object, NULL, 1, _RET_IP_);
diff --git a/mm/vmalloc.c b/mm/vmalloc.c
index e11475cdeb7adb66194d8fd985fcbe155be806f3..91f44e78c516f67ff9969f47679bcefa22eed61d 100644
--- a/mm/vmalloc.c
+++ b/mm/vmalloc.c
@@ -1501,7 +1501,7 @@ static void __vunmap(const void *addr, int deallocate_pages)
 			struct page *page = area->pages[i];
 
 			BUG_ON(!page);
-			__free_kmem_pages(page, 0);
+			__free_pages(page, 0);
 		}
 
 		kvfree(area->pages);
@@ -1629,9 +1629,9 @@ static void *__vmalloc_area_node(struct vm_struct *area, gfp_t gfp_mask,
 		struct page *page;
 
 		if (node == NUMA_NO_NODE)
-			page = alloc_kmem_pages(alloc_mask, order);
+			page = alloc_pages(alloc_mask, order);
 		else
-			page = alloc_kmem_pages_node(node, alloc_mask, order);
+			page = alloc_pages_node(node, alloc_mask, order);
 
 		if (unlikely(!page)) {
 			/* Successfully allocated i pages, free them in __vunmap() */