diff --git a/fs/proc/base.c b/fs/proc/base.c
index 64dadd469786033dad9a2a86c48660854075648c..77eb628ecc7f2b1adf1ffefc7aa7d7331f3ef9a4 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -532,8 +532,7 @@ static int proc_oom_score(struct seq_file *m, struct pid_namespace *ns,
 	unsigned long totalpages = totalram_pages() + total_swap_pages;
 	unsigned long points = 0;
 
-	points = oom_badness(task, NULL, totalpages) *
-					1000 / totalpages;
+	points = oom_badness(task, totalpages) * 1000 / totalpages;
 	seq_printf(m, "%lu\n", points);
 
 	return 0;
diff --git a/include/linux/oom.h b/include/linux/oom.h
index b75104690311557b38ac474118b5f17cb87a3428..c696c265f0193e2a06127f8815e89714ca3ab0fd 100644
--- a/include/linux/oom.h
+++ b/include/linux/oom.h
@@ -108,7 +108,6 @@ static inline vm_fault_t check_stable_address_space(struct mm_struct *mm)
 bool __oom_reap_task_mm(struct mm_struct *mm);
 
 extern unsigned long oom_badness(struct task_struct *p,
-		const nodemask_t *nodemask,
 		unsigned long totalpages);
 
 extern bool out_of_memory(struct oom_control *oc);
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index b353f468a36aebf1eccc8a036df946c277c260ee..d1c9c4e66d5915be1774d9178bc3ba1bcfc74db7 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -64,21 +64,33 @@ int sysctl_oom_dump_tasks = 1;
  */
 DEFINE_MUTEX(oom_lock);
 
+static inline bool is_memcg_oom(struct oom_control *oc)
+{
+	return oc->memcg != NULL;
+}
+
 #ifdef CONFIG_NUMA
 /**
- * has_intersects_mems_allowed() - check task eligiblity for kill
+ * oom_cpuset_eligible() - check task eligiblity for kill
  * @start: task struct of which task to consider
  * @mask: nodemask passed to page allocator for mempolicy ooms
  *
  * Task eligibility is determined by whether or not a candidate task, @tsk,
  * shares the same mempolicy nodes as current if it is bound by such a policy
  * and whether or not it has the same set of allowed cpuset nodes.
+ *
+ * This function is assuming oom-killer context and 'current' has triggered
+ * the oom-killer.
  */
-static bool has_intersects_mems_allowed(struct task_struct *start,
-					const nodemask_t *mask)
+static bool oom_cpuset_eligible(struct task_struct *start,
+				struct oom_control *oc)
 {
 	struct task_struct *tsk;
 	bool ret = false;
+	const nodemask_t *mask = oc->nodemask;
+
+	if (is_memcg_oom(oc))
+		return true;
 
 	rcu_read_lock();
 	for_each_thread(start, tsk) {
@@ -105,8 +117,7 @@ static bool has_intersects_mems_allowed(struct task_struct *start,
 	return ret;
 }
 #else
-static bool has_intersects_mems_allowed(struct task_struct *tsk,
-					const nodemask_t *mask)
+static bool oom_cpuset_eligible(struct task_struct *tsk, struct oom_control *oc)
 {
 	return true;
 }
@@ -146,24 +157,13 @@ static inline bool is_sysrq_oom(struct oom_control *oc)
 	return oc->order == -1;
 }
 
-static inline bool is_memcg_oom(struct oom_control *oc)
-{
-	return oc->memcg != NULL;
-}
-
 /* return true if the task is not adequate as candidate victim task. */
-static bool oom_unkillable_task(struct task_struct *p,
-				const nodemask_t *nodemask)
+static bool oom_unkillable_task(struct task_struct *p)
 {
 	if (is_global_init(p))
 		return true;
 	if (p->flags & PF_KTHREAD)
 		return true;
-
-	/* p may not have freeable memory in nodemask */
-	if (!has_intersects_mems_allowed(p, nodemask))
-		return true;
-
 	return false;
 }
 
@@ -190,19 +190,17 @@ static bool is_dump_unreclaim_slabs(void)
  * oom_badness - heuristic function to determine which candidate task to kill
  * @p: task struct of which task we should calculate
  * @totalpages: total present RAM allowed for page allocation
- * @nodemask: nodemask passed to page allocator for mempolicy ooms
  *
  * The heuristic for determining which task to kill is made to be as simple and
  * predictable as possible.  The goal is to return the highest value for the
  * task consuming the most memory to avoid subsequent oom failures.
  */
-unsigned long oom_badness(struct task_struct *p,
-			  const nodemask_t *nodemask, unsigned long totalpages)
+unsigned long oom_badness(struct task_struct *p, unsigned long totalpages)
 {
 	long points;
 	long adj;
 
-	if (oom_unkillable_task(p, nodemask))
+	if (oom_unkillable_task(p))
 		return 0;
 
 	p = find_lock_task_mm(p);
@@ -313,7 +311,11 @@ static int oom_evaluate_task(struct task_struct *task, void *arg)
 	struct oom_control *oc = arg;
 	unsigned long points;
 
-	if (oom_unkillable_task(task, oc->nodemask))
+	if (oom_unkillable_task(task))
+		goto next;
+
+	/* p may not have freeable memory in nodemask */
+	if (!is_memcg_oom(oc) && !oom_cpuset_eligible(task, oc))
 		goto next;
 
 	/*
@@ -337,7 +339,7 @@ static int oom_evaluate_task(struct task_struct *task, void *arg)
 		goto select;
 	}
 
-	points = oom_badness(task, oc->nodemask, oc->totalpages);
+	points = oom_badness(task, oc->totalpages);
 	if (!points || points < oc->chosen_points)
 		goto next;
 
@@ -382,7 +384,11 @@ static int dump_task(struct task_struct *p, void *arg)
 	struct oom_control *oc = arg;
 	struct task_struct *task;
 
-	if (oom_unkillable_task(p, oc->nodemask))
+	if (oom_unkillable_task(p))
+		return 0;
+
+	/* p may not have freeable memory in nodemask */
+	if (!is_memcg_oom(oc) && !oom_cpuset_eligible(p, oc))
 		return 0;
 
 	task = find_lock_task_mm(p);
@@ -1079,7 +1085,8 @@ bool out_of_memory(struct oom_control *oc)
 	check_panic_on_oom(oc);
 
 	if (!is_memcg_oom(oc) && sysctl_oom_kill_allocating_task &&
-	    current->mm && !oom_unkillable_task(current, oc->nodemask) &&
+	    current->mm && !oom_unkillable_task(current) &&
+	    oom_cpuset_eligible(current, oc) &&
 	    current->signal->oom_score_adj != OOM_SCORE_ADJ_MIN) {
 		get_task_struct(current);
 		oc->chosen = current;