]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blobdiff - arch/x86/kvm/mmu/mmu.c
KVM: x86/mmu: Use an rwlock for the x86 MMU
[mirror_ubuntu-jammy-kernel.git] / arch / x86 / kvm / mmu / mmu.c
index 4798a4472066d07fc941e1515523306063b8d7bd..329930d57774ca891ac4eff1aea97e4b03502302 100644 (file)
@@ -190,7 +190,7 @@ static void kvm_flush_remote_tlbs_with_range(struct kvm *kvm,
        int ret = -ENOTSUPP;
 
        if (range && kvm_x86_ops.tlb_remote_flush_with_range)
-               ret = kvm_x86_ops.tlb_remote_flush_with_range(kvm, range);
+               ret = static_call(kvm_x86_tlb_remote_flush_with_range)(kvm, range);
 
        if (ret)
                kvm_flush_remote_tlbs(kvm);
@@ -820,7 +820,7 @@ gfn_to_memslot_dirty_bitmap(struct kvm_vcpu *vcpu, gfn_t gfn,
        slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
        if (!slot || slot->flags & KVM_MEMSLOT_INVALID)
                return NULL;
-       if (no_dirty_log && slot->dirty_bitmap)
+       if (no_dirty_log && kvm_slot_dirty_track_enabled(slot))
                return NULL;
 
        return slot;
@@ -844,17 +844,17 @@ static int pte_list_add(struct kvm_vcpu *vcpu, u64 *spte,
        int i, count = 0;
 
        if (!rmap_head->val) {
-               rmap_printk("pte_list_add: %p %llx 0->1\n", spte, *spte);
+               rmap_printk("%p %llx 0->1\n", spte, *spte);
                rmap_head->val = (unsigned long)spte;
        } else if (!(rmap_head->val & 1)) {
-               rmap_printk("pte_list_add: %p %llx 1->many\n", spte, *spte);
+               rmap_printk("%p %llx 1->many\n", spte, *spte);
                desc = mmu_alloc_pte_list_desc(vcpu);
                desc->sptes[0] = (u64 *)rmap_head->val;
                desc->sptes[1] = spte;
                rmap_head->val = (unsigned long)desc | 1;
                ++count;
        } else {
-               rmap_printk("pte_list_add: %p %llx many->many\n", spte, *spte);
+               rmap_printk("%p %llx many->many\n", spte, *spte);
                desc = (struct pte_list_desc *)(rmap_head->val & ~1ul);
                while (desc->sptes[PTE_LIST_EXT-1]) {
                        count += PTE_LIST_EXT;
@@ -906,14 +906,14 @@ static void __pte_list_remove(u64 *spte, struct kvm_rmap_head *rmap_head)
                pr_err("%s: %p 0->BUG\n", __func__, spte);
                BUG();
        } else if (!(rmap_head->val & 1)) {
-               rmap_printk("%s:  %p 1->0\n", __func__, spte);
+               rmap_printk("%p 1->0\n", spte);
                if ((u64 *)rmap_head->val != spte) {
                        pr_err("%s:  %p 1->BUG\n", __func__, spte);
                        BUG();
                }
                rmap_head->val = 0;
        } else {
-               rmap_printk("%s:  %p many->many\n", __func__, spte);
+               rmap_printk("%p many->many\n", spte);
                desc = (struct pte_list_desc *)(rmap_head->val & ~1ul);
                prev_desc = NULL;
                while (desc) {
@@ -1115,7 +1115,7 @@ static bool spte_write_protect(u64 *sptep, bool pt_protect)
              !(pt_protect && spte_can_locklessly_be_made_writable(spte)))
                return false;
 
-       rmap_printk("rmap_write_protect: spte %p %llx\n", sptep, *sptep);
+       rmap_printk("spte %p %llx\n", sptep, *sptep);
 
        if (pt_protect)
                spte &= ~SPTE_MMU_WRITEABLE;
@@ -1142,7 +1142,7 @@ static bool spte_clear_dirty(u64 *sptep)
 {
        u64 spte = *sptep;
 
-       rmap_printk("rmap_clear_dirty: spte %p %llx\n", sptep, *sptep);
+       rmap_printk("spte %p %llx\n", sptep, *sptep);
 
        MMU_WARN_ON(!spte_ad_enabled(spte));
        spte &= ~shadow_dirty_mask;
@@ -1184,7 +1184,7 @@ static bool spte_set_dirty(u64 *sptep)
 {
        u64 spte = *sptep;
 
-       rmap_printk("rmap_set_dirty: spte %p %llx\n", sptep, *sptep);
+       rmap_printk("spte %p %llx\n", sptep, *sptep);
 
        /*
         * Similar to the !kvm_x86_ops.slot_disable_log_dirty case,
@@ -1283,12 +1283,21 @@ void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                                gfn_t gfn_offset, unsigned long mask)
 {
        if (kvm_x86_ops.enable_log_dirty_pt_masked)
-               kvm_x86_ops.enable_log_dirty_pt_masked(kvm, slot, gfn_offset,
-                               mask);
+               static_call(kvm_x86_enable_log_dirty_pt_masked)(kvm, slot,
+                                                               gfn_offset,
+                                                               mask);
        else
                kvm_mmu_write_protect_pt_masked(kvm, slot, gfn_offset, mask);
 }
 
+int kvm_cpu_dirty_log_size(void)
+{
+       if (kvm_x86_ops.cpu_dirty_log_size)
+               return static_call(kvm_x86_cpu_dirty_log_size)();
+
+       return 0;
+}
+
 bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
                                    struct kvm_memory_slot *slot, u64 gfn)
 {
@@ -1323,7 +1332,7 @@ static bool kvm_zap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head)
        bool flush = false;
 
        while ((sptep = rmap_get_first(rmap_head, &iter))) {
-               rmap_printk("%s: spte %p %llx.\n", __func__, sptep, *sptep);
+               rmap_printk("spte %p %llx.\n", sptep, *sptep);
 
                pte_list_remove(rmap_head, sptep);
                flush = true;
@@ -1355,7 +1364,7 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
 
 restart:
        for_each_rmap_spte(rmap_head, &iter, sptep) {
-               rmap_printk("kvm_set_pte_rmapp: spte %p %llx gfn %llx (%d)\n",
+               rmap_printk("spte %p %llx gfn %llx (%d)\n",
                            sptep, *sptep, gfn, level);
 
                need_flush = 1;
@@ -1715,13 +1724,6 @@ static int nonpaging_sync_page(struct kvm_vcpu *vcpu,
        return 0;
 }
 
-static void nonpaging_update_pte(struct kvm_vcpu *vcpu,
-                                struct kvm_mmu_page *sp, u64 *spte,
-                                const void *pte)
-{
-       WARN_ON(1);
-}
-
 #define KVM_PAGE_ARRAY_NR 16
 
 struct kvm_mmu_pages {
@@ -2008,9 +2010,9 @@ static void mmu_sync_children(struct kvm_vcpu *vcpu,
                        flush |= kvm_sync_page(vcpu, sp, &invalid_list);
                        mmu_pages_clear_parents(&parents);
                }
-               if (need_resched() || spin_needbreak(&vcpu->kvm->mmu_lock)) {
+               if (need_resched() || rwlock_needbreak(&vcpu->kvm->mmu_lock)) {
                        kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
-                       cond_resched_lock(&vcpu->kvm->mmu_lock);
+                       cond_resched_rwlock_write(&vcpu->kvm->mmu_lock);
                        flush = false;
                }
        }
@@ -2409,7 +2411,7 @@ static unsigned long kvm_mmu_zap_oldest_mmu_pages(struct kvm *kvm,
                return 0;
 
 restart:
-       list_for_each_entry_safe(sp, tmp, &kvm->arch.active_mmu_pages, link) {
+       list_for_each_entry_safe_reverse(sp, tmp, &kvm->arch.active_mmu_pages, link) {
                /*
                 * Don't zap active root pages, the page itself can't be freed
                 * and zapping it will just force vCPUs to realloc and reload.
@@ -2462,7 +2464,7 @@ static int make_mmu_pages_available(struct kvm_vcpu *vcpu)
  */
 void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned long goal_nr_mmu_pages)
 {
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
 
        if (kvm->arch.n_used_mmu_pages > goal_nr_mmu_pages) {
                kvm_mmu_zap_oldest_mmu_pages(kvm, kvm->arch.n_used_mmu_pages -
@@ -2473,7 +2475,7 @@ void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned long goal_nr_mmu_pages)
 
        kvm->arch.n_max_mmu_pages = goal_nr_mmu_pages;
 
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 }
 
 int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
@@ -2484,7 +2486,7 @@ int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
 
        pgprintk("%s: looking for gfn %llx\n", __func__, gfn);
        r = 0;
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        for_each_gfn_indirect_valid_sp(kvm, sp, gfn) {
                pgprintk("%s: gfn %llx role %x\n", __func__, gfn,
                         sp->role.word);
@@ -2492,7 +2494,7 @@ int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
                kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list);
        }
        kvm_mmu_commit_zap_page(kvm, &invalid_list);
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 
        return r;
 }
@@ -3184,7 +3186,7 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                        return;
        }
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
 
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
                if (roots_to_free & KVM_MMU_ROOT_PREVIOUS(i))
@@ -3207,7 +3209,7 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
        }
 
        kvm_mmu_commit_zap_page(kvm, &invalid_list);
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_free_roots);
 
@@ -3228,16 +3230,16 @@ static hpa_t mmu_alloc_root(struct kvm_vcpu *vcpu, gfn_t gfn, gva_t gva,
 {
        struct kvm_mmu_page *sp;
 
-       spin_lock(&vcpu->kvm->mmu_lock);
+       write_lock(&vcpu->kvm->mmu_lock);
 
        if (make_mmu_pages_available(vcpu)) {
-               spin_unlock(&vcpu->kvm->mmu_lock);
+               write_unlock(&vcpu->kvm->mmu_lock);
                return INVALID_PAGE;
        }
        sp = kvm_mmu_get_page(vcpu, gfn, gva, level, direct, ACC_ALL);
        ++sp->root_count;
 
-       spin_unlock(&vcpu->kvm->mmu_lock);
+       write_unlock(&vcpu->kvm->mmu_lock);
        return __pa(sp->spt);
 }
 
@@ -3408,17 +3410,17 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
                    !smp_load_acquire(&sp->unsync_children))
                        return;
 
-               spin_lock(&vcpu->kvm->mmu_lock);
+               write_lock(&vcpu->kvm->mmu_lock);
                kvm_mmu_audit(vcpu, AUDIT_PRE_SYNC);
 
                mmu_sync_children(vcpu, sp);
 
                kvm_mmu_audit(vcpu, AUDIT_POST_SYNC);
-               spin_unlock(&vcpu->kvm->mmu_lock);
+               write_unlock(&vcpu->kvm->mmu_lock);
                return;
        }
 
-       spin_lock(&vcpu->kvm->mmu_lock);
+       write_lock(&vcpu->kvm->mmu_lock);
        kvm_mmu_audit(vcpu, AUDIT_PRE_SYNC);
 
        for (i = 0; i < 4; ++i) {
@@ -3432,7 +3434,7 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
        }
 
        kvm_mmu_audit(vcpu, AUDIT_POST_SYNC);
-       spin_unlock(&vcpu->kvm->mmu_lock);
+       write_unlock(&vcpu->kvm->mmu_lock);
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_sync_roots);
 
@@ -3511,7 +3513,7 @@ static int get_walk(struct kvm_vcpu *vcpu, u64 addr, u64 *sptes, int *root_level
        return leaf;
 }
 
-/* return true if reserved bit is detected on spte. */
+/* return true if reserved bit(s) are detected on a valid, non-MMIO SPTE. */
 static bool get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
 {
        u64 sptes[PT64_ROOT_MAX_LEVEL + 1];
@@ -3534,11 +3536,20 @@ static bool get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
                return reserved;
        }
 
+       *sptep = sptes[leaf];
+
+       /*
+        * Skip reserved bits checks on the terminal leaf if it's not a valid
+        * SPTE.  Note, this also (intentionally) skips MMIO SPTEs, which, by
+        * design, always have reserved bits set.  The purpose of the checks is
+        * to detect reserved bits on non-MMIO SPTEs. i.e. buggy SPTEs.
+        */
+       if (!is_shadow_present_pte(sptes[leaf]))
+               leaf++;
+
        rsvd_check = &vcpu->arch.mmu->shadow_zero_check;
 
-       for (level = root; level >= leaf; level--) {
-               if (!is_shadow_present_pte(sptes[level]))
-                       break;
+       for (level = root; level >= leaf; level--)
                /*
                 * Use a bitwise-OR instead of a logical-OR to aggregate the
                 * reserved bit and EPT's invalid memtype/XWR checks to avoid
@@ -3546,7 +3557,6 @@ static bool get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
                 */
                reserved |= __is_bad_mt_xwr(rsvd_check, sptes[level]) |
                            __is_rsvd_bits_set(rsvd_check, sptes[level], level);
-       }
 
        if (reserved) {
                pr_err("%s: detect reserved bits on spte, addr 0x%llx, dump hierarchy:\n",
@@ -3556,8 +3566,6 @@ static bool get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
                               sptes[level], level);
        }
 
-       *sptep = sptes[leaf];
-
        return reserved;
 }
 
@@ -3710,7 +3718,7 @@ static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
                return r;
 
        r = RET_PF_RETRY;
-       spin_lock(&vcpu->kvm->mmu_lock);
+       write_lock(&vcpu->kvm->mmu_lock);
        if (mmu_notifier_retry(vcpu->kvm, mmu_seq))
                goto out_unlock;
        r = make_mmu_pages_available(vcpu);
@@ -3725,7 +3733,7 @@ static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
                                 prefault, is_tdp);
 
 out_unlock:
-       spin_unlock(&vcpu->kvm->mmu_lock);
+       write_unlock(&vcpu->kvm->mmu_lock);
        kvm_release_pfn_clean(pfn);
        return r;
 }
@@ -3799,7 +3807,6 @@ static void nonpaging_init_context(struct kvm_vcpu *vcpu,
        context->gva_to_gpa = nonpaging_gva_to_gpa;
        context->sync_page = nonpaging_sync_page;
        context->invlpg = NULL;
-       context->update_pte = nonpaging_update_pte;
        context->root_level = 0;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->direct_map = true;
@@ -4381,7 +4388,6 @@ static void paging64_init_context_common(struct kvm_vcpu *vcpu,
        context->gva_to_gpa = paging64_gva_to_gpa;
        context->sync_page = paging64_sync_page;
        context->invlpg = paging64_invlpg;
-       context->update_pte = paging64_update_pte;
        context->shadow_root_level = level;
        context->direct_map = false;
 }
@@ -4410,7 +4416,6 @@ static void paging32_init_context(struct kvm_vcpu *vcpu,
        context->gva_to_gpa = paging32_gva_to_gpa;
        context->sync_page = paging32_sync_page;
        context->invlpg = paging32_invlpg;
-       context->update_pte = paging32_update_pte;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->direct_map = false;
 }
@@ -4492,7 +4497,6 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
        context->page_fault = kvm_tdp_page_fault;
        context->sync_page = nonpaging_sync_page;
        context->invlpg = NULL;
-       context->update_pte = nonpaging_update_pte;
        context->shadow_root_level = kvm_mmu_get_tdp_level(vcpu);
        context->direct_map = true;
        context->get_guest_pgd = get_cr3;
@@ -4664,7 +4668,6 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
        context->gva_to_gpa = ept_gva_to_gpa;
        context->sync_page = ept_sync_page;
        context->invlpg = ept_invlpg;
-       context->update_pte = ept_update_pte;
        context->root_level = level;
        context->direct_map = false;
        context->mmu_role.as_u64 = new_role.as_u64;
@@ -4797,7 +4800,7 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
        if (r)
                goto out;
        kvm_mmu_load_pgd(vcpu);
-       kvm_x86_ops.tlb_flush_current(vcpu);
+       static_call(kvm_x86_tlb_flush_current)(vcpu);
 out:
        return r;
 }
@@ -4812,19 +4815,6 @@ void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unload);
 
-static void mmu_pte_write_new_pte(struct kvm_vcpu *vcpu,
-                                 struct kvm_mmu_page *sp, u64 *spte,
-                                 const void *new)
-{
-       if (sp->role.level != PG_LEVEL_4K) {
-               ++vcpu->kvm->stat.mmu_pde_zapped;
-               return;
-        }
-
-       ++vcpu->kvm->stat.mmu_pte_updated;
-       vcpu->arch.mmu->update_pte(vcpu, sp, spte, new);
-}
-
 static bool need_remote_flush(u64 old, u64 new)
 {
        if (!is_shadow_present_pte(old))
@@ -4940,22 +4930,6 @@ static u64 *get_written_sptes(struct kvm_mmu_page *sp, gpa_t gpa, int *nspte)
        return spte;
 }
 
-/*
- * Ignore various flags when determining if a SPTE can be immediately
- * overwritten for the current MMU.
- *  - level: explicitly checked in mmu_pte_write_new_pte(), and will never
- *    match the current MMU role, as MMU's level tracks the root level.
- *  - access: updated based on the new guest PTE
- *  - quadrant: handled by get_written_sptes()
- *  - invalid: always false (loop only walks valid shadow pages)
- */
-static const union kvm_mmu_page_role role_ign = {
-       .level = 0xf,
-       .access = 0x7,
-       .quadrant = 0x3,
-       .invalid = 0x1,
-};
-
 static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                              const u8 *new, int bytes,
                              struct kvm_page_track_notifier_node *node)
