diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 6fc5c389f5a1ba3169befe5147c9f57df6fc6b7b..af9dafa54f852dddd4a664f6bc051505fa8a7787 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -3181,40 +3181,39 @@ static void direct_pte_prefetch(struct kvm_vcpu *vcpu, u64 *sptep)
 	__direct_pte_prefetch(vcpu, sp, sptep);
 }
 
-static int __direct_map(struct kvm_vcpu *vcpu, int write, int map_writable,
-			int level, gfn_t gfn, kvm_pfn_t pfn, bool prefault)
+static int __direct_map(struct kvm_vcpu *vcpu, gpa_t gpa, int write,
+			int map_writable, int level, kvm_pfn_t pfn,
+			bool prefault)
 {
-	struct kvm_shadow_walk_iterator iterator;
+	struct kvm_shadow_walk_iterator it;
 	struct kvm_mmu_page *sp;
-	int emulate = 0;
-	gfn_t pseudo_gfn;
+	int ret;
+	gfn_t gfn = gpa >> PAGE_SHIFT;
+	gfn_t base_gfn = gfn;
 
 	if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
-		return 0;
+		return RET_PF_RETRY;
 
-	for_each_shadow_entry(vcpu, (u64)gfn << PAGE_SHIFT, iterator) {
-		if (iterator.level == level) {
-			emulate = mmu_set_spte(vcpu, iterator.sptep, ACC_ALL,
-					       write, level, gfn, pfn, prefault,
-					       map_writable);
-			direct_pte_prefetch(vcpu, iterator.sptep);
-			++vcpu->stat.pf_fixed;
+	for_each_shadow_entry(vcpu, gpa, it) {
+		base_gfn = gfn & ~(KVM_PAGES_PER_HPAGE(it.level) - 1);
+		if (it.level == level)
 			break;
-		}
 
-		drop_large_spte(vcpu, iterator.sptep);
-		if (!is_shadow_present_pte(*iterator.sptep)) {
-			u64 base_addr = iterator.addr;
+		drop_large_spte(vcpu, it.sptep);
+		if (!is_shadow_present_pte(*it.sptep)) {
+			sp = kvm_mmu_get_page(vcpu, base_gfn, it.addr,
+					      it.level - 1, true, ACC_ALL);
 
-			base_addr &= PT64_LVL_ADDR_MASK(iterator.level);
-			pseudo_gfn = base_addr >> PAGE_SHIFT;
-			sp = kvm_mmu_get_page(vcpu, pseudo_gfn, iterator.addr,
-					      iterator.level - 1, 1, ACC_ALL);
-
-			link_shadow_page(vcpu, iterator.sptep, sp);
+			link_shadow_page(vcpu, it.sptep, sp);
 		}
 	}
-	return emulate;
+
+	ret = mmu_set_spte(vcpu, it.sptep, ACC_ALL,
+			   write, level, base_gfn, pfn, prefault,
+			   map_writable);
+	direct_pte_prefetch(vcpu, it.sptep);
+	++vcpu->stat.pf_fixed;
+	return ret;
 }
 
 static void kvm_send_hwpoison_signal(unsigned long address, struct task_struct *tsk)
@@ -3538,8 +3537,7 @@ static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, u32 error_code,
 		goto out_unlock;
 	if (likely(!force_pt_level))
 		transparent_hugepage_adjust(vcpu, &gfn, &pfn, &level);
-	r = __direct_map(vcpu, write, map_writable, level, gfn, pfn, prefault);
-
+	r = __direct_map(vcpu, v, write, map_writable, level, pfn, prefault);
 out_unlock:
 	spin_unlock(&vcpu->kvm->mmu_lock);
 	kvm_release_pfn_clean(pfn);
@@ -4165,8 +4163,7 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa, u32 error_code,
 		goto out_unlock;
 	if (likely(!force_pt_level))
 		transparent_hugepage_adjust(vcpu, &gfn, &pfn, &level);
-	r = __direct_map(vcpu, write, map_writable, level, gfn, pfn, prefault);
-
+	r = __direct_map(vcpu, gpa, write, map_writable, level, pfn, prefault);
 out_unlock:
 	spin_unlock(&vcpu->kvm->mmu_lock);
 	kvm_release_pfn_clean(pfn);
diff --git a/arch/x86/kvm/paging_tmpl.h b/arch/x86/kvm/paging_tmpl.h
index 2db96401178ed0dfddda846bca23acb4665cd13d..bfd89966832b0207ff6c027340bcdd561ad7f0f9 100644
--- a/arch/x86/kvm/paging_tmpl.h
+++ b/arch/x86/kvm/paging_tmpl.h
@@ -623,6 +623,7 @@ static int FNAME(fetch)(struct kvm_vcpu *vcpu, gva_t addr,
 	struct kvm_shadow_walk_iterator it;
 	unsigned direct_access, access = gw->pt_access;
 	int top_level, ret;
+	gfn_t base_gfn;
 
 	direct_access = gw->pte_access;
 
@@ -667,31 +668,29 @@ static int FNAME(fetch)(struct kvm_vcpu *vcpu, gva_t addr,
 			link_shadow_page(vcpu, it.sptep, sp);
 	}
 
-	for (;
-	     shadow_walk_okay(&it) && it.level > hlevel;
-	     shadow_walk_next(&it)) {
-		gfn_t direct_gfn;
+	base_gfn = gw->gfn;
 
+	for (; shadow_walk_okay(&it); shadow_walk_next(&it)) {
 		clear_sp_write_flooding_count(it.sptep);
+		base_gfn = gw->gfn & ~(KVM_PAGES_PER_HPAGE(it.level) - 1);
+		if (it.level == hlevel)
+			break;
+
 		validate_direct_spte(vcpu, it.sptep, direct_access);
 
 		drop_large_spte(vcpu, it.sptep);
 
-		if (is_shadow_present_pte(*it.sptep))
-			continue;
-
-		direct_gfn = gw->gfn & ~(KVM_PAGES_PER_HPAGE(it.level) - 1);
-
-		sp = kvm_mmu_get_page(vcpu, direct_gfn, addr, it.level-1,
-				      true, direct_access);
-		link_shadow_page(vcpu, it.sptep, sp);
+		if (!is_shadow_present_pte(*it.sptep)) {
+			sp = kvm_mmu_get_page(vcpu, base_gfn, addr,
+					      it.level - 1, true, direct_access);
+			link_shadow_page(vcpu, it.sptep, sp);
+		}
 	}
 
-	clear_sp_write_flooding_count(it.sptep);
 	ret = mmu_set_spte(vcpu, it.sptep, gw->pte_access, write_fault,
-			   it.level, gw->gfn, pfn, prefault, map_writable);
+			   it.level, base_gfn, pfn, prefault, map_writable);
 	FNAME(pte_prefetch)(vcpu, gw, it.sptep);
-
+	++vcpu->stat.pf_fixed;
 	return ret;
 
 out_gpte_changed:
@@ -854,7 +853,6 @@ static int FNAME(page_fault)(struct kvm_vcpu *vcpu, gva_t addr, u32 error_code,
 		transparent_hugepage_adjust(vcpu, &walker.gfn, &pfn, &level);
 	r = FNAME(fetch)(vcpu, addr, &walker, write_fault,
 			 level, pfn, map_writable, prefault);
-	++vcpu->stat.pf_fixed;
 	kvm_mmu_audit(vcpu, AUDIT_POST_PAGE_FAULT);
 
 out_unlock: