diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index f0ae8d1d4329d5c72c9bcf9e460e16a7815a1059..229ab8c75a6b01b9e92c31751e4c1923b98cf1fc 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -555,8 +555,7 @@ static int __do_huge_pmd_anonymous_page(struct vm_fault *vmf, struct page *page,
 
 	VM_BUG_ON_PAGE(!PageCompound(page), page);
 
-	if (mem_cgroup_try_charge(page, vma->vm_mm, gfp | __GFP_NORETRY, &memcg,
-				  true)) {
+	if (mem_cgroup_try_charge(page, vma->vm_mm, gfp, &memcg, true)) {
 		put_page(page);
 		count_vm_event(THP_FAULT_FALLBACK);
 		return VM_FAULT_FALLBACK;
@@ -1317,7 +1316,7 @@ int do_huge_pmd_wp_page(struct vm_fault *vmf, pmd_t orig_pmd)
 	}
 
 	if (unlikely(mem_cgroup_try_charge(new_page, vma->vm_mm,
-				huge_gfp | __GFP_NORETRY, &memcg, true))) {
+					huge_gfp, &memcg, true))) {
 		put_page(new_page);
 		split_huge_pmd(vma, vmf->pmd, vmf->address);
 		if (page)
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index e42568284e06038ab70ec1344f63a2e5182ea90d..c15da1ea7e639bc0ab56bc747a4f7110c38f7187 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -965,9 +965,7 @@ static void collapse_huge_page(struct mm_struct *mm,
 		goto out_nolock;
 	}
 
-	/* Do not oom kill for khugepaged charges */
-	if (unlikely(mem_cgroup_try_charge(new_page, mm, gfp | __GFP_NORETRY,
-					   &memcg, true))) {
+	if (unlikely(mem_cgroup_try_charge(new_page, mm, gfp, &memcg, true))) {
 		result = SCAN_CGROUP_CHARGE_FAIL;
 		goto out_nolock;
 	}
@@ -1326,9 +1324,7 @@ static void collapse_shmem(struct mm_struct *mm,
 		goto out;
 	}
 
-	/* Do not oom kill for khugepaged charges */
-	if (unlikely(mem_cgroup_try_charge(new_page, mm, gfp | __GFP_NORETRY,
-					   &memcg, true))) {
+	if (unlikely(mem_cgroup_try_charge(new_page, mm, gfp, &memcg, true))) {
 		result = SCAN_CGROUP_CHARGE_FAIL;
 		goto out;
 	}
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 9ec024b862aca01cb32d302584e6b110d11107a7..6b4f5c0a8eefdf49539532517d08c19e0dfef276 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1485,7 +1485,7 @@ static void memcg_oom_recover(struct mem_cgroup *memcg)
 
 static void mem_cgroup_oom(struct mem_cgroup *memcg, gfp_t mask, int order)
 {
-	if (!current->memcg_may_oom)
+	if (!current->memcg_may_oom || order > PAGE_ALLOC_COSTLY_ORDER)
 		return;
 	/*
 	 * We are in the middle of the charge context here, so we