diff --git a/include/net/act_api.h b/include/net/act_api.h
index c745e9ccfab2d6f86a58c1a038e416fe29a63d67..54fbb49bd08a9d833c049a980d769ceff7458452 100644
--- a/include/net/act_api.h
+++ b/include/net/act_api.h
@@ -90,7 +90,7 @@ struct tc_action_ops {
 	int     (*lookup)(struct net *net, struct tc_action **a, u32 index);
 	int     (*init)(struct net *net, struct nlattr *nla,
 			struct nlattr *est, struct tc_action **act, int ovr,
-			int bind, bool rtnl_held,
+			int bind, bool rtnl_held, struct tcf_proto *tp,
 			struct netlink_ext_ack *extack);
 	int     (*walk)(struct net *, struct sk_buff *,
 			struct netlink_callback *, int,
@@ -181,6 +181,11 @@ int tcf_action_dump_old(struct sk_buff *skb, struct tc_action *a, int, int);
 int tcf_action_dump_1(struct sk_buff *skb, struct tc_action *a, int, int);
 int tcf_action_copy_stats(struct sk_buff *, struct tc_action *, int);
 
+int tcf_action_check_ctrlact(int action, struct tcf_proto *tp,
+			     struct tcf_chain **handle,
+			     struct netlink_ext_ack *newchain);
+struct tcf_chain *tcf_action_set_ctrlact(struct tc_action *a, int action,
+					 struct tcf_chain *newchain);
 #endif /* CONFIG_NET_CLS_ACT */
 
 static inline void tcf_action_stats_update(struct tc_action *a, u64 bytes,
diff --git a/net/sched/act_api.c b/net/sched/act_api.c
index aecf1bf233c8362673812b5ab212f32e5f868a5b..fe67b98ac64114a4df97eb9173ba40b341d7c6ae 100644
--- a/net/sched/act_api.c
+++ b/net/sched/act_api.c
@@ -28,23 +28,6 @@
 #include <net/act_api.h>
 #include <net/netlink.h>
 
-static int tcf_action_goto_chain_init(struct tc_action *a, struct tcf_proto *tp)
-{
-	u32 chain_index = a->tcfa_action & TC_ACT_EXT_VAL_MASK;
-
-	if (!tp)
-		return -EINVAL;
-	a->goto_chain = tcf_chain_get_by_act(tp->chain->block, chain_index);
-	if (!a->goto_chain)
-		return -ENOMEM;
-	return 0;
-}
-
-static void tcf_action_goto_chain_fini(struct tc_action *a)
-{
-	tcf_chain_put_by_act(a->goto_chain);
-}
-
 static void tcf_action_goto_chain_exec(const struct tc_action *a,
 				       struct tcf_result *res)
 {
@@ -71,6 +54,53 @@ static void tcf_set_action_cookie(struct tc_cookie __rcu **old_cookie,
 		call_rcu(&old->rcu, tcf_free_cookie_rcu);
 }
 
+int tcf_action_check_ctrlact(int action, struct tcf_proto *tp,
+			     struct tcf_chain **newchain,
+			     struct netlink_ext_ack *extack)
+{
+	int opcode = TC_ACT_EXT_OPCODE(action), ret = -EINVAL;
+	u32 chain_index;
+
+	if (!opcode)
+		ret = action > TC_ACT_VALUE_MAX ? -EINVAL : 0;
+	else if (opcode <= TC_ACT_EXT_OPCODE_MAX || action == TC_ACT_UNSPEC)
+		ret = 0;
+	if (ret) {
+		NL_SET_ERR_MSG(extack, "invalid control action");
+		goto end;
+	}
+
+	if (TC_ACT_EXT_CMP(action, TC_ACT_GOTO_CHAIN)) {
+		chain_index = action & TC_ACT_EXT_VAL_MASK;
+		if (!tp || !newchain) {
+			ret = -EINVAL;
+			NL_SET_ERR_MSG(extack,
+				       "can't goto NULL proto/chain");
+			goto end;
+		}
+		*newchain = tcf_chain_get_by_act(tp->chain->block, chain_index);
+		if (!*newchain) {
+			ret = -ENOMEM;
+			NL_SET_ERR_MSG(extack,
+				       "can't allocate goto_chain");
+		}
+	}
+end:
+	return ret;
+}
+EXPORT_SYMBOL(tcf_action_check_ctrlact);
+
+struct tcf_chain *tcf_action_set_ctrlact(struct tc_action *a, int action,
+					 struct tcf_chain *newchain)
+{
+	struct tcf_chain *oldchain = a->goto_chain;
+
+	a->tcfa_action = action;
+	a->goto_chain = newchain;
+	return oldchain;
+}
+EXPORT_SYMBOL(tcf_action_set_ctrlact);
+
 /* XXX: For standalone actions, we don't need a RCU grace period either, because
  * actions are always connected to filters and filters are already destroyed in
  * RCU callbacks, so after a RCU grace period actions are already disconnected
@@ -78,13 +108,15 @@ static void tcf_set_action_cookie(struct tc_cookie __rcu **old_cookie,
  */
 static void free_tcf(struct tc_action *p)
 {
+	struct tcf_chain *chain = p->goto_chain;
+
 	free_percpu(p->cpu_bstats);
 	free_percpu(p->cpu_bstats_hw);
 	free_percpu(p->cpu_qstats);
 
 	tcf_set_action_cookie(&p->act_cookie, NULL);
-	if (p->goto_chain)
-		tcf_action_goto_chain_fini(p);
+	if (chain)
+		tcf_chain_put_by_act(chain);
 
 	kfree(p);
 }
@@ -800,15 +832,6 @@ static struct tc_cookie *nla_memdup_cookie(struct nlattr **tb)
 	return c;
 }
 
-static bool tcf_action_valid(int action)
-{
-	int opcode = TC_ACT_EXT_OPCODE(action);
-
-	if (!opcode)
-		return action <= TC_ACT_VALUE_MAX;
-	return opcode <= TC_ACT_EXT_OPCODE_MAX || action == TC_ACT_UNSPEC;
-}
-
 struct tc_action *tcf_action_init_1(struct net *net, struct tcf_proto *tp,
 				    struct nlattr *nla, struct nlattr *est,
 				    char *name, int ovr, int bind,
@@ -890,10 +913,10 @@ struct tc_action *tcf_action_init_1(struct net *net, struct tcf_proto *tp,
 	/* backward compatibility for policer */
 	if (name == NULL)
 		err = a_o->init(net, tb[TCA_ACT_OPTIONS], est, &a, ovr, bind,
-				rtnl_held, extack);
+				rtnl_held, tp, extack);
 	else
 		err = a_o->init(net, nla, est, &a, ovr, bind, rtnl_held,
-				extack);
+				tp, extack);
 	if (err < 0)
 		goto err_mod;
 
@@ -907,18 +930,10 @@ struct tc_action *tcf_action_init_1(struct net *net, struct tcf_proto *tp,
 	if (err != ACT_P_CREATED)
 		module_put(a_o->owner);
 
-	if (TC_ACT_EXT_CMP(a->tcfa_action, TC_ACT_GOTO_CHAIN)) {
-		err = tcf_action_goto_chain_init(a, tp);
-		if (err) {
-			tcf_action_destroy_1(a, bind);
-			NL_SET_ERR_MSG(extack, "Failed to init TC action chain");
-			return ERR_PTR(err);
-		}
-	}
-
-	if (!tcf_action_valid(a->tcfa_action)) {
+	if (TC_ACT_EXT_CMP(a->tcfa_action, TC_ACT_GOTO_CHAIN) &&
+	    !a->goto_chain) {
 		tcf_action_destroy_1(a, bind);
-		NL_SET_ERR_MSG(extack, "Invalid control action value");
+		NL_SET_ERR_MSG(extack, "can't use goto chain with NULL chain");
 		return ERR_PTR(-EINVAL);
 	}
 
diff --git a/net/sched/act_bpf.c b/net/sched/act_bpf.c
index aa5c38d11a3079644d36329c8b4637e0bdfaa5c6..3c0468f2aae617ee67464a04fb1d11e8eb5a9346 100644
--- a/net/sched/act_bpf.c
+++ b/net/sched/act_bpf.c
@@ -278,7 +278,7 @@ static void tcf_bpf_prog_fill_cfg(const struct tcf_bpf *prog,
 static int tcf_bpf_init(struct net *net, struct nlattr *nla,
 			struct nlattr *est, struct tc_action **act,
 			int replace, int bind, bool rtnl_held,
-			struct netlink_ext_ack *extack)
+			struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, bpf_net_id);
 	struct nlattr *tb[TCA_ACT_BPF_MAX + 1];
diff --git a/net/sched/act_connmark.c b/net/sched/act_connmark.c
index 5d24993cccfebead613c7dd15bb41a1087c8024e..44aa046a92ea4a4b4a00db725175935bfe1d0827 100644
--- a/net/sched/act_connmark.c
+++ b/net/sched/act_connmark.c
@@ -97,6 +97,7 @@ static const struct nla_policy connmark_policy[TCA_CONNMARK_MAX + 1] = {
 static int tcf_connmark_init(struct net *net, struct nlattr *nla,
 			     struct nlattr *est, struct tc_action **a,
 			     int ovr, int bind, bool rtnl_held,
+			     struct tcf_proto *tp,
 			     struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, connmark_net_id);
diff --git a/net/sched/act_csum.c b/net/sched/act_csum.c
index c79aca29505e33a44b7402b3ccc1ccaf8827b280..9ba0f61a1e82eb5cc14855d15beb5c4a3a84aa02 100644
--- a/net/sched/act_csum.c
+++ b/net/sched/act_csum.c
@@ -46,7 +46,7 @@ static struct tc_action_ops act_csum_ops;
 
 static int tcf_csum_init(struct net *net, struct nlattr *nla,
 			 struct nlattr *est, struct tc_action **a, int ovr,
-			 int bind, bool rtnl_held,
+			 int bind, bool rtnl_held, struct tcf_proto *tp,
 			 struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, csum_net_id);
diff --git a/net/sched/act_gact.c b/net/sched/act_gact.c
index 93da0004e9f415e9eb8439eeacdc0ae73f5783fd..b8ad311bd8cc840c2a41d9a20e4dc37d08579106 100644
--- a/net/sched/act_gact.c
+++ b/net/sched/act_gact.c
@@ -57,7 +57,7 @@ static const struct nla_policy gact_policy[TCA_GACT_MAX + 1] = {
 static int tcf_gact_init(struct net *net, struct nlattr *nla,
 			 struct nlattr *est, struct tc_action **a,
 			 int ovr, int bind, bool rtnl_held,
-			 struct netlink_ext_ack *extack)
+			 struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, gact_net_id);
 	struct nlattr *tb[TCA_GACT_MAX + 1];
diff --git a/net/sched/act_ife.c b/net/sched/act_ife.c
index 9b1f2b3990eedeeca44ba6ad5c7e334043727dd1..c1ba74d5c1e38a5bd3c0592a64be5978e380ee22 100644
--- a/net/sched/act_ife.c
+++ b/net/sched/act_ife.c
@@ -469,7 +469,7 @@ static int populate_metalist(struct tcf_ife_info *ife, struct nlattr **tb,
 static int tcf_ife_init(struct net *net, struct nlattr *nla,
 			struct nlattr *est, struct tc_action **a,
 			int ovr, int bind, bool rtnl_held,
-			struct netlink_ext_ack *extack)
+			struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, ife_net_id);
 	struct nlattr *tb[TCA_IFE_MAX + 1];
diff --git a/net/sched/act_ipt.c b/net/sched/act_ipt.c
index 98f5b6ea77b46ea7a55e1c325fd60542020b2165..04a0b5c611943a4e10bfa0855f1c3928e4e141de 100644
--- a/net/sched/act_ipt.c
+++ b/net/sched/act_ipt.c
@@ -97,7 +97,8 @@ static const struct nla_policy ipt_policy[TCA_IPT_MAX + 1] = {
 
 static int __tcf_ipt_init(struct net *net, unsigned int id, struct nlattr *nla,
 			  struct nlattr *est, struct tc_action **a,
-			  const struct tc_action_ops *ops, int ovr, int bind)
+			  const struct tc_action_ops *ops, int ovr, int bind,
+			  struct tcf_proto *tp)
 {
 	struct tc_action_net *tn = net_generic(net, id);
 	struct nlattr *tb[TCA_IPT_MAX + 1];
@@ -205,20 +206,20 @@ static int __tcf_ipt_init(struct net *net, unsigned int id, struct nlattr *nla,
 
 static int tcf_ipt_init(struct net *net, struct nlattr *nla,
 			struct nlattr *est, struct tc_action **a, int ovr,
-			int bind, bool rtnl_held,
+			int bind, bool rtnl_held, struct tcf_proto *tp,
 			struct netlink_ext_ack *extack)
 {
 	return __tcf_ipt_init(net, ipt_net_id, nla, est, a, &act_ipt_ops, ovr,
-			      bind);
+			      bind, tp);
 }
 
 static int tcf_xt_init(struct net *net, struct nlattr *nla,
 		       struct nlattr *est, struct tc_action **a, int ovr,
-		       int bind, bool unlocked,
+		       int bind, bool unlocked, struct tcf_proto *tp,
 		       struct netlink_ext_ack *extack)
 {
 	return __tcf_ipt_init(net, xt_net_id, nla, est, a, &act_xt_ops, ovr,
-			      bind);
+			      bind, tp);
 }
 
 static int tcf_ipt_act(struct sk_buff *skb, const struct tc_action *a,
diff --git a/net/sched/act_mirred.c b/net/sched/act_mirred.c
index 6692fd0546177347a70123a959156d1120eace90..383f4024452ccf5cf1d7f10b4a6f1852278bb62a 100644
--- a/net/sched/act_mirred.c
+++ b/net/sched/act_mirred.c
@@ -94,6 +94,7 @@ static struct tc_action_ops act_mirred_ops;
 static int tcf_mirred_init(struct net *net, struct nlattr *nla,
 			   struct nlattr *est, struct tc_action **a,
 			   int ovr, int bind, bool rtnl_held,
+			   struct tcf_proto *tp,
 			   struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, mirred_net_id);
diff --git a/net/sched/act_nat.c b/net/sched/act_nat.c
index 543eab9193f17ca94756bff060fee937078a1e21..de4b493e26d2262bcd0341b7a0ff2ff1800401e0 100644
--- a/net/sched/act_nat.c
+++ b/net/sched/act_nat.c
@@ -38,7 +38,8 @@ static const struct nla_policy nat_policy[TCA_NAT_MAX + 1] = {
 
 static int tcf_nat_init(struct net *net, struct nlattr *nla, struct nlattr *est,
 			struct tc_action **a, int ovr, int bind,
-			bool rtnl_held, struct netlink_ext_ack *extack)
+			bool rtnl_held,	struct tcf_proto *tp,
+			struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, nat_net_id);
 	struct nlattr *tb[TCA_NAT_MAX + 1];
diff --git a/net/sched/act_pedit.c b/net/sched/act_pedit.c
index a80373878df769d180a6c34a9c6fdda4e846efd8..8ca82aefa11a967c65f260dc5a62f6db8d89d2e6 100644
--- a/net/sched/act_pedit.c
+++ b/net/sched/act_pedit.c
@@ -138,7 +138,7 @@ static int tcf_pedit_key_ex_dump(struct sk_buff *skb,
 static int tcf_pedit_init(struct net *net, struct nlattr *nla,
 			  struct nlattr *est, struct tc_action **a,
 			  int ovr, int bind, bool rtnl_held,
-			  struct netlink_ext_ack *extack)
+			  struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, pedit_net_id);
 	struct nlattr *tb[TCA_PEDIT_MAX + 1];
diff --git a/net/sched/act_police.c b/net/sched/act_police.c
index 8271a6263824bf53aaa92f97a35916d6d31aa244..229eba7925e516393e1758ec07a52f9c5b58dbb3 100644
--- a/net/sched/act_police.c
+++ b/net/sched/act_police.c
@@ -83,6 +83,7 @@ static const struct nla_policy police_policy[TCA_POLICE_MAX + 1] = {
 static int tcf_police_init(struct net *net, struct nlattr *nla,
 			       struct nlattr *est, struct tc_action **a,
 			       int ovr, int bind, bool rtnl_held,
+			       struct tcf_proto *tp,
 			       struct netlink_ext_ack *extack)
 {
 	int ret = 0, tcfp_result = TC_ACT_OK, err, size;
diff --git a/net/sched/act_sample.c b/net/sched/act_sample.c
index 203e399e5c85a293b29d52a84f31bff929827cf4..36b8adbe935d9a578f1746c12f92ef419d37a01d 100644
--- a/net/sched/act_sample.c
+++ b/net/sched/act_sample.c
@@ -37,7 +37,7 @@ static const struct nla_policy sample_policy[TCA_SAMPLE_MAX + 1] = {
 
 static int tcf_sample_init(struct net *net, struct nlattr *nla,
 			   struct nlattr *est, struct tc_action **a, int ovr,
-			   int bind, bool rtnl_held,
+			   int bind, bool rtnl_held, struct tcf_proto *tp,
 			   struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, sample_net_id);
diff --git a/net/sched/act_simple.c b/net/sched/act_simple.c
index d54cb608dbafae7ea9a333bd6a117824f2a1caac..4916dc3e3668b0c3b9ba0c186194203136b98ac9 100644
--- a/net/sched/act_simple.c
+++ b/net/sched/act_simple.c
@@ -78,7 +78,7 @@ static const struct nla_policy simple_policy[TCA_DEF_MAX + 1] = {
 static int tcf_simp_init(struct net *net, struct nlattr *nla,
 			 struct nlattr *est, struct tc_action **a,
 			 int ovr, int bind, bool rtnl_held,
-			 struct netlink_ext_ack *extack)
+			 struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, simp_net_id);
 	struct nlattr *tb[TCA_DEF_MAX + 1];
diff --git a/net/sched/act_skbedit.c b/net/sched/act_skbedit.c
index 65879500b688bca58c451b017928f22994fa9d82..4566eff3d0278524f5f8bd84220db24855fc2cec 100644
--- a/net/sched/act_skbedit.c
+++ b/net/sched/act_skbedit.c
@@ -96,6 +96,7 @@ static const struct nla_policy skbedit_policy[TCA_SKBEDIT_MAX + 1] = {
 static int tcf_skbedit_init(struct net *net, struct nlattr *nla,
 			    struct nlattr *est, struct tc_action **a,
 			    int ovr, int bind, bool rtnl_held,
+			    struct tcf_proto *tp,
 			    struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, skbedit_net_id);
diff --git a/net/sched/act_skbmod.c b/net/sched/act_skbmod.c
index 7bac1d78e7a39994ccb9bea33fd3aa0fd7030a7e..b9ab2c8f07f1af9b55ceb5bcff51c4c068d6c075 100644
--- a/net/sched/act_skbmod.c
+++ b/net/sched/act_skbmod.c
@@ -82,6 +82,7 @@ static const struct nla_policy skbmod_policy[TCA_SKBMOD_MAX + 1] = {
 static int tcf_skbmod_init(struct net *net, struct nlattr *nla,
 			   struct nlattr *est, struct tc_action **a,
 			   int ovr, int bind, bool rtnl_held,
+			   struct tcf_proto *tp,
 			   struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, skbmod_net_id);
diff --git a/net/sched/act_tunnel_key.c b/net/sched/act_tunnel_key.c
index 7c6591b991d510318f0eba8ec8036f62d8019a1d..fc295d91559acb273e6519d6d6b4aa3519dd216d 100644
--- a/net/sched/act_tunnel_key.c
+++ b/net/sched/act_tunnel_key.c
@@ -210,6 +210,7 @@ static void tunnel_key_release_params(struct tcf_tunnel_key_params *p)
 static int tunnel_key_init(struct net *net, struct nlattr *nla,
 			   struct nlattr *est, struct tc_action **a,
 			   int ovr, int bind, bool rtnl_held,
+			   struct tcf_proto *tp,
 			   struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, tunnel_key_net_id);
diff --git a/net/sched/act_vlan.c b/net/sched/act_vlan.c
index ac0061599225b6871d846dce30dd83ae37abff25..4651ee15e35d4f6347a6c0d8f71bd701598b2671 100644
--- a/net/sched/act_vlan.c
+++ b/net/sched/act_vlan.c
@@ -105,7 +105,7 @@ static const struct nla_policy vlan_policy[TCA_VLAN_MAX + 1] = {
 static int tcf_vlan_init(struct net *net, struct nlattr *nla,
 			 struct nlattr *est, struct tc_action **a,
 			 int ovr, int bind, bool rtnl_held,
-			 struct netlink_ext_ack *extack)
+			 struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
 	struct tc_action_net *tn = net_generic(net, vlan_net_id);
 	struct nlattr *tb[TCA_VLAN_MAX + 1];