diff --git a/crypto/aead.c b/crypto/aead.c
index fe00cbd7243d028f05701993cf1b82bf23af25c4..60b3bbe973e752a879f39ac5ca5af712560ad888 100644
--- a/crypto/aead.c
+++ b/crypto/aead.c
@@ -54,11 +54,18 @@ int crypto_aead_setkey(struct crypto_aead *tfm,
 		       const u8 *key, unsigned int keylen)
 {
 	unsigned long alignmask = crypto_aead_alignmask(tfm);
+	int err;
 
 	if ((unsigned long)key & alignmask)
-		return setkey_unaligned(tfm, key, keylen);
+		err = setkey_unaligned(tfm, key, keylen);
+	else
+		err = crypto_aead_alg(tfm)->setkey(tfm, key, keylen);
+
+	if (err)
+		return err;
 
-	return crypto_aead_alg(tfm)->setkey(tfm, key, keylen);
+	crypto_aead_clear_flags(tfm, CRYPTO_TFM_NEED_KEY);
+	return 0;
 }
 EXPORT_SYMBOL_GPL(crypto_aead_setkey);
 
@@ -93,6 +100,8 @@ static int crypto_aead_init_tfm(struct crypto_tfm *tfm)
 	struct crypto_aead *aead = __crypto_aead_cast(tfm);
 	struct aead_alg *alg = crypto_aead_alg(aead);
 
+	crypto_aead_set_flags(aead, CRYPTO_TFM_NEED_KEY);
+
 	aead->authsize = alg->maxauthsize;
 
 	if (alg->exit)
diff --git a/crypto/algif_aead.c b/crypto/algif_aead.c
index d963c8cf8a552efa0d893edd69850fe70b092a0b..4b07edd5a9ff7265873ddaa700bb6aace08fadfd 100644
--- a/crypto/algif_aead.c
+++ b/crypto/algif_aead.c
@@ -42,7 +42,6 @@
 
 struct aead_tfm {
 	struct crypto_aead *aead;
-	bool has_key;
 	struct crypto_skcipher *null_tfm;
 };
 
@@ -398,7 +397,7 @@ static int aead_check_key(struct socket *sock)
 
 	err = -ENOKEY;
 	lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
-	if (!tfm->has_key)
+	if (crypto_aead_get_flags(tfm->aead) & CRYPTO_TFM_NEED_KEY)
 		goto unlock;
 
 	if (!pask->refcnt++)
@@ -523,12 +522,8 @@ static int aead_setauthsize(void *private, unsigned int authsize)
 static int aead_setkey(void *private, const u8 *key, unsigned int keylen)
 {
 	struct aead_tfm *tfm = private;
-	int err;
-
-	err = crypto_aead_setkey(tfm->aead, key, keylen);
-	tfm->has_key = !err;
 
-	return err;
+	return crypto_aead_setkey(tfm->aead, key, keylen);
 }
 
 static void aead_sock_destruct(struct sock *sk)
@@ -589,7 +584,7 @@ static int aead_accept_parent(void *private, struct sock *sk)
 {
 	struct aead_tfm *tfm = private;
 
-	if (!tfm->has_key)
+	if (crypto_aead_get_flags(tfm->aead) & CRYPTO_TFM_NEED_KEY)
 		return -ENOKEY;
 
 	return aead_accept_parent_nokey(private, sk);
diff --git a/include/crypto/aead.h b/include/crypto/aead.h
index 03b97629442c183d923b0b7c3a68337702030fcd..1e26f790b03fa83864cd0ad46007a65ec09e8bf2 100644
--- a/include/crypto/aead.h
+++ b/include/crypto/aead.h
@@ -327,7 +327,12 @@ static inline struct crypto_aead *crypto_aead_reqtfm(struct aead_request *req)
  */
 static inline int crypto_aead_encrypt(struct aead_request *req)
 {
-	return crypto_aead_alg(crypto_aead_reqtfm(req))->encrypt(req);
+	struct crypto_aead *aead = crypto_aead_reqtfm(req);
+
+	if (crypto_aead_get_flags(aead) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
+	return crypto_aead_alg(aead)->encrypt(req);
 }
 
 /**
@@ -356,6 +361,9 @@ static inline int crypto_aead_decrypt(struct aead_request *req)
 {
 	struct crypto_aead *aead = crypto_aead_reqtfm(req);
 
+	if (crypto_aead_get_flags(aead) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
 	if (req->cryptlen < crypto_aead_authsize(aead))
 		return -EINVAL;