]> git.proxmox.com Git - mirror_ubuntu-zesty-kernel.git/blobdiff - fs/dax.c
Merge tag 'ext4_for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/tytso...
[mirror_ubuntu-zesty-kernel.git] / fs / dax.c
index 6916ed37d4631846a7478cfa3e3dd14c8808d373..5ae8e11ad78677ef3103fc569d0959742a380370 100644 (file)
--- a/fs/dax.c
+++ b/fs/dax.c
 #include <linux/iomap.h>
 #include "internal.h"
 
-/*
- * We use lowest available bit in exceptional entry for locking, other two
- * bits to determine entry type. In total 3 special bits.
- */
-#define RADIX_DAX_SHIFT        (RADIX_TREE_EXCEPTIONAL_SHIFT + 3)
-#define RADIX_DAX_PTE (1 << (RADIX_TREE_EXCEPTIONAL_SHIFT + 1))
-#define RADIX_DAX_PMD (1 << (RADIX_TREE_EXCEPTIONAL_SHIFT + 2))
-#define RADIX_DAX_TYPE_MASK (RADIX_DAX_PTE | RADIX_DAX_PMD)
-#define RADIX_DAX_TYPE(entry) ((unsigned long)entry & RADIX_DAX_TYPE_MASK)
-#define RADIX_DAX_SECTOR(entry) (((unsigned long)entry >> RADIX_DAX_SHIFT))
-#define RADIX_DAX_ENTRY(sector, pmd) ((void *)((unsigned long)sector << \
-               RADIX_DAX_SHIFT | (pmd ? RADIX_DAX_PMD : RADIX_DAX_PTE) | \
-               RADIX_TREE_EXCEPTIONAL_ENTRY))
-
 /* We choose 4096 entries - same as per-zone page wait tables */
 #define DAX_WAIT_TABLE_BITS 12
 #define DAX_WAIT_TABLE_ENTRIES (1 << DAX_WAIT_TABLE_BITS)
 
-wait_queue_head_t wait_table[DAX_WAIT_TABLE_ENTRIES];
+static wait_queue_head_t wait_table[DAX_WAIT_TABLE_ENTRIES];
 
 static int __init init_dax_wait_table(void)
 {
@@ -64,14 +50,6 @@ static int __init init_dax_wait_table(void)
 }
 fs_initcall(init_dax_wait_table);
 
-static wait_queue_head_t *dax_entry_waitqueue(struct address_space *mapping,
-                                             pgoff_t index)
-{
-       unsigned long hash = hash_long((unsigned long)mapping ^ index,
-                                      DAX_WAIT_TABLE_BITS);
-       return wait_table + hash;
-}
-
 static long dax_map_atomic(struct block_device *bdev, struct blk_dax_ctl *dax)
 {
        struct request_queue *q = bdev->bd_queue;
@@ -98,209 +76,52 @@ static void dax_unmap_atomic(struct block_device *bdev,
        blk_queue_exit(bdev->bd_queue);
 }
 
-struct page *read_dax_sector(struct block_device *bdev, sector_t n)
+static int dax_is_pmd_entry(void *entry)
 {
-       struct page *page = alloc_pages(GFP_KERNEL, 0);
-       struct blk_dax_ctl dax = {
-               .size = PAGE_SIZE,
-               .sector = n & ~((((int) PAGE_SIZE) / 512) - 1),
-       };
-       long rc;
-
-       if (!page)
-               return ERR_PTR(-ENOMEM);
-
-       rc = dax_map_atomic(bdev, &dax);
-       if (rc < 0)
-               return ERR_PTR(rc);
-       memcpy_from_pmem(page_address(page), dax.addr, PAGE_SIZE);
-       dax_unmap_atomic(bdev, &dax);
-       return page;
+       return (unsigned long)entry & RADIX_DAX_PMD;
 }
 
-static bool buffer_written(struct buffer_head *bh)
+static int dax_is_pte_entry(void *entry)
 {
-       return buffer_mapped(bh) && !buffer_unwritten(bh);
+       return !((unsigned long)entry & RADIX_DAX_PMD);
 }
 
-/*
- * When ext4 encounters a hole, it returns without modifying the buffer_head
- * which means that we can't trust b_size.  To cope with this, we set b_state
- * to 0 before calling get_block and, if any bit is set, we know we can trust
- * b_size.  Unfortunate, really, since ext4 knows precisely how long a hole is
- * and would save us time calling get_block repeatedly.
- */
-static bool buffer_size_valid(struct buffer_head *bh)
+static int dax_is_zero_entry(void *entry)
 {
-       return bh->b_state != 0;
+       return (unsigned long)entry & RADIX_DAX_HZP;
 }
 
-
-static sector_t to_sector(const struct buffer_head *bh,
-               const struct inode *inode)
+static int dax_is_empty_entry(void *entry)
 {
-       sector_t sector = bh->b_blocknr << (inode->i_blkbits - 9);
-
-       return sector;
+       return (unsigned long)entry & RADIX_DAX_EMPTY;
 }
 
-static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
-                     loff_t start, loff_t end, get_block_t get_block,
-                     struct buffer_head *bh)
+struct page *read_dax_sector(struct block_device *bdev, sector_t n)
 {
-       loff_t pos = start, max = start, bh_max = start;
-       bool hole = false;
-       struct block_device *bdev = NULL;
-       int rw = iov_iter_rw(iter), rc;
-       long map_len = 0;
+       struct page *page = alloc_pages(GFP_KERNEL, 0);
        struct blk_dax_ctl dax = {
-               .addr = ERR_PTR(-EIO),
+               .size = PAGE_SIZE,
+               .sector = n & ~((((int) PAGE_SIZE) / 512) - 1),
        };
-       unsigned blkbits = inode->i_blkbits;
-       sector_t file_blks = (i_size_read(inode) + (1 << blkbits) - 1)
-                                                               >> blkbits;
-
-       if (rw == READ)
-               end = min(end, i_size_read(inode));
-
-       while (pos < end) {
-               size_t len;
-               if (pos == max) {
-                       long page = pos >> PAGE_SHIFT;
-                       sector_t block = page << (PAGE_SHIFT - blkbits);
-                       unsigned first = pos - (block << blkbits);
-                       long size;
-
-                       if (pos == bh_max) {
-                               bh->b_size = PAGE_ALIGN(end - pos);
-                               bh->b_state = 0;
-                               rc = get_block(inode, block, bh, rw == WRITE);
-                               if (rc)
-                                       break;
-                               if (!buffer_size_valid(bh))
-                                       bh->b_size = 1 << blkbits;
-                               bh_max = pos - first + bh->b_size;
-                               bdev = bh->b_bdev;
-                               /*
-                                * We allow uninitialized buffers for writes
-                                * beyond EOF as those cannot race with faults
-                                */
-                               WARN_ON_ONCE(
-                                       (buffer_new(bh) && block < file_blks) ||
-                                       (rw == WRITE && buffer_unwritten(bh)));
-                       } else {
-                               unsigned done = bh->b_size -
-                                               (bh_max - (pos - first));
-                               bh->b_blocknr += done >> blkbits;
-                               bh->b_size -= done;
-                       }
-
-                       hole = rw == READ && !buffer_written(bh);
-                       if (hole) {
-                               size = bh->b_size - first;
-                       } else {
-                               dax_unmap_atomic(bdev, &dax);
-                               dax.sector = to_sector(bh, inode);
-                               dax.size = bh->b_size;
-                               map_len = dax_map_atomic(bdev, &dax);
-                               if (map_len < 0) {
-                                       rc = map_len;
-                                       break;
-                               }
-                               dax.addr += first;
-                               size = map_len - first;
-                       }
-                       /*
-                        * pos + size is one past the last offset for IO,
-                        * so pos + size can overflow loff_t at extreme offsets.
-                        * Cast to u64 to catch this and get the true minimum.
-                        */
-                       max = min_t(u64, pos + size, end);
-               }
-
-               if (iov_iter_rw(iter) == WRITE) {
-                       len = copy_from_iter_pmem(dax.addr, max - pos, iter);
-               } else if (!hole)
-                       len = copy_to_iter((void __force *) dax.addr, max - pos,
-                                       iter);
-               else
-                       len = iov_iter_zero(max - pos, iter);
-
-               if (!len) {
-                       rc = -EFAULT;
-                       break;
-               }
+       long rc;
 
-               pos += len;
-               if (!IS_ERR(dax.addr))
-                       dax.addr += len;
-       }
+       if (!page)
+               return ERR_PTR(-ENOMEM);
 
+       rc = dax_map_atomic(bdev, &dax);
+       if (rc < 0)
+               return ERR_PTR(rc);
+       memcpy_from_pmem(page_address(page), dax.addr, PAGE_SIZE);
        dax_unmap_atomic(bdev, &dax);
-
-       return (pos == start) ? rc : pos - start;
-}
-
-/**
- * dax_do_io - Perform I/O to a DAX file
- * @iocb: The control block for this I/O
- * @inode: The file which the I/O is directed at
- * @iter: The addresses to do I/O from or to
- * @get_block: The filesystem method used to translate file offsets to blocks
- * @end_io: A filesystem callback for I/O completion
- * @flags: See below
- *
- * This function uses the same locking scheme as do_blockdev_direct_IO:
- * If @flags has DIO_LOCKING set, we assume that the i_mutex is held by the
- * caller for writes.  For reads, we take and release the i_mutex ourselves.
- * If DIO_LOCKING is not set, the filesystem takes care of its own locking.
- * As with do_blockdev_direct_IO(), we increment i_dio_count while the I/O
- * is in progress.
- */
-ssize_t dax_do_io(struct kiocb *iocb, struct inode *inode,
-                 struct iov_iter *iter, get_block_t get_block,
-                 dio_iodone_t end_io, int flags)
-{
-       struct buffer_head bh;
-       ssize_t retval = -EINVAL;
-       loff_t pos = iocb->ki_pos;
-       loff_t end = pos + iov_iter_count(iter);
-
-       memset(&bh, 0, sizeof(bh));
-       bh.b_bdev = inode->i_sb->s_bdev;
-
-       if ((flags & DIO_LOCKING) && iov_iter_rw(iter) == READ)
-               inode_lock(inode);
-
-       /* Protects against truncate */
-       if (!(flags & DIO_SKIP_DIO_COUNT))
-               inode_dio_begin(inode);
-
-       retval = dax_io(inode, iter, pos, end, get_block, &bh);
-
-       if ((flags & DIO_LOCKING) && iov_iter_rw(iter) == READ)
-               inode_unlock(inode);
-
-       if (end_io) {
-               int err;
-
-               err = end_io(iocb, pos, retval, bh.b_private);
-               if (err)
-                       retval = err;
-       }
-
-       if (!(flags & DIO_SKIP_DIO_COUNT))
-               inode_dio_end(inode);
-       return retval;
+       return page;
 }
