diff --git a/drivers/net/ethernet/netronome/nfp/bpf/jit.c b/drivers/net/ethernet/netronome/nfp/bpf/jit.c
index 662cbc21d9093ae40e41ebe70b14fde8cefb8c7d..e23ca90289f71ca5776bab7e66659640de1857c2 100644
--- a/drivers/net/ethernet/netronome/nfp/bpf/jit.c
+++ b/drivers/net/ethernet/netronome/nfp/bpf/jit.c
@@ -3052,26 +3052,19 @@ static int jset_imm(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta)
 {
 	const struct bpf_insn *insn = &meta->insn;
 	u64 imm = insn->imm; /* sign extend */
+	u8 dst_gpr = insn->dst_reg * 2;
 	swreg tmp_reg;
 
-	if (!imm) {
-		meta->skip = true;
-		return 0;
-	}
-
-	if (imm & ~0U) {
-		tmp_reg = ur_load_imm_any(nfp_prog, imm & ~0U, imm_b(nfp_prog));
-		emit_alu(nfp_prog, reg_none(),
-			 reg_a(insn->dst_reg * 2), ALU_OP_AND, tmp_reg);
-		emit_br(nfp_prog, BR_BNE, insn->off, 0);
-	}
-
-	if (imm >> 32) {
-		tmp_reg = ur_load_imm_any(nfp_prog, imm >> 32, imm_b(nfp_prog));
+	tmp_reg = ur_load_imm_any(nfp_prog, imm & ~0U, imm_b(nfp_prog));
+	emit_alu(nfp_prog, imm_b(nfp_prog),
+		 reg_a(dst_gpr), ALU_OP_AND, tmp_reg);
+	/* Upper word of the mask can only be 0 or ~0 from sign extension,
+	 * so either ignore it or OR the whole thing in.
+	 */
+	if (imm >> 32)
 		emit_alu(nfp_prog, reg_none(),
-			 reg_a(insn->dst_reg * 2 + 1), ALU_OP_AND, tmp_reg);
-		emit_br(nfp_prog, BR_BNE, insn->off, 0);
-	}
+			 reg_a(dst_gpr + 1), ALU_OP_OR, imm_b(nfp_prog));
+	emit_br(nfp_prog, BR_BNE, insn->off, 0);
 
 	return 0;
 }
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 5c64281d566ecd21c14ad78a53bcdb78b3f4c049..d27d5a880015b92e0ce4e09c3add5f7dabc54a98 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3859,6 +3859,12 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
 		if (tnum_is_const(reg->var_off))
 			return !tnum_equals_const(reg->var_off, val);
 		break;
+	case BPF_JSET:
+		if ((~reg->var_off.mask & reg->var_off.value) & val)
+			return 1;
+		if (!((reg->var_off.mask | reg->var_off.value) & val))
+			return 0;
+		break;
 	case BPF_JGT:
 		if (reg->umin_value > val)
 			return 1;
@@ -3943,6 +3949,13 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
 		 */
 		__mark_reg_known(false_reg, val);
 		break;
+	case BPF_JSET:
+		false_reg->var_off = tnum_and(false_reg->var_off,
+					      tnum_const(~val));
+		if (is_power_of_2(val))
+			true_reg->var_off = tnum_or(true_reg->var_off,
+						    tnum_const(val));
+		break;
 	case BPF_JGT:
 		false_reg->umax_value = min(false_reg->umax_value, val);
 		true_reg->umin_value = max(true_reg->umin_value, val + 1);
@@ -4015,6 +4028,13 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
 		 */
 		__mark_reg_known(false_reg, val);
 		break;
+	case BPF_JSET:
+		false_reg->var_off = tnum_and(false_reg->var_off,
+					      tnum_const(~val));
+		if (is_power_of_2(val))
+			true_reg->var_off = tnum_or(true_reg->var_off,
+						    tnum_const(val));
+		break;
 	case BPF_JGT:
 		true_reg->umax_value = min(true_reg->umax_value, val - 1);
 		false_reg->umin_value = max(false_reg->umin_value, val);
@@ -6963,10 +6983,11 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr,
 	free_states(env);
 
 	if (ret == 0)
-		sanitize_dead_code(env);
+		ret = check_max_stack_depth(env);
 
+	/* instruction rewrites happen after this point */
 	if (ret == 0)
-		ret = check_max_stack_depth(env);
+		sanitize_dead_code(env);
 
 	if (ret == 0)
 		/* program is valid, convert *(u32*)(ctx + off) accesses */
diff --git a/tools/testing/selftests/bpf/.gitignore b/tools/testing/selftests/bpf/.gitignore
index 1b799e30c06d38fc744c3b98bf17f53ba89ba1a3..4a9785043a3914959fef99de25983fe3b5b256f3 100644
--- a/tools/testing/selftests/bpf/.gitignore
+++ b/tools/testing/selftests/bpf/.gitignore
@@ -27,3 +27,4 @@ test_flow_dissector
 flow_dissector_load
 test_netcnt
 test_section_names
+test_tcpnotify_user
diff --git a/tools/testing/selftests/bpf/test_verifier.c b/tools/testing/selftests/bpf/test_verifier.c
index 7865b94c02c4c0db3bbb10496f4ab8dd5177e8d4..b246931c46ef7a9a430bb0346b44d1e9fce7b41f 100644
--- a/tools/testing/selftests/bpf/test_verifier.c
+++ b/tools/testing/selftests/bpf/test_verifier.c
@@ -49,6 +49,7 @@
 #define MAX_INSNS	BPF_MAXINSNS
 #define MAX_FIXUPS	8
 #define MAX_NR_MAPS	13
+#define MAX_TEST_RUNS	8
 #define POINTER_VALUE	0xcafe4all
 #define TEST_DATA_LEN	64
 
@@ -86,6 +87,14 @@ struct bpf_test {
 	uint8_t flags;
 	__u8 data[TEST_DATA_LEN];
 	void (*fill_helper)(struct bpf_test *self);
+	uint8_t runs;
+	struct {
+		uint32_t retval, retval_unpriv;
+		union {
+			__u8 data[TEST_DATA_LEN];
+			__u64 data64[TEST_DATA_LEN / 8];
+		};
+	} retvals[MAX_TEST_RUNS];
 };
 
 /* Note we want this to be 64 bit aligned so that the end of our array is
@@ -14161,6 +14170,197 @@ static struct bpf_test tests[] = {
 		.errstr_unpriv = "R1 leaks addr",
 		.result = REJECT,
 	},
+	{
+		"jset: functional",
+		.insns = {
+			/* r0 = 0 */
+			BPF_MOV64_IMM(BPF_REG_0, 0),
+			/* prep for direct packet access via r2 */
+			BPF_LDX_MEM(BPF_W, BPF_REG_2, BPF_REG_1,
+				    offsetof(struct __sk_buff, data)),
+			BPF_LDX_MEM(BPF_W, BPF_REG_3, BPF_REG_1,
+				    offsetof(struct __sk_buff, data_end)),
+			BPF_MOV64_REG(BPF_REG_4, BPF_REG_2),
+			BPF_ALU64_IMM(BPF_ADD, BPF_REG_4, 8),
+			BPF_JMP_REG(BPF_JLE, BPF_REG_4, BPF_REG_3, 1),
+			BPF_EXIT_INSN(),
+
+			BPF_LDX_MEM(BPF_DW, BPF_REG_7, BPF_REG_2, 0),
+
+			/* reg, bit 63 or bit 0 set, taken */
+			BPF_LD_IMM64(BPF_REG_8, 0x8000000000000001),
+			BPF_JMP_REG(BPF_JSET, BPF_REG_7, BPF_REG_8, 1),
+			BPF_EXIT_INSN(),
+
+			/* reg, bit 62, not taken */
+			BPF_LD_IMM64(BPF_REG_8, 0x4000000000000000),
+			BPF_JMP_REG(BPF_JSET, BPF_REG_7, BPF_REG_8, 1),
+			BPF_JMP_IMM(BPF_JA, 0, 0, 1),
+			BPF_EXIT_INSN(),
+
+			/* imm, any bit set, taken */
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_7, -1, 1),
+			BPF_EXIT_INSN(),
+
+			/* imm, bit 31 set, taken */
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_7, 0x80000000, 1),
+			BPF_EXIT_INSN(),
+
+			/* all good - return r0 == 2 */
+			BPF_MOV64_IMM(BPF_REG_0, 2),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SCHED_CLS,
+		.result = ACCEPT,
+		.runs = 7,
+		.retvals = {
+			{ .retval = 2,
+			  .data64 = { (1ULL << 63) | (1U << 31) | (1U << 0), }
+			},
+			{ .retval = 2,
+			  .data64 = { (1ULL << 63) | (1U << 31), }
+			},
+			{ .retval = 2,
+			  .data64 = { (1ULL << 31) | (1U << 0), }
+			},
+			{ .retval = 2,
+			  .data64 = { (__u32)-1, }
+			},
+			{ .retval = 2,
+			  .data64 = { ~0x4000000000000000ULL, }
+			},
+			{ .retval = 0,
+			  .data64 = { 0, }
+			},
+			{ .retval = 0,
+			  .data64 = { ~0ULL, }
+			},
+		},
+	},
+	{
+		"jset: sign-extend",
+		.insns = {
+			/* r0 = 0 */
+			BPF_MOV64_IMM(BPF_REG_0, 0),
+			/* prep for direct packet access via r2 */
+			BPF_LDX_MEM(BPF_W, BPF_REG_2, BPF_REG_1,
+				    offsetof(struct __sk_buff, data)),
+			BPF_LDX_MEM(BPF_W, BPF_REG_3, BPF_REG_1,
+				    offsetof(struct __sk_buff, data_end)),
+			BPF_MOV64_REG(BPF_REG_4, BPF_REG_2),
+			BPF_ALU64_IMM(BPF_ADD, BPF_REG_4, 8),
+			BPF_JMP_REG(BPF_JLE, BPF_REG_4, BPF_REG_3, 1),
+			BPF_EXIT_INSN(),
+
+			BPF_LDX_MEM(BPF_DW, BPF_REG_7, BPF_REG_2, 0),
+
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_7, 0x80000000, 1),
+			BPF_EXIT_INSN(),
+
+			BPF_MOV64_IMM(BPF_REG_0, 2),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SCHED_CLS,
+		.result = ACCEPT,
+		.retval = 2,
+		.data = { 1, 0, 0, 0, 0, 0, 0, 1, },
+	},
+	{
+		"jset: known const compare",
+		.insns = {
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_0, 1, 1),
+			BPF_LDX_MEM(BPF_B, BPF_REG_8, BPF_REG_9, 0),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
+		.retval_unpriv = 1,
+		.result_unpriv = ACCEPT,
+		.retval = 1,
+		.result = ACCEPT,
+	},
+	{
+		"jset: known const compare bad",
+		.insns = {
+			BPF_MOV64_IMM(BPF_REG_0, 0),
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_0, 1, 1),
+			BPF_LDX_MEM(BPF_B, BPF_REG_8, BPF_REG_9, 0),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
+		.errstr_unpriv = "!read_ok",
+		.result_unpriv = REJECT,
+		.errstr = "!read_ok",
+		.result = REJECT,
+	},
+	{
+		"jset: unknown const compare taken",
+		.insns = {
+			BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
+				     BPF_FUNC_get_prandom_u32),
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_0, 1, 1),
+			BPF_JMP_IMM(BPF_JA, 0, 0, 1),
+			BPF_LDX_MEM(BPF_B, BPF_REG_8, BPF_REG_9, 0),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
+		.errstr_unpriv = "!read_ok",
+		.result_unpriv = REJECT,
+		.errstr = "!read_ok",
+		.result = REJECT,
+	},
+	{
+		"jset: unknown const compare not taken",
+		.insns = {
+			BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
+				     BPF_FUNC_get_prandom_u32),
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_0, 1, 1),
+			BPF_LDX_MEM(BPF_B, BPF_REG_8, BPF_REG_9, 0),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
+		.errstr_unpriv = "!read_ok",
+		.result_unpriv = REJECT,
+		.errstr = "!read_ok",
+		.result = REJECT,
+	},
+	{
+		"jset: half-known const compare",
+		.insns = {
+			BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
+				     BPF_FUNC_get_prandom_u32),
+			BPF_ALU64_IMM(BPF_OR, BPF_REG_0, 2),
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_0, 3, 1),
+			BPF_LDX_MEM(BPF_B, BPF_REG_8, BPF_REG_9, 0),
+			BPF_MOV64_IMM(BPF_REG_0, 0),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
+		.result_unpriv = ACCEPT,
+		.result = ACCEPT,
+	},
+	{
+		"jset: range",
+		.insns = {
+			BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
+				     BPF_FUNC_get_prandom_u32),
+			BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
+			BPF_MOV64_IMM(BPF_REG_0, 0),
+			BPF_ALU64_IMM(BPF_AND, BPF_REG_1, 0xff),
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_1, 0xf0, 3),
+			BPF_JMP_IMM(BPF_JLT, BPF_REG_1, 0x10, 1),
+			BPF_LDX_MEM(BPF_B, BPF_REG_8, BPF_REG_9, 0),
+			BPF_EXIT_INSN(),
+			BPF_JMP_IMM(BPF_JSET, BPF_REG_1, 0x10, 1),
+			BPF_EXIT_INSN(),
+			BPF_JMP_IMM(BPF_JGE, BPF_REG_1, 0x10, 1),
+			BPF_LDX_MEM(BPF_B, BPF_REG_8, BPF_REG_9, 0),
+			BPF_EXIT_INSN(),
+		},
+		.prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
+		.result_unpriv = ACCEPT,
+		.result = ACCEPT,
+	},
 };
 
 static int probe_filter_length(const struct bpf_insn *fp)
