diff --git a/lib/idr.c b/lib/idr.c
index fab2fd5bc326bef8bdc9277485d9fe0d4ec50733..729e381e23b4324676e8f4375bde88c17dbfa5fa 100644
--- a/lib/idr.c
+++ b/lib/idr.c
@@ -39,8 +39,6 @@ int idr_alloc_u32(struct idr *idr, void *ptr, u32 *nextid,
 	unsigned int base = idr->idr_base;
 	unsigned int id = *nextid;
 
-	if (WARN_ON_ONCE(radix_tree_is_internal_node(ptr)))
-		return -EINVAL;
 	if (WARN_ON_ONCE(!(idr->idr_rt.gfp_mask & ROOT_IS_IDR)))
 		idr->idr_rt.gfp_mask |= IDR_RT_MARKER;
 
@@ -295,8 +293,6 @@ void *idr_replace(struct idr *idr, void *ptr, unsigned long id)
 	void __rcu **slot = NULL;
 	void *entry;
 
-	if (WARN_ON_ONCE(radix_tree_is_internal_node(ptr)))
-		return ERR_PTR(-EINVAL);
 	id -= idr->idr_base;
 
 	entry = __radix_tree_lookup(&idr->idr_rt, id, &node, &slot);
diff --git a/lib/radix-tree.c b/lib/radix-tree.c
index bc03ecc4dfd2f69c8638cbff3270bc55793b7f6b..a904a8ddd174372170f94507d556899448730d51 100644
--- a/lib/radix-tree.c
+++ b/lib/radix-tree.c
@@ -703,6 +703,14 @@ static inline bool radix_tree_shrink(struct radix_tree_root *root,
 		if (!radix_tree_is_internal_node(child) && node->shift)
 			break;
 
+		/*
+		 * For an IDR, we must not shrink entry 0 into the root in
+		 * case somebody calls idr_replace() with a pointer that
+		 * appears to be an internal entry
+		 */
+		if (!node->shift && is_idr(root))
+			break;
+
 		if (radix_tree_is_internal_node(child))
 			entry_to_node(child)->parent = NULL;
 
@@ -875,8 +883,8 @@ static void radix_tree_free_nodes(struct radix_tree_node *node)
 
 	for (;;) {
 		void *entry = rcu_dereference_raw(child->slots[offset]);
-		if (radix_tree_is_internal_node(entry) &&
-					!is_sibling_entry(child, entry)) {
+		if (radix_tree_is_internal_node(entry) && child->shift &&
+				!is_sibling_entry(child, entry)) {
 			child = entry_to_node(entry);
 			offset = 0;
 			continue;
@@ -1049,6 +1057,8 @@ void *__radix_tree_lookup(const struct radix_tree_root *root,
 		parent = entry_to_node(node);
 		offset = radix_tree_descend(parent, &node, index);
 		slot = parent->slots + offset;
+		if (parent->shift == 0)
+			break;
 	}
 
 	if (nodep)
@@ -1123,9 +1133,6 @@ static inline void replace_sibling_entries(struct radix_tree_node *node,
 static void replace_slot(void __rcu **slot, void *item,
 		struct radix_tree_node *node, int count, int exceptional)
 {
-	if (WARN_ON_ONCE(radix_tree_is_internal_node(item)))
-		return;
-
 	if (node && (count || exceptional)) {
 		node->count += count;
 		node->exceptional += exceptional;
@@ -1784,7 +1791,7 @@ void __rcu **radix_tree_next_chunk(const struct radix_tree_root *root,
 			goto restart;
 		if (child == RADIX_TREE_RETRY)
 			break;
-	} while (radix_tree_is_internal_node(child));
+	} while (node->shift && radix_tree_is_internal_node(child));
 
 	/* Update the iterator state */
 	iter->index = (index &~ node_maxindex(node)) | (offset << node->shift);
@@ -2150,6 +2157,8 @@ void __rcu **idr_get_free(struct radix_tree_root *root,
 		shift = error;
 		child = rcu_dereference_raw(root->rnode);
 	}
+	if (start == 0 && shift == 0)
+		shift = RADIX_TREE_MAP_SHIFT;
 
 	while (shift) {
 		shift -= RADIX_TREE_MAP_SHIFT;
diff --git a/tools/testing/radix-tree/idr-test.c b/tools/testing/radix-tree/idr-test.c
index 321ba92c70d2acb56ef5e9573c0a4b23b304062f..f620c831a4b5982f47191024a9992d27a6f633e5 100644
--- a/tools/testing/radix-tree/idr-test.c
+++ b/tools/testing/radix-tree/idr-test.c
@@ -227,6 +227,66 @@ void idr_u32_test(int base)
 	idr_u32_test1(&idr, 0xffffffff);
 }
 
+static void idr_align_test(struct idr *idr)
+{
+	char name[] = "Motorola 68000";
+	int i, id;
+	void *entry;
+
+	for (i = 0; i < 9; i++) {
+		BUG_ON(idr_alloc(idr, &name[i], 0, 0, GFP_KERNEL) != i);
+		idr_for_each_entry(idr, entry, id);
+	}
+	idr_destroy(idr);
+
+	for (i = 1; i < 10; i++) {
+		BUG_ON(idr_alloc(idr, &name[i], 0, 0, GFP_KERNEL) != i - 1);
+		idr_for_each_entry(idr, entry, id);
+	}
+	idr_destroy(idr);
+
+	for (i = 2; i < 11; i++) {
+		BUG_ON(idr_alloc(idr, &name[i], 0, 0, GFP_KERNEL) != i - 2);
+		idr_for_each_entry(idr, entry, id);
+	}
+	idr_destroy(idr);
+
+	for (i = 3; i < 12; i++) {
+		BUG_ON(idr_alloc(idr, &name[i], 0, 0, GFP_KERNEL) != i - 3);
+		idr_for_each_entry(idr, entry, id);
+	}
+	idr_destroy(idr);
+
+	for (i = 0; i < 8; i++) {
+		BUG_ON(idr_alloc(idr, &name[i], 0, 0, GFP_KERNEL) != 0);
+		BUG_ON(idr_alloc(idr, &name[i + 1], 0, 0, GFP_KERNEL) != 1);
+		idr_for_each_entry(idr, entry, id);
+		idr_remove(idr, 1);
+		idr_for_each_entry(idr, entry, id);
+		idr_remove(idr, 0);
+		BUG_ON(!idr_is_empty(idr));
+	}
+
+	for (i = 0; i < 8; i++) {
+		BUG_ON(idr_alloc(idr, NULL, 0, 0, GFP_KERNEL) != 0);
+		idr_for_each_entry(idr, entry, id);
+		idr_replace(idr, &name[i], 0);
+		idr_for_each_entry(idr, entry, id);
+		BUG_ON(idr_find(idr, 0) != &name[i]);
+		idr_remove(idr, 0);
+	}
+
+	for (i = 0; i < 8; i++) {
+		BUG_ON(idr_alloc(idr, &name[i], 0, 0, GFP_KERNEL) != 0);
+		BUG_ON(idr_alloc(idr, NULL, 0, 0, GFP_KERNEL) != 1);
+		idr_remove(idr, 1);
+		idr_for_each_entry(idr, entry, id);
+		idr_replace(idr, &name[i + 1], 0);
+		idr_for_each_entry(idr, entry, id);
+		idr_remove(idr, 0);
+	}
+}
+
 void idr_checks(void)
 {
 	unsigned long i;
@@ -307,6 +367,7 @@ void idr_checks(void)
 	idr_u32_test(4);
 	idr_u32_test(1);
 	idr_u32_test(0);
+	idr_align_test(&idr);
 }
 
 #define module_init(x)
@@ -341,6 +402,7 @@ void ida_check_nomem(void)
  */
 void ida_check_conv_user(void)
 {
+#if 0
 	DEFINE_IDA(ida);
 	unsigned long i;
 
@@ -358,6 +420,7 @@ void ida_check_conv_user(void)
 		IDA_BUG_ON(&ida, id != i);
 	}
 	ida_destroy(&ida);
+#endif
 }
 
 void ida_check_random(void)