diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
index 166be18541117ae6ee22ab076d5a8c97fcd4d7b2..2f0e14974a0868581f9c32082e701e8b5c2ed492 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
@@ -210,6 +210,41 @@ static void mlxsw_sp_txhdr_construct(struct sk_buff *skb,
 	mlxsw_tx_hdr_type_set(txhdr, MLXSW_TXHDR_TYPE_CONTROL);
 }
 
+int mlxsw_sp_port_vid_stp_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid,
+			      u8 state)
+{
+	struct mlxsw_sp *mlxsw_sp = mlxsw_sp_port->mlxsw_sp;
+	enum mlxsw_reg_spms_state spms_state;
+	char *spms_pl;
+	int err;
+
+	switch (state) {
+	case BR_STATE_FORWARDING:
+		spms_state = MLXSW_REG_SPMS_STATE_FORWARDING;
+		break;
+	case BR_STATE_LEARNING:
+		spms_state = MLXSW_REG_SPMS_STATE_LEARNING;
+		break;
+	case BR_STATE_LISTENING: /* fall-through */
+	case BR_STATE_DISABLED: /* fall-through */
+	case BR_STATE_BLOCKING:
+		spms_state = MLXSW_REG_SPMS_STATE_DISCARDING;
+		break;
+	default:
+		BUG();
+	}
+
+	spms_pl = kmalloc(MLXSW_REG_SPMS_LEN, GFP_KERNEL);
+	if (!spms_pl)
+		return -ENOMEM;
+	mlxsw_reg_spms_pack(spms_pl, mlxsw_sp_port->local_port);
+	mlxsw_reg_spms_vid_pack(spms_pl, vid, spms_state);
+
+	err = mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(spms), spms_pl);
+	kfree(spms_pl);
+	return err;
+}
+
 static int mlxsw_sp_base_mac_get(struct mlxsw_sp *mlxsw_sp)
 {
 	char spad_pl[MLXSW_REG_SPAD_LEN] = {0};
@@ -649,8 +684,8 @@ int __mlxsw_sp_port_vid_learning_set(struct mlxsw_sp_port *mlxsw_sp_port,
 	return err;
 }
 
-static int mlxsw_sp_port_vid_learning_set(struct mlxsw_sp_port *mlxsw_sp_port,
-					  u16 vid, bool learn_enable)
+int mlxsw_sp_port_vid_learning_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid,
+				   bool learn_enable)
 {
 	return __mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid, vid,
 						learn_enable);
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
index 29db77f8bb6f993c709383c487d03bb881f27c53..d96e9126262e99c500edca50eb4529f81834f8d2 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
@@ -445,6 +445,10 @@ int mlxsw_sp_port_ets_maxrate_set(struct mlxsw_sp_port *mlxsw_sp_port,
 int __mlxsw_sp_port_vid_learning_set(struct mlxsw_sp_port *mlxsw_sp_port,
 				     u16 vid_begin, u16 vid_end,
 				     bool learn_enable);
+int mlxsw_sp_port_vid_stp_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid,
+			      u8 state);
+int mlxsw_sp_port_vid_learning_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid,
+				   bool learn_enable);
 
 #ifdef CONFIG_MLXSW_SPECTRUM_DCB
 
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
index d9393f7ff79f93b762da4e34a5fe3a931214eea6..8a31bf9013f2eed98e55343559dc513d57c58676 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
@@ -650,61 +650,44 @@ static int mlxsw_sp_port_fid_map(struct mlxsw_sp_port *mlxsw_sp_port, u16 fid,
 	return mlxsw_sp_port_vid_to_fid_set(mlxsw_sp_port, mt, valid, fid, fid);
 }
 
-static int mlxsw_sp_port_fid_join(struct mlxsw_sp_port *mlxsw_sp_port,
-				  u16 fid_begin, u16 fid_end)
+static int mlxsw_sp_port_fid_join(struct mlxsw_sp_port *mlxsw_sp_port, u16 fid)
 {
 	bool mc_flood;
-	int fid, err;
+	int err;
 
-	for (fid = fid_begin; fid <= fid_end; fid++) {
-		err = __mlxsw_sp_port_fid_join(mlxsw_sp_port, fid);
-		if (err)
-			goto err_port_fid_join;
-	}
+	err = __mlxsw_sp_port_fid_join(mlxsw_sp_port, fid);
+	if (err)
+		return err;
 
 	mc_flood = mlxsw_sp_port->mc_disabled ?
 			mlxsw_sp_port->mc_flood : mlxsw_sp_port->mc_router;
 
-	err = __mlxsw_sp_port_flood_set(mlxsw_sp_port, fid_begin, fid_end,
+	err = __mlxsw_sp_port_flood_set(mlxsw_sp_port, fid, fid,
 					mlxsw_sp_port->uc_flood, true,
 					mc_flood);
 	if (err)
 		goto err_port_flood_set;
 
-	for (fid = fid_begin; fid <= fid_end; fid++) {
-		err = mlxsw_sp_port_fid_map(mlxsw_sp_port, fid, true);
-		if (err)
-			goto err_port_fid_map;
-	}
+	err = mlxsw_sp_port_fid_map(mlxsw_sp_port, fid, true);
+	if (err)
+		goto err_port_fid_map;
 
 	return 0;
 
 err_port_fid_map:
-	for (fid--; fid >= fid_begin; fid--)
-		mlxsw_sp_port_fid_map(mlxsw_sp_port, fid, false);
-	__mlxsw_sp_port_flood_set(mlxsw_sp_port, fid_begin, fid_end, false,
-				  false, false);
+	__mlxsw_sp_port_flood_set(mlxsw_sp_port, fid, fid, false, false, false);
 err_port_flood_set:
-	fid = fid_end;
-err_port_fid_join:
-	for (fid--; fid >= fid_begin; fid--)
-		__mlxsw_sp_port_fid_leave(mlxsw_sp_port, fid);
+	__mlxsw_sp_port_fid_leave(mlxsw_sp_port, fid);
 	return err;
 }
 
 static void mlxsw_sp_port_fid_leave(struct mlxsw_sp_port *mlxsw_sp_port,
-				    u16 fid_begin, u16 fid_end)
+				    u16 fid)
 {
-	int fid;
-
-	for (fid = fid_begin; fid <= fid_end; fid++)
-		mlxsw_sp_port_fid_map(mlxsw_sp_port, fid, false);
-
-	__mlxsw_sp_port_flood_set(mlxsw_sp_port, fid_begin, fid_end, false,
+	mlxsw_sp_port_fid_map(mlxsw_sp_port, fid, false);
+	__mlxsw_sp_port_flood_set(mlxsw_sp_port, fid, fid, false,
 				  false, false);
-
-	for (fid = fid_begin; fid <= fid_end; fid++)
-		__mlxsw_sp_port_fid_leave(mlxsw_sp_port, fid);
+	__mlxsw_sp_port_fid_leave(mlxsw_sp_port, fid);
 }
 
 static int __mlxsw_sp_port_pvid_set(struct mlxsw_sp_port *mlxsw_sp_port,
@@ -764,104 +747,64 @@ int mlxsw_sp_port_pvid_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid)
 	return err;
 }
 
-static int mlxsw_sp_port_vid_learning_set(struct mlxsw_sp_port *mlxsw_sp_port,
-					  u16 vid_begin, u16 vid_end,
-					  bool learn_enable)
+static u16
+mlxsw_sp_port_pvid_determine(const struct mlxsw_sp_port *mlxsw_sp_port,
+			     u16 vid, bool is_pvid)
 {
-	u16 vid, vid_e;
-	int err;
-
-	for (vid = vid_begin; vid <= vid_end;
-	     vid += MLXSW_REG_SPVMLR_REC_MAX_COUNT) {
-		vid_e = min((u16) (vid + MLXSW_REG_SPVMLR_REC_MAX_COUNT - 1),
-			    vid_end);
-
-		err = __mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid,
-						       vid_e, learn_enable);
-		if (err)
-			return err;
-	}
-
-	return 0;
+	if (is_pvid)
+		return vid;
+	else if (mlxsw_sp_port->pvid == vid)
+		return 0;	/* Dis-allow untagged packets */
+	else
+		return mlxsw_sp_port->pvid;
 }
 
-static int __mlxsw_sp_port_vlans_add(struct mlxsw_sp_port *mlxsw_sp_port,
-				     u16 vid_begin, u16 vid_end,
-				     bool flag_untagged, bool flag_pvid)
+static int mlxsw_sp_port_vlan_add(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid,
+				  bool is_untagged, bool is_pvid)
 {
-	struct net_device *dev = mlxsw_sp_port->dev;
-	u16 vid, old_pvid;
+	u16 pvid = mlxsw_sp_port_pvid_determine(mlxsw_sp_port, vid, is_pvid);
+	u16 old_pvid = mlxsw_sp_port->pvid;
 	int err;
 
-	err = mlxsw_sp_port_fid_join(mlxsw_sp_port, vid_begin, vid_end);
-	if (err) {
-		netdev_err(dev, "Failed to join FIDs\n");
+	err = mlxsw_sp_port_fid_join(mlxsw_sp_port, vid);
+	if (err)
 		return err;
-	}
 
-	err = mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid_begin, vid_end,
-				     true, flag_untagged);
-	if (err) {
-		netdev_err(dev, "Unable to add VIDs %d-%d\n", vid_begin,
-			   vid_end);
-		goto err_port_vlans_set;
-	}
+	err = mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, true,
+				     is_untagged);
+	if (err)
+		goto err_port_vlan_set;
 
-	old_pvid = mlxsw_sp_port->pvid;
-	if (flag_pvid && old_pvid != vid_begin) {
-		err = mlxsw_sp_port_pvid_set(mlxsw_sp_port, vid_begin);
-		if (err) {
-			netdev_err(dev, "Unable to add PVID %d\n", vid_begin);
-			goto err_port_pvid_set;
-		}
-	} else if (!flag_pvid && old_pvid >= vid_begin && old_pvid <= vid_end) {
-		err = mlxsw_sp_port_pvid_set(mlxsw_sp_port, 0);
-		if (err) {
-			netdev_err(dev, "Unable to del PVID\n");
-			goto err_port_pvid_set;
-		}
-	}
+	err = mlxsw_sp_port_pvid_set(mlxsw_sp_port, pvid);
+	if (err)
+		goto err_port_pvid_set;
 
-	err = mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid_begin, vid_end,
+	err = mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid,
 					     mlxsw_sp_port->learning);
-	if (err) {
-		netdev_err(dev, "Failed to set learning for VIDs %d-%d\n",
-			   vid_begin, vid_end);
+	if (err)
 		goto err_port_vid_learning_set;
-	}
 
-	/* Changing activity bits only if HW operation succeded */
-	for (vid = vid_begin; vid <= vid_end; vid++) {
-		set_bit(vid, mlxsw_sp_port->active_vlans);
-		if (flag_untagged)
-			set_bit(vid, mlxsw_sp_port->untagged_vlans);
-		else
-			clear_bit(vid, mlxsw_sp_port->untagged_vlans);
-	}
+	err = mlxsw_sp_port_vid_stp_set(mlxsw_sp_port, vid,
+					mlxsw_sp_port->stp_state);
+	if (err)
+		goto err_port_vid_stp_set;
 
-	/* STP state change must be done after we set active VLANs */
-	err = mlxsw_sp_port_stp_state_set(mlxsw_sp_port,
-					  mlxsw_sp_port->stp_state);
-	if (err) {
-		netdev_err(dev, "Failed to set STP state\n");
-		goto err_port_stp_state_set;
-	}
+	if (is_untagged)
+		__set_bit(vid, mlxsw_sp_port->untagged_vlans);
+	else
+		__clear_bit(vid, mlxsw_sp_port->untagged_vlans);
+	__set_bit(vid, mlxsw_sp_port->active_vlans);
 
 	return 0;
 
-err_port_stp_state_set:
-	for (vid = vid_begin; vid <= vid_end; vid++)
-		clear_bit(vid, mlxsw_sp_port->active_vlans);
-	mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid_begin, vid_end,
-				       false);
+err_port_vid_stp_set:
+	mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid, false);
 err_port_vid_learning_set:
-	if (old_pvid != mlxsw_sp_port->pvid)
-		mlxsw_sp_port_pvid_set(mlxsw_sp_port, old_pvid);
+	mlxsw_sp_port_pvid_set(mlxsw_sp_port, old_pvid);
 err_port_pvid_set:
-	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid_begin, vid_end,
-			       false, false);
-err_port_vlans_set:
-	mlxsw_sp_port_fid_leave(mlxsw_sp_port, vid_begin, vid_end);
+	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, false, false);
+err_port_vlan_set:
+	mlxsw_sp_port_fid_leave(mlxsw_sp_port, vid);
 	return err;
 }
 
@@ -871,13 +814,21 @@ static int mlxsw_sp_port_vlans_add(struct mlxsw_sp_port *mlxsw_sp_port,
 {
 	bool flag_untagged = vlan->flags & BRIDGE_VLAN_INFO_UNTAGGED;
 	bool flag_pvid = vlan->flags & BRIDGE_VLAN_INFO_PVID;
+	u16 vid;
 
 	if (switchdev_trans_ph_prepare(trans))
 		return 0;
 
-	return __mlxsw_sp_port_vlans_add(mlxsw_sp_port,
-					 vlan->vid_begin, vlan->vid_end,
-					 flag_untagged, flag_pvid);
+	for (vid = vlan->vid_begin; vid <= vlan->vid_end; vid++) {
+		int err;
+
+		err = mlxsw_sp_port_vlan_add(mlxsw_sp_port, vid, flag_untagged,
+					     flag_pvid);
+		if (err)
+			return err;
+	}
+
+	return 0;
 }
 
 static enum mlxsw_reg_sfd_rec_policy mlxsw_sp_sfd_rec_policy(bool dynamic)
@@ -1151,35 +1102,27 @@ static int mlxsw_sp_port_obj_add(struct net_device *dev,
 	return err;
 }
 
-static int __mlxsw_sp_port_vlans_del(struct mlxsw_sp_port *mlxsw_sp_port,
-				     u16 vid_begin, u16 vid_end)
+static void mlxsw_sp_port_vlan_del(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid)
 {
-	u16 vid, pvid;
-
-	mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid_begin, vid_end,
-				       false);
-
-	pvid = mlxsw_sp_port->pvid;
-	if (pvid >= vid_begin && pvid <= vid_end)
-		mlxsw_sp_port_pvid_set(mlxsw_sp_port, 0);
-
-	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid_begin, vid_end,
-			       false, false);
+	u16 pvid = mlxsw_sp_port->pvid == vid ? 0 : vid;
 
-	mlxsw_sp_port_fid_leave(mlxsw_sp_port, vid_begin, vid_end);
-
-	/* Changing activity bits only if HW operation succeded */
-	for (vid = vid_begin; vid <= vid_end; vid++)
-		clear_bit(vid, mlxsw_sp_port->active_vlans);
-
-	return 0;
+	__clear_bit(vid, mlxsw_sp_port->active_vlans);
+	mlxsw_sp_port_vid_stp_set(mlxsw_sp_port, vid, BR_STATE_DISABLED);
+	mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid, false);
+	mlxsw_sp_port_pvid_set(mlxsw_sp_port, pvid);
+	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, false, false);
+	mlxsw_sp_port_fid_leave(mlxsw_sp_port, vid);
 }
 
 static int mlxsw_sp_port_vlans_del(struct mlxsw_sp_port *mlxsw_sp_port,
 				   const struct switchdev_obj_port_vlan *vlan)
 {
-	return __mlxsw_sp_port_vlans_del(mlxsw_sp_port, vlan->vid_begin,
-					 vlan->vid_end);
+	u16 vid;
+
+	for (vid = vlan->vid_begin; vid <= vlan->vid_end; vid++)
+		mlxsw_sp_port_vlan_del(mlxsw_sp_port, vid);
+
+	return 0;
 }
 
 void mlxsw_sp_port_active_vlans_del(struct mlxsw_sp_port *mlxsw_sp_port)
@@ -1187,7 +1130,7 @@ void mlxsw_sp_port_active_vlans_del(struct mlxsw_sp_port *mlxsw_sp_port)
 	u16 vid;
 
 	for_each_set_bit(vid, mlxsw_sp_port->active_vlans, VLAN_N_VID)
-		__mlxsw_sp_port_vlans_del(mlxsw_sp_port, vid, vid);
+		mlxsw_sp_port_vlan_del(mlxsw_sp_port, vid);
 }
 
 static int