]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blobdiff - arch/x86/kvm/mmu/mmu.c
KVM: x86/mmu: Batch zap MMU pages when shrinking the slab
[mirror_ubuntu-jammy-kernel.git] / arch / x86 / kvm / mmu / mmu.c
index fd59fee846315d10f6208bcd46f14115cf0fcadd..8083ec32a0dd5fbaaa0026dee1e2c875cfc9be3e 100644 (file)
@@ -16,6 +16,7 @@
  */
 
 #include "irq.h"
+#include "ioapic.h"
 #include "mmu.h"
 #include "x86.h"
 #include "kvm_cache_regs.h"
@@ -78,6 +79,9 @@ module_param_cb(nx_huge_pages_recovery_ratio, &nx_huge_pages_recovery_ratio_ops,
                &nx_huge_pages_recovery_ratio, 0644);
 __MODULE_PARM_TYPE(nx_huge_pages_recovery_ratio, "uint");
 
+static bool __read_mostly force_flush_and_sync_on_reuse;
+module_param_named(flush_on_reuse, force_flush_and_sync_on_reuse, bool, 0644);
+
 /*
  * When setting this variable to true it enables Two-Dimensional-Paging
  * where the hardware walks 2 page tables:
@@ -244,7 +248,6 @@ static u64 __read_mostly shadow_x_mask;     /* mutual exclusive with nx_mask */
 static u64 __read_mostly shadow_user_mask;
 static u64 __read_mostly shadow_accessed_mask;
 static u64 __read_mostly shadow_dirty_mask;
-static u64 __read_mostly shadow_mmio_mask;
 static u64 __read_mostly shadow_mmio_value;
 static u64 __read_mostly shadow_mmio_access_mask;
 static u64 __read_mostly shadow_present_mask;
@@ -331,19 +334,19 @@ static void kvm_flush_remote_tlbs_with_address(struct kvm *kvm,
        kvm_flush_remote_tlbs_with_range(kvm, &range);
 }
 
-void kvm_mmu_set_mmio_spte_mask(u64 mmio_mask, u64 mmio_value, u64 access_mask)
+void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 access_mask)
 {
        BUG_ON((u64)(unsigned)access_mask != access_mask);
-       BUG_ON((mmio_mask & mmio_value) != mmio_value);
+       WARN_ON(mmio_value & (shadow_nonpresent_or_rsvd_mask << shadow_nonpresent_or_rsvd_mask_len));
+       WARN_ON(mmio_value & shadow_nonpresent_or_rsvd_lower_gfn_mask);
        shadow_mmio_value = mmio_value | SPTE_MMIO_MASK;
-       shadow_mmio_mask = mmio_mask | SPTE_SPECIAL_MASK;
        shadow_mmio_access_mask = access_mask;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mmio_spte_mask);
 
 static bool is_mmio_spte(u64 spte)
 {
-       return (spte & shadow_mmio_mask) == shadow_mmio_value;
+       return (spte & SPTE_SPECIAL_MASK) == SPTE_MMIO_MASK;
 }
 
 static inline bool sp_ad_disabled(struct kvm_mmu_page *sp)
@@ -566,7 +569,6 @@ static void kvm_mmu_reset_all_pte_masks(void)
        shadow_dirty_mask = 0;
        shadow_nx_mask = 0;
        shadow_x_mask = 0;
-       shadow_mmio_mask = 0;
        shadow_present_mask = 0;
        shadow_acc_track_mask = 0;
 
@@ -583,16 +585,15 @@ static void kvm_mmu_reset_all_pte_masks(void)
         * the most significant bits of legal physical address space.
         */
        shadow_nonpresent_or_rsvd_mask = 0;
-       low_phys_bits = boot_cpu_data.x86_cache_bits;
-       if (boot_cpu_data.x86_cache_bits <
-           52 - shadow_nonpresent_or_rsvd_mask_len) {
+       low_phys_bits = boot_cpu_data.x86_phys_bits;
+       if (boot_cpu_has_bug(X86_BUG_L1TF) &&
+           !WARN_ON_ONCE(boot_cpu_data.x86_cache_bits >=
+                         52 - shadow_nonpresent_or_rsvd_mask_len)) {
+               low_phys_bits = boot_cpu_data.x86_cache_bits
+                       - shadow_nonpresent_or_rsvd_mask_len;
                shadow_nonpresent_or_rsvd_mask =
-                       rsvd_bits(boot_cpu_data.x86_cache_bits -
-                                 shadow_nonpresent_or_rsvd_mask_len,
-                                 boot_cpu_data.x86_cache_bits - 1);
-               low_phys_bits -= shadow_nonpresent_or_rsvd_mask_len;
-       } else
-               WARN_ON_ONCE(boot_cpu_has_bug(X86_BUG_L1TF));
+                       rsvd_bits(low_phys_bits, boot_cpu_data.x86_cache_bits - 1);
+       }
 
        shadow_nonpresent_or_rsvd_lower_gfn_mask =
                GENMASK_ULL(low_phys_bits - 1, PAGE_SHIFT);
