diff --git a/include/net/tls.h b/include/net/tls.h
index a93a8ed8f71697a128c1b201b37e45281505697a..a8b37226a28795aa2e961791397d5304de3ee751 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -199,15 +199,8 @@ enum {
 };
 
 struct cipher_context {
-	u16 prepend_size;
-	u16 tag_size;
-	u16 overhead_size;
-	u16 iv_size;
 	char *iv;
-	u16 rec_seq_size;
 	char *rec_seq;
-	u16 aad_size;
-	u16 tail_size;
 };
 
 union tls_crypto_context {
@@ -218,7 +211,21 @@ union tls_crypto_context {
 	};
 };
 
+struct tls_prot_info {
+	u16 version;
+	u16 cipher_type;
+	u16 prepend_size;
+	u16 tag_size;
+	u16 overhead_size;
+	u16 iv_size;
+	u16 rec_seq_size;
+	u16 aad_size;
+	u16 tail_size;
+};
+
 struct tls_context {
+	struct tls_prot_info prot_info;
+
 	union tls_crypto_context crypto_send;
 	union tls_crypto_context crypto_recv;
 
@@ -401,16 +408,26 @@ static inline bool tls_bigint_increment(unsigned char *seq, int len)
 	return (i == -1);
 }
 
+static inline struct tls_context *tls_get_ctx(const struct sock *sk)
+{
+	struct inet_connection_sock *icsk = inet_csk(sk);
+
+	return icsk->icsk_ulp_data;
+}
+
 static inline void tls_advance_record_sn(struct sock *sk,
 					 struct cipher_context *ctx,
 					 int version)
 {
-	if (tls_bigint_increment(ctx->rec_seq, ctx->rec_seq_size))
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+
+	if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size))
 		tls_err_abort(sk, EBADMSG);
 
 	if (version != TLS_1_3_VERSION) {
 		tls_bigint_increment(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-				     ctx->iv_size);
+				     prot->iv_size);
 	}
 }
 
@@ -420,9 +437,10 @@ static inline void tls_fill_prepend(struct tls_context *ctx,
 			     unsigned char record_type,
 			     int version)
 {
-	size_t pkt_len, iv_size = ctx->tx.iv_size;
+	struct tls_prot_info *prot = &ctx->prot_info;
+	size_t pkt_len, iv_size = prot->iv_size;
 
-	pkt_len = plaintext_len + ctx->tx.tag_size;
+	pkt_len = plaintext_len + prot->tag_size;
 	if (version != TLS_1_3_VERSION) {
 		pkt_len += iv_size;
 
@@ -475,12 +493,6 @@ static inline void xor_iv_with_seq(int version, char *iv, char *seq)
 	}
 }
 
