diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index cc6467b35a85f6cec9300011cfa0c464574ed5d3..101f53ccf5718d99e5c1e95927b153031114c497 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -2937,6 +2937,8 @@ static void nested_svm_inject_npf_exit(struct kvm_vcpu *vcpu,
 static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 {
 	WARN_ON(mmu_is_nested(vcpu));
+
+	vcpu->arch.mmu = &vcpu->arch.guest_mmu;
 	kvm_init_shadow_mmu(vcpu);
 	vcpu->arch.mmu->set_cr3           = nested_svm_set_tdp_cr3;
 	vcpu->arch.mmu->get_cr3           = nested_svm_get_tdp_cr3;
@@ -2949,6 +2951,7 @@ static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 
 static void nested_svm_uninit_mmu_context(struct kvm_vcpu *vcpu)
 {
+	vcpu->arch.mmu = &vcpu->arch.root_mmu;
 	vcpu->arch.walk_mmu = &vcpu->arch.root_mmu;
 }
 
@@ -3458,7 +3461,6 @@ static void enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb_gpa,
 		svm->vcpu.arch.hflags &= ~HF_HIF_MASK;
 
 	if (nested_vmcb->control.nested_ctl & SVM_NESTED_CTL_NP_ENABLE) {
-		kvm_mmu_unload(&svm->vcpu);
 		svm->nested.nested_cr3 = nested_vmcb->control.nested_cr3;
 		nested_svm_init_mmu_context(&svm->vcpu);
 	}