diff --git a/include/linux/xarray.h b/include/linux/xarray.h
index f492e21c4aa2c81f5301116408f84d0703a61a18..4cf3cd12868938835a384bd1bb49a2ad6f5cb6b7 100644
--- a/include/linux/xarray.h
+++ b/include/linux/xarray.h
@@ -286,7 +286,6 @@ struct xarray {
  */
 #define DEFINE_XARRAY_ALLOC(name) DEFINE_XARRAY_FLAGS(name, XA_FLAGS_ALLOC)
 
-void xa_init_flags(struct xarray *, gfp_t flags);
 void *xa_load(struct xarray *, unsigned long index);
 void *xa_store(struct xarray *, unsigned long index, void *entry, gfp_t);
 void *xa_erase(struct xarray *, unsigned long index);
@@ -303,6 +302,24 @@ unsigned int xa_extract(struct xarray *, void **dst, unsigned long start,
 		unsigned long max, unsigned int n, xa_mark_t);
 void xa_destroy(struct xarray *);
 
+/**
+ * xa_init_flags() - Initialise an empty XArray with flags.
+ * @xa: XArray.
+ * @flags: XA_FLAG values.
+ *
+ * If you need to initialise an XArray with special flags (eg you need
+ * to take the lock from interrupt context), use this function instead
+ * of xa_init().
+ *
+ * Context: Any context.
+ */
+static inline void xa_init_flags(struct xarray *xa, gfp_t flags)
+{
+	spin_lock_init(&xa->xa_lock);
+	xa->xa_flags = flags;
+	xa->xa_head = NULL;
+}
+
 /**
  * xa_init() - Initialise an empty XArray.
  * @xa: XArray.
diff --git a/lib/xarray.c b/lib/xarray.c
index 5f3f9311de893a2975990060f5dfae6a6fb3d462..dda6026d202ec2174e1c459b4f24be67b827fa74 100644
--- a/lib/xarray.c
+++ b/lib/xarray.c
@@ -1250,35 +1250,6 @@ void *xas_find_conflict(struct xa_state *xas)
 }
 EXPORT_SYMBOL_GPL(xas_find_conflict);
 
-/**
- * xa_init_flags() - Initialise an empty XArray with flags.
- * @xa: XArray.
- * @flags: XA_FLAG values.
- *
- * If you need to initialise an XArray with special flags (eg you need
- * to take the lock from interrupt context), use this function instead
- * of xa_init().
- *
- * Context: Any context.
- */
-void xa_init_flags(struct xarray *xa, gfp_t flags)
-{
-	unsigned int lock_type;
-	static struct lock_class_key xa_lock_irq;
-	static struct lock_class_key xa_lock_bh;
-
-	spin_lock_init(&xa->xa_lock);
-	xa->xa_flags = flags;
-	xa->xa_head = NULL;
-
-	lock_type = xa_lock_type(xa);
-	if (lock_type == XA_LOCK_IRQ)
-		lockdep_set_class(&xa->xa_lock, &xa_lock_irq);
-	else if (lock_type == XA_LOCK_BH)
-		lockdep_set_class(&xa->xa_lock, &xa_lock_bh);
-}
-EXPORT_SYMBOL(xa_init_flags);
-
 /**
  * xa_load() - Load an entry from an XArray.
  * @xa: XArray.