diff --git a/arch/x86/include/asm/kvm_emulate.h b/arch/x86/include/asm/kvm_emulate.h
index 93c4bf598fb06c7e53865141dd3e7faa514194ff..feab24cac610e25f276d3d1f71f4705c23106b00 100644
--- a/arch/x86/include/asm/kvm_emulate.h
+++ b/arch/x86/include/asm/kvm_emulate.h
@@ -226,7 +226,9 @@ struct x86_emulate_ops {
 
 	unsigned (*get_hflags)(struct x86_emulate_ctxt *ctxt);
 	void (*set_hflags)(struct x86_emulate_ctxt *ctxt, unsigned hflags);
-	int (*pre_leave_smm)(struct x86_emulate_ctxt *ctxt, u64 smbase);
+	int (*pre_leave_smm)(struct x86_emulate_ctxt *ctxt,
+			     const char *smstate);
+	void (*post_leave_smm)(struct x86_emulate_ctxt *ctxt);
 
 };
 
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 159b5988292f33ec2d1a079bf7d10ba2bc999d4b..a9d03af340307db6589376cf3bfb29a533910cdd 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -126,7 +126,7 @@ static inline gfn_t gfn_to_index(gfn_t gfn, gfn_t base_gfn, int level)
 }
 
 #define KVM_PERMILLE_MMU_PAGES 20
-#define KVM_MIN_ALLOC_MMU_PAGES 64
+#define KVM_MIN_ALLOC_MMU_PAGES 64UL
 #define KVM_MMU_HASH_SHIFT 12
 #define KVM_NUM_MMU_PAGES (1 << KVM_MMU_HASH_SHIFT)
 #define KVM_MIN_FREE_MMU_PAGES 5
@@ -844,9 +844,9 @@ enum kvm_irqchip_mode {
 };
 
 struct kvm_arch {
-	unsigned int n_used_mmu_pages;
-	unsigned int n_requested_mmu_pages;
-	unsigned int n_max_mmu_pages;
+	unsigned long n_used_mmu_pages;
+	unsigned long n_requested_mmu_pages;
+	unsigned long n_max_mmu_pages;
 	unsigned int indirect_shadow_pages;
 	struct hlist_head mmu_page_hash[KVM_NUM_MMU_PAGES];
 	/*
@@ -1182,7 +1182,7 @@ struct kvm_x86_ops {
 
 	int (*smi_allowed)(struct kvm_vcpu *vcpu);
 	int (*pre_enter_smm)(struct kvm_vcpu *vcpu, char *smstate);
-	int (*pre_leave_smm)(struct kvm_vcpu *vcpu, u64 smbase);
+	int (*pre_leave_smm)(struct kvm_vcpu *vcpu, const char *smstate);
 	int (*enable_smi_window)(struct kvm_vcpu *vcpu);
 
 	int (*mem_enc_op)(struct kvm *kvm, void __user *argp);
@@ -1256,8 +1256,8 @@ void kvm_mmu_clear_dirty_pt_masked(struct kvm *kvm,
 				   gfn_t gfn_offset, unsigned long mask);
 void kvm_mmu_zap_all(struct kvm *kvm);
 void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm, u64 gen);
-unsigned int kvm_mmu_calculate_default_mmu_pages(struct kvm *kvm);
-void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages);
+unsigned long kvm_mmu_calculate_default_mmu_pages(struct kvm *kvm);
+void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned long kvm_nr_mmu_pages);
 
 int load_pdptrs(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu, unsigned long cr3);
 bool pdptrs_changed(struct kvm_vcpu *vcpu);
@@ -1592,4 +1592,7 @@ static inline int kvm_cpu_get_apicid(int mps_cpu)
 #define put_smstate(type, buf, offset, val)                      \
 	*(type *)((buf) + (offset) - 0x7e00) = val
 
+#define GET_SMSTATE(type, buf, offset)		\
+	(*(type *)((buf) + (offset) - 0x7e00))
+
 #endif /* _ASM_X86_KVM_HOST_H */
diff --git a/arch/x86/include/uapi/asm/vmx.h b/arch/x86/include/uapi/asm/vmx.h
index f0b0c90dd398246eb2882050d69c6b53ccca11af..d213ec5c3766db0dd5176c951b13e5f3c1514cfb 100644
--- a/arch/x86/include/uapi/asm/vmx.h
+++ b/arch/x86/include/uapi/asm/vmx.h
@@ -146,6 +146,7 @@
 
 #define VMX_ABORT_SAVE_GUEST_MSR_FAIL        1
 #define VMX_ABORT_LOAD_HOST_PDPTE_FAIL       2
+#define VMX_ABORT_VMCS_CORRUPTED             3
 #define VMX_ABORT_LOAD_HOST_MSR_FAIL         4
 
 #endif /* _UAPIVMX_H */
diff --git a/arch/x86/kvm/emulate.c b/arch/x86/kvm/emulate.c
index c338984c850d28a1213e46f86efc06d425115660..d0d5dd44b4f478524cc959cefb245695d9e40894 100644
--- a/arch/x86/kvm/emulate.c
+++ b/arch/x86/kvm/emulate.c
@@ -2331,24 +2331,18 @@ static int em_lseg(struct x86_emulate_ctxt *ctxt)
 
 static int emulator_has_longmode(struct x86_emulate_ctxt *ctxt)
 {
+#ifdef CONFIG_X86_64
 	u32 eax, ebx, ecx, edx;
 
 	eax = 0x80000001;
 	ecx = 0;
 	ctxt->ops->get_cpuid(ctxt, &eax, &ebx, &ecx, &edx, false);
 	return edx & bit(X86_FEATURE_LM);
+#else
+	return false;
+#endif
 }
 
-#define GET_SMSTATE(type, smbase, offset)				  \
-	({								  \
-	 type __val;							  \
-	 int r = ctxt->ops->read_phys(ctxt, smbase + offset, &__val,      \
-				      sizeof(__val));			  \
-	 if (r != X86EMUL_CONTINUE)					  \
-		 return X86EMUL_UNHANDLEABLE;				  \
-	 __val;								  \
-	})
-
 static void rsm_set_desc_flags(struct desc_struct *desc, u32 flags)
 {
 	desc->g    = (flags >> 23) & 1;
@@ -2361,27 +2355,30 @@ static void rsm_set_desc_flags(struct desc_struct *desc, u32 flags)
 	desc->type = (flags >>  8) & 15;
 }
 
-static int rsm_load_seg_32(struct x86_emulate_ctxt *ctxt, u64 smbase, int n)
+static int rsm_load_seg_32(struct x86_emulate_ctxt *ctxt, const char *smstate,
+			   int n)
 {
 	struct desc_struct desc;
 	int offset;
 	u16 selector;
 
-	selector = GET_SMSTATE(u32, smbase, 0x7fa8 + n * 4);
+	selector = GET_SMSTATE(u32, smstate, 0x7fa8 + n * 4);
 
 	if (n < 3)
 		offset = 0x7f84 + n * 12;
 	else
 		offset = 0x7f2c + (n - 3) * 12;
 
-	set_desc_base(&desc,      GET_SMSTATE(u32, smbase, offset + 8));
-	set_desc_limit(&desc,     GET_SMSTATE(u32, smbase, offset + 4));
-	rsm_set_desc_flags(&desc, GET_SMSTATE(u32, smbase, offset));
+	set_desc_base(&desc,      GET_SMSTATE(u32, smstate, offset + 8));
+	set_desc_limit(&desc,     GET_SMSTATE(u32, smstate, offset + 4));
+	rsm_set_desc_flags(&desc, GET_SMSTATE(u32, smstate, offset));
 	ctxt->ops->set_segment(ctxt, selector, &desc, 0, n);
 	return X86EMUL_CONTINUE;
 }
 
-static int rsm_load_seg_64(struct x86_emulate_ctxt *ctxt, u64 smbase, int n)
+#ifdef CONFIG_X86_64
+static int rsm_load_seg_64(struct x86_emulate_ctxt *ctxt, const char *smstate,
+			   int n)
 {
 	struct desc_struct desc;
 	int offset;
@@ -2390,15 +2387,16 @@ static int rsm_load_seg_64(struct x86_emulate_ctxt *ctxt, u64 smbase, int n)
 
 	offset = 0x7e00 + n * 16;
 
-	selector =                GET_SMSTATE(u16, smbase, offset);
-	rsm_set_desc_flags(&desc, GET_SMSTATE(u16, smbase, offset + 2) << 8);
-	set_desc_limit(&desc,     GET_SMSTATE(u32, smbase, offset + 4));
-	set_desc_base(&desc,      GET_SMSTATE(u32, smbase, offset + 8));
-	base3 =                   GET_SMSTATE(u32, smbase, offset + 12);
+	selector =                GET_SMSTATE(u16, smstate, offset);
+	rsm_set_desc_flags(&desc, GET_SMSTATE(u16, smstate, offset + 2) << 8);
+	set_desc_limit(&desc,     GET_SMSTATE(u32, smstate, offset + 4));
+	set_desc_base(&desc,      GET_SMSTATE(u32, smstate, offset + 8));
+	base3 =                   GET_SMSTATE(u32, smstate, offset + 12);
 
 	ctxt->ops->set_segment(ctxt, selector, &desc, base3, n);
 	return X86EMUL_CONTINUE;
 }
+#endif
 
 static int rsm_enter_protected_mode(struct x86_emulate_ctxt *ctxt,
 				    u64 cr0, u64 cr3, u64 cr4)
@@ -2445,7 +2443,8 @@ static int rsm_enter_protected_mode(struct x86_emulate_ctxt *ctxt,
 	return X86EMUL_CONTINUE;
 }
 
-static int rsm_load_state_32(struct x86_emulate_ctxt *ctxt, u64 smbase)
+static int rsm_load_state_32(struct x86_emulate_ctxt *ctxt,
+			     const char *smstate)
 {
 	struct desc_struct desc;
 	struct desc_ptr dt;
@@ -2453,53 +2452,55 @@ static int rsm_load_state_32(struct x86_emulate_ctxt *ctxt, u64 smbase)
 	u32 val, cr0, cr3, cr4;
 	int i;
 
-	cr0 =                      GET_SMSTATE(u32, smbase, 0x7ffc);
-	cr3 =                      GET_SMSTATE(u32, smbase, 0x7ff8);
-	ctxt->eflags =             GET_SMSTATE(u32, smbase, 0x7ff4) | X86_EFLAGS_FIXED;
-	ctxt->_eip =               GET_SMSTATE(u32, smbase, 0x7ff0);
+	cr0 =                      GET_SMSTATE(u32, smstate, 0x7ffc);
+	cr3 =                      GET_SMSTATE(u32, smstate, 0x7ff8);
+	ctxt->eflags =             GET_SMSTATE(u32, smstate, 0x7ff4) | X86_EFLAGS_FIXED;
+	ctxt->_eip =               GET_SMSTATE(u32, smstate, 0x7ff0);
 
 	for (i = 0; i < 8; i++)
-		*reg_write(ctxt, i) = GET_SMSTATE(u32, smbase, 0x7fd0 + i * 4);
+		*reg_write(ctxt, i) = GET_SMSTATE(u32, smstate, 0x7fd0 + i * 4);
 
-	val = GET_SMSTATE(u32, smbase, 0x7fcc);
+	val = GET_SMSTATE(u32, smstate, 0x7fcc);
 	ctxt->ops->set_dr(ctxt, 6, (val & DR6_VOLATILE) | DR6_FIXED_1);
-	val = GET_SMSTATE(u32, smbase, 0x7fc8);
+	val = GET_SMSTATE(u32, smstate, 0x7fc8);
 	ctxt->ops->set_dr(ctxt, 7, (val & DR7_VOLATILE) | DR7_FIXED_1);
 
-	selector =                 GET_SMSTATE(u32, smbase, 0x7fc4);
-	set_desc_base(&desc,       GET_SMSTATE(u32, smbase, 0x7f64));
-	set_desc_limit(&desc,      GET_SMSTATE(u32, smbase, 0x7f60));
-	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smbase, 0x7f5c));
+	selector =                 GET_SMSTATE(u32, smstate, 0x7fc4);
+	set_desc_base(&desc,       GET_SMSTATE(u32, smstate, 0x7f64));
+	set_desc_limit(&desc,      GET_SMSTATE(u32, smstate, 0x7f60));
+	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smstate, 0x7f5c));
 	ctxt->ops->set_segment(ctxt, selector, &desc, 0, VCPU_SREG_TR);
 
