GNU Linux-libre 6.1.86-gnu
[releases.git] / drivers / xen / pvcalls-front.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
4  */
5
6 #include <linux/module.h>
7 #include <linux/net.h>
8 #include <linux/socket.h>
9
10 #include <net/sock.h>
11
12 #include <xen/events.h>
13 #include <xen/grant_table.h>
14 #include <xen/xen.h>
15 #include <xen/xenbus.h>
16 #include <xen/interface/io/pvcalls.h>
17
18 #include "pvcalls-front.h"
19
20 #define PVCALLS_INVALID_ID UINT_MAX
21 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
22 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
23 #define PVCALLS_FRONT_MAX_SPIN 5000
24
25 static struct proto pvcalls_proto = {
26         .name   = "PVCalls",
27         .owner  = THIS_MODULE,
28         .obj_size = sizeof(struct sock),
29 };
30
31 struct pvcalls_bedata {
32         struct xen_pvcalls_front_ring ring;
33         grant_ref_t ref;
34         int irq;
35
36         struct list_head socket_mappings;
37         spinlock_t socket_lock;
38
39         wait_queue_head_t inflight_req;
40         struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING];
41 };
42 /* Only one front/back connection supported. */
43 static struct xenbus_device *pvcalls_front_dev;
44 static atomic_t pvcalls_refcount;
45
46 /* first increment refcount, then proceed */
47 #define pvcalls_enter() {               \
48         atomic_inc(&pvcalls_refcount);      \
49 }
50
51 /* first complete other operations, then decrement refcount */
52 #define pvcalls_exit() {                \
53         atomic_dec(&pvcalls_refcount);      \
54 }
55
56 struct sock_mapping {
57         bool active_socket;
58         struct list_head list;
59         struct socket *sock;
60         atomic_t refcount;
61         union {
62                 struct {
63                         int irq;
64                         grant_ref_t ref;
65                         struct pvcalls_data_intf *ring;
66                         struct pvcalls_data data;
67                         struct mutex in_mutex;
68                         struct mutex out_mutex;
69
70                         wait_queue_head_t inflight_conn_req;
71                 } active;
72                 struct {
73                 /*
74                  * Socket status, needs to be 64-bit aligned due to the
75                  * test_and_* functions which have this requirement on arm64.
76                  */
77 #define PVCALLS_STATUS_UNINITALIZED  0
78 #define PVCALLS_STATUS_BIND          1
79 #define PVCALLS_STATUS_LISTEN        2
80                         uint8_t status __attribute__((aligned(8)));
81                 /*
82                  * Internal state-machine flags.
83                  * Only one accept operation can be inflight for a socket.
84                  * Only one poll operation can be inflight for a given socket.
85                  * flags needs to be 64-bit aligned due to the test_and_*
86                  * functions which have this requirement on arm64.
87                  */
88 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
89 #define PVCALLS_FLAG_POLL_INFLIGHT   1
90 #define PVCALLS_FLAG_POLL_RET        2
91                         uint8_t flags __attribute__((aligned(8)));
92                         uint32_t inflight_req_id;
93                         struct sock_mapping *accept_map;
94                         wait_queue_head_t inflight_accept_req;
95                 } passive;
96         };
97 };
98
99 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
100 {
101         struct sock_mapping *map;
102
103         if (!pvcalls_front_dev ||
104                 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
105                 return ERR_PTR(-ENOTCONN);
106
107         map = (struct sock_mapping *)sock->sk->sk_send_head;
108         if (map == NULL)
109                 return ERR_PTR(-ENOTSOCK);
110
111         pvcalls_enter();
112         atomic_inc(&map->refcount);
113         return map;
114 }
115
116 static inline void pvcalls_exit_sock(struct socket *sock)
117 {
118         struct sock_mapping *map;
119
120         map = (struct sock_mapping *)sock->sk->sk_send_head;
121         atomic_dec(&map->refcount);
122         pvcalls_exit();
123 }
124
125 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
126 {
127         *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
128         if (RING_FULL(&bedata->ring) ||
129             bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID)
130                 return -EAGAIN;
131         return 0;
132 }
133
134 static bool pvcalls_front_write_todo(struct sock_mapping *map)
135 {
136         struct pvcalls_data_intf *intf = map->active.ring;
137         RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
138         int32_t error;
139
140         error = intf->out_error;
141         if (error == -ENOTCONN)
142                 return false;
143         if (error != 0)
144                 return true;
145
146         cons = intf->out_cons;
147         prod = intf->out_prod;
148         return !!(size - pvcalls_queued(prod, cons, size));
149 }
150
151 static bool pvcalls_front_read_todo(struct sock_mapping *map)
152 {
153         struct pvcalls_data_intf *intf = map->active.ring;
154         RING_IDX cons, prod;
155         int32_t error;
156
157         cons = intf->in_cons;
158         prod = intf->in_prod;
159         error = intf->in_error;
160         return (error != 0 ||
161                 pvcalls_queued(prod, cons,
162                                XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0);
163 }
164
165 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id)
166 {
167         struct xenbus_device *dev = dev_id;
168         struct pvcalls_bedata *bedata;
169         struct xen_pvcalls_response *rsp;
170         uint8_t *src, *dst;
171         int req_id = 0, more = 0, done = 0;
172
173         if (dev == NULL)
174                 return IRQ_HANDLED;
175
176         pvcalls_enter();
177         bedata = dev_get_drvdata(&dev->dev);
178         if (bedata == NULL) {
179                 pvcalls_exit();
180                 return IRQ_HANDLED;
181         }
182
183 again:
184         while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) {
185                 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons);
186
187                 req_id = rsp->req_id;
188                 if (rsp->cmd == PVCALLS_POLL) {
189                         struct sock_mapping *map = (struct sock_mapping *)(uintptr_t)
190                                                    rsp->u.poll.id;
191
192                         clear_bit(PVCALLS_FLAG_POLL_INFLIGHT,
193                                   (void *)&map->passive.flags);
194                         /*
195                          * clear INFLIGHT, then set RET. It pairs with
196                          * the checks at the beginning of
197                          * pvcalls_front_poll_passive.
198                          */
199                         smp_wmb();
200                         set_bit(PVCALLS_FLAG_POLL_RET,
201                                 (void *)&map->passive.flags);
202                 } else {
203                         dst = (uint8_t *)&bedata->rsp[req_id] +
204                               sizeof(rsp->req_id);
205                         src = (uint8_t *)rsp + sizeof(rsp->req_id);
206                         memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id));
207                         /*
208                          * First copy the rest of the data, then req_id. It is
209                          * paired with the barrier when accessing bedata->rsp.
210                          */
211                         smp_wmb();
212                         bedata->rsp[req_id].req_id = req_id;
213                 }
214
215                 done = 1;
216                 bedata->ring.rsp_cons++;
217         }
218
219         RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more);
220         if (more)
221                 goto again;
222         if (done)
223                 wake_up(&bedata->inflight_req);
224         pvcalls_exit();
225         return IRQ_HANDLED;
226 }
227
228 static void free_active_ring(struct sock_mapping *map);
229
230 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata,
231                                    struct sock_mapping *map)
232 {
233         int i;
234
235         unbind_from_irqhandler(map->active.irq, map);
236
237         spin_lock(&bedata->socket_lock);
238         if (!list_empty(&map->list))
239                 list_del_init(&map->list);
240         spin_unlock(&bedata->socket_lock);
241
242         for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
243                 gnttab_end_foreign_access(map->active.ring->ref[i], NULL);
244         gnttab_end_foreign_access(map->active.ref, NULL);
245         free_active_ring(map);
246
247         kfree(map);
248 }
249
250 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map)
251 {
252         struct sock_mapping *map = sock_map;
253
254         if (map == NULL)
255                 return IRQ_HANDLED;
256
257         wake_up_interruptible(&map->active.inflight_conn_req);
258
259         return IRQ_HANDLED;
260 }
261
262 int pvcalls_front_socket(struct socket *sock)
263 {
264         struct pvcalls_bedata *bedata;
265         struct sock_mapping *map = NULL;
266         struct xen_pvcalls_request *req;
267         int notify, req_id, ret;
268
269         /*
270          * PVCalls only supports domain AF_INET,
271          * type SOCK_STREAM and protocol 0 sockets for now.
272          *
273          * Check socket type here, AF_INET and protocol checks are done
274          * by the caller.
275          */
276         if (sock->type != SOCK_STREAM)
277                 return -EOPNOTSUPP;
278
279         pvcalls_enter();
280         if (!pvcalls_front_dev) {
281                 pvcalls_exit();
282                 return -EACCES;
283         }
284         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
285
286         map = kzalloc(sizeof(*map), GFP_KERNEL);
287         if (map == NULL) {
288                 pvcalls_exit();
289                 return -ENOMEM;
290         }
291
292         spin_lock(&bedata->socket_lock);
293
294         ret = get_request(bedata, &req_id);
295         if (ret < 0) {
296                 kfree(map);
297                 spin_unlock(&bedata->socket_lock);
298                 pvcalls_exit();
299                 return ret;
300         }
301
302         /*
303          * sock->sk->sk_send_head is not used for ip sockets: reuse the
304          * field to store a pointer to the struct sock_mapping
305          * corresponding to the socket. This way, we can easily get the
306          * struct sock_mapping from the struct socket.
307          */
308         sock->sk->sk_send_head = (void *)map;
309         list_add_tail(&map->list, &bedata->socket_mappings);
310
311         req = RING_GET_REQUEST(&bedata->ring, req_id);
312         req->req_id = req_id;
313         req->cmd = PVCALLS_SOCKET;
314         req->u.socket.id = (uintptr_t) map;
315         req->u.socket.domain = AF_INET;
316         req->u.socket.type = SOCK_STREAM;
317         req->u.socket.protocol = IPPROTO_IP;
318
319         bedata->ring.req_prod_pvt++;
320         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
321         spin_unlock(&bedata->socket_lock);
322         if (notify)
323                 notify_remote_via_irq(bedata->irq);
324
325         wait_event(bedata->inflight_req,
326                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
327
328         /* read req_id, then the content */
329         smp_rmb();
330         ret = bedata->rsp[req_id].ret;
331         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
332
333         pvcalls_exit();
334         return ret;
335 }
336
337 static void free_active_ring(struct sock_mapping *map)
338 {
339         if (!map->active.ring)
340                 return;
341
342         free_pages_exact(map->active.data.in,
343                          PAGE_SIZE << map->active.ring->ring_order);
344         free_page((unsigned long)map->active.ring);
345 }
346
347 static int alloc_active_ring(struct sock_mapping *map)
348 {
349         void *bytes;
350
351         map->active.ring = (struct pvcalls_data_intf *)
352                 get_zeroed_page(GFP_KERNEL);
353         if (!map->active.ring)
354                 goto out;
355
356         map->active.ring->ring_order = PVCALLS_RING_ORDER;
357         bytes = alloc_pages_exact(PAGE_SIZE << PVCALLS_RING_ORDER,
358                                   GFP_KERNEL | __GFP_ZERO);
359         if (!bytes)
360                 goto out;
361
362         map->active.data.in = bytes;
363         map->active.data.out = bytes +
364                 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
365
366         return 0;
367
368 out:
369         free_active_ring(map);
370         return -ENOMEM;
371 }
372
373 static int create_active(struct sock_mapping *map, evtchn_port_t *evtchn)
374 {
375         void *bytes;
376         int ret, irq = -1, i;
377
378         *evtchn = 0;
379         init_waitqueue_head(&map->active.inflight_conn_req);
380
381         bytes = map->active.data.in;
382         for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
383                 map->active.ring->ref[i] = gnttab_grant_foreign_access(
384                         pvcalls_front_dev->otherend_id,
385                         pfn_to_gfn(virt_to_pfn(bytes) + i), 0);
386
387         map->active.ref = gnttab_grant_foreign_access(
388                 pvcalls_front_dev->otherend_id,
389                 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0);
390
391         ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn);
392         if (ret)
393                 goto out_error;
394         irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler,
395                                         0, "pvcalls-frontend", map);
396         if (irq < 0) {
397                 ret = irq;
398                 goto out_error;
399         }
400
401         map->active.irq = irq;
402         map->active_socket = true;
403         mutex_init(&map->active.in_mutex);
404         mutex_init(&map->active.out_mutex);
405
406         return 0;
407
408 out_error:
409         if (*evtchn > 0)
410                 xenbus_free_evtchn(pvcalls_front_dev, *evtchn);
411         return ret;
412 }
413
414 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
415                                 int addr_len, int flags)
416 {
417         struct pvcalls_bedata *bedata;
418         struct sock_mapping *map = NULL;
419         struct xen_pvcalls_request *req;
420         int notify, req_id, ret;
421         evtchn_port_t evtchn;
422
423         if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
424                 return -EOPNOTSUPP;
425
426         map = pvcalls_enter_sock(sock);
427         if (IS_ERR(map))
428                 return PTR_ERR(map);
429
430         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
431         ret = alloc_active_ring(map);
432         if (ret < 0) {
433                 pvcalls_exit_sock(sock);
434                 return ret;
435         }
436
437         spin_lock(&bedata->socket_lock);
438         ret = get_request(bedata, &req_id);
439         if (ret < 0) {
440                 spin_unlock(&bedata->socket_lock);
441                 free_active_ring(map);
442                 pvcalls_exit_sock(sock);
443                 return ret;
444         }
445         ret = create_active(map, &evtchn);
446         if (ret < 0) {
447                 spin_unlock(&bedata->socket_lock);
448                 free_active_ring(map);
449                 pvcalls_exit_sock(sock);
450                 return ret;
451         }
452
453         req = RING_GET_REQUEST(&bedata->ring, req_id);
454         req->req_id = req_id;
455         req->cmd = PVCALLS_CONNECT;
456         req->u.connect.id = (uintptr_t)map;
457         req->u.connect.len = addr_len;
458         req->u.connect.flags = flags;
459         req->u.connect.ref = map->active.ref;
460         req->u.connect.evtchn = evtchn;
461         memcpy(req->u.connect.addr, addr, sizeof(*addr));
462
463         map->sock = sock;
464
465         bedata->ring.req_prod_pvt++;
466         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
467         spin_unlock(&bedata->socket_lock);
468
469         if (notify)
470                 notify_remote_via_irq(bedata->irq);
471
472         wait_event(bedata->inflight_req,
473                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
474
475         /* read req_id, then the content */
476         smp_rmb();
477         ret = bedata->rsp[req_id].ret;
478         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
479         pvcalls_exit_sock(sock);
480         return ret;
481 }
482
483 static int __write_ring(struct pvcalls_data_intf *intf,
484                         struct pvcalls_data *data,
485                         struct iov_iter *msg_iter,
486                         int len)
487 {
488         RING_IDX cons, prod, size, masked_prod, masked_cons;
489         RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
490         int32_t error;
491
492         error = intf->out_error;
493         if (error < 0)
494                 return error;
495         cons = intf->out_cons;
496         prod = intf->out_prod;
497         /* read indexes before continuing */
498         virt_mb();
499
500         size = pvcalls_queued(prod, cons, array_size);
501         if (size > array_size)
502                 return -EINVAL;
503         if (size == array_size)
504                 return 0;
505         if (len > array_size - size)
506                 len = array_size - size;
507
508         masked_prod = pvcalls_mask(prod, array_size);
509         masked_cons = pvcalls_mask(cons, array_size);
510
511         if (masked_prod < masked_cons) {
512                 len = copy_from_iter(data->out + masked_prod, len, msg_iter);
513         } else {
514                 if (len > array_size - masked_prod) {
515                         int ret = copy_from_iter(data->out + masked_prod,
516                                        array_size - masked_prod, msg_iter);
517                         if (ret != array_size - masked_prod) {
518                                 len = ret;
519                                 goto out;
520                         }
521                         len = ret + copy_from_iter(data->out, len - ret, msg_iter);
522                 } else {
523                         len = copy_from_iter(data->out + masked_prod, len, msg_iter);
524                 }
525         }
526 out:
527         /* write to ring before updating pointer */
528         virt_wmb();
529         intf->out_prod += len;
530
531         return len;
532 }
533
534 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg,
535                           size_t len)
536 {
537         struct sock_mapping *map;
538         int sent, tot_sent = 0;
539         int count = 0, flags;
540
541         flags = msg->msg_flags;
542         if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
543                 return -EOPNOTSUPP;
544
545         map = pvcalls_enter_sock(sock);
546         if (IS_ERR(map))
547                 return PTR_ERR(map);
548
549         mutex_lock(&map->active.out_mutex);
550         if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
551                 mutex_unlock(&map->active.out_mutex);
552                 pvcalls_exit_sock(sock);
553                 return -EAGAIN;
554         }
555         if (len > INT_MAX)
556                 len = INT_MAX;
557
558 again:
559         count++;
560         sent = __write_ring(map->active.ring,
561                             &map->active.data, &msg->msg_iter,
562                             len);
563         if (sent > 0) {
564                 len -= sent;
565                 tot_sent += sent;
566                 notify_remote_via_irq(map->active.irq);
567         }
568         if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN)
569                 goto again;
570         if (sent < 0)
571                 tot_sent = sent;
572
573         mutex_unlock(&map->active.out_mutex);
574         pvcalls_exit_sock(sock);
575         return tot_sent;
576 }
577
578 static int __read_ring(struct pvcalls_data_intf *intf,
579                        struct pvcalls_data *data,
580                        struct iov_iter *msg_iter,
581                        size_t len, int flags)
582 {
583         RING_IDX cons, prod, size, masked_prod, masked_cons;
584         RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
585         int32_t error;
586
587         cons = intf->in_cons;
588         prod = intf->in_prod;
589         error = intf->in_error;
590         /* get pointers before reading from the ring */
591         virt_rmb();
592
593         size = pvcalls_queued(prod, cons, array_size);
594         masked_prod = pvcalls_mask(prod, array_size);
595         masked_cons = pvcalls_mask(cons, array_size);
596
597         if (size == 0)
598                 return error ?: size;
599
600         if (len > size)
601                 len = size;
602
603         if (masked_prod > masked_cons) {
604                 len = copy_to_iter(data->in + masked_cons, len, msg_iter);
605         } else {
606                 if (len > (array_size - masked_cons)) {
607                         int ret = copy_to_iter(data->in + masked_cons,
608                                      array_size - masked_cons, msg_iter);
609                         if (ret != array_size - masked_cons) {
610                                 len = ret;
611                                 goto out;
612                         }
613                         len = ret + copy_to_iter(data->in, len - ret, msg_iter);
614                 } else {
615                         len = copy_to_iter(data->in + masked_cons, len, msg_iter);
616                 }
617         }
618 out:
619         /* read data from the ring before increasing the index */
620         virt_mb();
621         if (!(flags & MSG_PEEK))
622                 intf->in_cons += len;
623
624         return len;
625 }
626
627 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
628                      int flags)
629 {
630         int ret;
631         struct sock_mapping *map;
632
633         if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
634                 return -EOPNOTSUPP;
635
636         map = pvcalls_enter_sock(sock);
637         if (IS_ERR(map))
638                 return PTR_ERR(map);
639
640         mutex_lock(&map->active.in_mutex);
641         if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER))
642                 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
643
644         while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) {
645                 wait_event_interruptible(map->active.inflight_conn_req,
646                                          pvcalls_front_read_todo(map));
647         }
648         ret = __read_ring(map->active.ring, &map->active.data,
649                           &msg->msg_iter, len, flags);
650
651         if (ret > 0)
652                 notify_remote_via_irq(map->active.irq);
653         if (ret == 0)
654                 ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0;
655         if (ret == -ENOTCONN)
656                 ret = 0;
657
658         mutex_unlock(&map->active.in_mutex);
659         pvcalls_exit_sock(sock);
660         return ret;
661 }
662
663 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
664 {
665         struct pvcalls_bedata *bedata;
666         struct sock_mapping *map = NULL;
667         struct xen_pvcalls_request *req;
668         int notify, req_id, ret;
669
670         if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
671                 return -EOPNOTSUPP;
672
673         map = pvcalls_enter_sock(sock);
674         if (IS_ERR(map))
675                 return PTR_ERR(map);
676         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
677
678         spin_lock(&bedata->socket_lock);
679         ret = get_request(bedata, &req_id);
680         if (ret < 0) {
681                 spin_unlock(&bedata->socket_lock);
682                 pvcalls_exit_sock(sock);
683                 return ret;
684         }
685         req = RING_GET_REQUEST(&bedata->ring, req_id);
686         req->req_id = req_id;
687         map->sock = sock;
688         req->cmd = PVCALLS_BIND;
689         req->u.bind.id = (uintptr_t)map;
690         memcpy(req->u.bind.addr, addr, sizeof(*addr));
691         req->u.bind.len = addr_len;
692
693         init_waitqueue_head(&map->passive.inflight_accept_req);
694
695         map->active_socket = false;
696
697         bedata->ring.req_prod_pvt++;
698         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
699         spin_unlock(&bedata->socket_lock);
700         if (notify)
701                 notify_remote_via_irq(bedata->irq);
702
703         wait_event(bedata->inflight_req,
704                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
705
706         /* read req_id, then the content */
707         smp_rmb();
708         ret = bedata->rsp[req_id].ret;
709         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
710
711         map->passive.status = PVCALLS_STATUS_BIND;
712         pvcalls_exit_sock(sock);
713         return 0;
714 }
715
716 int pvcalls_front_listen(struct socket *sock, int backlog)
717 {
718         struct pvcalls_bedata *bedata;
719         struct sock_mapping *map;
720         struct xen_pvcalls_request *req;
721         int notify, req_id, ret;
722
723         map = pvcalls_enter_sock(sock);
724         if (IS_ERR(map))
725                 return PTR_ERR(map);
726         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
727
728         if (map->passive.status != PVCALLS_STATUS_BIND) {
729                 pvcalls_exit_sock(sock);
730                 return -EOPNOTSUPP;
731         }
732
733         spin_lock(&bedata->socket_lock);
734         ret = get_request(bedata, &req_id);
735         if (ret < 0) {
736                 spin_unlock(&bedata->socket_lock);
737                 pvcalls_exit_sock(sock);
738                 return ret;
739         }
740         req = RING_GET_REQUEST(&bedata->ring, req_id);
741         req->req_id = req_id;
742         req->cmd = PVCALLS_LISTEN;
743         req->u.listen.id = (uintptr_t) map;
744         req->u.listen.backlog = backlog;
745
746         bedata->ring.req_prod_pvt++;
747         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
748         spin_unlock(&bedata->socket_lock);
749         if (notify)
750                 notify_remote_via_irq(bedata->irq);
751
752         wait_event(bedata->inflight_req,
753                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
754
755         /* read req_id, then the content */
756         smp_rmb();
757         ret = bedata->rsp[req_id].ret;
758         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
759
760         map->passive.status = PVCALLS_STATUS_LISTEN;
761         pvcalls_exit_sock(sock);
762         return ret;
763 }
764
765 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
766 {
767         struct pvcalls_bedata *bedata;
768         struct sock_mapping *map;
769         struct sock_mapping *map2 = NULL;
770         struct xen_pvcalls_request *req;
771         int notify, req_id, ret, nonblock;
772         evtchn_port_t evtchn;
773
774         map = pvcalls_enter_sock(sock);
775         if (IS_ERR(map))
776                 return PTR_ERR(map);
777         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
778
779         if (map->passive.status != PVCALLS_STATUS_LISTEN) {
780                 pvcalls_exit_sock(sock);
781                 return -EINVAL;
782         }
783
784         nonblock = flags & SOCK_NONBLOCK;
785         /*
786          * Backend only supports 1 inflight accept request, will return
787          * errors for the others
788          */
789         if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
790                              (void *)&map->passive.flags)) {
791                 req_id = READ_ONCE(map->passive.inflight_req_id);
792                 if (req_id != PVCALLS_INVALID_ID &&
793                     READ_ONCE(bedata->rsp[req_id].req_id) == req_id) {
794                         map2 = map->passive.accept_map;
795                         goto received;
796                 }
797                 if (nonblock) {
798                         pvcalls_exit_sock(sock);
799                         return -EAGAIN;
800                 }
801                 if (wait_event_interruptible(map->passive.inflight_accept_req,
802                         !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
803                                           (void *)&map->passive.flags))) {
804                         pvcalls_exit_sock(sock);
805                         return -EINTR;
806                 }
807         }
808
809         map2 = kzalloc(sizeof(*map2), GFP_KERNEL);
810         if (map2 == NULL) {
811                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
812                           (void *)&map->passive.flags);
813                 pvcalls_exit_sock(sock);
814                 return -ENOMEM;
815         }
816         ret = alloc_active_ring(map2);
817         if (ret < 0) {
818                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
819                                 (void *)&map->passive.flags);
820                 kfree(map2);
821                 pvcalls_exit_sock(sock);
822                 return ret;
823         }
824         spin_lock(&bedata->socket_lock);
825         ret = get_request(bedata, &req_id);
826         if (ret < 0) {
827                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
828                           (void *)&map->passive.flags);
829                 spin_unlock(&bedata->socket_lock);
830                 free_active_ring(map2);
831                 kfree(map2);
832                 pvcalls_exit_sock(sock);
833                 return ret;
834         }
835
836         ret = create_active(map2, &evtchn);
837         if (ret < 0) {
838                 free_active_ring(map2);
839                 kfree(map2);
840                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
841                           (void *)&map->passive.flags);
842                 spin_unlock(&bedata->socket_lock);
843                 pvcalls_exit_sock(sock);
844                 return ret;
845         }
846         list_add_tail(&map2->list, &bedata->socket_mappings);
847
848         req = RING_GET_REQUEST(&bedata->ring, req_id);
849         req->req_id = req_id;
850         req->cmd = PVCALLS_ACCEPT;
851         req->u.accept.id = (uintptr_t) map;
852         req->u.accept.ref = map2->active.ref;
853         req->u.accept.id_new = (uintptr_t) map2;
854         req->u.accept.evtchn = evtchn;
855         map->passive.accept_map = map2;
856
857         bedata->ring.req_prod_pvt++;
858         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
859         spin_unlock(&bedata->socket_lock);
860         if (notify)
861                 notify_remote_via_irq(bedata->irq);
862         /* We could check if we have received a response before returning. */
863         if (nonblock) {
864                 WRITE_ONCE(map->passive.inflight_req_id, req_id);
865                 pvcalls_exit_sock(sock);
866                 return -EAGAIN;
867         }
868
869         if (wait_event_interruptible(bedata->inflight_req,
870                 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
871                 pvcalls_exit_sock(sock);
872                 return -EINTR;
873         }
874         /* read req_id, then the content */
875         smp_rmb();
876
877 received:
878         map2->sock = newsock;
879         newsock->sk = sk_alloc(sock_net(sock->sk), PF_INET, GFP_KERNEL, &pvcalls_proto, false);
880         if (!newsock->sk) {
881                 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
882                 map->passive.inflight_req_id = PVCALLS_INVALID_ID;
883                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
884                           (void *)&map->passive.flags);
885                 pvcalls_front_free_map(bedata, map2);
886                 pvcalls_exit_sock(sock);
887                 return -ENOMEM;
888         }
889         newsock->sk->sk_send_head = (void *)map2;
890
891         ret = bedata->rsp[req_id].ret;
892         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
893         map->passive.inflight_req_id = PVCALLS_INVALID_ID;
894
895         clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
896         wake_up(&map->passive.inflight_accept_req);
897
898         pvcalls_exit_sock(sock);
899         return ret;
900 }
901
902 static __poll_t pvcalls_front_poll_passive(struct file *file,
903                                                struct pvcalls_bedata *bedata,
904                                                struct sock_mapping *map,
905                                                poll_table *wait)
906 {
907         int notify, req_id, ret;
908         struct xen_pvcalls_request *req;
909
910         if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
911                      (void *)&map->passive.flags)) {
912                 uint32_t req_id = READ_ONCE(map->passive.inflight_req_id);
913
914                 if (req_id != PVCALLS_INVALID_ID &&
915                     READ_ONCE(bedata->rsp[req_id].req_id) == req_id)
916                         return EPOLLIN | EPOLLRDNORM;
917
918                 poll_wait(file, &map->passive.inflight_accept_req, wait);
919                 return 0;
920         }
921
922         if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET,
923                                (void *)&map->passive.flags))
924                 return EPOLLIN | EPOLLRDNORM;
925
926         /*
927          * First check RET, then INFLIGHT. No barriers necessary to
928          * ensure execution ordering because of the conditional
929          * instructions creating control dependencies.
930          */
931
932         if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT,
933                              (void *)&map->passive.flags)) {
934                 poll_wait(file, &bedata->inflight_req, wait);
935                 return 0;
936         }
937
938         spin_lock(&bedata->socket_lock);
939         ret = get_request(bedata, &req_id);
940         if (ret < 0) {
941                 spin_unlock(&bedata->socket_lock);
942                 return ret;
943         }
944         req = RING_GET_REQUEST(&bedata->ring, req_id);
945         req->req_id = req_id;
946         req->cmd = PVCALLS_POLL;
947         req->u.poll.id = (uintptr_t) map;
948
949         bedata->ring.req_prod_pvt++;
950         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
951         spin_unlock(&bedata->socket_lock);
952         if (notify)
953                 notify_remote_via_irq(bedata->irq);
954
955         poll_wait(file, &bedata->inflight_req, wait);
956         return 0;
957 }
958
959 static __poll_t pvcalls_front_poll_active(struct file *file,
960                                               struct pvcalls_bedata *bedata,
961                                               struct sock_mapping *map,
962                                               poll_table *wait)
963 {
964         __poll_t mask = 0;
965         int32_t in_error, out_error;
966         struct pvcalls_data_intf *intf = map->active.ring;
967
968         out_error = intf->out_error;
969         in_error = intf->in_error;
970
971         poll_wait(file, &map->active.inflight_conn_req, wait);
972         if (pvcalls_front_write_todo(map))
973                 mask |= EPOLLOUT | EPOLLWRNORM;
974         if (pvcalls_front_read_todo(map))
975                 mask |= EPOLLIN | EPOLLRDNORM;
976         if (in_error != 0 || out_error != 0)
977                 mask |= EPOLLERR;
978
979         return mask;
980 }
981
982 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock,
983                                poll_table *wait)
984 {
985         struct pvcalls_bedata *bedata;
986         struct sock_mapping *map;
987         __poll_t ret;
988
989         map = pvcalls_enter_sock(sock);
990         if (IS_ERR(map))
991                 return EPOLLNVAL;
992         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
993
994         if (map->active_socket)
995                 ret = pvcalls_front_poll_active(file, bedata, map, wait);
996         else
997                 ret = pvcalls_front_poll_passive(file, bedata, map, wait);
998         pvcalls_exit_sock(sock);
999         return ret;
1000 }
1001
1002 int pvcalls_front_release(struct socket *sock)
1003 {
1004         struct pvcalls_bedata *bedata;
1005         struct sock_mapping *map;
1006         int req_id, notify, ret;
1007         struct xen_pvcalls_request *req;
1008
1009         if (sock->sk == NULL)
1010                 return 0;
1011
1012         map = pvcalls_enter_sock(sock);
1013         if (IS_ERR(map)) {
1014                 if (PTR_ERR(map) == -ENOTCONN)
1015                         return -EIO;
1016                 else
1017                         return 0;
1018         }
1019         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1020
1021         spin_lock(&bedata->socket_lock);
1022         ret = get_request(bedata, &req_id);
1023         if (ret < 0) {
1024                 spin_unlock(&bedata->socket_lock);
1025                 pvcalls_exit_sock(sock);
1026                 return ret;
1027         }
1028         sock->sk->sk_send_head = NULL;
1029
1030         req = RING_GET_REQUEST(&bedata->ring, req_id);
1031         req->req_id = req_id;
1032         req->cmd = PVCALLS_RELEASE;
1033         req->u.release.id = (uintptr_t)map;
1034
1035         bedata->ring.req_prod_pvt++;
1036         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
1037         spin_unlock(&bedata->socket_lock);
1038         if (notify)
1039                 notify_remote_via_irq(bedata->irq);
1040
1041         wait_event(bedata->inflight_req,
1042                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
1043
1044         if (map->active_socket) {
1045                 /*
1046                  * Set in_error and wake up inflight_conn_req to force
1047                  * recvmsg waiters to exit.
1048                  */
1049                 map->active.ring->in_error = -EBADF;
1050                 wake_up_interruptible(&map->active.inflight_conn_req);
1051
1052                 /*
1053                  * We need to make sure that sendmsg/recvmsg on this socket have
1054                  * not started before we've cleared sk_send_head here. The
1055                  * easiest way to guarantee this is to see that no pvcalls
1056                  * (other than us) is in progress on this socket.
1057                  */
1058                 while (atomic_read(&map->refcount) > 1)
1059                         cpu_relax();
1060
1061                 pvcalls_front_free_map(bedata, map);
1062         } else {
1063                 wake_up(&bedata->inflight_req);
1064                 wake_up(&map->passive.inflight_accept_req);
1065
1066                 while (atomic_read(&map->refcount) > 1)
1067                         cpu_relax();
1068
1069                 spin_lock(&bedata->socket_lock);
1070                 list_del(&map->list);
1071                 spin_unlock(&bedata->socket_lock);
1072                 if (READ_ONCE(map->passive.inflight_req_id) != PVCALLS_INVALID_ID &&
1073                         READ_ONCE(map->passive.inflight_req_id) != 0) {
1074                         pvcalls_front_free_map(bedata,
1075                                                map->passive.accept_map);
1076                 }
1077                 kfree(map);
1078         }
1079         WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID);
1080
1081         pvcalls_exit();
1082         return 0;
1083 }
1084
1085 static const struct xenbus_device_id pvcalls_front_ids[] = {
1086         { "pvcalls" },
1087         { "" }
1088 };
1089
1090 static int pvcalls_front_remove(struct xenbus_device *dev)
1091 {
1092         struct pvcalls_bedata *bedata;
1093         struct sock_mapping *map = NULL, *n;
1094
1095         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1096         dev_set_drvdata(&dev->dev, NULL);
1097         pvcalls_front_dev = NULL;
1098         if (bedata->irq >= 0)
1099                 unbind_from_irqhandler(bedata->irq, dev);
1100
1101         list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1102                 map->sock->sk->sk_send_head = NULL;
1103                 if (map->active_socket) {
1104                         map->active.ring->in_error = -EBADF;
1105                         wake_up_interruptible(&map->active.inflight_conn_req);
1106                 }
1107         }
1108
1109         smp_mb();
1110         while (atomic_read(&pvcalls_refcount) > 0)
1111                 cpu_relax();
1112         list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1113                 if (map->active_socket) {
1114                         /* No need to lock, refcount is 0 */
1115                         pvcalls_front_free_map(bedata, map);
1116                 } else {
1117                         list_del(&map->list);
1118                         kfree(map);
1119                 }
1120         }
1121         if (bedata->ref != -1)
1122                 gnttab_end_foreign_access(bedata->ref, NULL);
1123         kfree(bedata->ring.sring);
1124         kfree(bedata);
1125         xenbus_switch_state(dev, XenbusStateClosed);
1126         return 0;
1127 }
1128
1129 static int pvcalls_front_probe(struct xenbus_device *dev,
1130                           const struct xenbus_device_id *id)
1131 {
1132         int ret = -ENOMEM, i;
1133         evtchn_port_t evtchn;
1134         unsigned int max_page_order, function_calls, len;
1135         char *versions;
1136         grant_ref_t gref_head = 0;
1137         struct xenbus_transaction xbt;
1138         struct pvcalls_bedata *bedata = NULL;
1139         struct xen_pvcalls_sring *sring;
1140
1141         if (pvcalls_front_dev != NULL) {
1142                 dev_err(&dev->dev, "only one PV Calls connection supported\n");
1143                 return -EINVAL;
1144         }
1145
1146         versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len);
1147         if (IS_ERR(versions))
1148                 return PTR_ERR(versions);
1149         if (!len)
1150                 return -EINVAL;
1151         if (strcmp(versions, "1")) {
1152                 kfree(versions);
1153                 return -EINVAL;
1154         }
1155         kfree(versions);
1156         max_page_order = xenbus_read_unsigned(dev->otherend,
1157                                               "max-page-order", 0);
1158         if (max_page_order < PVCALLS_RING_ORDER)
1159                 return -ENODEV;
1160         function_calls = xenbus_read_unsigned(dev->otherend,
1161                                               "function-calls", 0);
1162         /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
1163         if (function_calls != 1)
1164                 return -ENODEV;
1165         pr_info("%s max-page-order is %u\n", __func__, max_page_order);
1166
1167         bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL);
1168         if (!bedata)
1169                 return -ENOMEM;
1170
1171         dev_set_drvdata(&dev->dev, bedata);
1172         pvcalls_front_dev = dev;
1173         init_waitqueue_head(&bedata->inflight_req);
1174         INIT_LIST_HEAD(&bedata->socket_mappings);
1175         spin_lock_init(&bedata->socket_lock);
1176         bedata->irq = -1;
1177         bedata->ref = -1;
1178
1179         for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++)
1180                 bedata->rsp[i].req_id = PVCALLS_INVALID_ID;
1181
1182         sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL |
1183                                                              __GFP_ZERO);
1184         if (!sring)
1185                 goto error;
1186         SHARED_RING_INIT(sring);
1187         FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE);
1188
1189         ret = xenbus_alloc_evtchn(dev, &evtchn);
1190         if (ret)
1191                 goto error;
1192
1193         bedata->irq = bind_evtchn_to_irqhandler(evtchn,
1194                                                 pvcalls_front_event_handler,
1195                                                 0, "pvcalls-frontend", dev);
1196         if (bedata->irq < 0) {
1197                 ret = bedata->irq;
1198                 goto error;
1199         }
1200
1201         ret = gnttab_alloc_grant_references(1, &gref_head);
1202         if (ret < 0)
1203                 goto error;
1204         ret = gnttab_claim_grant_reference(&gref_head);
1205         if (ret < 0)
1206                 goto error;
1207         bedata->ref = ret;
1208         gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id,
1209                                         virt_to_gfn((void *)sring), 0);
1210
1211  again:
1212         ret = xenbus_transaction_start(&xbt);
1213         if (ret) {
1214                 xenbus_dev_fatal(dev, ret, "starting transaction");
1215                 goto error;
1216         }
1217         ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1);
1218         if (ret)
1219                 goto error_xenbus;
1220         ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref);
1221         if (ret)
1222                 goto error_xenbus;
1223         ret = xenbus_printf(xbt, dev->nodename, "port", "%u",
1224                             evtchn);
1225         if (ret)
1226                 goto error_xenbus;
1227         ret = xenbus_transaction_end(xbt, 0);
1228         if (ret) {
1229                 if (ret == -EAGAIN)
1230                         goto again;
1231                 xenbus_dev_fatal(dev, ret, "completing transaction");
1232                 goto error;
1233         }
1234         xenbus_switch_state(dev, XenbusStateInitialised);
1235
1236         return 0;
1237
1238  error_xenbus:
1239         xenbus_transaction_end(xbt, 1);
1240         xenbus_dev_fatal(dev, ret, "writing xenstore");
1241  error:
1242         pvcalls_front_remove(dev);
1243         return ret;
1244 }
1245
1246 static void pvcalls_front_changed(struct xenbus_device *dev,
1247                             enum xenbus_state backend_state)
1248 {
1249         switch (backend_state) {
1250         case XenbusStateReconfiguring:
1251         case XenbusStateReconfigured:
1252         case XenbusStateInitialising:
1253         case XenbusStateInitialised:
1254         case XenbusStateUnknown:
1255                 break;
1256
1257         case XenbusStateInitWait:
1258                 break;
1259
1260         case XenbusStateConnected:
1261                 xenbus_switch_state(dev, XenbusStateConnected);
1262                 break;
1263
1264         case XenbusStateClosed:
1265                 if (dev->state == XenbusStateClosed)
1266                         break;
1267                 /* Missed the backend's CLOSING state */
1268                 fallthrough;
1269         case XenbusStateClosing:
1270                 xenbus_frontend_closed(dev);
1271                 break;
1272         }
1273 }
1274
1275 static struct xenbus_driver pvcalls_front_driver = {
1276         .ids = pvcalls_front_ids,
1277         .probe = pvcalls_front_probe,
1278         .remove = pvcalls_front_remove,
1279         .otherend_changed = pvcalls_front_changed,
1280         .not_essential = true,
1281 };
1282
1283 static int __init pvcalls_frontend_init(void)
1284 {
1285         if (!xen_domain())
1286                 return -ENODEV;
1287
1288         pr_info("Initialising Xen pvcalls frontend driver\n");
1289
1290         return xenbus_register_frontend(&pvcalls_front_driver);
1291 }
1292
1293 module_init(pvcalls_frontend_init);
1294
1295 MODULE_DESCRIPTION("Xen PV Calls frontend driver");
1296 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1297 MODULE_LICENSE("GPL");