diff --git a/drivers/infiniband/core/cache.c b/drivers/infiniband/core/cache.c
index a28dc1901c800092fdfbd9e278464abe4070c782..43c67e5f43c643a8228dea355031ad98287a8bdd 100644
--- a/drivers/infiniband/core/cache.c
+++ b/drivers/infiniband/core/cache.c
@@ -547,21 +547,19 @@ int ib_cache_gid_add(struct ib_device *ib_dev, u8 port,
 	unsigned long mask;
 	int ret;
 
-	if (ib_dev->ops.get_netdev) {
-		idev = ib_dev->ops.get_netdev(ib_dev, port);
-		if (idev && attr->ndev != idev) {
-			union ib_gid default_gid;
-
-			/* Adding default GIDs in not permitted */
-			make_default_gid(idev, &default_gid);
-			if (!memcmp(gid, &default_gid, sizeof(*gid))) {
-				dev_put(idev);
-				return -EPERM;
-			}
-		}
-		if (idev)
+	idev = ib_device_get_netdev(ib_dev, port);
+	if (idev && attr->ndev != idev) {
+		union ib_gid default_gid;
+
+		/* Adding default GIDs is not permitted */
+		make_default_gid(idev, &default_gid);
+		if (!memcmp(gid, &default_gid, sizeof(*gid))) {
 			dev_put(idev);
+			return -EPERM;
+		}
 	}
+	if (idev)
+		dev_put(idev);
 
 	mask = GID_ATTR_FIND_MASK_GID |
 	       GID_ATTR_FIND_MASK_GID_TYPE |
diff --git a/drivers/infiniband/core/core_priv.h b/drivers/infiniband/core/core_priv.h
index eeabe9ca8427d7f3f19fc845e801eef097ac848b..08c69024959489f6459b291ecb558c0f7ff14e18 100644
--- a/drivers/infiniband/core/core_priv.h
+++ b/drivers/infiniband/core/core_priv.h
@@ -66,6 +66,9 @@ typedef void (*roce_netdev_callback)(struct ib_device *device, u8 port,
 typedef bool (*roce_netdev_filter)(struct ib_device *device, u8 port,
 				   struct net_device *idev, void *cookie);
 
+struct net_device *ib_device_get_netdev(struct ib_device *ib_dev,
+					unsigned int port);
+
 void ib_enum_roce_netdev(struct ib_device *ib_dev,
 			 roce_netdev_filter filter,
 			 void *filter_cookie,
diff --git a/drivers/infiniband/core/device.c b/drivers/infiniband/core/device.c
index 8d7d63a60ef5157a9b0727e45ab8b3de75154447..7680a64a98bc540aa4f0132c29d6563b4899860f 100644
--- a/drivers/infiniband/core/device.c
+++ b/drivers/infiniband/core/device.c
@@ -134,6 +134,7 @@ static void *xan_find_marked(struct xarray *xa, unsigned long *indexp,
 	     !xa_is_err(entry);                                                \
 	     (index)++, entry = xan_find_marked(xa, &(index), filter))
 
+static void free_netdevs(struct ib_device *ib_dev);
 static int ib_security_change(struct notifier_block *nb, unsigned long event,
 			      void *lsm_data);
 static void ib_policy_change_task(struct work_struct *work);
@@ -290,6 +291,7 @@ static void ib_device_release(struct device *device)
 {
 	struct ib_device *dev = container_of(device, struct ib_device, dev);
 
+	free_netdevs(dev);
 	WARN_ON(refcount_read(&dev->refcount));
 	ib_cache_release_one(dev);
 	ib_security_release_port_pkey_list(dev);
@@ -371,6 +373,9 @@ EXPORT_SYMBOL(_ib_alloc_device);
  */
 void ib_dealloc_device(struct ib_device *device)
 {
+	/* Expedite releasing netdev references */
+	free_netdevs(device);
+
 	WARN_ON(!xa_empty(&device->client_data));
 	WARN_ON(refcount_read(&device->refcount));
 	rdma_restrack_clean(device);
@@ -461,16 +466,16 @@ static void remove_client_context(struct ib_device *device,
 	up_read(&device->client_data_rwsem);
 }
 
-static int verify_immutable(const struct ib_device *dev, u8 port)
-{
-	return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
-			    rdma_max_mad_size(dev, port) != 0);
-}
-
-static int setup_port_data(struct ib_device *device)
+static int alloc_port_data(struct ib_device *device)
 {
 	unsigned int port;
-	int ret;
+
+	if (device->port_data)
+		return 0;
+
+	/* This can only be called once the physical port range is defined */
+	if (WARN_ON(!device->phys_port_cnt))
+		return -EINVAL;
 
 	/*
 	 * device->port_data is indexed directly by the port number to make
@@ -489,6 +494,28 @@ static int setup_port_data(struct ib_device *device)
 
 		spin_lock_init(&pdata->pkey_list_lock);
 		INIT_LIST_HEAD(&pdata->pkey_list);
+		spin_lock_init(&pdata->netdev_lock);
+	}
+	return 0;
+}
+
+static int verify_immutable(const struct ib_device *dev, u8 port)
+{
+	return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
+			    rdma_max_mad_size(dev, port) != 0);
+}
+
+static int setup_port_data(struct ib_device *device)
+{
+	unsigned int port;
+	int ret;
+
+	ret = alloc_port_data(device);
+	if (ret)
+		return ret;
+
+	rdma_for_each_port (device, port) {
+		struct ib_port_data *pdata = &device->port_data[port];
 
 		ret = device->ops.get_port_immutable(device, port,
 						     &pdata->immutable);
@@ -682,6 +709,9 @@ static void disable_device(struct ib_device *device)
 	/* Pairs with refcount_set in enable_device */
 	ib_device_put(device);
 	wait_for_completion(&device->unreg_completion);
+
+	/* Expedite removing unregistered pointers from the hash table */
+	free_netdevs(device);
 }
 
 /*
@@ -1012,6 +1042,114 @@ int ib_query_port(struct ib_device *device,
 }
 EXPORT_SYMBOL(ib_query_port);
 
+/**
+ * ib_device_set_netdev - Associate the ib_dev with an underlying net_device
+ * @ib_dev: Device to modify
+ * @ndev: net_device to affiliate, may be NULL
+ * @port: IB port the net_device is connected to
+ *
+ * Drivers should use this to link the ib_device to a netdev so the netdev
+ * shows up in interfaces like ib_enum_roce_netdev. Only one netdev may be
+ * affiliated with any port.
+ *
+ * The caller must ensure that the given ndev is not unregistered or
+ * unregistering, and that either the ib_device is unregistered or
+ * ib_device_set_netdev() is called with NULL when the ndev sends a
+ * NETDEV_UNREGISTER event.
+ */
+int ib_device_set_netdev(struct ib_device *ib_dev, struct net_device *ndev,
+			 unsigned int port)
+{
+	struct net_device *old_ndev;
+	struct ib_port_data *pdata;
+	unsigned long flags;
+	int ret;
+
+	/*
+	 * Drivers wish to call this before ib_register_driver, so we have to
+	 * setup the port data early.
+	 */
+	ret = alloc_port_data(ib_dev);
+	if (ret)
+		return ret;
+
+	if (!rdma_is_port_valid(ib_dev, port))
+		return -EINVAL;
+
+	pdata = &ib_dev->port_data[port];
+	spin_lock_irqsave(&pdata->netdev_lock, flags);
+	if (pdata->netdev == ndev) {
+		spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+		return 0;
+	}
+	old_ndev = pdata->netdev;
+
+	if (ndev)
+		dev_hold(ndev);
+	pdata->netdev = ndev;
+	spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+
+	if (old_ndev)
+		dev_put(old_ndev);
+
+	return 0;
+}
+EXPORT_SYMBOL(ib_device_set_netdev);
+
+static void free_netdevs(struct ib_device *ib_dev)
+{
+	unsigned long flags;
+	unsigned int port;
+
+	rdma_for_each_port (ib_dev, port) {
+		struct ib_port_data *pdata = &ib_dev->port_data[port];
+
+		spin_lock_irqsave(&pdata->netdev_lock, flags);
+		if (pdata->netdev) {
+			dev_put(pdata->netdev);
+			pdata->netdev = NULL;
+		}
+		spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+	}
+}
+
+struct net_device *ib_device_get_netdev(struct ib_device *ib_dev,
+					unsigned int port)
+{
+	struct ib_port_data *pdata;
+	struct net_device *res;
+
+	if (!rdma_is_port_valid(ib_dev, port))
+		return NULL;
+
+	pdata = &ib_dev->port_data[port];
+
+	/*
+	 * New drivers should use ib_device_set_netdev() not the legacy
+	 * get_netdev().
+	 */
+	if (ib_dev->ops.get_netdev)
+		res = ib_dev->ops.get_netdev(ib_dev, port);
+	else {
+		spin_lock(&pdata->netdev_lock);
+		res = pdata->netdev;
+		if (res)
+			dev_hold(res);
+		spin_unlock(&pdata->netdev_lock);
+	}
+
+	/*
+	 * If we are starting to unregister expedite things by preventing
+	 * propagation of an unregistering netdev.
+	 */
+	if (res && res->reg_state != NETREG_REGISTERED) {
+		dev_put(res);
+		return NULL;
+	}
+
+	return res;
+}
+
 /**
  * ib_enum_roce_netdev - enumerate all RoCE ports
  * @ib_dev : IB device we want to query
@@ -1034,16 +1172,8 @@ void ib_enum_roce_netdev(struct ib_device *ib_dev,
 
 	rdma_for_each_port (ib_dev, port)
 		if (rdma_protocol_roce(ib_dev, port)) {
-			struct net_device *idev = NULL;
-
-			if (ib_dev->ops.get_netdev)
-				idev = ib_dev->ops.get_netdev(ib_dev, port);
-
-			if (idev &&
-			    idev->reg_state >= NETREG_UNREGISTERED) {
-				dev_put(idev);
-				idev = NULL;
-			}
+			struct net_device *idev =
+				ib_device_get_netdev(ib_dev, port);
 
 			if (filter(ib_dev, port, idev, filter_cookie))
 				cb(ib_dev, port, idev, cookie);
diff --git a/drivers/infiniband/core/nldev.c b/drivers/infiniband/core/nldev.c
index 85f6f2bcce4010c9101bb8c2ae6fd97130416779..1980ddc5f7bc6ea69f5b1bd0fd258180e3571328 100644
--- a/drivers/infiniband/core/nldev.c
+++ b/drivers/infiniband/core/nldev.c
@@ -268,9 +268,7 @@ static int fill_port_info(struct sk_buff *msg,
 	if (nla_put_u8(msg, RDMA_NLDEV_ATTR_PORT_PHYS_STATE, attr.phys_state))
 		return -EMSGSIZE;
 
-	if (device->ops.get_netdev)
-		netdev = device->ops.get_netdev(device, port);
-
+	netdev = ib_device_get_netdev(device, port);
 	if (netdev && net_eq(dev_net(netdev), net)) {
 		ret = nla_put_u32(msg,
 				  RDMA_NLDEV_ATTR_NDEV_INDEX, netdev->ifindex);
diff --git a/drivers/infiniband/core/verbs.c b/drivers/infiniband/core/verbs.c
index de5d895a50544ceeab721331e15a3e04d46d21e4..5a5e83f5f0fc4c8638f2c17efc455ae732757d45 100644
--- a/drivers/infiniband/core/verbs.c
+++ b/drivers/infiniband/core/verbs.c
@@ -1723,10 +1723,7 @@ int ib_get_eth_speed(struct ib_device *dev, u8 port_num, u8 *speed, u8 *width)
 	if (rdma_port_get_link_layer(dev, port_num) != IB_LINK_LAYER_ETHERNET)
 		return -EINVAL;
 
-	if (!dev->ops.get_netdev)
-		return -EOPNOTSUPP;
-
-	netdev = dev->ops.get_netdev(dev, port_num);
+	netdev = ib_device_get_netdev(dev, port_num);
 	if (!netdev)
 		return -ENODEV;
 
diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h
index 50b7ebc2885e9c549586520825db04e954d6507f..7f81a313c01bcf834f604921faa17dfb6a7e74d7 100644
--- a/include/rdma/ib_verbs.h
+++ b/include/rdma/ib_verbs.h
@@ -2204,6 +2204,9 @@ struct ib_port_data {
 	struct list_head pkey_list;
 
 	struct ib_port_cache cache;
+
+	spinlock_t netdev_lock;
+	struct net_device *netdev;
 };
 
 /* rdma netdev type - specifies protocol type */
@@ -3996,6 +3999,10 @@ void ib_device_put(struct ib_device *device);
 struct net_device *ib_get_net_dev_by_params(struct ib_device *dev, u8 port,
 					    u16 pkey, const union ib_gid *gid,
 					    const struct sockaddr *addr);
+int ib_device_set_netdev(struct ib_device *ib_dev, struct net_device *ndev,
+			 unsigned int port);
+struct net_device *ib_device_netdev(struct ib_device *dev, u8 port);
+
 struct ib_wq *ib_create_wq(struct ib_pd *pd,
 			   struct ib_wq_init_attr *init_attr);
 int ib_destroy_wq(struct ib_wq *wq);