diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
index 1542df00b23c1ceb0ef72509923731f1b919517e..aaddc0217e73dfce56b44f2e7f19c26a64ac9c2b 100644
--- a/arch/arm64/net/bpf_jit_comp.c
+++ b/arch/arm64/net/bpf_jit_comp.c
@@ -362,7 +362,8 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
 	const s16 off = insn->off;
 	const s32 imm = insn->imm;
 	const int i = insn - ctx->prog->insnsi;
-	const bool is64 = BPF_CLASS(code) == BPF_ALU64;
+	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
+			  BPF_CLASS(code) == BPF_JMP;
 	const bool isdw = BPF_SIZE(code) == BPF_DW;
 	u8 jmp_cond;
 	s32 jmp_offset;
@@ -559,7 +560,17 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
 	case BPF_JMP | BPF_JSLT | BPF_X:
 	case BPF_JMP | BPF_JSGE | BPF_X:
 	case BPF_JMP | BPF_JSLE | BPF_X:
-		emit(A64_CMP(1, dst, src), ctx);
+	case BPF_JMP32 | BPF_JEQ | BPF_X:
+	case BPF_JMP32 | BPF_JGT | BPF_X:
+	case BPF_JMP32 | BPF_JLT | BPF_X:
+	case BPF_JMP32 | BPF_JGE | BPF_X:
+	case BPF_JMP32 | BPF_JLE | BPF_X:
+	case BPF_JMP32 | BPF_JNE | BPF_X:
+	case BPF_JMP32 | BPF_JSGT | BPF_X:
+	case BPF_JMP32 | BPF_JSLT | BPF_X:
+	case BPF_JMP32 | BPF_JSGE | BPF_X:
+	case BPF_JMP32 | BPF_JSLE | BPF_X:
+		emit(A64_CMP(is64, dst, src), ctx);
 emit_cond_jmp:
 		jmp_offset = bpf2a64_offset(i + off, i, ctx);
 		check_imm19(jmp_offset);
@@ -601,7 +612,8 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
 		emit(A64_B_(jmp_cond, jmp_offset), ctx);
 		break;
 	case BPF_JMP | BPF_JSET | BPF_X:
-		emit(A64_TST(1, dst, src), ctx);
+	case BPF_JMP32 | BPF_JSET | BPF_X:
+		emit(A64_TST(is64, dst, src), ctx);
 		goto emit_cond_jmp;
 	/* IF (dst COND imm) JUMP off */
 	case BPF_JMP | BPF_JEQ | BPF_K:
@@ -614,12 +626,23 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
 	case BPF_JMP | BPF_JSLT | BPF_K:
 	case BPF_JMP | BPF_JSGE | BPF_K:
 	case BPF_JMP | BPF_JSLE | BPF_K:
-		emit_a64_mov_i(1, tmp, imm, ctx);
-		emit(A64_CMP(1, dst, tmp), ctx);
+	case BPF_JMP32 | BPF_JEQ | BPF_K:
+	case BPF_JMP32 | BPF_JGT | BPF_K:
+	case BPF_JMP32 | BPF_JLT | BPF_K:
+	case BPF_JMP32 | BPF_JGE | BPF_K:
+	case BPF_JMP32 | BPF_JLE | BPF_K:
+	case BPF_JMP32 | BPF_JNE | BPF_K:
+	case BPF_JMP32 | BPF_JSGT | BPF_K:
+	case BPF_JMP32 | BPF_JSLT | BPF_K:
+	case BPF_JMP32 | BPF_JSGE | BPF_K:
+	case BPF_JMP32 | BPF_JSLE | BPF_K:
+		emit_a64_mov_i(is64, tmp, imm, ctx);
+		emit(A64_CMP(is64, dst, tmp), ctx);
 		goto emit_cond_jmp;
 	case BPF_JMP | BPF_JSET | BPF_K:
-		emit_a64_mov_i(1, tmp, imm, ctx);
-		emit(A64_TST(1, dst, tmp), ctx);
+	case BPF_JMP32 | BPF_JSET | BPF_K:
+		emit_a64_mov_i(is64, tmp, imm, ctx);
+		emit(A64_TST(is64, dst, tmp), ctx);
 		goto emit_cond_jmp;
 	/* function call */
 	case BPF_JMP | BPF_CALL: