diff --git a/fs/proc/base.c b/fs/proc/base.c
index 7d795d28dd02954234bab6d8eda1cadb20c771fd..872a3f28bfe41b7dd6cf847047187a9d28ba63be 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -1363,16 +1363,16 @@ static ssize_t proc_fail_nth_write(struct file *file, const char __user *buf,
 	int err;
 	unsigned int n;
 
+	err = kstrtouint_from_user(buf, count, 0, &n);
+	if (err)
+		return err;
+
 	task = get_proc_task(file_inode(file));
 	if (!task)
 		return -ESRCH;
+	WRITE_ONCE(task->fail_nth, n);
 	put_task_struct(task);
-	if (task != current)
-		return -EPERM;
-	err = kstrtouint_from_user(buf, count, 0, &n);
-	if (err)
-		return err;
-	current->fail_nth = n;
+
 	return count;
 }
 
@@ -1386,11 +1386,10 @@ static ssize_t proc_fail_nth_read(struct file *file, char __user *buf,
 	task = get_proc_task(file_inode(file));
 	if (!task)
 		return -ESRCH;
-	put_task_struct(task);
-	if (task != current)
-		return -EPERM;
-	len = snprintf(numbuf, sizeof(numbuf), "%u\n", task->fail_nth);
+	len = snprintf(numbuf, sizeof(numbuf), "%u\n",
+			READ_ONCE(task->fail_nth));
 	len = simple_read_from_buffer(buf, count, ppos, numbuf, len);
+	put_task_struct(task);
 
 	return len;
 }
@@ -3355,11 +3354,7 @@ static const struct pid_entry tid_base_stuff[] = {
 #endif
 #ifdef CONFIG_FAULT_INJECTION
 	REG("make-it-fail", S_IRUGO|S_IWUSR, proc_fault_inject_operations),
-	/*
-	 * Operations on the file check that the task is current,
-	 * so we create it with 0666 to support testing under unprivileged user.
-	 */
-	REG("fail-nth", 0666, proc_fail_nth_operations),
+	REG("fail-nth", 0644, proc_fail_nth_operations),
 #endif
 #ifdef CONFIG_TASK_IO_ACCOUNTING
 	ONE("io",	S_IRUSR, proc_tid_io_accounting),
diff --git a/lib/fault-inject.c b/lib/fault-inject.c
index 09ac73c177fd555b86ef8b3a155ee1043258b532..7d315fdb9f13d9b17d8a2aa129c75790c7599bdb 100644
--- a/lib/fault-inject.c
+++ b/lib/fault-inject.c
@@ -107,9 +107,12 @@ static inline bool fail_stacktrace(struct fault_attr *attr)
 
 bool should_fail(struct fault_attr *attr, ssize_t size)
 {
-	if (in_task() && current->fail_nth) {
-		if (--current->fail_nth == 0)
+	if (in_task()) {
+		unsigned int fail_nth = READ_ONCE(current->fail_nth);
+
+		if (fail_nth && !WRITE_ONCE(current->fail_nth, fail_nth - 1))
 			goto fail;
+
 		return false;
 	}