diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
index f276c453885f2116228d427008885ac3f905b585..30fe0d2e31a6eb33d7541db4b81c8a620e4206b9 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
@@ -2800,11 +2800,19 @@ mlxsw_sp_port_pvid_vport_lag_join(struct mlxsw_sp_port *mlxsw_sp_port,
 				  u16 lag_id)
 {
 	struct mlxsw_sp_port *mlxsw_sp_vport;
+	struct mlxsw_sp_fid *f;
 
 	mlxsw_sp_vport = mlxsw_sp_port_vport_find(mlxsw_sp_port, 1);
 	if (WARN_ON(!mlxsw_sp_vport))
 		return;
 
+	/* If vPort is assigned a RIF, then leave it since it's no
+	 * longer valid.
+	 */
+	f = mlxsw_sp_vport_fid_get(mlxsw_sp_vport);
+	if (f)
+		f->leave(mlxsw_sp_vport);
+
 	mlxsw_sp_vport->lag_id = lag_id;
 	mlxsw_sp_vport->lagged = 1;
 }
@@ -2813,11 +2821,16 @@ static void
 mlxsw_sp_port_pvid_vport_lag_leave(struct mlxsw_sp_port *mlxsw_sp_port)
 {
 	struct mlxsw_sp_port *mlxsw_sp_vport;
+	struct mlxsw_sp_fid *f;
 
 	mlxsw_sp_vport = mlxsw_sp_port_vport_find(mlxsw_sp_port, 1);
 	if (WARN_ON(!mlxsw_sp_vport))
 		return;
 
+	f = mlxsw_sp_vport_fid_get(mlxsw_sp_vport);
+	if (f)
+		f->leave(mlxsw_sp_vport);
+
 	mlxsw_sp_vport->lagged = 0;
 }