-	selector =                 GET_SMSTATE(u32, smbase, 0x7fc0);
-	set_desc_base(&desc,       GET_SMSTATE(u32, smbase, 0x7f80));
-	set_desc_limit(&desc,      GET_SMSTATE(u32, smbase, 0x7f7c));
-	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smbase, 0x7f78));
+	selector =                 GET_SMSTATE(u32, smstate, 0x7fc0);
+	set_desc_base(&desc,       GET_SMSTATE(u32, smstate, 0x7f80));
+	set_desc_limit(&desc,      GET_SMSTATE(u32, smstate, 0x7f7c));
+	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smstate, 0x7f78));
 	ctxt->ops->set_segment(ctxt, selector, &desc, 0, VCPU_SREG_LDTR);
 
-	dt.address =               GET_SMSTATE(u32, smbase, 0x7f74);
-	dt.size =                  GET_SMSTATE(u32, smbase, 0x7f70);
+	dt.address =               GET_SMSTATE(u32, smstate, 0x7f74);
+	dt.size =                  GET_SMSTATE(u32, smstate, 0x7f70);
 	ctxt->ops->set_gdt(ctxt, &dt);
 
-	dt.address =               GET_SMSTATE(u32, smbase, 0x7f58);
-	dt.size =                  GET_SMSTATE(u32, smbase, 0x7f54);
+	dt.address =               GET_SMSTATE(u32, smstate, 0x7f58);
+	dt.size =                  GET_SMSTATE(u32, smstate, 0x7f54);
 	ctxt->ops->set_idt(ctxt, &dt);
 
 	for (i = 0; i < 6; i++) {
-		int r = rsm_load_seg_32(ctxt, smbase, i);
+		int r = rsm_load_seg_32(ctxt, smstate, i);
 		if (r != X86EMUL_CONTINUE)
 			return r;
 	}
 
-	cr4 = GET_SMSTATE(u32, smbase, 0x7f14);
+	cr4 = GET_SMSTATE(u32, smstate, 0x7f14);
 
-	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smbase, 0x7ef8));
+	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smstate, 0x7ef8));
 
 	return rsm_enter_protected_mode(ctxt, cr0, cr3, cr4);
 }
 
-static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt, u64 smbase)
+#ifdef CONFIG_X86_64
+static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt,
+			     const char *smstate)
 {
 	struct desc_struct desc;
 	struct desc_ptr dt;
@@ -2509,43 +2510,43 @@ static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt, u64 smbase)
 	int i, r;
 
 	for (i = 0; i < 16; i++)
-		*reg_write(ctxt, i) = GET_SMSTATE(u64, smbase, 0x7ff8 - i * 8);
+		*reg_write(ctxt, i) = GET_SMSTATE(u64, smstate, 0x7ff8 - i * 8);
 
-	ctxt->_eip   = GET_SMSTATE(u64, smbase, 0x7f78);
-	ctxt->eflags = GET_SMSTATE(u32, smbase, 0x7f70) | X86_EFLAGS_FIXED;
+	ctxt->_eip   = GET_SMSTATE(u64, smstate, 0x7f78);
+	ctxt->eflags = GET_SMSTATE(u32, smstate, 0x7f70) | X86_EFLAGS_FIXED;
 
-	val = GET_SMSTATE(u32, smbase, 0x7f68);
+	val = GET_SMSTATE(u32, smstate, 0x7f68);
 	ctxt->ops->set_dr(ctxt, 6, (val & DR6_VOLATILE) | DR6_FIXED_1);
-	val = GET_SMSTATE(u32, smbase, 0x7f60);
+	val = GET_SMSTATE(u32, smstate, 0x7f60);
 	ctxt->ops->set_dr(ctxt, 7, (val & DR7_VOLATILE) | DR7_FIXED_1);
 
-	cr0 =                       GET_SMSTATE(u64, smbase, 0x7f58);
-	cr3 =                       GET_SMSTATE(u64, smbase, 0x7f50);
-	cr4 =                       GET_SMSTATE(u64, smbase, 0x7f48);
-	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smbase, 0x7f00));
-	val =                       GET_SMSTATE(u64, smbase, 0x7ed0);
+	cr0 =                       GET_SMSTATE(u64, smstate, 0x7f58);
+	cr3 =                       GET_SMSTATE(u64, smstate, 0x7f50);
+	cr4 =                       GET_SMSTATE(u64, smstate, 0x7f48);
+	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smstate, 0x7f00));
+	val =                       GET_SMSTATE(u64, smstate, 0x7ed0);
 	ctxt->ops->set_msr(ctxt, MSR_EFER, val & ~EFER_LMA);
 
-	selector =                  GET_SMSTATE(u32, smbase, 0x7e90);
-	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smbase, 0x7e92) << 8);
-	set_desc_limit(&desc,       GET_SMSTATE(u32, smbase, 0x7e94));
-	set_desc_base(&desc,        GET_SMSTATE(u32, smbase, 0x7e98));
-	base3 =                     GET_SMSTATE(u32, smbase, 0x7e9c);
+	selector =                  GET_SMSTATE(u32, smstate, 0x7e90);
+	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smstate, 0x7e92) << 8);
+	set_desc_limit(&desc,       GET_SMSTATE(u32, smstate, 0x7e94));
+	set_desc_base(&desc,        GET_SMSTATE(u32, smstate, 0x7e98));
+	base3 =                     GET_SMSTATE(u32, smstate, 0x7e9c);
 	ctxt->ops->set_segment(ctxt, selector, &desc, base3, VCPU_SREG_TR);
 
-	dt.size =                   GET_SMSTATE(u32, smbase, 0x7e84);
-	dt.address =                GET_SMSTATE(u64, smbase, 0x7e88);
+	dt.size =                   GET_SMSTATE(u32, smstate, 0x7e84);
+	dt.address =                GET_SMSTATE(u64, smstate, 0x7e88);
 	ctxt->ops->set_idt(ctxt, &dt);
 
-	selector =                  GET_SMSTATE(u32, smbase, 0x7e70);
-	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smbase, 0x7e72) << 8);
-	set_desc_limit(&desc,       GET_SMSTATE(u32, smbase, 0x7e74));
-	set_desc_base(&desc,        GET_SMSTATE(u32, smbase, 0x7e78));
-	base3 =                     GET_SMSTATE(u32, smbase, 0x7e7c);
+	selector =                  GET_SMSTATE(u32, smstate, 0x7e70);
+	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smstate, 0x7e72) << 8);
+	set_desc_limit(&desc,       GET_SMSTATE(u32, smstate, 0x7e74));
+	set_desc_base(&desc,        GET_SMSTATE(u32, smstate, 0x7e78));
+	base3 =                     GET_SMSTATE(u32, smstate, 0x7e7c);
 	ctxt->ops->set_segment(ctxt, selector, &desc, base3, VCPU_SREG_LDTR);
 
-	dt.size =                   GET_SMSTATE(u32, smbase, 0x7e64);
-	dt.address =                GET_SMSTATE(u64, smbase, 0x7e68);
+	dt.size =                   GET_SMSTATE(u32, smstate, 0x7e64);
+	dt.address =                GET_SMSTATE(u64, smstate, 0x7e68);
 	ctxt->ops->set_gdt(ctxt, &dt);
 
 	r = rsm_enter_protected_mode(ctxt, cr0, cr3, cr4);
