]> git.proxmox.com Git - mirror_ubuntu-hirsute-kernel.git/blobdiff - fs/io_uring.c
io_uring: fix ltout double free on completion race
[mirror_ubuntu-hirsute-kernel.git] / fs / io_uring.c
index 38c6cbe1ab387d8d50560dd71e9f6008d6566764..2b86b413641a490b610f4123943325939a4f5276 100644 (file)
@@ -222,7 +222,7 @@ struct fixed_file_data {
 struct io_buffer {
        struct list_head list;
        __u64 addr;
-       __s32 len;
+       __u32 len;
        __u16 bid;
 };
 
@@ -411,7 +411,6 @@ struct io_poll_remove {
 
 struct io_close {
        struct file                     *file;
-       struct file                     *put_file;
        int                             fd;
 };
 
@@ -536,7 +535,7 @@ struct io_splice {
 struct io_provide_buf {
        struct file                     *file;
        __u64                           addr;
-       __s32                           len;
+       __u32                           len;
        __u32                           bgid;
        __u16                           nbufs;
        __u16                           bid;
@@ -575,7 +574,7 @@ struct io_unlink {
 struct io_completion {
        struct file                     *file;
        struct list_head                list;
-       int                             cflags;
+       u32                             cflags;
 };
 
 struct io_async_connect {
@@ -857,7 +856,8 @@ static const struct io_op_def io_op_defs[] = {
                .pollout                = 1,
                .needs_async_data       = 1,
                .async_size             = sizeof(struct io_async_msghdr),
-               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG |
+                                               IO_WQ_WORK_FS,
        },
        [IORING_OP_RECVMSG] = {
                .needs_file             = 1,
@@ -866,7 +866,8 @@ static const struct io_op_def io_op_defs[] = {
                .buffer_select          = 1,
                .needs_async_data       = 1,
                .async_size             = sizeof(struct io_async_msghdr),
-               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG |
+                                               IO_WQ_WORK_FS,
        },
        [IORING_OP_TIMEOUT] = {
                .needs_async_data       = 1,
@@ -906,8 +907,6 @@ static const struct io_op_def io_op_defs[] = {
                                                IO_WQ_WORK_FS | IO_WQ_WORK_MM,
        },
        [IORING_OP_CLOSE] = {
-               .needs_file             = 1,
-               .needs_file_no_error    = 1,
                .work_flags             = IO_WQ_WORK_FILES | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_FILES_UPDATE] = {
@@ -994,9 +993,9 @@ enum io_mem_account {
        ACCT_PINNED,
 };
 
-static void __io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
-                                           struct task_struct *task);
-
+static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
+                                        struct task_struct *task,
+                                        struct files_struct *files);
 static void destroy_fixed_file_ref_node(struct fixed_file_ref_node *ref_node);
 static struct fixed_file_ref_node *alloc_fixed_file_ref_node(
                        struct io_ring_ctx *ctx);
@@ -1547,7 +1546,7 @@ static void io_prep_async_work(struct io_kiocb *req)
        if (req->flags & REQ_F_ISREG) {
                if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
                        io_wq_hash_work(&req->work, file_inode(req->file));
-       } else {
+       } else if (!req->file || !S_ISBLK(file_inode(req->file)->i_mode)) {
                if (def->unbound_nonreg_file)
                        req->work.flags |= IO_WQ_WORK_UNBOUND;
        }
@@ -1595,7 +1594,7 @@ static void io_queue_async_work(struct io_kiocb *req)
                io_queue_linked_timeout(link);
 }
 
-static void io_kill_timeout(struct io_kiocb *req)
+static void io_kill_timeout(struct io_kiocb *req, int status)
 {
        struct io_timeout_data *io = req->async_data;
        int ret;
@@ -1605,7 +1604,7 @@ static void io_kill_timeout(struct io_kiocb *req)
                atomic_set(&req->ctx->cq_timeouts,
                        atomic_read(&req->ctx->cq_timeouts) + 1);
                list_del_init(&req->timeout.list);
-               io_cqring_fill_event(req, 0);
+               io_cqring_fill_event(req, status);
                io_put_req_deferred(req, 1);
        }
 }
@@ -1622,7 +1621,7 @@ static bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
        spin_lock_irq(&ctx->completion_lock);
        list_for_each_entry_safe(req, tmp, &ctx->timeout_list, timeout.list) {
                if (io_match_task(req, tsk, files)) {
-                       io_kill_timeout(req);
+                       io_kill_timeout(req, -ECANCELED);
                        canceled++;
                }
        }
@@ -1674,7 +1673,7 @@ static void io_flush_timeouts(struct io_ring_ctx *ctx)
                        break;
 
                list_del_init(&req->timeout.list);
-               io_kill_timeout(req);
+               io_kill_timeout(req, 0);
        } while (!list_empty(&ctx->timeout_list));
 
        ctx->cq_last_tm_flush = seq;
@@ -1824,21 +1823,26 @@ static bool __io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
        return all_flushed;
 }
 
-static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
+static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
                                     struct task_struct *tsk,
                                     struct files_struct *files)
 {
+       bool ret = true;
+
        if (test_bit(0, &ctx->cq_check_overflow)) {
                /* iopoll syncs against uring_lock, not completion_lock */
                if (ctx->flags & IORING_SETUP_IOPOLL)
                        mutex_lock(&ctx->uring_lock);
-               __io_cqring_overflow_flush(ctx, force, tsk, files);
+               ret = __io_cqring_overflow_flush(ctx, force, tsk, files);
                if (ctx->flags & IORING_SETUP_IOPOLL)
                        mutex_unlock(&ctx->uring_lock);
        }
+
+       return ret;
 }
 
-static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
+static void __io_cqring_fill_event(struct io_kiocb *req, long res,
+                                  unsigned int cflags)
 {
        struct io_ring_ctx *ctx = req->ctx;
        struct io_uring_cqe *cqe;
@@ -2170,6 +2174,16 @@ static int io_req_task_work_add(struct io_kiocb *req)
        return ret;
 }
 
+static void io_req_task_work_add_fallback(struct io_kiocb *req,
+                                         void (*cb)(struct callback_head *))
+{
+       struct task_struct *tsk = io_wq_get_task(req->ctx->io_wq);
+
+       init_task_work(&req->task_work, cb);
+       task_work_add(tsk, &req->task_work, TWA_NONE);
+       wake_up_process(tsk);
+}
+
 static void __io_req_task_cancel(struct io_kiocb *req, int error)
 {
        struct io_ring_ctx *ctx = req->ctx;
@@ -2189,7 +2203,9 @@ static void io_req_task_cancel(struct callback_head *cb)
        struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
        struct io_ring_ctx *ctx = req->ctx;
 
+       mutex_lock(&ctx->uring_lock);
        __io_req_task_cancel(req, -ECANCELED);
+       mutex_unlock(&ctx->uring_lock);
        percpu_ref_put(&ctx->refs);
 }
 
@@ -2205,6 +2221,10 @@ static void __io_req_task_submit(struct io_kiocb *req)
        else
                __io_req_task_cancel(req, -EFAULT);
        mutex_unlock(&ctx->uring_lock);
+
+       ctx->flags &= ~IORING_SETUP_R_DISABLED;
+       if (ctx->flags & IORING_SETUP_SQPOLL)
+               io_sq_thread_drop_mm_files();
 }
 
 static void io_req_task_submit(struct callback_head *cb)
@@ -2224,14 +2244,8 @@ static void io_req_task_queue(struct io_kiocb *req)
        percpu_ref_get(&req->ctx->refs);
 
        ret = io_req_task_work_add(req);
-       if (unlikely(ret)) {
-               struct task_struct *tsk;
-
-               init_task_work(&req->task_work, io_req_task_cancel);
-               tsk = io_wq_get_task(req->ctx->io_wq);
-               task_work_add(tsk, &req->task_work, TWA_NONE);
-               wake_up_process(tsk);
-       }
+       if (unlikely(ret))
+               io_req_task_work_add_fallback(req, io_req_task_cancel);
 }
 
 static inline void io_queue_next(struct io_kiocb *req)
