diff --git a/drivers/md/md.c b/drivers/md/md.c
index a114b05e3db48372fb7ecf512c45d8902f2e1b3b..24638ccedce42b8d9fb346fc7a01ff8622f4c26d 100644
--- a/drivers/md/md.c
+++ b/drivers/md/md.c
@@ -5316,7 +5316,8 @@ int mddev_init_writes_pending(struct mddev *mddev)
 {
 	if (mddev->writes_pending.percpu_count_ptr)
 		return 0;
-	if (percpu_ref_init(&mddev->writes_pending, no_op, 0, GFP_KERNEL) < 0)
+	if (percpu_ref_init(&mddev->writes_pending, no_op,
+			    PERCPU_REF_ALLOW_REINIT, GFP_KERNEL) < 0)
 		return -ENOMEM;
 	/* We want to start with the refcount at zero */
 	percpu_ref_put(&mddev->writes_pending);
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 3fd884b4e0bec41b72406cb36a49ff05a5cde090..d682049c07b2c0b6daa78bccaa0da81b2df4811b 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -399,7 +399,8 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 	if (!ctx)
 		return NULL;
 
-	if (percpu_ref_init(&ctx->refs, io_ring_ctx_ref_free, 0, GFP_KERNEL)) {
+	if (percpu_ref_init(&ctx->refs, io_ring_ctx_ref_free,
+			    PERCPU_REF_ALLOW_REINIT, GFP_KERNEL)) {
 		kfree(ctx);
 		return NULL;
 	}
diff --git a/include/linux/percpu-refcount.h b/include/linux/percpu-refcount.h
index b297cd1cd4f190ccb36a75d73ec2d56a3300c2c0..7aef0abc194a2acabf47333bf10150053d681766 100644
--- a/include/linux/percpu-refcount.h
+++ b/include/linux/percpu-refcount.h
@@ -75,14 +75,21 @@ enum {
 	 * operation using percpu_ref_switch_to_percpu().  If initialized
 	 * with this flag, the ref will stay in atomic mode until
 	 * percpu_ref_switch_to_percpu() is invoked on it.
+	 * Implies ALLOW_REINIT.
 	 */
 	PERCPU_REF_INIT_ATOMIC	= 1 << 0,
 
 	/*
 	 * Start dead w/ ref == 0 in atomic mode.  Must be revived with
-	 * percpu_ref_reinit() before used.  Implies INIT_ATOMIC.
+	 * percpu_ref_reinit() before used.  Implies INIT_ATOMIC and
+	 * ALLOW_REINIT.
 	 */
 	PERCPU_REF_INIT_DEAD	= 1 << 1,
+
+	/*
+	 * Allow switching from atomic mode to percpu mode.
+	 */
+	PERCPU_REF_ALLOW_REINIT	= 1 << 2,
 };
 
 struct percpu_ref {
@@ -95,6 +102,7 @@ struct percpu_ref {
 	percpu_ref_func_t	*release;
 	percpu_ref_func_t	*confirm_switch;
 	bool			force_atomic:1;
+	bool			allow_reinit:1;
 	struct rcu_head		rcu;
 };
 
diff --git a/lib/percpu-refcount.c b/lib/percpu-refcount.c
index 071a76c7bac079d421840ad4c708c5e66d02ea8c..4f6c6ebbbbdea8c311d30eda388b29c2f02893b6 100644
--- a/lib/percpu-refcount.c
+++ b/lib/percpu-refcount.c
@@ -70,11 +70,14 @@ int percpu_ref_init(struct percpu_ref *ref, percpu_ref_func_t *release,
 		return -ENOMEM;
 
 	ref->force_atomic = flags & PERCPU_REF_INIT_ATOMIC;
+	ref->allow_reinit = flags & PERCPU_REF_ALLOW_REINIT;
 
-	if (flags & (PERCPU_REF_INIT_ATOMIC | PERCPU_REF_INIT_DEAD))
+	if (flags & (PERCPU_REF_INIT_ATOMIC | PERCPU_REF_INIT_DEAD)) {
 		ref->percpu_count_ptr |= __PERCPU_REF_ATOMIC;
-	else
+		ref->allow_reinit = true;
+	} else {
 		start_count += PERCPU_COUNT_BIAS;
+	}
 
 	if (flags & PERCPU_REF_INIT_DEAD)
 		ref->percpu_count_ptr |= __PERCPU_REF_DEAD;
@@ -120,6 +123,9 @@ static void percpu_ref_call_confirm_rcu(struct rcu_head *rcu)
 	ref->confirm_switch = NULL;
 	wake_up_all(&percpu_ref_switch_waitq);
 
+	if (!ref->allow_reinit)
+		percpu_ref_exit(ref);
+
 	/* drop ref from percpu_ref_switch_to_atomic() */
 	percpu_ref_put(ref);
 }
@@ -195,6 +201,9 @@ static void __percpu_ref_switch_to_percpu(struct percpu_ref *ref)
 	if (!(ref->percpu_count_ptr & __PERCPU_REF_ATOMIC))
 		return;
 
+	if (WARN_ON_ONCE(!ref->allow_reinit))
+		return;
+
 	atomic_long_add(PERCPU_COUNT_BIAS, &ref->count);
 
 	/*