GNU Linux-libre 6.5.10-gnu
[releases.git] / net / vmw_vsock / vsock_bpf.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com>
3  *
4  * Based off of net/unix/unix_bpf.c
5  */
6
7 #include <linux/bpf.h>
8 #include <linux/module.h>
9 #include <linux/skmsg.h>
10 #include <linux/socket.h>
11 #include <linux/wait.h>
12 #include <net/af_vsock.h>
13 #include <net/sock.h>
14
15 #define vsock_sk_has_data(__sk, __psock)                                \
16                 ({      !skb_queue_empty(&(__sk)->sk_receive_queue) ||  \
17                         !skb_queue_empty(&(__psock)->ingress_skb) ||    \
18                         !list_empty(&(__psock)->ingress_msg);           \
19                 })
20
21 static struct proto *vsock_prot_saved __read_mostly;
22 static DEFINE_SPINLOCK(vsock_prot_lock);
23 static struct proto vsock_bpf_prot;
24
25 static bool vsock_has_data(struct sock *sk, struct sk_psock *psock)
26 {
27         struct vsock_sock *vsk = vsock_sk(sk);
28         s64 ret;
29
30         ret = vsock_connectible_has_data(vsk);
31         if (ret > 0)
32                 return true;
33
34         return vsock_sk_has_data(sk, psock);
35 }
36
37 static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo)
38 {
39         bool ret;
40
41         DEFINE_WAIT_FUNC(wait, woken_wake_function);
42
43         if (sk->sk_shutdown & RCV_SHUTDOWN)
44                 return true;
45
46         if (!timeo)
47                 return false;
48
49         add_wait_queue(sk_sleep(sk), &wait);
50         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
51         ret = vsock_has_data(sk, psock);
52         if (!ret) {
53                 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
54                 ret = vsock_has_data(sk, psock);
55         }
56         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
57         remove_wait_queue(sk_sleep(sk), &wait);
58         return ret;
59 }
60
61 static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags)
62 {
63         struct socket *sock = sk->sk_socket;
64         int err;
65
66         if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
67                 err = vsock_connectible_recvmsg(sock, msg, len, flags);
68         else if (sk->sk_type == SOCK_DGRAM)
69                 err = vsock_dgram_recvmsg(sock, msg, len, flags);
70         else
71                 err = -EPROTOTYPE;
72
73         return err;
74 }
75
76 static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
77                              size_t len, int flags, int *addr_len)
78 {
79         struct sk_psock *psock;
80         int copied;
81
82         psock = sk_psock_get(sk);
83         if (unlikely(!psock))
84                 return __vsock_recvmsg(sk, msg, len, flags);
85
86         lock_sock(sk);
87         if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) {
88                 release_sock(sk);
89                 sk_psock_put(sk, psock);
90                 return __vsock_recvmsg(sk, msg, len, flags);
91         }
92
93         copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
94         while (copied == 0) {
95                 long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
96
97                 if (!vsock_msg_wait_data(sk, psock, timeo)) {
98                         copied = -EAGAIN;
99                         break;
100                 }
101
102                 if (sk_psock_queue_empty(psock)) {
103                         release_sock(sk);
104                         sk_psock_put(sk, psock);
105                         return __vsock_recvmsg(sk, msg, len, flags);
106                 }
107
108                 copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
109         }
110
111         release_sock(sk);
112         sk_psock_put(sk, psock);
113
114         return copied;
115 }
116
117 /* Copy of original proto with updated sock_map methods */
118 static struct proto vsock_bpf_prot = {
119         .close = sock_map_close,
120         .recvmsg = vsock_bpf_recvmsg,
121         .sock_is_readable = sk_msg_is_readable,
122         .unhash = sock_map_unhash,
123 };
124
125 static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
126 {
127         *prot        = *base;
128         prot->close  = sock_map_close;
129         prot->recvmsg = vsock_bpf_recvmsg;
130         prot->sock_is_readable = sk_msg_is_readable;
131 }
132
133 static void vsock_bpf_check_needs_rebuild(struct proto *ops)
134 {
135         /* Paired with the smp_store_release() below. */
136         if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) {
137                 spin_lock_bh(&vsock_prot_lock);
138                 if (likely(ops != vsock_prot_saved)) {
139                         vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops);
140                         /* Make sure proto function pointers are updated before publishing the
141                          * pointer to the struct.
142                          */
143                         smp_store_release(&vsock_prot_saved, ops);
144                 }
145                 spin_unlock_bh(&vsock_prot_lock);
146         }
147 }
148
149 int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
150 {
151         struct vsock_sock *vsk;
152
153         if (restore) {
154                 sk->sk_write_space = psock->saved_write_space;
155                 sock_replace_proto(sk, psock->sk_proto);
156                 return 0;
157         }
158
159         vsk = vsock_sk(sk);
160         if (!vsk->transport)
161                 return -ENODEV;
162
163         if (!vsk->transport->read_skb)
164                 return -EOPNOTSUPP;
165
166         vsock_bpf_check_needs_rebuild(psock->sk_proto);
167         sock_replace_proto(sk, &vsock_bpf_prot);
168         return 0;
169 }
170
171 void __init vsock_bpf_build_proto(void)
172 {
173         vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto);
174 }