@@ -14443,16 +14643,42 @@ static int set_admin(bool admin)
 	return ret;
 }
 
+static int do_prog_test_run(int fd_prog, bool unpriv, uint32_t expected_val,
+			    void *data, size_t size_data)
+{
+	__u8 tmp[TEST_DATA_LEN << 2];
+	__u32 size_tmp = sizeof(tmp);
+	uint32_t retval;
+	int err;
+
+	if (unpriv)
+		set_admin(true);
+	err = bpf_prog_test_run(fd_prog, 1, data, size_data,
+				tmp, &size_tmp, &retval, NULL);
+	if (unpriv)
+		set_admin(false);
+	if (err && errno != 524/*ENOTSUPP*/ && errno != EPERM) {
+		printf("Unexpected bpf_prog_test_run error ");
+		return err;
+	}
+	if (!err && retval != expected_val &&
+	    expected_val != POINTER_VALUE) {
+		printf("FAIL retval %d != %d ", retval, expected_val);
+		return 1;
+	}
+
+	return 0;
+}
+
 static void do_test_single(struct bpf_test *test, bool unpriv,
 			   int *passes, int *errors)
 {
 	int fd_prog, expected_ret, alignment_prevented_execution;
 	int prog_len, prog_type = test->prog_type;
 	struct bpf_insn *prog = test->insns;
+	int run_errs, run_successes;
 	int map_fds[MAX_NR_MAPS];
 	const char *expected_err;
-	uint32_t expected_val;
-	uint32_t retval;
 	__u32 pflags;
 	int i, err;
 
@@ -14476,8 +14702,6 @@ static void do_test_single(struct bpf_test *test, bool unpriv,
 		       test->result_unpriv : test->result;
 	expected_err = unpriv && test->errstr_unpriv ?
 		       test->errstr_unpriv : test->errstr;
-	expected_val = unpriv && test->retval_unpriv ?
-		       test->retval_unpriv : test->retval;
 
 	alignment_prevented_execution = 0;
 
@@ -14489,10 +14713,8 @@ static void do_test_single(struct bpf_test *test, bool unpriv,
 		}
 #ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS
 		if (fd_prog >= 0 &&
-		    (test->flags & F_NEEDS_EFFICIENT_UNALIGNED_ACCESS)) {
+		    (test->flags & F_NEEDS_EFFICIENT_UNALIGNED_ACCESS))
 			alignment_prevented_execution = 1;
-			goto test_ok;
-		}
 #endif
 	} else {
 		if (fd_prog >= 0) {
@@ -14519,33 +14741,54 @@ static void do_test_single(struct bpf_test *test, bool unpriv,
 		}
 	}
 
-	if (fd_prog >= 0) {
-		__u8 tmp[TEST_DATA_LEN << 2];
-		__u32 size_tmp = sizeof(tmp);
-
-		if (unpriv)
-			set_admin(true);
-		err = bpf_prog_test_run(fd_prog, 1, test->data,
-					sizeof(test->data), tmp, &size_tmp,
-					&retval, NULL);
-		if (unpriv)
-			set_admin(false);
-		if (err && errno != 524/*ENOTSUPP*/ && errno != EPERM) {
-			printf("Unexpected bpf_prog_test_run error\n");
-			goto fail_log;
+	run_errs = 0;
+	run_successes = 0;
+	if (!alignment_prevented_execution && fd_prog >= 0) {
+		uint32_t expected_val;
+		int i;
+
+		if (!test->runs) {
+			expected_val = unpriv && test->retval_unpriv ?
+				test->retval_unpriv : test->retval;
+
+			err = do_prog_test_run(fd_prog, unpriv, expected_val,
+					       test->data, sizeof(test->data));
+			if (err)
+				run_errs++;
+			else
+				run_successes++;
 		}
-		if (!err && retval != expected_val &&
-		    expected_val != POINTER_VALUE) {
-			printf("FAIL retval %d != %d\n", retval, expected_val);
-			goto fail_log;
+
+		for (i = 0; i < test->runs; i++) {
+			if (unpriv && test->retvals[i].retval_unpriv)
+				expected_val = test->retvals[i].retval_unpriv;
+			else
+				expected_val = test->retvals[i].retval;
+
+			err = do_prog_test_run(fd_prog, unpriv, expected_val,
+					       test->retvals[i].data,
+					       sizeof(test->retvals[i].data));
+			if (err) {
+				printf("(run %d/%d) ", i + 1, test->runs);
+				run_errs++;
+			} else {
+				run_successes++;
+			}
 		}
 	}
-#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS
-test_ok:
-#endif
-	(*passes)++;
-	printf("OK%s\n", alignment_prevented_execution ?
-	       " (NOTE: not executed due to unknown alignment)" : "");
+
+	if (!run_errs) {
+		(*passes)++;
+		if (run_successes > 1)
+			printf("%d cases ", run_successes);
+		printf("OK");
+		if (alignment_prevented_execution)
+			printf(" (NOTE: not executed due to unknown alignment)");
+		printf("\n");
+	} else {
+		printf("\n");
+		goto fail_log;
+	}
 close_fds:
 	close(fd_prog);
 	for (i = 0; i < MAX_NR_MAPS; i++)