]>
Commit | Line | Data |
---|---|---|
80a19e33 AH |
1 | /* |
2 | * common code for virtio vsock | |
3 | * | |
4 | * Copyright (C) 2013-2015 Red Hat, Inc. | |
5 | * Author: Asias He <asias@redhat.com> | |
6 | * Stefan Hajnoczi <stefanha@redhat.com> | |
7 | * | |
8 | * This work is licensed under the terms of the GNU GPL, version 2. | |
9 | */ | |
10 | #include <linux/module.h> | |
11 | #include <linux/ctype.h> | |
12 | #include <linux/list.h> | |
13 | #include <linux/virtio.h> | |
14 | #include <linux/virtio_ids.h> | |
15 | #include <linux/virtio_config.h> | |
16 | #include <linux/virtio_vsock.h> | |
17 | #include <linux/random.h> | |
18 | #include <linux/cryptohash.h> | |
19 | ||
20 | #include <net/sock.h> | |
21 | #include <net/af_vsock.h> | |
22 | ||
23 | #define COOKIEBITS 24 | |
24 | #define COOKIEMASK (((u32)1 << COOKIEBITS) - 1) | |
25 | #define VSOCK_TIMEOUT_INIT 4 | |
26 | ||
27 | #define SHA_MESSAGE_WORDS 16 | |
28 | #define SHA_VSOCK_WORDS 5 | |
29 | ||
30 | static u32 vsockcookie_secret[2][SHA_MESSAGE_WORDS - SHA_VSOCK_WORDS + | |
31 | SHA_DIGEST_WORDS]; | |
32 | ||
33 | static DEFINE_PER_CPU(__u32[SHA_MESSAGE_WORDS + SHA_DIGEST_WORDS + | |
34 | SHA_WORKSPACE_WORDS], vsock_cookie_scratch); | |
35 | ||
36 | static u32 cookie_hash(u32 saddr, u32 daddr, u16 sport, u16 dport, | |
37 | u32 count, int c) | |
38 | { | |
39 | __u32 *tmp = this_cpu_ptr(vsock_cookie_scratch); | |
40 | ||
41 | memcpy(tmp + SHA_VSOCK_WORDS, vsockcookie_secret[c], | |
42 | sizeof(vsockcookie_secret[c])); | |
43 | tmp[0] = saddr; | |
44 | tmp[1] = daddr; | |
45 | tmp[2] = sport; | |
46 | tmp[3] = dport; | |
47 | tmp[4] = count; | |
48 | sha_transform(tmp + SHA_MESSAGE_WORDS, (__u8 *)tmp, | |
49 | tmp + SHA_MESSAGE_WORDS + SHA_DIGEST_WORDS); | |
50 | ||
51 | return tmp[17]; | |
52 | } | |
53 | ||
54 | static u32 | |
55 | virtio_vsock_secure_cookie(u32 saddr, u32 daddr, u32 sport, u32 dport, | |
56 | u32 count) | |
57 | { | |
58 | u32 h1, h2; | |
59 | ||
60 | h1 = cookie_hash(saddr, daddr, sport, dport, 0, 0); | |
61 | h2 = cookie_hash(saddr, daddr, sport, dport, count, 1); | |
62 | ||
63 | return h1 + (count << COOKIEBITS) + (h2 & COOKIEMASK); | |
64 | } | |
65 | ||
66 | static u32 | |
67 | virtio_vsock_check_cookie(u32 saddr, u32 daddr, u32 sport, u32 dport, | |
68 | u32 count, u32 cookie, u32 maxdiff) | |
69 | { | |
70 | u32 diff; | |
71 | u32 ret; | |
72 | ||
73 | cookie -= cookie_hash(saddr, daddr, sport, dport, 0, 0); | |
74 | ||
75 | diff = (count - (cookie >> COOKIEBITS)) & ((u32)-1 >> COOKIEBITS); | |
76 | pr_debug("%s: diff=%x\n", __func__, diff); | |
77 | if (diff >= maxdiff) | |
78 | return (u32)-1; | |
79 | ||
80 | ret = (cookie - | |
81 | cookie_hash(saddr, daddr, sport, dport, count - diff, 1)) | |
82 | & COOKIEMASK; | |
83 | pr_debug("%s: ret=%x\n", __func__, diff); | |
84 | ||
85 | return ret; | |
86 | } | |
87 | ||
88 | void virtio_vsock_dumppkt(const char *func, const struct virtio_vsock_pkt *pkt) | |
89 | { | |
90 | pr_debug("%s: pkt=%p, op=%d, len=%d, %d:%d---%d:%d, len=%d\n", | |
91 | func, pkt, | |
92 | le16_to_cpu(pkt->hdr.op), | |
93 | le32_to_cpu(pkt->hdr.len), | |
94 | le32_to_cpu(pkt->hdr.src_cid), | |
95 | le32_to_cpu(pkt->hdr.src_port), | |
96 | le32_to_cpu(pkt->hdr.dst_cid), | |
97 | le32_to_cpu(pkt->hdr.dst_port), | |
98 | pkt->len); | |
99 | } | |
100 | EXPORT_SYMBOL_GPL(virtio_vsock_dumppkt); | |
101 | ||
102 | struct virtio_vsock_pkt * | |
103 | virtio_transport_alloc_pkt(struct vsock_sock *vsk, | |
104 | struct virtio_vsock_pkt_info *info, | |
105 | size_t len, | |
106 | u32 src_cid, | |
107 | u32 src_port, | |
108 | u32 dst_cid, | |
109 | u32 dst_port) | |
110 | { | |
111 | struct virtio_transport *trans = vsk->trans; | |
112 | struct virtio_vsock_pkt *pkt; | |
113 | int err; | |
114 | ||
115 | BUG_ON(!trans); | |
116 | ||
117 | pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); | |
118 | if (!pkt) | |
119 | return NULL; | |
120 | ||
121 | pkt->hdr.type = cpu_to_le16(info->type); | |
122 | pkt->hdr.op = cpu_to_le16(info->op); | |
123 | pkt->hdr.src_cid = cpu_to_le32(src_cid); | |
124 | pkt->hdr.src_port = cpu_to_le32(src_port); | |
125 | pkt->hdr.dst_cid = cpu_to_le32(dst_cid); | |
126 | pkt->hdr.dst_port = cpu_to_le32(dst_port); | |
127 | pkt->hdr.flags = cpu_to_le32(info->flags); | |
128 | pkt->len = len; | |
129 | pkt->trans = trans; | |
130 | if (info->type == VIRTIO_VSOCK_TYPE_DGRAM) | |
131 | pkt->hdr.len = cpu_to_le32(len + (info->dgram_len << 16)); | |
132 | else if (info->type == VIRTIO_VSOCK_TYPE_STREAM) | |
133 | pkt->hdr.len = cpu_to_le32(len); | |
134 | ||
135 | if (info->msg && len > 0) { | |
136 | pkt->buf = kmalloc(len, GFP_KERNEL); | |
137 | if (!pkt->buf) | |
138 | goto out_pkt; | |
139 | err = memcpy_from_msg(pkt->buf, info->msg, len); | |
140 | if (err) | |
141 | goto out; | |
142 | } | |
143 | ||
144 | return pkt; | |
145 | ||
146 | out: | |
147 | kfree(pkt->buf); | |
148 | out_pkt: | |
149 | kfree(pkt); | |
150 | return NULL; | |
151 | } | |
152 | EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt); | |
153 | ||
154 | struct sock * | |
155 | virtio_transport_get_pending(struct sock *listener, | |
156 | struct virtio_vsock_pkt *pkt) | |
157 | { | |
158 | struct vsock_sock *vlistener; | |
159 | struct vsock_sock *vpending; | |
160 | struct sockaddr_vm src; | |
161 | struct sockaddr_vm dst; | |
162 | struct sock *pending; | |
163 | ||
164 | vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); | |
165 | vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port)); | |
166 | ||
167 | vlistener = vsock_sk(listener); | |
168 | list_for_each_entry(vpending, &vlistener->pending_links, | |
169 | pending_links) { | |
170 | if (vsock_addr_equals_addr(&src, &vpending->remote_addr) && | |
171 | vsock_addr_equals_addr(&dst, &vpending->local_addr)) { | |
172 | pending = sk_vsock(vpending); | |
173 | sock_hold(pending); | |
174 | return pending; | |
175 | } | |
176 | } | |
177 | ||
178 | return NULL; | |
179 | } | |
180 | EXPORT_SYMBOL_GPL(virtio_transport_get_pending); | |
181 | ||
182 | static void virtio_transport_inc_rx_pkt(struct virtio_vsock_pkt *pkt) | |
183 | { | |
184 | pkt->trans->rx_bytes += pkt->len; | |
185 | } | |
186 | ||
187 | static void virtio_transport_dec_rx_pkt(struct virtio_vsock_pkt *pkt) | |
188 | { | |
189 | pkt->trans->rx_bytes -= pkt->len; | |
190 | pkt->trans->fwd_cnt += pkt->len; | |
191 | } | |
192 | ||
193 | void virtio_transport_inc_tx_pkt(struct virtio_vsock_pkt *pkt) | |
194 | { | |
195 | mutex_lock(&pkt->trans->tx_lock); | |
196 | pkt->hdr.fwd_cnt = cpu_to_le32(pkt->trans->fwd_cnt); | |
197 | pkt->hdr.buf_alloc = cpu_to_le32(pkt->trans->buf_alloc); | |
198 | mutex_unlock(&pkt->trans->tx_lock); | |
199 | } | |
200 | EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); | |
201 | ||
202 | void virtio_transport_dec_tx_pkt(struct virtio_vsock_pkt *pkt) | |
203 | { | |
204 | } | |
205 | EXPORT_SYMBOL_GPL(virtio_transport_dec_tx_pkt); | |
206 | ||
207 | u32 virtio_transport_get_credit(struct virtio_transport *trans, u32 credit) | |
208 | { | |
209 | u32 ret; | |
210 | ||
211 | mutex_lock(&trans->tx_lock); | |
212 | ret = trans->peer_buf_alloc - (trans->tx_cnt - trans->peer_fwd_cnt); | |
213 | if (ret > credit) | |
214 | ret = credit; | |
215 | trans->tx_cnt += ret; | |
216 | mutex_unlock(&trans->tx_lock); | |
217 | ||
218 | pr_debug("%s: ret=%d, buf_alloc=%d, peer_buf_alloc=%d," | |
219 | "tx_cnt=%d, fwd_cnt=%d, peer_fwd_cnt=%d\n", __func__, | |
220 | ret, trans->buf_alloc, trans->peer_buf_alloc, | |
221 | trans->tx_cnt, trans->fwd_cnt, trans->peer_fwd_cnt); | |
222 | ||
223 | return ret; | |
224 | } | |
225 | EXPORT_SYMBOL_GPL(virtio_transport_get_credit); | |
226 | ||
227 | void virtio_transport_put_credit(struct virtio_transport *trans, u32 credit) | |
228 | { | |
229 | mutex_lock(&trans->tx_lock); | |
230 | trans->tx_cnt -= credit; | |
231 | mutex_unlock(&trans->tx_lock); | |
232 | } | |
233 | EXPORT_SYMBOL_GPL(virtio_transport_put_credit); | |
234 | ||
235 | static int virtio_transport_send_credit_update(struct vsock_sock *vsk, int type, struct virtio_vsock_hdr *hdr) | |
236 | { | |
237 | struct virtio_transport *trans = vsk->trans; | |
238 | struct virtio_vsock_pkt_info info = { | |
239 | .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, | |
240 | .type = type, | |
241 | }; | |
242 | ||
243 | if (hdr && type == VIRTIO_VSOCK_TYPE_DGRAM) { | |
244 | info.remote_cid = le32_to_cpu(hdr->src_cid); | |
245 | info.remote_port = le32_to_cpu(hdr->src_port); | |
246 | } | |
247 | ||
248 | pr_debug("%s: sk=%p send_credit_update\n", __func__, vsk); | |
249 | return trans->ops->send_pkt(vsk, &info); | |
250 | } | |
251 | ||
252 | static int virtio_transport_send_credit_request(struct vsock_sock *vsk, int type) | |
253 | { | |
254 | struct virtio_transport *trans = vsk->trans; | |
255 | struct virtio_vsock_pkt_info info = { | |
256 | .op = VIRTIO_VSOCK_OP_CREDIT_REQUEST, | |
257 | .type = type, | |
258 | }; | |
259 | ||
260 | pr_debug("%s: sk=%p send_credit_request\n", __func__, vsk); | |
261 | return trans->ops->send_pkt(vsk, &info); | |
262 | } | |
263 | ||
264 | static ssize_t | |
265 | virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, | |
266 | struct msghdr *msg, | |
267 | size_t len) | |
268 | { | |
269 | struct virtio_transport *trans = vsk->trans; | |
270 | struct virtio_vsock_pkt *pkt; | |
271 | size_t bytes, total = 0; | |
272 | int err = -EFAULT; | |
273 | ||
274 | mutex_lock(&trans->rx_lock); | |
275 | while (total < len && trans->rx_bytes > 0 && | |
276 | !list_empty(&trans->rx_queue)) { | |
277 | pkt = list_first_entry(&trans->rx_queue, | |
278 | struct virtio_vsock_pkt, list); | |
279 | ||
280 | bytes = len - total; | |
281 | if (bytes > pkt->len - pkt->off) | |
282 | bytes = pkt->len - pkt->off; | |
283 | ||
284 | err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); | |
285 | if (err) | |
286 | goto out; | |
287 | total += bytes; | |
288 | pkt->off += bytes; | |
289 | if (pkt->off == pkt->len) { | |
290 | virtio_transport_dec_rx_pkt(pkt); | |
291 | list_del(&pkt->list); | |
292 | virtio_transport_free_pkt(pkt); | |
293 | } | |
294 | } | |
295 | mutex_unlock(&trans->rx_lock); | |
296 | ||
297 | /* Send a credit pkt to peer */ | |
298 | virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, | |
299 | NULL); | |
300 | ||
301 | return total; | |
302 | ||
303 | out: | |
304 | mutex_unlock(&trans->rx_lock); | |
305 | if (total) | |
306 | err = total; | |
307 | return err; | |
308 | } | |
309 | ||
310 | ssize_t | |
311 | virtio_transport_stream_dequeue(struct vsock_sock *vsk, | |
312 | struct msghdr *msg, | |
313 | size_t len, int flags) | |
314 | { | |
315 | if (flags & MSG_PEEK) | |
316 | return -EOPNOTSUPP; | |
317 | ||
318 | return virtio_transport_stream_do_dequeue(vsk, msg, len); | |
319 | } | |
320 | EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); | |
321 | ||
322 | struct dgram_skb { | |
323 | struct list_head list; | |
324 | struct sk_buff *skb; | |
325 | u16 id; | |
326 | }; | |
327 | ||
328 | static struct dgram_skb *dgram_id_to_skb(struct virtio_transport *trans, | |
329 | u16 id) | |
330 | { | |
331 | struct dgram_skb *dgram_skb; | |
332 | ||
333 | list_for_each_entry(dgram_skb, &trans->incomplete_dgrams, list) { | |
334 | if (dgram_skb->id == id) | |
335 | return dgram_skb; | |
336 | } | |
337 | ||
338 | return NULL; | |
339 | } | |
340 | ||
341 | static void | |
342 | virtio_transport_recv_dgram(struct sock *sk, | |
343 | struct virtio_vsock_pkt *pkt) | |
344 | { | |
345 | struct sk_buff *skb = NULL; | |
346 | struct vsock_sock *vsk; | |
347 | struct virtio_transport *trans; | |
348 | size_t size; | |
349 | u16 dgram_id, pkt_off, dgram_len, pkt_len; | |
350 | u32 flags, len; | |
351 | struct dgram_skb *dgram_skb; | |
352 | ||
353 | vsk = vsock_sk(sk); | |
354 | trans = vsk->trans; | |
355 | ||
356 | /* len: dgram_len | pkt_len */ | |
357 | len = le32_to_cpu(pkt->hdr.len); | |
358 | dgram_len = len >> 16; | |
359 | pkt_len = len & 0xFFFF; | |
360 | ||
361 | /* flags: dgram_id | pkt_off */ | |
362 | flags = le32_to_cpu(pkt->hdr.flags); | |
363 | dgram_id = flags >> 16; | |
364 | pkt_off = flags & 0xFFFF; | |
365 | ||
366 | pr_debug("%s: dgram_len=%d, pkt_len=%d, id=%d, off=%d\n", __func__, | |
367 | dgram_len, pkt_len, dgram_id, pkt_off); | |
368 | ||
369 | dgram_skb = dgram_id_to_skb(trans, dgram_id); | |
370 | if (dgram_skb) { | |
371 | /* This pkt is for a existing dgram */ | |
372 | skb = dgram_skb->skb; | |
373 | pr_debug("%s:found skb\n", __func__); | |
374 | } | |
375 | ||
376 | /* Packet payload must be within datagram bounds */ | |
377 | if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) | |
378 | goto drop; | |
379 | if (pkt_len > dgram_len) | |
380 | goto drop; | |
381 | if (pkt_off > dgram_len) | |
382 | goto drop; | |
383 | if (dgram_len - pkt_off < pkt_len) | |
384 | goto drop; | |
385 | ||
386 | if (!skb) { | |
387 | /* This pkt is for a new dgram */ | |
388 | pr_debug("%s:create skb\n", __func__); | |
389 | ||
390 | size = sizeof(pkt->hdr) + dgram_len; | |
391 | /* Attach the packet to the socket's receive queue as an sk_buff. */ | |
392 | dgram_skb = kzalloc(sizeof(struct dgram_skb), GFP_ATOMIC); | |
393 | if (!dgram_skb) | |
394 | goto drop; | |
395 | ||
396 | skb = alloc_skb(size, GFP_ATOMIC); | |
397 | if (!skb) { | |
398 | kfree(dgram_skb); | |
399 | dgram_skb = NULL; | |
400 | goto drop; | |
401 | } | |
402 | dgram_skb->id = dgram_id; | |
403 | dgram_skb->skb = skb; | |
404 | list_add_tail(&dgram_skb->list, &trans->incomplete_dgrams); | |
405 | ||
406 | /* sk_receive_skb() will do a sock_put(), so hold here. */ | |
407 | sock_hold(sk); | |
408 | skb_put(skb, size); | |
409 | memcpy(skb->data, &pkt->hdr, sizeof(pkt->hdr)); | |
410 | } | |
411 | ||
412 | memcpy(skb->data + sizeof(pkt->hdr) + pkt_off, pkt->buf, pkt_len); | |
413 | ||
414 | pr_debug("%s:C, off=%d, pkt_len=%d, dgram_len=%d\n", __func__, | |
415 | pkt_off, pkt_len, dgram_len); | |
416 | ||
417 | /* We are done with this dgram */ | |
418 | if (pkt_off + pkt_len == dgram_len) { | |
419 | pr_debug("%s:dgram_id=%d is done\n", __func__, dgram_id); | |
420 | list_del(&dgram_skb->list); | |
421 | kfree(dgram_skb); | |
422 | sk_receive_skb(sk, skb, 0); | |
423 | } | |
424 | virtio_transport_free_pkt(pkt); | |
425 | return; | |
426 | ||
427 | drop: | |
428 | if (dgram_skb) { | |
429 | list_del(&dgram_skb->list); | |
430 | kfree(dgram_skb); | |
431 | kfree_skb(skb); | |
432 | sock_put(sk); | |
433 | } | |
434 | virtio_transport_free_pkt(pkt); | |
435 | } | |
436 | ||
437 | int | |
438 | virtio_transport_dgram_dequeue(struct vsock_sock *vsk, | |
439 | struct msghdr *msg, | |
440 | size_t len, int flags) | |
441 | { | |
442 | struct virtio_vsock_hdr *hdr; | |
443 | struct sk_buff *skb; | |
444 | int noblock; | |
445 | int err; | |
446 | int dgram_len; | |
447 | ||
448 | noblock = flags & MSG_DONTWAIT; | |
449 | ||
450 | if (flags & MSG_OOB || flags & MSG_ERRQUEUE) | |
451 | return -EOPNOTSUPP; | |
452 | ||
453 | /* Retrieve the head sk_buff from the socket's receive queue. */ | |
454 | err = 0; | |
455 | skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err); | |
456 | if (err) | |
457 | return err; | |
458 | if (!skb) | |
459 | return -EAGAIN; | |
460 | ||
461 | hdr = (struct virtio_vsock_hdr *)skb->data; | |
462 | if (!hdr) | |
463 | goto out; | |
464 | ||
465 | dgram_len = le32_to_cpu(hdr->len) >> 16; | |
466 | /* Place the datagram payload in the user's iovec. */ | |
467 | err = skb_copy_datagram_msg(skb, sizeof(*hdr), msg, dgram_len); | |
468 | if (err) | |
469 | goto out; | |
470 | ||
471 | if (msg->msg_name) { | |
472 | /* Provide the address of the sender. */ | |
473 | DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name); | |
474 | vsock_addr_init(vm_addr, le32_to_cpu(hdr->src_cid), le32_to_cpu(hdr->src_port)); | |
475 | msg->msg_namelen = sizeof(*vm_addr); | |
476 | } | |
477 | err = dgram_len; | |
478 | ||
479 | /* Send a credit pkt to peer */ | |
480 | virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_DGRAM, hdr); | |
481 | ||
482 | pr_debug("%s:done, recved =%d\n", __func__, dgram_len); | |
483 | out: | |
484 | skb_free_datagram(&vsk->sk, skb); | |
485 | return err; | |
486 | } | |
487 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); | |
488 | ||
489 | s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) | |
490 | { | |
491 | struct virtio_transport *trans = vsk->trans; | |
492 | s64 bytes; | |
493 | ||
494 | mutex_lock(&trans->rx_lock); | |
495 | bytes = trans->rx_bytes; | |
496 | mutex_unlock(&trans->rx_lock); | |
497 | ||
498 | return bytes; | |
499 | } | |
500 | EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); | |
501 | ||
502 | static s64 virtio_transport_has_space(struct vsock_sock *vsk) | |
503 | { | |
504 | struct virtio_transport *trans = vsk->trans; | |
505 | s64 bytes; | |
506 | ||
507 | bytes = trans->peer_buf_alloc - (trans->tx_cnt - trans->peer_fwd_cnt); | |
508 | if (bytes < 0) | |
509 | bytes = 0; | |
510 | ||
511 | return bytes; | |
512 | } | |
513 | ||
514 | s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) | |
515 | { | |
516 | struct virtio_transport *trans = vsk->trans; | |
517 | s64 bytes; | |
518 | ||
519 | mutex_lock(&trans->tx_lock); | |
520 | bytes = virtio_transport_has_space(vsk); | |
521 | mutex_unlock(&trans->tx_lock); | |
522 | ||
523 | pr_debug("%s: bytes=%lld\n", __func__, bytes); | |
524 | ||
525 | return bytes; | |
526 | } | |
527 | EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); | |
528 | ||
529 | int virtio_transport_do_socket_init(struct vsock_sock *vsk, | |
530 | struct vsock_sock *psk) | |
531 | { | |
532 | struct virtio_transport *trans; | |
533 | ||
534 | trans = kzalloc(sizeof(*trans), GFP_KERNEL); | |
535 | if (!trans) | |
536 | return -ENOMEM; | |
537 | ||
538 | vsk->trans = trans; | |
539 | trans->vsk = vsk; | |
540 | if (psk) { | |
541 | struct virtio_transport *ptrans = psk->trans; | |
542 | trans->buf_size = ptrans->buf_size; | |
543 | trans->buf_size_min = ptrans->buf_size_min; | |
544 | trans->buf_size_max = ptrans->buf_size_max; | |
545 | trans->peer_buf_alloc = ptrans->peer_buf_alloc; | |
546 | } else { | |
547 | trans->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; | |
548 | trans->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; | |
549 | trans->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; | |
550 | } | |
551 | ||
552 | trans->buf_alloc = trans->buf_size; | |
553 | ||
554 | pr_debug("%s: trans->buf_alloc=%d\n", __func__, trans->buf_alloc); | |
555 | ||
556 | mutex_init(&trans->rx_lock); | |
557 | mutex_init(&trans->tx_lock); | |
558 | INIT_LIST_HEAD(&trans->rx_queue); | |
559 | INIT_LIST_HEAD(&trans->incomplete_dgrams); | |
560 | ||
561 | return 0; | |
562 | } | |
563 | EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); | |
564 | ||
565 | u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) | |
566 | { | |
567 | struct virtio_transport *trans = vsk->trans; | |
568 | ||
569 | return trans->buf_size; | |
570 | } | |
571 | EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); | |
572 | ||
573 | u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) | |
574 | { | |
575 | struct virtio_transport *trans = vsk->trans; | |
576 | ||
577 | return trans->buf_size_min; | |
578 | } | |
579 | EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); | |
580 | ||
581 | u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) | |
582 | { | |
583 | struct virtio_transport *trans = vsk->trans; | |
584 | ||
585 | return trans->buf_size_max; | |
586 | } | |
587 | EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); | |
588 | ||
589 | void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) | |
590 | { | |
591 | struct virtio_transport *trans = vsk->trans; | |
592 | ||
593 | if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) | |
594 | val = VIRTIO_VSOCK_MAX_BUF_SIZE; | |
595 | if (val < trans->buf_size_min) | |
596 | trans->buf_size_min = val; | |
597 | if (val > trans->buf_size_max) | |
598 | trans->buf_size_max = val; | |
599 | trans->buf_size = val; | |
600 | trans->buf_alloc = val; | |
601 | } | |
602 | EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); | |
603 | ||
604 | void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) | |
605 | { | |
606 | struct virtio_transport *trans = vsk->trans; | |
607 | ||
608 | if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) | |
609 | val = VIRTIO_VSOCK_MAX_BUF_SIZE; | |
610 | if (val > trans->buf_size) | |
611 | trans->buf_size = val; | |
612 | trans->buf_size_min = val; | |
613 | } | |
614 | EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); | |
615 | ||
616 | void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) | |
617 | { | |
618 | struct virtio_transport *trans = vsk->trans; | |
619 | ||
620 | if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) | |
621 | val = VIRTIO_VSOCK_MAX_BUF_SIZE; | |
622 | if (val < trans->buf_size) | |
623 | trans->buf_size = val; | |
624 | trans->buf_size_max = val; | |
625 | } | |
626 | EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); | |
627 | ||
628 | int | |
629 | virtio_transport_notify_poll_in(struct vsock_sock *vsk, | |
630 | size_t target, | |
631 | bool *data_ready_now) | |
632 | { | |
633 | if (vsock_stream_has_data(vsk)) | |
634 | *data_ready_now = true; | |
635 | else | |
636 | *data_ready_now = false; | |
637 | ||
638 | return 0; | |
639 | } | |
640 | EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); | |
641 | ||
642 | int | |
643 | virtio_transport_notify_poll_out(struct vsock_sock *vsk, | |
644 | size_t target, | |
645 | bool *space_avail_now) | |
646 | { | |
647 | s64 free_space; | |
648 | ||
649 | free_space = vsock_stream_has_space(vsk); | |
650 | if (free_space > 0) | |
651 | *space_avail_now = true; | |
652 | else if (free_space == 0) | |
653 | *space_avail_now = false; | |
654 | ||
655 | return 0; | |
656 | } | |
657 | EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); | |
658 | ||
659 | int virtio_transport_notify_recv_init(struct vsock_sock *vsk, | |
660 | size_t target, struct vsock_transport_recv_notify_data *data) | |
661 | { | |
662 | return 0; | |
663 | } | |
664 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); | |
665 | ||
666 | int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, | |
667 | size_t target, struct vsock_transport_recv_notify_data *data) | |
668 | { | |
669 | return 0; | |
670 | } | |
671 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); | |
672 | ||
673 | int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, | |
674 | size_t target, struct vsock_transport_recv_notify_data *data) | |
675 | { | |
676 | return 0; | |
677 | } | |
678 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); | |
679 | ||
680 | int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, | |
681 | size_t target, ssize_t copied, bool data_read, | |
682 | struct vsock_transport_recv_notify_data *data) | |
683 | { | |
684 | return 0; | |
685 | } | |
686 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); | |
687 | ||
688 | int virtio_transport_notify_send_init(struct vsock_sock *vsk, | |
689 | struct vsock_transport_send_notify_data *data) | |
690 | { | |
691 | return 0; | |
692 | } | |
693 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); | |
694 | ||
695 | int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, | |
696 | struct vsock_transport_send_notify_data *data) | |
697 | { | |
698 | return 0; | |
699 | } | |
700 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); | |
701 | ||
702 | int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, | |
703 | struct vsock_transport_send_notify_data *data) | |
704 | { | |
705 | return 0; | |
706 | } | |
707 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); | |
708 | ||
709 | int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, | |
710 | ssize_t written, struct vsock_transport_send_notify_data *data) | |
711 | { | |
712 | return 0; | |
713 | } | |
714 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); | |
715 | ||
716 | u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) | |
717 | { | |
718 | struct virtio_transport *trans = vsk->trans; | |
719 | ||
720 | return trans->buf_size; | |
721 | } | |
722 | EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); | |
723 | ||
724 | bool virtio_transport_stream_is_active(struct vsock_sock *vsk) | |
725 | { | |
726 | return true; | |
727 | } | |
728 | EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); | |
729 | ||
730 | bool virtio_transport_stream_allow(u32 cid, u32 port) | |
731 | { | |
732 | return true; | |
733 | } | |
734 | EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); | |
735 | ||
736 | int virtio_transport_dgram_bind(struct vsock_sock *vsk, | |
737 | struct sockaddr_vm *addr) | |
738 | { | |
739 | return vsock_bind_dgram_generic(vsk, addr); | |
740 | } | |
741 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); | |
742 | ||
743 | bool virtio_transport_dgram_allow(u32 cid, u32 port) | |
744 | { | |
745 | return true; | |
746 | } | |
747 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); | |
748 | ||
749 | int virtio_transport_connect(struct vsock_sock *vsk) | |
750 | { | |
751 | struct virtio_transport *trans = vsk->trans; | |
752 | struct virtio_vsock_pkt_info info = { | |
753 | .op = VIRTIO_VSOCK_OP_REQUEST, | |
754 | .type = VIRTIO_VSOCK_TYPE_STREAM, | |
755 | }; | |
756 | ||
757 | pr_debug("%s: vsk=%p send_request\n", __func__, vsk); | |
758 | return trans->ops->send_pkt(vsk, &info); | |
759 | } | |
760 | EXPORT_SYMBOL_GPL(virtio_transport_connect); | |
761 | ||
762 | int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) | |
763 | { | |
764 | struct virtio_transport *trans = vsk->trans; | |
765 | struct virtio_vsock_pkt_info info = { | |
766 | .op = VIRTIO_VSOCK_OP_SHUTDOWN, | |
767 | .type = VIRTIO_VSOCK_TYPE_STREAM, | |
768 | .flags = (mode & RCV_SHUTDOWN ? | |
769 | VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | | |
770 | (mode & SEND_SHUTDOWN ? | |
771 | VIRTIO_VSOCK_SHUTDOWN_SEND : 0), | |
772 | }; | |
773 | ||
774 | pr_debug("%s: vsk=%p: send_shutdown\n", __func__, vsk); | |
775 | return trans->ops->send_pkt(vsk, &info); | |
776 | } | |
777 | EXPORT_SYMBOL_GPL(virtio_transport_shutdown); | |
778 | ||
779 | void virtio_transport_release(struct vsock_sock *vsk) | |
780 | { | |
781 | struct virtio_transport *trans = vsk->trans; | |
782 | struct sock *sk = &vsk->sk; | |
783 | struct dgram_skb *dgram_skb; | |
784 | struct dgram_skb *dgram_skb_tmp; | |
785 | ||
786 | pr_debug("%s: vsk=%p\n", __func__, vsk); | |
787 | ||
788 | /* Tell other side to terminate connection */ | |
789 | if (sk->sk_type == SOCK_STREAM && sk->sk_state == SS_CONNECTED) { | |
790 | virtio_transport_shutdown(vsk, SHUTDOWN_MASK); | |
791 | } | |
792 | ||
793 | /* Free incomplete dgrams */ | |
794 | lock_sock(sk); | |
795 | list_for_each_entry_safe(dgram_skb, dgram_skb_tmp, | |
796 | &trans->incomplete_dgrams, list) { | |
797 | list_del(&dgram_skb->list); | |
798 | kfree_skb(dgram_skb->skb); | |
799 | kfree(dgram_skb); | |
800 | sock_put(sk); /* held in virtio_transport_recv_dgram() */ | |
801 | } | |
802 | release_sock(sk); | |
803 | } | |
804 | EXPORT_SYMBOL_GPL(virtio_transport_release); | |
805 | ||
806 | int | |
807 | virtio_transport_dgram_enqueue(struct vsock_sock *vsk, | |
808 | struct sockaddr_vm *remote_addr, | |
809 | struct msghdr *msg, | |
810 | size_t dgram_len) | |
811 | { | |
812 | struct virtio_transport *trans = vsk->trans; | |
813 | struct virtio_vsock_pkt_info info = { | |
814 | .op = VIRTIO_VSOCK_OP_RW, | |
815 | .type = VIRTIO_VSOCK_TYPE_DGRAM, | |
816 | .msg = msg, | |
817 | }; | |
818 | size_t total_written = 0, pkt_off = 0, written; | |
819 | u16 dgram_id; | |
820 | ||
821 | /* The max size of a single dgram we support is 64KB */ | |
822 | if (dgram_len > VIRTIO_VSOCK_MAX_DGRAM_SIZE) | |
823 | return -EMSGSIZE; | |
824 | ||
825 | info.dgram_len = dgram_len; | |
826 | vsk->remote_addr = *remote_addr; | |
827 | ||
828 | dgram_id = trans->dgram_id++; | |
829 | ||
830 | /* TODO: To optimize, if we have enough credit to send the pkt already, | |
831 | * do not ask the peer to send credit to use */ | |
832 | virtio_transport_send_credit_request(vsk, VIRTIO_VSOCK_TYPE_DGRAM); | |
833 | ||
834 | while (total_written < dgram_len) { | |
835 | info.pkt_len = dgram_len - total_written; | |
836 | info.flags = dgram_id << 16 | pkt_off; | |
837 | written = trans->ops->send_pkt(vsk, &info); | |
838 | if (written < 0) | |
839 | return -ENOMEM; | |
840 | if (written == 0) { | |
841 | /* TODO: if written = 0, we need a sleep & wakeup | |
842 | * instead of sleep */ | |
843 | pr_debug("%s: SHOULD WAIT written==0", __func__); | |
844 | msleep(10); | |
845 | } | |
846 | total_written += written; | |
847 | pkt_off += written; | |
848 | pr_debug("%s:id=%d, dgram_len=%zu, off=%zu, total_written=%zu, written=%zu\n", | |
849 | __func__, dgram_id, dgram_len, pkt_off, total_written, written); | |
850 | } | |
851 | ||
852 | return dgram_len; | |
853 | } | |
854 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); | |
855 | ||
856 | ssize_t | |
857 | virtio_transport_stream_enqueue(struct vsock_sock *vsk, | |
858 | struct msghdr *msg, | |
859 | size_t len) | |
860 | { | |
861 | struct virtio_transport *trans = vsk->trans; | |
862 | struct virtio_vsock_pkt_info info = { | |
863 | .op = VIRTIO_VSOCK_OP_RW, | |
864 | .type = VIRTIO_VSOCK_TYPE_STREAM, | |
865 | .msg = msg, | |
866 | .pkt_len = len, | |
867 | }; | |
868 | ||
869 | return trans->ops->send_pkt(vsk, &info); | |
870 | } | |
871 | EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); | |
872 | ||
873 | void virtio_transport_destruct(struct vsock_sock *vsk) | |
874 | { | |
875 | struct virtio_transport *trans = vsk->trans; | |
876 | ||
877 | pr_debug("%s: vsk=%p\n", __func__, vsk); | |
878 | kfree(trans); | |
879 | } | |
880 | EXPORT_SYMBOL_GPL(virtio_transport_destruct); | |
881 | ||
882 | static int virtio_transport_send_ack(struct vsock_sock *vsk, u32 cookie) | |
883 | { | |
884 | struct virtio_transport *trans = vsk->trans; | |
885 | struct virtio_vsock_pkt_info info = { | |
886 | .op = VIRTIO_VSOCK_OP_ACK, | |
887 | .type = VIRTIO_VSOCK_TYPE_STREAM, | |
888 | .flags = cpu_to_le32(cookie), | |
889 | }; | |
890 | ||
891 | pr_debug("%s: sk=%p send_offer\n", __func__, vsk); | |
892 | return trans->ops->send_pkt(vsk, &info); | |
893 | } | |
894 | ||
895 | static int virtio_transport_send_reset(struct vsock_sock *vsk, | |
896 | struct virtio_vsock_pkt *pkt) | |
897 | { | |
898 | struct virtio_transport *trans = vsk->trans; | |
899 | struct virtio_vsock_pkt_info info = { | |
900 | .op = VIRTIO_VSOCK_OP_RST, | |
901 | .type = VIRTIO_VSOCK_TYPE_STREAM, | |
902 | }; | |
903 | ||
904 | pr_debug("%s\n", __func__); | |
905 | ||
906 | /* Send RST only if the original pkt is not a RST pkt */ | |
907 | if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) | |
908 | return 0; | |
909 | ||
910 | return trans->ops->send_pkt(vsk, &info); | |
911 | } | |
912 | ||
913 | static int | |
914 | virtio_transport_recv_connecting(struct sock *sk, | |
915 | struct virtio_vsock_pkt *pkt) | |
916 | { | |
917 | struct vsock_sock *vsk = vsock_sk(sk); | |
918 | int err; | |
919 | int skerr; | |
920 | u32 cookie; | |
921 | ||
922 | pr_debug("%s: vsk=%p\n", __func__, vsk); | |
923 | switch (le16_to_cpu(pkt->hdr.op)) { | |
924 | case VIRTIO_VSOCK_OP_RESPONSE: | |
925 | cookie = le32_to_cpu(pkt->hdr.flags); | |
926 | pr_debug("%s: got RESPONSE and send ACK, cookie=%x\n", __func__, cookie); | |
927 | err = virtio_transport_send_ack(vsk, cookie); | |
928 | if (err < 0) { | |
929 | skerr = -err; | |
930 | goto destroy; | |
931 | } | |
932 | sk->sk_state = SS_CONNECTED; | |
933 | sk->sk_socket->state = SS_CONNECTED; | |
934 | vsock_insert_connected(vsk); | |
935 | sk->sk_state_change(sk); | |
936 | break; | |
937 | case VIRTIO_VSOCK_OP_INVALID: | |
938 | pr_debug("%s: got invalid\n", __func__); | |
939 | break; | |
940 | case VIRTIO_VSOCK_OP_RST: | |
941 | pr_debug("%s: got rst\n", __func__); | |
942 | skerr = ECONNRESET; | |
943 | err = 0; | |
944 | goto destroy; | |
945 | default: | |
946 | pr_debug("%s: got def\n", __func__); | |
947 | skerr = EPROTO; | |
948 | err = -EINVAL; | |
949 | goto destroy; | |
950 | } | |
951 | return 0; | |
952 | ||
953 | destroy: | |
954 | virtio_transport_send_reset(vsk, pkt); | |
955 | sk->sk_state = SS_UNCONNECTED; | |
956 | sk->sk_err = skerr; | |
957 | sk->sk_error_report(sk); | |
958 | return err; | |
959 | } | |
960 | ||
961 | static int | |
962 | virtio_transport_recv_connected(struct sock *sk, | |
963 | struct virtio_vsock_pkt *pkt) | |
964 | { | |
965 | struct vsock_sock *vsk = vsock_sk(sk); | |
966 | struct virtio_transport *trans = vsk->trans; | |
967 | int err = 0; | |
968 | ||
969 | switch (le16_to_cpu(pkt->hdr.op)) { | |
970 | case VIRTIO_VSOCK_OP_RW: | |
971 | pkt->len = le32_to_cpu(pkt->hdr.len); | |
972 | pkt->off = 0; | |
973 | pkt->trans = trans; | |
974 | ||
975 | mutex_lock(&trans->rx_lock); | |
976 | virtio_transport_inc_rx_pkt(pkt); | |
977 | list_add_tail(&pkt->list, &trans->rx_queue); | |
978 | mutex_unlock(&trans->rx_lock); | |
979 | ||
980 | sk->sk_data_ready(sk); | |
981 | return err; | |
982 | case VIRTIO_VSOCK_OP_CREDIT_UPDATE: | |
983 | sk->sk_write_space(sk); | |
984 | break; | |
985 | case VIRTIO_VSOCK_OP_SHUTDOWN: | |
986 | pr_debug("%s: got shutdown\n", __func__); | |
987 | if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) | |
988 | vsk->peer_shutdown |= RCV_SHUTDOWN; | |
989 | if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) | |
990 | vsk->peer_shutdown |= SEND_SHUTDOWN; | |
991 | if (le32_to_cpu(pkt->hdr.flags)) | |
992 | sk->sk_state_change(sk); | |
993 | break; | |
994 | case VIRTIO_VSOCK_OP_RST: | |
995 | pr_debug("%s: got rst\n", __func__); | |
996 | sock_set_flag(sk, SOCK_DONE); | |
997 | vsk->peer_shutdown = SHUTDOWN_MASK; | |
998 | if (vsock_stream_has_data(vsk) <= 0) | |
999 | sk->sk_state = SS_DISCONNECTING; | |
1000 | sk->sk_state_change(sk); | |
1001 | break; | |
1002 | default: | |
1003 | err = -EINVAL; | |
1004 | break; | |
1005 | } | |
1006 | ||
1007 | virtio_transport_free_pkt(pkt); | |
1008 | return err; | |
1009 | } | |
1010 | ||
1011 | static int | |
1012 | virtio_transport_send_response(struct vsock_sock *vsk, | |
1013 | struct virtio_vsock_pkt *pkt) | |
1014 | { | |
1015 | struct virtio_transport *trans = vsk->trans; | |
1016 | struct virtio_vsock_pkt_info info = { | |
1017 | .op = VIRTIO_VSOCK_OP_RESPONSE, | |
1018 | .type = VIRTIO_VSOCK_TYPE_STREAM, | |
1019 | .remote_cid = le32_to_cpu(pkt->hdr.src_cid), | |
1020 | .remote_port = le32_to_cpu(pkt->hdr.src_port), | |
1021 | }; | |
1022 | u32 cookie; | |
1023 | ||
1024 | cookie = virtio_vsock_secure_cookie(le32_to_cpu(pkt->hdr.src_cid), | |
1025 | le32_to_cpu(pkt->hdr.dst_cid), | |
1026 | le32_to_cpu(pkt->hdr.src_port), | |
1027 | le32_to_cpu(pkt->hdr.dst_port), | |
1028 | jiffies / (HZ * 60)); | |
1029 | info.flags = cpu_to_le32(cookie); | |
1030 | ||
1031 | pr_debug("%s: send_response, cookie=%x\n", __func__, le32_to_cpu(cookie)); | |
1032 | ||
1033 | return trans->ops->send_pkt(vsk, &info); | |
1034 | } | |
1035 | ||
1036 | /* Handle server socket */ | |
1037 | static int | |
1038 | virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) | |
1039 | { | |
1040 | struct vsock_sock *vsk = vsock_sk(sk); | |
1041 | struct vsock_sock *vpending; | |
1042 | struct sock *pending; | |
1043 | int err; | |
1044 | u32 cookie; | |
1045 | ||
1046 | switch (le16_to_cpu(pkt->hdr.op)) { | |
1047 | case VIRTIO_VSOCK_OP_REQUEST: | |
1048 | err = virtio_transport_send_response(vsk, pkt); | |
1049 | if (err < 0) { | |
1050 | // FIXME vsk should be vpending | |
1051 | virtio_transport_send_reset(vsk, pkt); | |
1052 | return err; | |
1053 | } | |
1054 | break; | |
1055 | case VIRTIO_VSOCK_OP_ACK: | |
1056 | cookie = le32_to_cpu(pkt->hdr.flags); | |
1057 | err = virtio_vsock_check_cookie(le32_to_cpu(pkt->hdr.src_cid), | |
1058 | le32_to_cpu(pkt->hdr.dst_cid), | |
1059 | le32_to_cpu(pkt->hdr.src_port), | |
1060 | le32_to_cpu(pkt->hdr.dst_port), | |
1061 | jiffies / (HZ * 60), | |
1062 | le32_to_cpu(pkt->hdr.flags), | |
1063 | VSOCK_TIMEOUT_INIT); | |
1064 | pr_debug("%s: cookie=%x, err=%d\n", __func__, cookie, err); | |
1065 | if (err) | |
1066 | return err; | |
1067 | ||
1068 | /* So no pending socket are responsible for this pkt, create one */ | |
1069 | pr_debug("%s: create pending\n", __func__); | |
1070 | pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, | |
1071 | sk->sk_type, 0); | |
1072 | if (!pending) { | |
1073 | virtio_transport_send_reset(vsk, pkt); | |
1074 | return -ENOMEM; | |
1075 | } | |
1076 | sk->sk_ack_backlog++; | |
1077 | pending->sk_state = SS_CONNECTING; | |
1078 | ||
1079 | vpending = vsock_sk(pending); | |
1080 | vsock_addr_init(&vpending->local_addr, le32_to_cpu(pkt->hdr.dst_cid), | |
1081 | le32_to_cpu(pkt->hdr.dst_port)); | |
1082 | vsock_addr_init(&vpending->remote_addr, le32_to_cpu(pkt->hdr.src_cid), | |
1083 | le32_to_cpu(pkt->hdr.src_port)); | |
1084 | vsock_add_pending(sk, pending); | |
1085 | ||
1086 | pr_debug("%s: get pending\n", __func__); | |
1087 | pending = virtio_transport_get_pending(sk, pkt); | |
1088 | vpending = vsock_sk(pending); | |
1089 | lock_sock(pending); | |
1090 | switch (pending->sk_state) { | |
1091 | case SS_CONNECTING: | |
1092 | if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_ACK) { | |
1093 | pr_debug("%s: op=%d != OP_ACK\n", __func__, | |
1094 | le16_to_cpu(pkt->hdr.op)); | |
1095 | virtio_transport_send_reset(vpending, pkt); | |
1096 | pending->sk_err = EPROTO; | |
1097 | pending->sk_state = SS_UNCONNECTED; | |
1098 | sock_put(pending); | |
1099 | } else { | |
1100 | pending->sk_state = SS_CONNECTED; | |
1101 | vsock_insert_connected(vpending); | |
1102 | ||
1103 | vsock_remove_pending(sk, pending); | |
1104 | vsock_enqueue_accept(sk, pending); | |
1105 | ||
1106 | sk->sk_data_ready(sk); | |
1107 | } | |
1108 | err = 0; | |
1109 | break; | |
1110 | default: | |
1111 | pr_debug("%s: sk->sk_ack_backlog=%d\n", __func__, | |
1112 | sk->sk_ack_backlog); | |
1113 | virtio_transport_send_reset(vpending, pkt); | |
1114 | err = -EINVAL; | |
1115 | break; | |
1116 | } | |
1117 | if (err < 0) | |
1118 | vsock_remove_pending(sk, pending); | |
1119 | release_sock(pending); | |
1120 | ||
1121 | /* Release refcnt obtained in virtio_transport_get_pending */ | |
1122 | sock_put(pending); | |
1123 | break; | |
1124 | default: | |
1125 | break; | |
1126 | } | |
1127 | ||
1128 | return 0; | |
1129 | } | |
1130 | ||
1131 | static void virtio_transport_space_update(struct sock *sk, | |
1132 | struct virtio_vsock_pkt *pkt) | |
1133 | { | |
1134 | struct vsock_sock *vsk = vsock_sk(sk); | |
1135 | struct virtio_transport *trans = vsk->trans; | |
1136 | bool space_available; | |
1137 | ||
1138 | /* buf_alloc and fwd_cnt is always included in the hdr */ | |
1139 | mutex_lock(&trans->tx_lock); | |
1140 | trans->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); | |
1141 | trans->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); | |
1142 | space_available = virtio_transport_has_space(vsk); | |
1143 | mutex_unlock(&trans->tx_lock); | |
1144 | ||
1145 | if (space_available) | |
1146 | sk->sk_write_space(sk); | |
1147 | } | |
1148 | ||
1149 | /* We are under the virtio-vsock's vsock->rx_lock or | |
1150 | * vhost-vsock's vq->mutex lock */ | |
1151 | void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) | |
1152 | { | |
1153 | struct virtio_transport *trans; | |
1154 | struct sockaddr_vm src, dst; | |
1155 | struct vsock_sock *vsk; | |
1156 | struct sock *sk; | |
1157 | ||
1158 | vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); | |
1159 | vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port)); | |
1160 | ||
1161 | virtio_vsock_dumppkt(__func__, pkt); | |
1162 | ||
1163 | if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_DGRAM) { | |
1164 | sk = vsock_find_unbound_socket(&dst); | |
1165 | if (!sk) | |
1166 | goto free_pkt; | |
1167 | ||
1168 | vsk = vsock_sk(sk); | |
1169 | trans = vsk->trans; | |
1170 | BUG_ON(!trans); | |
1171 | ||
1172 | virtio_transport_space_update(sk, pkt); | |
1173 | ||
1174 | lock_sock(sk); | |
1175 | switch (le16_to_cpu(pkt->hdr.op)) { | |
1176 | case VIRTIO_VSOCK_OP_CREDIT_UPDATE: | |
1177 | virtio_transport_free_pkt(pkt); | |
1178 | break; | |
1179 | case VIRTIO_VSOCK_OP_CREDIT_REQUEST: | |
1180 | virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_DGRAM, | |
1181 | &pkt->hdr); | |
1182 | virtio_transport_free_pkt(pkt); | |
1183 | break; | |
1184 | case VIRTIO_VSOCK_OP_RW: | |
1185 | virtio_transport_recv_dgram(sk, pkt); | |
1186 | break; | |
1187 | default: | |
1188 | virtio_transport_free_pkt(pkt); | |
1189 | break; | |
1190 | } | |
1191 | release_sock(sk); | |
1192 | ||
1193 | /* Release refcnt obtained when we fetched this socket out of | |
1194 | * the unbound list. | |
1195 | */ | |
1196 | sock_put(sk); | |
1197 | return; | |
1198 | } else if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) { | |
1199 | /* The socket must be in connected or bound table | |
1200 | * otherwise send reset back | |
1201 | */ | |
1202 | sk = vsock_find_connected_socket(&src, &dst); | |
1203 | if (!sk) { | |
1204 | sk = vsock_find_bound_socket(&dst); | |
1205 | if (!sk) { | |
1206 | pr_debug("%s: can not find bound_socket\n", __func__); | |
1207 | virtio_vsock_dumppkt(__func__, pkt); | |
1208 | /* Ignore this pkt instead of sending reset back */ | |
1209 | /* TODO send a RST unless this packet is a RST (to avoid infinite loops) */ | |
1210 | goto free_pkt; | |
1211 | } | |
1212 | } | |
1213 | ||
1214 | vsk = vsock_sk(sk); | |
1215 | trans = vsk->trans; | |
1216 | BUG_ON(!trans); | |
1217 | ||
1218 | virtio_transport_space_update(sk, pkt); | |
1219 | ||
1220 | lock_sock(sk); | |
1221 | switch (sk->sk_state) { | |
1222 | case VSOCK_SS_LISTEN: | |
1223 | virtio_transport_recv_listen(sk, pkt); | |
1224 | virtio_transport_free_pkt(pkt); | |
1225 | break; | |
1226 | case SS_CONNECTING: | |
1227 | virtio_transport_recv_connecting(sk, pkt); | |
1228 | virtio_transport_free_pkt(pkt); | |
1229 | break; | |
1230 | case SS_CONNECTED: | |
1231 | virtio_transport_recv_connected(sk, pkt); | |
1232 | break; | |
1233 | default: | |
1234 | virtio_transport_free_pkt(pkt); | |
1235 | break; | |
1236 | } | |
1237 | release_sock(sk); | |
1238 | ||
1239 | /* Release refcnt obtained when we fetched this socket out of the | |
1240 | * bound or connected list. | |
1241 | */ | |
1242 | sock_put(sk); | |
1243 | } | |
1244 | return; | |
1245 | ||
1246 | free_pkt: | |
1247 | virtio_transport_free_pkt(pkt); | |
1248 | } | |
1249 | EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); | |
1250 | ||
1251 | void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) | |
1252 | { | |
1253 | kfree(pkt->buf); | |
1254 | kfree(pkt); | |
1255 | } | |
1256 | EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); | |
1257 | ||
1258 | static int __init virtio_vsock_common_init(void) | |
1259 | { | |
1260 | get_random_bytes(vsockcookie_secret, sizeof(vsockcookie_secret)); | |
1261 | return 0; | |
1262 | } | |
1263 | ||
1264 | static void __exit virtio_vsock_common_exit(void) | |
1265 | { | |
1266 | } | |
1267 | ||
1268 | module_init(virtio_vsock_common_init); | |
1269 | module_exit(virtio_vsock_common_exit); | |
1270 | MODULE_LICENSE("GPL v2"); | |
1271 | MODULE_AUTHOR("Asias He"); | |
1272 | MODULE_DESCRIPTION("common code for virtio vsock"); |