-static inline struct tls_context *tls_get_ctx(const struct sock *sk)
-{
-	struct inet_connection_sock *icsk = inet_csk(sk);
-
-	return icsk->icsk_ulp_data;
-}
 
 static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
 		const struct tls_context *tls_ctx)
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 7ee9008b2187d8135fc0bda7b90e150e82501527..a5c17c47d08a8d97c4bc7bc8c86c8f96939402ea 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -247,6 +247,7 @@ static int tls_push_record(struct sock *sk,
 			   int flags,
 			   unsigned char record_type)
 {
+	struct tls_prot_info *prot = &ctx->prot_info;
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct page_frag dummy_tag_frag;
 	skb_frag_t *frag;
@@ -256,7 +257,7 @@ static int tls_push_record(struct sock *sk,
 	frag = &record->frags[0];
 	tls_fill_prepend(ctx,
 			 skb_frag_address(frag),
-			 record->len - ctx->tx.prepend_size,
+			 record->len - prot->prepend_size,
 			 record_type,
 			 ctx->crypto_send.info.version);
 
@@ -264,7 +265,7 @@ static int tls_push_record(struct sock *sk,
 	dummy_tag_frag.page = skb_frag_page(frag);
 	dummy_tag_frag.offset = 0;
 
-	tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size);
+	tls_append_frag(record, &dummy_tag_frag, prot->tag_size);
 	record->end_seq = tp->write_seq + record->len;
 	spin_lock_irq(&offload_ctx->lock);
 	list_add_tail(&record->list, &offload_ctx->records_list);
@@ -347,6 +348,7 @@ static int tls_push_data(struct sock *sk,
 			 unsigned char record_type)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 	int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
 	int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
@@ -376,10 +378,10 @@ static int tls_push_data(struct sock *sk,
 	 * we need to leave room for an authentication tag.
 	 */
 	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
-			      tls_ctx->tx.prepend_size;
+			      prot->prepend_size;
 	do {
 		rc = tls_do_allocation(sk, ctx, pfrag,
-				       tls_ctx->tx.prepend_size);
+				       prot->prepend_size);
 		if (rc) {
 			rc = sk_stream_wait_memory(sk, &timeo);
 			if (!rc)
@@ -397,7 +399,7 @@ static int tls_push_data(struct sock *sk,
 				size = orig_size;
 				destroy_record(record);
 				ctx->open_record = NULL;
-			} else if (record->len > tls_ctx->tx.prepend_size) {
+			} else if (record->len > prot->prepend_size) {
 				goto last_record;
 			}
 
@@ -658,6 +660,8 @@ int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 {
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_record_info *start_marker_record;
 	struct tls_offload_context_tx *offload_ctx;
 	struct tls_crypto_info *crypto_info;
@@ -703,10 +707,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 		goto free_offload_ctx;
 	}
 
-	ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
-	ctx->tx.tag_size = tag_size;
-	ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
-	ctx->tx.iv_size = iv_size;
+	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
+	prot->tag_size = tag_size;
+	prot->overhead_size = prot->prepend_size + prot->tag_size;
+	prot->iv_size = iv_size;
 	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 			     GFP_KERNEL);
 	if (!ctx->tx.iv) {
@@ -716,7 +720,7 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 
 	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
 
-	ctx->tx.rec_seq_size = rec_seq_size;
+	prot->rec_seq_size = rec_seq_size;
 	ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
 	if (!ctx->tx.rec_seq) {
 		rc = -ENOMEM;
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index d1c2fd9a3f63e09fcd361dc2949a59c9c645a35f..caff15b2f9b263b665157ea048eab494049aa1d8 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -435,6 +435,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 				  unsigned int optlen, int tx)
 {
 	struct tls_crypto_info *crypto_info;
+	struct tls_crypto_info *alt_crypto_info;
 	struct tls_context *ctx = tls_get_ctx(sk);
 	size_t optsize;
 	int rc = 0;
@@ -445,10 +446,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 		goto out;
 	}
 
-	if (tx)
+	if (tx) {
 		crypto_info = &ctx->crypto_send.info;
-	else
+		alt_crypto_info = &ctx->crypto_recv.info;
+	} else {
 		crypto_info = &ctx->crypto_recv.info;
+		alt_crypto_info = &ctx->crypto_send.info;
+	}
 
 	/* Currently we don't support set crypto info more than one time */
 	if (TLS_CRYPTO_INFO_READY(crypto_info)) {
@@ -469,6 +473,15 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 		goto err_crypto_info;
 	}
 
+	/* Ensure that TLS version and ciphers are same in both directions */
+	if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
+		if (alt_crypto_info->version != crypto_info->version ||
+		    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
+			rc = -EINVAL;
+			goto err_crypto_info;
+		}
+	}
+
 	switch (crypto_info->cipher_type) {
 	case TLS_CIPHER_AES_GCM_128:
 	case TLS_CIPHER_AES_GCM_256: {
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index ae478473454779a098967f432db0d048773c8dc9..71be8acfbc9b382c53d07b0f677636b5f34ef24c 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -127,7 +127,7 @@ static int padding_length(struct tls_sw_context_rx *ctx,
 	int sub = 0;
 
 	/* Determine zero-padding length */
-	if (tls_ctx->crypto_recv.info.version == TLS_1_3_VERSION) {
+	if (tls_ctx->prot_info.version == TLS_1_3_VERSION) {
 		char content_type = 0;
 		int err;
 		int back = 17;
@@ -155,6 +155,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 	struct scatterlist *sgin = aead_req->src;
 	struct tls_sw_context_rx *ctx;
 	struct tls_context *tls_ctx;
+	struct tls_prot_info *prot;
 	struct scatterlist *sg;
 	struct sk_buff *skb;
 	unsigned int pages;
@@ -163,6 +164,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 	skb = (struct sk_buff *)req->data;
 	tls_ctx = tls_get_ctx(skb->sk);
 	ctx = tls_sw_ctx_rx(tls_ctx);
+	prot = &tls_ctx->prot_info;
 
 	/* Propagate if there was an err */
 	if (err) {
@@ -171,8 +173,8 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 	} else {
 		struct strp_msg *rxm = strp_msg(skb);
 		rxm->full_len -= padding_length(ctx, tls_ctx, skb);
-		rxm->offset += tls_ctx->rx.prepend_size;
-		rxm->full_len -= tls_ctx->rx.overhead_size;
+		rxm->offset += prot->prepend_size;
+		rxm->full_len -= prot->overhead_size;
 	}
 
 	/* After using skb->sk to propagate sk through crypto async callback
@@ -209,13 +211,14 @@ static int tls_do_decryption(struct sock *sk,
 			     bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	int ret;
 
 	aead_request_set_tfm(aead_req, ctx->aead_recv);
-	aead_request_set_ad(aead_req, tls_ctx->rx.aad_size);
+	aead_request_set_ad(aead_req, prot->aad_size);
 	aead_request_set_crypt(aead_req, sgin, sgout,
-			       data_len + tls_ctx->rx.tag_size,
+			       data_len + prot->tag_size,
 			       (u8 *)iv_recv);
 
 	if (async) {
@@ -253,12 +256,13 @@ static int tls_do_decryption(struct sock *sk,
 static void tls_trim_both_msgs(struct sock *sk, int target_size)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec = ctx->open_rec;
 
 	sk_msg_trim(sk, &rec->msg_plaintext, target_size);
 	if (target_size > 0)
-		target_size += tls_ctx->tx.overhead_size;
+		target_size += prot->overhead_size;
 	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
 }
 
@@ -275,6 +279,7 @@ static int tls_alloc_encrypted_msg(struct sock *sk, int len)
 static int tls_clone_plaintext_msg(struct sock *sk, int required)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec = ctx->open_rec;
 	struct sk_msg *msg_pl = &rec->msg_plaintext;
@@ -290,7 +295,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
 	/* Skip initial bytes in msg_en's data to be able to use
 	 * same offset of both plain and encrypted data.
 	 */
-	skip = tls_ctx->tx.prepend_size + msg_pl->sg.size;
+	skip = prot->prepend_size + msg_pl->sg.size;
 
 	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
 }
@@ -298,6 +303,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
 static struct tls_rec *tls_get_rec(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct sk_msg *msg_pl, *msg_en;
 	struct tls_rec *rec;
@@ -316,13 +322,11 @@ static struct tls_rec *tls_get_rec(struct sock *sk)
 	sk_msg_init(msg_en);
 
 	sg_init_table(rec->sg_aead_in, 2);
-	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
-		   tls_ctx->tx.aad_size);
+	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
 	sg_unmark_end(&rec->sg_aead_in[1]);
 
 	sg_init_table(rec->sg_aead_out, 2);
-	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
-		   tls_ctx->tx.aad_size);
+	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
 	sg_unmark_end(&rec->sg_aead_out[1]);
 
 	return rec;
@@ -411,6 +415,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err)
 	struct aead_request *aead_req = (struct aead_request *)req;
 	struct sock *sk = req->data;
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct scatterlist *sge;
 	struct sk_msg *msg_en;
@@ -422,8 +427,8 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err)
 	msg_en = &rec->msg_encrypted;
 
 	sge = sk_msg_elem(msg_en, msg_en->sg.curr);
-	sge->offset -= tls_ctx->tx.prepend_size;
-	sge->length += tls_ctx->tx.prepend_size;
+	sge->offset -= prot->prepend_size;
+	sge->length += prot->prepend_size;
 
 	/* Check if error is previously set on socket */
 	if (err || sk->sk_err) {
@@ -470,22 +475,23 @@ static int tls_do_encryption(struct sock *sk,
 			     struct aead_request *aead_req,
 			     size_t data_len, u32 start)
 {
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_rec *rec = ctx->open_rec;
 	struct sk_msg *msg_en = &rec->msg_encrypted;
 	struct scatterlist *sge = sk_msg_elem(msg_en, start);
 	int rc;
 
 	memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data));
-	xor_iv_with_seq(tls_ctx->crypto_send.info.version, rec->iv_data,
+	xor_iv_with_seq(prot->version, rec->iv_data,
 			tls_ctx->tx.rec_seq);
 
-	sge->offset += tls_ctx->tx.prepend_size;
-	sge->length -= tls_ctx->tx.prepend_size;
+	sge->offset += prot->prepend_size;
+	sge->length -= prot->prepend_size;
 
 	msg_en->sg.curr = start;
 
 	aead_request_set_tfm(aead_req, ctx->aead_send);
-	aead_request_set_ad(aead_req, tls_ctx->tx.aad_size);
+	aead_request_set_ad(aead_req, prot->aad_size);
 	aead_request_set_crypt(aead_req, rec->sg_aead_in,
 			       rec->sg_aead_out,
 			       data_len, rec->iv_data);
@@ -500,8 +506,8 @@ static int tls_do_encryption(struct sock *sk,
 	rc = crypto_aead_encrypt(aead_req);
 	if (!rc || rc != -EINPROGRESS) {
 		atomic_dec(&ctx->encrypt_pending);
-		sge->offset -= tls_ctx->tx.prepend_size;
-		sge->length += tls_ctx->tx.prepend_size;
+		sge->offset -= prot->prepend_size;
+		sge->length += prot->prepend_size;
 	}
 
 	if (!rc) {
@@ -513,8 +519,7 @@ static int tls_do_encryption(struct sock *sk,
 
 	/* Unhook the record from context if encryption is not failure */
 	ctx->open_rec = NULL;
-	tls_advance_record_sn(sk, &tls_ctx->tx,
-			      tls_ctx->crypto_send.info.version);
+	tls_advance_record_sn(sk, &tls_ctx->tx, prot->version);
 	return rc;
 }
 
@@ -640,6 +645,7 @@ static int tls_push_record(struct sock *sk, int flags,
 			   unsigned char record_type)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
 	u32 i, split_point, uninitialized_var(orig_end);
@@ -658,12 +664,12 @@ static int tls_push_record(struct sock *sk, int flags,
 	split = split_point && split_point < msg_pl->sg.size;
 	if (split) {
 		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
-					   split_point, tls_ctx->tx.overhead_size,
+					   split_point, prot->overhead_size,
 					   &orig_end);
 		if (rc < 0)
 			return rc;
 		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
-			    tls_ctx->tx.overhead_size);
+			    prot->overhead_size);
 	}
 
 	rec->tx_flags = flags;
@@ -673,7 +679,7 @@ static int tls_push_record(struct sock *sk, int flags,
 	sk_msg_iter_var_prev(i);
 
 	rec->content_type = record_type;
-	if (tls_ctx->crypto_send.info.version == TLS_1_3_VERSION) {
+	if (prot->version == TLS_1_3_VERSION) {
 		/* Add content type to end of message.  No padding added */
 		sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
 		sg_mark_end(&rec->sg_content_type);
@@ -694,22 +700,20 @@ static int tls_push_record(struct sock *sk, int flags,
 	i = msg_en->sg.start;
 	sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
 
-	tls_make_aad(rec->aad_space, msg_pl->sg.size + tls_ctx->tx.tail_size,
-		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
-		     record_type,
-		     tls_ctx->crypto_send.info.version);
+	tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
+		     tls_ctx->tx.rec_seq, prot->rec_seq_size,
+		     record_type, prot->version);
 
 	tls_fill_prepend(tls_ctx,
 			 page_address(sg_page(&msg_en->sg.data[i])) +
 			 msg_en->sg.data[i].offset,
-			 msg_pl->sg.size + tls_ctx->tx.tail_size,
-			 record_type,
-			 tls_ctx->crypto_send.info.version);
+			 msg_pl->sg.size + prot->tail_size,
+			 record_type, prot->version);
 
 	tls_ctx->pending_open_record_frags = false;
 
 	rc = tls_do_encryption(sk, tls_ctx, ctx, req,
-			       msg_pl->sg.size + tls_ctx->tx.tail_size, i);
+			       msg_pl->sg.size + prot->tail_size, i);
 	if (rc < 0) {
 		if (rc != -EINPROGRESS) {
 			tls_err_abort(sk, EBADMSG);
@@ -723,8 +727,7 @@ static int tls_push_record(struct sock *sk, int flags,
 	} else if (split) {
 		msg_pl = &tmp->msg_plaintext;
 		msg_en = &tmp->msg_encrypted;
-		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
-			    tls_ctx->tx.overhead_size);
+		sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
 		tls_ctx->pending_open_record_frags = true;
 		ctx->open_rec = tmp;
 	}
@@ -859,6 +862,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	bool async_capable = ctx->async_capable;
 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
@@ -925,7 +929,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 		}
 
 		required_size = msg_pl->sg.size + try_to_copy +
-				tls_ctx->tx.overhead_size;
+				prot->overhead_size;
 
 		if (!sk_stream_memory_free(sk))
 			goto wait_for_sndbuf;
@@ -994,8 +998,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 			 */
 			try_to_copy -= required_size - msg_pl->sg.size;
 			full_record = true;
-			sk_msg_trim(sk, msg_en, msg_pl->sg.size +
-				    tls_ctx->tx.overhead_size);
+			sk_msg_trim(sk, msg_en,
+				    msg_pl->sg.size + prot->overhead_size);
 		}
 
 		if (try_to_copy) {
@@ -1081,6 +1085,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
 	struct sk_msg *msg_pl;
 	struct tls_rec *rec;
@@ -1130,8 +1135,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 			full_record = true;
 		}
 
-		required_size = msg_pl->sg.size + copy +
-				tls_ctx->tx.overhead_size;
+		required_size = msg_pl->sg.size + copy + prot->overhead_size;
 
 		if (!sk_stream_memory_free(sk))
 			goto wait_for_sndbuf;
@@ -1330,6 +1334,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct strp_msg *rxm = strp_msg(skb);
 	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
 	struct aead_request *aead_req;
@@ -1337,16 +1342,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 	u8 *aad, *iv, *mem = NULL;
 	struct scatterlist *sgin = NULL;
 	struct scatterlist *sgout = NULL;
-	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size +
-		tls_ctx->rx.tail_size;
+	const int data_len = rxm->full_len - prot->overhead_size +
+			     prot->tail_size;
 
 	if (*zc && (out_iov || out_sg)) {
 		if (out_iov)
 			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
 		else
 			n_sgout = sg_nents(out_sg);
-		n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
-				 rxm->full_len - tls_ctx->rx.prepend_size);
+		n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
+				 rxm->full_len - prot->prepend_size);
 	} else {
 		n_sgout = 0;
 		*zc = false;
@@ -1363,7 +1368,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 
 	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
 	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
-	mem_size = mem_size + tls_ctx->rx.aad_size;
+	mem_size = mem_size + prot->aad_size;
 	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
 
 	/* Allocate a single block of memory which contains
@@ -1379,37 +1384,35 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 	sgin = (struct scatterlist *)(mem + aead_size);
 	sgout = sgin + n_sgin;
 	aad = (u8 *)(sgout + n_sgout);
-	iv = aad + tls_ctx->rx.aad_size;
+	iv = aad + prot->aad_size;
 
 	/* Prepare IV */
 	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
 			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			    tls_ctx->rx.iv_size);
+			    prot->iv_size);
 	if (err < 0) {
 		kfree(mem);
 		return err;
 	}
-	if (tls_ctx->crypto_recv.info.version == TLS_1_3_VERSION)
+	if (prot->version == TLS_1_3_VERSION)
 		memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv));
 	else
 		memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
 
-	xor_iv_with_seq(tls_ctx->crypto_recv.info.version, iv,
-			tls_ctx->rx.rec_seq);
+	xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq);
 
 	/* Prepare AAD */
-	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size +
-		     tls_ctx->rx.tail_size,
-		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
-		     ctx->control,
-		     tls_ctx->crypto_recv.info.version);
+	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
+		     prot->tail_size,
+		     tls_ctx->rx.rec_seq, prot->rec_seq_size,
+		     ctx->control, prot->version);
 
 	/* Prepare sgin */
 	sg_init_table(sgin, n_sgin);
-	sg_set_buf(&sgin[0], aad, tls_ctx->rx.aad_size);
+	sg_set_buf(&sgin[0], aad, prot->aad_size);
 	err = skb_to_sgvec(skb, &sgin[1],
-			   rxm->offset + tls_ctx->rx.prepend_size,
-			   rxm->full_len - tls_ctx->rx.prepend_size);
+			   rxm->offset + prot->prepend_size,
+			   rxm->full_len - prot->prepend_size);
 	if (err < 0) {
 		kfree(mem);
 		return err;
@@ -1418,7 +1421,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 	if (n_sgout) {
 		if (out_iov) {
 			sg_init_table(sgout, n_sgout);
-			sg_set_buf(&sgout[0], aad, tls_ctx->rx.aad_size);
+			sg_set_buf(&sgout[0], aad, prot->aad_size);
 
 			*chunk = 0;
 			err = tls_setup_from_iter(sk, out_iov, data_len,
@@ -1459,7 +1462,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	int version = tls_ctx->crypto_recv.info.version;
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	int version = prot->version;
 	struct strp_msg *rxm = strp_msg(skb);
 	int err = 0;
 
@@ -1480,8 +1484,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 
 		rxm->full_len -= padding_length(ctx, tls_ctx, skb);
 
-		rxm->offset += tls_ctx->rx.prepend_size;
-		rxm->full_len -= tls_ctx->rx.overhead_size;
+		rxm->offset += prot->prepend_size;
+		rxm->full_len -= prot->overhead_size;
 		tls_advance_record_sn(sk, &tls_ctx->rx, version);
 		ctx->decrypted = true;
 		ctx->saved_data_ready(sk);
@@ -1605,6 +1609,7 @@ int tls_sw_recvmsg(struct sock *sk,
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct sk_psock *psock;
 	unsigned char control = 0;
 	ssize_t decrypted = 0;
@@ -1667,11 +1672,11 @@ int tls_sw_recvmsg(struct sock *sk,
 
 		rxm = strp_msg(skb);
 
-		to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
+		to_decrypt = rxm->full_len - prot->overhead_size;
 
 		if (to_decrypt <= len && !is_kvec && !is_peek &&
 		    ctx->control == TLS_RECORD_TYPE_DATA &&
-		    tls_ctx->crypto_recv.info.version != TLS_1_3_VERSION)
+		    prot->version != TLS_1_3_VERSION)
 			zc = true;
 
 		/* Do not use async mode if record is non-data */
@@ -1875,6 +1880,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
 	struct strp_msg *rxm = strp_msg(skb);
 	size_t cipher_overhead;
@@ -1882,17 +1888,17 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 	int ret;
 
 	/* Verify that we have a full TLS header, or wait for more data */
-	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
+	if (rxm->offset + prot->prepend_size > skb->len)
 		return 0;
 
 	/* Sanity-check size of on-stack buffer. */
-	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
+	if (WARN_ON(prot->prepend_size > sizeof(header))) {
 		ret = -EINVAL;
 		goto read_failure;
 	}
 
 	/* Linearize header to local buffer */
-	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
+	ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
 
 	if (ret < 0)
 		goto read_failure;
@@ -1901,12 +1907,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 
 	data_len = ((header[4] & 0xFF) | (header[3] << 8));
 
-	cipher_overhead = tls_ctx->rx.tag_size;
-	if (tls_ctx->crypto_recv.info.version != TLS_1_3_VERSION)
-		cipher_overhead += tls_ctx->rx.iv_size;
+	cipher_overhead = prot->tag_size;
+	if (prot->version != TLS_1_3_VERSION)
+		cipher_overhead += prot->iv_size;
 
 	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
-	    tls_ctx->rx.tail_size) {
+	    prot->tail_size) {
 		ret = -EMSGSIZE;
 		goto read_failure;
 	}
@@ -2066,6 +2072,8 @@ static void tx_work_handler(struct work_struct *work)
 
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_crypto_info *crypto_info;
 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
 	struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
@@ -2171,18 +2179,20 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 
 	if (crypto_info->version == TLS_1_3_VERSION) {
 		nonce_size = 0;
-		cctx->aad_size = TLS_HEADER_SIZE;
-		cctx->tail_size = 1;
+		prot->aad_size = TLS_HEADER_SIZE;
+		prot->tail_size = 1;
 	} else {
-		cctx->aad_size = TLS_AAD_SPACE_SIZE;
-		cctx->tail_size = 0;
+		prot->aad_size = TLS_AAD_SPACE_SIZE;
+		prot->tail_size = 0;
 	}
 
-	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
-	cctx->tag_size = tag_size;
-	cctx->overhead_size = cctx->prepend_size + cctx->tag_size +
-		cctx->tail_size;
-	cctx->iv_size = iv_size;
+	prot->version = crypto_info->version;
+	prot->cipher_type = crypto_info->cipher_type;
+	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
+	prot->tag_size = tag_size;
+	prot->overhead_size = prot->prepend_size +
+			      prot->tag_size + prot->tail_size;
+	prot->iv_size = iv_size;
 	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 			   GFP_KERNEL);
 	if (!cctx->iv) {
@@ -2192,7 +2202,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 	/* Note: 128 & 256 bit salt are the same size */
 	memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
 	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
-	cctx->rec_seq_size = rec_seq_size;
+	prot->rec_seq_size = rec_seq_size;
 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
 	if (!cctx->rec_seq) {
 		rc = -ENOMEM;
@@ -2215,7 +2225,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 	if (rc)
 		goto free_aead;
 
-	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
+	rc = crypto_aead_setauthsize(*aead, prot->tag_size);
 	if (rc)
 		goto free_aead;