@@ -2553,37 +2554,49 @@ static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt, u64 smbase)
 		return r;
 
 	for (i = 0; i < 6; i++) {
-		r = rsm_load_seg_64(ctxt, smbase, i);
+		r = rsm_load_seg_64(ctxt, smstate, i);
 		if (r != X86EMUL_CONTINUE)
 			return r;
 	}
 
 	return X86EMUL_CONTINUE;
 }
+#endif
 
 static int em_rsm(struct x86_emulate_ctxt *ctxt)
 {
 	unsigned long cr0, cr4, efer;
+	char buf[512];
 	u64 smbase;
 	int ret;
 
 	if ((ctxt->ops->get_hflags(ctxt) & X86EMUL_SMM_MASK) == 0)
 		return emulate_ud(ctxt);
 
+	smbase = ctxt->ops->get_smbase(ctxt);
+
+	ret = ctxt->ops->read_phys(ctxt, smbase + 0xfe00, buf, sizeof(buf));
+	if (ret != X86EMUL_CONTINUE)
+		return X86EMUL_UNHANDLEABLE;
+
+	if ((ctxt->ops->get_hflags(ctxt) & X86EMUL_SMM_INSIDE_NMI_MASK) == 0)
+		ctxt->ops->set_nmi_mask(ctxt, false);
+
+	ctxt->ops->set_hflags(ctxt, ctxt->ops->get_hflags(ctxt) &
+		~(X86EMUL_SMM_INSIDE_NMI_MASK | X86EMUL_SMM_MASK));
+
 	/*
 	 * Get back to real mode, to prepare a safe state in which to load
 	 * CR0/CR3/CR4/EFER.  It's all a bit more complicated if the vCPU
 	 * supports long mode.
 	 */
-	cr4 = ctxt->ops->get_cr(ctxt, 4);
 	if (emulator_has_longmode(ctxt)) {
 		struct desc_struct cs_desc;
 
 		/* Zero CR4.PCIDE before CR0.PG.  */
-		if (cr4 & X86_CR4_PCIDE) {
+		cr4 = ctxt->ops->get_cr(ctxt, 4);
+		if (cr4 & X86_CR4_PCIDE)
 			ctxt->ops->set_cr(ctxt, 4, cr4 & ~X86_CR4_PCIDE);
-			cr4 &= ~X86_CR4_PCIDE;
-		}
 
 		/* A 32-bit code segment is required to clear EFER.LMA.  */
 		memset(&cs_desc, 0, sizeof(cs_desc));
@@ -2597,39 +2610,39 @@ static int em_rsm(struct x86_emulate_ctxt *ctxt)
 	if (cr0 & X86_CR0_PE)
 		ctxt->ops->set_cr(ctxt, 0, cr0 & ~(X86_CR0_PG | X86_CR0_PE));
 
-	/* Now clear CR4.PAE (which must be done before clearing EFER.LME).  */
-	if (cr4 & X86_CR4_PAE)
-		ctxt->ops->set_cr(ctxt, 4, cr4 & ~X86_CR4_PAE);
-
-	/* And finally go back to 32-bit mode.  */
-	efer = 0;
-	ctxt->ops->set_msr(ctxt, MSR_EFER, efer);
+	if (emulator_has_longmode(ctxt)) {
+		/* Clear CR4.PAE before clearing EFER.LME. */
+		cr4 = ctxt->ops->get_cr(ctxt, 4);
+		if (cr4 & X86_CR4_PAE)
+			ctxt->ops->set_cr(ctxt, 4, cr4 & ~X86_CR4_PAE);
 
-	smbase = ctxt->ops->get_smbase(ctxt);
+		/* And finally go back to 32-bit mode.  */
+		efer = 0;
+		ctxt->ops->set_msr(ctxt, MSR_EFER, efer);
+	}
 
 	/*
 	 * Give pre_leave_smm() a chance to make ISA-specific changes to the
 	 * vCPU state (e.g. enter guest mode) before loading state from the SMM
 	 * state-save area.
 	 */
-	if (ctxt->ops->pre_leave_smm(ctxt, smbase))
+	if (ctxt->ops->pre_leave_smm(ctxt, buf))
 		return X86EMUL_UNHANDLEABLE;
 
+#ifdef CONFIG_X86_64
 	if (emulator_has_longmode(ctxt))
-		ret = rsm_load_state_64(ctxt, smbase + 0x8000);
+		ret = rsm_load_state_64(ctxt, buf);
 	else
-		ret = rsm_load_state_32(ctxt, smbase + 0x8000);
+#endif
+		ret = rsm_load_state_32(ctxt, buf);
 
 	if (ret != X86EMUL_CONTINUE) {
 		/* FIXME: should triple fault */
 		return X86EMUL_UNHANDLEABLE;
 	}
 
-	if ((ctxt->ops->get_hflags(ctxt) & X86EMUL_SMM_INSIDE_NMI_MASK) == 0)
-		ctxt->ops->set_nmi_mask(ctxt, false);
+	ctxt->ops->post_leave_smm(ctxt);
 
-	ctxt->ops->set_hflags(ctxt, ctxt->ops->get_hflags(ctxt) &
-		~(X86EMUL_SMM_INSIDE_NMI_MASK | X86EMUL_SMM_MASK));
 	return X86EMUL_CONTINUE;
 }
 
diff --git a/arch/x86/kvm/lapic.c b/arch/x86/kvm/lapic.c
index 991fdf7fc17fbd9e1a4cab99d688a7af820d397c..9bf70cf845648f5e66143440166d57c5fd287bf9 100644
--- a/arch/x86/kvm/lapic.c
+++ b/arch/x86/kvm/lapic.c
@@ -138,6 +138,7 @@ static inline bool kvm_apic_map_get_logical_dest(struct kvm_apic_map *map,
 		if (offset <= max_apic_id) {
 			u8 cluster_size = min(max_apic_id - offset + 1, 16U);
 
+			offset = array_index_nospec(offset, map->max_apic_id + 1);
 			*cluster = &map->phys_map[offset];
 			*mask = dest_id & (0xffff >> (16 - cluster_size));
 		} else {
@@ -901,7 +902,8 @@ static inline bool kvm_apic_map_get_dest_lapic(struct kvm *kvm,
 		if (irq->dest_id > map->max_apic_id) {
 			*bitmap = 0;
 		} else {
-			*dst = &map->phys_map[irq->dest_id];
+			u32 dest_id = array_index_nospec(irq->dest_id, map->max_apic_id + 1);
+			*dst = &map->phys_map[dest_id];
 			*bitmap = 1;
 		}
 		return true;
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index eee455a8a612d00a516bfe892a690bcd8bc91e39..e10962dfc2032d982f124070b88f7d625d2b8f0b 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -2007,7 +2007,7 @@ static int is_empty_shadow_page(u64 *spt)
  * aggregate version in order to make the slab shrinker
  * faster
  */
-static inline void kvm_mod_used_mmu_pages(struct kvm *kvm, int nr)
+static inline void kvm_mod_used_mmu_pages(struct kvm *kvm, unsigned long nr)
 {
 	kvm->arch.n_used_mmu_pages += nr;
 	percpu_counter_add(&kvm_total_used_mmu_pages, nr);
@@ -2238,7 +2238,7 @@ static bool kvm_mmu_remote_flush_or_zap(struct kvm *kvm,
 					struct list_head *invalid_list,
 					bool remote_flush)
 {
-	if (!remote_flush && !list_empty(invalid_list))
+	if (!remote_flush && list_empty(invalid_list))
 		return false;
 
 	if (!list_empty(invalid_list))
@@ -2763,7 +2763,7 @@ static bool prepare_zap_oldest_mmu_page(struct kvm *kvm,
  * Changing the number of mmu pages allocated to the vm
  * Note: if goal_nr_mmu_pages is too small, you will get dead lock
  */
-void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int goal_nr_mmu_pages)
+void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned long goal_nr_mmu_pages)
 {
 	LIST_HEAD(invalid_list);
 
@@ -6031,10 +6031,10 @@ int kvm_mmu_module_init(void)
 /*
  * Calculate mmu pages needed for kvm.
  */
-unsigned int kvm_mmu_calculate_default_mmu_pages(struct kvm *kvm)
+unsigned long kvm_mmu_calculate_default_mmu_pages(struct kvm *kvm)
 {
-	unsigned int nr_mmu_pages;
-	unsigned int  nr_pages = 0;
+	unsigned long nr_mmu_pages;
+	unsigned long nr_pages = 0;
 	struct kvm_memslots *slots;
 	struct kvm_memory_slot *memslot;
 	int i;
@@ -6047,8 +6047,7 @@ unsigned int kvm_mmu_calculate_default_mmu_pages(struct kvm *kvm)
 	}
 
 	nr_mmu_pages = nr_pages * KVM_PERMILLE_MMU_PAGES / 1000;
-	nr_mmu_pages = max(nr_mmu_pages,
-			   (unsigned int) KVM_MIN_ALLOC_MMU_PAGES);
+	nr_mmu_pages = max(nr_mmu_pages, KVM_MIN_ALLOC_MMU_PAGES);
 
 	return nr_mmu_pages;
 }
diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index bbdc60f2fae89beb34c72716d9e7eb9c33584651..54c2a377795be6920bee9676e58555110c3a56b9 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -64,7 +64,7 @@ bool kvm_can_do_async_pf(struct kvm_vcpu *vcpu);
 int kvm_handle_page_fault(struct kvm_vcpu *vcpu, u64 error_code,
 				u64 fault_address, char *insn, int insn_len);
 
-static inline unsigned int kvm_mmu_available_pages(struct kvm *kvm)
+static inline unsigned long kvm_mmu_available_pages(struct kvm *kvm)
 {
 	if (kvm->arch.n_max_mmu_pages > kvm->arch.n_used_mmu_pages)
 		return kvm->arch.n_max_mmu_pages -
diff --git a/arch/x86/kvm/pmu.c b/arch/x86/kvm/pmu.c
index 58ead7db71a312764b56d9f242e84820239eeb93..e39741997893a977fdda077ff637bf465fbb1748 100644
--- a/arch/x86/kvm/pmu.c
+++ b/arch/x86/kvm/pmu.c
@@ -281,9 +281,13 @@ static int kvm_pmu_rdpmc_vmware(struct kvm_vcpu *vcpu, unsigned idx, u64 *data)
 int kvm_pmu_rdpmc(struct kvm_vcpu *vcpu, unsigned idx, u64 *data)
 {
 	bool fast_mode = idx & (1u << 31);
+	struct kvm_pmu *pmu = vcpu_to_pmu(vcpu);
 	struct kvm_pmc *pmc;
 	u64 ctr_val;
 
+	if (!pmu->version)
+		return 1;
+
 	if (is_vmware_backdoor_pmc(idx))
 		return kvm_pmu_rdpmc_vmware(vcpu, idx, data);
 
diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index e0a791c3d4fcc6bb3b1426632e828196953101d5..406b558abfef7379eb46bd2de18e5d6890079eb9 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -262,6 +262,7 @@ struct amd_svm_iommu_ir {
 };
 
 #define AVIC_LOGICAL_ID_ENTRY_GUEST_PHYSICAL_ID_MASK	(0xFF)
+#define AVIC_LOGICAL_ID_ENTRY_VALID_BIT			31
 #define AVIC_LOGICAL_ID_ENTRY_VALID_MASK		(1 << 31)
 
 #define AVIC_PHYSICAL_ID_ENTRY_HOST_PHYSICAL_ID_MASK	(0xFFULL)
@@ -2692,6 +2693,7 @@ static int npf_interception(struct vcpu_svm *svm)
 static int db_interception(struct vcpu_svm *svm)
 {
 	struct kvm_run *kvm_run = svm->vcpu.run;
+	struct kvm_vcpu *vcpu = &svm->vcpu;
 
 	if (!(svm->vcpu.guest_debug &
 	      (KVM_GUESTDBG_SINGLESTEP | KVM_GUESTDBG_USE_HW_BP)) &&
@@ -2702,6 +2704,8 @@ static int db_interception(struct vcpu_svm *svm)
 
 	if (svm->nmi_singlestep) {
 		disable_nmi_singlestep(svm);
+		/* Make sure we check for pending NMIs upon entry */
+		kvm_make_request(KVM_REQ_EVENT, vcpu);
 	}
 
 	if (svm->vcpu.guest_debug &
@@ -4517,14 +4521,25 @@ static int avic_incomplete_ipi_interception(struct vcpu_svm *svm)
 		kvm_lapic_reg_write(apic, APIC_ICR, icrl);
 		break;
 	case AVIC_IPI_FAILURE_TARGET_NOT_RUNNING: {
+		int i;
+		struct kvm_vcpu *vcpu;
+		struct kvm *kvm = svm->vcpu.kvm;
 		struct kvm_lapic *apic = svm->vcpu.arch.apic;
 
 		/*
-		 * Update ICR high and low, then emulate sending IPI,
-		 * which is handled when writing APIC_ICR.
+		 * At this point, we expect that the AVIC HW has already
+		 * set the appropriate IRR bits on the valid target
+		 * vcpus. So, we just need to kick the appropriate vcpu.
 		 */
-		kvm_lapic_reg_write(apic, APIC_ICR2, icrh);
-		kvm_lapic_reg_write(apic, APIC_ICR, icrl);
+		kvm_for_each_vcpu(i, vcpu, kvm) {
+			bool m = kvm_apic_match_dest(vcpu, apic,
+						     icrl & KVM_APIC_SHORT_MASK,
+						     GET_APIC_DEST_FIELD(icrh),
+						     icrl & KVM_APIC_DEST_MASK);
+
+			if (m && !avic_vcpu_is_running(vcpu))
+				kvm_vcpu_wake_up(vcpu);
+		}
 		break;
 	}
 	case AVIC_IPI_FAILURE_INVALID_TARGET:
@@ -4596,7 +4611,7 @@ static void avic_invalidate_logical_id_entry(struct kvm_vcpu *vcpu)
 	u32 *entry = avic_get_logical_id_entry(vcpu, svm->ldr_reg, flat);
 
 	if (entry)
-		WRITE_ONCE(*entry, (u32) ~AVIC_LOGICAL_ID_ENTRY_VALID_MASK);
+		clear_bit(AVIC_LOGICAL_ID_ENTRY_VALID_BIT, (unsigned long *)entry);
 }
 
 static int avic_handle_ldr_update(struct kvm_vcpu *vcpu)
@@ -5621,6 +5636,7 @@ static void svm_vcpu_run(struct kvm_vcpu *vcpu)
 	svm->vmcb->save.cr2 = vcpu->arch.cr2;
 
 	clgi();
+	kvm_load_guest_xcr0(vcpu);
 
 	/*
 	 * If this vCPU has touched SPEC_CTRL, restore the guest's value if
@@ -5766,6 +5782,7 @@ static void svm_vcpu_run(struct kvm_vcpu *vcpu)
 	if (unlikely(svm->vmcb->control.exit_code == SVM_EXIT_NMI))
 		kvm_before_interrupt(&svm->vcpu);
 
+	kvm_put_guest_xcr0(vcpu);
 	stgi();
 
 	/* Any pending NMI will happen here */
@@ -6215,32 +6232,24 @@ static int svm_pre_enter_smm(struct kvm_vcpu *vcpu, char *smstate)
 	return 0;
 }
 
-static int svm_pre_leave_smm(struct kvm_vcpu *vcpu, u64 smbase)
+static int svm_pre_leave_smm(struct kvm_vcpu *vcpu, const char *smstate)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 	struct vmcb *nested_vmcb;
 	struct page *page;
-	struct {
-		u64 guest;
-		u64 vmcb;
-	} svm_state_save;
-	int ret;
+	u64 guest;
+	u64 vmcb;
 
-	ret = kvm_vcpu_read_guest(vcpu, smbase + 0xfed8, &svm_state_save,
-				  sizeof(svm_state_save));
-	if (ret)
-		return ret;
+	guest = GET_SMSTATE(u64, smstate, 0x7ed8);
+	vmcb = GET_SMSTATE(u64, smstate, 0x7ee0);
 
-	if (svm_state_save.guest) {
-		vcpu->arch.hflags &= ~HF_SMM_MASK;
-		nested_vmcb = nested_svm_map(svm, svm_state_save.vmcb, &page);
-		if (nested_vmcb)
-			enter_svm_guest_mode(svm, svm_state_save.vmcb, nested_vmcb, page);
-		else
-			ret = 1;
-		vcpu->arch.hflags |= HF_SMM_MASK;
+	if (guest) {
+		nested_vmcb = nested_svm_map(svm, vmcb, &page);
+		if (!nested_vmcb)
+			return 1;
+		enter_svm_guest_mode(svm, vmcb, nested_vmcb, page);
 	}
-	return ret;
+	return 0;
 }
 
 static int enable_smi_window(struct kvm_vcpu *vcpu)
diff --git a/arch/x86/kvm/trace.h b/arch/x86/kvm/trace.h
index 6432d08c7de79ccbde654b7ab17c9649b75a25c2..4d47a2631d1fb46d9f913b59743cb5417d7401c6 100644
--- a/arch/x86/kvm/trace.h
+++ b/arch/x86/kvm/trace.h
@@ -438,13 +438,13 @@ TRACE_EVENT(kvm_apic_ipi,
 );
 
 TRACE_EVENT(kvm_apic_accept_irq,
-	    TP_PROTO(__u32 apicid, __u16 dm, __u8 tm, __u8 vec),
+	    TP_PROTO(__u32 apicid, __u16 dm, __u16 tm, __u8 vec),
 	    TP_ARGS(apicid, dm, tm, vec),
 
 	TP_STRUCT__entry(
 		__field(	__u32,		apicid		)
 		__field(	__u16,		dm		)
-		__field(	__u8,		tm		)
+		__field(	__u16,		tm		)
 		__field(	__u8,		vec		)
 	),
 
diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index 7ec9bb1dd72315d7c725c3e197639b09ee3e6386..6401eb7ef19ce0e9f9258b617e001dfdac534a2a 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -2873,20 +2873,27 @@ static void nested_get_vmcs12_pages(struct kvm_vcpu *vcpu)
 		/*
 		 * If translation failed, VM entry will fail because
 		 * prepare_vmcs02 set VIRTUAL_APIC_PAGE_ADDR to -1ull.
-		 * Failing the vm entry is _not_ what the processor
-		 * does but it's basically the only possibility we
-		 * have.  We could still enter the guest if CR8 load
-		 * exits are enabled, CR8 store exits are enabled, and
-		 * virtualize APIC access is disabled; in this case
-		 * the processor would never use the TPR shadow and we
-		 * could simply clear the bit from the execution
-		 * control.  But such a configuration is useless, so
-		 * let's keep the code simple.
 		 */
 		if (!is_error_page(page)) {
 			vmx->nested.virtual_apic_page = page;
 			hpa = page_to_phys(vmx->nested.virtual_apic_page);
 			vmcs_write64(VIRTUAL_APIC_PAGE_ADDR, hpa);
+		} else if (nested_cpu_has(vmcs12, CPU_BASED_CR8_LOAD_EXITING) &&
+		           nested_cpu_has(vmcs12, CPU_BASED_CR8_STORE_EXITING) &&
+			   !nested_cpu_has2(vmcs12, SECONDARY_EXEC_VIRTUALIZE_APIC_ACCESSES)) {
+			/*
+			 * The processor will never use the TPR shadow, simply
+			 * clear the bit from the execution control.  Such a
+			 * configuration is useless, but it happens in tests.
+			 * For any other configuration, failing the vm entry is
+			 * _not_ what the processor does but it's basically the
+			 * only possibility we have.
+			 */
+			vmcs_clear_bits(CPU_BASED_VM_EXEC_CONTROL,
+					CPU_BASED_TPR_SHADOW);
+		} else {
+			printk("bad virtual-APIC page address\n");
+			dump_vmcs();
 		}
 	}
 
@@ -3789,8 +3796,18 @@ static void nested_vmx_restore_host_state(struct kvm_vcpu *vcpu)
 	vmx_set_cr4(vcpu, vmcs_readl(CR4_READ_SHADOW));
 
 	nested_ept_uninit_mmu_context(vcpu);
-	vcpu->arch.cr3 = vmcs_readl(GUEST_CR3);
-	__set_bit(VCPU_EXREG_CR3, (ulong *)&vcpu->arch.regs_avail);
+
+	/*
+	 * This is only valid if EPT is in use, otherwise the vmcs01 GUEST_CR3
+	 * points to shadow pages!  Fortunately we only get here after a WARN_ON
+	 * if EPT is disabled, so a VMabort is perfectly fine.
+	 */
+	if (enable_ept) {
+		vcpu->arch.cr3 = vmcs_readl(GUEST_CR3);
+		__set_bit(VCPU_EXREG_CR3, (ulong *)&vcpu->arch.regs_avail);
+	} else {
+		nested_vmx_abort(vcpu, VMX_ABORT_VMCS_CORRUPTED);
+	}
 
 	/*
 	 * Use ept_save_pdptrs(vcpu) to load the MMU's cached PDPTRs
@@ -5738,6 +5755,14 @@ __init int nested_vmx_hardware_setup(int (*exit_handlers[])(struct kvm_vcpu *))
 {
 	int i;
 
+	/*
+	 * Without EPT it is not possible to restore L1's CR3 and PDPTR on
+	 * VMfail, because they are not available in vmcs01.  Just always
+	 * use hardware checks.
+	 */
+	if (!enable_ept)
+		nested_early_check = 1;
+
 	if (!cpu_has_vmx_shadow_vmcs())
 		enable_shadow_vmcs = 0;
 	if (enable_shadow_vmcs) {
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index ab432a930ae865d0000d8273643de236d0738fb8..b4e7d645275a2153c42fa252cce8a8cbb930b59e 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -5603,7 +5603,7 @@ static void vmx_dump_dtsel(char *name, uint32_t limit)
 	       vmcs_readl(limit + GUEST_GDTR_BASE - GUEST_GDTR_LIMIT));
 }
 
-static void dump_vmcs(void)
+void dump_vmcs(void)
 {
 	u32 vmentry_ctl = vmcs_read32(VM_ENTRY_CONTROLS);
 	u32 vmexit_ctl = vmcs_read32(VM_EXIT_CONTROLS);
@@ -6410,6 +6410,8 @@ static void vmx_vcpu_run(struct kvm_vcpu *vcpu)
 	if (vcpu->guest_debug & KVM_GUESTDBG_SINGLESTEP)
 		vmx_set_interrupt_shadow(vcpu, 0);
 
+	kvm_load_guest_xcr0(vcpu);
+
 	if (static_cpu_has(X86_FEATURE_PKU) &&
 	    kvm_read_cr4_bits(vcpu, X86_CR4_PKE) &&
 	    vcpu->arch.pkru != vmx->host_pkru)
@@ -6506,6 +6508,8 @@ static void vmx_vcpu_run(struct kvm_vcpu *vcpu)
 			__write_pkru(vmx->host_pkru);
 	}
 
+	kvm_put_guest_xcr0(vcpu);
+
 	vmx->nested.nested_run_pending = 0;
 	vmx->idt_vectoring_info = 0;
 
@@ -6852,6 +6856,30 @@ static void nested_vmx_entry_exit_ctls_update(struct kvm_vcpu *vcpu)
 	}
 }
 
+static bool guest_cpuid_has_pmu(struct kvm_vcpu *vcpu)
+{
+	struct kvm_cpuid_entry2 *entry;
+	union cpuid10_eax eax;
+
+	entry = kvm_find_cpuid_entry(vcpu, 0xa, 0);
+	if (!entry)
+		return false;
+
+	eax.full = entry->eax;
+	return (eax.split.version_id > 0);
+}
+
+static void nested_vmx_procbased_ctls_update(struct kvm_vcpu *vcpu)
+{
+	struct vcpu_vmx *vmx = to_vmx(vcpu);
+	bool pmu_enabled = guest_cpuid_has_pmu(vcpu);
+
+	if (pmu_enabled)
+		vmx->nested.msrs.procbased_ctls_high |= CPU_BASED_RDPMC_EXITING;
+	else
+		vmx->nested.msrs.procbased_ctls_high &= ~CPU_BASED_RDPMC_EXITING;
+}
+
 static void update_intel_pt_cfg(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_vmx *vmx = to_vmx(vcpu);
@@ -6940,6 +6968,7 @@ static void vmx_cpuid_update(struct kvm_vcpu *vcpu)
 	if (nested_vmx_allowed(vcpu)) {
 		nested_vmx_cr_fixed1_bits_update(vcpu);
 		nested_vmx_entry_exit_ctls_update(vcpu);
+		nested_vmx_procbased_ctls_update(vcpu);
 	}
 
 	if (boot_cpu_has(X86_FEATURE_INTEL_PT) &&
@@ -7369,7 +7398,7 @@ static int vmx_pre_enter_smm(struct kvm_vcpu *vcpu, char *smstate)
 	return 0;
 }
 
-static int vmx_pre_leave_smm(struct kvm_vcpu *vcpu, u64 smbase)
+static int vmx_pre_leave_smm(struct kvm_vcpu *vcpu, const char *smstate)
 {
 	struct vcpu_vmx *vmx = to_vmx(vcpu);
 	int ret;
@@ -7380,9 +7409,7 @@ static int vmx_pre_leave_smm(struct kvm_vcpu *vcpu, u64 smbase)
 	}
 
 	if (vmx->nested.smm.guest_mode) {
-		vcpu->arch.hflags &= ~HF_SMM_MASK;
 		ret = nested_vmx_enter_non_root_mode(vcpu, false);
-		vcpu->arch.hflags |= HF_SMM_MASK;
 		if (ret)
 			return ret;
 
diff --git a/arch/x86/kvm/vmx/vmx.h b/arch/x86/kvm/vmx/vmx.h
index a1e00d0a2482c16b81be561c30a4d10d3233975b..f879529906b48cd84e99cc0f672210aaeaffeabd 100644
--- a/arch/x86/kvm/vmx/vmx.h
+++ b/arch/x86/kvm/vmx/vmx.h
@@ -517,4 +517,6 @@ static inline void decache_tsc_multiplier(struct vcpu_vmx *vmx)
 	vmcs_write64(TSC_MULTIPLIER, vmx->current_tsc_ratio);
 }
 
+void dump_vmcs(void);
+
 #endif /* __KVM_X86_VMX_H */
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index 099b851dabafd7e2980f96472209777f9cc8f77b..a0d1fc80ac5a8407c123d8df12eb2215d4d70392 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -800,7 +800,7 @@ void kvm_lmsw(struct kvm_vcpu *vcpu, unsigned long msw)
 }
 EXPORT_SYMBOL_GPL(kvm_lmsw);
 
-static void kvm_load_guest_xcr0(struct kvm_vcpu *vcpu)
+void kvm_load_guest_xcr0(struct kvm_vcpu *vcpu)
 {
 	if (kvm_read_cr4_bits(vcpu, X86_CR4_OSXSAVE) &&
 			!vcpu->guest_xcr0_loaded) {
@@ -810,8 +810,9 @@ static void kvm_load_guest_xcr0(struct kvm_vcpu *vcpu)
 		vcpu->guest_xcr0_loaded = 1;
 	}
 }
+EXPORT_SYMBOL_GPL(kvm_load_guest_xcr0);
 
-static void kvm_put_guest_xcr0(struct kvm_vcpu *vcpu)
+void kvm_put_guest_xcr0(struct kvm_vcpu *vcpu)
 {
 	if (vcpu->guest_xcr0_loaded) {
 		if (vcpu->arch.xcr0 != host_xcr0)
@@ -819,6 +820,7 @@ static void kvm_put_guest_xcr0(struct kvm_vcpu *vcpu)
 		vcpu->guest_xcr0_loaded = 0;
 	}
 }
+EXPORT_SYMBOL_GPL(kvm_put_guest_xcr0);
 
 static int __kvm_set_xcr(struct kvm_vcpu *vcpu, u32 index, u64 xcr)
 {
@@ -3093,7 +3095,7 @@ int kvm_vm_ioctl_check_extension(struct kvm *kvm, long ext)
 		break;
 	case KVM_CAP_NESTED_STATE:
 		r = kvm_x86_ops->get_nested_state ?
-			kvm_x86_ops->get_nested_state(NULL, 0, 0) : 0;
+			kvm_x86_ops->get_nested_state(NULL, NULL, 0) : 0;
 		break;
 	default:
 		break;
@@ -3528,7 +3530,7 @@ static void kvm_vcpu_ioctl_x86_get_vcpu_events(struct kvm_vcpu *vcpu,
 	memset(&events->reserved, 0, sizeof(events->reserved));
 }
 
-static void kvm_set_hflags(struct kvm_vcpu *vcpu, unsigned emul_flags);
+static void kvm_smm_changed(struct kvm_vcpu *vcpu);
 
 static int kvm_vcpu_ioctl_x86_set_vcpu_events(struct kvm_vcpu *vcpu,
 					      struct kvm_vcpu_events *events)
@@ -3588,12 +3590,13 @@ static int kvm_vcpu_ioctl_x86_set_vcpu_events(struct kvm_vcpu *vcpu,
 		vcpu->arch.apic->sipi_vector = events->sipi_vector;
 
 	if (events->flags & KVM_VCPUEVENT_VALID_SMM) {
-		u32 hflags = vcpu->arch.hflags;
-		if (events->smi.smm)
-			hflags |= HF_SMM_MASK;
-		else
-			hflags &= ~HF_SMM_MASK;
-		kvm_set_hflags(vcpu, hflags);
+		if (!!(vcpu->arch.hflags & HF_SMM_MASK) != events->smi.smm) {
+			if (events->smi.smm)
+				vcpu->arch.hflags |= HF_SMM_MASK;
+			else
+				vcpu->arch.hflags &= ~HF_SMM_MASK;
+			kvm_smm_changed(vcpu);
+		}
 
 		vcpu->arch.smi_pending = events->smi.pending;
 
@@ -4270,7 +4273,7 @@ static int kvm_vm_ioctl_set_identity_map_addr(struct kvm *kvm,
 }
 
 static int kvm_vm_ioctl_set_nr_mmu_pages(struct kvm *kvm,
-					  u32 kvm_nr_mmu_pages)
+					 unsigned long kvm_nr_mmu_pages)
 {
 	if (kvm_nr_mmu_pages < KVM_MIN_ALLOC_MMU_PAGES)
 		return -EINVAL;
@@ -4284,7 +4287,7 @@ static int kvm_vm_ioctl_set_nr_mmu_pages(struct kvm *kvm,
 	return 0;
 }
 
-static int kvm_vm_ioctl_get_nr_mmu_pages(struct kvm *kvm)
+static unsigned long kvm_vm_ioctl_get_nr_mmu_pages(struct kvm *kvm)
 {
 	return kvm->arch.n_max_mmu_pages;
 }
@@ -5958,12 +5961,18 @@ static unsigned emulator_get_hflags(struct x86_emulate_ctxt *ctxt)
 
 static void emulator_set_hflags(struct x86_emulate_ctxt *ctxt, unsigned emul_flags)
 {
-	kvm_set_hflags(emul_to_vcpu(ctxt), emul_flags);
+	emul_to_vcpu(ctxt)->arch.hflags = emul_flags;
+}
+
+static int emulator_pre_leave_smm(struct x86_emulate_ctxt *ctxt,
+				  const char *smstate)
+{
+	return kvm_x86_ops->pre_leave_smm(emul_to_vcpu(ctxt), smstate);
 }
 
-static int emulator_pre_leave_smm(struct x86_emulate_ctxt *ctxt, u64 smbase)
+static void emulator_post_leave_smm(struct x86_emulate_ctxt *ctxt)
 {
-	return kvm_x86_ops->pre_leave_smm(emul_to_vcpu(ctxt), smbase);
+	kvm_smm_changed(emul_to_vcpu(ctxt));
 }
 
 static const struct x86_emulate_ops emulate_ops = {
@@ -6006,6 +6015,7 @@ static const struct x86_emulate_ops emulate_ops = {
 	.get_hflags          = emulator_get_hflags,
 	.set_hflags          = emulator_set_hflags,
 	.pre_leave_smm       = emulator_pre_leave_smm,
+	.post_leave_smm      = emulator_post_leave_smm,
 };
 
 static void toggle_interruptibility(struct kvm_vcpu *vcpu, u32 mask)
@@ -6247,16 +6257,6 @@ static void kvm_smm_changed(struct kvm_vcpu *vcpu)
 	kvm_mmu_reset_context(vcpu);
 }
 
-static void kvm_set_hflags(struct kvm_vcpu *vcpu, unsigned emul_flags)
-{
-	unsigned changed = vcpu->arch.hflags ^ emul_flags;
-
-	vcpu->arch.hflags = emul_flags;
-
-	if (changed & HF_SMM_MASK)
-		kvm_smm_changed(vcpu);
-}
-
 static int kvm_vcpu_check_hw_bp(unsigned long addr, u32 type, u32 dr7,
 				unsigned long *db)
 {
@@ -7441,9 +7441,9 @@ static void enter_smm_save_state_32(struct kvm_vcpu *vcpu, char *buf)
 	put_smstate(u32, buf, 0x7ef8, vcpu->arch.smbase);
 }
 
+#ifdef CONFIG_X86_64
 static void enter_smm_save_state_64(struct kvm_vcpu *vcpu, char *buf)
 {
-#ifdef CONFIG_X86_64
 	struct desc_ptr dt;
 	struct kvm_segment seg;
 	unsigned long val;
@@ -7493,10 +7493,8 @@ static void enter_smm_save_state_64(struct kvm_vcpu *vcpu, char *buf)
 
 	for (i = 0; i < 6; i++)
 		enter_smm_save_seg_64(vcpu, buf, i);
-#else
-	WARN_ON_ONCE(1);
-#endif
 }
+#endif
 
 static void enter_smm(struct kvm_vcpu *vcpu)
 {
@@ -7507,9 +7505,11 @@ static void enter_smm(struct kvm_vcpu *vcpu)
 
 	trace_kvm_enter_smm(vcpu->vcpu_id, vcpu->arch.smbase, true);
 	memset(buf, 0, 512);
+#ifdef CONFIG_X86_64
 	if (guest_cpuid_has(vcpu, X86_FEATURE_LM))
 		enter_smm_save_state_64(vcpu, buf);
 	else
+#endif
 		enter_smm_save_state_32(vcpu, buf);
 
 	/*
@@ -7567,8 +7567,10 @@ static void enter_smm(struct kvm_vcpu *vcpu)
 	kvm_set_segment(vcpu, &ds, VCPU_SREG_GS);
 	kvm_set_segment(vcpu, &ds, VCPU_SREG_SS);
 
+#ifdef CONFIG_X86_64
 	if (guest_cpuid_has(vcpu, X86_FEATURE_LM))
 		kvm_x86_ops->set_efer(vcpu, 0);
+#endif
 
 	kvm_update_cpuid(vcpu);
 	kvm_mmu_reset_context(vcpu);
@@ -7865,8 +7867,6 @@ static int vcpu_enter_guest(struct kvm_vcpu *vcpu)
 		goto cancel_injection;
 	}
 
-	kvm_load_guest_xcr0(vcpu);
-
 	if (req_immediate_exit) {
 		kvm_make_request(KVM_REQ_EVENT, vcpu);
 		kvm_x86_ops->request_immediate_exit(vcpu);
@@ -7919,8 +7919,6 @@ static int vcpu_enter_guest(struct kvm_vcpu *vcpu)
 	vcpu->mode = OUTSIDE_GUEST_MODE;
 	smp_wmb();
 
-	kvm_put_guest_xcr0(vcpu);
-
 	kvm_before_interrupt(vcpu);
 	kvm_x86_ops->handle_external_intr(vcpu);
 	kvm_after_interrupt(vcpu);
diff --git a/arch/x86/kvm/x86.h b/arch/x86/kvm/x86.h
index 28406aa1136d7eb772ed712f9df34ffe14290e66..aedc5d0d4989b3fc7422c17e55fc6b65bfef06a3 100644
--- a/arch/x86/kvm/x86.h
+++ b/arch/x86/kvm/x86.h
@@ -347,4 +347,6 @@ static inline void kvm_after_interrupt(struct kvm_vcpu *vcpu)
 	__this_cpu_write(current_vcpu, NULL);
 }
 
+void kvm_load_guest_xcr0(struct kvm_vcpu *vcpu);
+void kvm_put_guest_xcr0(struct kvm_vcpu *vcpu);
 #endif
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 9d55c63db09b5dcb9ac997d802cb00ff356d4353..640a03642766bb4ae02c86e3606318c80adaf81d 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -28,6 +28,7 @@
 #include <linux/irqbypass.h>
 #include <linux/swait.h>
 #include <linux/refcount.h>
+#include <linux/nospec.h>
 #include <asm/signal.h>
 
 #include <linux/kvm.h>
@@ -513,10 +514,10 @@ static inline struct kvm_io_bus *kvm_get_bus(struct kvm *kvm, enum kvm_bus idx)
 
 static inline struct kvm_vcpu *kvm_get_vcpu(struct kvm *kvm, int i)
 {
-	/* Pairs with smp_wmb() in kvm_vm_ioctl_create_vcpu, in case
-	 * the caller has read kvm->online_vcpus before (as is the case
-	 * for kvm_for_each_vcpu, for example).
-	 */
+	int num_vcpus = atomic_read(&kvm->online_vcpus);
+	i = array_index_nospec(i, num_vcpus);
+
+	/* Pairs with smp_wmb() in kvm_vm_ioctl_create_vcpu.  */
 	smp_rmb();
 	return kvm->vcpus[i];
 }
@@ -600,6 +601,7 @@ void kvm_put_kvm(struct kvm *kvm);
 
 static inline struct kvm_memslots *__kvm_memslots(struct kvm *kvm, int as_id)
 {
+	as_id = array_index_nospec(as_id, KVM_ADDRESS_SPACE_NUM);
 	return srcu_dereference_check(kvm->memslots[as_id], &kvm->srcu,
 			lockdep_is_held(&kvm->slots_lock) ||
 			!refcount_read(&kvm->users_count));
diff --git a/tools/testing/selftests/kvm/Makefile b/tools/testing/selftests/kvm/Makefile
index 7514fcea91a73e80a91313ab280fb90e12375138..f8588cca2bef4bfe4d3cdf2afdb6586f21e67894 100644
--- a/tools/testing/selftests/kvm/Makefile
+++ b/tools/testing/selftests/kvm/Makefile
@@ -1,3 +1,5 @@
+include ../../../../scripts/Kbuild.include
+
 all:
 
 top_srcdir = ../../../..
@@ -17,6 +19,7 @@ TEST_GEN_PROGS_x86_64 += x86_64/state_test
 TEST_GEN_PROGS_x86_64 += x86_64/evmcs_test
 TEST_GEN_PROGS_x86_64 += x86_64/hyperv_cpuid
 TEST_GEN_PROGS_x86_64 += x86_64/vmx_close_while_nested_test
+TEST_GEN_PROGS_x86_64 += x86_64/smm_test
 TEST_GEN_PROGS_x86_64 += dirty_log_test
 TEST_GEN_PROGS_x86_64 += clear_dirty_log_test
 
@@ -30,7 +33,11 @@ INSTALL_HDR_PATH = $(top_srcdir)/usr
 LINUX_HDR_PATH = $(INSTALL_HDR_PATH)/include/
 LINUX_TOOL_INCLUDE = $(top_srcdir)/tools/include
 CFLAGS += -O2 -g -std=gnu99 -fno-stack-protector -fno-PIE -I$(LINUX_TOOL_INCLUDE) -I$(LINUX_HDR_PATH) -Iinclude -I$(<D) -Iinclude/$(UNAME_M) -I..
-LDFLAGS += -pthread -no-pie
+
+no-pie-option := $(call try-run, echo 'int main() { return 0; }' | \
+        $(CC) -Werror $(KBUILD_CPPFLAGS) $(CC_OPTION_CFLAGS) -no-pie -x c - -o "$$TMP", -no-pie)
+
+LDFLAGS += -pthread $(no-pie-option)
 
 # After inclusion, $(OUTPUT) is defined and
 # $(TEST_GEN_PROGS) starts with $(OUTPUT)/
diff --git a/tools/testing/selftests/kvm/include/x86_64/processor.h b/tools/testing/selftests/kvm/include/x86_64/processor.h
index e2884c2b81fff80c1ec6c261828dbb0493b3e98b..6063d5b2f3561c450778f86f3d1474390d79b5ec 100644
--- a/tools/testing/selftests/kvm/include/x86_64/processor.h
+++ b/tools/testing/selftests/kvm/include/x86_64/processor.h
@@ -778,6 +778,33 @@ void vcpu_set_msr(struct kvm_vm *vm, uint32_t vcpuid, uint64_t msr_index,
 #define MSR_IA32_APICBASE_ENABLE	(1<<11)
 #define MSR_IA32_APICBASE_BASE		(0xfffff<<12)
 
+#define APIC_BASE_MSR	0x800
+#define X2APIC_ENABLE	(1UL << 10)
+#define	APIC_ICR	0x300
+#define		APIC_DEST_SELF		0x40000
+#define		APIC_DEST_ALLINC	0x80000
+#define		APIC_DEST_ALLBUT	0xC0000
+#define		APIC_ICR_RR_MASK	0x30000
+#define		APIC_ICR_RR_INVALID	0x00000
+#define		APIC_ICR_RR_INPROG	0x10000
+#define		APIC_ICR_RR_VALID	0x20000
+#define		APIC_INT_LEVELTRIG	0x08000
+#define		APIC_INT_ASSERT		0x04000
+#define		APIC_ICR_BUSY		0x01000
+#define		APIC_DEST_LOGICAL	0x00800
+#define		APIC_DEST_PHYSICAL	0x00000
+#define		APIC_DM_FIXED		0x00000
+#define		APIC_DM_FIXED_MASK	0x00700
+#define		APIC_DM_LOWEST		0x00100
+#define		APIC_DM_SMI		0x00200
+#define		APIC_DM_REMRD		0x00300
+#define		APIC_DM_NMI		0x00400
+#define		APIC_DM_INIT		0x00500
+#define		APIC_DM_STARTUP		0x00600
+#define		APIC_DM_EXTINT		0x00700
+#define		APIC_VECTOR_MASK	0x000FF
+#define	APIC_ICR2	0x310
+
 #define MSR_IA32_TSCDEADLINE		0x000006e0
 
 #define MSR_IA32_UCODE_WRITE		0x00000079
diff --git a/tools/testing/selftests/kvm/lib/kvm_util.c b/tools/testing/selftests/kvm/lib/kvm_util.c
index efa0aad8b3c69ab370a1f5440194cee3486c11db..4ca96b228e46ba248476803583cb94d14410ff16 100644
--- a/tools/testing/selftests/kvm/lib/kvm_util.c
+++ b/tools/testing/selftests/kvm/lib/kvm_util.c
@@ -91,6 +91,11 @@ static void vm_open(struct kvm_vm *vm, int perm, unsigned long type)
 	if (vm->kvm_fd < 0)
 		exit(KSFT_SKIP);
 
+	if (!kvm_check_cap(KVM_CAP_IMMEDIATE_EXIT)) {
+		fprintf(stderr, "immediate_exit not available, skipping test\n");
+		exit(KSFT_SKIP);
+	}
+
 	vm->fd = ioctl(vm->kvm_fd, KVM_CREATE_VM, type);
 	TEST_ASSERT(vm->fd >= 0, "KVM_CREATE_VM ioctl failed, "
 		"rc: %i errno: %i", vm->fd, errno);
diff --git a/tools/testing/selftests/kvm/lib/x86_64/processor.c b/tools/testing/selftests/kvm/lib/x86_64/processor.c
index f28127f4a3af63cb9ac15d2124f425e7492fccda..dc7fae9fa424cf2b45fb7acf10c4b58c272763a0 100644
--- a/tools/testing/selftests/kvm/lib/x86_64/processor.c
+++ b/tools/testing/selftests/kvm/lib/x86_64/processor.c
@@ -1030,6 +1030,14 @@ struct kvm_x86_state *vcpu_save_state(struct kvm_vm *vm, uint32_t vcpuid)
 			    nested_size, sizeof(state->nested_));
 	}
 
+	/*
+	 * When KVM exits to userspace with KVM_EXIT_IO, KVM guarantees
+	 * guest state is consistent only after userspace re-enters the
+	 * kernel with KVM_RUN.  Complete IO prior to migrating state
+	 * to a new VM.
+	 */
+	vcpu_run_complete_io(vm, vcpuid);
+
 	nmsrs = kvm_get_num_msrs(vm);
 	list = malloc(sizeof(*list) + nmsrs * sizeof(list->indices[0]));
 	list->nmsrs = nmsrs;
@@ -1093,12 +1101,6 @@ void vcpu_load_state(struct kvm_vm *vm, uint32_t vcpuid, struct kvm_x86_state *s
 	struct vcpu *vcpu = vcpu_find(vm, vcpuid);
 	int r;
 
-	if (state->nested.size) {
-		r = ioctl(vcpu->fd, KVM_SET_NESTED_STATE, &state->nested);
-		TEST_ASSERT(r == 0, "Unexpected result from KVM_SET_NESTED_STATE, r: %i",
-			r);
-	}
-
 	r = ioctl(vcpu->fd, KVM_SET_XSAVE, &state->xsave);
         TEST_ASSERT(r == 0, "Unexpected result from KVM_SET_XSAVE, r: %i",
                 r);
@@ -1130,4 +1132,10 @@ void vcpu_load_state(struct kvm_vm *vm, uint32_t vcpuid, struct kvm_x86_state *s
 	r = ioctl(vcpu->fd, KVM_SET_REGS, &state->regs);
         TEST_ASSERT(r == 0, "Unexpected result from KVM_SET_REGS, r: %i",
                 r);
+
+	if (state->nested.size) {
+		r = ioctl(vcpu->fd, KVM_SET_NESTED_STATE, &state->nested);
+		TEST_ASSERT(r == 0, "Unexpected result from KVM_SET_NESTED_STATE, r: %i",
+			r);
+	}
 }
diff --git a/tools/testing/selftests/kvm/x86_64/evmcs_test.c b/tools/testing/selftests/kvm/x86_64/evmcs_test.c
index c49c2a28b0eb290ccd6c51498a0b9fd716b58b07..36669684eca58a6c09140453f70a403cf0119348 100644
--- a/tools/testing/selftests/kvm/x86_64/evmcs_test.c
+++ b/tools/testing/selftests/kvm/x86_64/evmcs_test.c
@@ -123,8 +123,6 @@ int main(int argc, char *argv[])
 			    stage, run->exit_reason,
 			    exit_reason_str(run->exit_reason));
 
-		memset(&regs1, 0, sizeof(regs1));
-		vcpu_regs_get(vm, VCPU_ID, &regs1);
 		switch (get_ucall(vm, VCPU_ID, &uc)) {
 		case UCALL_ABORT:
 			TEST_ASSERT(false, "%s at %s:%d", (const char *)uc.args[0],
@@ -144,6 +142,9 @@ int main(int argc, char *argv[])
 			    stage, (ulong)uc.args[1]);
 
 		state = vcpu_save_state(vm, VCPU_ID);
+		memset(&regs1, 0, sizeof(regs1));
+		vcpu_regs_get(vm, VCPU_ID, &regs1);
+
 		kvm_vm_release(vm);
 
 		/* Restore state in a new VM.  */
diff --git a/tools/testing/selftests/kvm/x86_64/smm_test.c b/tools/testing/selftests/kvm/x86_64/smm_test.c
new file mode 100644
index 0000000000000000000000000000000000000000..fb8086964d83b80642250b1626f444ede7bb34f2
--- /dev/null
+++ b/tools/testing/selftests/kvm/x86_64/smm_test.c
@@ -0,0 +1,157 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (C) 2018, Red Hat, Inc.
+ *
+ * Tests for SMM.
+ */
+#define _GNU_SOURCE /* for program_invocation_short_name */
+#include <fcntl.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/ioctl.h>
+
+#include "test_util.h"
+
+#include "kvm_util.h"
+
+#include "vmx.h"
+
+#define VCPU_ID	      1
+
+#define PAGE_SIZE  4096
+
+#define SMRAM_SIZE 65536
+#define SMRAM_MEMSLOT ((1 << 16) | 1)
+#define SMRAM_PAGES (SMRAM_SIZE / PAGE_SIZE)
+#define SMRAM_GPA 0x1000000
+#define SMRAM_STAGE 0xfe
+
+#define STR(x) #x
+#define XSTR(s) STR(s)
+
+#define SYNC_PORT 0xe
+#define DONE 0xff
+
+/*
+ * This is compiled as normal 64-bit code, however, SMI handler is executed
+ * in real-address mode. To stay simple we're limiting ourselves to a mode
+ * independent subset of asm here.
+ * SMI handler always report back fixed stage SMRAM_STAGE.
+ */
+uint8_t smi_handler[] = {
+	0xb0, SMRAM_STAGE,    /* mov $SMRAM_STAGE, %al */
+	0xe4, SYNC_PORT,      /* in $SYNC_PORT, %al */
+	0x0f, 0xaa,           /* rsm */
+};
+
+void sync_with_host(uint64_t phase)
+{
+	asm volatile("in $" XSTR(SYNC_PORT)", %%al \n"
+		     : : "a" (phase));
+}
+
+void self_smi(void)
+{
+	wrmsr(APIC_BASE_MSR + (APIC_ICR >> 4),
+	      APIC_DEST_SELF | APIC_INT_ASSERT | APIC_DM_SMI);
+}
+
+void guest_code(struct vmx_pages *vmx_pages)
+{
+	uint64_t apicbase = rdmsr(MSR_IA32_APICBASE);
+
+	sync_with_host(1);
+
+	wrmsr(MSR_IA32_APICBASE, apicbase | X2APIC_ENABLE);
+
+	sync_with_host(2);
+
+	self_smi();
+
+	sync_with_host(4);
+
+	if (vmx_pages) {
+		GUEST_ASSERT(prepare_for_vmx_operation(vmx_pages));
+
+		sync_with_host(5);
+
+		self_smi();
+
+		sync_with_host(7);
+	}
+
+	sync_with_host(DONE);
+}
+
+int main(int argc, char *argv[])
+{
+	struct vmx_pages *vmx_pages = NULL;
+	vm_vaddr_t vmx_pages_gva = 0;
+
+	struct kvm_regs regs;
+	struct kvm_vm *vm;
+	struct kvm_run *run;
+	struct kvm_x86_state *state;
+	int stage, stage_reported;
+
+	/* Create VM */
+	vm = vm_create_default(VCPU_ID, 0, guest_code);
+
+	vcpu_set_cpuid(vm, VCPU_ID, kvm_get_supported_cpuid());
+
+	run = vcpu_state(vm, VCPU_ID);
+
+	vm_userspace_mem_region_add(vm, VM_MEM_SRC_ANONYMOUS, SMRAM_GPA,
+				    SMRAM_MEMSLOT, SMRAM_PAGES, 0);
+	TEST_ASSERT(vm_phy_pages_alloc(vm, SMRAM_PAGES, SMRAM_GPA, SMRAM_MEMSLOT)
+		    == SMRAM_GPA, "could not allocate guest physical addresses?");
+
+	memset(addr_gpa2hva(vm, SMRAM_GPA), 0x0, SMRAM_SIZE);
+	memcpy(addr_gpa2hva(vm, SMRAM_GPA) + 0x8000, smi_handler,
+	       sizeof(smi_handler));
+
+	vcpu_set_msr(vm, VCPU_ID, MSR_IA32_SMBASE, SMRAM_GPA);
+
+	if (kvm_check_cap(KVM_CAP_NESTED_STATE)) {
+		vmx_pages = vcpu_alloc_vmx(vm, &vmx_pages_gva);
+		vcpu_args_set(vm, VCPU_ID, 1, vmx_pages_gva);
+	} else {
+		printf("will skip SMM test with VMX enabled\n");
+		vcpu_args_set(vm, VCPU_ID, 1, 0);
+	}
+
+	for (stage = 1;; stage++) {
+		_vcpu_run(vm, VCPU_ID);
+		TEST_ASSERT(run->exit_reason == KVM_EXIT_IO,
+			    "Stage %d: unexpected exit reason: %u (%s),\n",
+			    stage, run->exit_reason,
+			    exit_reason_str(run->exit_reason));
+
+		memset(&regs, 0, sizeof(regs));
+		vcpu_regs_get(vm, VCPU_ID, &regs);
+
+		stage_reported = regs.rax & 0xff;
+
+		if (stage_reported == DONE)
+			goto done;
+
+		TEST_ASSERT(stage_reported == stage ||
+			    stage_reported == SMRAM_STAGE,
+			    "Unexpected stage: #%x, got %x",
+			    stage, stage_reported);
+
+		state = vcpu_save_state(vm, VCPU_ID);
+		kvm_vm_release(vm);
+		kvm_vm_restart(vm, O_RDWR);
+		vm_vcpu_add(vm, VCPU_ID, 0, 0);
+		vcpu_set_cpuid(vm, VCPU_ID, kvm_get_supported_cpuid());
+		vcpu_load_state(vm, VCPU_ID, state);
+		run = vcpu_state(vm, VCPU_ID);
+		free(state);
+	}
+
+done:
+	kvm_vm_free(vm);
+}
diff --git a/tools/testing/selftests/kvm/x86_64/state_test.c b/tools/testing/selftests/kvm/x86_64/state_test.c
index 30f75856cf3984277bee22caad9e5df95f98aa26..e0a3c0204b7cd11c5da7024bea68f0da71e41bab 100644
--- a/tools/testing/selftests/kvm/x86_64/state_test.c
+++ b/tools/testing/selftests/kvm/x86_64/state_test.c
@@ -134,11 +134,6 @@ int main(int argc, char *argv[])
 
 	struct kvm_cpuid_entry2 *entry = kvm_get_supported_cpuid_entry(1);
 
-	if (!kvm_check_cap(KVM_CAP_IMMEDIATE_EXIT)) {
-		fprintf(stderr, "immediate_exit not available, skipping test\n");
-		exit(KSFT_SKIP);
-	}
-
 	/* Create VM */
 	vm = vm_create_default(VCPU_ID, 0, guest_code);
 	vcpu_set_cpuid(vm, VCPU_ID, kvm_get_supported_cpuid());
@@ -179,18 +174,10 @@ int main(int argc, char *argv[])
 			    uc.args[1] == stage, "Unexpected register values vmexit #%lx, got %lx",
 			    stage, (ulong)uc.args[1]);
 
-		/*
-		 * When KVM exits to userspace with KVM_EXIT_IO, KVM guarantees
-		 * guest state is consistent only after userspace re-enters the
-		 * kernel with KVM_RUN.  Complete IO prior to migrating state
-		 * to a new VM.
-		 */
-		vcpu_run_complete_io(vm, VCPU_ID);
-
+		state = vcpu_save_state(vm, VCPU_ID);
 		memset(&regs1, 0, sizeof(regs1));
 		vcpu_regs_get(vm, VCPU_ID, &regs1);
 
-		state = vcpu_save_state(vm, VCPU_ID);
 		kvm_vm_release(vm);
 
 		/* Restore state in a new VM.  */
diff --git a/virt/kvm/irqchip.c b/virt/kvm/irqchip.c
index 3547b0d8c91ea2c84e0869b769e9947829fe4286..79e59e4fa3dc6be751079e669e214b7fc614e07f 100644
--- a/virt/kvm/irqchip.c
+++ b/virt/kvm/irqchip.c
@@ -144,18 +144,19 @@ static int setup_routing_entry(struct kvm *kvm,
 {
 	struct kvm_kernel_irq_routing_entry *ei;
 	int r;
+	u32 gsi = array_index_nospec(ue->gsi, KVM_MAX_IRQ_ROUTES);
 
 	/*
 	 * Do not allow GSI to be mapped to the same irqchip more than once.
 	 * Allow only one to one mapping between GSI and non-irqchip routing.
 	 */
-	hlist_for_each_entry(ei, &rt->map[ue->gsi], link)
+	hlist_for_each_entry(ei, &rt->map[gsi], link)
 		if (ei->type != KVM_IRQ_ROUTING_IRQCHIP ||
 		    ue->type != KVM_IRQ_ROUTING_IRQCHIP ||
 		    ue->u.irqchip.irqchip == ei->irqchip.irqchip)
 			return -EINVAL;
 
-	e->gsi = ue->gsi;
+	e->gsi = gsi;
 	e->type = ue->type;
 	r = kvm_set_routing_entry(kvm, e, ue);
 	if (r)
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 55fe8e20d8fd9b7367619a250dde9076a74bdc6e..dc8edc97ba850384680b56f88063f61bebfc96c8 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -2977,12 +2977,14 @@ static int kvm_ioctl_create_device(struct kvm *kvm,
 	struct kvm_device_ops *ops = NULL;
 	struct kvm_device *dev;
 	bool test = cd->flags & KVM_CREATE_DEVICE_TEST;
+	int type;
 	int ret;
 
 	if (cd->type >= ARRAY_SIZE(kvm_device_ops_table))
 		return -ENODEV;
 
-	ops = kvm_device_ops_table[cd->type];
+	type = array_index_nospec(cd->type, ARRAY_SIZE(kvm_device_ops_table));
+	ops = kvm_device_ops_table[type];
 	if (ops == NULL)
 		return -ENODEV;
 
@@ -2997,7 +2999,7 @@ static int kvm_ioctl_create_device(struct kvm *kvm,
 	dev->kvm = kvm;
 
 	mutex_lock(&kvm->lock);
-	ret = ops->create(dev, cd->type);
+	ret = ops->create(dev, type);
 	if (ret < 0) {
 		mutex_unlock(&kvm->lock);
 		kfree(dev);