-EXPORT_SYMBOL_GPL(dax_do_io);
 
 /*
  * DAX radix tree locking
  */
 struct exceptional_entry_key {
        struct address_space *mapping;
-       unsigned long index;
+       pgoff_t entry_start;
 };
 
 struct wait_exceptional_entry_queue {
@@ -308,6 +129,26 @@ struct wait_exceptional_entry_queue {
        struct exceptional_entry_key key;
 };
 
+static wait_queue_head_t *dax_entry_waitqueue(struct address_space *mapping,
+               pgoff_t index, void *entry, struct exceptional_entry_key *key)
+{
+       unsigned long hash;
+
+       /*
+        * If 'entry' is a PMD, align the 'index' that we use for the wait
+        * queue to the start of that PMD.  This ensures that all offsets in
+        * the range covered by the PMD map to the same bit lock.
+        */
+       if (dax_is_pmd_entry(entry))
+               index &= ~((1UL << (PMD_SHIFT - PAGE_SHIFT)) - 1);
+
+       key->mapping = mapping;
+       key->entry_start = index;
+
+       hash = hash_long((unsigned long)mapping ^ index, DAX_WAIT_TABLE_BITS);
+       return wait_table + hash;
+}
+
 static int wake_exceptional_entry_func(wait_queue_t *wait, unsigned int mode,
                                       int sync, void *keyp)
 {
@@ -316,7 +157,7 @@ static int wake_exceptional_entry_func(wait_queue_t *wait, unsigned int mode,
                container_of(wait, struct wait_exceptional_entry_queue, wait);
 
        if (key->mapping != ewait->key.mapping ||
-           key->index != ewait->key.index)
+           key->entry_start != ewait->key.entry_start)
                return 0;
        return autoremove_wake_function(wait, mode, sync, NULL);
 }
@@ -372,24 +213,24 @@ static inline void *unlock_slot(struct address_space *mapping, void **slot)
 static void *get_unlocked_mapping_entry(struct address_space *mapping,
                                        pgoff_t index, void ***slotp)
 {
-       void *ret, **slot;
+       void *entry, **slot;
        struct wait_exceptional_entry_queue ewait;
-       wait_queue_head_t *wq = dax_entry_waitqueue(mapping, index);
+       wait_queue_head_t *wq;
 
        init_wait(&ewait.wait);
        ewait.wait.func = wake_exceptional_entry_func;
-       ewait.key.mapping = mapping;
-       ewait.key.index = index;
 
        for (;;) {
-               ret = __radix_tree_lookup(&mapping->page_tree, index, NULL,
+               entry = __radix_tree_lookup(&mapping->page_tree, index, NULL,
                                          &slot);
-               if (!ret || !radix_tree_exceptional_entry(ret) ||
+               if (!entry || !radix_tree_exceptional_entry(entry) ||
                    !slot_locked(mapping, slot)) {
                        if (slotp)
                                *slotp = slot;
-                       return ret;
+                       return entry;
                }
+
+               wq = dax_entry_waitqueue(mapping, index, entry, &ewait.key);
                prepare_to_wait_exclusive(wq, &ewait.wait,
                                          TASK_UNINTERRUPTIBLE);
                spin_unlock_irq(&mapping->tree_lock);
@@ -399,52 +240,156 @@ static void *get_unlocked_mapping_entry(struct address_space *mapping,
        }
 }
 
+static void put_locked_mapping_entry(struct address_space *mapping,
+                                    pgoff_t index, void *entry)
+{
+       if (!radix_tree_exceptional_entry(entry)) {
+               unlock_page(entry);
+               put_page(entry);
+       } else {
+               dax_unlock_mapping_entry(mapping, index);
+       }
+}
+
+/*
+ * Called when we are done with radix tree entry we looked up via
+ * get_unlocked_mapping_entry() and which we didn't lock in the end.
+ */
+static void put_unlocked_mapping_entry(struct address_space *mapping,
+                                      pgoff_t index, void *entry)
+{
+       if (!radix_tree_exceptional_entry(entry))
+               return;
+
+       /* We have to wake up next waiter for the radix tree entry lock */
+       dax_wake_mapping_entry_waiter(mapping, index, entry, false);
+}
+
 /*
  * Find radix tree entry at given index. If it points to a page, return with
  * the page locked. If it points to the exceptional entry, return with the
  * radix tree entry locked. If the radix tree doesn't contain given index,
  * create empty exceptional entry for the index and return with it locked.
  *
+ * When requesting an entry with size RADIX_DAX_PMD, grab_mapping_entry() will
+ * either return that locked entry or will return an error.  This error will
+ * happen if there are any 4k entries (either zero pages or DAX entries)
+ * within the 2MiB range that we are requesting.
+ *
+ * We always favor 4k entries over 2MiB entries. There isn't a flow where we
+ * evict 4k entries in order to 'upgrade' them to a 2MiB entry.  A 2MiB
+ * insertion will fail if it finds any 4k entries already in the tree, and a
+ * 4k insertion will cause an existing 2MiB entry to be unmapped and
+ * downgraded to 4k entries.  This happens for both 2MiB huge zero pages as
+ * well as 2MiB empty entries.
+ *
+ * The exception to this downgrade path is for 2MiB DAX PMD entries that have
+ * real storage backing them.  We will leave these real 2MiB DAX entries in
+ * the tree, and PTE writes will simply dirty the entire 2MiB DAX entry.
+ *
  * Note: Unlike filemap_fault() we don't honor FAULT_FLAG_RETRY flags. For
  * persistent memory the benefit is doubtful. We can add that later if we can
  * show it helps.
  */
-static void *grab_mapping_entry(struct address_space *mapping, pgoff_t index)
+static void *grab_mapping_entry(struct address_space *mapping, pgoff_t index,
+               unsigned long size_flag)
 {
-       void *ret, **slot;
+       bool pmd_downgrade = false; /* splitting 2MiB entry into 4k entries? */
+       void *entry, **slot;
 
 restart:
        spin_lock_irq(&mapping->tree_lock);
-       ret = get_unlocked_mapping_entry(mapping, index, &slot);
+       entry = get_unlocked_mapping_entry(mapping, index, &slot);
+
+       if (entry) {
+               if (size_flag & RADIX_DAX_PMD) {
+                       if (!radix_tree_exceptional_entry(entry) ||
+                           dax_is_pte_entry(entry)) {
+                               put_unlocked_mapping_entry(mapping, index,
+                                               entry);
+                               entry = ERR_PTR(-EEXIST);
+                               goto out_unlock;
+                       }
+               } else { /* trying to grab a PTE entry */
+                       if (radix_tree_exceptional_entry(entry) &&
+                           dax_is_pmd_entry(entry) &&
+                           (dax_is_zero_entry(entry) ||
+                            dax_is_empty_entry(entry))) {
+                               pmd_downgrade = true;
+                       }
+               }
+       }
+
        /* No entry for given index? Make sure radix tree is big enough. */
-       if (!ret) {
+       if (!entry || pmd_downgrade) {
                int err;
 
+               if (pmd_downgrade) {
+                       /*
+                        * Make sure 'entry' remains valid while we drop
+                        * mapping->tree_lock.
+                        */
+                       entry = lock_slot(mapping, slot);
+               }
+
                spin_unlock_irq(&mapping->tree_lock);
+               /*
+                * Besides huge zero pages the only other thing that gets
+                * downgraded are empty entries which don't need to be
+                * unmapped.
+                */
+               if (pmd_downgrade && dax_is_zero_entry(entry))
+                       unmap_mapping_range(mapping,
+                               (index << PAGE_SHIFT) & PMD_MASK, PMD_SIZE, 0);
+
                err = radix_tree_preload(
                                mapping_gfp_mask(mapping) & ~__GFP_HIGHMEM);
-               if (err)
+               if (err) {
+                       if (pmd_downgrade)
+                               put_locked_mapping_entry(mapping, index, entry);
                        return ERR_PTR(err);
-               ret = (void *)(RADIX_TREE_EXCEPTIONAL_ENTRY |
-                              RADIX_DAX_ENTRY_LOCK);
+               }
                spin_lock_irq(&mapping->tree_lock);
-               err = radix_tree_insert(&mapping->page_tree, index, ret);
+
+               if (pmd_downgrade) {
+                       radix_tree_delete(&mapping->page_tree, index);
+                       mapping->nrexceptional--;
+                       dax_wake_mapping_entry_waiter(mapping, index, entry,
+                                       true);
+               }
+
+               entry = dax_radix_locked_entry(0, size_flag | RADIX_DAX_EMPTY);
+
+               err = __radix_tree_insert(&mapping->page_tree, index,
+                               dax_radix_order(entry), entry);
                radix_tree_preload_end();
                if (err) {
                        spin_unlock_irq(&mapping->tree_lock);
-                       /* Someone already created the entry? */
-                       if (err == -EEXIST)
+                       /*
+                        * Someone already created the entry?  This is a
+                        * normal failure when inserting PMDs in a range
+                        * that already contains PTEs.  In that case we want
+                        * to return -EEXIST immediately.
+                        */
+                       if (err == -EEXIST && !(size_flag & RADIX_DAX_PMD))
                                goto restart;
+                       /*
+                        * Our insertion of a DAX PMD entry failed, most
+                        * likely because it collided with a PTE sized entry
+                        * at a different index in the PMD range.  We haven't
+                        * inserted anything into the radix tree and have no
+                        * waiters to wake.
+                        */
                        return ERR_PTR(err);
                }
                /* Good, we have inserted empty locked entry into the tree. */
                mapping->nrexceptional++;
                spin_unlock_irq(&mapping->tree_lock);
-               return ret;
+               return entry;
        }
        /* Normal page in radix tree? */
-       if (!radix_tree_exceptional_entry(ret)) {
-               struct page *page = ret;
+       if (!radix_tree_exceptional_entry(entry)) {
+               struct page *page = entry;
 
                get_page(page);
                spin_unlock_irq(&mapping->tree_lock);
@@ -457,15 +402,26 @@ restart:
                }
                return page;
        }
-       ret = lock_slot(mapping, slot);
+       entry = lock_slot(mapping, slot);
+ out_unlock:
        spin_unlock_irq(&mapping->tree_lock);
-       return ret;
+       return entry;
 }
 
+/*
+ * We do not necessarily hold the mapping->tree_lock when we call this
+ * function so it is possible that 'entry' is no longer a valid item in the
+ * radix tree.  This is okay because all we really need to do is to find the
+ * correct waitqueue where tasks might be waiting for that old 'entry' and
+ * wake them.
+ */
 void dax_wake_mapping_entry_waiter(struct address_space *mapping,
-                                  pgoff_t index, bool wake_all)
+               pgoff_t index, void *entry, bool wake_all)
 {
-       wait_queue_head_t *wq = dax_entry_waitqueue(mapping, index);
+       struct exceptional_entry_key key;
+       wait_queue_head_t *wq;
+
+       wq = dax_entry_waitqueue(mapping, index, entry, &key);
 
        /*
         * Checking for locked entry and prepare_to_wait_exclusive() happens
@@ -473,54 +429,24 @@ void dax_wake_mapping_entry_waiter(struct address_space *mapping,
         * So at this point all tasks that could have seen our entry locked
         * must be in the waitqueue and the following check will see them.
         */
-       if (waitqueue_active(wq)) {
-               struct exceptional_entry_key key;
-
-               key.mapping = mapping;
-               key.index = index;
+       if (waitqueue_active(wq))
                __wake_up(wq, TASK_NORMAL, wake_all ? 0 : 1, &key);
-       }
 }
 
 void dax_unlock_mapping_entry(struct address_space *mapping, pgoff_t index)
 {
-       void *ret, **slot;
+       void *entry, **slot;
 
        spin_lock_irq(&mapping->tree_lock);
-       ret = __radix_tree_lookup(&mapping->page_tree, index, NULL, &slot);
-       if (WARN_ON_ONCE(!ret || !radix_tree_exceptional_entry(ret) ||
+       entry = __radix_tree_lookup(&mapping->page_tree, index, NULL, &slot);
+       if (WARN_ON_ONCE(!entry || !radix_tree_exceptional_entry(entry) ||
                         !slot_locked(mapping, slot))) {
                spin_unlock_irq(&mapping->tree_lock);
                return;
        }
        unlock_slot(mapping, slot);
        spin_unlock_irq(&mapping->tree_lock);
-       dax_wake_mapping_entry_waiter(mapping, index, false);
-}
-
-static void put_locked_mapping_entry(struct address_space *mapping,
-                                    pgoff_t index, void *entry)
-{
-       if (!radix_tree_exceptional_entry(entry)) {
-               unlock_page(entry);
-               put_page(entry);
-       } else {
-               dax_unlock_mapping_entry(mapping, index);
-       }
-}
-
-/*
- * Called when we are done with radix tree entry we looked up via
- * get_unlocked_mapping_entry() and which we didn't lock in the end.
- */
-static void put_unlocked_mapping_entry(struct address_space *mapping,
-                                      pgoff_t index, void *entry)
-{
-       if (!radix_tree_exceptional_entry(entry))
-               return;
-
-       /* We have to wake up next waiter for the radix tree entry lock */
-       dax_wake_mapping_entry_waiter(mapping, index, false);
+       dax_wake_mapping_entry_waiter(mapping, index, entry, false);
 }
 
 /*
@@ -547,7 +473,7 @@ int dax_delete_mapping_entry(struct address_space *mapping, pgoff_t index)
        radix_tree_delete(&mapping->page_tree, index);
        mapping->nrexceptional--;
        spin_unlock_irq(&mapping->tree_lock);
-       dax_wake_mapping_entry_waiter(mapping, index, true);
+       dax_wake_mapping_entry_waiter(mapping, index, entry, true);
 
        return 1;
 }
@@ -600,11 +526,17 @@ static int copy_user_dax(struct block_device *bdev, sector_t sector, size_t size
        return 0;
 }
 
-#define DAX_PMD_INDEX(page_index) (page_index & (PMD_MASK >> PAGE_SHIFT))
-
+/*
+ * By this point grab_mapping_entry() has ensured that we have a locked entry
+ * of the appropriate size so we don't have to worry about downgrading PMDs to
+ * PTEs.  If we happen to be trying to insert a PTE and there is a PMD
+ * already in the tree, we will skip the insertion and just dirty the PMD as
+ * appropriate.
+ */
 static void *dax_insert_mapping_entry(struct address_space *mapping,
                                      struct vm_fault *vmf,
-                                     void *entry, sector_t sector)
+                                     void *entry, sector_t sector,
+                                     unsigned long flags)
 {
        struct radix_tree_root *page_tree = &mapping->page_tree;
        int error = 0;
@@ -627,22 +559,35 @@ static void *dax_insert_mapping_entry(struct address_space *mapping,
                error = radix_tree_preload(vmf->gfp_mask & ~__GFP_HIGHMEM);
                if (error)
                        return ERR_PTR(error);
+       } else if (dax_is_zero_entry(entry) && !(flags & RADIX_DAX_HZP)) {
+               /* replacing huge zero page with PMD block mapping */
+               unmap_mapping_range(mapping,
+                       (vmf->pgoff << PAGE_SHIFT) & PMD_MASK, PMD_SIZE, 0);
        }
 
        spin_lock_irq(&mapping->tree_lock);
-       new_entry = (void *)((unsigned long)RADIX_DAX_ENTRY(sector, false) |
-                      RADIX_DAX_ENTRY_LOCK);
+       new_entry = dax_radix_locked_entry(sector, flags);
+
        if (hole_fill) {
                __delete_from_page_cache(entry, NULL);
                /* Drop pagecache reference */
                put_page(entry);
-               error = radix_tree_insert(page_tree, index, new_entry);
+               error = __radix_tree_insert(page_tree, index,
+                               dax_radix_order(new_entry), new_entry);
                if (error) {
                        new_entry = ERR_PTR(error);
                        goto unlock;
                }
                mapping->nrexceptional++;
-       } else {
+       } else if (dax_is_zero_entry(entry) || dax_is_empty_entry(entry)) {
+               /*
+                * Only swap our new entry into the radix tree if the current
+                * entry is a zero page or an empty entry.  If a normal PTE or
+                * PMD entry is already in the tree, we leave it alone.  This
+                * means that if we are trying to insert a PTE and the
+                * existing entry is a PMD, we will just leave the PMD in the
+                * tree and dirty it if necessary.
+                */
                struct radix_tree_node *node;
                void **slot;
                void *ret;
@@ -674,7 +619,6 @@ static int dax_writeback_one(struct block_device *bdev,
                struct address_space *mapping, pgoff_t index, void *entry)
 {
        struct radix_tree_root *page_tree = &mapping->page_tree;
-       int type = RADIX_DAX_TYPE(entry);
        struct radix_tree_node *node;
        struct blk_dax_ctl dax;
        void **slot;
@@ -695,13 +639,21 @@ static int dax_writeback_one(struct block_device *bdev,
        if (!radix_tree_tag_get(page_tree, index, PAGECACHE_TAG_TOWRITE))
                goto unlock;
 
-       if (WARN_ON_ONCE(type != RADIX_DAX_PTE && type != RADIX_DAX_PMD)) {
+       if (WARN_ON_ONCE(dax_is_empty_entry(entry) ||
+                               dax_is_zero_entry(entry))) {
                ret = -EIO;
                goto unlock;
        }
 
-       dax.sector = RADIX_DAX_SECTOR(entry);
-       dax.size = (type == RADIX_DAX_PMD ? PMD_SIZE : PAGE_SIZE);
+       /*
+        * Even if dax_writeback_mapping_range() was given a wbc->range_start
+        * in the middle of a PMD, the 'index' we are given will be aligned to
+        * the start index of the PMD, as will the sector we pull from
+        * 'entry'.  This allows us to flush for PMD_SIZE and not have to
+        * worry about partial PMD writebacks.
+        */
+       dax.sector = dax_radix_sector(entry);
+       dax.size = PAGE_SIZE << dax_radix_order(entry);
        spin_unlock_irq(&mapping->tree_lock);
 
        /*
@@ -740,12 +692,11 @@ int dax_writeback_mapping_range(struct address_space *mapping,
                struct block_device *bdev, struct writeback_control *wbc)
 {
        struct inode *inode = mapping->host;
-       pgoff_t start_index, end_index, pmd_index;
+       pgoff_t start_index, end_index;
        pgoff_t indices[PAGEVEC_SIZE];
        struct pagevec pvec;
        bool done = false;
        int i, ret = 0;
-       void *entry;
 
        if (WARN_ON_ONCE(inode->i_blkbits != PAGE_SHIFT))
                return -EIO;
@@ -755,15 +706,6 @@ int dax_writeback_mapping_range(struct address_space *mapping,
 
        start_index = wbc->range_start >> PAGE_SHIFT;
        end_index = wbc->range_end >> PAGE_SHIFT;
-       pmd_index = DAX_PMD_INDEX(start_index);
-
-       rcu_read_lock();
-       entry = radix_tree_lookup(&mapping->page_tree, pmd_index);
-       rcu_read_unlock();
-
-       /* see if the start of our range is covered by a PMD entry */
-       if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD)
-               start_index = pmd_index;
 
        tag_pages_for_writeback(mapping, start_index, end_index);
 
@@ -808,7 +750,7 @@ static int dax_insert_mapping(struct address_space *mapping,
                return PTR_ERR(dax.addr);
        dax_unmap_atomic(bdev, &dax);
 
-       ret = dax_insert_mapping_entry(mapping, vmf, entry, dax.sector);
+       ret = dax_insert_mapping_entry(mapping, vmf, entry, dax.sector, 0);
        if (IS_ERR(ret))
                return PTR_ERR(ret);
        *entryp = ret;
@@ -816,323 +758,6 @@ static int dax_insert_mapping(struct address_space *mapping,
        return vm_insert_mixed(vma, vaddr, dax.pfn);
 }
 
-/**
- * dax_fault - handle a page fault on a DAX file
- * @vma: The virtual memory area where the fault occurred
- * @vmf: The description of the fault
- * @get_block: The filesystem method used to translate file offsets to blocks
- *
- * When a page fault occurs, filesystems may call this helper in their
- * fault handler for DAX files. dax_fault() assumes the caller has done all
- * the necessary locking for the page fault to proceed successfully.
- */
-int dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
-                       get_block_t get_block)
-{
-       struct file *file = vma->vm_file;
-       struct address_space *mapping = file->f_mapping;
-       struct inode *inode = mapping->host;
-       void *entry;
-       struct buffer_head bh;
-       unsigned long vaddr = (unsigned long)vmf->virtual_address;
-       unsigned blkbits = inode->i_blkbits;
-       sector_t block;
-       pgoff_t size;
-       int error;
-       int major = 0;
-
-       /*
-        * Check whether offset isn't beyond end of file now. Caller is supposed
-        * to hold locks serializing us with truncate / punch hole so this is
-        * a reliable test.
-        */
-       size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
-       if (vmf->pgoff >= size)
-               return VM_FAULT_SIGBUS;
-
-       memset(&bh, 0, sizeof(bh));
-       block = (sector_t)vmf->pgoff << (PAGE_SHIFT - blkbits);
-       bh.b_bdev = inode->i_sb->s_bdev;
-       bh.b_size = PAGE_SIZE;
-
-       entry = grab_mapping_entry(mapping, vmf->pgoff);
-       if (IS_ERR(entry)) {
-               error = PTR_ERR(entry);
-               goto out;
-       }
-
-       error = get_block(inode, block, &bh, 0);
-       if (!error && (bh.b_size < PAGE_SIZE))
-               error = -EIO;           /* fs corruption? */
-       if (error)
-               goto unlock_entry;
-
-       if (vmf->cow_page) {
-               struct page *new_page = vmf->cow_page;
-               if (buffer_written(&bh))
-                       error = copy_user_dax(bh.b_bdev, to_sector(&bh, inode),
-                                       bh.b_size, new_page, vaddr);
-               else
-                       clear_user_highpage(new_page, vaddr);
-               if (error)
-                       goto unlock_entry;
-               if (!radix_tree_exceptional_entry(entry)) {
-                       vmf->page = entry;
-                       return VM_FAULT_LOCKED;
-               }
-               vmf->entry = entry;
-               return VM_FAULT_DAX_LOCKED;
-       }
-
-       if (!buffer_mapped(&bh)) {
-               if (vmf->flags & FAULT_FLAG_WRITE) {
-                       error = get_block(inode, block, &bh, 1);
-                       count_vm_event(PGMAJFAULT);
-                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-                       major = VM_FAULT_MAJOR;
-                       if (!error && (bh.b_size < PAGE_SIZE))
-                               error = -EIO;
-                       if (error)
-                               goto unlock_entry;
-               } else {
-                       return dax_load_hole(mapping, entry, vmf);
-               }
-       }
-
-       /* Filesystem should not return unwritten buffers to us! */
-       WARN_ON_ONCE(buffer_unwritten(&bh) || buffer_new(&bh));
-       error = dax_insert_mapping(mapping, bh.b_bdev, to_sector(&bh, inode),
-                       bh.b_size, &entry, vma, vmf);
- unlock_entry:
-       put_locked_mapping_entry(mapping, vmf->pgoff, entry);
- out:
-       if (error == -ENOMEM)
-               return VM_FAULT_OOM | major;
-       /* -EBUSY is fine, somebody else faulted on the same PTE */
-       if ((error < 0) && (error != -EBUSY))
-               return VM_FAULT_SIGBUS | major;
-       return VM_FAULT_NOPAGE | major;
-}
-EXPORT_SYMBOL_GPL(dax_fault);
-
-#if defined(CONFIG_TRANSPARENT_HUGEPAGE)
-/*
- * The 'colour' (ie low bits) within a PMD of a page offset.  This comes up
- * more often than one might expect in the below function.
- */
-#define PG_PMD_COLOUR  ((PMD_SIZE >> PAGE_SHIFT) - 1)
-
-static void __dax_dbg(struct buffer_head *bh, unsigned long address,
-               const char *reason, const char *fn)
-{
-       if (bh) {
-               char bname[BDEVNAME_SIZE];
-               bdevname(bh->b_bdev, bname);
-               pr_debug("%s: %s addr: %lx dev %s state %lx start %lld "
-                       "length %zd fallback: %s\n", fn, current->comm,
-                       address, bname, bh->b_state, (u64)bh->b_blocknr,
-                       bh->b_size, reason);
-       } else {
-               pr_debug("%s: %s addr: %lx fallback: %s\n", fn,
-                       current->comm, address, reason);
-       }
-}
-
-#define dax_pmd_dbg(bh, address, reason)       __dax_dbg(bh, address, reason, "dax_pmd")
-
-/**
- * dax_pmd_fault - handle a PMD fault on a DAX file
- * @vma: The virtual memory area where the fault occurred
- * @vmf: The description of the fault
- * @get_block: The filesystem method used to translate file offsets to blocks
- *
- * When a page fault occurs, filesystems may call this helper in their
- * pmd_fault handler for DAX files.
- */
-int dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
-               pmd_t *pmd, unsigned int flags, get_block_t get_block)
-{
-       struct file *file = vma->vm_file;
-       struct address_space *mapping = file->f_mapping;
-       struct inode *inode = mapping->host;
-       struct buffer_head bh;
-       unsigned blkbits = inode->i_blkbits;
-       unsigned long pmd_addr = address & PMD_MASK;
-       bool write = flags & FAULT_FLAG_WRITE;
-       struct block_device *bdev;
-       pgoff_t size, pgoff;
-       sector_t block;
-       int result = 0;
-       bool alloc = false;
-
-       /* dax pmd mappings require pfn_t_devmap() */
-       if (!IS_ENABLED(CONFIG_FS_DAX_PMD))
-               return VM_FAULT_FALLBACK;
-
-       /* Fall back to PTEs if we're going to COW */
-       if (write && !(vma->vm_flags & VM_SHARED)) {
-               split_huge_pmd(vma, pmd, address);
-               dax_pmd_dbg(NULL, address, "cow write");
-               return VM_FAULT_FALLBACK;
-       }
-       /* If the PMD would extend outside the VMA */
-       if (pmd_addr < vma->vm_start) {
-               dax_pmd_dbg(NULL, address, "vma start unaligned");
-               return VM_FAULT_FALLBACK;
-       }
-       if ((pmd_addr + PMD_SIZE) > vma->vm_end) {
-               dax_pmd_dbg(NULL, address, "vma end unaligned");
-               return VM_FAULT_FALLBACK;
-       }
-
-       pgoff = linear_page_index(vma, pmd_addr);
-       size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
-       if (pgoff >= size)
-               return VM_FAULT_SIGBUS;
-       /* If the PMD would cover blocks out of the file */
-       if ((pgoff | PG_PMD_COLOUR) >= size) {
-               dax_pmd_dbg(NULL, address,
-                               "offset + huge page size > file size");
-               return VM_FAULT_FALLBACK;
-       }
-
-       memset(&bh, 0, sizeof(bh));
-       bh.b_bdev = inode->i_sb->s_bdev;
-       block = (sector_t)pgoff << (PAGE_SHIFT - blkbits);
-
-       bh.b_size = PMD_SIZE;
-
-       if (get_block(inode, block, &bh, 0) != 0)
-               return VM_FAULT_SIGBUS;
-
-       if (!buffer_mapped(&bh) && write) {
-               if (get_block(inode, block, &bh, 1) != 0)
-                       return VM_FAULT_SIGBUS;
-               alloc = true;
-               WARN_ON_ONCE(buffer_unwritten(&bh) || buffer_new(&bh));
-       }
-
-       bdev = bh.b_bdev;
-
-       /*
-        * If the filesystem isn't willing to tell us the length of a hole,
-        * just fall back to PTEs.  Calling get_block 512 times in a loop
-        * would be silly.
-        */
-       if (!buffer_size_valid(&bh) || bh.b_size < PMD_SIZE) {
-               dax_pmd_dbg(&bh, address, "allocated block too small");
-               return VM_FAULT_FALLBACK;
-       }
-
-       /*
-        * If we allocated new storage, make sure no process has any
-        * zero pages covering this hole
-        */
-       if (alloc) {
-               loff_t lstart = pgoff << PAGE_SHIFT;
-               loff_t lend = lstart + PMD_SIZE - 1; /* inclusive */
-
-               truncate_pagecache_range(inode, lstart, lend);
-       }
-
-       if (!write && !buffer_mapped(&bh)) {
-               spinlock_t *ptl;
-               pmd_t entry;
-               struct page *zero_page = mm_get_huge_zero_page(vma->vm_mm);
-
-               if (unlikely(!zero_page)) {
-                       dax_pmd_dbg(&bh, address, "no zero page");
-                       goto fallback;
-               }
-
-               ptl = pmd_lock(vma->vm_mm, pmd);
-               if (!pmd_none(*pmd)) {
-                       spin_unlock(ptl);
-                       dax_pmd_dbg(&bh, address, "pmd already present");
-                       goto fallback;
-               }
-
-               dev_dbg(part_to_dev(bdev->bd_part),
-                               "%s: %s addr: %lx pfn: <zero> sect: %llx\n",
-                               __func__, current->comm, address,
-                               (unsigned long long) to_sector(&bh, inode));
-
-               entry = mk_pmd(zero_page, vma->vm_page_prot);
-               entry = pmd_mkhuge(entry);
-               set_pmd_at(vma->vm_mm, pmd_addr, pmd, entry);
-               result = VM_FAULT_NOPAGE;
-               spin_unlock(ptl);
-       } else {
-               struct blk_dax_ctl dax = {
-                       .sector = to_sector(&bh, inode),
-                       .size = PMD_SIZE,
-               };
-               long length = dax_map_atomic(bdev, &dax);
-
-               if (length < 0) {
-                       dax_pmd_dbg(&bh, address, "dax-error fallback");
-                       goto fallback;
-               }
-               if (length < PMD_SIZE) {
-                       dax_pmd_dbg(&bh, address, "dax-length too small");
-                       dax_unmap_atomic(bdev, &dax);
-                       goto fallback;
-               }
-               if (pfn_t_to_pfn(dax.pfn) & PG_PMD_COLOUR) {
-                       dax_pmd_dbg(&bh, address, "pfn unaligned");
-                       dax_unmap_atomic(bdev, &dax);
-                       goto fallback;
-               }
-
-               if (!pfn_t_devmap(dax.pfn)) {
-                       dax_unmap_atomic(bdev, &dax);
-                       dax_pmd_dbg(&bh, address, "pfn not in memmap");
-                       goto fallback;
-               }
-               dax_unmap_atomic(bdev, &dax);
-
-               /*
-                * For PTE faults we insert a radix tree entry for reads, and
-                * leave it clean.  Then on the first write we dirty the radix
-                * tree entry via the dax_pfn_mkwrite() path.  This sequence
-                * allows the dax_pfn_mkwrite() call to be simpler and avoid a
-                * call into get_block() to translate the pgoff to a sector in
-                * order to be able to create a new radix tree entry.
-                *
-                * The PMD path doesn't have an equivalent to
-                * dax_pfn_mkwrite(), though, so for a read followed by a
-                * write we traverse all the way through dax_pmd_fault()
-                * twice.  This means we can just skip inserting a radix tree
-                * entry completely on the initial read and just wait until
-                * the write to insert a dirty entry.
-                */
-               if (write) {
-                       /*
-                        * We should insert radix-tree entry and dirty it here.
-                        * For now this is broken...
-                        */
-               }
-
-               dev_dbg(part_to_dev(bdev->bd_part),
-                               "%s: %s addr: %lx pfn: %lx sect: %llx\n",
-                               __func__, current->comm, address,
-                               pfn_t_to_pfn(dax.pfn),
-                               (unsigned long long) dax.sector);
-               result |= vmf_insert_pfn_pmd(vma, address, pmd,
-                               dax.pfn, write);
-       }
-
- out:
-       return result;
-
- fallback:
-       count_vm_event(THP_FAULT_FALLBACK);
-       result = VM_FAULT_FALLBACK;
-       goto out;
-}
-EXPORT_SYMBOL_GPL(dax_pmd_fault);
-#endif /* CONFIG_TRANSPARENT_HUGEPAGE */
-
 /**
  * dax_pfn_mkwrite - handle first write to DAX page
  * @vma: The virtual memory area where the fault occurred
@@ -1193,62 +818,14 @@ int __dax_zero_page_range(struct block_device *bdev, sector_t sector,
 }
 EXPORT_SYMBOL_GPL(__dax_zero_page_range);
 
-/**
- * dax_zero_page_range - zero a range within a page of a DAX file
- * @inode: The file being truncated
- * @from: The file offset that is being truncated to
- * @length: The number of bytes to zero
- * @get_block: The filesystem method used to translate file offsets to blocks
- *
- * This function can be called by a filesystem when it is zeroing part of a
- * page in a DAX file.  This is intended for hole-punch operations.  If
- * you are truncating a file, the helper function dax_truncate_page() may be
- * more convenient.
- */
-int dax_zero_page_range(struct inode *inode, loff_t from, unsigned length,
-                                                       get_block_t get_block)
-{
-       struct buffer_head bh;
-       pgoff_t index = from >> PAGE_SHIFT;
-       unsigned offset = from & (PAGE_SIZE-1);
-       int err;
-
-       /* Block boundary? Nothing to do */
-       if (!length)
-               return 0;
-       BUG_ON((offset + length) > PAGE_SIZE);
-
-       memset(&bh, 0, sizeof(bh));
-       bh.b_bdev = inode->i_sb->s_bdev;
-       bh.b_size = PAGE_SIZE;
-       err = get_block(inode, index, &bh, 0);
-       if (err < 0 || !buffer_written(&bh))
-               return err;
-
-       return __dax_zero_page_range(bh.b_bdev, to_sector(&bh, inode),
-                       offset, length);
-}
-EXPORT_SYMBOL_GPL(dax_zero_page_range);
-
-/**
- * dax_truncate_page - handle a partial page being truncated in a DAX file
- * @inode: The file being truncated
- * @from: The file offset that is being truncated to
- * @get_block: The filesystem method used to translate file offsets to blocks
- *
- * Similar to block_truncate_page(), this function can be called by a
- * filesystem when it is truncating a DAX file to handle the partial page.
- */
-int dax_truncate_page(struct inode *inode, loff_t from, get_block_t get_block)
+#ifdef CONFIG_FS_IOMAP
+static sector_t dax_iomap_sector(struct iomap *iomap, loff_t pos)
 {
-       unsigned length = PAGE_ALIGN(from) - from;
-       return dax_zero_page_range(inode, from, length, get_block);
+       return iomap->blkno + (((pos & PAGE_MASK) - iomap->offset) >> 9);
 }
-EXPORT_SYMBOL_GPL(dax_truncate_page);
 
-#ifdef CONFIG_FS_IOMAP
 static loff_t
-iomap_dax_actor(struct inode *inode, loff_t pos, loff_t length, void *data,
+dax_iomap_actor(struct inode *inode, loff_t pos, loff_t length, void *data,
                struct iomap *iomap)
 {
        struct iov_iter *iter = data;
@@ -1272,8 +849,7 @@ iomap_dax_actor(struct inode *inode, loff_t pos, loff_t length, void *data,
                struct blk_dax_ctl dax = { 0 };
                ssize_t map_len;
 
-               dax.sector = iomap->blkno +
-                       (((pos & PAGE_MASK) - iomap->offset) >> 9);
+               dax.sector = dax_iomap_sector(iomap, pos);
                dax.size = (length + offset + PAGE_SIZE - 1) & PAGE_MASK;
                map_len = dax_map_atomic(iomap->bdev, &dax);
                if (map_len < 0) {
@@ -1305,7 +881,7 @@ iomap_dax_actor(struct inode *inode, loff_t pos, loff_t length, void *data,
 }
 
 /**
- * iomap_dax_rw - Perform I/O to a DAX file
+ * dax_iomap_rw - Perform I/O to a DAX file
  * @iocb:      The control block for this I/O
  * @iter:      The addresses to do I/O from or to
  * @ops:       iomap ops passed from the file system
@@ -1315,7 +891,7 @@ iomap_dax_actor(struct inode *inode, loff_t pos, loff_t length, void *data,
  * and evicting any page cache pages in the region under I/O.
  */
 ssize_t
-iomap_dax_rw(struct kiocb *iocb, struct iov_iter *iter,
+dax_iomap_rw(struct kiocb *iocb, struct iov_iter *iter,
                struct iomap_ops *ops)
 {
        struct address_space *mapping = iocb->ki_filp->f_mapping;
@@ -1345,7 +921,7 @@ iomap_dax_rw(struct kiocb *iocb, struct iov_iter *iter,
 
        while (iov_iter_count(iter)) {
                ret = iomap_apply(inode, pos, iov_iter_count(iter), flags, ops,
-                               iter, iomap_dax_actor);
+                               iter, dax_iomap_actor);
                if (ret <= 0)
                        break;
                pos += ret;
@@ -1355,10 +931,10 @@ iomap_dax_rw(struct kiocb *iocb, struct iov_iter *iter,
        iocb->ki_pos += done;
        return done ? done : ret;
 }
-EXPORT_SYMBOL_GPL(iomap_dax_rw);
+EXPORT_SYMBOL_GPL(dax_iomap_rw);
 
 /**
- * iomap_dax_fault - handle a page fault on a DAX file
+ * dax_iomap_fault - handle a page fault on a DAX file
  * @vma: The virtual memory area where the fault occurred
  * @vmf: The description of the fault
  * @ops: iomap ops passed from the file system
@@ -1367,7 +943,7 @@ EXPORT_SYMBOL_GPL(iomap_dax_rw);
  * or mkwrite handler for DAX files. Assumes the caller has done all the
  * necessary locking for the page fault to proceed successfully.
  */
-int iomap_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
+int dax_iomap_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                        struct iomap_ops *ops)
 {
        struct address_space *mapping = vma->vm_file->f_mapping;
@@ -1376,8 +952,9 @@ int iomap_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
        loff_t pos = (loff_t)vmf->pgoff << PAGE_SHIFT;
        sector_t sector;
        struct iomap iomap = { 0 };
-       unsigned flags = 0;
+       unsigned flags = IOMAP_FAULT;
        int error, major = 0;
+       int locked_status = 0;
        void *entry;
 
        /*
@@ -1388,7 +965,7 @@ int iomap_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
        if (pos >= i_size_read(inode))
                return VM_FAULT_SIGBUS;
 
-       entry = grab_mapping_entry(mapping, vmf->pgoff);
+       entry = grab_mapping_entry(mapping, vmf->pgoff, 0);
        if (IS_ERR(entry)) {
                error = PTR_ERR(entry);
                goto out;
@@ -1407,10 +984,10 @@ int iomap_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                goto unlock_entry;
        if (WARN_ON_ONCE(iomap.offset + iomap.length < pos + PAGE_SIZE)) {
                error = -EIO;           /* fs corruption? */
-               goto unlock_entry;
+               goto finish_iomap;
        }
 
-       sector = iomap.blkno + (((pos & PAGE_MASK) - iomap.offset) >> 9);
+       sector = dax_iomap_sector(&iomap, pos);
 
        if (vmf->cow_page) {
                switch (iomap.type) {
@@ -1429,13 +1006,15 @@ int iomap_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                }
 
                if (error)
-                       goto unlock_entry;
+                       goto finish_iomap;
                if (!radix_tree_exceptional_entry(entry)) {
                        vmf->page = entry;
-                       return VM_FAULT_LOCKED;
+                       locked_status = VM_FAULT_LOCKED;
+               } else {
+                       vmf->entry = entry;
+                       locked_status = VM_FAULT_DAX_LOCKED;
                }
-               vmf->entry = entry;
-               return VM_FAULT_DAX_LOCKED;
+               goto finish_iomap;
        }
 
        switch (iomap.type) {
@@ -1450,8 +1029,10 @@ int iomap_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                break;
        case IOMAP_UNWRITTEN:
        case IOMAP_HOLE:
-               if (!(vmf->flags & FAULT_FLAG_WRITE))
-                       return dax_load_hole(mapping, entry, vmf);
+               if (!(vmf->flags & FAULT_FLAG_WRITE)) {
+                       locked_status = dax_load_hole(mapping, entry, vmf);
+                       break;
+               }
                /*FALLTHRU*/
        default:
                WARN_ON_ONCE(1);
@@ -1459,15 +1040,218 @@ int iomap_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                break;
        }
 
+ finish_iomap:
+       if (ops->iomap_end) {
+               if (error) {
+                       /* keep previous error */
+                       ops->iomap_end(inode, pos, PAGE_SIZE, 0, flags,
+                                       &iomap);
+               } else {
+                       error = ops->iomap_end(inode, pos, PAGE_SIZE,
+                                       PAGE_SIZE, flags, &iomap);
+               }
+       }
  unlock_entry:
-       put_locked_mapping_entry(mapping, vmf->pgoff, entry);
+       if (!locked_status || error)
+               put_locked_mapping_entry(mapping, vmf->pgoff, entry);
  out:
        if (error == -ENOMEM)
                return VM_FAULT_OOM | major;
        /* -EBUSY is fine, somebody else faulted on the same PTE */
        if (error < 0 && error != -EBUSY)
                return VM_FAULT_SIGBUS | major;
+       if (locked_status) {
+               WARN_ON_ONCE(error); /* -EBUSY from ops->iomap_end? */
+               return locked_status;
+       }
        return VM_FAULT_NOPAGE | major;
 }
-EXPORT_SYMBOL_GPL(iomap_dax_fault);
+EXPORT_SYMBOL_GPL(dax_iomap_fault);
+
+#ifdef CONFIG_FS_DAX_PMD
+/*
+ * The 'colour' (ie low bits) within a PMD of a page offset.  This comes up
+ * more often than one might expect in the below functions.
+ */
+#define PG_PMD_COLOUR  ((PMD_SIZE >> PAGE_SHIFT) - 1)
+
+static int dax_pmd_insert_mapping(struct vm_area_struct *vma, pmd_t *pmd,
+               struct vm_fault *vmf, unsigned long address,
+               struct iomap *iomap, loff_t pos, bool write, void **entryp)
+{
+       struct address_space *mapping = vma->vm_file->f_mapping;
+       struct block_device *bdev = iomap->bdev;
+       struct blk_dax_ctl dax = {
+               .sector = dax_iomap_sector(iomap, pos),
+               .size = PMD_SIZE,
+       };
+       long length = dax_map_atomic(bdev, &dax);
+       void *ret;
+
+       if (length < 0) /* dax_map_atomic() failed */
+               return VM_FAULT_FALLBACK;
+       if (length < PMD_SIZE)
+               goto unmap_fallback;
+       if (pfn_t_to_pfn(dax.pfn) & PG_PMD_COLOUR)
+               goto unmap_fallback;
+       if (!pfn_t_devmap(dax.pfn))
+               goto unmap_fallback;
+
+       dax_unmap_atomic(bdev, &dax);
+
+       ret = dax_insert_mapping_entry(mapping, vmf, *entryp, dax.sector,
+                       RADIX_DAX_PMD);
+       if (IS_ERR(ret))
+               return VM_FAULT_FALLBACK;
+       *entryp = ret;
+
+       return vmf_insert_pfn_pmd(vma, address, pmd, dax.pfn, write);
+
+ unmap_fallback:
+       dax_unmap_atomic(bdev, &dax);
+       return VM_FAULT_FALLBACK;
+}
+
+static int dax_pmd_load_hole(struct vm_area_struct *vma, pmd_t *pmd,
+               struct vm_fault *vmf, unsigned long address,
+               struct iomap *iomap, void **entryp)
+{
+       struct address_space *mapping = vma->vm_file->f_mapping;
+       unsigned long pmd_addr = address & PMD_MASK;
+       struct page *zero_page;
+       spinlock_t *ptl;
+       pmd_t pmd_entry;
+       void *ret;
+
+       zero_page = mm_get_huge_zero_page(vma->vm_mm);
+
+       if (unlikely(!zero_page))
+               return VM_FAULT_FALLBACK;
+
+       ret = dax_insert_mapping_entry(mapping, vmf, *entryp, 0,
+                       RADIX_DAX_PMD | RADIX_DAX_HZP);
+       if (IS_ERR(ret))
+               return VM_FAULT_FALLBACK;
+       *entryp = ret;
+
+       ptl = pmd_lock(vma->vm_mm, pmd);
+       if (!pmd_none(*pmd)) {
+               spin_unlock(ptl);
+               return VM_FAULT_FALLBACK;
+       }
+
+       pmd_entry = mk_pmd(zero_page, vma->vm_page_prot);
+       pmd_entry = pmd_mkhuge(pmd_entry);
+       set_pmd_at(vma->vm_mm, pmd_addr, pmd, pmd_entry);
+       spin_unlock(ptl);
+       return VM_FAULT_NOPAGE;
+}
+
+int dax_iomap_pmd_fault(struct vm_area_struct *vma, unsigned long address,
+               pmd_t *pmd, unsigned int flags, struct iomap_ops *ops)
+{
+       struct address_space *mapping = vma->vm_file->f_mapping;
+       unsigned long pmd_addr = address & PMD_MASK;
+       bool write = flags & FAULT_FLAG_WRITE;
+       unsigned int iomap_flags = (write ? IOMAP_WRITE : 0) | IOMAP_FAULT;
+       struct inode *inode = mapping->host;
+       int result = VM_FAULT_FALLBACK;
+       struct iomap iomap = { 0 };
+       pgoff_t max_pgoff, pgoff;
+       struct vm_fault vmf;
+       void *entry;
+       loff_t pos;
+       int error;
+
+       /* Fall back to PTEs if we're going to COW */
+       if (write && !(vma->vm_flags & VM_SHARED))
+               goto fallback;
+
+       /* If the PMD would extend outside the VMA */
+       if (pmd_addr < vma->vm_start)
+               goto fallback;
+       if ((pmd_addr + PMD_SIZE) > vma->vm_end)
+               goto fallback;
+
+       /*
+        * Check whether offset isn't beyond end of file now. Caller is
+        * supposed to hold locks serializing us with truncate / punch hole so
+        * this is a reliable test.
+        */
+       pgoff = linear_page_index(vma, pmd_addr);
+       max_pgoff = (i_size_read(inode) - 1) >> PAGE_SHIFT;
+
+       if (pgoff > max_pgoff)
+               return VM_FAULT_SIGBUS;
+
+       /* If the PMD would extend beyond the file size */
+       if ((pgoff | PG_PMD_COLOUR) > max_pgoff)
+               goto fallback;
+
+       /*
+        * grab_mapping_entry() will make sure we get a 2M empty entry, a DAX
+        * PMD or a HZP entry.  If it can't (because a 4k page is already in
+        * the tree, for instance), it will return -EEXIST and we just fall
+        * back to 4k entries.
+        */
+       entry = grab_mapping_entry(mapping, pgoff, RADIX_DAX_PMD);
+       if (IS_ERR(entry))
+               goto fallback;
+
+       /*
+        * Note that we don't use iomap_apply here.  We aren't doing I/O, only
+        * setting up a mapping, so really we're using iomap_begin() as a way
+        * to look up our filesystem block.
+        */
+       pos = (loff_t)pgoff << PAGE_SHIFT;
+       error = ops->iomap_begin(inode, pos, PMD_SIZE, iomap_flags, &iomap);
+       if (error)
+               goto unlock_entry;
+       if (iomap.offset + iomap.length < pos + PMD_SIZE)
+               goto finish_iomap;
+
+       vmf.pgoff = pgoff;
+       vmf.flags = flags;
+       vmf.gfp_mask = mapping_gfp_mask(mapping) | __GFP_IO;
+
+       switch (iomap.type) {
+       case IOMAP_MAPPED:
+               result = dax_pmd_insert_mapping(vma, pmd, &vmf, address,
+                               &iomap, pos, write, &entry);
+               break;
+       case IOMAP_UNWRITTEN:
+       case IOMAP_HOLE:
+               if (WARN_ON_ONCE(write))
+                       goto finish_iomap;
+               result = dax_pmd_load_hole(vma, pmd, &vmf, address, &iomap,
+                               &entry);
+               break;
+       default:
+               WARN_ON_ONCE(1);
+               break;
+       }
+
+ finish_iomap:
+       if (ops->iomap_end) {
+               if (result == VM_FAULT_FALLBACK) {
+                       ops->iomap_end(inode, pos, PMD_SIZE, 0, iomap_flags,
+                                       &iomap);
+               } else {
+                       error = ops->iomap_end(inode, pos, PMD_SIZE, PMD_SIZE,
+                                       iomap_flags, &iomap);
+                       if (error)
+                               result = VM_FAULT_FALLBACK;
+               }
+       }
+ unlock_entry:
+       put_locked_mapping_entry(mapping, pgoff, entry);
+ fallback:
+       if (result == VM_FAULT_FALLBACK) {
+               split_huge_pmd(vma, pmd, address);
+               count_vm_event(THP_FAULT_FALLBACK);
+       }
+       return result;
+}
+EXPORT_SYMBOL_GPL(dax_iomap_pmd_fault);
+#endif /* CONFIG_FS_DAX_PMD */
 #endif /* CONFIG_FS_IOMAP */