GNU Linux-libre 6.5.10-gnu
[releases.git] / net / vmw_vsock / virtio_transport_common.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * common code for virtio vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio_vsock.h>
15 #include <uapi/linux/vsockmon.h>
16
17 #include <net/sock.h>
18 #include <net/af_vsock.h>
19
20 #define CREATE_TRACE_POINTS
21 #include <trace/events/vsock_virtio_transport_common.h>
22
23 /* How long to wait for graceful shutdown of a connection */
24 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
25
26 /* Threshold for detecting small packets to copy */
27 #define GOOD_COPY_LEN  128
28
29 static const struct virtio_transport *
30 virtio_transport_get_ops(struct vsock_sock *vsk)
31 {
32         const struct vsock_transport *t = vsock_core_get_transport(vsk);
33
34         if (WARN_ON(!t))
35                 return NULL;
36
37         return container_of(t, struct virtio_transport, transport);
38 }
39
40 /* Returns a new packet on success, otherwise returns NULL.
41  *
42  * If NULL is returned, errp is set to a negative errno.
43  */
44 static struct sk_buff *
45 virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
46                            size_t len,
47                            u32 src_cid,
48                            u32 src_port,
49                            u32 dst_cid,
50                            u32 dst_port)
51 {
52         const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len;
53         struct virtio_vsock_hdr *hdr;
54         struct sk_buff *skb;
55         void *payload;
56         int err;
57
58         skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
59         if (!skb)
60                 return NULL;
61
62         hdr = virtio_vsock_hdr(skb);
63         hdr->type       = cpu_to_le16(info->type);
64         hdr->op         = cpu_to_le16(info->op);
65         hdr->src_cid    = cpu_to_le64(src_cid);
66         hdr->dst_cid    = cpu_to_le64(dst_cid);
67         hdr->src_port   = cpu_to_le32(src_port);
68         hdr->dst_port   = cpu_to_le32(dst_port);
69         hdr->flags      = cpu_to_le32(info->flags);
70         hdr->len        = cpu_to_le32(len);
71
72         if (info->msg && len > 0) {
73                 payload = skb_put(skb, len);
74                 err = memcpy_from_msg(payload, info->msg, len);
75                 if (err)
76                         goto out;
77
78                 if (msg_data_left(info->msg) == 0 &&
79                     info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
80                         hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
81
82                         if (info->msg->msg_flags & MSG_EOR)
83                                 hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
84                 }
85         }
86
87         if (info->reply)
88                 virtio_vsock_skb_set_reply(skb);
89
90         trace_virtio_transport_alloc_pkt(src_cid, src_port,
91                                          dst_cid, dst_port,
92                                          len,
93                                          info->type,
94                                          info->op,
95                                          info->flags);
96
97         if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) {
98                 WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n");
99                 goto out;
100         }
101
102         return skb;
103
104 out:
105         kfree_skb(skb);
106         return NULL;
107 }
108
109 /* Packet capture */
110 static struct sk_buff *virtio_transport_build_skb(void *opaque)
111 {
112         struct virtio_vsock_hdr *pkt_hdr;
113         struct sk_buff *pkt = opaque;
114         struct af_vsockmon_hdr *hdr;
115         struct sk_buff *skb;
116         size_t payload_len;
117         void *payload_buf;
118
119         /* A packet could be split to fit the RX buffer, so we can retrieve
120          * the payload length from the header and the buffer pointer taking
121          * care of the offset in the original packet.
122          */
123         pkt_hdr = virtio_vsock_hdr(pkt);
124         payload_len = pkt->len;
125         payload_buf = pkt->data;
126
127         skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len,
128                         GFP_ATOMIC);
129         if (!skb)
130                 return NULL;
131
132         hdr = skb_put(skb, sizeof(*hdr));
133
134         /* pkt->hdr is little-endian so no need to byteswap here */
135         hdr->src_cid = pkt_hdr->src_cid;
136         hdr->src_port = pkt_hdr->src_port;
137         hdr->dst_cid = pkt_hdr->dst_cid;
138         hdr->dst_port = pkt_hdr->dst_port;
139
140         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
141         hdr->len = cpu_to_le16(sizeof(*pkt_hdr));
142         memset(hdr->reserved, 0, sizeof(hdr->reserved));
143
144         switch (le16_to_cpu(pkt_hdr->op)) {
145         case VIRTIO_VSOCK_OP_REQUEST:
146         case VIRTIO_VSOCK_OP_RESPONSE:
147                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
148                 break;
149         case VIRTIO_VSOCK_OP_RST:
150         case VIRTIO_VSOCK_OP_SHUTDOWN:
151                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
152                 break;
153         case VIRTIO_VSOCK_OP_RW:
154                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
155                 break;
156         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
157         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
158                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
159                 break;
160         default:
161                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
162                 break;
163         }
164
165         skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr));
166
167         if (payload_len) {
168                 skb_put_data(skb, payload_buf, payload_len);
169         }
170
171         return skb;
172 }
173
174 void virtio_transport_deliver_tap_pkt(struct sk_buff *skb)
175 {
176         if (virtio_vsock_skb_tap_delivered(skb))
177                 return;
178
179         vsock_deliver_tap(virtio_transport_build_skb, skb);
180         virtio_vsock_skb_set_tap_delivered(skb);
181 }
182 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
183
184 static u16 virtio_transport_get_type(struct sock *sk)
185 {
186         if (sk->sk_type == SOCK_STREAM)
187                 return VIRTIO_VSOCK_TYPE_STREAM;
188         else
189                 return VIRTIO_VSOCK_TYPE_SEQPACKET;
190 }
191
192 /* This function can only be used on connecting/connected sockets,
193  * since a socket assigned to a transport is required.
194  *
195  * Do not use on listener sockets!
196  */
197 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
198                                           struct virtio_vsock_pkt_info *info)
199 {
200         u32 src_cid, src_port, dst_cid, dst_port;
201         const struct virtio_transport *t_ops;
202         struct virtio_vsock_sock *vvs;
203         u32 pkt_len = info->pkt_len;
204         u32 rest_len;
205         int ret;
206
207         info->type = virtio_transport_get_type(sk_vsock(vsk));
208
209         t_ops = virtio_transport_get_ops(vsk);
210         if (unlikely(!t_ops))
211                 return -EFAULT;
212
213         src_cid = t_ops->transport.get_local_cid();
214         src_port = vsk->local_addr.svm_port;
215         if (!info->remote_cid) {
216                 dst_cid = vsk->remote_addr.svm_cid;
217                 dst_port = vsk->remote_addr.svm_port;
218         } else {
219                 dst_cid = info->remote_cid;
220                 dst_port = info->remote_port;
221         }
222
223         vvs = vsk->trans;
224
225         /* virtio_transport_get_credit might return less than pkt_len credit */
226         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
227
228         /* Do not send zero length OP_RW pkt */
229         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
230                 return pkt_len;
231
232         rest_len = pkt_len;
233
234         do {
235                 struct sk_buff *skb;
236                 size_t skb_len;
237
238                 skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len);
239
240                 skb = virtio_transport_alloc_skb(info, skb_len,
241                                                  src_cid, src_port,
242                                                  dst_cid, dst_port);
243                 if (!skb) {
244                         ret = -ENOMEM;
245                         break;
246                 }
247
248                 virtio_transport_inc_tx_pkt(vvs, skb);
249
250                 ret = t_ops->send_pkt(skb);
251                 if (ret < 0)
252                         break;
253
254                 /* Both virtio and vhost 'send_pkt()' returns 'skb_len',
255                  * but for reliability use 'ret' instead of 'skb_len'.
256                  * Also if partial send happens (e.g. 'ret' != 'skb_len')
257                  * somehow, we break this loop, but account such returned
258                  * value in 'virtio_transport_put_credit()'.
259                  */
260                 rest_len -= ret;
261
262                 if (WARN_ONCE(ret != skb_len,
263                               "'send_pkt()' returns %i, but %zu expected\n",
264                               ret, skb_len))
265                         break;
266         } while (rest_len);
267
268         virtio_transport_put_credit(vvs, rest_len);
269
270         /* Return number of bytes, if any data has been sent. */
271         if (rest_len != pkt_len)
272                 ret = pkt_len - rest_len;
273
274         return ret;
275 }
276
277 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
278                                         u32 len)
279 {
280         if (vvs->rx_bytes + len > vvs->buf_alloc)
281                 return false;
282
283         vvs->rx_bytes += len;
284         return true;
285 }
286
287 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
288                                         u32 len)
289 {
290         vvs->rx_bytes -= len;
291         vvs->fwd_cnt += len;
292 }
293
294 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb)
295 {
296         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
297
298         spin_lock_bh(&vvs->rx_lock);
299         vvs->last_fwd_cnt = vvs->fwd_cnt;
300         hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
301         hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc);
302         spin_unlock_bh(&vvs->rx_lock);
303 }
304 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
305
306 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
307 {
308         u32 ret;
309
310         if (!credit)
311                 return 0;
312
313         spin_lock_bh(&vvs->tx_lock);
314         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
315         if (ret > credit)
316                 ret = credit;
317         vvs->tx_cnt += ret;
318         spin_unlock_bh(&vvs->tx_lock);
319
320         return ret;
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
323
324 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
325 {
326         if (!credit)
327                 return;
328
329         spin_lock_bh(&vvs->tx_lock);
330         vvs->tx_cnt -= credit;
331         spin_unlock_bh(&vvs->tx_lock);
332 }
333 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
334
335 static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
336 {
337         struct virtio_vsock_pkt_info info = {
338                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
339                 .vsk = vsk,
340         };
341
342         return virtio_transport_send_pkt_info(vsk, &info);
343 }
344
345 static ssize_t
346 virtio_transport_stream_do_peek(struct vsock_sock *vsk,
347                                 struct msghdr *msg,
348                                 size_t len)
349 {
350         struct virtio_vsock_sock *vvs = vsk->trans;
351         size_t bytes, total = 0, off;
352         struct sk_buff *skb, *tmp;
353         int err = -EFAULT;
354
355         spin_lock_bh(&vvs->rx_lock);
356
357         skb_queue_walk_safe(&vvs->rx_queue, skb,  tmp) {
358                 off = 0;
359
360                 if (total == len)
361                         break;
362
363                 while (total < len && off < skb->len) {
364                         bytes = len - total;
365                         if (bytes > skb->len - off)
366                                 bytes = skb->len - off;
367
368                         /* sk_lock is held by caller so no one else can dequeue.
369                          * Unlock rx_lock since memcpy_to_msg() may sleep.
370                          */
371                         spin_unlock_bh(&vvs->rx_lock);
372
373                         err = memcpy_to_msg(msg, skb->data + off, bytes);
374                         if (err)
375                                 goto out;
376
377                         spin_lock_bh(&vvs->rx_lock);
378
379                         total += bytes;
380                         off += bytes;
381                 }
382         }
383
384         spin_unlock_bh(&vvs->rx_lock);
385
386         return total;
387
388 out:
389         if (total)
390                 err = total;
391         return err;
392 }
393
394 static ssize_t
395 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
396                                    struct msghdr *msg,
397                                    size_t len)
398 {
399         struct virtio_vsock_sock *vvs = vsk->trans;
400         size_t bytes, total = 0;
401         struct sk_buff *skb;
402         int err = -EFAULT;
403         u32 free_space;
404
405         spin_lock_bh(&vvs->rx_lock);
406
407         if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes,
408                       "rx_queue is empty, but rx_bytes is non-zero\n")) {
409                 spin_unlock_bh(&vvs->rx_lock);
410                 return err;
411         }
412
413         while (total < len && !skb_queue_empty(&vvs->rx_queue)) {
414                 skb = skb_peek(&vvs->rx_queue);
415
416                 bytes = len - total;
417                 if (bytes > skb->len)
418                         bytes = skb->len;
419
420                 /* sk_lock is held by caller so no one else can dequeue.
421                  * Unlock rx_lock since memcpy_to_msg() may sleep.
422                  */
423                 spin_unlock_bh(&vvs->rx_lock);
424
425                 err = memcpy_to_msg(msg, skb->data, bytes);
426                 if (err)
427                         goto out;
428
429                 spin_lock_bh(&vvs->rx_lock);
430
431                 total += bytes;
432                 skb_pull(skb, bytes);
433
434                 if (skb->len == 0) {
435                         u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len);
436
437                         virtio_transport_dec_rx_pkt(vvs, pkt_len);
438                         __skb_unlink(skb, &vvs->rx_queue);
439                         consume_skb(skb);
440                 }
441         }
442
443         free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
444
445         spin_unlock_bh(&vvs->rx_lock);
446
447         /* To reduce the number of credit update messages,
448          * don't update credits as long as lots of space is available.
449          * Note: the limit chosen here is arbitrary. Setting the limit
450          * too high causes extra messages. Too low causes transmitter
451          * stalls. As stalls are in theory more expensive than extra
452          * messages, we set the limit to a high value. TODO: experiment
453          * with different values.
454          */
455         if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
456                 virtio_transport_send_credit_update(vsk);
457
458         return total;
459
460 out:
461         if (total)
462                 err = total;
463         return err;
464 }
465
466 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
467                                                  struct msghdr *msg,
468                                                  int flags)
469 {
470         struct virtio_vsock_sock *vvs = vsk->trans;
471         int dequeued_len = 0;
472         size_t user_buf_len = msg_data_left(msg);
473         bool msg_ready = false;
474         struct sk_buff *skb;
475
476         spin_lock_bh(&vvs->rx_lock);
477
478         if (vvs->msg_count == 0) {
479                 spin_unlock_bh(&vvs->rx_lock);
480                 return 0;
481         }
482
483         while (!msg_ready) {
484                 struct virtio_vsock_hdr *hdr;
485                 size_t pkt_len;
486
487                 skb = __skb_dequeue(&vvs->rx_queue);
488                 if (!skb)
489                         break;
490                 hdr = virtio_vsock_hdr(skb);
491                 pkt_len = (size_t)le32_to_cpu(hdr->len);
492
493                 if (dequeued_len >= 0) {
494                         size_t bytes_to_copy;
495
496                         bytes_to_copy = min(user_buf_len, pkt_len);
497
498                         if (bytes_to_copy) {
499                                 int err;
500
501                                 /* sk_lock is held by caller so no one else can dequeue.
502                                  * Unlock rx_lock since memcpy_to_msg() may sleep.
503                                  */
504                                 spin_unlock_bh(&vvs->rx_lock);
505
506                                 err = memcpy_to_msg(msg, skb->data, bytes_to_copy);
507                                 if (err) {
508                                         /* Copy of message failed. Rest of
509                                          * fragments will be freed without copy.
510                                          */
511                                         dequeued_len = err;
512                                 } else {
513                                         user_buf_len -= bytes_to_copy;
514                                 }
515
516                                 spin_lock_bh(&vvs->rx_lock);
517                         }
518
519                         if (dequeued_len >= 0)
520                                 dequeued_len += pkt_len;
521                 }
522
523                 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
524                         msg_ready = true;
525                         vvs->msg_count--;
526
527                         if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
528                                 msg->msg_flags |= MSG_EOR;
529                 }
530
531                 virtio_transport_dec_rx_pkt(vvs, pkt_len);
532                 kfree_skb(skb);
533         }
534
535         spin_unlock_bh(&vvs->rx_lock);
536
537         virtio_transport_send_credit_update(vsk);
538
539         return dequeued_len;
540 }
541
542 ssize_t
543 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
544                                 struct msghdr *msg,
545                                 size_t len, int flags)
546 {
547         if (flags & MSG_PEEK)
548                 return virtio_transport_stream_do_peek(vsk, msg, len);
549         else
550                 return virtio_transport_stream_do_dequeue(vsk, msg, len);
551 }
552 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
553
554 ssize_t
555 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
556                                    struct msghdr *msg,
557                                    int flags)
558 {
559         if (flags & MSG_PEEK)
560                 return -EOPNOTSUPP;
561
562         return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
563 }
564 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
565
566 int
567 virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
568                                    struct msghdr *msg,
569                                    size_t len)
570 {
571         struct virtio_vsock_sock *vvs = vsk->trans;
572
573         spin_lock_bh(&vvs->tx_lock);
574
575         if (len > vvs->peer_buf_alloc) {
576                 spin_unlock_bh(&vvs->tx_lock);
577                 return -EMSGSIZE;
578         }
579
580         spin_unlock_bh(&vvs->tx_lock);
581
582         return virtio_transport_stream_enqueue(vsk, msg, len);
583 }
584 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
585
586 int
587 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
588                                struct msghdr *msg,
589                                size_t len, int flags)
590 {
591         return -EOPNOTSUPP;
592 }
593 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
594
595 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
596 {
597         struct virtio_vsock_sock *vvs = vsk->trans;
598         s64 bytes;
599
600         spin_lock_bh(&vvs->rx_lock);
601         bytes = vvs->rx_bytes;
602         spin_unlock_bh(&vvs->rx_lock);
603
604         return bytes;
605 }
606 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
607
608 u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
609 {
610         struct virtio_vsock_sock *vvs = vsk->trans;
611         u32 msg_count;
612
613         spin_lock_bh(&vvs->rx_lock);
614         msg_count = vvs->msg_count;
615         spin_unlock_bh(&vvs->rx_lock);
616
617         return msg_count;
618 }
619 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
620
621 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
622 {
623         struct virtio_vsock_sock *vvs = vsk->trans;
624         s64 bytes;
625
626         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
627         if (bytes < 0)
628                 bytes = 0;
629
630         return bytes;
631 }
632
633 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
634 {
635         struct virtio_vsock_sock *vvs = vsk->trans;
636         s64 bytes;
637
638         spin_lock_bh(&vvs->tx_lock);
639         bytes = virtio_transport_has_space(vsk);
640         spin_unlock_bh(&vvs->tx_lock);
641
642         return bytes;
643 }
644 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
645
646 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
647                                     struct vsock_sock *psk)
648 {
649         struct virtio_vsock_sock *vvs;
650
651         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
652         if (!vvs)
653                 return -ENOMEM;
654
655         vsk->trans = vvs;
656         vvs->vsk = vsk;
657         if (psk && psk->trans) {
658                 struct virtio_vsock_sock *ptrans = psk->trans;
659
660                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
661         }
662
663         if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
664                 vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
665
666         vvs->buf_alloc = vsk->buffer_size;
667
668         spin_lock_init(&vvs->rx_lock);
669         spin_lock_init(&vvs->tx_lock);
670         skb_queue_head_init(&vvs->rx_queue);
671
672         return 0;
673 }
674 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
675
676 /* sk_lock held by the caller */
677 void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
678 {
679         struct virtio_vsock_sock *vvs = vsk->trans;
680
681         if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
682                 *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
683
684         vvs->buf_alloc = *val;
685
686         virtio_transport_send_credit_update(vsk);
687 }
688 EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
689
690 int
691 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
692                                 size_t target,
693                                 bool *data_ready_now)
694 {
695         *data_ready_now = vsock_stream_has_data(vsk) >= target;
696
697         return 0;
698 }
699 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
700
701 int
702 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
703                                  size_t target,
704                                  bool *space_avail_now)
705 {
706         s64 free_space;
707
708         free_space = vsock_stream_has_space(vsk);
709         if (free_space > 0)
710                 *space_avail_now = true;
711         else if (free_space == 0)
712                 *space_avail_now = false;
713
714         return 0;
715 }
716 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
717
718 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
719         size_t target, struct vsock_transport_recv_notify_data *data)
720 {
721         return 0;
722 }
723 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
724
725 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
726         size_t target, struct vsock_transport_recv_notify_data *data)
727 {
728         return 0;
729 }
730 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
731
732 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
733         size_t target, struct vsock_transport_recv_notify_data *data)
734 {
735         return 0;
736 }
737 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
738
739 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
740         size_t target, ssize_t copied, bool data_read,
741         struct vsock_transport_recv_notify_data *data)
742 {
743         return 0;
744 }
745 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
746
747 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
748         struct vsock_transport_send_notify_data *data)
749 {
750         return 0;
751 }
752 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
753
754 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
755         struct vsock_transport_send_notify_data *data)
756 {
757         return 0;
758 }
759 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
760
761 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
762         struct vsock_transport_send_notify_data *data)
763 {
764         return 0;
765 }
766 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
767
768 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
769         ssize_t written, struct vsock_transport_send_notify_data *data)
770 {
771         return 0;
772 }
773 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
774
775 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
776 {
777         return vsk->buffer_size;
778 }
779 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
780
781 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
782 {
783         return true;
784 }
785 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
786
787 bool virtio_transport_stream_allow(u32 cid, u32 port)
788 {
789         return true;
790 }
791 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
792
793 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
794                                 struct sockaddr_vm *addr)
795 {
796         return -EOPNOTSUPP;
797 }
798 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
799
800 bool virtio_transport_dgram_allow(u32 cid, u32 port)
801 {
802         return false;
803 }
804 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
805
806 int virtio_transport_connect(struct vsock_sock *vsk)
807 {
808         struct virtio_vsock_pkt_info info = {
809                 .op = VIRTIO_VSOCK_OP_REQUEST,
810                 .vsk = vsk,
811         };
812
813         return virtio_transport_send_pkt_info(vsk, &info);
814 }
815 EXPORT_SYMBOL_GPL(virtio_transport_connect);
816
817 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
818 {
819         struct virtio_vsock_pkt_info info = {
820                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
821                 .flags = (mode & RCV_SHUTDOWN ?
822                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
823                          (mode & SEND_SHUTDOWN ?
824                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
825                 .vsk = vsk,
826         };
827
828         return virtio_transport_send_pkt_info(vsk, &info);
829 }
830 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
831
832 int
833 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
834                                struct sockaddr_vm *remote_addr,
835                                struct msghdr *msg,
836                                size_t dgram_len)
837 {
838         return -EOPNOTSUPP;
839 }
840 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
841
842 ssize_t
843 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
844                                 struct msghdr *msg,
845                                 size_t len)
846 {
847         struct virtio_vsock_pkt_info info = {
848                 .op = VIRTIO_VSOCK_OP_RW,
849                 .msg = msg,
850                 .pkt_len = len,
851                 .vsk = vsk,
852         };
853
854         return virtio_transport_send_pkt_info(vsk, &info);
855 }
856 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
857
858 void virtio_transport_destruct(struct vsock_sock *vsk)
859 {
860         struct virtio_vsock_sock *vvs = vsk->trans;
861
862         kfree(vvs);
863 }
864 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
865
866 static int virtio_transport_reset(struct vsock_sock *vsk,
867                                   struct sk_buff *skb)
868 {
869         struct virtio_vsock_pkt_info info = {
870                 .op = VIRTIO_VSOCK_OP_RST,
871                 .reply = !!skb,
872                 .vsk = vsk,
873         };
874
875         /* Send RST only if the original pkt is not a RST pkt */
876         if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST)
877                 return 0;
878
879         return virtio_transport_send_pkt_info(vsk, &info);
880 }
881
882 /* Normally packets are associated with a socket.  There may be no socket if an
883  * attempt was made to connect to a socket that does not exist.
884  */
885 static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
886                                           struct sk_buff *skb)
887 {
888         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
889         struct virtio_vsock_pkt_info info = {
890                 .op = VIRTIO_VSOCK_OP_RST,
891                 .type = le16_to_cpu(hdr->type),
892                 .reply = true,
893         };
894         struct sk_buff *reply;
895
896         /* Send RST only if the original pkt is not a RST pkt */
897         if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
898                 return 0;
899
900         if (!t)
901                 return -ENOTCONN;
902
903         reply = virtio_transport_alloc_skb(&info, 0,
904                                            le64_to_cpu(hdr->dst_cid),
905                                            le32_to_cpu(hdr->dst_port),
906                                            le64_to_cpu(hdr->src_cid),
907                                            le32_to_cpu(hdr->src_port));
908         if (!reply)
909                 return -ENOMEM;
910
911         return t->send_pkt(reply);
912 }
913
914 /* This function should be called with sk_lock held and SOCK_DONE set */
915 static void virtio_transport_remove_sock(struct vsock_sock *vsk)
916 {
917         struct virtio_vsock_sock *vvs = vsk->trans;
918
919         /* We don't need to take rx_lock, as the socket is closing and we are
920          * removing it.
921          */
922         __skb_queue_purge(&vvs->rx_queue);
923         vsock_remove_sock(vsk);
924 }
925
926 static void virtio_transport_wait_close(struct sock *sk, long timeout)
927 {
928         if (timeout) {
929                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
930
931                 add_wait_queue(sk_sleep(sk), &wait);
932
933                 do {
934                         if (sk_wait_event(sk, &timeout,
935                                           sock_flag(sk, SOCK_DONE), &wait))
936                                 break;
937                 } while (!signal_pending(current) && timeout);
938
939                 remove_wait_queue(sk_sleep(sk), &wait);
940         }
941 }
942
943 static void virtio_transport_do_close(struct vsock_sock *vsk,
944                                       bool cancel_timeout)
945 {
946         struct sock *sk = sk_vsock(vsk);
947
948         sock_set_flag(sk, SOCK_DONE);
949         vsk->peer_shutdown = SHUTDOWN_MASK;
950         if (vsock_stream_has_data(vsk) <= 0)
951                 sk->sk_state = TCP_CLOSING;
952         sk->sk_state_change(sk);
953
954         if (vsk->close_work_scheduled &&
955             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
956                 vsk->close_work_scheduled = false;
957
958                 virtio_transport_remove_sock(vsk);
959
960                 /* Release refcnt obtained when we scheduled the timeout */
961                 sock_put(sk);
962         }
963 }
964
965 static void virtio_transport_close_timeout(struct work_struct *work)
966 {
967         struct vsock_sock *vsk =
968                 container_of(work, struct vsock_sock, close_work.work);
969         struct sock *sk = sk_vsock(vsk);
970
971         sock_hold(sk);
972         lock_sock(sk);
973
974         if (!sock_flag(sk, SOCK_DONE)) {
975                 (void)virtio_transport_reset(vsk, NULL);
976
977                 virtio_transport_do_close(vsk, false);
978         }
979
980         vsk->close_work_scheduled = false;
981
982         release_sock(sk);
983         sock_put(sk);
984 }
985
986 /* User context, vsk->sk is locked */
987 static bool virtio_transport_close(struct vsock_sock *vsk)
988 {
989         struct sock *sk = &vsk->sk;
990
991         if (!(sk->sk_state == TCP_ESTABLISHED ||
992               sk->sk_state == TCP_CLOSING))
993                 return true;
994
995         /* Already received SHUTDOWN from peer, reply with RST */
996         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
997                 (void)virtio_transport_reset(vsk, NULL);
998                 return true;
999         }
1000
1001         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
1002                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
1003
1004         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
1005                 virtio_transport_wait_close(sk, sk->sk_lingertime);
1006
1007         if (sock_flag(sk, SOCK_DONE)) {
1008                 return true;
1009         }
1010
1011         sock_hold(sk);
1012         INIT_DELAYED_WORK(&vsk->close_work,
1013                           virtio_transport_close_timeout);
1014         vsk->close_work_scheduled = true;
1015         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
1016         return false;
1017 }
1018
1019 void virtio_transport_release(struct vsock_sock *vsk)
1020 {
1021         struct sock *sk = &vsk->sk;
1022         bool remove_sock = true;
1023
1024         if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
1025                 remove_sock = virtio_transport_close(vsk);
1026
1027         if (remove_sock) {
1028                 sock_set_flag(sk, SOCK_DONE);
1029                 virtio_transport_remove_sock(vsk);
1030         }
1031 }
1032 EXPORT_SYMBOL_GPL(virtio_transport_release);
1033
1034 static int
1035 virtio_transport_recv_connecting(struct sock *sk,
1036                                  struct sk_buff *skb)
1037 {
1038         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1039         struct vsock_sock *vsk = vsock_sk(sk);
1040         int skerr;
1041         int err;
1042
1043         switch (le16_to_cpu(hdr->op)) {
1044         case VIRTIO_VSOCK_OP_RESPONSE:
1045                 sk->sk_state = TCP_ESTABLISHED;
1046                 sk->sk_socket->state = SS_CONNECTED;
1047                 vsock_insert_connected(vsk);
1048                 sk->sk_state_change(sk);
1049                 break;
1050         case VIRTIO_VSOCK_OP_INVALID:
1051                 break;
1052         case VIRTIO_VSOCK_OP_RST:
1053                 skerr = ECONNRESET;
1054                 err = 0;
1055                 goto destroy;
1056         default:
1057                 skerr = EPROTO;
1058                 err = -EINVAL;
1059                 goto destroy;
1060         }
1061         return 0;
1062
1063 destroy:
1064         virtio_transport_reset(vsk, skb);
1065         sk->sk_state = TCP_CLOSE;
1066         sk->sk_err = skerr;
1067         sk_error_report(sk);
1068         return err;
1069 }
1070
1071 static void
1072 virtio_transport_recv_enqueue(struct vsock_sock *vsk,
1073                               struct sk_buff *skb)
1074 {
1075         struct virtio_vsock_sock *vvs = vsk->trans;
1076         bool can_enqueue, free_pkt = false;
1077         struct virtio_vsock_hdr *hdr;
1078         u32 len;
1079
1080         hdr = virtio_vsock_hdr(skb);
1081         len = le32_to_cpu(hdr->len);
1082
1083         spin_lock_bh(&vvs->rx_lock);
1084
1085         can_enqueue = virtio_transport_inc_rx_pkt(vvs, len);
1086         if (!can_enqueue) {
1087                 free_pkt = true;
1088                 goto out;
1089         }
1090
1091         if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)
1092                 vvs->msg_count++;
1093
1094         /* Try to copy small packets into the buffer of last packet queued,
1095          * to avoid wasting memory queueing the entire buffer with a small
1096          * payload.
1097          */
1098         if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) {
1099                 struct virtio_vsock_hdr *last_hdr;
1100                 struct sk_buff *last_skb;
1101
1102                 last_skb = skb_peek_tail(&vvs->rx_queue);
1103                 last_hdr = virtio_vsock_hdr(last_skb);
1104
1105                 /* If there is space in the last packet queued, we copy the
1106                  * new packet in its buffer. We avoid this if the last packet
1107                  * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
1108                  * delimiter of SEQPACKET message, so 'pkt' is the first packet
1109                  * of a new message.
1110                  */
1111                 if (skb->len < skb_tailroom(last_skb) &&
1112                     !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) {
1113                         memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);
1114                         free_pkt = true;
1115                         last_hdr->flags |= hdr->flags;
1116                         le32_add_cpu(&last_hdr->len, len);
1117                         goto out;
1118                 }
1119         }
1120
1121         __skb_queue_tail(&vvs->rx_queue, skb);
1122
1123 out:
1124         spin_unlock_bh(&vvs->rx_lock);
1125         if (free_pkt)
1126                 kfree_skb(skb);
1127 }
1128
1129 static int
1130 virtio_transport_recv_connected(struct sock *sk,
1131                                 struct sk_buff *skb)
1132 {
1133         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1134         struct vsock_sock *vsk = vsock_sk(sk);
1135         int err = 0;
1136
1137         switch (le16_to_cpu(hdr->op)) {
1138         case VIRTIO_VSOCK_OP_RW:
1139                 virtio_transport_recv_enqueue(vsk, skb);
1140                 vsock_data_ready(sk);
1141                 return err;
1142         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
1143                 virtio_transport_send_credit_update(vsk);
1144                 break;
1145         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
1146                 sk->sk_write_space(sk);
1147                 break;
1148         case VIRTIO_VSOCK_OP_SHUTDOWN:
1149                 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
1150                         vsk->peer_shutdown |= RCV_SHUTDOWN;
1151                 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
1152                         vsk->peer_shutdown |= SEND_SHUTDOWN;
1153                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
1154                     vsock_stream_has_data(vsk) <= 0 &&
1155                     !sock_flag(sk, SOCK_DONE)) {
1156                         (void)virtio_transport_reset(vsk, NULL);
1157                         virtio_transport_do_close(vsk, true);
1158                 }
1159                 if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
1160                         sk->sk_state_change(sk);
1161                 break;
1162         case VIRTIO_VSOCK_OP_RST:
1163                 virtio_transport_do_close(vsk, true);
1164                 break;
1165         default:
1166                 err = -EINVAL;
1167                 break;
1168         }
1169
1170         kfree_skb(skb);
1171         return err;
1172 }
1173
1174 static void
1175 virtio_transport_recv_disconnecting(struct sock *sk,
1176                                     struct sk_buff *skb)
1177 {
1178         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1179         struct vsock_sock *vsk = vsock_sk(sk);
1180
1181         if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
1182                 virtio_transport_do_close(vsk, true);
1183 }
1184
1185 static int
1186 virtio_transport_send_response(struct vsock_sock *vsk,
1187                                struct sk_buff *skb)
1188 {
1189         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1190         struct virtio_vsock_pkt_info info = {
1191                 .op = VIRTIO_VSOCK_OP_RESPONSE,
1192                 .remote_cid = le64_to_cpu(hdr->src_cid),
1193                 .remote_port = le32_to_cpu(hdr->src_port),
1194                 .reply = true,
1195                 .vsk = vsk,
1196         };
1197
1198         return virtio_transport_send_pkt_info(vsk, &info);
1199 }
1200
1201 static bool virtio_transport_space_update(struct sock *sk,
1202                                           struct sk_buff *skb)
1203 {
1204         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1205         struct vsock_sock *vsk = vsock_sk(sk);
1206         struct virtio_vsock_sock *vvs = vsk->trans;
1207         bool space_available;
1208
1209         /* Listener sockets are not associated with any transport, so we are
1210          * not able to take the state to see if there is space available in the
1211          * remote peer, but since they are only used to receive requests, we
1212          * can assume that there is always space available in the other peer.
1213          */
1214         if (!vvs)
1215                 return true;
1216
1217         /* buf_alloc and fwd_cnt is always included in the hdr */
1218         spin_lock_bh(&vvs->tx_lock);
1219         vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc);
1220         vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt);
1221         space_available = virtio_transport_has_space(vsk);
1222         spin_unlock_bh(&vvs->tx_lock);
1223         return space_available;
1224 }
1225
1226 /* Handle server socket */
1227 static int
1228 virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
1229                              struct virtio_transport *t)
1230 {
1231         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1232         struct vsock_sock *vsk = vsock_sk(sk);
1233         struct vsock_sock *vchild;
1234         struct sock *child;
1235         int ret;
1236
1237         if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) {
1238                 virtio_transport_reset_no_sock(t, skb);
1239                 return -EINVAL;
1240         }
1241
1242         if (sk_acceptq_is_full(sk)) {
1243                 virtio_transport_reset_no_sock(t, skb);
1244                 return -ENOMEM;
1245         }
1246
1247         child = vsock_create_connected(sk);
1248         if (!child) {
1249                 virtio_transport_reset_no_sock(t, skb);
1250                 return -ENOMEM;
1251         }
1252
1253         sk_acceptq_added(sk);
1254
1255         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1256
1257         child->sk_state = TCP_ESTABLISHED;
1258
1259         vchild = vsock_sk(child);
1260         vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
1261                         le32_to_cpu(hdr->dst_port));
1262         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
1263                         le32_to_cpu(hdr->src_port));
1264
1265         ret = vsock_assign_transport(vchild, vsk);
1266         /* Transport assigned (looking at remote_addr) must be the same
1267          * where we received the request.
1268          */
1269         if (ret || vchild->transport != &t->transport) {
1270                 release_sock(child);
1271                 virtio_transport_reset_no_sock(t, skb);
1272                 sock_put(child);
1273                 return ret;
1274         }
1275
1276         if (virtio_transport_space_update(child, skb))
1277                 child->sk_write_space(child);
1278
1279         vsock_insert_connected(vchild);
1280         vsock_enqueue_accept(sk, child);
1281         virtio_transport_send_response(vchild, skb);
1282
1283         release_sock(child);
1284
1285         sk->sk_data_ready(sk);
1286         return 0;
1287 }
1288
1289 static bool virtio_transport_valid_type(u16 type)
1290 {
1291         return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
1292                (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
1293 }
1294
1295 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1296  * lock.
1297  */
1298 void virtio_transport_recv_pkt(struct virtio_transport *t,
1299                                struct sk_buff *skb)
1300 {
1301         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1302         struct sockaddr_vm src, dst;
1303         struct vsock_sock *vsk;
1304         struct sock *sk;
1305         bool space_available;
1306
1307         vsock_addr_init(&src, le64_to_cpu(hdr->src_cid),
1308                         le32_to_cpu(hdr->src_port));
1309         vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid),
1310                         le32_to_cpu(hdr->dst_port));
1311
1312         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1313                                         dst.svm_cid, dst.svm_port,
1314                                         le32_to_cpu(hdr->len),
1315                                         le16_to_cpu(hdr->type),
1316                                         le16_to_cpu(hdr->op),
1317                                         le32_to_cpu(hdr->flags),
1318                                         le32_to_cpu(hdr->buf_alloc),
1319                                         le32_to_cpu(hdr->fwd_cnt));
1320
1321         if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) {
1322                 (void)virtio_transport_reset_no_sock(t, skb);
1323                 goto free_pkt;
1324         }
1325
1326         /* The socket must be in connected or bound table
1327          * otherwise send reset back
1328          */
1329         sk = vsock_find_connected_socket(&src, &dst);
1330         if (!sk) {
1331                 sk = vsock_find_bound_socket(&dst);
1332                 if (!sk) {
1333                         (void)virtio_transport_reset_no_sock(t, skb);
1334                         goto free_pkt;
1335                 }
1336         }
1337
1338         if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) {
1339                 (void)virtio_transport_reset_no_sock(t, skb);
1340                 sock_put(sk);
1341                 goto free_pkt;
1342         }
1343
1344         if (!skb_set_owner_sk_safe(skb, sk)) {
1345                 WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n");
1346                 goto free_pkt;
1347         }
1348
1349         vsk = vsock_sk(sk);
1350
1351         lock_sock(sk);
1352
1353         /* Check if sk has been closed before lock_sock */
1354         if (sock_flag(sk, SOCK_DONE)) {
1355                 (void)virtio_transport_reset_no_sock(t, skb);
1356                 release_sock(sk);
1357                 sock_put(sk);
1358                 goto free_pkt;
1359         }
1360
1361         space_available = virtio_transport_space_update(sk, skb);
1362
1363         /* Update CID in case it has changed after a transport reset event */
1364         if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
1365                 vsk->local_addr.svm_cid = dst.svm_cid;
1366
1367         if (space_available)
1368                 sk->sk_write_space(sk);
1369
1370         switch (sk->sk_state) {
1371         case TCP_LISTEN:
1372                 virtio_transport_recv_listen(sk, skb, t);
1373                 kfree_skb(skb);
1374                 break;
1375         case TCP_SYN_SENT:
1376                 virtio_transport_recv_connecting(sk, skb);
1377                 kfree_skb(skb);
1378                 break;
1379         case TCP_ESTABLISHED:
1380                 virtio_transport_recv_connected(sk, skb);
1381                 break;
1382         case TCP_CLOSING:
1383                 virtio_transport_recv_disconnecting(sk, skb);
1384                 kfree_skb(skb);
1385                 break;
1386         default:
1387                 (void)virtio_transport_reset_no_sock(t, skb);
1388                 kfree_skb(skb);
1389                 break;
1390         }
1391
1392         release_sock(sk);
1393
1394         /* Release refcnt obtained when we fetched this socket out of the
1395          * bound or connected list.
1396          */
1397         sock_put(sk);
1398         return;
1399
1400 free_pkt:
1401         kfree_skb(skb);
1402 }
1403 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1404
1405 /* Remove skbs found in a queue that have a vsk that matches.
1406  *
1407  * Each skb is freed.
1408  *
1409  * Returns the count of skbs that were reply packets.
1410  */
1411 int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue)
1412 {
1413         struct sk_buff_head freeme;
1414         struct sk_buff *skb, *tmp;
1415         int cnt = 0;
1416
1417         skb_queue_head_init(&freeme);
1418
1419         spin_lock_bh(&queue->lock);
1420         skb_queue_walk_safe(queue, skb, tmp) {
1421                 if (vsock_sk(skb->sk) != vsk)
1422                         continue;
1423
1424                 __skb_unlink(skb, queue);
1425                 __skb_queue_tail(&freeme, skb);
1426
1427                 if (virtio_vsock_skb_reply(skb))
1428                         cnt++;
1429         }
1430         spin_unlock_bh(&queue->lock);
1431
1432         __skb_queue_purge(&freeme);
1433
1434         return cnt;
1435 }
1436 EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs);
1437
1438 int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor)
1439 {
1440         struct virtio_vsock_sock *vvs = vsk->trans;
1441         struct sock *sk = sk_vsock(vsk);
1442         struct sk_buff *skb;
1443         int off = 0;
1444         int err;
1445
1446         spin_lock_bh(&vvs->rx_lock);
1447         /* Use __skb_recv_datagram() for race-free handling of the receive. It
1448          * works for types other than dgrams.
1449          */
1450         skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err);
1451         spin_unlock_bh(&vvs->rx_lock);
1452
1453         if (!skb)
1454                 return err;
1455
1456         return recv_actor(sk, skb);
1457 }
1458 EXPORT_SYMBOL_GPL(virtio_transport_read_skb);
1459
1460 MODULE_LICENSE("GPL v2");
1461 MODULE_AUTHOR("Asias He");
1462 MODULE_DESCRIPTION("common code for virtio vsock");