]> git.proxmox.com Git - mirror_ubuntu-hirsute-kernel.git/blobdiff - fs/io_uring.c
io_uring: fix ltout double free on completion race
[mirror_ubuntu-hirsute-kernel.git] / fs / io_uring.c
index 381f82ebd28286d76524d9f6c7fd61c90adc47ac..2b86b413641a490b610f4123943325939a4f5276 100644 (file)
@@ -222,7 +222,7 @@ struct fixed_file_data {
 struct io_buffer {
        struct list_head list;
        __u64 addr;
-       __s32 len;
+       __u32 len;
        __u16 bid;
 };
 
@@ -535,7 +535,7 @@ struct io_splice {
 struct io_provide_buf {
        struct file                     *file;
        __u64                           addr;
-       __s32                           len;
+       __u32                           len;
        __u32                           bgid;
        __u16                           nbufs;
        __u16                           bid;
@@ -574,7 +574,7 @@ struct io_unlink {
 struct io_completion {
        struct file                     *file;
        struct list_head                list;
-       int                             cflags;
+       u32                             cflags;
 };
 
 struct io_async_connect {
@@ -1546,7 +1546,7 @@ static void io_prep_async_work(struct io_kiocb *req)
        if (req->flags & REQ_F_ISREG) {
                if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
                        io_wq_hash_work(&req->work, file_inode(req->file));
-       } else {
+       } else if (!req->file || !S_ISBLK(file_inode(req->file)->i_mode)) {
                if (def->unbound_nonreg_file)
                        req->work.flags |= IO_WQ_WORK_UNBOUND;
        }
@@ -1594,7 +1594,7 @@ static void io_queue_async_work(struct io_kiocb *req)
                io_queue_linked_timeout(link);
 }
 
-static void io_kill_timeout(struct io_kiocb *req)
+static void io_kill_timeout(struct io_kiocb *req, int status)
 {
        struct io_timeout_data *io = req->async_data;
        int ret;
@@ -1604,7 +1604,7 @@ static void io_kill_timeout(struct io_kiocb *req)
                atomic_set(&req->ctx->cq_timeouts,
                        atomic_read(&req->ctx->cq_timeouts) + 1);
                list_del_init(&req->timeout.list);
-               io_cqring_fill_event(req, 0);
+               io_cqring_fill_event(req, status);
                io_put_req_deferred(req, 1);
        }
 }
@@ -1621,7 +1621,7 @@ static bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
        spin_lock_irq(&ctx->completion_lock);
        list_for_each_entry_safe(req, tmp, &ctx->timeout_list, timeout.list) {
                if (io_match_task(req, tsk, files)) {
-                       io_kill_timeout(req);
+                       io_kill_timeout(req, -ECANCELED);
                        canceled++;
                }
        }
@@ -1673,7 +1673,7 @@ static void io_flush_timeouts(struct io_ring_ctx *ctx)
                        break;
 
                list_del_init(&req->timeout.list);
-               io_kill_timeout(req);
+               io_kill_timeout(req, 0);
        } while (!list_empty(&ctx->timeout_list));
 
        ctx->cq_last_tm_flush = seq;
@@ -1841,7 +1841,8 @@ static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
        return ret;
 }
 
-static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
+static void __io_cqring_fill_event(struct io_kiocb *req, long res,
+                                  unsigned int cflags)
 {
        struct io_ring_ctx *ctx = req->ctx;
        struct io_uring_cqe *cqe;
@@ -4214,7 +4215,7 @@ static int io_remove_buffers(struct io_kiocb *req, bool force_nonblock,
 static int io_provide_buffers_prep(struct io_kiocb *req,
                                   const struct io_uring_sqe *sqe)
 {
-       unsigned long size;
+       unsigned long size, tmp_check;
        struct io_provide_buf *p = &req->pbuf;
        u64 tmp;
 
@@ -4228,6 +4229,12 @@ static int io_provide_buffers_prep(struct io_kiocb *req,
        p->addr = READ_ONCE(sqe->addr);
        p->len = READ_ONCE(sqe->len);
 
+       if (check_mul_overflow((unsigned long)p->len, (unsigned long)p->nbufs,
+                               &size))
+               return -EOVERFLOW;
+       if (check_add_overflow((unsigned long)p->addr, size, &tmp_check))
+               return -EOVERFLOW;
+
        size = (unsigned long)p->len * p->nbufs;
        if (!access_ok(u64_to_user_ptr(p->addr), size))
                return -EFAULT;
@@ -4252,7 +4259,7 @@ static int io_add_buffers(struct io_provide_buf *pbuf, struct io_buffer **head)
                        break;
 
                buf->addr = addr;
-               buf->len = pbuf->len;
+               buf->len = min_t(__u32, pbuf->len, MAX_RW_COUNT);
                buf->bid = bid;
                addr += pbuf->len;
                bid++;
@@ -4628,6 +4635,7 @@ static int io_sendmsg(struct io_kiocb *req, bool force_nonblock,
        struct io_async_msghdr iomsg, *kmsg;
        struct socket *sock;
        unsigned flags;
+       int min_ret = 0;
        int ret;
 
        sock = sock_from_file(req->file);
@@ -4648,12 +4656,15 @@ static int io_sendmsg(struct io_kiocb *req, bool force_nonblock,
                kmsg = &iomsg;
        }
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&kmsg->msg.msg_iter);
+
        ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
        if (force_nonblock && ret == -EAGAIN)
                return io_setup_async_msg(req, kmsg);
@@ -4663,7 +4674,7 @@ static int io_sendmsg(struct io_kiocb *req, bool force_nonblock,
        if (kmsg->iov != kmsg->fast_iov)
                kfree(kmsg->iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < 0)
+       if (ret < min_ret)
                req_set_fail_links(req);
        __io_req_complete(req, ret, 0, cs);
        return 0;
@@ -4677,6 +4688,7 @@ static int io_send(struct io_kiocb *req, bool force_nonblock,
        struct iovec iov;
        struct socket *sock;
        unsigned flags;
+       int min_ret = 0;
        int ret;
 
        sock = sock_from_file(req->file);
@@ -4692,12 +4704,15 @@ static int io_send(struct io_kiocb *req, bool force_nonblock,
        msg.msg_controllen = 0;
        msg.msg_namelen = 0;
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&msg.msg_iter);
+
        msg.msg_flags = flags;
        ret = sock_sendmsg(sock, &msg);
        if (force_nonblock && ret == -EAGAIN)
@@ -4705,7 +4720,7 @@ static int io_send(struct io_kiocb *req, bool force_nonblock,
        if (ret == -ERESTARTSYS)
                ret = -EINTR;
 
-       if (ret < 0)
+       if (ret < min_ret)
                req_set_fail_links(req);
        __io_req_complete(req, ret, 0, cs);
        return 0;
@@ -4857,6 +4872,7 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock,
        struct socket *sock;
        struct io_buffer *kbuf;
        unsigned flags;
+       int min_ret = 0;
        int ret, cflags = 0;
 
        sock = sock_from_file(req->file);
@@ -4886,12 +4902,15 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock,
                                1, req->sr_msg.len);
        }
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&kmsg->msg.msg_iter);
+
        ret = __sys_recvmsg_sock(sock, &kmsg->msg, req->sr_msg.umsg,
                                        kmsg->uaddr, flags);
        if (force_nonblock && ret == -EAGAIN)
@@ -4904,7 +4923,7 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock,
        if (kmsg->iov != kmsg->fast_iov)
                kfree(kmsg->iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < 0)
+       if (ret < min_ret || ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
                req_set_fail_links(req);
        __io_req_complete(req, ret, cflags, cs);
        return 0;
@@ -4920,6 +4939,7 @@ static int io_recv(struct io_kiocb *req, bool force_nonblock,
        struct socket *sock;
        struct iovec iov;
        unsigned flags;
+       int min_ret = 0;
        int ret, cflags = 0;
 
        sock = sock_from_file(req->file);
@@ -4944,12 +4964,15 @@ static int io_recv(struct io_kiocb *req, bool force_nonblock,
        msg.msg_iocb = NULL;
        msg.msg_flags = 0;
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&msg.msg_iter);
+
        ret = sock_recvmsg(sock, &msg, flags);
        if (force_nonblock && ret == -EAGAIN)
                return -EAGAIN;
@@ -4958,7 +4981,7 @@ static int io_recv(struct io_kiocb *req, bool force_nonblock,
 out_free:
        if (req->flags & REQ_F_BUFFER_SELECTED)
                cflags = io_put_recv_kbuf(req);
-       if (ret < 0)
+       if (ret < min_ret || ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
                req_set_fail_links(req);
        __io_req_complete(req, ret, cflags, cs);
        return 0;
@@ -6489,15 +6512,17 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
         * We don't expect the list to be empty, that will only happen if we
         * race with the completion of the linked work.
         */
-       if (prev && refcount_inc_not_zero(&prev->refs))
+       if (prev) {
                io_remove_next_linked(prev);
-       else
-               prev = NULL;
+               if (!refcount_inc_not_zero(&prev->refs))
+                       prev = NULL;
+       }
        spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
        if (prev) {
                io_async_find_and_cancel(ctx, req, prev->user_data, -ETIME);
                io_put_req_deferred(prev, 1);
+               io_put_req_deferred(req, 1);
        } else {
                io_cqring_add_event(req, -ETIME, 0);
                io_put_req_deferred(req, 1);
@@ -8722,6 +8747,14 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait)
        if (!io_sqring_full(ctx))
                mask |= EPOLLOUT | EPOLLWRNORM;
 
+       /* prevent SQPOLL from submitting new requests */
+       if (ctx->sq_data) {
+               io_sq_thread_park(ctx->sq_data);
+               list_del_init(&ctx->sqd_list);
+               io_sqd_update_thread_idle(ctx->sq_data);
+               io_sq_thread_unpark(ctx->sq_data);
+       }
+
        /*
         * Don't flush cqring overflow list here, just do a simple check.
         * Otherwise there could possible be ABBA deadlock: