]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blame - net/vmw_vsock/virtio_transport_common.c
Merge remote-tracking branches 'asoc/topic/adsp', 'asoc/topic/ak4613', 'asoc/topic...
[mirror_ubuntu-bionic-kernel.git] / net / vmw_vsock / virtio_transport_common.c
CommitLineData
06a8fc78
AH
1/*
2 * common code for virtio vsock
3 *
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * This work is licensed under the terms of the GNU GPL, version 2.
9 */
10#include <linux/spinlock.h>
11#include <linux/module.h>
174cd4b1 12#include <linux/sched/signal.h>
06a8fc78
AH
13#include <linux/ctype.h>
14#include <linux/list.h>
15#include <linux/virtio.h>
16#include <linux/virtio_ids.h>
17#include <linux/virtio_config.h>
18#include <linux/virtio_vsock.h>
19
20#include <net/sock.h>
21#include <net/af_vsock.h>
22
23#define CREATE_TRACE_POINTS
24#include <trace/events/vsock_virtio_transport_common.h>
25
26/* How long to wait for graceful shutdown of a connection */
27#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
28
29static const struct virtio_transport *virtio_transport_get_ops(void)
30{
31 const struct vsock_transport *t = vsock_core_get_transport();
32
33 return container_of(t, struct virtio_transport, transport);
34}
35
819483d8 36static struct virtio_vsock_pkt *
06a8fc78
AH
37virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
38 size_t len,
39 u32 src_cid,
40 u32 src_port,
41 u32 dst_cid,
42 u32 dst_port)
43{
44 struct virtio_vsock_pkt *pkt;
45 int err;
46
47 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
48 if (!pkt)
49 return NULL;
50
51 pkt->hdr.type = cpu_to_le16(info->type);
52 pkt->hdr.op = cpu_to_le16(info->op);
53 pkt->hdr.src_cid = cpu_to_le64(src_cid);
54 pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
55 pkt->hdr.src_port = cpu_to_le32(src_port);
56 pkt->hdr.dst_port = cpu_to_le32(dst_port);
57 pkt->hdr.flags = cpu_to_le32(info->flags);
58 pkt->len = len;
59 pkt->hdr.len = cpu_to_le32(len);
60 pkt->reply = info->reply;
36d277ba 61 pkt->vsk = info->vsk;
06a8fc78
AH
62
63 if (info->msg && len > 0) {
64 pkt->buf = kmalloc(len, GFP_KERNEL);
65 if (!pkt->buf)
66 goto out_pkt;
67 err = memcpy_from_msg(pkt->buf, info->msg, len);
68 if (err)
69 goto out;
70 }
71
72 trace_virtio_transport_alloc_pkt(src_cid, src_port,
73 dst_cid, dst_port,
74 len,
75 info->type,
76 info->op,
77 info->flags);
78
79 return pkt;
80
81out:
82 kfree(pkt->buf);
83out_pkt:
84 kfree(pkt);
85 return NULL;
86}
06a8fc78
AH
87
88static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
89 struct virtio_vsock_pkt_info *info)
90{
91 u32 src_cid, src_port, dst_cid, dst_port;
92 struct virtio_vsock_sock *vvs;
93 struct virtio_vsock_pkt *pkt;
94 u32 pkt_len = info->pkt_len;
95
96 src_cid = vm_sockets_get_local_cid();
97 src_port = vsk->local_addr.svm_port;
98 if (!info->remote_cid) {
99 dst_cid = vsk->remote_addr.svm_cid;
100 dst_port = vsk->remote_addr.svm_port;
101 } else {
102 dst_cid = info->remote_cid;
103 dst_port = info->remote_port;
104 }
105
106 vvs = vsk->trans;
107
108 /* we can send less than pkt_len bytes */
109 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
110 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
111
112 /* virtio_transport_get_credit might return less than pkt_len credit */
113 pkt_len = virtio_transport_get_credit(vvs, pkt_len);
114
115 /* Do not send zero length OP_RW pkt */
116 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
117 return pkt_len;
118
119 pkt = virtio_transport_alloc_pkt(info, pkt_len,
120 src_cid, src_port,
121 dst_cid, dst_port);
122 if (!pkt) {
123 virtio_transport_put_credit(vvs, pkt_len);
124 return -ENOMEM;
125 }
126
127 virtio_transport_inc_tx_pkt(vvs, pkt);
128
129 return virtio_transport_get_ops()->send_pkt(pkt);
130}
131
132static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
133 struct virtio_vsock_pkt *pkt)
134{
135 vvs->rx_bytes += pkt->len;
136}
137
138static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
139 struct virtio_vsock_pkt *pkt)
140{
141 vvs->rx_bytes -= pkt->len;
142 vvs->fwd_cnt += pkt->len;
143}
144
145void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
146{
147 spin_lock_bh(&vvs->tx_lock);
148 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
149 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
150 spin_unlock_bh(&vvs->tx_lock);
151}
152EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
153
154u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
155{
156 u32 ret;
157
158 spin_lock_bh(&vvs->tx_lock);
159 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
160 if (ret > credit)
161 ret = credit;
162 vvs->tx_cnt += ret;
163 spin_unlock_bh(&vvs->tx_lock);
164
165 return ret;
166}
167EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
168
169void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
170{
171 spin_lock_bh(&vvs->tx_lock);
172 vvs->tx_cnt -= credit;
173 spin_unlock_bh(&vvs->tx_lock);
174}
175EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
176
177static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
178 int type,
179 struct virtio_vsock_hdr *hdr)
180{
181 struct virtio_vsock_pkt_info info = {
182 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
183 .type = type,
36d277ba 184 .vsk = vsk,
06a8fc78
AH
185 };
186
187 return virtio_transport_send_pkt_info(vsk, &info);
188}
189
190static ssize_t
191virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
192 struct msghdr *msg,
193 size_t len)
194{
195 struct virtio_vsock_sock *vvs = vsk->trans;
196 struct virtio_vsock_pkt *pkt;
197 size_t bytes, total = 0;
198 int err = -EFAULT;
199
200 spin_lock_bh(&vvs->rx_lock);
201 while (total < len && !list_empty(&vvs->rx_queue)) {
202 pkt = list_first_entry(&vvs->rx_queue,
203 struct virtio_vsock_pkt, list);
204
205 bytes = len - total;
206 if (bytes > pkt->len - pkt->off)
207 bytes = pkt->len - pkt->off;
208
209 /* sk_lock is held by caller so no one else can dequeue.
210 * Unlock rx_lock since memcpy_to_msg() may sleep.
211 */
212 spin_unlock_bh(&vvs->rx_lock);
213
214 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
215 if (err)
216 goto out;
217
218 spin_lock_bh(&vvs->rx_lock);
219
220 total += bytes;
221 pkt->off += bytes;
222 if (pkt->off == pkt->len) {
223 virtio_transport_dec_rx_pkt(vvs, pkt);
224 list_del(&pkt->list);
225 virtio_transport_free_pkt(pkt);
226 }
227 }
228 spin_unlock_bh(&vvs->rx_lock);
229
230 /* Send a credit pkt to peer */
231 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
232 NULL);
233
234 return total;
235
236out:
237 if (total)
238 err = total;
239 return err;
240}
241
242ssize_t
243virtio_transport_stream_dequeue(struct vsock_sock *vsk,
244 struct msghdr *msg,
245 size_t len, int flags)
246{
247 if (flags & MSG_PEEK)
248 return -EOPNOTSUPP;
249
250 return virtio_transport_stream_do_dequeue(vsk, msg, len);
251}
252EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
253
254int
255virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
256 struct msghdr *msg,
257 size_t len, int flags)
258{
259 return -EOPNOTSUPP;
260}
261EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
262
263s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
264{
265 struct virtio_vsock_sock *vvs = vsk->trans;
266 s64 bytes;
267
268 spin_lock_bh(&vvs->rx_lock);
269 bytes = vvs->rx_bytes;
270 spin_unlock_bh(&vvs->rx_lock);
271
272 return bytes;
273}
274EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
275
276static s64 virtio_transport_has_space(struct vsock_sock *vsk)
277{
278 struct virtio_vsock_sock *vvs = vsk->trans;
279 s64 bytes;
280
281 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
282 if (bytes < 0)
283 bytes = 0;
284
285 return bytes;
286}
287
288s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
289{
290 struct virtio_vsock_sock *vvs = vsk->trans;
291 s64 bytes;
292
293 spin_lock_bh(&vvs->tx_lock);
294 bytes = virtio_transport_has_space(vsk);
295 spin_unlock_bh(&vvs->tx_lock);
296
297 return bytes;
298}
299EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
300
301int virtio_transport_do_socket_init(struct vsock_sock *vsk,
302 struct vsock_sock *psk)
303{
304 struct virtio_vsock_sock *vvs;
305
306 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
307 if (!vvs)
308 return -ENOMEM;
309
310 vsk->trans = vvs;
311 vvs->vsk = vsk;
312 if (psk) {
313 struct virtio_vsock_sock *ptrans = psk->trans;
314
315 vvs->buf_size = ptrans->buf_size;
316 vvs->buf_size_min = ptrans->buf_size_min;
317 vvs->buf_size_max = ptrans->buf_size_max;
318 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
319 } else {
320 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
321 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
322 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
323 }
324
325 vvs->buf_alloc = vvs->buf_size;
326
327 spin_lock_init(&vvs->rx_lock);
328 spin_lock_init(&vvs->tx_lock);
329 INIT_LIST_HEAD(&vvs->rx_queue);
330
331 return 0;
332}
333EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
334
335u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
336{
337 struct virtio_vsock_sock *vvs = vsk->trans;
338
339 return vvs->buf_size;
340}
341EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
342
343u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
344{
345 struct virtio_vsock_sock *vvs = vsk->trans;
346
347 return vvs->buf_size_min;
348}
349EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
350
351u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
352{
353 struct virtio_vsock_sock *vvs = vsk->trans;
354
355 return vvs->buf_size_max;
356}
357EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
358
359void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
360{
361 struct virtio_vsock_sock *vvs = vsk->trans;
362
363 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
364 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
365 if (val < vvs->buf_size_min)
366 vvs->buf_size_min = val;
367 if (val > vvs->buf_size_max)
368 vvs->buf_size_max = val;
369 vvs->buf_size = val;
370 vvs->buf_alloc = val;
371}
372EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
373
374void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
375{
376 struct virtio_vsock_sock *vvs = vsk->trans;
377
378 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
379 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
380 if (val > vvs->buf_size)
381 vvs->buf_size = val;
382 vvs->buf_size_min = val;
383}
384EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
385
386void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
387{
388 struct virtio_vsock_sock *vvs = vsk->trans;
389
390 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
391 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
392 if (val < vvs->buf_size)
393 vvs->buf_size = val;
394 vvs->buf_size_max = val;
395}
396EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
397
398int
399virtio_transport_notify_poll_in(struct vsock_sock *vsk,
400 size_t target,
401 bool *data_ready_now)
402{
403 if (vsock_stream_has_data(vsk))
404 *data_ready_now = true;
405 else
406 *data_ready_now = false;
407
408 return 0;
409}
410EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
411
412int
413virtio_transport_notify_poll_out(struct vsock_sock *vsk,
414 size_t target,
415 bool *space_avail_now)
416{
417 s64 free_space;
418
419 free_space = vsock_stream_has_space(vsk);
420 if (free_space > 0)
421 *space_avail_now = true;
422 else if (free_space == 0)
423 *space_avail_now = false;
424
425 return 0;
426}
427EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
428
429int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
430 size_t target, struct vsock_transport_recv_notify_data *data)
431{
432 return 0;
433}
434EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
435
436int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
437 size_t target, struct vsock_transport_recv_notify_data *data)
438{
439 return 0;
440}
441EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
442
443int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
444 size_t target, struct vsock_transport_recv_notify_data *data)
445{
446 return 0;
447}
448EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
449
450int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
451 size_t target, ssize_t copied, bool data_read,
452 struct vsock_transport_recv_notify_data *data)
453{
454 return 0;
455}
456EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
457
458int virtio_transport_notify_send_init(struct vsock_sock *vsk,
459 struct vsock_transport_send_notify_data *data)
460{
461 return 0;
462}
463EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
464
465int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
466 struct vsock_transport_send_notify_data *data)
467{
468 return 0;
469}
470EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
471
472int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
473 struct vsock_transport_send_notify_data *data)
474{
475 return 0;
476}
477EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
478
479int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
480 ssize_t written, struct vsock_transport_send_notify_data *data)
481{
482 return 0;
483}
484EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
485
486u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
487{
488 struct virtio_vsock_sock *vvs = vsk->trans;
489
490 return vvs->buf_size;
491}
492EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
493
494bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
495{
496 return true;
497}
498EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
499
500bool virtio_transport_stream_allow(u32 cid, u32 port)
501{
502 return true;
503}
504EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
505
506int virtio_transport_dgram_bind(struct vsock_sock *vsk,
507 struct sockaddr_vm *addr)
508{
509 return -EOPNOTSUPP;
510}
511EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
512
513bool virtio_transport_dgram_allow(u32 cid, u32 port)
514{
515 return false;
516}
517EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
518
519int virtio_transport_connect(struct vsock_sock *vsk)
520{
521 struct virtio_vsock_pkt_info info = {
522 .op = VIRTIO_VSOCK_OP_REQUEST,
523 .type = VIRTIO_VSOCK_TYPE_STREAM,
36d277ba 524 .vsk = vsk,
06a8fc78
AH
525 };
526
527 return virtio_transport_send_pkt_info(vsk, &info);
528}
529EXPORT_SYMBOL_GPL(virtio_transport_connect);
530
531int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
532{
533 struct virtio_vsock_pkt_info info = {
534 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
535 .type = VIRTIO_VSOCK_TYPE_STREAM,
536 .flags = (mode & RCV_SHUTDOWN ?
537 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
538 (mode & SEND_SHUTDOWN ?
539 VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
36d277ba 540 .vsk = vsk,
06a8fc78
AH
541 };
542
543 return virtio_transport_send_pkt_info(vsk, &info);
544}
545EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
546
547int
548virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
549 struct sockaddr_vm *remote_addr,
550 struct msghdr *msg,
551 size_t dgram_len)
552{
553 return -EOPNOTSUPP;
554}
555EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
556
557ssize_t
558virtio_transport_stream_enqueue(struct vsock_sock *vsk,
559 struct msghdr *msg,
560 size_t len)
561{
562 struct virtio_vsock_pkt_info info = {
563 .op = VIRTIO_VSOCK_OP_RW,
564 .type = VIRTIO_VSOCK_TYPE_STREAM,
565 .msg = msg,
566 .pkt_len = len,
36d277ba 567 .vsk = vsk,
06a8fc78
AH
568 };
569
570 return virtio_transport_send_pkt_info(vsk, &info);
571}
572EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
573
574void virtio_transport_destruct(struct vsock_sock *vsk)
575{
576 struct virtio_vsock_sock *vvs = vsk->trans;
577
578 kfree(vvs);
579}
580EXPORT_SYMBOL_GPL(virtio_transport_destruct);
581
582static int virtio_transport_reset(struct vsock_sock *vsk,
583 struct virtio_vsock_pkt *pkt)
584{
585 struct virtio_vsock_pkt_info info = {
586 .op = VIRTIO_VSOCK_OP_RST,
587 .type = VIRTIO_VSOCK_TYPE_STREAM,
588 .reply = !!pkt,
36d277ba 589 .vsk = vsk,
06a8fc78
AH
590 };
591
592 /* Send RST only if the original pkt is not a RST pkt */
593 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
594 return 0;
595
596 return virtio_transport_send_pkt_info(vsk, &info);
597}
598
599/* Normally packets are associated with a socket. There may be no socket if an
600 * attempt was made to connect to a socket that does not exist.
601 */
602static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
603{
604 struct virtio_vsock_pkt_info info = {
605 .op = VIRTIO_VSOCK_OP_RST,
606 .type = le16_to_cpu(pkt->hdr.type),
607 .reply = true,
608 };
609
610 /* Send RST only if the original pkt is not a RST pkt */
611 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
612 return 0;
613
614 pkt = virtio_transport_alloc_pkt(&info, 0,
f83f12d6 615 le64_to_cpu(pkt->hdr.dst_cid),
06a8fc78 616 le32_to_cpu(pkt->hdr.dst_port),
f83f12d6 617 le64_to_cpu(pkt->hdr.src_cid),
06a8fc78
AH
618 le32_to_cpu(pkt->hdr.src_port));
619 if (!pkt)
620 return -ENOMEM;
621
622 return virtio_transport_get_ops()->send_pkt(pkt);
623}
624
625static void virtio_transport_wait_close(struct sock *sk, long timeout)
626{
627 if (timeout) {
d9dc8b0f
WC
628 DEFINE_WAIT_FUNC(wait, woken_wake_function);
629
630 add_wait_queue(sk_sleep(sk), &wait);
06a8fc78
AH
631
632 do {
06a8fc78 633 if (sk_wait_event(sk, &timeout,
d9dc8b0f 634 sock_flag(sk, SOCK_DONE), &wait))
06a8fc78
AH
635 break;
636 } while (!signal_pending(current) && timeout);
637
d9dc8b0f 638 remove_wait_queue(sk_sleep(sk), &wait);
06a8fc78
AH
639 }
640}
641
642static void virtio_transport_do_close(struct vsock_sock *vsk,
643 bool cancel_timeout)
644{
645 struct sock *sk = sk_vsock(vsk);
646
647 sock_set_flag(sk, SOCK_DONE);
648 vsk->peer_shutdown = SHUTDOWN_MASK;
649 if (vsock_stream_has_data(vsk) <= 0)
650 sk->sk_state = SS_DISCONNECTING;
651 sk->sk_state_change(sk);
652
653 if (vsk->close_work_scheduled &&
654 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
655 vsk->close_work_scheduled = false;
656
657 vsock_remove_sock(vsk);
658
659 /* Release refcnt obtained when we scheduled the timeout */
660 sock_put(sk);
661 }
662}
663
664static void virtio_transport_close_timeout(struct work_struct *work)
665{
666 struct vsock_sock *vsk =
667 container_of(work, struct vsock_sock, close_work.work);
668 struct sock *sk = sk_vsock(vsk);
669
670 sock_hold(sk);
671 lock_sock(sk);
672
673 if (!sock_flag(sk, SOCK_DONE)) {
674 (void)virtio_transport_reset(vsk, NULL);
675
676 virtio_transport_do_close(vsk, false);
677 }
678
679 vsk->close_work_scheduled = false;
680
681 release_sock(sk);
682 sock_put(sk);
683}
684
685/* User context, vsk->sk is locked */
686static bool virtio_transport_close(struct vsock_sock *vsk)
687{
688 struct sock *sk = &vsk->sk;
689
690 if (!(sk->sk_state == SS_CONNECTED ||
691 sk->sk_state == SS_DISCONNECTING))
692 return true;
693
694 /* Already received SHUTDOWN from peer, reply with RST */
695 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
696 (void)virtio_transport_reset(vsk, NULL);
697 return true;
698 }
699
700 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
701 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
702
703 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
704 virtio_transport_wait_close(sk, sk->sk_lingertime);
705
706 if (sock_flag(sk, SOCK_DONE)) {
707 return true;
708 }
709
710 sock_hold(sk);
711 INIT_DELAYED_WORK(&vsk->close_work,
712 virtio_transport_close_timeout);
713 vsk->close_work_scheduled = true;
714 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
715 return false;
716}
717
718void virtio_transport_release(struct vsock_sock *vsk)
719{
720 struct sock *sk = &vsk->sk;
721 bool remove_sock = true;
722
723 lock_sock(sk);
724 if (sk->sk_type == SOCK_STREAM)
725 remove_sock = virtio_transport_close(vsk);
726 release_sock(sk);
727
728 if (remove_sock)
729 vsock_remove_sock(vsk);
730}
731EXPORT_SYMBOL_GPL(virtio_transport_release);
732
733static int
734virtio_transport_recv_connecting(struct sock *sk,
735 struct virtio_vsock_pkt *pkt)
736{
737 struct vsock_sock *vsk = vsock_sk(sk);
738 int err;
739 int skerr;
740
741 switch (le16_to_cpu(pkt->hdr.op)) {
742 case VIRTIO_VSOCK_OP_RESPONSE:
743 sk->sk_state = SS_CONNECTED;
744 sk->sk_socket->state = SS_CONNECTED;
745 vsock_insert_connected(vsk);
746 sk->sk_state_change(sk);
747 break;
748 case VIRTIO_VSOCK_OP_INVALID:
749 break;
750 case VIRTIO_VSOCK_OP_RST:
751 skerr = ECONNRESET;
752 err = 0;
753 goto destroy;
754 default:
755 skerr = EPROTO;
756 err = -EINVAL;
757 goto destroy;
758 }
759 return 0;
760
761destroy:
762 virtio_transport_reset(vsk, pkt);
763 sk->sk_state = SS_UNCONNECTED;
764 sk->sk_err = skerr;
765 sk->sk_error_report(sk);
766 return err;
767}
768
769static int
770virtio_transport_recv_connected(struct sock *sk,
771 struct virtio_vsock_pkt *pkt)
772{
773 struct vsock_sock *vsk = vsock_sk(sk);
774 struct virtio_vsock_sock *vvs = vsk->trans;
775 int err = 0;
776
777 switch (le16_to_cpu(pkt->hdr.op)) {
778 case VIRTIO_VSOCK_OP_RW:
779 pkt->len = le32_to_cpu(pkt->hdr.len);
780 pkt->off = 0;
781
782 spin_lock_bh(&vvs->rx_lock);
783 virtio_transport_inc_rx_pkt(vvs, pkt);
784 list_add_tail(&pkt->list, &vvs->rx_queue);
785 spin_unlock_bh(&vvs->rx_lock);
786
787 sk->sk_data_ready(sk);
788 return err;
789 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
790 sk->sk_write_space(sk);
791 break;
792 case VIRTIO_VSOCK_OP_SHUTDOWN:
793 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
794 vsk->peer_shutdown |= RCV_SHUTDOWN;
795 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
796 vsk->peer_shutdown |= SEND_SHUTDOWN;
797 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
798 vsock_stream_has_data(vsk) <= 0)
799 sk->sk_state = SS_DISCONNECTING;
800 if (le32_to_cpu(pkt->hdr.flags))
801 sk->sk_state_change(sk);
802 break;
803 case VIRTIO_VSOCK_OP_RST:
804 virtio_transport_do_close(vsk, true);
805 break;
806 default:
807 err = -EINVAL;
808 break;
809 }
810
811 virtio_transport_free_pkt(pkt);
812 return err;
813}
814
815static void
816virtio_transport_recv_disconnecting(struct sock *sk,
817 struct virtio_vsock_pkt *pkt)
818{
819 struct vsock_sock *vsk = vsock_sk(sk);
820
821 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
822 virtio_transport_do_close(vsk, true);
823}
824
825static int
826virtio_transport_send_response(struct vsock_sock *vsk,
827 struct virtio_vsock_pkt *pkt)
828{
829 struct virtio_vsock_pkt_info info = {
830 .op = VIRTIO_VSOCK_OP_RESPONSE,
831 .type = VIRTIO_VSOCK_TYPE_STREAM,
f83f12d6 832 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
06a8fc78
AH
833 .remote_port = le32_to_cpu(pkt->hdr.src_port),
834 .reply = true,
36d277ba 835 .vsk = vsk,
06a8fc78
AH
836 };
837
838 return virtio_transport_send_pkt_info(vsk, &info);
839}
840
841/* Handle server socket */
842static int
843virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
844{
845 struct vsock_sock *vsk = vsock_sk(sk);
846 struct vsock_sock *vchild;
847 struct sock *child;
848
849 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
850 virtio_transport_reset(vsk, pkt);
851 return -EINVAL;
852 }
853
854 if (sk_acceptq_is_full(sk)) {
855 virtio_transport_reset(vsk, pkt);
856 return -ENOMEM;
857 }
858
859 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
860 sk->sk_type, 0);
861 if (!child) {
862 virtio_transport_reset(vsk, pkt);
863 return -ENOMEM;
864 }
865
866 sk->sk_ack_backlog++;
867
868 lock_sock_nested(child, SINGLE_DEPTH_NESTING);
869
870 child->sk_state = SS_CONNECTED;
871
872 vchild = vsock_sk(child);
f83f12d6 873 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
06a8fc78 874 le32_to_cpu(pkt->hdr.dst_port));
f83f12d6 875 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
06a8fc78
AH
876 le32_to_cpu(pkt->hdr.src_port));
877
878 vsock_insert_connected(vchild);
879 vsock_enqueue_accept(sk, child);
880 virtio_transport_send_response(vchild, pkt);
881
882 release_sock(child);
883
884 sk->sk_data_ready(sk);
885 return 0;
886}
887
888static bool virtio_transport_space_update(struct sock *sk,
889 struct virtio_vsock_pkt *pkt)
890{
891 struct vsock_sock *vsk = vsock_sk(sk);
892 struct virtio_vsock_sock *vvs = vsk->trans;
893 bool space_available;
894
895 /* buf_alloc and fwd_cnt is always included in the hdr */
896 spin_lock_bh(&vvs->tx_lock);
897 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
898 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
899 space_available = virtio_transport_has_space(vsk);
900 spin_unlock_bh(&vvs->tx_lock);
901 return space_available;
902}
903
904/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
905 * lock.
906 */
907void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
908{
909 struct sockaddr_vm src, dst;
910 struct vsock_sock *vsk;
911 struct sock *sk;
912 bool space_available;
913
f83f12d6 914 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
06a8fc78 915 le32_to_cpu(pkt->hdr.src_port));
f83f12d6 916 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
06a8fc78
AH
917 le32_to_cpu(pkt->hdr.dst_port));
918
919 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
920 dst.svm_cid, dst.svm_port,
921 le32_to_cpu(pkt->hdr.len),
922 le16_to_cpu(pkt->hdr.type),
923 le16_to_cpu(pkt->hdr.op),
924 le32_to_cpu(pkt->hdr.flags),
925 le32_to_cpu(pkt->hdr.buf_alloc),
926 le32_to_cpu(pkt->hdr.fwd_cnt));
927
928 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
929 (void)virtio_transport_reset_no_sock(pkt);
930 goto free_pkt;
931 }
932
933 /* The socket must be in connected or bound table
934 * otherwise send reset back
935 */
936 sk = vsock_find_connected_socket(&src, &dst);
937 if (!sk) {
938 sk = vsock_find_bound_socket(&dst);
939 if (!sk) {
940 (void)virtio_transport_reset_no_sock(pkt);
941 goto free_pkt;
942 }
943 }
944
945 vsk = vsock_sk(sk);
946
947 space_available = virtio_transport_space_update(sk, pkt);
948
949 lock_sock(sk);
950
951 /* Update CID in case it has changed after a transport reset event */
952 vsk->local_addr.svm_cid = dst.svm_cid;
953
954 if (space_available)
955 sk->sk_write_space(sk);
956
957 switch (sk->sk_state) {
958 case VSOCK_SS_LISTEN:
959 virtio_transport_recv_listen(sk, pkt);
960 virtio_transport_free_pkt(pkt);
961 break;
962 case SS_CONNECTING:
963 virtio_transport_recv_connecting(sk, pkt);
964 virtio_transport_free_pkt(pkt);
965 break;
966 case SS_CONNECTED:
967 virtio_transport_recv_connected(sk, pkt);
968 break;
969 case SS_DISCONNECTING:
970 virtio_transport_recv_disconnecting(sk, pkt);
971 virtio_transport_free_pkt(pkt);
972 break;
973 default:
974 virtio_transport_free_pkt(pkt);
975 break;
976 }
977 release_sock(sk);
978
979 /* Release refcnt obtained when we fetched this socket out of the
980 * bound or connected list.
981 */
982 sock_put(sk);
983 return;
984
985free_pkt:
986 virtio_transport_free_pkt(pkt);
987}
988EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
989
990void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
991{
992 kfree(pkt->buf);
993 kfree(pkt);
994}
995EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
996
997MODULE_LICENSE("GPL v2");
998MODULE_AUTHOR("Asias He");
999MODULE_DESCRIPTION("common code for virtio vsock");