diff --git a/arch/x86/kernel/process_64.c b/arch/x86/kernel/process_64.c
index 59013f480b8664cfe42ea1a08d03ceade0690c44..8f239091c15d43c15274e551028de9e65815d876 100644
--- a/arch/x86/kernel/process_64.c
+++ b/arch/x86/kernel/process_64.c
@@ -245,13 +245,17 @@ static __always_inline void save_fsgs(struct task_struct *task)
 	savesegment(fs, task->thread.fsindex);
 	savesegment(gs, task->thread.gsindex);
 	if (static_cpu_has(X86_FEATURE_FSGSBASE)) {
+		unsigned long flags;
+
 		/*
 		 * If FSGSBASE is enabled, we can't make any useful guesses
 		 * about the base, and user code expects us to save the current
 		 * value.  Fortunately, reading the base directly is efficient.
 		 */
 		task->thread.fsbase = rdfsbase();
+		local_irq_save(flags);
 		task->thread.gsbase = __rdgsbase_inactive();
+		local_irq_restore(flags);
 	} else {
 		save_base_legacy(task, task->thread.fsindex, FS);
 		save_base_legacy(task, task->thread.gsindex, GS);
@@ -433,7 +437,8 @@ unsigned long x86_fsbase_read_task(struct task_struct *task)
 
 	if (task == current)
 		fsbase = x86_fsbase_read_cpu();
-	else if (task->thread.fsindex == 0)
+	else if (static_cpu_has(X86_FEATURE_FSGSBASE) ||
+		 (task->thread.fsindex == 0))
 		fsbase = task->thread.fsbase;
 	else
 		fsbase = x86_fsgsbase_read_task(task, task->thread.fsindex);
@@ -447,7 +452,8 @@ unsigned long x86_gsbase_read_task(struct task_struct *task)
 
 	if (task == current)
 		gsbase = x86_gsbase_read_cpu_inactive();
-	else if (task->thread.gsindex == 0)
+	else if (static_cpu_has(X86_FEATURE_FSGSBASE) ||
+		 (task->thread.gsindex == 0))
 		gsbase = task->thread.gsbase;
 	else
 		gsbase = x86_fsgsbase_read_task(task, task->thread.gsindex);
@@ -487,10 +493,11 @@ int copy_thread_tls(unsigned long clone_flags, unsigned long sp,
 	p->thread.sp = (unsigned long) fork_frame;
 	p->thread.io_bitmap_ptr = NULL;
 
-	savesegment(gs, p->thread.gsindex);
-	p->thread.gsbase = p->thread.gsindex ? 0 : me->thread.gsbase;
-	savesegment(fs, p->thread.fsindex);
-	p->thread.fsbase = p->thread.fsindex ? 0 : me->thread.fsbase;
+	save_fsgs(me);
+	p->thread.fsindex = me->thread.fsindex;
+	p->thread.fsbase = me->thread.fsbase;
+	p->thread.gsindex = me->thread.gsindex;
+	p->thread.gsbase = me->thread.gsbase;
 	savesegment(es, p->thread.es);
 	savesegment(ds, p->thread.ds);
 	memset(p->thread.ptrace_bps, 0, sizeof(p->thread.ptrace_bps));