@@ -2349,13 +2363,8 @@ static void io_free_req_deferred(struct io_kiocb *req)
 
        init_task_work(&req->task_work, io_put_req_deferred_cb);
        ret = io_req_task_work_add(req);
-       if (unlikely(ret)) {
-               struct task_struct *tsk;
-
-               tsk = io_wq_get_task(req->ctx->io_wq);
-               task_work_add(tsk, &req->task_work, TWA_NONE);
-               wake_up_process(tsk);
-       }
+       if (unlikely(ret))
+               io_req_task_work_add_fallback(req, io_put_req_deferred_cb);
 }
 
 static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
@@ -2364,22 +2373,6 @@ static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
                io_free_req_deferred(req);
 }
 
-static struct io_wq_work *io_steal_work(struct io_kiocb *req)
-{
-       struct io_kiocb *nxt;
-
-       /*
-        * A ref is owned by io-wq in which context we're. So, if that's the
-        * last one, it's safe to steal next work. False negatives are Ok,
-        * it just will be re-punted async in io_put_work()
-        */
-       if (refcount_read(&req->refs) != 1)
-               return NULL;
-
-       nxt = io_req_find_next(req);
-       return nxt ? &nxt->work : NULL;
-}
-
 static void io_double_put_req(struct io_kiocb *req)
 {
        /* drop both submit and complete references */
@@ -2730,6 +2723,13 @@ static bool io_rw_reissue(struct io_kiocb *req, long res)
                return false;
        if ((res != -EAGAIN && res != -EOPNOTSUPP) || io_wq_current_is_worker())
                return false;
+       /*
+        * If ref is dying, we might be running poll reap from the exit work.
+        * Don't attempt to reissue from that path, just let it fail with
+        * -EAGAIN.
+        */
+       if (percpu_ref_is_dying(&req->ctx->refs))
+               return false;
 
        lockdep_assert_held(&req->ctx->uring_lock);
 
@@ -3434,15 +3434,8 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
        /* submit ref gets dropped, acquire a new one */
        refcount_inc(&req->refs);
        ret = io_req_task_work_add(req);
-       if (unlikely(ret)) {
-               struct task_struct *tsk;
-
-               /* queue just for cancelation */
-               init_task_work(&req->task_work, io_req_task_cancel);
-               tsk = io_wq_get_task(req->ctx->io_wq);
-               task_work_add(tsk, &req->task_work, TWA_NONE);
-               wake_up_process(tsk);
-       }
+       if (unlikely(ret))
+               io_req_task_work_add_fallback(req, io_req_task_cancel);
        return 1;
 }
 
