diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index 153e539c29c92fcb3c55c3737ec7e0533c13e1e2..897d70e3d291dfe58802eba8b919853106bf33bc 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -500,6 +500,17 @@ static void nested_vmx_disable_intercept_for_msr(unsigned long *msr_bitmap_l1,
 	}
 }
 
+static inline void enable_x2apic_msr_intercepts(unsigned long *msr_bitmap) {
+	int msr;
+
+	for (msr = 0x800; msr <= 0x8ff; msr += BITS_PER_LONG) {
+		unsigned word = msr / BITS_PER_LONG;
+
+		msr_bitmap[word] = ~0;
+		msr_bitmap[word + (0x800 / sizeof(long))] = ~0;
+	}
+}
+
 /*
  * Merge L0's and L1's MSR bitmap, return false to indicate that
  * we do not use the hardware.
@@ -541,39 +552,44 @@ static inline bool nested_vmx_prepare_msr_bitmap(struct kvm_vcpu *vcpu,
 		return false;
 
 	msr_bitmap_l1 = (unsigned long *)kmap(page);
-	if (nested_cpu_has_apic_reg_virt(vmcs12)) {
-		/*
-		 * L0 need not intercept reads for MSRs between 0x800 and 0x8ff, it
-		 * just lets the processor take the value from the virtual-APIC page;
-		 * take those 256 bits directly from the L1 bitmap.
-		 */
-		for (msr = 0x800; msr <= 0x8ff; msr += BITS_PER_LONG) {
-			unsigned word = msr / BITS_PER_LONG;
-			msr_bitmap_l0[word] = msr_bitmap_l1[word];
-			msr_bitmap_l0[word + (0x800 / sizeof(long))] = ~0;
-		}
-	} else {
-		for (msr = 0x800; msr <= 0x8ff; msr += BITS_PER_LONG) {
-			unsigned word = msr / BITS_PER_LONG;
-			msr_bitmap_l0[word] = ~0;
-			msr_bitmap_l0[word + (0x800 / sizeof(long))] = ~0;
-		}
-	}
 
-	nested_vmx_disable_intercept_for_msr(
-		msr_bitmap_l1, msr_bitmap_l0,
-		X2APIC_MSR(APIC_TASKPRI),
-		MSR_TYPE_W);
+	/*
+	 * To keep the control flow simple, pay eight 8-byte writes (sixteen
+	 * 4-byte writes on 32-bit systems) up front to enable intercepts for
+	 * the x2APIC MSR range and selectively disable them below.
+	 */
+	enable_x2apic_msr_intercepts(msr_bitmap_l0);
+
+	if (nested_cpu_has_virt_x2apic_mode(vmcs12)) {
+		if (nested_cpu_has_apic_reg_virt(vmcs12)) {
+			/*
+			 * L0 need not intercept reads for MSRs between 0x800
+			 * and 0x8ff, it just lets the processor take the value
+			 * from the virtual-APIC page; take those 256 bits
+			 * directly from the L1 bitmap.
+			 */
+			for (msr = 0x800; msr <= 0x8ff; msr += BITS_PER_LONG) {
+				unsigned word = msr / BITS_PER_LONG;
+
+				msr_bitmap_l0[word] = msr_bitmap_l1[word];
+			}
+		}
 
-	if (nested_cpu_has_vid(vmcs12)) {
-		nested_vmx_disable_intercept_for_msr(
-			msr_bitmap_l1, msr_bitmap_l0,
-			X2APIC_MSR(APIC_EOI),
-			MSR_TYPE_W);
 		nested_vmx_disable_intercept_for_msr(
 			msr_bitmap_l1, msr_bitmap_l0,
-			X2APIC_MSR(APIC_SELF_IPI),
+			X2APIC_MSR(APIC_TASKPRI),
 			MSR_TYPE_W);
+
+		if (nested_cpu_has_vid(vmcs12)) {
+			nested_vmx_disable_intercept_for_msr(
+				msr_bitmap_l1, msr_bitmap_l0,
+				X2APIC_MSR(APIC_EOI),
+				MSR_TYPE_W);
+			nested_vmx_disable_intercept_for_msr(
+				msr_bitmap_l1, msr_bitmap_l0,
+				X2APIC_MSR(APIC_SELF_IPI),
+				MSR_TYPE_W);
+		}
 	}
 
 	if (spec_ctrl)