]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blame - net/vmw_vsock/virtio_transport_common.c
Merge branch 'sh_eth-optimize-mdio'
[mirror_ubuntu-bionic-kernel.git] / net / vmw_vsock / virtio_transport_common.c
CommitLineData
80a19e33
AH
1/*
2 * common code for virtio vsock
3 *
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * This work is licensed under the terms of the GNU GPL, version 2.
9 */
10#include <linux/module.h>
11#include <linux/ctype.h>
12#include <linux/list.h>
13#include <linux/virtio.h>
14#include <linux/virtio_ids.h>
15#include <linux/virtio_config.h>
16#include <linux/virtio_vsock.h>
17#include <linux/random.h>
18#include <linux/cryptohash.h>
19
20#include <net/sock.h>
21#include <net/af_vsock.h>
22
23#define COOKIEBITS 24
24#define COOKIEMASK (((u32)1 << COOKIEBITS) - 1)
25#define VSOCK_TIMEOUT_INIT 4
26
27#define SHA_MESSAGE_WORDS 16
28#define SHA_VSOCK_WORDS 5
29
30static u32 vsockcookie_secret[2][SHA_MESSAGE_WORDS - SHA_VSOCK_WORDS +
31 SHA_DIGEST_WORDS];
32
33static DEFINE_PER_CPU(__u32[SHA_MESSAGE_WORDS + SHA_DIGEST_WORDS +
34 SHA_WORKSPACE_WORDS], vsock_cookie_scratch);
35
36static u32 cookie_hash(u32 saddr, u32 daddr, u16 sport, u16 dport,
37 u32 count, int c)
38{
39 __u32 *tmp = this_cpu_ptr(vsock_cookie_scratch);
40
41 memcpy(tmp + SHA_VSOCK_WORDS, vsockcookie_secret[c],
42 sizeof(vsockcookie_secret[c]));
43 tmp[0] = saddr;
44 tmp[1] = daddr;
45 tmp[2] = sport;
46 tmp[3] = dport;
47 tmp[4] = count;
48 sha_transform(tmp + SHA_MESSAGE_WORDS, (__u8 *)tmp,
49 tmp + SHA_MESSAGE_WORDS + SHA_DIGEST_WORDS);
50
51 return tmp[17];
52}
53
54static u32
55virtio_vsock_secure_cookie(u32 saddr, u32 daddr, u32 sport, u32 dport,
56 u32 count)
57{
58 u32 h1, h2;
59
60 h1 = cookie_hash(saddr, daddr, sport, dport, 0, 0);
61 h2 = cookie_hash(saddr, daddr, sport, dport, count, 1);
62
63 return h1 + (count << COOKIEBITS) + (h2 & COOKIEMASK);
64}
65
66static u32
67virtio_vsock_check_cookie(u32 saddr, u32 daddr, u32 sport, u32 dport,
68 u32 count, u32 cookie, u32 maxdiff)
69{
70 u32 diff;
71 u32 ret;
72
73 cookie -= cookie_hash(saddr, daddr, sport, dport, 0, 0);
74
75 diff = (count - (cookie >> COOKIEBITS)) & ((u32)-1 >> COOKIEBITS);
76 pr_debug("%s: diff=%x\n", __func__, diff);
77 if (diff >= maxdiff)
78 return (u32)-1;
79
80 ret = (cookie -
81 cookie_hash(saddr, daddr, sport, dport, count - diff, 1))
82 & COOKIEMASK;
83 pr_debug("%s: ret=%x\n", __func__, diff);
84
85 return ret;
86}
87
88void virtio_vsock_dumppkt(const char *func, const struct virtio_vsock_pkt *pkt)
89{
90 pr_debug("%s: pkt=%p, op=%d, len=%d, %d:%d---%d:%d, len=%d\n",
91 func, pkt,
92 le16_to_cpu(pkt->hdr.op),
93 le32_to_cpu(pkt->hdr.len),
94 le32_to_cpu(pkt->hdr.src_cid),
95 le32_to_cpu(pkt->hdr.src_port),
96 le32_to_cpu(pkt->hdr.dst_cid),
97 le32_to_cpu(pkt->hdr.dst_port),
98 pkt->len);
99}
100EXPORT_SYMBOL_GPL(virtio_vsock_dumppkt);
101
102struct virtio_vsock_pkt *
103virtio_transport_alloc_pkt(struct vsock_sock *vsk,
104 struct virtio_vsock_pkt_info *info,
105 size_t len,
106 u32 src_cid,
107 u32 src_port,
108 u32 dst_cid,
109 u32 dst_port)
110{
111 struct virtio_transport *trans = vsk->trans;
112 struct virtio_vsock_pkt *pkt;
113 int err;
114
115 BUG_ON(!trans);
116
117 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
118 if (!pkt)
119 return NULL;
120
121 pkt->hdr.type = cpu_to_le16(info->type);
122 pkt->hdr.op = cpu_to_le16(info->op);
123 pkt->hdr.src_cid = cpu_to_le32(src_cid);
124 pkt->hdr.src_port = cpu_to_le32(src_port);
125 pkt->hdr.dst_cid = cpu_to_le32(dst_cid);
126 pkt->hdr.dst_port = cpu_to_le32(dst_port);
127 pkt->hdr.flags = cpu_to_le32(info->flags);
128 pkt->len = len;
129 pkt->trans = trans;
130 if (info->type == VIRTIO_VSOCK_TYPE_DGRAM)
131 pkt->hdr.len = cpu_to_le32(len + (info->dgram_len << 16));
132 else if (info->type == VIRTIO_VSOCK_TYPE_STREAM)
133 pkt->hdr.len = cpu_to_le32(len);
134
135 if (info->msg && len > 0) {
136 pkt->buf = kmalloc(len, GFP_KERNEL);
137 if (!pkt->buf)
138 goto out_pkt;
139 err = memcpy_from_msg(pkt->buf, info->msg, len);
140 if (err)
141 goto out;
142 }
143
144 return pkt;
145
146out:
147 kfree(pkt->buf);
148out_pkt:
149 kfree(pkt);
150 return NULL;
151}
152EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
153
154struct sock *
155virtio_transport_get_pending(struct sock *listener,
156 struct virtio_vsock_pkt *pkt)
157{
158 struct vsock_sock *vlistener;
159 struct vsock_sock *vpending;
160 struct sockaddr_vm src;
161 struct sockaddr_vm dst;
162 struct sock *pending;
163
164 vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port));
165 vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port));
166
167 vlistener = vsock_sk(listener);
168 list_for_each_entry(vpending, &vlistener->pending_links,
169 pending_links) {
170 if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
171 vsock_addr_equals_addr(&dst, &vpending->local_addr)) {
172 pending = sk_vsock(vpending);
173 sock_hold(pending);
174 return pending;
175 }
176 }
177
178 return NULL;
179}
180EXPORT_SYMBOL_GPL(virtio_transport_get_pending);
181
182static void virtio_transport_inc_rx_pkt(struct virtio_vsock_pkt *pkt)
183{
184 pkt->trans->rx_bytes += pkt->len;
185}
186
187static void virtio_transport_dec_rx_pkt(struct virtio_vsock_pkt *pkt)
188{
189 pkt->trans->rx_bytes -= pkt->len;
190 pkt->trans->fwd_cnt += pkt->len;
191}
192
193void virtio_transport_inc_tx_pkt(struct virtio_vsock_pkt *pkt)
194{
195 mutex_lock(&pkt->trans->tx_lock);
196 pkt->hdr.fwd_cnt = cpu_to_le32(pkt->trans->fwd_cnt);
197 pkt->hdr.buf_alloc = cpu_to_le32(pkt->trans->buf_alloc);
198 mutex_unlock(&pkt->trans->tx_lock);
199}
200EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
201
202void virtio_transport_dec_tx_pkt(struct virtio_vsock_pkt *pkt)
203{
204}
205EXPORT_SYMBOL_GPL(virtio_transport_dec_tx_pkt);
206
207u32 virtio_transport_get_credit(struct virtio_transport *trans, u32 credit)
208{
209 u32 ret;
210
211 mutex_lock(&trans->tx_lock);
212 ret = trans->peer_buf_alloc - (trans->tx_cnt - trans->peer_fwd_cnt);
213 if (ret > credit)
214 ret = credit;
215 trans->tx_cnt += ret;
216 mutex_unlock(&trans->tx_lock);
217
218 pr_debug("%s: ret=%d, buf_alloc=%d, peer_buf_alloc=%d,"
219 "tx_cnt=%d, fwd_cnt=%d, peer_fwd_cnt=%d\n", __func__,
220 ret, trans->buf_alloc, trans->peer_buf_alloc,
221 trans->tx_cnt, trans->fwd_cnt, trans->peer_fwd_cnt);
222
223 return ret;
224}
225EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
226
227void virtio_transport_put_credit(struct virtio_transport *trans, u32 credit)
228{
229 mutex_lock(&trans->tx_lock);
230 trans->tx_cnt -= credit;
231 mutex_unlock(&trans->tx_lock);
232}
233EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
234
235static int virtio_transport_send_credit_update(struct vsock_sock *vsk, int type, struct virtio_vsock_hdr *hdr)
236{
237 struct virtio_transport *trans = vsk->trans;
238 struct virtio_vsock_pkt_info info = {
239 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
240 .type = type,
241 };
242
243 if (hdr && type == VIRTIO_VSOCK_TYPE_DGRAM) {
244 info.remote_cid = le32_to_cpu(hdr->src_cid);
245 info.remote_port = le32_to_cpu(hdr->src_port);
246 }
247
248 pr_debug("%s: sk=%p send_credit_update\n", __func__, vsk);
249 return trans->ops->send_pkt(vsk, &info);
250}
251
252static int virtio_transport_send_credit_request(struct vsock_sock *vsk, int type)
253{
254 struct virtio_transport *trans = vsk->trans;
255 struct virtio_vsock_pkt_info info = {
256 .op = VIRTIO_VSOCK_OP_CREDIT_REQUEST,
257 .type = type,
258 };
259
260 pr_debug("%s: sk=%p send_credit_request\n", __func__, vsk);
261 return trans->ops->send_pkt(vsk, &info);
262}
263
264static ssize_t
265virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
266 struct msghdr *msg,
267 size_t len)
268{
269 struct virtio_transport *trans = vsk->trans;
270 struct virtio_vsock_pkt *pkt;
271 size_t bytes, total = 0;
272 int err = -EFAULT;
273
274 mutex_lock(&trans->rx_lock);
275 while (total < len && trans->rx_bytes > 0 &&
276 !list_empty(&trans->rx_queue)) {
277 pkt = list_first_entry(&trans->rx_queue,
278 struct virtio_vsock_pkt, list);
279
280 bytes = len - total;
281 if (bytes > pkt->len - pkt->off)
282 bytes = pkt->len - pkt->off;
283
284 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
285 if (err)
286 goto out;
287 total += bytes;
288 pkt->off += bytes;
289 if (pkt->off == pkt->len) {
290 virtio_transport_dec_rx_pkt(pkt);
291 list_del(&pkt->list);
292 virtio_transport_free_pkt(pkt);
293 }
294 }
295 mutex_unlock(&trans->rx_lock);
296
297 /* Send a credit pkt to peer */
298 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
299 NULL);
300
301 return total;
302
303out:
304 mutex_unlock(&trans->rx_lock);
305 if (total)
306 err = total;
307 return err;
308}
309
310ssize_t
311virtio_transport_stream_dequeue(struct vsock_sock *vsk,
312 struct msghdr *msg,
313 size_t len, int flags)
314{
315 if (flags & MSG_PEEK)
316 return -EOPNOTSUPP;
317
318 return virtio_transport_stream_do_dequeue(vsk, msg, len);
319}
320EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
321
322struct dgram_skb {
323 struct list_head list;
324 struct sk_buff *skb;
325 u16 id;
326};
327
328static struct dgram_skb *dgram_id_to_skb(struct virtio_transport *trans,
329 u16 id)
330{
331 struct dgram_skb *dgram_skb;
332
333 list_for_each_entry(dgram_skb, &trans->incomplete_dgrams, list) {
334 if (dgram_skb->id == id)
335 return dgram_skb;
336 }
337
338 return NULL;
339}
340
341static void
342virtio_transport_recv_dgram(struct sock *sk,
343 struct virtio_vsock_pkt *pkt)
344{
345 struct sk_buff *skb = NULL;
346 struct vsock_sock *vsk;
347 struct virtio_transport *trans;
348 size_t size;
349 u16 dgram_id, pkt_off, dgram_len, pkt_len;
350 u32 flags, len;
351 struct dgram_skb *dgram_skb;
352
353 vsk = vsock_sk(sk);
354 trans = vsk->trans;
355
356 /* len: dgram_len | pkt_len */
357 len = le32_to_cpu(pkt->hdr.len);
358 dgram_len = len >> 16;
359 pkt_len = len & 0xFFFF;
360
361 /* flags: dgram_id | pkt_off */
362 flags = le32_to_cpu(pkt->hdr.flags);
363 dgram_id = flags >> 16;
364 pkt_off = flags & 0xFFFF;
365
366 pr_debug("%s: dgram_len=%d, pkt_len=%d, id=%d, off=%d\n", __func__,
367 dgram_len, pkt_len, dgram_id, pkt_off);
368
369 dgram_skb = dgram_id_to_skb(trans, dgram_id);
370 if (dgram_skb) {
371 /* This pkt is for a existing dgram */
372 skb = dgram_skb->skb;
373 pr_debug("%s:found skb\n", __func__);
374 }
375
376 /* Packet payload must be within datagram bounds */
377 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
378 goto drop;
379 if (pkt_len > dgram_len)
380 goto drop;
381 if (pkt_off > dgram_len)
382 goto drop;
383 if (dgram_len - pkt_off < pkt_len)
384 goto drop;
385
386 if (!skb) {
387 /* This pkt is for a new dgram */
388 pr_debug("%s:create skb\n", __func__);
389
390 size = sizeof(pkt->hdr) + dgram_len;
391 /* Attach the packet to the socket's receive queue as an sk_buff. */
392 dgram_skb = kzalloc(sizeof(struct dgram_skb), GFP_ATOMIC);
393 if (!dgram_skb)
394 goto drop;
395
396 skb = alloc_skb(size, GFP_ATOMIC);
397 if (!skb) {
398 kfree(dgram_skb);
399 dgram_skb = NULL;
400 goto drop;
401 }
402 dgram_skb->id = dgram_id;
403 dgram_skb->skb = skb;
404 list_add_tail(&dgram_skb->list, &trans->incomplete_dgrams);
405
406 /* sk_receive_skb() will do a sock_put(), so hold here. */
407 sock_hold(sk);
408 skb_put(skb, size);
409 memcpy(skb->data, &pkt->hdr, sizeof(pkt->hdr));
410 }
411
412 memcpy(skb->data + sizeof(pkt->hdr) + pkt_off, pkt->buf, pkt_len);
413
414 pr_debug("%s:C, off=%d, pkt_len=%d, dgram_len=%d\n", __func__,
415 pkt_off, pkt_len, dgram_len);
416
417 /* We are done with this dgram */
418 if (pkt_off + pkt_len == dgram_len) {
419 pr_debug("%s:dgram_id=%d is done\n", __func__, dgram_id);
420 list_del(&dgram_skb->list);
421 kfree(dgram_skb);
422 sk_receive_skb(sk, skb, 0);
423 }
424 virtio_transport_free_pkt(pkt);
425 return;
426
427drop:
428 if (dgram_skb) {
429 list_del(&dgram_skb->list);
430 kfree(dgram_skb);
431 kfree_skb(skb);
432 sock_put(sk);
433 }
434 virtio_transport_free_pkt(pkt);
435}
436
437int
438virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
439 struct msghdr *msg,
440 size_t len, int flags)
441{
442 struct virtio_vsock_hdr *hdr;
443 struct sk_buff *skb;
444 int noblock;
445 int err;
446 int dgram_len;
447
448 noblock = flags & MSG_DONTWAIT;
449
450 if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
451 return -EOPNOTSUPP;
452
453 /* Retrieve the head sk_buff from the socket's receive queue. */
454 err = 0;
455 skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err);
456 if (err)
457 return err;
458 if (!skb)
459 return -EAGAIN;
460
461 hdr = (struct virtio_vsock_hdr *)skb->data;
462 if (!hdr)
463 goto out;
464
465 dgram_len = le32_to_cpu(hdr->len) >> 16;
466 /* Place the datagram payload in the user's iovec. */
467 err = skb_copy_datagram_msg(skb, sizeof(*hdr), msg, dgram_len);
468 if (err)
469 goto out;
470
471 if (msg->msg_name) {
472 /* Provide the address of the sender. */
473 DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name);
474 vsock_addr_init(vm_addr, le32_to_cpu(hdr->src_cid), le32_to_cpu(hdr->src_port));
475 msg->msg_namelen = sizeof(*vm_addr);
476 }
477 err = dgram_len;
478
479 /* Send a credit pkt to peer */
480 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_DGRAM, hdr);
481
482 pr_debug("%s:done, recved =%d\n", __func__, dgram_len);
483out:
484 skb_free_datagram(&vsk->sk, skb);
485 return err;
486}
487EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
488
489s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
490{
491 struct virtio_transport *trans = vsk->trans;
492 s64 bytes;
493
494 mutex_lock(&trans->rx_lock);
495 bytes = trans->rx_bytes;
496 mutex_unlock(&trans->rx_lock);
497
498 return bytes;
499}
500EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
501
502static s64 virtio_transport_has_space(struct vsock_sock *vsk)
503{
504 struct virtio_transport *trans = vsk->trans;
505 s64 bytes;
506
507 bytes = trans->peer_buf_alloc - (trans->tx_cnt - trans->peer_fwd_cnt);
508 if (bytes < 0)
509 bytes = 0;
510
511 return bytes;
512}
513
514s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
515{
516 struct virtio_transport *trans = vsk->trans;
517 s64 bytes;
518
519 mutex_lock(&trans->tx_lock);
520 bytes = virtio_transport_has_space(vsk);
521 mutex_unlock(&trans->tx_lock);
522
523 pr_debug("%s: bytes=%lld\n", __func__, bytes);
524
525 return bytes;
526}
527EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
528
529int virtio_transport_do_socket_init(struct vsock_sock *vsk,
530 struct vsock_sock *psk)
531{
532 struct virtio_transport *trans;
533
534 trans = kzalloc(sizeof(*trans), GFP_KERNEL);
535 if (!trans)
536 return -ENOMEM;
537
538 vsk->trans = trans;
539 trans->vsk = vsk;
540 if (psk) {
541 struct virtio_transport *ptrans = psk->trans;
542 trans->buf_size = ptrans->buf_size;
543 trans->buf_size_min = ptrans->buf_size_min;
544 trans->buf_size_max = ptrans->buf_size_max;
545 trans->peer_buf_alloc = ptrans->peer_buf_alloc;
546 } else {
547 trans->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
548 trans->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
549 trans->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
550 }
551
552 trans->buf_alloc = trans->buf_size;
553
554 pr_debug("%s: trans->buf_alloc=%d\n", __func__, trans->buf_alloc);
555
556 mutex_init(&trans->rx_lock);
557 mutex_init(&trans->tx_lock);
558 INIT_LIST_HEAD(&trans->rx_queue);
559 INIT_LIST_HEAD(&trans->incomplete_dgrams);
560
561 return 0;
562}
563EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
564
565u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
566{
567 struct virtio_transport *trans = vsk->trans;
568
569 return trans->buf_size;
570}
571EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
572
573u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
574{
575 struct virtio_transport *trans = vsk->trans;
576
577 return trans->buf_size_min;
578}
579EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
580
581u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
582{
583 struct virtio_transport *trans = vsk->trans;
584
585 return trans->buf_size_max;
586}
587EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
588
589void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
590{
591 struct virtio_transport *trans = vsk->trans;
592
593 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
594 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
595 if (val < trans->buf_size_min)
596 trans->buf_size_min = val;
597 if (val > trans->buf_size_max)
598 trans->buf_size_max = val;
599 trans->buf_size = val;
600 trans->buf_alloc = val;
601}
602EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
603
604void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
605{
606 struct virtio_transport *trans = vsk->trans;
607
608 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
609 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
610 if (val > trans->buf_size)
611 trans->buf_size = val;
612 trans->buf_size_min = val;
613}
614EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
615
616void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
617{
618 struct virtio_transport *trans = vsk->trans;
619
620 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
621 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
622 if (val < trans->buf_size)
623 trans->buf_size = val;
624 trans->buf_size_max = val;
625}
626EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
627
628int
629virtio_transport_notify_poll_in(struct vsock_sock *vsk,
630 size_t target,
631 bool *data_ready_now)
632{
633 if (vsock_stream_has_data(vsk))
634 *data_ready_now = true;
635 else
636 *data_ready_now = false;
637
638 return 0;
639}
640EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
641
642int
643virtio_transport_notify_poll_out(struct vsock_sock *vsk,
644 size_t target,
645 bool *space_avail_now)
646{
647 s64 free_space;
648
649 free_space = vsock_stream_has_space(vsk);
650 if (free_space > 0)
651 *space_avail_now = true;
652 else if (free_space == 0)
653 *space_avail_now = false;
654
655 return 0;
656}
657EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
658
659int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
660 size_t target, struct vsock_transport_recv_notify_data *data)
661{
662 return 0;
663}
664EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
665
666int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
667 size_t target, struct vsock_transport_recv_notify_data *data)
668{
669 return 0;
670}
671EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
672
673int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
674 size_t target, struct vsock_transport_recv_notify_data *data)
675{
676 return 0;
677}
678EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
679
680int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
681 size_t target, ssize_t copied, bool data_read,
682 struct vsock_transport_recv_notify_data *data)
683{
684 return 0;
685}
686EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
687
688int virtio_transport_notify_send_init(struct vsock_sock *vsk,
689 struct vsock_transport_send_notify_data *data)
690{
691 return 0;
692}
693EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
694
695int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
696 struct vsock_transport_send_notify_data *data)
697{
698 return 0;
699}
700EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
701
702int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
703 struct vsock_transport_send_notify_data *data)
704{
705 return 0;
706}
707EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
708
709int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
710 ssize_t written, struct vsock_transport_send_notify_data *data)
711{
712 return 0;
713}
714EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
715
716u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
717{
718 struct virtio_transport *trans = vsk->trans;
719
720 return trans->buf_size;
721}
722EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
723
724bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
725{
726 return true;
727}
728EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
729
730bool virtio_transport_stream_allow(u32 cid, u32 port)
731{
732 return true;
733}
734EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
735
736int virtio_transport_dgram_bind(struct vsock_sock *vsk,
737 struct sockaddr_vm *addr)
738{
739 return vsock_bind_dgram_generic(vsk, addr);
740}
741EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
742
743bool virtio_transport_dgram_allow(u32 cid, u32 port)
744{
745 return true;
746}
747EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
748
749int virtio_transport_connect(struct vsock_sock *vsk)
750{
751 struct virtio_transport *trans = vsk->trans;
752 struct virtio_vsock_pkt_info info = {
753 .op = VIRTIO_VSOCK_OP_REQUEST,
754 .type = VIRTIO_VSOCK_TYPE_STREAM,
755 };
756
757 pr_debug("%s: vsk=%p send_request\n", __func__, vsk);
758 return trans->ops->send_pkt(vsk, &info);
759}
760EXPORT_SYMBOL_GPL(virtio_transport_connect);
761
762int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
763{
764 struct virtio_transport *trans = vsk->trans;
765 struct virtio_vsock_pkt_info info = {
766 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
767 .type = VIRTIO_VSOCK_TYPE_STREAM,
768 .flags = (mode & RCV_SHUTDOWN ?
769 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
770 (mode & SEND_SHUTDOWN ?
771 VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
772 };
773
774 pr_debug("%s: vsk=%p: send_shutdown\n", __func__, vsk);
775 return trans->ops->send_pkt(vsk, &info);
776}
777EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
778
779void virtio_transport_release(struct vsock_sock *vsk)
780{
781 struct virtio_transport *trans = vsk->trans;
782 struct sock *sk = &vsk->sk;
783 struct dgram_skb *dgram_skb;
784 struct dgram_skb *dgram_skb_tmp;
785
786 pr_debug("%s: vsk=%p\n", __func__, vsk);
787
788 /* Tell other side to terminate connection */
789 if (sk->sk_type == SOCK_STREAM && sk->sk_state == SS_CONNECTED) {
790 virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
791 }
792
793 /* Free incomplete dgrams */
794 lock_sock(sk);
795 list_for_each_entry_safe(dgram_skb, dgram_skb_tmp,
796 &trans->incomplete_dgrams, list) {
797 list_del(&dgram_skb->list);
798 kfree_skb(dgram_skb->skb);
799 kfree(dgram_skb);
800 sock_put(sk); /* held in virtio_transport_recv_dgram() */
801 }
802 release_sock(sk);
803}
804EXPORT_SYMBOL_GPL(virtio_transport_release);
805
806int
807virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
808 struct sockaddr_vm *remote_addr,
809 struct msghdr *msg,
810 size_t dgram_len)
811{
812 struct virtio_transport *trans = vsk->trans;
813 struct virtio_vsock_pkt_info info = {
814 .op = VIRTIO_VSOCK_OP_RW,
815 .type = VIRTIO_VSOCK_TYPE_DGRAM,
816 .msg = msg,
817 };
818 size_t total_written = 0, pkt_off = 0, written;
819 u16 dgram_id;
820
821 /* The max size of a single dgram we support is 64KB */
822 if (dgram_len > VIRTIO_VSOCK_MAX_DGRAM_SIZE)
823 return -EMSGSIZE;
824
825 info.dgram_len = dgram_len;
826 vsk->remote_addr = *remote_addr;
827
828 dgram_id = trans->dgram_id++;
829
830 /* TODO: To optimize, if we have enough credit to send the pkt already,
831 * do not ask the peer to send credit to use */
832 virtio_transport_send_credit_request(vsk, VIRTIO_VSOCK_TYPE_DGRAM);
833
834 while (total_written < dgram_len) {
835 info.pkt_len = dgram_len - total_written;
836 info.flags = dgram_id << 16 | pkt_off;
837 written = trans->ops->send_pkt(vsk, &info);
838 if (written < 0)
839 return -ENOMEM;
840 if (written == 0) {
841 /* TODO: if written = 0, we need a sleep & wakeup
842 * instead of sleep */
843 pr_debug("%s: SHOULD WAIT written==0", __func__);
844 msleep(10);
845 }
846 total_written += written;
847 pkt_off += written;
848 pr_debug("%s:id=%d, dgram_len=%zu, off=%zu, total_written=%zu, written=%zu\n",
849 __func__, dgram_id, dgram_len, pkt_off, total_written, written);
850 }
851
852 return dgram_len;
853}
854EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
855
856ssize_t
857virtio_transport_stream_enqueue(struct vsock_sock *vsk,
858 struct msghdr *msg,
859 size_t len)
860{
861 struct virtio_transport *trans = vsk->trans;
862 struct virtio_vsock_pkt_info info = {
863 .op = VIRTIO_VSOCK_OP_RW,
864 .type = VIRTIO_VSOCK_TYPE_STREAM,
865 .msg = msg,
866 .pkt_len = len,
867 };
868
869 return trans->ops->send_pkt(vsk, &info);
870}
871EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
872
873void virtio_transport_destruct(struct vsock_sock *vsk)
874{
875 struct virtio_transport *trans = vsk->trans;
876
877 pr_debug("%s: vsk=%p\n", __func__, vsk);
878 kfree(trans);
879}
880EXPORT_SYMBOL_GPL(virtio_transport_destruct);
881
882static int virtio_transport_send_ack(struct vsock_sock *vsk, u32 cookie)
883{
884 struct virtio_transport *trans = vsk->trans;
885 struct virtio_vsock_pkt_info info = {
886 .op = VIRTIO_VSOCK_OP_ACK,
887 .type = VIRTIO_VSOCK_TYPE_STREAM,
888 .flags = cpu_to_le32(cookie),
889 };
890
891 pr_debug("%s: sk=%p send_offer\n", __func__, vsk);
892 return trans->ops->send_pkt(vsk, &info);
893}
894
895static int virtio_transport_send_reset(struct vsock_sock *vsk,
896 struct virtio_vsock_pkt *pkt)
897{
898 struct virtio_transport *trans = vsk->trans;
899 struct virtio_vsock_pkt_info info = {
900 .op = VIRTIO_VSOCK_OP_RST,
901 .type = VIRTIO_VSOCK_TYPE_STREAM,
902 };
903
904 pr_debug("%s\n", __func__);
905
906 /* Send RST only if the original pkt is not a RST pkt */
907 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
908 return 0;
909
910 return trans->ops->send_pkt(vsk, &info);
911}
912
913static int
914virtio_transport_recv_connecting(struct sock *sk,
915 struct virtio_vsock_pkt *pkt)
916{
917 struct vsock_sock *vsk = vsock_sk(sk);
918 int err;
919 int skerr;
920 u32 cookie;
921
922 pr_debug("%s: vsk=%p\n", __func__, vsk);
923 switch (le16_to_cpu(pkt->hdr.op)) {
924 case VIRTIO_VSOCK_OP_RESPONSE:
925 cookie = le32_to_cpu(pkt->hdr.flags);
926 pr_debug("%s: got RESPONSE and send ACK, cookie=%x\n", __func__, cookie);
927 err = virtio_transport_send_ack(vsk, cookie);
928 if (err < 0) {
929 skerr = -err;
930 goto destroy;
931 }
932 sk->sk_state = SS_CONNECTED;
933 sk->sk_socket->state = SS_CONNECTED;
934 vsock_insert_connected(vsk);
935 sk->sk_state_change(sk);
936 break;
937 case VIRTIO_VSOCK_OP_INVALID:
938 pr_debug("%s: got invalid\n", __func__);
939 break;
940 case VIRTIO_VSOCK_OP_RST:
941 pr_debug("%s: got rst\n", __func__);
942 skerr = ECONNRESET;
943 err = 0;
944 goto destroy;
945 default:
946 pr_debug("%s: got def\n", __func__);
947 skerr = EPROTO;
948 err = -EINVAL;
949 goto destroy;
950 }
951 return 0;
952
953destroy:
954 virtio_transport_send_reset(vsk, pkt);
955 sk->sk_state = SS_UNCONNECTED;
956 sk->sk_err = skerr;
957 sk->sk_error_report(sk);
958 return err;
959}
960
961static int
962virtio_transport_recv_connected(struct sock *sk,
963 struct virtio_vsock_pkt *pkt)
964{
965 struct vsock_sock *vsk = vsock_sk(sk);
966 struct virtio_transport *trans = vsk->trans;
967 int err = 0;
968
969 switch (le16_to_cpu(pkt->hdr.op)) {
970 case VIRTIO_VSOCK_OP_RW:
971 pkt->len = le32_to_cpu(pkt->hdr.len);
972 pkt->off = 0;
973 pkt->trans = trans;
974
975 mutex_lock(&trans->rx_lock);
976 virtio_transport_inc_rx_pkt(pkt);
977 list_add_tail(&pkt->list, &trans->rx_queue);
978 mutex_unlock(&trans->rx_lock);
979
980 sk->sk_data_ready(sk);
981 return err;
982 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
983 sk->sk_write_space(sk);
984 break;
985 case VIRTIO_VSOCK_OP_SHUTDOWN:
986 pr_debug("%s: got shutdown\n", __func__);
987 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
988 vsk->peer_shutdown |= RCV_SHUTDOWN;
989 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
990 vsk->peer_shutdown |= SEND_SHUTDOWN;
991 if (le32_to_cpu(pkt->hdr.flags))
992 sk->sk_state_change(sk);
993 break;
994 case VIRTIO_VSOCK_OP_RST:
995 pr_debug("%s: got rst\n", __func__);
996 sock_set_flag(sk, SOCK_DONE);
997 vsk->peer_shutdown = SHUTDOWN_MASK;
998 if (vsock_stream_has_data(vsk) <= 0)
999 sk->sk_state = SS_DISCONNECTING;
1000 sk->sk_state_change(sk);
1001 break;
1002 default:
1003 err = -EINVAL;
1004 break;
1005 }
1006
1007 virtio_transport_free_pkt(pkt);
1008 return err;
1009}
1010
1011static int
1012virtio_transport_send_response(struct vsock_sock *vsk,
1013 struct virtio_vsock_pkt *pkt)
1014{
1015 struct virtio_transport *trans = vsk->trans;
1016 struct virtio_vsock_pkt_info info = {
1017 .op = VIRTIO_VSOCK_OP_RESPONSE,
1018 .type = VIRTIO_VSOCK_TYPE_STREAM,
1019 .remote_cid = le32_to_cpu(pkt->hdr.src_cid),
1020 .remote_port = le32_to_cpu(pkt->hdr.src_port),
1021 };
1022 u32 cookie;
1023
1024 cookie = virtio_vsock_secure_cookie(le32_to_cpu(pkt->hdr.src_cid),
1025 le32_to_cpu(pkt->hdr.dst_cid),
1026 le32_to_cpu(pkt->hdr.src_port),
1027 le32_to_cpu(pkt->hdr.dst_port),
1028 jiffies / (HZ * 60));
1029 info.flags = cpu_to_le32(cookie);
1030
1031 pr_debug("%s: send_response, cookie=%x\n", __func__, le32_to_cpu(cookie));
1032
1033 return trans->ops->send_pkt(vsk, &info);
1034}
1035
1036/* Handle server socket */
1037static int
1038virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
1039{
1040 struct vsock_sock *vsk = vsock_sk(sk);
1041 struct vsock_sock *vpending;
1042 struct sock *pending;
1043 int err;
1044 u32 cookie;
1045
1046 switch (le16_to_cpu(pkt->hdr.op)) {
1047 case VIRTIO_VSOCK_OP_REQUEST:
1048 err = virtio_transport_send_response(vsk, pkt);
1049 if (err < 0) {
1050 // FIXME vsk should be vpending
1051 virtio_transport_send_reset(vsk, pkt);
1052 return err;
1053 }
1054 break;
1055 case VIRTIO_VSOCK_OP_ACK:
1056 cookie = le32_to_cpu(pkt->hdr.flags);
1057 err = virtio_vsock_check_cookie(le32_to_cpu(pkt->hdr.src_cid),
1058 le32_to_cpu(pkt->hdr.dst_cid),
1059 le32_to_cpu(pkt->hdr.src_port),
1060 le32_to_cpu(pkt->hdr.dst_port),
1061 jiffies / (HZ * 60),
1062 le32_to_cpu(pkt->hdr.flags),
1063 VSOCK_TIMEOUT_INIT);
1064 pr_debug("%s: cookie=%x, err=%d\n", __func__, cookie, err);
1065 if (err)
1066 return err;
1067
1068 /* So no pending socket are responsible for this pkt, create one */
1069 pr_debug("%s: create pending\n", __func__);
1070 pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
1071 sk->sk_type, 0);
1072 if (!pending) {
1073 virtio_transport_send_reset(vsk, pkt);
1074 return -ENOMEM;
1075 }
1076 sk->sk_ack_backlog++;
1077 pending->sk_state = SS_CONNECTING;
1078
1079 vpending = vsock_sk(pending);
1080 vsock_addr_init(&vpending->local_addr, le32_to_cpu(pkt->hdr.dst_cid),
1081 le32_to_cpu(pkt->hdr.dst_port));
1082 vsock_addr_init(&vpending->remote_addr, le32_to_cpu(pkt->hdr.src_cid),
1083 le32_to_cpu(pkt->hdr.src_port));
1084 vsock_add_pending(sk, pending);
1085
1086 pr_debug("%s: get pending\n", __func__);
1087 pending = virtio_transport_get_pending(sk, pkt);
1088 vpending = vsock_sk(pending);
1089 lock_sock(pending);
1090 switch (pending->sk_state) {
1091 case SS_CONNECTING:
1092 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_ACK) {
1093 pr_debug("%s: op=%d != OP_ACK\n", __func__,
1094 le16_to_cpu(pkt->hdr.op));
1095 virtio_transport_send_reset(vpending, pkt);
1096 pending->sk_err = EPROTO;
1097 pending->sk_state = SS_UNCONNECTED;
1098 sock_put(pending);
1099 } else {
1100 pending->sk_state = SS_CONNECTED;
1101 vsock_insert_connected(vpending);
1102
1103 vsock_remove_pending(sk, pending);
1104 vsock_enqueue_accept(sk, pending);
1105
1106 sk->sk_data_ready(sk);
1107 }
1108 err = 0;
1109 break;
1110 default:
1111 pr_debug("%s: sk->sk_ack_backlog=%d\n", __func__,
1112 sk->sk_ack_backlog);
1113 virtio_transport_send_reset(vpending, pkt);
1114 err = -EINVAL;
1115 break;
1116 }
1117 if (err < 0)
1118 vsock_remove_pending(sk, pending);
1119 release_sock(pending);
1120
1121 /* Release refcnt obtained in virtio_transport_get_pending */
1122 sock_put(pending);
1123 break;
1124 default:
1125 break;
1126 }
1127
1128 return 0;
1129}
1130
1131static void virtio_transport_space_update(struct sock *sk,
1132 struct virtio_vsock_pkt *pkt)
1133{
1134 struct vsock_sock *vsk = vsock_sk(sk);
1135 struct virtio_transport *trans = vsk->trans;
1136 bool space_available;
1137
1138 /* buf_alloc and fwd_cnt is always included in the hdr */
1139 mutex_lock(&trans->tx_lock);
1140 trans->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
1141 trans->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
1142 space_available = virtio_transport_has_space(vsk);
1143 mutex_unlock(&trans->tx_lock);
1144
1145 if (space_available)
1146 sk->sk_write_space(sk);
1147}
1148
1149/* We are under the virtio-vsock's vsock->rx_lock or
1150 * vhost-vsock's vq->mutex lock */
1151void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
1152{
1153 struct virtio_transport *trans;
1154 struct sockaddr_vm src, dst;
1155 struct vsock_sock *vsk;
1156 struct sock *sk;
1157
1158 vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port));
1159 vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port));
1160
1161 virtio_vsock_dumppkt(__func__, pkt);
1162
1163 if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_DGRAM) {
1164 sk = vsock_find_unbound_socket(&dst);
1165 if (!sk)
1166 goto free_pkt;
1167
1168 vsk = vsock_sk(sk);
1169 trans = vsk->trans;
1170 BUG_ON(!trans);
1171
1172 virtio_transport_space_update(sk, pkt);
1173
1174 lock_sock(sk);
1175 switch (le16_to_cpu(pkt->hdr.op)) {
1176 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
1177 virtio_transport_free_pkt(pkt);
1178 break;
1179 case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
1180 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_DGRAM,
1181 &pkt->hdr);
1182 virtio_transport_free_pkt(pkt);
1183 break;
1184 case VIRTIO_VSOCK_OP_RW:
1185 virtio_transport_recv_dgram(sk, pkt);
1186 break;
1187 default:
1188 virtio_transport_free_pkt(pkt);
1189 break;
1190 }
1191 release_sock(sk);
1192
1193 /* Release refcnt obtained when we fetched this socket out of
1194 * the unbound list.
1195 */
1196 sock_put(sk);
1197 return;
1198 } else if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) {
1199 /* The socket must be in connected or bound table
1200 * otherwise send reset back
1201 */
1202 sk = vsock_find_connected_socket(&src, &dst);
1203 if (!sk) {
1204 sk = vsock_find_bound_socket(&dst);
1205 if (!sk) {
1206 pr_debug("%s: can not find bound_socket\n", __func__);
1207 virtio_vsock_dumppkt(__func__, pkt);
1208 /* Ignore this pkt instead of sending reset back */
1209 /* TODO send a RST unless this packet is a RST (to avoid infinite loops) */
1210 goto free_pkt;
1211 }
1212 }
1213
1214 vsk = vsock_sk(sk);
1215 trans = vsk->trans;
1216 BUG_ON(!trans);
1217
1218 virtio_transport_space_update(sk, pkt);
1219
1220 lock_sock(sk);
1221 switch (sk->sk_state) {
1222 case VSOCK_SS_LISTEN:
1223 virtio_transport_recv_listen(sk, pkt);
1224 virtio_transport_free_pkt(pkt);
1225 break;
1226 case SS_CONNECTING:
1227 virtio_transport_recv_connecting(sk, pkt);
1228 virtio_transport_free_pkt(pkt);
1229 break;
1230 case SS_CONNECTED:
1231 virtio_transport_recv_connected(sk, pkt);
1232 break;
1233 default:
1234 virtio_transport_free_pkt(pkt);
1235 break;
1236 }
1237 release_sock(sk);
1238
1239 /* Release refcnt obtained when we fetched this socket out of the
1240 * bound or connected list.
1241 */
1242 sock_put(sk);
1243 }
1244 return;
1245
1246free_pkt:
1247 virtio_transport_free_pkt(pkt);
1248}
1249EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1250
1251void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1252{
1253 kfree(pkt->buf);
1254 kfree(pkt);
1255}
1256EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1257
1258static int __init virtio_vsock_common_init(void)
1259{
1260 get_random_bytes(vsockcookie_secret, sizeof(vsockcookie_secret));
1261 return 0;
1262}
1263
1264static void __exit virtio_vsock_common_exit(void)
1265{
1266}
1267
1268module_init(virtio_vsock_common_init);
1269module_exit(virtio_vsock_common_exit);
1270MODULE_LICENSE("GPL v2");
1271MODULE_AUTHOR("Asias He");
1272MODULE_DESCRIPTION("common code for virtio vsock");