@@ -3527,7 +3520,6 @@ static int io_read(struct io_kiocb *req, bool force_nonblock,
        else
                kiocb->ki_flags |= IOCB_NOWAIT;
 
-
        /* If the file doesn't support async, just async punt */
        no_async = force_nonblock && !io_file_supports_async(req->file, READ);
        if (no_async)
@@ -3539,9 +3531,7 @@ static int io_read(struct io_kiocb *req, bool force_nonblock,
 
        ret = io_iter_do_read(req, iter);
 
-       if (!ret) {
-               goto done;
-       } else if (ret == -EIOCBQUEUED) {
+       if (ret == -EIOCBQUEUED) {
                ret = 0;
                goto out_free;
        } else if (ret == -EAGAIN) {
@@ -3555,7 +3545,7 @@ static int io_read(struct io_kiocb *req, bool force_nonblock,
                iov_iter_revert(iter, io_size - iov_iter_count(iter));
                ret = 0;
                goto copy_iov;
-       } else if (ret < 0) {
+       } else if (ret <= 0) {
                /* make sure -ERESTARTSYS -> -EINTR is done */
                goto done;
        }
@@ -3599,6 +3589,7 @@ retry:
                goto out_free;
        } else if (ret > 0 && ret < io_size) {
                /* we got some bytes, but not all. retry. */
+               kiocb->ki_flags &= ~IOCB_WAITQ;
                goto retry;
        }
 done:
@@ -4224,6 +4215,7 @@ static int io_remove_buffers(struct io_kiocb *req, bool force_nonblock,
 static int io_provide_buffers_prep(struct io_kiocb *req,
                                   const struct io_uring_sqe *sqe)
 {
+       unsigned long size, tmp_check;
        struct io_provide_buf *p = &req->pbuf;
        u64 tmp;
 
@@ -4237,7 +4229,14 @@ static int io_provide_buffers_prep(struct io_kiocb *req,
        p->addr = READ_ONCE(sqe->addr);
        p->len = READ_ONCE(sqe->len);
 
-       if (!access_ok(u64_to_user_ptr(p->addr), (p->len * p->nbufs)))
+       if (check_mul_overflow((unsigned long)p->len, (unsigned long)p->nbufs,
+                               &size))
+               return -EOVERFLOW;
+       if (check_add_overflow((unsigned long)p->addr, size, &tmp_check))
+               return -EOVERFLOW;
+
+       size = (unsigned long)p->len * p->nbufs;
+       if (!access_ok(u64_to_user_ptr(p->addr), size))
                return -EFAULT;
 
        p->bgid = READ_ONCE(sqe->buf_group);
@@ -4260,7 +4259,7 @@ static int io_add_buffers(struct io_provide_buf *pbuf, struct io_buffer **head)
                        break;
 
                buf->addr = addr;
-               buf->len = pbuf->len;
+               buf->len = min_t(__u32, pbuf->len, MAX_RW_COUNT);
                buf->bid = bid;
                addr += pbuf->len;
                bid++;
@@ -4476,13 +4475,6 @@ static int io_statx(struct io_kiocb *req, bool force_nonblock)
 
 static int io_close_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
-       /*
-        * If we queue this for async, it must not be cancellable. That would
-        * leave the 'file' in an undeterminate state, and here need to modify
-        * io_wq_work.flags, so initialize io_wq_work firstly.
-        */
-       io_req_init_async(req);
-
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
        if (sqe->ioprio || sqe->off || sqe->addr || sqe->len ||
@@ -4492,43 +4484,59 @@ static int io_close_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return -EBADF;
 
        req->close.fd = READ_ONCE(sqe->fd);
-       if ((req->file && req->file->f_op == &io_uring_fops))
-               return -EBADF;
-
-       req->close.put_file = NULL;
        return 0;
 }
 
 static int io_close(struct io_kiocb *req, bool force_nonblock,
                    struct io_comp_state *cs)
 {
+       struct files_struct *files = current->files;
        struct io_close *close = &req->close;
+       struct fdtable *fdt;
+       struct file *file;
        int ret;
 
-       /* might be already done during nonblock submission */
-       if (!close->put_file) {
-               ret = close_fd_get_file(close->fd, &close->put_file);
-               if (ret < 0)
-                       return (ret == -ENOENT) ? -EBADF : ret;
+       file = NULL;
+       ret = -EBADF;
+       spin_lock(&files->file_lock);
+       fdt = files_fdtable(files);
+       if (close->fd >= fdt->max_fds) {
+               spin_unlock(&files->file_lock);
+               goto err;
+       }
+       file = fdt->fd[close->fd];
+       if (!file) {
+               spin_unlock(&files->file_lock);
+               goto err;
+       }
+
+       if (file->f_op == &io_uring_fops) {
+               spin_unlock(&files->file_lock);
+               file = NULL;
+               goto err;
        }
 
        /* if the file has a flush method, be safe and punt to async */
-       if (close->put_file->f_op->flush && force_nonblock) {
-               /* not safe to cancel at this point */
-               req->work.flags |= IO_WQ_WORK_NO_CANCEL;
-               /* was never set, but play safe */
-               req->flags &= ~REQ_F_NOWAIT;
-               /* avoid grabbing files - we don't need the files */
-               req->flags |= REQ_F_NO_FILE_TABLE;
+       if (file->f_op->flush && force_nonblock) {
+               spin_unlock(&files->file_lock);
                return -EAGAIN;
        }
 
+       ret = __close_fd_get_file(close->fd, &file);
+       spin_unlock(&files->file_lock);
+       if (ret < 0) {
+               if (ret == -ENOENT)
+                       ret = -EBADF;
+               goto err;
+       }
+
        /* No ->flush() or already async, safely close from here */
-       ret = filp_close(close->put_file, req->work.identity->files);
+       ret = filp_close(file, current->files);
+err:
        if (ret < 0)
                req_set_fail_links(req);
-       fput(close->put_file);
-       close->put_file = NULL;
+       if (file)
+               fput(file);
        __io_req_complete(req, ret, 0, cs);
        return 0;
 }
@@ -4627,6 +4635,7 @@ static int io_sendmsg(struct io_kiocb *req, bool force_nonblock,
        struct io_async_msghdr iomsg, *kmsg;
        struct socket *sock;
        unsigned flags;
+       int min_ret = 0;
        int ret;
 
        sock = sock_from_file(req->file);
@@ -4647,12 +4656,15 @@ static int io_sendmsg(struct io_kiocb *req, bool force_nonblock,
                kmsg = &iomsg;
        }
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&kmsg->msg.msg_iter);
+
        ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
        if (force_nonblock && ret == -EAGAIN)
                return io_setup_async_msg(req, kmsg);
@@ -4662,7 +4674,7 @@ static int io_sendmsg(struct io_kiocb *req, bool force_nonblock,
        if (kmsg->iov != kmsg->fast_iov)
                kfree(kmsg->iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < 0)
+       if (ret < min_ret)
                req_set_fail_links(req);
        __io_req_complete(req, ret, 0, cs);
        return 0;
@@ -4676,6 +4688,7 @@ static int io_send(struct io_kiocb *req, bool force_nonblock,
        struct iovec iov;
        struct socket *sock;
        unsigned flags;
+       int min_ret = 0;
        int ret;
 
        sock = sock_from_file(req->file);
@@ -4691,12 +4704,15 @@ static int io_send(struct io_kiocb *req, bool force_nonblock,
        msg.msg_controllen = 0;
        msg.msg_namelen = 0;
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&msg.msg_iter);
+
        msg.msg_flags = flags;
        ret = sock_sendmsg(sock, &msg);
        if (force_nonblock && ret == -EAGAIN)
@@ -4704,7 +4720,7 @@ static int io_send(struct io_kiocb *req, bool force_nonblock,
        if (ret == -ERESTARTSYS)
                ret = -EINTR;
 
-       if (ret < 0)
+       if (ret < min_ret)
                req_set_fail_links(req);
        __io_req_complete(req, ret, 0, cs);
        return 0;
@@ -4856,6 +4872,7 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock,
        struct socket *sock;
        struct io_buffer *kbuf;
        unsigned flags;
+       int min_ret = 0;
        int ret, cflags = 0;
 
        sock = sock_from_file(req->file);
@@ -4885,12 +4902,15 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock,
                                1, req->sr_msg.len);
        }
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&kmsg->msg.msg_iter);
+
        ret = __sys_recvmsg_sock(sock, &kmsg->msg, req->sr_msg.umsg,
                                        kmsg->uaddr, flags);
        if (force_nonblock && ret == -EAGAIN)
@@ -4903,7 +4923,7 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock,
        if (kmsg->iov != kmsg->fast_iov)
                kfree(kmsg->iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < 0)
+       if (ret < min_ret || ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
                req_set_fail_links(req);
        __io_req_complete(req, ret, cflags, cs);
        return 0;
@@ -4919,6 +4939,7 @@ static int io_recv(struct io_kiocb *req, bool force_nonblock,
        struct socket *sock;
        struct iovec iov;
        unsigned flags;
+       int min_ret = 0;
        int ret, cflags = 0;
 
        sock = sock_from_file(req->file);
@@ -4943,12 +4964,15 @@ static int io_recv(struct io_kiocb *req, bool force_nonblock,
        msg.msg_iocb = NULL;
        msg.msg_flags = 0;
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&msg.msg_iter);
+
        ret = sock_recvmsg(sock, &msg, flags);
        if (force_nonblock && ret == -EAGAIN)
                return -EAGAIN;
@@ -4957,7 +4981,7 @@ static int io_recv(struct io_kiocb *req, bool force_nonblock,
 out_free:
        if (req->flags & REQ_F_BUFFER_SELECTED)
                cflags = io_put_recv_kbuf(req);
-       if (ret < 0)
+       if (ret < min_ret || ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
                req_set_fail_links(req);
        __io_req_complete(req, ret, cflags, cs);
        return 0;
@@ -5154,12 +5178,8 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
         */
        ret = io_req_task_work_add(req);
        if (unlikely(ret)) {
-               struct task_struct *tsk;
-
                WRITE_ONCE(poll->canceled, true);
-               tsk = io_wq_get_task(req->ctx->io_wq);
-               task_work_add(tsk, &req->task_work, TWA_NONE);
-               wake_up_process(tsk);
+               io_req_task_work_add_fallback(req, func);
        }
        return 1;
 }
@@ -5311,6 +5331,9 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                        pt->error = -EINVAL;
                        return;
                }
+               /* double add on the same waitqueue head, ignore */
+               if (poll->head == head)
+                       return;
                poll = kmalloc(sizeof(*poll), GFP_ATOMIC);
                if (!poll) {
                        pt->error = -ENOMEM;
@@ -6376,7 +6399,7 @@ static int io_issue_sqe(struct io_kiocb *req, bool force_nonblock,
        return 0;
 }
 
-static struct io_wq_work *io_wq_submit_work(struct io_wq_work *work)
+static void io_wq_submit_work(struct io_wq_work *work)
 {
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
        struct io_kiocb *timeout;
@@ -6386,10 +6409,12 @@ static struct io_wq_work *io_wq_submit_work(struct io_wq_work *work)
        if (timeout)
                io_queue_linked_timeout(timeout);
 
-       /* if NO_CANCEL is set, we must still run the work */
-       if ((work->flags & (IO_WQ_WORK_CANCEL|IO_WQ_WORK_NO_CANCEL)) ==
-                               IO_WQ_WORK_CANCEL) {
-               ret = -ECANCELED;
+       if (work->flags & IO_WQ_WORK_CANCEL) {
+               /* io-wq is going to take down one */
+               refcount_inc(&req->refs);
+               percpu_ref_get(&req->ctx->refs);
+               io_req_task_work_add_fallback(req, io_req_task_cancel);
+               return;
        }
 
        if (!ret) {
@@ -6430,8 +6455,6 @@ static struct io_wq_work *io_wq_submit_work(struct io_wq_work *work)
                if (lock_ctx)
                        mutex_unlock(&lock_ctx->uring_lock);
        }
-
-       return io_steal_work(req);
 }
 
 static inline struct file *io_file_from_index(struct io_ring_ctx *ctx,
@@ -6489,18 +6512,20 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
         * We don't expect the list to be empty, that will only happen if we
         * race with the completion of the linked work.
         */
-       if (prev && refcount_inc_not_zero(&prev->refs))
+       if (prev) {
                io_remove_next_linked(prev);
-       else
-               prev = NULL;
+               if (!refcount_inc_not_zero(&prev->refs))
+                       prev = NULL;
+       }
        spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
        if (prev) {
-               req_set_fail_links(prev);
                io_async_find_and_cancel(ctx, req, prev->user_data, -ETIME);
-               io_put_req(prev);
+               io_put_req_deferred(prev, 1);
+               io_put_req_deferred(req, 1);
        } else {
-               io_req_complete(req, -ETIME);
+               io_cqring_add_event(req, -ETIME, 0);
+               io_put_req_deferred(req, 1);
        }
        return HRTIMER_NORESTART;
 }
@@ -7212,6 +7237,25 @@ static int io_run_task_work_sig(void)
        return -EINTR;
 }
 
+/* when returns >0, the caller should retry */
+static inline int io_cqring_wait_schedule(struct io_ring_ctx *ctx,
+                                         struct io_wait_queue *iowq,
+                                         signed long *timeout)
+{
+       int ret;
+
+       /* make sure we run task_work before checking for signals */
+       ret = io_run_task_work_sig();
+       if (ret || io_should_wake(iowq))
+               return ret;
+       /* let the caller flush overflows, retry */
+       if (test_bit(0, &ctx->cq_check_overflow))
+               return 1;
+
+       *timeout = schedule_timeout(*timeout);
+       return !*timeout ? -ETIME : 1;
+}
+
 /*
  * Wait until events become available, if we don't already have some. The
  * application must reap them itself, as they reside on the shared cq ring.
@@ -7230,9 +7274,8 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events,
                .to_wait        = min_events,
        };
        struct io_rings *rings = ctx->rings;
-       struct timespec64 ts;
-       signed long timeout = 0;
-       int ret = 0;
+       signed long timeout = MAX_SCHEDULE_TIMEOUT;
+       int ret;
 
        do {
                io_cqring_overflow_flush(ctx, false, NULL, NULL);
@@ -7256,6 +7299,8 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events,
        }
 
        if (uts) {
+               struct timespec64 ts;
+
                if (get_timespec64(&ts, uts))
                        return -EFAULT;
                timeout = timespec64_to_jiffies(&ts);
@@ -7264,34 +7309,17 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events,
        iowq.nr_timeouts = atomic_read(&ctx->cq_timeouts);
        trace_io_uring_cqring_wait(ctx, min_events);
        do {
-               io_cqring_overflow_flush(ctx, false, NULL, NULL);
-               prepare_to_wait_exclusive(&ctx->wait, &iowq.wq,
-                                               TASK_INTERRUPTIBLE);
-               /* make sure we run task_work before checking for signals */
-               ret = io_run_task_work_sig();
-               if (ret > 0) {
-                       finish_wait(&ctx->wait, &iowq.wq);
-                       continue;
-               }
-               else if (ret < 0)
-                       break;
-               if (io_should_wake(&iowq))
+               /* if we can't even flush overflow, don't wait for more */
+               if (!io_cqring_overflow_flush(ctx, false, NULL, NULL)) {
+                       ret = -EBUSY;
                        break;
-               if (test_bit(0, &ctx->cq_check_overflow)) {
-                       finish_wait(&ctx->wait, &iowq.wq);
-                       continue;
-               }
-               if (uts) {
-                       timeout = schedule_timeout(timeout);
-                       if (timeout == 0) {
-                               ret = -ETIME;
-                               break;
-                       }
-               } else {
-                       schedule();
                }
-       } while (1);
-       finish_wait(&ctx->wait, &iowq.wq);
+               prepare_to_wait_exclusive(&ctx->wait, &iowq.wq,
+                                               TASK_INTERRUPTIBLE);
+               ret = io_cqring_wait_schedule(ctx, &iowq, &timeout);
+               finish_wait(&ctx->wait, &iowq.wq);
+               cond_resched();
+       } while (ret > 0);
 
        restore_saved_sigmask_unless(ret == -EINTR);
 
@@ -8062,12 +8090,12 @@ static int io_sqe_files_update(struct io_ring_ctx *ctx, void __user *arg,
        return __io_sqe_files_update(ctx, &up, nr_args);
 }
 
-static void io_free_work(struct io_wq_work *work)
+static struct io_wq_work *io_free_work(struct io_wq_work *work)
 {
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
 
-       /* Consider that io_steal_work() relies on this ref */
-       io_put_req(req);
+       req = io_put_req_find_next(req);
+       return req ? &req->work : NULL;
 }
 
 static int io_init_wq_offload(struct io_ring_ctx *ctx,
@@ -8718,8 +8746,29 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait)
        smp_rmb();
        if (!io_sqring_full(ctx))
                mask |= EPOLLOUT | EPOLLWRNORM;
-       io_cqring_overflow_flush(ctx, false, NULL, NULL);
-       if (io_cqring_events(ctx))
+
+       /* prevent SQPOLL from submitting new requests */
+       if (ctx->sq_data) {
+               io_sq_thread_park(ctx->sq_data);
+               list_del_init(&ctx->sqd_list);
+               io_sqd_update_thread_idle(ctx->sq_data);
+               io_sq_thread_unpark(ctx->sq_data);
+       }
+
+       /*
+        * Don't flush cqring overflow list here, just do a simple check.
+        * Otherwise there could possible be ABBA deadlock:
+        *      CPU0                    CPU1
+        *      ----                    ----
+        * lock(&ctx->uring_lock);
+        *                              lock(&ep->mtx);
+        *                              lock(&ctx->uring_lock);
+        * lock(&ep->mtx);
+        *
+        * Users may get EPOLLIN meanwhile seeing nothing in cqring, this
+        * pushs them to do the flush.
+        */
+       if (io_cqring_events(ctx) || test_bit(0, &ctx->cq_check_overflow))
                mask |= EPOLLIN | EPOLLRDNORM;
 
        return mask;
@@ -8758,7 +8807,7 @@ static void io_ring_exit_work(struct work_struct *work)
         * as nobody else will be looking for them.
         */
        do {
-               __io_uring_cancel_task_requests(ctx, NULL);
+               io_uring_try_cancel_requests(ctx, NULL, NULL);
        } while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20));
        io_ring_ctx_free(ctx);
 }
@@ -8846,11 +8895,11 @@ static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
        return ret;
 }
 
-static void io_cancel_defer_files(struct io_ring_ctx *ctx,
+static bool io_cancel_defer_files(struct io_ring_ctx *ctx,
                                  struct task_struct *task,
                                  struct files_struct *files)
 {
-       struct io_defer_entry *de = NULL;
+       struct io_defer_entry *de;
        LIST_HEAD(list);
 
        spin_lock_irq(&ctx->completion_lock);
@@ -8861,6 +8910,8 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
                }
        }
        spin_unlock_irq(&ctx->completion_lock);
+       if (list_empty(&list))
+               return false;
 
        while (!list_empty(&list)) {
                de = list_first_entry(&list, struct io_defer_entry, list);
@@ -8870,6 +8921,43 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
                io_req_complete(de->req, -ECANCELED);
                kfree(de);
        }
+       return true;
+}
+
+static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
+                                        struct task_struct *task,
+                                        struct files_struct *files)
+{
+       struct io_task_cancel cancel = { .task = task, .files = files, };
+
+       while (1) {
+               enum io_wq_cancel cret;
+               bool ret = false;
+
+               if (ctx->io_wq) {
+                       cret = io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb,
+                                              &cancel, true);
+                       ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
+               }
+
+               /* SQPOLL thread does its own polling */
+               if ((!(ctx->flags & IORING_SETUP_SQPOLL) && !files) ||
+                   (ctx->sq_data && ctx->sq_data->thread == current)) {
+                       while (!list_empty_careful(&ctx->iopoll_list)) {
+                               io_iopoll_try_reap_events(ctx);
+                               ret = true;
+                       }
+               }
+
+               ret |= io_cancel_defer_files(ctx, task, files);
+               ret |= io_poll_remove_all(ctx, task, files);
+               ret |= io_kill_timeouts(ctx, task, files);
+               ret |= io_run_task_work();
+               io_cqring_overflow_flush(ctx, true, task, files);
+               if (!ret)
+                       break;
+               cond_resched();
+       }
 }
 
 static int io_uring_count_inflight(struct io_ring_ctx *ctx,
@@ -8891,7 +8979,6 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
                                  struct files_struct *files)
 {
        while (!list_empty_careful(&ctx->inflight_list)) {
-               struct io_task_cancel cancel = { .task = task, .files = files };
                DEFINE_WAIT(wait);
                int inflight;
 
@@ -8899,49 +8986,17 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
                if (!inflight)
                        break;
 
-               io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, &cancel, true);
-               io_poll_remove_all(ctx, task, files);
-               io_kill_timeouts(ctx, task, files);
-               io_cqring_overflow_flush(ctx, true, task, files);
-               /* cancellations _may_ trigger task work */
-               io_run_task_work();
+               io_uring_try_cancel_requests(ctx, task, files);
 
+               if (ctx->sq_data)
+                       io_sq_thread_unpark(ctx->sq_data);
                prepare_to_wait(&task->io_uring->wait, &wait,
                                TASK_UNINTERRUPTIBLE);
                if (inflight == io_uring_count_inflight(ctx, task, files))
                        schedule();
                finish_wait(&task->io_uring->wait, &wait);
-       }
-}
-
-static void __io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
-                                           struct task_struct *task)
-{
-       while (1) {
-               struct io_task_cancel cancel = { .task = task, .files = NULL, };
-               enum io_wq_cancel cret;
-               bool ret = false;
-
-               if (ctx->io_wq) {
-                       cret = io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb,
-                                              &cancel, true);
-                       ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
-               }
-
-               /* SQPOLL thread does its own polling */
-               if (!(ctx->flags & IORING_SETUP_SQPOLL)) {
-                       while (!list_empty_careful(&ctx->iopoll_list)) {
-                               io_iopoll_try_reap_events(ctx);
-                               ret = true;
-                       }
-               }
-
-               ret |= io_poll_remove_all(ctx, task, NULL);
-               ret |= io_kill_timeouts(ctx, task, NULL);
-               ret |= io_run_task_work();
-               if (!ret)
-                       break;
-               cond_resched();
+               if (ctx->sq_data)
+                       io_sq_thread_park(ctx->sq_data);
        }
 }
 
