diff --git a/Documentation/admin-guide/cgroup-v2.rst b/Documentation/admin-guide/cgroup-v2.rst
index 1746131bc9cb315dca4d4b03b23678cece518727..184193bcb262ac908f1f5a7a7c2c662dec0ea4b8 100644
--- a/Documentation/admin-guide/cgroup-v2.rst
+++ b/Documentation/admin-guide/cgroup-v2.rst
@@ -1072,6 +1072,24 @@ PAGE_SIZE multiple when read back.
 	high limit is used and monitored properly, this limit's
 	utility is limited to providing the final safety net.
 
+  memory.oom.group
+	A read-write single value file which exists on non-root
+	cgroups.  The default value is "0".
+
+	Determines whether the cgroup should be treated as
+	an indivisible workload by the OOM killer. If set,
+	all tasks belonging to the cgroup or to its descendants
+	(if the memory cgroup is not a leaf cgroup) are killed
+	together or not at all. This can be used to avoid
+	partial kills to guarantee workload integrity.
+
+	Tasks with the OOM protection (oom_score_adj set to -1000)
+	are treated as an exception and are never killed.
+
+	If the OOM killer is invoked in a cgroup, it's not going
+	to kill any tasks outside of this cgroup, regardless
+	memory.oom.group values of ancestor cgroups.
+
   memory.events
 	A read-only flat-keyed file which exists on non-root cgroups.
 	The following entries are defined.  Unless specified
diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 0e6c515fb698f7a6e8c139b2fc87b9286342ca06..652f602167df49b58830acca591e3fbe5fb67185 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -225,6 +225,11 @@ struct mem_cgroup {
 	 */
 	bool use_hierarchy;
 
+	/*
+	 * Should the OOM killer kill all belonging tasks, had it kill one?
+	 */
+	bool oom_group;
+
 	/* protected by memcg_oom_lock */
 	bool		oom_lock;
 	int		under_oom;
@@ -542,6 +547,9 @@ static inline bool task_in_memcg_oom(struct task_struct *p)
 }
 
 bool mem_cgroup_oom_synchronize(bool wait);
+struct mem_cgroup *mem_cgroup_get_oom_group(struct task_struct *victim,
+					    struct mem_cgroup *oom_domain);
+void mem_cgroup_print_oom_group(struct mem_cgroup *memcg);
 
 #ifdef CONFIG_MEMCG_SWAP
 extern int do_swap_account;
@@ -1001,6 +1009,16 @@ static inline bool mem_cgroup_oom_synchronize(bool wait)
 	return false;
 }
 
+static inline struct mem_cgroup *mem_cgroup_get_oom_group(
+	struct task_struct *victim, struct mem_cgroup *oom_domain)
+{
+	return NULL;
+}
+
+static inline void mem_cgroup_print_oom_group(struct mem_cgroup *memcg)
+{
+}
+
 static inline unsigned long memcg_page_state(struct mem_cgroup *memcg,
 					     int idx)
 {
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 59c14c988143b2ce6cfe4b90fc6bf6c2ff564be5..4ead5a4817de3ffbf1477cfe77b8f4f4802281c8 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1776,6 +1776,62 @@ bool mem_cgroup_oom_synchronize(bool handle)
 	return true;
 }
 
+/**
+ * mem_cgroup_get_oom_group - get a memory cgroup to clean up after OOM
+ * @victim: task to be killed by the OOM killer
+ * @oom_domain: memcg in case of memcg OOM, NULL in case of system-wide OOM
+ *
+ * Returns a pointer to a memory cgroup, which has to be cleaned up
+ * by killing all belonging OOM-killable tasks.
+ *
+ * Caller has to call mem_cgroup_put() on the returned non-NULL memcg.
+ */
+struct mem_cgroup *mem_cgroup_get_oom_group(struct task_struct *victim,
+					    struct mem_cgroup *oom_domain)
+{
+	struct mem_cgroup *oom_group = NULL;
+	struct mem_cgroup *memcg;
+
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+		return NULL;
+
+	if (!oom_domain)
+		oom_domain = root_mem_cgroup;
+
+	rcu_read_lock();
+
+	memcg = mem_cgroup_from_task(victim);
+	if (memcg == root_mem_cgroup)
+		goto out;
+
+	/*
+	 * Traverse the memory cgroup hierarchy from the victim task's
+	 * cgroup up to the OOMing cgroup (or root) to find the
+	 * highest-level memory cgroup with oom.group set.
+	 */
+	for (; memcg; memcg = parent_mem_cgroup(memcg)) {
+		if (memcg->oom_group)
+			oom_group = memcg;
+
+		if (memcg == oom_domain)
+			break;
+	}
+
+	if (oom_group)
+		css_get(&oom_group->css);
+out:
+	rcu_read_unlock();
+
+	return oom_group;
+}
+
+void mem_cgroup_print_oom_group(struct mem_cgroup *memcg)
+{
+	pr_info("Tasks in ");
+	pr_cont_cgroup_path(memcg->css.cgroup);
+	pr_cont(" are going to be killed due to memory.oom.group set\n");
+}
+
 /**
  * lock_page_memcg - lock a page->mem_cgroup binding
  * @page: the page
@@ -5561,6 +5617,37 @@ static int memory_stat_show(struct seq_file *m, void *v)
 	return 0;
 }
 
+static int memory_oom_group_show(struct seq_file *m, void *v)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(seq_css(m));
+
+	seq_printf(m, "%d\n", memcg->oom_group);
+
+	return 0;
+}
+
+static ssize_t memory_oom_group_write(struct kernfs_open_file *of,
+				      char *buf, size_t nbytes, loff_t off)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+	int ret, oom_group;
+
+	buf = strstrip(buf);
+	if (!buf)
+		return -EINVAL;
+
+	ret = kstrtoint(buf, 0, &oom_group);
+	if (ret)
+		return ret;
+
+	if (oom_group != 0 && oom_group != 1)
+		return -EINVAL;
+
+	memcg->oom_group = oom_group;
+
+	return nbytes;
+}
+
 static struct cftype memory_files[] = {
 	{
 		.name = "current",
@@ -5602,6 +5689,12 @@ static struct cftype memory_files[] = {
 		.flags = CFTYPE_NOT_ON_ROOT,
 		.seq_show = memory_stat_show,
 	},
+	{
+		.name = "oom.group",
+		.flags = CFTYPE_NOT_ON_ROOT | CFTYPE_NS_DELEGATABLE,
+		.seq_show = memory_oom_group_show,
+		.write = memory_oom_group_write,
+	},
 	{ }	/* terminate */
 };
 
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 330416c67ce52391ebc18727f5e917064f3a9c5f..0e10b864e0742da9f545130dbd913e1e8c7a8efb 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -908,6 +908,19 @@ static void __oom_kill_process(struct task_struct *victim)
 }
 #undef K
 
+/*
+ * Kill provided task unless it's secured by setting
+ * oom_score_adj to OOM_SCORE_ADJ_MIN.
+ */
+static int oom_kill_memcg_member(struct task_struct *task, void *unused)
+{
+	if (task->signal->oom_score_adj != OOM_SCORE_ADJ_MIN) {
+		get_task_struct(task);
+		__oom_kill_process(task);
+	}
+	return 0;
+}
+
 static void oom_kill_process(struct oom_control *oc, const char *message)
 {
 	struct task_struct *p = oc->chosen;
@@ -915,6 +928,7 @@ static void oom_kill_process(struct oom_control *oc, const char *message)
 	struct task_struct *victim = p;
 	struct task_struct *child;
 	struct task_struct *t;
+	struct mem_cgroup *oom_group;
 	unsigned int victim_points = 0;
 	static DEFINE_RATELIMIT_STATE(oom_rs, DEFAULT_RATELIMIT_INTERVAL,
 					      DEFAULT_RATELIMIT_BURST);
@@ -968,7 +982,23 @@ static void oom_kill_process(struct oom_control *oc, const char *message)
 	}
 	read_unlock(&tasklist_lock);
 
+	/*
+	 * Do we need to kill the entire memory cgroup?
+	 * Or even one of the ancestor memory cgroups?
+	 * Check this out before killing the victim task.
+	 */
+	oom_group = mem_cgroup_get_oom_group(victim, oc->memcg);
+
 	__oom_kill_process(victim);
+
+	/*
+	 * If necessary, kill all tasks in the selected memory cgroup.
+	 */
+	if (oom_group) {
+		mem_cgroup_print_oom_group(oom_group);
+		mem_cgroup_scan_tasks(oom_group, oom_kill_memcg_member, NULL);
+		mem_cgroup_put(oom_group);
+	}
 }
 
 /*