diff --git a/include/linux/sched.h b/include/linux/sched.h
index b48cd32be445b14b6a07feacff59f410a21d94ea..67ea79610e6748900d988d4425ca8d900de26d51 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -805,6 +805,8 @@ struct signal_struct {
 	short oom_score_adj;		/* OOM kill score adjustment */
 	short oom_score_adj_min;	/* OOM kill score adjustment min value.
 					 * Only settable by CAP_SYS_RESOURCE. */
+	struct mm_struct *oom_mm;	/* recorded mm when the thread group got
+					 * killed by the oom killer */
 
 	struct mutex cred_guard_mutex;	/* guard against foreign influences on
 					 * credential calculations
diff --git a/kernel/fork.c b/kernel/fork.c
index 9a05bd93f8e72872e5086e60c2ecde396846ab06..48cafe787b7590c01f6262762659f1cdc1590e45 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -359,6 +359,8 @@ static inline void free_signal_struct(struct signal_struct *sig)
 {
 	taskstats_tgid_free(sig);
 	sched_autogroup_exit(sig);
+	if (sig->oom_mm)
+		mmdrop(sig->oom_mm);
 	kmem_cache_free(signal_cachep, sig);
 }
 
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 45097f5a8f303c08c26c6f1e8847fc5a0c9c2117..f16ec0840a0e72d15e2e29ed84f10d845cad814a 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -300,14 +300,7 @@ static int oom_evaluate_task(struct task_struct *task, void *arg)
 	 * any memory is quite low.
 	 */
 	if (!is_sysrq_oom(oc) && atomic_read(&task->signal->oom_victims)) {
-		struct task_struct *p = find_lock_task_mm(task);
-		bool reaped = false;
-
-		if (p) {
-			reaped = test_bit(MMF_OOM_REAPED, &p->mm->flags);
-			task_unlock(p);
-		}
-		if (reaped)
+		if (test_bit(MMF_OOM_REAPED, &task->signal->oom_mm->flags))
 			goto next;
 		goto abort;
 	}
@@ -536,11 +529,6 @@ static bool __oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
 			K(get_mm_counter(mm, MM_SHMEMPAGES)));
 	up_read(&mm->mmap_sem);
 
-	/*
-	 * This task can be safely ignored because we cannot do much more
-	 * to release its memory.
-	 */
-	set_bit(MMF_OOM_REAPED, &mm->flags);
 	/*
 	 * Drop our reference but make sure the mmput slow path is called from a
 	 * different context because we shouldn't risk we get stuck there and
@@ -556,20 +544,7 @@ static bool __oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
 static void oom_reap_task(struct task_struct *tsk)
 {
 	int attempts = 0;
-	struct mm_struct *mm = NULL;
-	struct task_struct *p = find_lock_task_mm(tsk);
-
-	/*
-	 * Make sure we find the associated mm_struct even when the particular
-	 * thread has already terminated and cleared its mm.
-	 * We might have race with exit path so consider our work done if there
-	 * is no mm.
-	 */
-	if (!p)
-		goto done;
-	mm = p->mm;
-	atomic_inc(&mm->mm_count);
-	task_unlock(p);
+	struct mm_struct *mm = tsk->signal->oom_mm;
 
 	/* Retry the down_read_trylock(mmap_sem) a few times */
 	while (attempts++ < MAX_OOM_REAP_RETRIES && !__oom_reap_task_mm(tsk, mm))
@@ -578,8 +553,6 @@ static void oom_reap_task(struct task_struct *tsk)
 	if (attempts <= MAX_OOM_REAP_RETRIES)
 		goto done;
 
-	/* Ignore this mm because somebody can't call up_write(mmap_sem). */
-	set_bit(MMF_OOM_REAPED, &mm->flags);
 
 	pr_info("oom_reaper: unable to reap pid:%d (%s)\n",
 		task_pid_nr(tsk), tsk->comm);
@@ -595,11 +568,14 @@ static void oom_reap_task(struct task_struct *tsk)
 	tsk->oom_reaper_list = NULL;
 	exit_oom_victim(tsk);
 
+	/*
+	 * Hide this mm from OOM killer because it has been either reaped or
+	 * somebody can't call up_write(mmap_sem).
+	 */
+	set_bit(MMF_OOM_REAPED, &mm->flags);
+
 	/* Drop a reference taken by wake_oom_reaper */
 	put_task_struct(tsk);
-	/* Drop a reference taken above. */
-	if (mm)
-		mmdrop(mm);
 }
 
 static int oom_reaper(void *unused)
@@ -665,14 +641,25 @@ static inline void wake_oom_reaper(struct task_struct *tsk)
  *
  * Has to be called with oom_lock held and never after
  * oom has been disabled already.
+ *
+ * tsk->mm has to be non NULL and caller has to guarantee it is stable (either
+ * under task_lock or operate on the current).
  */
 static void mark_oom_victim(struct task_struct *tsk)
 {
+	struct mm_struct *mm = tsk->mm;
+
 	WARN_ON(oom_killer_disabled);
 	/* OOM killer might race with memcg OOM */
 	if (test_and_set_tsk_thread_flag(tsk, TIF_MEMDIE))
 		return;
+
 	atomic_inc(&tsk->signal->oom_victims);
+
+	/* oom_mm is bound to the signal struct life time. */
+	if (!cmpxchg(&tsk->signal->oom_mm, NULL, mm))
+		atomic_inc(&tsk->signal->oom_mm->mm_count);
+
 	/*
 	 * Make sure that the task is woken up from uninterruptible sleep
 	 * if it is frozen because OOM killer wouldn't be able to free