GNU Linux-libre 4.9.294-gnu1
[releases.git] / net / vmw_vsock / virtio_transport_common.c
1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18
19 #include <net/sock.h>
20 #include <net/af_vsock.h>
21
22 #define CREATE_TRACE_POINTS
23 #include <trace/events/vsock_virtio_transport_common.h>
24
25 /* How long to wait for graceful shutdown of a connection */
26 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
27
28 static const struct virtio_transport *virtio_transport_get_ops(void)
29 {
30         const struct vsock_transport *t = vsock_core_get_transport();
31
32         return container_of(t, struct virtio_transport, transport);
33 }
34
35 struct virtio_vsock_pkt *
36 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
37                            size_t len,
38                            u32 src_cid,
39                            u32 src_port,
40                            u32 dst_cid,
41                            u32 dst_port)
42 {
43         struct virtio_vsock_pkt *pkt;
44         int err;
45
46         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
47         if (!pkt)
48                 return NULL;
49
50         pkt->hdr.type           = cpu_to_le16(info->type);
51         pkt->hdr.op             = cpu_to_le16(info->op);
52         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
53         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
54         pkt->hdr.src_port       = cpu_to_le32(src_port);
55         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
56         pkt->hdr.flags          = cpu_to_le32(info->flags);
57         pkt->len                = len;
58         pkt->hdr.len            = cpu_to_le32(len);
59         pkt->reply              = info->reply;
60         pkt->vsk                = info->vsk;
61
62         if (info->msg && len > 0) {
63                 pkt->buf = kmalloc(len, GFP_KERNEL);
64                 if (!pkt->buf)
65                         goto out_pkt;
66                 err = memcpy_from_msg(pkt->buf, info->msg, len);
67                 if (err)
68                         goto out;
69         }
70
71         trace_virtio_transport_alloc_pkt(src_cid, src_port,
72                                          dst_cid, dst_port,
73                                          len,
74                                          info->type,
75                                          info->op,
76                                          info->flags);
77
78         return pkt;
79
80 out:
81         kfree(pkt->buf);
82 out_pkt:
83         kfree(pkt);
84         return NULL;
85 }
86 EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
87
88 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
89                                           struct virtio_vsock_pkt_info *info)
90 {
91         u32 src_cid, src_port, dst_cid, dst_port;
92         struct virtio_vsock_sock *vvs;
93         struct virtio_vsock_pkt *pkt;
94         u32 pkt_len = info->pkt_len;
95
96         src_cid = vm_sockets_get_local_cid();
97         src_port = vsk->local_addr.svm_port;
98         if (!info->remote_cid) {
99                 dst_cid = vsk->remote_addr.svm_cid;
100                 dst_port = vsk->remote_addr.svm_port;
101         } else {
102                 dst_cid = info->remote_cid;
103                 dst_port = info->remote_port;
104         }
105
106         vvs = vsk->trans;
107
108         /* we can send less than pkt_len bytes */
109         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
110                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
111
112         /* virtio_transport_get_credit might return less than pkt_len credit */
113         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
114
115         /* Do not send zero length OP_RW pkt */
116         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
117                 return pkt_len;
118
119         pkt = virtio_transport_alloc_pkt(info, pkt_len,
120                                          src_cid, src_port,
121                                          dst_cid, dst_port);
122         if (!pkt) {
123                 virtio_transport_put_credit(vvs, pkt_len);
124                 return -ENOMEM;
125         }
126
127         virtio_transport_inc_tx_pkt(vvs, pkt);
128
129         return virtio_transport_get_ops()->send_pkt(pkt);
130 }
131
132 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
133                                         struct virtio_vsock_pkt *pkt)
134 {
135         vvs->rx_bytes += pkt->len;
136 }
137
138 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
139                                         struct virtio_vsock_pkt *pkt)
140 {
141         vvs->rx_bytes -= pkt->len;
142         vvs->fwd_cnt += pkt->len;
143 }
144
145 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
146 {
147         spin_lock_bh(&vvs->tx_lock);
148         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
149         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
150         spin_unlock_bh(&vvs->tx_lock);
151 }
152 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
153
154 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
155 {
156         u32 ret;
157
158         spin_lock_bh(&vvs->tx_lock);
159         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
160         if (ret > credit)
161                 ret = credit;
162         vvs->tx_cnt += ret;
163         spin_unlock_bh(&vvs->tx_lock);
164
165         return ret;
166 }
167 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
168
169 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
170 {
171         spin_lock_bh(&vvs->tx_lock);
172         vvs->tx_cnt -= credit;
173         spin_unlock_bh(&vvs->tx_lock);
174 }
175 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
176
177 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
178                                                int type,
179                                                struct virtio_vsock_hdr *hdr)
180 {
181         struct virtio_vsock_pkt_info info = {
182                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
183                 .type = type,
184                 .vsk = vsk,
185         };
186
187         return virtio_transport_send_pkt_info(vsk, &info);
188 }
189
190 static ssize_t
191 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
192                                    struct msghdr *msg,
193                                    size_t len)
194 {
195         struct virtio_vsock_sock *vvs = vsk->trans;
196         struct virtio_vsock_pkt *pkt;
197         size_t bytes, total = 0;
198         int err = -EFAULT;
199
200         spin_lock_bh(&vvs->rx_lock);
201         while (total < len && !list_empty(&vvs->rx_queue)) {
202                 pkt = list_first_entry(&vvs->rx_queue,
203                                        struct virtio_vsock_pkt, list);
204
205                 bytes = len - total;
206                 if (bytes > pkt->len - pkt->off)
207                         bytes = pkt->len - pkt->off;
208
209                 /* sk_lock is held by caller so no one else can dequeue.
210                  * Unlock rx_lock since memcpy_to_msg() may sleep.
211                  */
212                 spin_unlock_bh(&vvs->rx_lock);
213
214                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
215                 if (err)
216                         goto out;
217
218                 spin_lock_bh(&vvs->rx_lock);
219
220                 total += bytes;
221                 pkt->off += bytes;
222                 if (pkt->off == pkt->len) {
223                         virtio_transport_dec_rx_pkt(vvs, pkt);
224                         list_del(&pkt->list);
225                         virtio_transport_free_pkt(pkt);
226                 }
227         }
228         spin_unlock_bh(&vvs->rx_lock);
229
230         /* Send a credit pkt to peer */
231         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
232                                             NULL);
233
234         return total;
235
236 out:
237         if (total)
238                 err = total;
239         return err;
240 }
241
242 ssize_t
243 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
244                                 struct msghdr *msg,
245                                 size_t len, int flags)
246 {
247         if (flags & MSG_PEEK)
248                 return -EOPNOTSUPP;
249
250         return virtio_transport_stream_do_dequeue(vsk, msg, len);
251 }
252 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
253
254 int
255 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
256                                struct msghdr *msg,
257                                size_t len, int flags)
258 {
259         return -EOPNOTSUPP;
260 }
261 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
262
263 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
264 {
265         struct virtio_vsock_sock *vvs = vsk->trans;
266         s64 bytes;
267
268         spin_lock_bh(&vvs->rx_lock);
269         bytes = vvs->rx_bytes;
270         spin_unlock_bh(&vvs->rx_lock);
271
272         return bytes;
273 }
274 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
275
276 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
277 {
278         struct virtio_vsock_sock *vvs = vsk->trans;
279         s64 bytes;
280
281         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
282         if (bytes < 0)
283                 bytes = 0;
284
285         return bytes;
286 }
287
288 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
289 {
290         struct virtio_vsock_sock *vvs = vsk->trans;
291         s64 bytes;
292
293         spin_lock_bh(&vvs->tx_lock);
294         bytes = virtio_transport_has_space(vsk);
295         spin_unlock_bh(&vvs->tx_lock);
296
297         return bytes;
298 }
299 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
300
301 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
302                                     struct vsock_sock *psk)
303 {
304         struct virtio_vsock_sock *vvs;
305
306         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
307         if (!vvs)
308                 return -ENOMEM;
309
310         vsk->trans = vvs;
311         vvs->vsk = vsk;
312         if (psk) {
313                 struct virtio_vsock_sock *ptrans = psk->trans;
314
315                 vvs->buf_size   = ptrans->buf_size;
316                 vvs->buf_size_min = ptrans->buf_size_min;
317                 vvs->buf_size_max = ptrans->buf_size_max;
318                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
319         } else {
320                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
321                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
322                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
323         }
324
325         vvs->buf_alloc = vvs->buf_size;
326
327         spin_lock_init(&vvs->rx_lock);
328         spin_lock_init(&vvs->tx_lock);
329         INIT_LIST_HEAD(&vvs->rx_queue);
330
331         return 0;
332 }
333 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
334
335 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
336 {
337         struct virtio_vsock_sock *vvs = vsk->trans;
338
339         return vvs->buf_size;
340 }
341 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
342
343 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
344 {
345         struct virtio_vsock_sock *vvs = vsk->trans;
346
347         return vvs->buf_size_min;
348 }
349 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
350
351 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
352 {
353         struct virtio_vsock_sock *vvs = vsk->trans;
354
355         return vvs->buf_size_max;
356 }
357 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
358
359 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
360 {
361         struct virtio_vsock_sock *vvs = vsk->trans;
362
363         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
364                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
365         if (val < vvs->buf_size_min)
366                 vvs->buf_size_min = val;
367         if (val > vvs->buf_size_max)
368                 vvs->buf_size_max = val;
369         vvs->buf_size = val;
370         vvs->buf_alloc = val;
371 }
372 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
373
374 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
375 {
376         struct virtio_vsock_sock *vvs = vsk->trans;
377
378         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
379                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
380         if (val > vvs->buf_size)
381                 vvs->buf_size = val;
382         vvs->buf_size_min = val;
383 }
384 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
385
386 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
387 {
388         struct virtio_vsock_sock *vvs = vsk->trans;
389
390         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
391                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
392         if (val < vvs->buf_size)
393                 vvs->buf_size = val;
394         vvs->buf_size_max = val;
395 }
396 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
397
398 int
399 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
400                                 size_t target,
401                                 bool *data_ready_now)
402 {
403         if (vsock_stream_has_data(vsk))
404                 *data_ready_now = true;
405         else
406                 *data_ready_now = false;
407
408         return 0;
409 }
410 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
411
412 int
413 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
414                                  size_t target,
415                                  bool *space_avail_now)
416 {
417         s64 free_space;
418
419         free_space = vsock_stream_has_space(vsk);
420         if (free_space > 0)
421                 *space_avail_now = true;
422         else if (free_space == 0)
423                 *space_avail_now = false;
424
425         return 0;
426 }
427 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
428
429 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
430         size_t target, struct vsock_transport_recv_notify_data *data)
431 {
432         return 0;
433 }
434 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
435
436 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
437         size_t target, struct vsock_transport_recv_notify_data *data)
438 {
439         return 0;
440 }
441 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
442
443 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
444         size_t target, struct vsock_transport_recv_notify_data *data)
445 {
446         return 0;
447 }
448 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
449
450 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
451         size_t target, ssize_t copied, bool data_read,
452         struct vsock_transport_recv_notify_data *data)
453 {
454         return 0;
455 }
456 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
457
458 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
459         struct vsock_transport_send_notify_data *data)
460 {
461         return 0;
462 }
463 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
464
465 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
466         struct vsock_transport_send_notify_data *data)
467 {
468         return 0;
469 }
470 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
471
472 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
473         struct vsock_transport_send_notify_data *data)
474 {
475         return 0;
476 }
477 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
478
479 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
480         ssize_t written, struct vsock_transport_send_notify_data *data)
481 {
482         return 0;
483 }
484 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
485
486 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
487 {
488         struct virtio_vsock_sock *vvs = vsk->trans;
489
490         return vvs->buf_size;
491 }
492 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
493
494 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
495 {
496         return true;
497 }
498 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
499
500 bool virtio_transport_stream_allow(u32 cid, u32 port)
501 {
502         return true;
503 }
504 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
505
506 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
507                                 struct sockaddr_vm *addr)
508 {
509         return -EOPNOTSUPP;
510 }
511 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
512
513 bool virtio_transport_dgram_allow(u32 cid, u32 port)
514 {
515         return false;
516 }
517 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
518
519 int virtio_transport_connect(struct vsock_sock *vsk)
520 {
521         struct virtio_vsock_pkt_info info = {
522                 .op = VIRTIO_VSOCK_OP_REQUEST,
523                 .type = VIRTIO_VSOCK_TYPE_STREAM,
524                 .vsk = vsk,
525         };
526
527         return virtio_transport_send_pkt_info(vsk, &info);
528 }
529 EXPORT_SYMBOL_GPL(virtio_transport_connect);
530
531 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
532 {
533         struct virtio_vsock_pkt_info info = {
534                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
535                 .type = VIRTIO_VSOCK_TYPE_STREAM,
536                 .flags = (mode & RCV_SHUTDOWN ?
537                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
538                          (mode & SEND_SHUTDOWN ?
539                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
540                 .vsk = vsk,
541         };
542
543         return virtio_transport_send_pkt_info(vsk, &info);
544 }
545 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
546
547 int
548 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
549                                struct sockaddr_vm *remote_addr,
550                                struct msghdr *msg,
551                                size_t dgram_len)
552 {
553         return -EOPNOTSUPP;
554 }
555 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
556
557 ssize_t
558 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
559                                 struct msghdr *msg,
560                                 size_t len)
561 {
562         struct virtio_vsock_pkt_info info = {
563                 .op = VIRTIO_VSOCK_OP_RW,
564                 .type = VIRTIO_VSOCK_TYPE_STREAM,
565                 .msg = msg,
566                 .pkt_len = len,
567                 .vsk = vsk,
568         };
569
570         return virtio_transport_send_pkt_info(vsk, &info);
571 }
572 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
573
574 void virtio_transport_destruct(struct vsock_sock *vsk)
575 {
576         struct virtio_vsock_sock *vvs = vsk->trans;
577
578         kfree(vvs);
579 }
580 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
581
582 static int virtio_transport_reset(struct vsock_sock *vsk,
583                                   struct virtio_vsock_pkt *pkt)
584 {
585         struct virtio_vsock_pkt_info info = {
586                 .op = VIRTIO_VSOCK_OP_RST,
587                 .type = VIRTIO_VSOCK_TYPE_STREAM,
588                 .reply = !!pkt,
589                 .vsk = vsk,
590         };
591
592         /* Send RST only if the original pkt is not a RST pkt */
593         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
594                 return 0;
595
596         return virtio_transport_send_pkt_info(vsk, &info);
597 }
598
599 /* Normally packets are associated with a socket.  There may be no socket if an
600  * attempt was made to connect to a socket that does not exist.
601  */
602 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
603 {
604         const struct virtio_transport *t;
605         struct virtio_vsock_pkt *reply;
606         struct virtio_vsock_pkt_info info = {
607                 .op = VIRTIO_VSOCK_OP_RST,
608                 .type = le16_to_cpu(pkt->hdr.type),
609                 .reply = true,
610         };
611
612         /* Send RST only if the original pkt is not a RST pkt */
613         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
614                 return 0;
615
616         reply = virtio_transport_alloc_pkt(&info, 0,
617                                            le64_to_cpu(pkt->hdr.dst_cid),
618                                            le32_to_cpu(pkt->hdr.dst_port),
619                                            le64_to_cpu(pkt->hdr.src_cid),
620                                            le32_to_cpu(pkt->hdr.src_port));
621         if (!reply)
622                 return -ENOMEM;
623
624         t = virtio_transport_get_ops();
625         if (!t) {
626                 virtio_transport_free_pkt(reply);
627                 return -ENOTCONN;
628         }
629
630         return t->send_pkt(reply);
631 }
632
633 static void virtio_transport_wait_close(struct sock *sk, long timeout)
634 {
635         if (timeout) {
636                 DEFINE_WAIT(wait);
637
638                 do {
639                         prepare_to_wait(sk_sleep(sk), &wait,
640                                         TASK_INTERRUPTIBLE);
641                         if (sk_wait_event(sk, &timeout,
642                                           sock_flag(sk, SOCK_DONE)))
643                                 break;
644                 } while (!signal_pending(current) && timeout);
645
646                 finish_wait(sk_sleep(sk), &wait);
647         }
648 }
649
650 static void virtio_transport_do_close(struct vsock_sock *vsk,
651                                       bool cancel_timeout)
652 {
653         struct sock *sk = sk_vsock(vsk);
654
655         sock_set_flag(sk, SOCK_DONE);
656         vsk->peer_shutdown = SHUTDOWN_MASK;
657         if (vsock_stream_has_data(vsk) <= 0)
658                 sk->sk_state = SS_DISCONNECTING;
659         sk->sk_state_change(sk);
660
661         if (vsk->close_work_scheduled &&
662             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
663                 vsk->close_work_scheduled = false;
664
665                 vsock_remove_sock(vsk);
666
667                 /* Release refcnt obtained when we scheduled the timeout */
668                 sock_put(sk);
669         }
670 }
671
672 static void virtio_transport_close_timeout(struct work_struct *work)
673 {
674         struct vsock_sock *vsk =
675                 container_of(work, struct vsock_sock, close_work.work);
676         struct sock *sk = sk_vsock(vsk);
677
678         sock_hold(sk);
679         lock_sock(sk);
680
681         if (!sock_flag(sk, SOCK_DONE)) {
682                 (void)virtio_transport_reset(vsk, NULL);
683
684                 virtio_transport_do_close(vsk, false);
685         }
686
687         vsk->close_work_scheduled = false;
688
689         release_sock(sk);
690         sock_put(sk);
691 }
692
693 /* User context, vsk->sk is locked */
694 static bool virtio_transport_close(struct vsock_sock *vsk)
695 {
696         struct sock *sk = &vsk->sk;
697
698         if (!(sk->sk_state == SS_CONNECTED ||
699               sk->sk_state == SS_DISCONNECTING))
700                 return true;
701
702         /* Already received SHUTDOWN from peer, reply with RST */
703         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
704                 (void)virtio_transport_reset(vsk, NULL);
705                 return true;
706         }
707
708         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
709                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
710
711         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
712                 virtio_transport_wait_close(sk, sk->sk_lingertime);
713
714         if (sock_flag(sk, SOCK_DONE)) {
715                 return true;
716         }
717
718         sock_hold(sk);
719         INIT_DELAYED_WORK(&vsk->close_work,
720                           virtio_transport_close_timeout);
721         vsk->close_work_scheduled = true;
722         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
723         return false;
724 }
725
726 void virtio_transport_release(struct vsock_sock *vsk)
727 {
728         struct virtio_vsock_sock *vvs = vsk->trans;
729         struct virtio_vsock_pkt *pkt, *tmp;
730         struct sock *sk = &vsk->sk;
731         bool remove_sock = true;
732
733         lock_sock(sk);
734         if (sk->sk_type == SOCK_STREAM)
735                 remove_sock = virtio_transport_close(vsk);
736
737         list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
738                 list_del(&pkt->list);
739                 virtio_transport_free_pkt(pkt);
740         }
741         release_sock(sk);
742
743         if (remove_sock)
744                 vsock_remove_sock(vsk);
745 }
746 EXPORT_SYMBOL_GPL(virtio_transport_release);
747
748 static int
749 virtio_transport_recv_connecting(struct sock *sk,
750                                  struct virtio_vsock_pkt *pkt)
751 {
752         struct vsock_sock *vsk = vsock_sk(sk);
753         int err;
754         int skerr;
755
756         switch (le16_to_cpu(pkt->hdr.op)) {
757         case VIRTIO_VSOCK_OP_RESPONSE:
758                 sk->sk_state = SS_CONNECTED;
759                 sk->sk_socket->state = SS_CONNECTED;
760                 vsock_insert_connected(vsk);
761                 sk->sk_state_change(sk);
762                 break;
763         case VIRTIO_VSOCK_OP_INVALID:
764                 break;
765         case VIRTIO_VSOCK_OP_RST:
766                 skerr = ECONNRESET;
767                 err = 0;
768                 goto destroy;
769         default:
770                 skerr = EPROTO;
771                 err = -EINVAL;
772                 goto destroy;
773         }
774         return 0;
775
776 destroy:
777         virtio_transport_reset(vsk, pkt);
778         sk->sk_state = SS_UNCONNECTED;
779         sk->sk_err = skerr;
780         sk->sk_error_report(sk);
781         return err;
782 }
783
784 static int
785 virtio_transport_recv_connected(struct sock *sk,
786                                 struct virtio_vsock_pkt *pkt)
787 {
788         struct vsock_sock *vsk = vsock_sk(sk);
789         struct virtio_vsock_sock *vvs = vsk->trans;
790         int err = 0;
791
792         switch (le16_to_cpu(pkt->hdr.op)) {
793         case VIRTIO_VSOCK_OP_RW:
794                 pkt->len = le32_to_cpu(pkt->hdr.len);
795                 pkt->off = 0;
796
797                 spin_lock_bh(&vvs->rx_lock);
798                 virtio_transport_inc_rx_pkt(vvs, pkt);
799                 list_add_tail(&pkt->list, &vvs->rx_queue);
800                 spin_unlock_bh(&vvs->rx_lock);
801
802                 sk->sk_data_ready(sk);
803                 return err;
804         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
805                 sk->sk_write_space(sk);
806                 break;
807         case VIRTIO_VSOCK_OP_SHUTDOWN:
808                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
809                         vsk->peer_shutdown |= RCV_SHUTDOWN;
810                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
811                         vsk->peer_shutdown |= SEND_SHUTDOWN;
812                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
813                     vsock_stream_has_data(vsk) <= 0)
814                         sk->sk_state = SS_DISCONNECTING;
815                 if (le32_to_cpu(pkt->hdr.flags))
816                         sk->sk_state_change(sk);
817                 break;
818         case VIRTIO_VSOCK_OP_RST:
819                 virtio_transport_do_close(vsk, true);
820                 break;
821         default:
822                 err = -EINVAL;
823                 break;
824         }
825
826         virtio_transport_free_pkt(pkt);
827         return err;
828 }
829
830 static void
831 virtio_transport_recv_disconnecting(struct sock *sk,
832                                     struct virtio_vsock_pkt *pkt)
833 {
834         struct vsock_sock *vsk = vsock_sk(sk);
835
836         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
837                 virtio_transport_do_close(vsk, true);
838 }
839
840 static int
841 virtio_transport_send_response(struct vsock_sock *vsk,
842                                struct virtio_vsock_pkt *pkt)
843 {
844         struct virtio_vsock_pkt_info info = {
845                 .op = VIRTIO_VSOCK_OP_RESPONSE,
846                 .type = VIRTIO_VSOCK_TYPE_STREAM,
847                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
848                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
849                 .reply = true,
850                 .vsk = vsk,
851         };
852
853         return virtio_transport_send_pkt_info(vsk, &info);
854 }
855
856 /* Handle server socket */
857 static int
858 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
859 {
860         struct vsock_sock *vsk = vsock_sk(sk);
861         struct vsock_sock *vchild;
862         struct sock *child;
863
864         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
865                 virtio_transport_reset(vsk, pkt);
866                 return -EINVAL;
867         }
868
869         if (sk_acceptq_is_full(sk)) {
870                 virtio_transport_reset(vsk, pkt);
871                 return -ENOMEM;
872         }
873
874         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
875                                sk->sk_type, 0);
876         if (!child) {
877                 virtio_transport_reset(vsk, pkt);
878                 return -ENOMEM;
879         }
880
881         sk->sk_ack_backlog++;
882
883         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
884
885         child->sk_state = SS_CONNECTED;
886
887         vchild = vsock_sk(child);
888         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
889                         le32_to_cpu(pkt->hdr.dst_port));
890         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
891                         le32_to_cpu(pkt->hdr.src_port));
892
893         vsock_insert_connected(vchild);
894         vsock_enqueue_accept(sk, child);
895         virtio_transport_send_response(vchild, pkt);
896
897         release_sock(child);
898
899         sk->sk_data_ready(sk);
900         return 0;
901 }
902
903 static bool virtio_transport_space_update(struct sock *sk,
904                                           struct virtio_vsock_pkt *pkt)
905 {
906         struct vsock_sock *vsk = vsock_sk(sk);
907         struct virtio_vsock_sock *vvs = vsk->trans;
908         bool space_available;
909
910         /* buf_alloc and fwd_cnt is always included in the hdr */
911         spin_lock_bh(&vvs->tx_lock);
912         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
913         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
914         space_available = virtio_transport_has_space(vsk);
915         spin_unlock_bh(&vvs->tx_lock);
916         return space_available;
917 }
918
919 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
920  * lock.
921  */
922 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
923 {
924         struct sockaddr_vm src, dst;
925         struct vsock_sock *vsk;
926         struct sock *sk;
927         bool space_available;
928
929         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
930                         le32_to_cpu(pkt->hdr.src_port));
931         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
932                         le32_to_cpu(pkt->hdr.dst_port));
933
934         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
935                                         dst.svm_cid, dst.svm_port,
936                                         le32_to_cpu(pkt->hdr.len),
937                                         le16_to_cpu(pkt->hdr.type),
938                                         le16_to_cpu(pkt->hdr.op),
939                                         le32_to_cpu(pkt->hdr.flags),
940                                         le32_to_cpu(pkt->hdr.buf_alloc),
941                                         le32_to_cpu(pkt->hdr.fwd_cnt));
942
943         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
944                 (void)virtio_transport_reset_no_sock(pkt);
945                 goto free_pkt;
946         }
947
948         /* The socket must be in connected or bound table
949          * otherwise send reset back
950          */
951         sk = vsock_find_connected_socket(&src, &dst);
952         if (!sk) {
953                 sk = vsock_find_bound_socket(&dst);
954                 if (!sk) {
955                         (void)virtio_transport_reset_no_sock(pkt);
956                         goto free_pkt;
957                 }
958         }
959
960         vsk = vsock_sk(sk);
961
962         lock_sock(sk);
963
964         space_available = virtio_transport_space_update(sk, pkt);
965
966         /* Update CID in case it has changed after a transport reset event */
967         vsk->local_addr.svm_cid = dst.svm_cid;
968
969         if (space_available)
970                 sk->sk_write_space(sk);
971
972         switch (sk->sk_state) {
973         case VSOCK_SS_LISTEN:
974                 virtio_transport_recv_listen(sk, pkt);
975                 virtio_transport_free_pkt(pkt);
976                 break;
977         case SS_CONNECTING:
978                 virtio_transport_recv_connecting(sk, pkt);
979                 virtio_transport_free_pkt(pkt);
980                 break;
981         case SS_CONNECTED:
982                 virtio_transport_recv_connected(sk, pkt);
983                 break;
984         case SS_DISCONNECTING:
985                 virtio_transport_recv_disconnecting(sk, pkt);
986                 virtio_transport_free_pkt(pkt);
987                 break;
988         default:
989                 virtio_transport_free_pkt(pkt);
990                 break;
991         }
992         release_sock(sk);
993
994         /* Release refcnt obtained when we fetched this socket out of the
995          * bound or connected list.
996          */
997         sock_put(sk);
998         return;
999
1000 free_pkt:
1001         virtio_transport_free_pkt(pkt);
1002 }
1003 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1004
1005 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1006 {
1007         kfree(pkt->buf);
1008         kfree(pkt);
1009 }
1010 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1011
1012 MODULE_LICENSE("GPL v2");
1013 MODULE_AUTHOR("Asias He");
1014 MODULE_DESCRIPTION("common code for virtio vsock");