diff --git a/lib/sbitmap.c b/lib/sbitmap.c
index 2261136ae0675e4a90b7ff4125c3441c00a03ebb..5b3e56d68dab8863eadf173b9dcd48452a7b5a87 100644
--- a/lib/sbitmap.c
+++ b/lib/sbitmap.c
@@ -20,6 +20,47 @@
 #include <linux/sbitmap.h>
 #include <linux/seq_file.h>
 
+/*
+ * See if we have deferred clears that we can batch move
+ */
+static inline bool sbitmap_deferred_clear(struct sbitmap *sb, int index)
+{
+	unsigned long mask, val;
+	unsigned long __maybe_unused flags;
+	bool ret = false;
+
+	/* Silence bogus lockdep warning */
+#if defined(CONFIG_LOCKDEP)
+	local_irq_save(flags);
+#endif
+	spin_lock(&sb->map[index].swap_lock);
+
+	if (!sb->map[index].cleared)
+		goto out_unlock;
+
+	/*
+	 * First get a stable cleared mask, setting the old mask to 0.
+	 */
+	do {
+		mask = sb->map[index].cleared;
+	} while (cmpxchg(&sb->map[index].cleared, mask, 0) != mask);
+
+	/*
+	 * Now clear the masked bits in our free word
+	 */
+	do {
+		val = sb->map[index].word;
+	} while (cmpxchg(&sb->map[index].word, val, val & ~mask) != val);
+
+	ret = true;
+out_unlock:
+	spin_unlock(&sb->map[index].swap_lock);
+#if defined(CONFIG_LOCKDEP)
+	local_irq_restore(flags);
+#endif
+	return ret;
+}
+
 int sbitmap_init_node(struct sbitmap *sb, unsigned int depth, int shift,
 		      gfp_t flags, int node)
 {
@@ -70,6 +111,9 @@ void sbitmap_resize(struct sbitmap *sb, unsigned int depth)
 	unsigned int bits_per_word = 1U << sb->shift;
 	unsigned int i;
 
+	for (i = 0; i < sb->map_nr; i++)
+		sbitmap_deferred_clear(sb, i);
+
 	sb->depth = depth;
 	sb->map_nr = DIV_ROUND_UP(sb->depth, bits_per_word);
 
@@ -112,47 +156,6 @@ static int __sbitmap_get_word(unsigned long *word, unsigned long depth,
 	return nr;
 }
 
-/*
- * See if we have deferred clears that we can batch move
- */
-static inline bool sbitmap_deferred_clear(struct sbitmap *sb, int index)
-{
-	unsigned long mask, val;
-	unsigned long __maybe_unused flags;
-	bool ret = false;
-
-	/* Silence bogus lockdep warning */
-#if defined(CONFIG_LOCKDEP)
-	local_irq_save(flags);
-#endif
-	spin_lock(&sb->map[index].swap_lock);
-
-	if (!sb->map[index].cleared)
-		goto out_unlock;
-
-	/*
-	 * First get a stable cleared mask, setting the old mask to 0.
-	 */
-	do {
-		mask = sb->map[index].cleared;
-	} while (cmpxchg(&sb->map[index].cleared, mask, 0) != mask);
-
-	/*
-	 * Now clear the masked bits in our free word
-	 */
-	do {
-		val = sb->map[index].word;
-	} while (cmpxchg(&sb->map[index].word, val, val & ~mask) != val);
-
-	ret = true;
-out_unlock:
-	spin_unlock(&sb->map[index].swap_lock);
-#if defined(CONFIG_LOCKDEP)
-	local_irq_restore(flags);
-#endif
-	return ret;
-}
-
 static int sbitmap_find_bit_in_index(struct sbitmap *sb, int index,
 				     unsigned int alloc_hint, bool round_robin)
 {
@@ -215,6 +218,7 @@ int sbitmap_get_shallow(struct sbitmap *sb, unsigned int alloc_hint,
 	index = SB_NR_TO_INDEX(sb, alloc_hint);
 
 	for (i = 0; i < sb->map_nr; i++) {
+again:
 		nr = __sbitmap_get_word(&sb->map[index].word,
 					min(sb->map[index].depth, shallow_depth),
 					SB_NR_TO_BIT(sb, alloc_hint), true);
@@ -223,6 +227,9 @@ int sbitmap_get_shallow(struct sbitmap *sb, unsigned int alloc_hint,
 			break;
 		}
 
+		if (sbitmap_deferred_clear(sb, index))
+			goto again;
+
 		/* Jump to next index. */
 		index++;
 		alloc_hint = index << sb->shift;
@@ -242,7 +249,7 @@ bool sbitmap_any_bit_set(const struct sbitmap *sb)
 	unsigned int i;
 
 	for (i = 0; i < sb->map_nr; i++) {
-		if (sb->map[i].word)
+		if (sb->map[i].word & ~sb->map[i].cleared)
 			return true;
 	}
 	return false;
@@ -255,9 +262,10 @@ bool sbitmap_any_bit_clear(const struct sbitmap *sb)
 
 	for (i = 0; i < sb->map_nr; i++) {
 		const struct sbitmap_word *word = &sb->map[i];
+		unsigned long mask = word->word & ~word->cleared;
 		unsigned long ret;
 
-		ret = find_first_zero_bit(&word->word, word->depth);
+		ret = find_first_zero_bit(&mask, word->depth);
 		if (ret < word->depth)
 			return true;
 	}