@@ -620,7 +621,7 @@ static int is_large_pte(u64 pte)
 
 static int is_last_spte(u64 pte, int level)
 {
-       if (level == PT_PAGE_TABLE_LEVEL)
+       if (level == PG_LEVEL_4K)
                return 1;
        if (is_large_pte(pte))
                return 1;
@@ -1196,7 +1197,7 @@ static void update_gfn_disallow_lpage_count(struct kvm_memory_slot *slot,
        struct kvm_lpage_info *linfo;
        int i;
 
-       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
+       for (i = PG_LEVEL_2M; i <= KVM_MAX_HUGEPAGE_LEVEL; ++i) {
                linfo = lpage_info_slot(gfn, slot, i);
                linfo->disallow_lpage += count;
                WARN_ON(linfo->disallow_lpage < 0);
@@ -1225,7 +1226,7 @@ static void account_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
        slot = __gfn_to_memslot(slots, gfn);
 
        /* the non-leaf shadow pages are keeping readonly. */
-       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+       if (sp->role.level > PG_LEVEL_4K)
                return kvm_slot_page_track_add_page(kvm, slot, gfn,
                                                    KVM_PAGE_TRACK_WRITE);
 
@@ -1253,7 +1254,7 @@ static void unaccount_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
        gfn = sp->gfn;
        slots = kvm_memslots_for_spte_role(kvm, sp->role);
        slot = __gfn_to_memslot(slots, gfn);
-       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+       if (sp->role.level > PG_LEVEL_4K)
                return kvm_slot_page_track_remove_page(kvm, slot, gfn,
                                                       KVM_PAGE_TRACK_WRITE);
 
@@ -1398,7 +1399,7 @@ static struct kvm_rmap_head *__gfn_to_rmap(gfn_t gfn, int level,
        unsigned long idx;
 
        idx = gfn_to_index(gfn, slot->base_gfn, level);
-       return &slot->arch.rmap[level - PT_PAGE_TABLE_LEVEL][idx];
+       return &slot->arch.rmap[level - PG_LEVEL_4K][idx];
 }
 
 static struct kvm_rmap_head *gfn_to_rmap(struct kvm *kvm, gfn_t gfn,
@@ -1529,8 +1530,7 @@ static void drop_spte(struct kvm *kvm, u64 *sptep)
 static bool __drop_large_spte(struct kvm *kvm, u64 *sptep)
 {
        if (is_large_pte(*sptep)) {
-               WARN_ON(page_header(__pa(sptep))->role.level ==
-                       PT_PAGE_TABLE_LEVEL);
+               WARN_ON(page_header(__pa(sptep))->role.level == PG_LEVEL_4K);
                drop_spte(kvm, sptep);
                --kvm->stat.lpages;
                return true;
@@ -1682,7 +1682,7 @@ static void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
 
        while (mask) {
                rmap_head = __gfn_to_rmap(slot->base_gfn + gfn_offset + __ffs(mask),
-                                         PT_PAGE_TABLE_LEVEL, slot);
+                                         PG_LEVEL_4K, slot);
                __rmap_write_protect(kvm, rmap_head, false);
 
                /* clear the first set bit */
@@ -1708,7 +1708,7 @@ void kvm_mmu_clear_dirty_pt_masked(struct kvm *kvm,
 
        while (mask) {
                rmap_head = __gfn_to_rmap(slot->base_gfn + gfn_offset + __ffs(mask),
-                                         PT_PAGE_TABLE_LEVEL, slot);
+                                         PG_LEVEL_4K, slot);
                __rmap_clear_dirty(kvm, rmap_head);
 
                /* clear the first set bit */
@@ -1738,21 +1738,6 @@ void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                kvm_mmu_write_protect_pt_masked(kvm, slot, gfn_offset, mask);
 }
 
-/**
- * kvm_arch_write_log_dirty - emulate dirty page logging
- * @vcpu: Guest mode vcpu
- *
- * Emulate arch specific page modification logging for the
- * nested hypervisor
- */
-int kvm_arch_write_log_dirty(struct kvm_vcpu *vcpu)
-{
-       if (kvm_x86_ops.write_log_dirty)
-               return kvm_x86_ops.write_log_dirty(vcpu);
-
-       return 0;
-}
-
 bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
                                    struct kvm_memory_slot *slot, u64 gfn)
 {
@@ -1760,7 +1745,7 @@ bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
        int i;
        bool write_protected = false;
 
-       for (i = PT_PAGE_TABLE_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
+       for (i = PG_LEVEL_4K; i <= KVM_MAX_HUGEPAGE_LEVEL; ++i) {
                rmap_head = __gfn_to_rmap(gfn, i, slot);
                write_protected |= __rmap_write_protect(kvm, rmap_head, true);
        }
@@ -1948,8 +1933,8 @@ static int kvm_handle_hva_range(struct kvm *kvm,
                        gfn_start = hva_to_gfn_memslot(hva_start, memslot);
                        gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
 
-                       for_each_slot_rmap_range(memslot, PT_PAGE_TABLE_LEVEL,
-                                                PT_MAX_HUGEPAGE_LEVEL,
+                       for_each_slot_rmap_range(memslot, PG_LEVEL_4K,
+                                                KVM_MAX_HUGEPAGE_LEVEL,
                                                 gfn_start, gfn_end - 1,
                                                 &iterator)
                                ret |= handler(kvm, iterator.rmap, memslot,
@@ -2153,10 +2138,6 @@ static int nonpaging_sync_page(struct kvm_vcpu *vcpu,
        return 0;
 }
 
-static void nonpaging_invlpg(struct kvm_vcpu *vcpu, gva_t gva, hpa_t root)
-{
-}
-
 static void nonpaging_update_pte(struct kvm_vcpu *vcpu,
                                 struct kvm_mmu_page *sp, u64 *spte,
                                 const void *pte)
@@ -2262,15 +2243,14 @@ static bool kvm_mmu_prepare_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp,
 static void kvm_mmu_commit_zap_page(struct kvm *kvm,
                                    struct list_head *invalid_list);
 
-
-#define for_each_valid_sp(_kvm, _sp, _gfn)                             \
-       hlist_for_each_entry(_sp,                                       \
-         &(_kvm)->arch.mmu_page_hash[kvm_page_table_hashfn(_gfn)], hash_link) \
+#define for_each_valid_sp(_kvm, _sp, _list)                            \
+       hlist_for_each_entry(_sp, _list, hash_link)                     \
                if (is_obsolete_sp((_kvm), (_sp))) {                    \
                } else
 
 #define for_each_gfn_indirect_valid_sp(_kvm, _sp, _gfn)                        \
-       for_each_valid_sp(_kvm, _sp, _gfn)                              \
+       for_each_valid_sp(_kvm, _sp,                                    \
+         &(_kvm)->arch.mmu_page_hash[kvm_page_table_hashfn(_gfn)])     \
                if ((_sp)->gfn != (_gfn) || (_sp)->role.direct) {} else
 
 static inline bool is_ept_sp(struct kvm_mmu_page *sp)
@@ -2313,7 +2293,7 @@ static void kvm_mmu_flush_or_zap(struct kvm_vcpu *vcpu,
                return;
 
        if (local_flush)
-               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
+               kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
 }
 
 #ifdef CONFIG_KVM_MMU_AUDIT
@@ -2347,7 +2327,7 @@ static bool kvm_sync_pages(struct kvm_vcpu *vcpu, gfn_t gfn,
                if (!s->unsync)
                        continue;
 
-               WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
+               WARN_ON(s->role.level != PG_LEVEL_4K);
                ret |= kvm_sync_page(vcpu, s, invalid_list);
        }
 
@@ -2376,7 +2356,7 @@ static int mmu_pages_next(struct kvm_mmu_pages *pvec,
                int level = sp->role.level;
 
                parents->idx[level-1] = idx;
-               if (level == PT_PAGE_TABLE_LEVEL)
+               if (level == PG_LEVEL_4K)
                        break;
 
                parents->parent[level-2] = sp;
@@ -2398,7 +2378,7 @@ static int mmu_pages_first(struct kvm_mmu_pages *pvec,
 
        sp = pvec->page[0].sp;
        level = sp->role.level;
-       WARN_ON(level == PT_PAGE_TABLE_LEVEL);
+       WARN_ON(level == PG_LEVEL_4K);
 
        parents->parent[level-2] = sp;
 
@@ -2480,7 +2460,9 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                                             int direct,
                                             unsigned int access)
 {
+       bool direct_mmu = vcpu->arch.mmu->direct_map;
        union kvm_mmu_page_role role;
+       struct hlist_head *sp_list;
        unsigned quadrant;
        struct kvm_mmu_page *sp;
        bool need_sync = false;
@@ -2494,13 +2476,14 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        if (role.direct)
                role.gpte_is_8_bytes = true;
        role.access = access;
-       if (!vcpu->arch.mmu->direct_map
-           && vcpu->arch.mmu->root_level <= PT32_ROOT_LEVEL) {
+       if (!direct_mmu && vcpu->arch.mmu->root_level <= PT32_ROOT_LEVEL) {
                quadrant = gaddr >> (PAGE_SHIFT + (PT64_PT_BITS * level));
                quadrant &= (1 << ((PT32_PT_BITS - PT64_PT_BITS) * level)) - 1;
                role.quadrant = quadrant;
        }
-       for_each_valid_sp(vcpu->kvm, sp, gfn) {
+
+       sp_list = &vcpu->kvm->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)];
+       for_each_valid_sp(vcpu->kvm, sp, sp_list) {
                if (sp->gfn != gfn) {
                        collisions++;
                        continue;
@@ -2512,6 +2495,9 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                if (sp->role.word != role.word)
                        continue;
 
+               if (direct_mmu)
+                       goto trace_get_page;
+
                if (sp->unsync) {
                        /* The page is good, but __kvm_sync_page might still end
                         * up zapping it.  If so, break in order to rebuild it.
@@ -2520,13 +2506,15 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                                break;
 
                        WARN_ON(!list_empty(&invalid_list));
-                       kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
+                       kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
                }
 
                if (sp->unsync_children)
-                       kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
+                       kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
 
                __clear_sp_write_flooding_count(sp);
+
+trace_get_page:
                trace_kvm_mmu_get_page(sp, false);
                goto out;
        }
@@ -2537,8 +2525,7 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
 
        sp->gfn = gfn;
        sp->role = role;
-       hlist_add_head(&sp->hash_link,
-               &vcpu->kvm->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)]);
+       hlist_add_head(&sp->hash_link, sp_list);
        if (!direct) {
                /*
                 * we should do write protection before syncing pages
@@ -2546,11 +2533,10 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                 * be inconsistent with guest page table.
                 */
                account_shadowed(vcpu->kvm, sp);
-               if (level == PT_PAGE_TABLE_LEVEL &&
-                     rmap_write_protect(vcpu, gfn))
+               if (level == PG_LEVEL_4K && rmap_write_protect(vcpu, gfn))
                        kvm_flush_remote_tlbs_with_address(vcpu->kvm, gfn, 1);
 
-               if (level > PT_PAGE_TABLE_LEVEL && need_sync)
+               if (level > PG_LEVEL_4K && need_sync)
                        flush |= kvm_sync_pages(vcpu, gfn, &invalid_list);
        }
        clear_page(sp->spt);
@@ -2601,7 +2587,7 @@ static void shadow_walk_init(struct kvm_shadow_walk_iterator *iterator,
 
 static bool shadow_walk_okay(struct kvm_shadow_walk_iterator *iterator)
 {
-       if (iterator->level < PT_PAGE_TABLE_LEVEL)
+       if (iterator->level < PG_LEVEL_4K)
                return false;
 
        iterator->index = SHADOW_PT_INDEX(iterator->addr, iterator->level);
@@ -2722,7 +2708,7 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
        struct mmu_page_path parents;
        struct kvm_mmu_pages pages;
 
-       if (parent->role.level == PT_PAGE_TABLE_LEVEL)
+       if (parent->role.level == PG_LEVEL_4K)
                return 0;
 
        while (mmu_unsync_walk(parent, &pages)) {
@@ -2762,10 +2748,23 @@ static bool __kvm_mmu_prepare_zap_page(struct kvm *kvm,
        if (!sp->root_count) {
                /* Count self */
                (*nr_zapped)++;
-               list_move(&sp->link, invalid_list);
+
+               /*
+                * Already invalid pages (previously active roots) are not on
+                * the active page list.  See list_del() in the "else" case of
+                * !sp->root_count.
+                */
+               if (sp->role.invalid)
+                       list_add(&sp->link, invalid_list);
+               else
+                       list_move(&sp->link, invalid_list);
                kvm_mod_used_mmu_pages(kvm, -1);
        } else {
-               list_move(&sp->link, &kvm->arch.active_mmu_pages);
+               /*
+                * Remove the active root from the active page list, the root
+                * will be explicitly freed when the root_count hits zero.
+                */
+               list_del(&sp->link);
 
                /*
                 * Obsolete pages cannot be used on any vCPUs, see the comment
@@ -2817,33 +2816,51 @@ static void kvm_mmu_commit_zap_page(struct kvm *kvm,
        }
 }
 
-static bool prepare_zap_oldest_mmu_page(struct kvm *kvm,
-                                       struct list_head *invalid_list)
+static unsigned long kvm_mmu_zap_oldest_mmu_pages(struct kvm *kvm,
+                                                 unsigned long nr_to_zap)
 {
-       struct kvm_mmu_page *sp;
+       unsigned long total_zapped = 0;
+       struct kvm_mmu_page *sp, *tmp;
+       LIST_HEAD(invalid_list);
+       bool unstable;
+       int nr_zapped;
 
        if (list_empty(&kvm->arch.active_mmu_pages))
-               return false;
+               return 0;
+
+restart:
+       list_for_each_entry_safe(sp, tmp, &kvm->arch.active_mmu_pages, link) {
+               /*
+                * Don't zap active root pages, the page itself can't be freed
+                * and zapping it will just force vCPUs to realloc and reload.
+                */
+               if (sp->root_count)
+                       continue;
+
+               unstable = __kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list,
+                                                     &nr_zapped);
+               total_zapped += nr_zapped;
+               if (total_zapped >= nr_to_zap)
+                       break;
+
+               if (unstable)
+                       goto restart;
+       }
 
-       sp = list_last_entry(&kvm->arch.active_mmu_pages,
-                            struct kvm_mmu_page, link);
-       return kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
+       kvm_mmu_commit_zap_page(kvm, &invalid_list);
+
+       kvm->stat.mmu_recycled += total_zapped;
+       return total_zapped;
 }
 
 static int make_mmu_pages_available(struct kvm_vcpu *vcpu)
 {
-       LIST_HEAD(invalid_list);
+       unsigned long avail = kvm_mmu_available_pages(vcpu->kvm);
 
-       if (likely(kvm_mmu_available_pages(vcpu->kvm) >= KVM_MIN_FREE_MMU_PAGES))
+       if (likely(avail >= KVM_MIN_FREE_MMU_PAGES))
                return 0;
 
-       while (kvm_mmu_available_pages(vcpu->kvm) < KVM_REFILL_PAGES) {
-               if (!prepare_zap_oldest_mmu_page(vcpu->kvm, &invalid_list))
-                       break;
-
-               ++vcpu->kvm->stat.mmu_recycled;
-       }
-       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+       kvm_mmu_zap_oldest_mmu_pages(vcpu->kvm, KVM_REFILL_PAGES - avail);
 
        if (!kvm_mmu_available_pages(vcpu->kvm))
                return -ENOSPC;
@@ -2856,17 +2873,12 @@ static int make_mmu_pages_available(struct kvm_vcpu *vcpu)
  */
 void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned long goal_nr_mmu_pages)
 {
-       LIST_HEAD(invalid_list);
-
        spin_lock(&kvm->mmu_lock);
 
        if (kvm->arch.n_used_mmu_pages > goal_nr_mmu_pages) {
-               /* Need to free some mmu pages to achieve the goal. */
-               while (kvm->arch.n_used_mmu_pages > goal_nr_mmu_pages)
-                       if (!prepare_zap_oldest_mmu_page(kvm, &invalid_list))
-                               break;
+               kvm_mmu_zap_oldest_mmu_pages(kvm, kvm->arch.n_used_mmu_pages -
+                                                 goal_nr_mmu_pages);
 
-               kvm_mmu_commit_zap_page(kvm, &invalid_list);
                goal_nr_mmu_pages = kvm->arch.n_used_mmu_pages;
        }
 
@@ -2921,7 +2933,7 @@ static bool mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
                if (sp->unsync)
                        continue;
 
-               WARN_ON(sp->role.level != PT_PAGE_TABLE_LEVEL);
+               WARN_ON(sp->role.level != PG_LEVEL_4K);
                kvm_unsync_page(vcpu, sp);
        }
 
@@ -3020,7 +3032,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        if (!speculative)
                spte |= spte_shadow_accessed_mask(spte);
 
-       if (level > PT_PAGE_TABLE_LEVEL && (pte_access & ACC_EXEC_MASK) &&
+       if (level > PG_LEVEL_4K && (pte_access & ACC_EXEC_MASK) &&
            is_nx_huge_page_enabled()) {
                pte_access &= ~ACC_EXEC_MASK;
        }
@@ -3033,7 +3045,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        if (pte_access & ACC_USER_MASK)
                spte |= shadow_user_mask;
 
-       if (level > PT_PAGE_TABLE_LEVEL)
+       if (level > PG_LEVEL_4K)
                spte |= PT_PAGE_SIZE_MASK;
        if (tdp_enabled)
                spte |= kvm_x86_ops.get_mt_mask(vcpu, gfn,
@@ -3103,8 +3115,7 @@ static int mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                 * If we overwrite a PTE page pointer with a 2MB PMD, unlink
                 * the parent of the now unreachable PTE.
                 */
-               if (level > PT_PAGE_TABLE_LEVEL &&
-                   !is_large_pte(*sptep)) {
+               if (level > PG_LEVEL_4K && !is_large_pte(*sptep)) {
                        struct kvm_mmu_page *child;
                        u64 pte = *sptep;
 
@@ -3125,7 +3136,7 @@ static int mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        if (set_spte_ret & SET_SPTE_WRITE_PROTECTED_PT) {
                if (write_fault)
                        ret = RET_PF_EMULATE;
-               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
+               kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
        }
 
        if (set_spte_ret & SET_SPTE_NEED_REMOTE_TLB_FLUSH || flush)
@@ -3228,7 +3239,7 @@ static void direct_pte_prefetch(struct kvm_vcpu *vcpu, u64 *sptep)
        if (sp_ad_disabled(sp))
                return;
 
-       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+       if (sp->role.level > PG_LEVEL_4K)
                return;
 
        __direct_pte_prefetch(vcpu, sp, sptep);
@@ -3241,12 +3252,8 @@ static int host_pfn_mapping_level(struct kvm_vcpu *vcpu, gfn_t gfn,
        pte_t *pte;
        int level;
 
-       BUILD_BUG_ON(PT_PAGE_TABLE_LEVEL != (int)PG_LEVEL_4K ||
-                    PT_DIRECTORY_LEVEL != (int)PG_LEVEL_2M ||
-                    PT_PDPE_LEVEL != (int)PG_LEVEL_1G);
-
        if (!PageCompound(pfn_to_page(pfn)) && !kvm_is_zone_device_pfn(pfn))
-               return PT_PAGE_TABLE_LEVEL;
+               return PG_LEVEL_4K;
 
        /*
         * Note, using the already-retrieved memslot and __gfn_to_hva_memslot()
@@ -3260,7 +3267,7 @@ static int host_pfn_mapping_level(struct kvm_vcpu *vcpu, gfn_t gfn,
 
        pte = lookup_address_in_mm(vcpu->kvm->mm, hva, &level);
        if (unlikely(!pte))
-               return PT_PAGE_TABLE_LEVEL;
+               return PG_LEVEL_4K;
 
        return level;
 }
@@ -3274,28 +3281,28 @@ static int kvm_mmu_hugepage_adjust(struct kvm_vcpu *vcpu, gfn_t gfn,
        kvm_pfn_t mask;
        int level;
 
-       if (unlikely(max_level == PT_PAGE_TABLE_LEVEL))
-               return PT_PAGE_TABLE_LEVEL;
+       if (unlikely(max_level == PG_LEVEL_4K))
+               return PG_LEVEL_4K;
 
        if (is_error_noslot_pfn(pfn) || kvm_is_reserved_pfn(pfn))
-               return PT_PAGE_TABLE_LEVEL;
+               return PG_LEVEL_4K;
 
        slot = gfn_to_memslot_dirty_bitmap(vcpu, gfn, true);
        if (!slot)
-               return PT_PAGE_TABLE_LEVEL;
+               return PG_LEVEL_4K;
 
        max_level = min(max_level, max_page_level);
-       for ( ; max_level > PT_PAGE_TABLE_LEVEL; max_level--) {
+       for ( ; max_level > PG_LEVEL_4K; max_level--) {
                linfo = lpage_info_slot(gfn, slot, max_level);
                if (!linfo->disallow_lpage)
                        break;
        }
 
-       if (max_level == PT_PAGE_TABLE_LEVEL)
-               return PT_PAGE_TABLE_LEVEL;
+       if (max_level == PG_LEVEL_4K)
+               return PG_LEVEL_4K;
 
        level = host_pfn_mapping_level(vcpu, gfn, pfn, slot);
-       if (level == PT_PAGE_TABLE_LEVEL)
+       if (level == PG_LEVEL_4K)
                return level;
 
        level = min(level, max_level);
@@ -3317,7 +3324,7 @@ static void disallowed_hugepage_adjust(struct kvm_shadow_walk_iterator it,
        int level = *levelp;
        u64 spte = *it.sptep;
 
-       if (it.level == level && level > PT_PAGE_TABLE_LEVEL &&
+       if (it.level == level && level > PG_LEVEL_4K &&
            is_nx_huge_page_enabled() &&
            is_shadow_present_pte(spte) &&
            !is_large_pte(spte)) {
@@ -3574,7 +3581,7 @@ static bool fast_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
                         *
                         * See the comments in kvm_arch_commit_memory_region().
                         */
-                       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+                       if (sp->role.level > PG_LEVEL_4K)
                                break;
                }
 
@@ -3666,7 +3673,7 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                                                           &invalid_list);
                        mmu->root_hpa = INVALID_PAGE;
                }
-               mmu->root_cr3 = 0;
+               mmu->root_pgd = 0;
        }
 
        kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
@@ -3686,58 +3693,64 @@ static int mmu_check_root(struct kvm_vcpu *vcpu, gfn_t root_gfn)
        return ret;
 }
 
-static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
+static hpa_t mmu_alloc_root(struct kvm_vcpu *vcpu, gfn_t gfn, gva_t gva,
+                           u8 level, bool direct)
 {
        struct kvm_mmu_page *sp;
+
+       spin_lock(&vcpu->kvm->mmu_lock);
+
+       if (make_mmu_pages_available(vcpu)) {
+               spin_unlock(&vcpu->kvm->mmu_lock);
+               return INVALID_PAGE;
+       }
+       sp = kvm_mmu_get_page(vcpu, gfn, gva, level, direct, ACC_ALL);
+       ++sp->root_count;
+
+       spin_unlock(&vcpu->kvm->mmu_lock);
+       return __pa(sp->spt);
+}
+
+static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
+{
+       u8 shadow_root_level = vcpu->arch.mmu->shadow_root_level;
+       hpa_t root;
        unsigned i;
 
-       if (vcpu->arch.mmu->shadow_root_level >= PT64_ROOT_4LEVEL) {
-               spin_lock(&vcpu->kvm->mmu_lock);
-               if(make_mmu_pages_available(vcpu) < 0) {
-                       spin_unlock(&vcpu->kvm->mmu_lock);
+       if (shadow_root_level >= PT64_ROOT_4LEVEL) {
+               root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level, true);
+               if (!VALID_PAGE(root))
                        return -ENOSPC;
-               }
-               sp = kvm_mmu_get_page(vcpu, 0, 0,
-                               vcpu->arch.mmu->shadow_root_level, 1, ACC_ALL);
-               ++sp->root_count;
-               spin_unlock(&vcpu->kvm->mmu_lock);
-               vcpu->arch.mmu->root_hpa = __pa(sp->spt);
-       } else if (vcpu->arch.mmu->shadow_root_level == PT32E_ROOT_LEVEL) {
+               vcpu->arch.mmu->root_hpa = root;
+       } else if (shadow_root_level == PT32E_ROOT_LEVEL) {
                for (i = 0; i < 4; ++i) {
-                       hpa_t root = vcpu->arch.mmu->pae_root[i];
+                       MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
 
-                       MMU_WARN_ON(VALID_PAGE(root));
-                       spin_lock(&vcpu->kvm->mmu_lock);
-                       if (make_mmu_pages_available(vcpu) < 0) {
-                               spin_unlock(&vcpu->kvm->mmu_lock);
+                       root = mmu_alloc_root(vcpu, i << (30 - PAGE_SHIFT),
+                                             i << 30, PT32_ROOT_LEVEL, true);
+                       if (!VALID_PAGE(root))
                                return -ENOSPC;
-                       }
-                       sp = kvm_mmu_get_page(vcpu, i << (30 - PAGE_SHIFT),
-                                       i << 30, PT32_ROOT_LEVEL, 1, ACC_ALL);
-                       root = __pa(sp->spt);
-                       ++sp->root_count;
-                       spin_unlock(&vcpu->kvm->mmu_lock);
                        vcpu->arch.mmu->pae_root[i] = root | PT_PRESENT_MASK;
                }
                vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
        } else
                BUG();
 
-       /* root_cr3 is ignored for direct MMUs. */
-       vcpu->arch.mmu->root_cr3 = 0;
+       /* root_pgd is ignored for direct MMUs. */
+       vcpu->arch.mmu->root_pgd = 0;
 
        return 0;
 }
 
 static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
 {
-       struct kvm_mmu_page *sp;
        u64 pdptr, pm_mask;
-       gfn_t root_gfn, root_cr3;
+       gfn_t root_gfn, root_pgd;
+       hpa_t root;
        int i;
 
-       root_cr3 = vcpu->arch.mmu->get_guest_pgd(vcpu);
-       root_gfn = root_cr3 >> PAGE_SHIFT;
+       root_pgd = vcpu->arch.mmu->get_guest_pgd(vcpu);
+       root_gfn = root_pgd >> PAGE_SHIFT;
 
        if (mmu_check_root(vcpu, root_gfn))
                return 1;
@@ -3747,22 +3760,14 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
         * write-protect the guests page table root.
         */
        if (vcpu->arch.mmu->root_level >= PT64_ROOT_4LEVEL) {
-               hpa_t root = vcpu->arch.mmu->root_hpa;
+               MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->root_hpa));
 
-               MMU_WARN_ON(VALID_PAGE(root));
-
-               spin_lock(&vcpu->kvm->mmu_lock);
-               if (make_mmu_pages_available(vcpu) < 0) {
-                       spin_unlock(&vcpu->kvm->mmu_lock);
+               root = mmu_alloc_root(vcpu, root_gfn, 0,
+                                     vcpu->arch.mmu->shadow_root_level, false);
+               if (!VALID_PAGE(root))
                        return -ENOSPC;
-               }
-               sp = kvm_mmu_get_page(vcpu, root_gfn, 0,
-                               vcpu->arch.mmu->shadow_root_level, 0, ACC_ALL);
-               root = __pa(sp->spt);
-               ++sp->root_count;
-               spin_unlock(&vcpu->kvm->mmu_lock);
                vcpu->arch.mmu->root_hpa = root;
-               goto set_root_cr3;
+               goto set_root_pgd;
        }
 
        /*
@@ -3775,9 +3780,7 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
                pm_mask |= PT_ACCESSED_MASK | PT_WRITABLE_MASK | PT_USER_MASK;
 
        for (i = 0; i < 4; ++i) {
-               hpa_t root = vcpu->arch.mmu->pae_root[i];
-
-               MMU_WARN_ON(VALID_PAGE(root));
+               MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
                if (vcpu->arch.mmu->root_level == PT32E_ROOT_LEVEL) {
                        pdptr = vcpu->arch.mmu->get_pdptr(vcpu, i);
                        if (!(pdptr & PT_PRESENT_MASK)) {
@@ -3788,17 +3791,11 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
                        if (mmu_check_root(vcpu, root_gfn))
                                return 1;
                }
-               spin_lock(&vcpu->kvm->mmu_lock);
-               if (make_mmu_pages_available(vcpu) < 0) {
-                       spin_unlock(&vcpu->kvm->mmu_lock);
-                       return -ENOSPC;
-               }
-               sp = kvm_mmu_get_page(vcpu, root_gfn, i << 30, PT32_ROOT_LEVEL,
-                                     0, ACC_ALL);
-               root = __pa(sp->spt);
-               ++sp->root_count;
-               spin_unlock(&vcpu->kvm->mmu_lock);
 
+               root = mmu_alloc_root(vcpu, root_gfn, i << 30,
+                                     PT32_ROOT_LEVEL, false);
+               if (!VALID_PAGE(root))
+                       return -ENOSPC;
                vcpu->arch.mmu->pae_root[i] = root | pm_mask;
        }
        vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
@@ -3828,8 +3825,8 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
                vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->lm_root);
        }
 
-set_root_cr3:
-       vcpu->arch.mmu->root_cr3 = root_cr3;
+set_root_pgd:
+       vcpu->arch.mmu->root_pgd = root_pgd;
 
        return 0;
 }
@@ -4065,8 +4062,8 @@ static void shadow_page_table_clear_flood(struct kvm_vcpu *vcpu, gva_t addr)
        walk_shadow_page_lockless_end(vcpu);
 }
 
-static int kvm_arch_setup_async_pf(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
-                                  gfn_t gfn)
+static bool kvm_arch_setup_async_pf(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
+                                   gfn_t gfn)
 {
        struct kvm_arch_async_pf arch;
 
@@ -4083,18 +4080,16 @@ static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
                         gpa_t cr2_or_gpa, kvm_pfn_t *pfn, bool write,
                         bool *writable)
 {
-       struct kvm_memory_slot *slot;
+       struct kvm_memory_slot *slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
        bool async;
 
-       /*
-        * Don't expose private memslots to L2.
-        */
-       if (is_guest_mode(vcpu) && !kvm_is_visible_gfn(vcpu->kvm, gfn)) {
+       /* Don't expose private memslots to L2. */
+       if (is_guest_mode(vcpu) && !kvm_is_visible_memslot(slot)) {
                *pfn = KVM_PFN_NOSLOT;
+               *writable = false;
                return false;
        }
 
-       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
        async = false;
        *pfn = __gfn_to_pfn_memslot(slot, gfn, false, &async, write, writable);
        if (!async)
@@ -4135,7 +4130,7 @@ static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
                return r;
 
        if (lpage_disallowed)
-               max_level = PT_PAGE_TABLE_LEVEL;
+               max_level = PG_LEVEL_4K;
 
        if (fast_page_fault(vcpu, gpa, error_code))
                return RET_PF_RETRY;
@@ -4171,13 +4166,14 @@ static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa,
 
        /* This path builds a PAE pagetable, we can map 2mb pages at maximum. */
        return direct_page_fault(vcpu, gpa & PAGE_MASK, error_code, prefault,
-                                PT_DIRECTORY_LEVEL, false);
+                                PG_LEVEL_2M, false);
 }
 
 int kvm_handle_page_fault(struct kvm_vcpu *vcpu, u64 error_code,
                                u64 fault_address, char *insn, int insn_len)
 {
        int r = 1;
+       u32 flags = vcpu->arch.apf.host_apf_flags;
 
 #ifndef CONFIG_X86_64
        /* A 64-bit CR2 should be impossible on 32-bit KVM. */
@@ -4186,28 +4182,22 @@ int kvm_handle_page_fault(struct kvm_vcpu *vcpu, u64 error_code,
 #endif
 
        vcpu->arch.l1tf_flush_l1d = true;
-       switch (vcpu->arch.apf.host_apf_reason) {
-       default:
+       if (!flags) {
                trace_kvm_page_fault(fault_address, error_code);
 
                if (kvm_event_needs_reinjection(vcpu))
                        kvm_mmu_unprotect_page_virt(vcpu, fault_address);
                r = kvm_mmu_page_fault(vcpu, fault_address, error_code, insn,
                                insn_len);
-               break;
-       case KVM_PV_REASON_PAGE_NOT_PRESENT:
-               vcpu->arch.apf.host_apf_reason = 0;
-               local_irq_disable();
-               kvm_async_pf_task_wait(fault_address, 0);
-               local_irq_enable();
-               break;
-       case KVM_PV_REASON_PAGE_READY:
-               vcpu->arch.apf.host_apf_reason = 0;
+       } else if (flags & KVM_PV_REASON_PAGE_NOT_PRESENT) {
+               vcpu->arch.apf.host_apf_flags = 0;
                local_irq_disable();
-               kvm_async_pf_task_wake(fault_address);
+               kvm_async_pf_task_wait_schedule(fault_address);
                local_irq_enable();
-               break;
+       } else {
+               WARN_ONCE(1, "Unexpected host async PF flags: %x\n", flags);
        }
+
        return r;
 }
 EXPORT_SYMBOL_GPL(kvm_handle_page_fault);
@@ -4217,8 +4207,8 @@ int kvm_tdp_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 {
        int max_level;
 
-       for (max_level = PT_MAX_HUGEPAGE_LEVEL;
-            max_level > PT_PAGE_TABLE_LEVEL;
+       for (max_level = KVM_MAX_HUGEPAGE_LEVEL;
+            max_level > PG_LEVEL_4K;
             max_level--) {
                int page_num = KVM_PAGES_PER_HPAGE(max_level);
                gfn_t base = (gpa >> PAGE_SHIFT) & ~(page_num - 1);
@@ -4237,7 +4227,7 @@ static void nonpaging_init_context(struct kvm_vcpu *vcpu,
        context->page_fault = nonpaging_page_fault;
        context->gva_to_gpa = nonpaging_gva_to_gpa;
        context->sync_page = nonpaging_sync_page;
-       context->invlpg = nonpaging_invlpg;
+       context->invlpg = NULL;
        context->update_pte = nonpaging_update_pte;
        context->root_level = 0;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
@@ -4245,51 +4235,50 @@ static void nonpaging_init_context(struct kvm_vcpu *vcpu,
        context->nx = false;
 }
 
-static inline bool is_root_usable(struct kvm_mmu_root_info *root, gpa_t cr3,
+static inline bool is_root_usable(struct kvm_mmu_root_info *root, gpa_t pgd,
                                  union kvm_mmu_page_role role)
 {
-       return (role.direct || cr3 == root->cr3) &&
+       return (role.direct || pgd == root->pgd) &&
               VALID_PAGE(root->hpa) && page_header(root->hpa) &&
               role.word == page_header(root->hpa)->role.word;
 }
 
 /*
- * Find out if a previously cached root matching the new CR3/role is available.
+ * Find out if a previously cached root matching the new pgd/role is available.
  * The current root is also inserted into the cache.
  * If a matching root was found, it is assigned to kvm_mmu->root_hpa and true is
  * returned.
  * Otherwise, the LRU root from the cache is assigned to kvm_mmu->root_hpa and
  * false is returned. This root should now be freed by the caller.
  */
-static bool cached_root_available(struct kvm_vcpu *vcpu, gpa_t new_cr3,
+static bool cached_root_available(struct kvm_vcpu *vcpu, gpa_t new_pgd,
                                  union kvm_mmu_page_role new_role)
 {
        uint i;
        struct kvm_mmu_root_info root;
        struct kvm_mmu *mmu = vcpu->arch.mmu;
 
-       root.cr3 = mmu->root_cr3;
+       root.pgd = mmu->root_pgd;
        root.hpa = mmu->root_hpa;
 
-       if (is_root_usable(&root, new_cr3, new_role))
+       if (is_root_usable(&root, new_pgd, new_role))
                return true;
 
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
                swap(root, mmu->prev_roots[i]);
 
-               if (is_root_usable(&root, new_cr3, new_role))
+               if (is_root_usable(&root, new_pgd, new_role))
                        break;
        }
 
        mmu->root_hpa = root.hpa;
-       mmu->root_cr3 = root.cr3;
+       mmu->root_pgd = root.pgd;
 
        return i < KVM_MMU_NUM_PREV_ROOTS;
 }
 
-static bool fast_cr3_switch(struct kvm_vcpu *vcpu, gpa_t new_cr3,
-                           union kvm_mmu_page_role new_role,
-                           bool skip_tlb_flush)
+static bool fast_pgd_switch(struct kvm_vcpu *vcpu, gpa_t new_pgd,
+                           union kvm_mmu_page_role new_role)
 {
        struct kvm_mmu *mmu = vcpu->arch.mmu;
 
@@ -4299,70 +4288,59 @@ static bool fast_cr3_switch(struct kvm_vcpu *vcpu, gpa_t new_cr3,
         * later if necessary.
         */
        if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
-           mmu->root_level >= PT64_ROOT_4LEVEL) {
-               if (mmu_check_root(vcpu, new_cr3 >> PAGE_SHIFT))
-                       return false;
-
-               if (cached_root_available(vcpu, new_cr3, new_role)) {
-                       /*
-                        * It is possible that the cached previous root page is
-                        * obsolete because of a change in the MMU generation
-                        * number. However, changing the generation number is
-                        * accompanied by KVM_REQ_MMU_RELOAD, which will free
-                        * the root set here and allocate a new one.
-                        */
-                       kvm_make_request(KVM_REQ_LOAD_MMU_PGD, vcpu);
-                       if (!skip_tlb_flush) {
-                               kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
-                               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
-                       }
-
-                       /*
-                        * The last MMIO access's GVA and GPA are cached in the
-                        * VCPU. When switching to a new CR3, that GVA->GPA
-                        * mapping may no longer be valid. So clear any cached
-                        * MMIO info even when we don't need to sync the shadow
-                        * page tables.
-                        */
-                       vcpu_clear_mmio_info(vcpu, MMIO_GVA_ANY);
-
-                       __clear_sp_write_flooding_count(
-                               page_header(mmu->root_hpa));
-
-                       return true;
-               }
-       }
+           mmu->root_level >= PT64_ROOT_4LEVEL)
+               return !mmu_check_root(vcpu, new_pgd >> PAGE_SHIFT) &&
+                      cached_root_available(vcpu, new_pgd, new_role);
 
        return false;
 }
 
-static void __kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3,
+static void __kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd,
                              union kvm_mmu_page_role new_role,
-                             bool skip_tlb_flush)
+                             bool skip_tlb_flush, bool skip_mmu_sync)
 {
-       if (!fast_cr3_switch(vcpu, new_cr3, new_role, skip_tlb_flush))
-               kvm_mmu_free_roots(vcpu, vcpu->arch.mmu,
-                                  KVM_MMU_ROOT_CURRENT);
+       if (!fast_pgd_switch(vcpu, new_pgd, new_role)) {
+               kvm_mmu_free_roots(vcpu, vcpu->arch.mmu, KVM_MMU_ROOT_CURRENT);
+               return;
+       }
+
+       /*
+        * It's possible that the cached previous root page is obsolete because
+        * of a change in the MMU generation number. However, changing the
+        * generation number is accompanied by KVM_REQ_MMU_RELOAD, which will
+        * free the root set here and allocate a new one.
+        */
+       kvm_make_request(KVM_REQ_LOAD_MMU_PGD, vcpu);
+
+       if (!skip_mmu_sync || force_flush_and_sync_on_reuse)
+               kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
+       if (!skip_tlb_flush || force_flush_and_sync_on_reuse)
+               kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
+
+       /*
+        * The last MMIO access's GVA and GPA are cached in the VCPU. When
+        * switching to a new CR3, that GVA->GPA mapping may no longer be
+        * valid. So clear any cached MMIO info even when we don't need to sync
+        * the shadow page tables.
+        */
+       vcpu_clear_mmio_info(vcpu, MMIO_GVA_ANY);
+
+       __clear_sp_write_flooding_count(page_header(vcpu->arch.mmu->root_hpa));
 }
 
-void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3, bool skip_tlb_flush)
+void kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd, bool skip_tlb_flush,
+                    bool skip_mmu_sync)
 {
-       __kvm_mmu_new_cr3(vcpu, new_cr3, kvm_mmu_calc_root_page_role(vcpu),
-                         skip_tlb_flush);
+       __kvm_mmu_new_pgd(vcpu, new_pgd, kvm_mmu_calc_root_page_role(vcpu),
+                         skip_tlb_flush, skip_mmu_sync);
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_new_cr3);
+EXPORT_SYMBOL_GPL(kvm_mmu_new_pgd);
 
 static unsigned long get_cr3(struct kvm_vcpu *vcpu)
 {
        return kvm_read_cr3(vcpu);
 }
 
-static void inject_page_fault(struct kvm_vcpu *vcpu,
-                             struct x86_exception *fault)
-{
-       vcpu->arch.mmu->inject_page_fault(vcpu, fault);
-}
-
 static bool sync_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
                           unsigned int access, int *nr_present)
 {
@@ -4391,11 +4369,11 @@ static inline bool is_last_gpte(struct kvm_mmu *mmu,
        gpte &= level - mmu->last_nonleaf_level;
 
        /*
-        * PT_PAGE_TABLE_LEVEL always terminates.  The RHS has bit 7 set
-        * iff level <= PT_PAGE_TABLE_LEVEL, which for our purpose means
-        * level == PT_PAGE_TABLE_LEVEL; set PT_PAGE_SIZE_MASK in gpte then.
+        * PG_LEVEL_4K always terminates.  The RHS has bit 7 set
+        * iff level <= PG_LEVEL_4K, which for our purpose means
+        * level == PG_LEVEL_4K; set PT_PAGE_SIZE_MASK in gpte then.
         */
-       gpte |= level - PT_PAGE_TABLE_LEVEL - 1;
+       gpte |= level - PG_LEVEL_4K - 1;
 
        return gpte & PT_PAGE_SIZE_MASK;
 }
@@ -4483,7 +4461,7 @@ __reset_rsvds_bits_mask(struct kvm_vcpu *vcpu,
                        nonleaf_bit8_rsvd | rsvd_bits(7, 7) |
                        rsvd_bits(maxphyaddr, 51);
                rsvd_check->rsvd_bits_mask[0][2] = exb_bit_rsvd |
-                       nonleaf_bit8_rsvd | gbpages_bit_rsvd |
+                       gbpages_bit_rsvd |
                        rsvd_bits(maxphyaddr, 51);
                rsvd_check->rsvd_bits_mask[0][1] = exb_bit_rsvd |
                        rsvd_bits(maxphyaddr, 51);
@@ -4909,7 +4887,7 @@ kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
        union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, base_only);
 
        role.base.ad_disabled = (shadow_accessed_mask == 0);
-       role.base.level = kvm_x86_ops.get_tdp_level(vcpu);
+       role.base.level = vcpu->arch.tdp_level;
        role.base.direct = true;
        role.base.gpte_is_8_bytes = true;
 
@@ -4928,9 +4906,9 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
        context->mmu_role.as_u64 = new_role.as_u64;
        context->page_fault = kvm_tdp_page_fault;
        context->sync_page = nonpaging_sync_page;
-       context->invlpg = nonpaging_invlpg;
+       context->invlpg = NULL;
        context->update_pte = nonpaging_update_pte;
-       context->shadow_root_level = kvm_x86_ops.get_tdp_level(vcpu);
+       context->shadow_root_level = vcpu->arch.tdp_level;
        context->direct_map = true;
        context->get_guest_pgd = get_cr3;
        context->get_pdptr = kvm_pdptr_read;
@@ -4986,7 +4964,7 @@ kvm_calc_shadow_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
        return role;
 }
 
-void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu)
+void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu, u32 cr0, u32 cr4, u32 efer)
 {
        struct kvm_mmu *context = vcpu->arch.mmu;
        union kvm_mmu_role new_role =
@@ -4995,11 +4973,11 @@ void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu)
        if (new_role.as_u64 == context->mmu_role.as_u64)
                return;
 
-       if (!is_paging(vcpu))
+       if (!(cr0 & X86_CR0_PG))
                nonpaging_init_context(vcpu, context);
-       else if (is_long_mode(vcpu))
+       else if (efer & EFER_LMA)
                paging64_init_context(vcpu, context);
-       else if (is_pae(vcpu))
+       else if (cr4 & X86_CR4_PAE)
                paging32E_init_context(vcpu, context);
        else
                paging32_init_context(vcpu, context);
@@ -5047,7 +5025,7 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
                kvm_calc_shadow_ept_root_page_role(vcpu, accessed_dirty,
                                                   execonly, level);
 
-       __kvm_mmu_new_cr3(vcpu, new_eptp, new_role.base, false);
+       __kvm_mmu_new_pgd(vcpu, new_eptp, new_role.base, true, true);
 
        if (new_role.as_u64 == context->mmu_role.as_u64)
                return;
@@ -5077,7 +5055,11 @@ static void init_kvm_softmmu(struct kvm_vcpu *vcpu)
 {
        struct kvm_mmu *context = vcpu->arch.mmu;
 
-       kvm_init_shadow_mmu(vcpu);
+       kvm_init_shadow_mmu(vcpu,
+                           kvm_read_cr0_bits(vcpu, X86_CR0_PG),
+                           kvm_read_cr4_bits(vcpu, X86_CR4_PAE),
+                           vcpu->arch.efer);
+
        context->get_guest_pgd     = get_cr3;
        context->get_pdptr         = kvm_pdptr_read;
        context->inject_page_fault = kvm_inject_page_fault;
@@ -5096,6 +5078,12 @@ static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu)
        g_context->get_pdptr         = kvm_pdptr_read;
        g_context->inject_page_fault = kvm_inject_page_fault;
 
+       /*
+        * L2 page tables are never shadowed, so there is no need to sync
+        * SPTEs.
+        */
+       g_context->invlpg            = NULL;
+
        /*
         * Note that arch.mmu->gva_to_gpa translates l2_gpa to l1_gpa using
         * L1's nested page tables (e.g. EPT12). The nested translation
@@ -5183,7 +5171,7 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
        if (r)
                goto out;
        kvm_mmu_load_pgd(vcpu);
-       kvm_x86_ops.tlb_flush(vcpu, true);
+       kvm_x86_ops.tlb_flush_current(vcpu);
 out:
        return r;
 }
@@ -5202,7 +5190,7 @@ static void mmu_pte_write_new_pte(struct kvm_vcpu *vcpu,
                                  struct kvm_mmu_page *sp, u64 *spte,
                                  const void *new)
 {
-       if (sp->role.level != PT_PAGE_TABLE_LEVEL) {
+       if (sp->role.level != PG_LEVEL_4K) {
                ++vcpu->kvm->stat.mmu_pde_zapped;
                return;
         }
@@ -5260,7 +5248,7 @@ static bool detect_write_flooding(struct kvm_mmu_page *sp)
         * Skip write-flooding detected for the sp whose level is 1, because
         * it can become unsync, then the guest page is not write-protected.
         */
-       if (sp->role.level == PT_PAGE_TABLE_LEVEL)
+       if (sp->role.level == PG_LEVEL_4K)
                return false;
 
        atomic_inc(&sp->write_flooding_count);
@@ -5497,37 +5485,54 @@ emulate:
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_page_fault);
 
-void kvm_mmu_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
+void kvm_mmu_invalidate_gva(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
+                           gva_t gva, hpa_t root_hpa)
 {
-       struct kvm_mmu *mmu = vcpu->arch.mmu;
        int i;
 
-       /* INVLPG on a * non-canonical address is a NOP according to the SDM.  */
-       if (is_noncanonical_address(gva, vcpu))
+       /* It's actually a GPA for vcpu->arch.guest_mmu.  */
+       if (mmu != &vcpu->arch.guest_mmu) {
+               /* INVLPG on a non-canonical address is a NOP according to the SDM.  */
+               if (is_noncanonical_address(gva, vcpu))
+                       return;
+
+               kvm_x86_ops.tlb_flush_gva(vcpu, gva);
+       }
+
+       if (!mmu->invlpg)
                return;
 
-       mmu->invlpg(vcpu, gva, mmu->root_hpa);
+       if (root_hpa == INVALID_PAGE) {
+               mmu->invlpg(vcpu, gva, mmu->root_hpa);
 
-       /*
-        * INVLPG is required to invalidate any global mappings for the VA,
-        * irrespective of PCID. Since it would take us roughly similar amount
-        * of work to determine whether any of the prev_root mappings of the VA
-        * is marked global, or to just sync it blindly, so we might as well
-        * just always sync it.
-        *
-        * Mappings not reachable via the current cr3 or the prev_roots will be
-        * synced when switching to that cr3, so nothing needs to be done here
-        * for them.
-        */
-       for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
-               if (VALID_PAGE(mmu->prev_roots[i].hpa))
-                       mmu->invlpg(vcpu, gva, mmu->prev_roots[i].hpa);
+               /*
+                * INVLPG is required to invalidate any global mappings for the VA,
+                * irrespective of PCID. Since it would take us roughly similar amount
+                * of work to determine whether any of the prev_root mappings of the VA
+                * is marked global, or to just sync it blindly, so we might as well
+                * just always sync it.
+                *
+                * Mappings not reachable via the current cr3 or the prev_roots will be
+                * synced when switching to that cr3, so nothing needs to be done here
+                * for them.
+                */
+               for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
+                       if (VALID_PAGE(mmu->prev_roots[i].hpa))
+                               mmu->invlpg(vcpu, gva, mmu->prev_roots[i].hpa);
+       } else {
+               mmu->invlpg(vcpu, gva, root_hpa);
+       }
+}
+EXPORT_SYMBOL_GPL(kvm_mmu_invalidate_gva);
 
-       kvm_x86_ops.tlb_flush_gva(vcpu, gva);
+void kvm_mmu_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
+{
+       kvm_mmu_invalidate_gva(vcpu, vcpu->arch.mmu, gva, INVALID_PAGE);
        ++vcpu->stat.invlpg;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_invlpg);
 
+
 void kvm_mmu_invpcid_gva(struct kvm_vcpu *vcpu, gva_t gva, unsigned long pcid)
 {
        struct kvm_mmu *mmu = vcpu->arch.mmu;
@@ -5541,7 +5546,7 @@ void kvm_mmu_invpcid_gva(struct kvm_vcpu *vcpu, gva_t gva, unsigned long pcid)
 
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
                if (VALID_PAGE(mmu->prev_roots[i].hpa) &&
-                   pcid == kvm_get_pcid(vcpu, mmu->prev_roots[i].cr3)) {
+                   pcid == kvm_get_pcid(vcpu, mmu->prev_roots[i].pgd)) {
                        mmu->invlpg(vcpu, gva, mmu->prev_roots[i].hpa);
                        tlb_flush = true;
                }
@@ -5574,9 +5579,9 @@ void kvm_configure_mmu(bool enable_tdp, int tdp_page_level)
        if (tdp_enabled)
                max_page_level = tdp_page_level;
        else if (boot_cpu_has(X86_FEATURE_GBPAGES))
-               max_page_level = PT_PDPE_LEVEL;
+               max_page_level = PG_LEVEL_1G;
        else
-               max_page_level = PT_DIRECTORY_LEVEL;
+               max_page_level = PG_LEVEL_2M;
 }
 EXPORT_SYMBOL_GPL(kvm_configure_mmu);
 
@@ -5632,24 +5637,24 @@ static __always_inline bool
 slot_handle_all_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
                      slot_level_handler fn, bool lock_flush_tlb)
 {
-       return slot_handle_level(kvm, memslot, fn, PT_PAGE_TABLE_LEVEL,
-                                PT_MAX_HUGEPAGE_LEVEL, lock_flush_tlb);
+       return slot_handle_level(kvm, memslot, fn, PG_LEVEL_4K,
+                                KVM_MAX_HUGEPAGE_LEVEL, lock_flush_tlb);
 }
 
 static __always_inline bool
 slot_handle_large_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
                        slot_level_handler fn, bool lock_flush_tlb)
 {
-       return slot_handle_level(kvm, memslot, fn, PT_PAGE_TABLE_LEVEL + 1,
-                                PT_MAX_HUGEPAGE_LEVEL, lock_flush_tlb);
+       return slot_handle_level(kvm, memslot, fn, PG_LEVEL_4K + 1,
+                                KVM_MAX_HUGEPAGE_LEVEL, lock_flush_tlb);
 }
 
 static __always_inline bool
 slot_handle_leaf(struct kvm *kvm, struct kvm_memory_slot *memslot,
                 slot_level_handler fn, bool lock_flush_tlb)
 {
-       return slot_handle_level(kvm, memslot, fn, PT_PAGE_TABLE_LEVEL,
-                                PT_PAGE_TABLE_LEVEL, lock_flush_tlb);
+       return slot_handle_level(kvm, memslot, fn, PG_LEVEL_4K,
+                                PG_LEVEL_4K, lock_flush_tlb);
 }
 
 static void free_mmu_pages(struct kvm_mmu *mmu)
@@ -5672,7 +5677,7 @@ static int alloc_mmu_pages(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
         * SVM's 32-bit NPT support, TDP paging doesn't use PAE paging and can
         * skip allocating the PDP table.
         */
-       if (tdp_enabled && kvm_x86_ops.get_tdp_level(vcpu) > PT32E_ROOT_LEVEL)
+       if (tdp_enabled && vcpu->arch.tdp_level > PT32E_ROOT_LEVEL)
                return 0;
 
        page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_DMA32);
@@ -5695,13 +5700,13 @@ int kvm_mmu_create(struct kvm_vcpu *vcpu)
        vcpu->arch.walk_mmu = &vcpu->arch.root_mmu;
 
        vcpu->arch.root_mmu.root_hpa = INVALID_PAGE;
-       vcpu->arch.root_mmu.root_cr3 = 0;
+       vcpu->arch.root_mmu.root_pgd = 0;
        vcpu->arch.root_mmu.translate_gpa = translate_gpa;
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
                vcpu->arch.root_mmu.prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
 
        vcpu->arch.guest_mmu.root_hpa = INVALID_PAGE;
-       vcpu->arch.guest_mmu.root_cr3 = 0;
+       vcpu->arch.guest_mmu.root_pgd = 0;
        vcpu->arch.guest_mmu.translate_gpa = translate_gpa;
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
                vcpu->arch.guest_mmu.prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
@@ -5739,12 +5744,11 @@ restart:
                        break;
 
                /*
-                * Skip invalid pages with a non-zero root count, zapping pages
-                * with a non-zero root count will never succeed, i.e. the page
-                * will get thrown back on active_mmu_pages and we'll get stuck
-                * in an infinite loop.
+                * Invalid pages should never land back on the list of active
+                * pages.  Skip the bogus page, otherwise we'll get stuck in an
+                * infinite loop if the page gets put back on the list (again).
                 */
-               if (sp->role.invalid && sp->root_count)
+               if (WARN_ON(sp->role.invalid))
                        continue;
 
                /*
@@ -5859,7 +5863,8 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
                                continue;
 
                        slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
-                                               PT_PAGE_TABLE_LEVEL, PT_MAX_HUGEPAGE_LEVEL,
+                                               PG_LEVEL_4K,
+                                               KVM_MAX_HUGEPAGE_LEVEL,
                                                start, end - 1, true);
                }
        }
@@ -5881,7 +5886,7 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
 
        spin_lock(&kvm->mmu_lock);
        flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
-                               start_level, PT_MAX_HUGEPAGE_LEVEL, false);
+                               start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
        spin_unlock(&kvm->mmu_lock);
 
        /*
@@ -6021,7 +6026,7 @@ void kvm_mmu_zap_all(struct kvm *kvm)
        spin_lock(&kvm->mmu_lock);
 restart:
        list_for_each_entry_safe(sp, node, &kvm->arch.active_mmu_pages, link) {
-               if (sp->role.invalid && sp->root_count)
+               if (WARN_ON(sp->role.invalid))
                        continue;
                if (__kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list, &ign))
                        goto restart;
@@ -6098,9 +6103,7 @@ mmu_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
                        goto unlock;
                }
 
-               if (prepare_zap_oldest_mmu_page(kvm, &invalid_list))
-                       freed++;
-               kvm_mmu_commit_zap_page(kvm, &invalid_list);
+               freed = kvm_mmu_zap_oldest_mmu_pages(kvm, sc->nr_to_scan);
 
 unlock:
                spin_unlock(&kvm->mmu_lock);
@@ -6142,27 +6145,18 @@ static void kvm_set_mmio_spte_mask(void)
        u64 mask;
 
        /*
-        * Set the reserved bits and the present bit of an paging-structure
-        * entry to generate page fault with PFER.RSV = 1.
+        * Set a reserved PA bit in MMIO SPTEs to generate page faults with
+        * PFEC.RSVD=1 on MMIO accesses.  64-bit PTEs (PAE, x86-64, and EPT
+        * paging) support a maximum of 52 bits of PA, i.e. if the CPU supports
+        * 52-bit physical addresses then there are no reserved PA bits in the
+        * PTEs and so the reserved PA approach must be disabled.
         */
+       if (shadow_phys_bits < 52)
+               mask = BIT_ULL(51) | PT_PRESENT_MASK;
+       else
+               mask = 0;
 
-       /*
-        * Mask the uppermost physical address bit, which would be reserved as
-        * long as the supported physical address width is less than 52.
-        */
-       mask = 1ull << 51;
-
-       /* Set the present bit. */
-       mask |= 1ull;
-
-       /*
-        * If reserved bit is not supported, clear the present bit to disable
-        * mmio page fault.
-        */
-       if (shadow_phys_bits == 52)
-               mask &= ~1ull;
-
-       kvm_mmu_set_mmio_spte_mask(mask, mask, ACC_WRITE_MASK | ACC_USER_MASK);
+       kvm_mmu_set_mmio_spte_mask(mask, ACC_WRITE_MASK | ACC_USER_MASK);
 }
 
 static bool get_nx_auto_mode(void)