@@ -8949,6 +9004,8 @@ static void io_disable_sqo_submit(struct io_ring_ctx *ctx)
 {
        mutex_lock(&ctx->uring_lock);
        ctx->sqo_dead = 1;
+       if (ctx->flags & IORING_SETUP_R_DISABLED)
+               io_sq_offload_start(ctx);
        mutex_unlock(&ctx->uring_lock);
 
        /* make sure callers enter the ring to get error */
@@ -8973,21 +9030,12 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
                io_sq_thread_park(ctx->sq_data);
        }
 
-       io_cancel_defer_files(ctx, task, files);
-       io_cqring_overflow_flush(ctx, true, task, files);
-
        io_uring_cancel_files(ctx, task, files);
        if (!files)
-               __io_uring_cancel_task_requests(ctx, task);
+               io_uring_try_cancel_requests(ctx, task, NULL);
 
        if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
                atomic_dec(&task->io_uring->in_idle);
-               /*
-                * If the files that are going away are the ones in the thread
-                * identity, clear them out.
-                */
-               if (task->io_uring->identity->files == files)
-                       task->io_uring->identity->files = NULL;
                io_sq_thread_unpark(ctx->sq_data);
        }
 }
@@ -9971,10 +10019,7 @@ static int io_register_enable_rings(struct io_ring_ctx *ctx)
        if (ctx->restrictions.registered)
                ctx->restricted = 1;
 
-       ctx->flags &= ~IORING_SETUP_R_DISABLED;
-
        io_sq_offload_start(ctx);
-
        return 0;
 }