@@ -4985,7 +4959,7 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
         */
        mmu_topup_memory_caches(vcpu, true);
 
-       spin_lock(&vcpu->kvm->mmu_lock);
+       write_lock(&vcpu->kvm->mmu_lock);
 
        gentry = mmu_pte_write_fetch_gpte(vcpu, &gpa, &bytes);
 
@@ -5006,14 +4980,10 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 
                local_flush = true;
                while (npte--) {
-                       u32 base_role = vcpu->arch.mmu->mmu_role.base.word;
-
                        entry = *spte;
                        mmu_page_zap_pte(vcpu->kvm, sp, spte, NULL);
-                       if (gentry &&
-                           !((sp->role.word ^ base_role) & ~role_ign.word) &&
-                           rmap_can_add(vcpu))
-                               mmu_pte_write_new_pte(vcpu, sp, spte, &gentry);
+                       if (gentry && sp->role.level != PG_LEVEL_4K)
+                               ++vcpu->kvm->stat.mmu_pde_zapped;
                        if (need_remote_flush(entry, *spte))
                                remote_flush = true;
                        ++spte;
@@ -5021,7 +4991,7 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        }
        kvm_mmu_flush_or_zap(vcpu, &invalid_list, remote_flush, local_flush);
        kvm_mmu_audit(vcpu, AUDIT_POST_PTE_WRITE);
-       spin_unlock(&vcpu->kvm->mmu_lock);
+       write_unlock(&vcpu->kvm->mmu_lock);
 }
 
 int kvm_mmu_unprotect_page_virt(struct kvm_vcpu *vcpu, gva_t gva)
@@ -5111,7 +5081,7 @@ void kvm_mmu_invalidate_gva(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                if (is_noncanonical_address(gva, vcpu))
                        return;
 
-               kvm_x86_ops.tlb_flush_gva(vcpu, gva);
+               static_call(kvm_x86_tlb_flush_gva)(vcpu, gva);
        }
 
        if (!mmu->invlpg)
