diff --git a/fs/afs/addr_list.c b/fs/afs/addr_list.c
index 967db336d11ae016324f4f15d7cbd33b809045c2..9eaff55df7b43ee490187b629856a4f35be6a46f 100644
--- a/fs/afs/addr_list.c
+++ b/fs/afs/addr_list.c
@@ -251,7 +251,7 @@ struct afs_vlserver_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry
 	_enter("%s", cell->name);
 
 	ret = dns_query("afsdb", cell->name, cell->name_len, "srv=1",
-			&result, _expiry);
+			&result, _expiry, true);
 	if (ret < 0) {
 		_leave(" = %d [dns]", ret);
 		return ERR_PTR(ret);
diff --git a/fs/afs/dynroot.c b/fs/afs/dynroot.c
index a9ba81ddf1546272d4a5cbb7e0885326c250c6ff..07484b5a3bbb837e1b2b58402e7e593d3b4db68e 100644
--- a/fs/afs/dynroot.c
+++ b/fs/afs/dynroot.c
@@ -46,7 +46,7 @@ static int afs_probe_cell_name(struct dentry *dentry)
 		return 0;
 	}
 
-	ret = dns_query("afsdb", name, len, "srv=1", NULL, NULL);
+	ret = dns_query("afsdb", name, len, "srv=1", NULL, NULL, false);
 	if (ret == -ENODATA)
 		ret = -EDESTADDRREQ;
 	return ret;
diff --git a/fs/cifs/dns_resolve.c b/fs/cifs/dns_resolve.c
index 7ede7306599f47449bdd02533db47bfede2c81df..1e21b2528cfb3e8397db38ab0d927ef047ab3e37 100644
--- a/fs/cifs/dns_resolve.c
+++ b/fs/cifs/dns_resolve.c
@@ -77,7 +77,7 @@ dns_resolve_server_name_to_ip(const char *unc, char **ip_addr)
 		goto name_is_IP_address;
 
 	/* Perform the upcall */
-	rc = dns_query(NULL, hostname, len, NULL, ip_addr, NULL);
+	rc = dns_query(NULL, hostname, len, NULL, ip_addr, NULL, false);
 	if (rc < 0)
 		cifs_dbg(FYI, "%s: unable to resolve: %*.*s\n",
 			 __func__, len, len, hostname);
diff --git a/fs/nfs/dns_resolve.c b/fs/nfs/dns_resolve.c
index a7d3df85736dfba49c461ed54a17b5db0d3bfa44..e6a700f014528088ef0e15260b9a5d4d95822a64 100644
--- a/fs/nfs/dns_resolve.c
+++ b/fs/nfs/dns_resolve.c
@@ -22,7 +22,7 @@ ssize_t nfs_dns_resolve_name(struct net *net, char *name, size_t namelen,
 	char *ip_addr = NULL;
 	int ip_len;
 
-	ip_len = dns_query(NULL, name, namelen, NULL, &ip_addr, NULL);
+	ip_len = dns_query(NULL, name, namelen, NULL, &ip_addr, NULL, false);
 	if (ip_len > 0)
 		ret = rpc_pton(net, ip_addr, ip_len, sa, salen);
 	else
diff --git a/include/linux/dns_resolver.h b/include/linux/dns_resolver.h
index 34a744a1bafcbc84c4e2992435f14b71610a4c88..f2b3ae22e6b77bf3ba37ed7b42abc94ca95222cb 100644
--- a/include/linux/dns_resolver.h
+++ b/include/linux/dns_resolver.h
@@ -27,6 +27,7 @@
 #include <uapi/linux/dns_resolver.h>
 
 extern int dns_query(const char *type, const char *name, size_t namelen,
-		     const char *options, char **_result, time64_t *_expiry);
+		     const char *options, char **_result, time64_t *_expiry,
+		     bool invalidate);
 
 #endif /* _LINUX_DNS_RESOLVER_H */
diff --git a/net/ceph/messenger.c b/net/ceph/messenger.c
index 3083988ce729dbe01771e9433b7de72e484394f9..579d6a1ac7fefb940cdc61e0eeeb16e9c812c9af 100644
--- a/net/ceph/messenger.c
+++ b/net/ceph/messenger.c
@@ -1889,7 +1889,7 @@ static int ceph_dns_resolve_name(const char *name, size_t namelen,
 		return -EINVAL;
 
 	/* do dns_resolve upcall */
-	ip_len = dns_query(NULL, name, end - name, NULL, &ip_addr, NULL);
+	ip_len = dns_query(NULL, name, end - name, NULL, &ip_addr, NULL, false);
 	if (ip_len > 0)
 		ret = ceph_pton(ip_addr, ip_len, ss, -1, NULL);
 	else
diff --git a/net/dns_resolver/dns_query.c b/net/dns_resolver/dns_query.c
index 19aa32fc1802b44a971eecd6f7b9f967ec2900d9..2d260432b3be0a7ddefbd6cb68c2bb2b18163525 100644
--- a/net/dns_resolver/dns_query.c
+++ b/net/dns_resolver/dns_query.c
@@ -54,6 +54,7 @@
  * @options: Request options (or NULL if no options)
  * @_result: Where to place the returned data (or NULL)
  * @_expiry: Where to store the result expiry time (or NULL)
+ * @invalidate: Always invalidate the key after use
  *
  * The data will be returned in the pointer at *result, if provided, and the
  * caller is responsible for freeing it.
@@ -69,7 +70,8 @@
  * Returns the size of the result on success, -ve error code otherwise.
  */
 int dns_query(const char *type, const char *name, size_t namelen,
-	      const char *options, char **_result, time64_t *_expiry)
+	      const char *options, char **_result, time64_t *_expiry,
+	      bool invalidate)
 {
 	struct key *rkey;
 	struct user_key_payload *upayload;
@@ -157,6 +159,8 @@ int dns_query(const char *type, const char *name, size_t namelen,
 	ret = len;
 put:
 	up_read(&rkey->sem);
+	if (invalidate)
+		key_invalidate(rkey);
 	key_put(rkey);
 out:
 	kleave(" = %d", ret);