diff --git a/mm/ksm.c b/mm/ksm.c
index 21c5f4ff722935e6c5d699ba192ff3f9d9e4ce78..5621840401c05ef439764bdf74f59fed8b5fbc3f 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -338,6 +338,7 @@ static inline void stable_node_chain_add_dup(struct stable_node *dup,
 
 static inline void __stable_node_dup_del(struct stable_node *dup)
 {
+	VM_BUG_ON(!is_stable_node_dup(dup));
 	hlist_del(&dup->hlist_dup);
 	ksm_stable_node_dups--;
 }
@@ -1312,12 +1313,12 @@ bool is_page_sharing_candidate(struct stable_node *stable_node)
 	return __is_page_sharing_candidate(stable_node, 0);
 }
 
-static struct stable_node *stable_node_dup(struct stable_node *stable_node,
+static struct stable_node *stable_node_dup(struct stable_node **_stable_node,
 					   struct page **tree_page,
 					   struct rb_root *root,
 					   bool prune_stale_stable_nodes)
 {
-	struct stable_node *dup, *found = NULL;
+	struct stable_node *dup, *found = NULL, *stable_node = *_stable_node;
 	struct hlist_node *hlist_safe;
 	struct page *_tree_page;
 	int nr = 0;
@@ -1390,6 +1391,15 @@ static struct stable_node *stable_node_dup(struct stable_node *stable_node,
 			free_stable_node(stable_node);
 			ksm_stable_node_chains--;
 			ksm_stable_node_dups--;
+			/*
+			 * NOTE: the caller depends on the
+			 * *_stable_node to become NULL if the chain
+			 * was collapsed. Enforce that if anything
+			 * uses a stale (freed) stable_node chain a
+			 * visible crash will materialize (instead of
+			 * an use after free).
+			 */
+			*_stable_node = stable_node = NULL;
 		} else if (__is_page_sharing_candidate(found, 1)) {
 			/*
 			 * Refile our candidate at the head
@@ -1419,11 +1429,12 @@ static struct stable_node *stable_node_dup_any(struct stable_node *stable_node,
 			   typeof(*stable_node), hlist_dup);
 }
 
-static struct stable_node *__stable_node_chain(struct stable_node *stable_node,
+static struct stable_node *__stable_node_chain(struct stable_node **_stable_node,
 					       struct page **tree_page,
 					       struct rb_root *root,
 					       bool prune_stale_stable_nodes)
 {
+	struct stable_node *stable_node = *_stable_node;
 	if (!is_stable_node_chain(stable_node)) {
 		if (is_page_sharing_candidate(stable_node)) {
 			*tree_page = get_ksm_page(stable_node, false);
@@ -1431,11 +1442,11 @@ static struct stable_node *__stable_node_chain(struct stable_node *stable_node,
 		}
 		return NULL;
 	}
-	return stable_node_dup(stable_node, tree_page, root,
+	return stable_node_dup(_stable_node, tree_page, root,
 			       prune_stale_stable_nodes);
 }
 
-static __always_inline struct stable_node *chain_prune(struct stable_node *s_n,
+static __always_inline struct stable_node *chain_prune(struct stable_node **s_n,
 						       struct page **t_p,
 						       struct rb_root *root)
 {
@@ -1446,7 +1457,7 @@ static __always_inline struct stable_node *chain(struct stable_node *s_n,
 						 struct page **t_p,
 						 struct rb_root *root)
 {
-	return __stable_node_chain(s_n, t_p, root, false);
+	return __stable_node_chain(&s_n, t_p, root, false);
 }
 
 /*
@@ -1487,7 +1498,15 @@ static struct page *stable_tree_search(struct page *page)
 		cond_resched();
 		stable_node = rb_entry(*new, struct stable_node, node);
 		stable_node_any = NULL;
-		stable_node_dup = chain_prune(stable_node, &tree_page, root);
+		stable_node_dup = chain_prune(&stable_node, &tree_page, root);
+		/*
+		 * NOTE: stable_node may have been freed by
+		 * chain_prune() if the returned stable_node_dup is
+		 * not NULL. stable_node_dup may have been inserted in
+		 * the rbtree instead as a regular stable_node (in
+		 * order to collapse the stable_node chain if a single
+		 * stable_node dup was found in it).
+		 */
 		if (!stable_node_dup) {
 			/*
 			 * Either all stable_node dups were full in
@@ -1602,20 +1621,33 @@ static struct page *stable_tree_search(struct page *page)
 		return NULL;
 
 replace:
-	if (stable_node_dup == stable_node) {
+	/*
+	 * If stable_node was a chain and chain_prune collapsed it,
+	 * stable_node will be NULL here. In that case the
+	 * stable_node_dup is the regular stable_node that has
+	 * replaced the chain. If stable_node is not NULL and equal to
+	 * stable_node_dup there was no chain and stable_node_dup is
+	 * the regular stable_node in the stable rbtree. Otherwise
+	 * stable_node is the chain and stable_node_dup is the dup to
+	 * replace.
+	 */
+	if (!stable_node || stable_node_dup == stable_node) {
+		VM_BUG_ON(is_stable_node_chain(stable_node_dup));
+		VM_BUG_ON(is_stable_node_dup(stable_node_dup));
 		/* there is no chain */
 		if (page_node) {
 			VM_BUG_ON(page_node->head != &migrate_nodes);
 			list_del(&page_node->list);
 			DO_NUMA(page_node->nid = nid);
-			rb_replace_node(&stable_node->node, &page_node->node,
+			rb_replace_node(&stable_node_dup->node,
+					&page_node->node,
 					root);
 			if (is_page_sharing_candidate(page_node))
 				get_page(page);
 			else
 				page = NULL;
 		} else {
-			rb_erase(&stable_node->node, root);
+			rb_erase(&stable_node_dup->node, root);
 			page = NULL;
 		}
 	} else {
@@ -1642,7 +1674,17 @@ static struct page *stable_tree_search(struct page *page)
 	/* stable_node_dup could be null if it reached the limit */
 	if (!stable_node_dup)
 		stable_node_dup = stable_node_any;
-	if (stable_node_dup == stable_node) {
+	/*
+	 * If stable_node was a chain and chain_prune collapsed it,
+	 * stable_node will be NULL here. In that case the
+	 * stable_node_dup is the regular stable_node that has
+	 * replaced the chain. If stable_node is not NULL and equal to
+	 * stable_node_dup there was no chain and stable_node_dup is
+	 * the regular stable_node in the stable rbtree.
+	 */
+	if (!stable_node || stable_node_dup == stable_node) {
+		VM_BUG_ON(is_stable_node_chain(stable_node_dup));
+		VM_BUG_ON(is_stable_node_dup(stable_node_dup));
 		/* chain is missing so create it */
 		stable_node = alloc_stable_node_chain(stable_node_dup,
 						      root);
@@ -1655,6 +1697,8 @@ static struct page *stable_tree_search(struct page *page)
 	 * of the current nid for this page
 	 * content.
 	 */
+	VM_BUG_ON(!is_stable_node_chain(stable_node));
+	VM_BUG_ON(!is_stable_node_dup(stable_node_dup));
 	VM_BUG_ON(page_node->head != &migrate_nodes);
 	list_del(&page_node->list);
 	DO_NUMA(page_node->nid = nid);