@@ -5168,7 +5138,7 @@ void kvm_mmu_invpcid_gva(struct kvm_vcpu *vcpu, gva_t gva, unsigned long pcid)
        }
 
        if (tlb_flush)
-               kvm_x86_ops.tlb_flush_gva(vcpu, gva);
+               static_call(kvm_x86_tlb_flush_gva)(vcpu, gva);
 
        ++vcpu->stat.invlpg;
 
@@ -5219,14 +5189,14 @@ slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
                if (iterator.rmap)
                        flush |= fn(kvm, iterator.rmap);
 
-               if (need_resched() || spin_needbreak(&kvm->mmu_lock)) {
+               if (need_resched() || rwlock_needbreak(&kvm->mmu_lock)) {
                        if (flush && lock_flush_tlb) {
                                kvm_flush_remote_tlbs_with_address(kvm,
                                                start_gfn,
                                                iterator.gfn - start_gfn + 1);
                                flush = false;
                        }
-                       cond_resched_lock(&kvm->mmu_lock);
+                       cond_resched_rwlock_write(&kvm->mmu_lock);
                }
        }
 
@@ -5376,7 +5346,7 @@ restart:
                 * be in active use by the guest.
                 */
                if (batch >= BATCH_ZAP_PAGES &&
-                   cond_resched_lock(&kvm->mmu_lock)) {
+                   cond_resched_rwlock_write(&kvm->mmu_lock)) {
                        batch = 0;
                        goto restart;
                }
@@ -5409,7 +5379,7 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
 {
        lockdep_assert_held(&kvm->slots_lock);
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        trace_kvm_mmu_zap_all_fast(kvm);
 
        /*
@@ -5436,7 +5406,7 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
        if (kvm->arch.tdp_mmu_enabled)
                kvm_tdp_mmu_zap_all(kvm);
 
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 }
 
 static bool kvm_has_zapped_obsolete_pages(struct kvm *kvm)
@@ -5478,7 +5448,7 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
        int i;
        bool flush;
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
                slots = __kvm_memslots(kvm, i);
                kvm_for_each_memslot(memslot, slots) {
@@ -5502,7 +5472,7 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
                        kvm_flush_remote_tlbs(kvm);
        }
 
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 }
 
 static bool slot_rmap_write_protect(struct kvm *kvm,
@@ -5517,12 +5487,12 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
 {
        bool flush;
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
                                start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
        if (kvm->arch.tdp_mmu_enabled)
                flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, PG_LEVEL_4K);
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 
        /*
         * We can flush all the TLBs out of the mmu lock without TLB
@@ -5582,13 +5552,13 @@ void kvm_mmu_zap_collapsible_sptes(struct kvm *kvm,
                                   const struct kvm_memory_slot *memslot)
 {
        /* FIXME: const-ify all uses of struct kvm_memory_slot.  */
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        slot_handle_leaf(kvm, (struct kvm_memory_slot *)memslot,
                         kvm_mmu_zap_collapsible_spte, true);
 
        if (kvm->arch.tdp_mmu_enabled)
                kvm_tdp_mmu_zap_collapsible_sptes(kvm, memslot);
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 }
 
 void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
