diff --git a/kernel/fork.c b/kernel/fork.c
index 737db182843764cc1d2a16b32d65341ea21aef44..b409e792aadcce7a71abbe97c4fca723345cd3b9 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -955,6 +955,15 @@ static void mm_init_aio(struct mm_struct *mm)
 #endif
 }
 
+static __always_inline void mm_clear_owner(struct mm_struct *mm,
+					   struct task_struct *p)
+{
+#ifdef CONFIG_MEMCG
+	if (mm->owner == p)
+		WRITE_ONCE(mm->owner, NULL);
+#endif
+}
+
 static void mm_init_owner(struct mm_struct *mm, struct task_struct *p)
 {
 #ifdef CONFIG_MEMCG
@@ -1343,6 +1352,7 @@ static struct mm_struct *dup_mm(struct task_struct *tsk,
 free_pt:
 	/* don't put binfmt in mmput, we haven't got module yet */
 	mm->binfmt = NULL;
+	mm_init_owner(mm, NULL);
 	mmput(mm);
 
 fail_nomem:
@@ -1726,6 +1736,21 @@ static int pidfd_create(struct pid *pid)
 	return fd;
 }
 
+static void __delayed_free_task(struct rcu_head *rhp)
+{
+	struct task_struct *tsk = container_of(rhp, struct task_struct, rcu);
+
+	free_task(tsk);
+}
+
+static __always_inline void delayed_free_task(struct task_struct *tsk)
+{
+	if (IS_ENABLED(CONFIG_MEMCG))
+		call_rcu(&tsk->rcu, __delayed_free_task);
+	else
+		free_task(tsk);
+}
+
 /*
  * This creates a new process as a copy of the old one,
  * but does not actually start it yet.
@@ -2233,8 +2258,10 @@ static __latent_entropy struct task_struct *copy_process(
 bad_fork_cleanup_namespaces:
 	exit_task_namespaces(p);
 bad_fork_cleanup_mm:
-	if (p->mm)
+	if (p->mm) {
+		mm_clear_owner(p->mm, p);
 		mmput(p->mm);
+	}
 bad_fork_cleanup_signal:
 	if (!(clone_flags & CLONE_THREAD))
 		free_signal_struct(p->signal);
@@ -2265,7 +2292,7 @@ static __latent_entropy struct task_struct *copy_process(
 bad_fork_free:
 	p->state = TASK_DEAD;
 	put_task_stack(p);
-	free_task(p);
+	delayed_free_task(p);
 fork_out:
 	spin_lock_irq(&current->sighand->siglock);
 	hlist_del_init(&delayed.node);