@@ -5611,11 +5581,11 @@ void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
 {
        bool flush;
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty, false);
        if (kvm->arch.tdp_mmu_enabled)
                flush |= kvm_tdp_mmu_clear_dirty_slot(kvm, memslot);
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 
        /*
         * It's also safe to flush TLBs out of mmu lock here as currently this
@@ -5633,12 +5603,12 @@ void kvm_mmu_slot_largepage_remove_write_access(struct kvm *kvm,
 {
        bool flush;
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        flush = slot_handle_large_level(kvm, memslot, slot_rmap_write_protect,
                                        false);
        if (kvm->arch.tdp_mmu_enabled)
                flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, PG_LEVEL_2M);
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 
        if (flush)
                kvm_arch_flush_remote_tlbs_memslot(kvm, memslot);
@@ -5650,11 +5620,11 @@ void kvm_mmu_slot_set_dirty(struct kvm *kvm,
 {
        bool flush;
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
        flush = slot_handle_all_level(kvm, memslot, __rmap_set_dirty, false);
        if (kvm->arch.tdp_mmu_enabled)
                flush |= kvm_tdp_mmu_slot_set_dirty(kvm, memslot);
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 
        if (flush)
                kvm_arch_flush_remote_tlbs_memslot(kvm, memslot);
@@ -5667,14 +5637,14 @@ void kvm_mmu_zap_all(struct kvm *kvm)
        LIST_HEAD(invalid_list);
        int ign;
 
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
 restart:
        list_for_each_entry_safe(sp, node, &kvm->arch.active_mmu_pages, link) {
                if (WARN_ON(sp->role.invalid))
                        continue;
                if (__kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list, &ign))
                        goto restart;
-               if (cond_resched_lock(&kvm->mmu_lock))
+               if (cond_resched_rwlock_write(&kvm->mmu_lock))
                        goto restart;
        }
 
@@ -5683,7 +5653,7 @@ restart:
        if (kvm->arch.tdp_mmu_enabled)
                kvm_tdp_mmu_zap_all(kvm);
 
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
 }
 
 void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm, u64 gen)
@@ -5743,7 +5713,7 @@ mmu_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
                        continue;
 
                idx = srcu_read_lock(&kvm->srcu);
-               spin_lock(&kvm->mmu_lock);
+               write_lock(&kvm->mmu_lock);
 
                if (kvm_has_zapped_obsolete_pages(kvm)) {
                        kvm_mmu_commit_zap_page(kvm,
@@ -5754,7 +5724,7 @@ mmu_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
                freed = kvm_mmu_zap_oldest_mmu_pages(kvm, sc->nr_to_scan);
 
 unlock:
-               spin_unlock(&kvm->mmu_lock);
+               write_unlock(&kvm->mmu_lock);
                srcu_read_unlock(&kvm->srcu, idx);
 
                /*
@@ -5974,7 +5944,7 @@ static void kvm_recover_nx_lpages(struct kvm *kvm)
        ulong to_zap;
 
        rcu_idx = srcu_read_lock(&kvm->srcu);
-       spin_lock(&kvm->mmu_lock);
+       write_lock(&kvm->mmu_lock);
 
        ratio = READ_ONCE(nx_huge_pages_recovery_ratio);
        to_zap = ratio ? DIV_ROUND_UP(kvm->stat.nx_lpage_splits, ratio) : 0;
@@ -5991,22 +5961,22 @@ static void kvm_recover_nx_lpages(struct kvm *kvm)
                                      struct kvm_mmu_page,
                                      lpage_disallowed_link);
                WARN_ON_ONCE(!sp->lpage_disallowed);
-               if (sp->tdp_mmu_page)
+               if (sp->tdp_mmu_page) {
                        kvm_tdp_mmu_zap_gfn_range(kvm, sp->gfn,
                                sp->gfn + KVM_PAGES_PER_HPAGE(sp->role.level));
-               else {
+               else {
                        kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list);
                        WARN_ON_ONCE(sp->lpage_disallowed);
                }
 
-               if (need_resched() || spin_needbreak(&kvm->mmu_lock)) {
+               if (need_resched() || rwlock_needbreak(&kvm->mmu_lock)) {
                        kvm_mmu_commit_zap_page(kvm, &invalid_list);
-                       cond_resched_lock(&kvm->mmu_lock);
+                       cond_resched_rwlock_write(&kvm->mmu_lock);
                }
        }
        kvm_mmu_commit_zap_page(kvm, &invalid_list);
 
-       spin_unlock(&kvm->mmu_lock);
+       write_unlock(&kvm->mmu_lock);
        srcu_read_unlock(&kvm->srcu, rcu_idx);
 }