GNU Linux-libre 5.10.219-gnu1
[releases.git] / net / core / sock_reuseport.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * To speed up listener socket lookup, create an array to store all sockets
4  * listening on the same port.  This allows a decision to be made after finding
5  * the first socket.  An optional BPF program can also be configured for
6  * selecting the socket index from the array of available sockets.
7  */
8
9 #include <net/sock_reuseport.h>
10 #include <linux/bpf.h>
11 #include <linux/idr.h>
12 #include <linux/filter.h>
13 #include <linux/rcupdate.h>
14
15 #define INIT_SOCKS 128
16
17 DEFINE_SPINLOCK(reuseport_lock);
18
19 static DEFINE_IDA(reuseport_ida);
20
21 void reuseport_has_conns_set(struct sock *sk)
22 {
23         struct sock_reuseport *reuse;
24
25         if (!rcu_access_pointer(sk->sk_reuseport_cb))
26                 return;
27
28         spin_lock_bh(&reuseport_lock);
29         reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
30                                           lockdep_is_held(&reuseport_lock));
31         if (likely(reuse))
32                 reuse->has_conns = 1;
33         spin_unlock_bh(&reuseport_lock);
34 }
35 EXPORT_SYMBOL(reuseport_has_conns_set);
36
37 static struct sock_reuseport *__reuseport_alloc(unsigned int max_socks)
38 {
39         unsigned int size = sizeof(struct sock_reuseport) +
40                       sizeof(struct sock *) * max_socks;
41         struct sock_reuseport *reuse = kzalloc(size, GFP_ATOMIC);
42
43         if (!reuse)
44                 return NULL;
45
46         reuse->max_socks = max_socks;
47
48         RCU_INIT_POINTER(reuse->prog, NULL);
49         return reuse;
50 }
51
52 int reuseport_alloc(struct sock *sk, bool bind_inany)
53 {
54         struct sock_reuseport *reuse;
55         int id, ret = 0;
56
57         /* bh lock used since this function call may precede hlist lock in
58          * soft irq of receive path or setsockopt from process context
59          */
60         spin_lock_bh(&reuseport_lock);
61
62         /* Allocation attempts can occur concurrently via the setsockopt path
63          * and the bind/hash path.  Nothing to do when we lose the race.
64          */
65         reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
66                                           lockdep_is_held(&reuseport_lock));
67         if (reuse) {
68                 /* Only set reuse->bind_inany if the bind_inany is true.
69                  * Otherwise, it will overwrite the reuse->bind_inany
70                  * which was set by the bind/hash path.
71                  */
72                 if (bind_inany)
73                         reuse->bind_inany = bind_inany;
74                 goto out;
75         }
76
77         reuse = __reuseport_alloc(INIT_SOCKS);
78         if (!reuse) {
79                 ret = -ENOMEM;
80                 goto out;
81         }
82
83         id = ida_alloc(&reuseport_ida, GFP_ATOMIC);
84         if (id < 0) {
85                 kfree(reuse);
86                 ret = id;
87                 goto out;
88         }
89
90         reuse->reuseport_id = id;
91         reuse->socks[0] = sk;
92         reuse->num_socks = 1;
93         reuse->bind_inany = bind_inany;
94         rcu_assign_pointer(sk->sk_reuseport_cb, reuse);
95
96 out:
97         spin_unlock_bh(&reuseport_lock);
98
99         return ret;
100 }
101 EXPORT_SYMBOL(reuseport_alloc);
102
103 static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse)
104 {
105         struct sock_reuseport *more_reuse;
106         u32 more_socks_size, i;
107
108         more_socks_size = reuse->max_socks * 2U;
109         if (more_socks_size > U16_MAX)
110                 return NULL;
111
112         more_reuse = __reuseport_alloc(more_socks_size);
113         if (!more_reuse)
114                 return NULL;
115
116         more_reuse->num_socks = reuse->num_socks;
117         more_reuse->prog = reuse->prog;
118         more_reuse->reuseport_id = reuse->reuseport_id;
119         more_reuse->bind_inany = reuse->bind_inany;
120         more_reuse->has_conns = reuse->has_conns;
121
122         memcpy(more_reuse->socks, reuse->socks,
123                reuse->num_socks * sizeof(struct sock *));
124         more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
125
126         for (i = 0; i < reuse->num_socks; ++i)
127                 rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb,
128                                    more_reuse);
129
130         /* Note: we use kfree_rcu here instead of reuseport_free_rcu so
131          * that reuse and more_reuse can temporarily share a reference
132          * to prog.
133          */
134         kfree_rcu(reuse, rcu);
135         return more_reuse;
136 }
137
138 static void reuseport_free_rcu(struct rcu_head *head)
139 {
140         struct sock_reuseport *reuse;
141
142         reuse = container_of(head, struct sock_reuseport, rcu);
143         sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1));
144         ida_free(&reuseport_ida, reuse->reuseport_id);
145         kfree(reuse);
146 }
147
148 /**
149  *  reuseport_add_sock - Add a socket to the reuseport group of another.
150  *  @sk:  New socket to add to the group.
151  *  @sk2: Socket belonging to the existing reuseport group.
152  *  @bind_inany: Whether or not the group is bound to a local INANY address.
153  *
154  *  May return ENOMEM and not add socket to group under memory pressure.
155  */
156 int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
157 {
158         struct sock_reuseport *old_reuse, *reuse;
159
160         if (!rcu_access_pointer(sk2->sk_reuseport_cb)) {
161                 int err = reuseport_alloc(sk2, bind_inany);
162
163                 if (err)
164                         return err;
165         }
166
167         spin_lock_bh(&reuseport_lock);
168         reuse = rcu_dereference_protected(sk2->sk_reuseport_cb,
169                                           lockdep_is_held(&reuseport_lock));
170         old_reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
171                                              lockdep_is_held(&reuseport_lock));
172         if (old_reuse && old_reuse->num_socks != 1) {
173                 spin_unlock_bh(&reuseport_lock);
174                 return -EBUSY;
175         }
176
177         if (reuse->num_socks == reuse->max_socks) {
178                 reuse = reuseport_grow(reuse);
179                 if (!reuse) {
180                         spin_unlock_bh(&reuseport_lock);
181                         return -ENOMEM;
182                 }
183         }
184
185         reuse->socks[reuse->num_socks] = sk;
186         /* paired with smp_rmb() in reuseport_select_sock() */
187         smp_wmb();
188         reuse->num_socks++;
189         rcu_assign_pointer(sk->sk_reuseport_cb, reuse);
190
191         spin_unlock_bh(&reuseport_lock);
192
193         if (old_reuse)
194                 call_rcu(&old_reuse->rcu, reuseport_free_rcu);
195         return 0;
196 }
197 EXPORT_SYMBOL(reuseport_add_sock);
198
199 void reuseport_detach_sock(struct sock *sk)
200 {
201         struct sock_reuseport *reuse;
202         int i;
203
204         spin_lock_bh(&reuseport_lock);
205         reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
206                                           lockdep_is_held(&reuseport_lock));
207
208         /* Notify the bpf side. The sk may be added to a sockarray
209          * map. If so, sockarray logic will remove it from the map.
210          *
211          * Other bpf map types that work with reuseport, like sockmap,
212          * don't need an explicit callback from here. They override sk
213          * unhash/close ops to remove the sk from the map before we
214          * get to this point.
215          */
216         bpf_sk_reuseport_detach(sk);
217
218         rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
219
220         for (i = 0; i < reuse->num_socks; i++) {
221                 if (reuse->socks[i] == sk) {
222                         reuse->socks[i] = reuse->socks[reuse->num_socks - 1];
223                         reuse->num_socks--;
224                         if (reuse->num_socks == 0)
225                                 call_rcu(&reuse->rcu, reuseport_free_rcu);
226                         break;
227                 }
228         }
229         spin_unlock_bh(&reuseport_lock);
230 }
231 EXPORT_SYMBOL(reuseport_detach_sock);
232
233 static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks,
234                                    struct bpf_prog *prog, struct sk_buff *skb,
235                                    int hdr_len)
236 {
237         struct sk_buff *nskb = NULL;
238         u32 index;
239
240         if (skb_shared(skb)) {
241                 nskb = skb_clone(skb, GFP_ATOMIC);
242                 if (!nskb)
243                         return NULL;
244                 skb = nskb;
245         }
246
247         /* temporarily advance data past protocol header */
248         if (!pskb_pull(skb, hdr_len)) {
249                 kfree_skb(nskb);
250                 return NULL;
251         }
252         index = bpf_prog_run_save_cb(prog, skb);
253         __skb_push(skb, hdr_len);
254
255         consume_skb(nskb);
256
257         if (index >= socks)
258                 return NULL;
259
260         return reuse->socks[index];
261 }
262
263 /**
264  *  reuseport_select_sock - Select a socket from an SO_REUSEPORT group.
265  *  @sk: First socket in the group.
266  *  @hash: When no BPF filter is available, use this hash to select.
267  *  @skb: skb to run through BPF filter.
268  *  @hdr_len: BPF filter expects skb data pointer at payload data.  If
269  *    the skb does not yet point at the payload, this parameter represents
270  *    how far the pointer needs to advance to reach the payload.
271  *  Returns a socket that should receive the packet (or NULL on error).
272  */
273 struct sock *reuseport_select_sock(struct sock *sk,
274                                    u32 hash,
275                                    struct sk_buff *skb,
276                                    int hdr_len)
277 {
278         struct sock_reuseport *reuse;
279         struct bpf_prog *prog;
280         struct sock *sk2 = NULL;
281         u16 socks;
282
283         rcu_read_lock();
284         reuse = rcu_dereference(sk->sk_reuseport_cb);
285
286         /* if memory allocation failed or add call is not yet complete */
287         if (!reuse)
288                 goto out;
289
290         prog = rcu_dereference(reuse->prog);
291         socks = READ_ONCE(reuse->num_socks);
292         if (likely(socks)) {
293                 /* paired with smp_wmb() in reuseport_add_sock() */
294                 smp_rmb();
295
296                 if (!prog || !skb)
297                         goto select_by_hash;
298
299                 if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
300                         sk2 = bpf_run_sk_reuseport(reuse, sk, prog, skb, hash);
301                 else
302                         sk2 = run_bpf_filter(reuse, socks, prog, skb, hdr_len);
303
304 select_by_hash:
305                 /* no bpf or invalid bpf result: fall back to hash usage */
306                 if (!sk2) {
307                         int i, j;
308
309                         i = j = reciprocal_scale(hash, socks);
310                         while (reuse->socks[i]->sk_state == TCP_ESTABLISHED) {
311                                 i++;
312                                 if (i >= socks)
313                                         i = 0;
314                                 if (i == j)
315                                         goto out;
316                         }
317                         sk2 = reuse->socks[i];
318                 }
319         }
320
321 out:
322         rcu_read_unlock();
323         return sk2;
324 }
325 EXPORT_SYMBOL(reuseport_select_sock);
326
327 int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
328 {
329         struct sock_reuseport *reuse;
330         struct bpf_prog *old_prog;
331
332         if (sk_unhashed(sk) && sk->sk_reuseport) {
333                 int err = reuseport_alloc(sk, false);
334
335                 if (err)
336                         return err;
337         } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
338                 /* The socket wasn't bound with SO_REUSEPORT */
339                 return -EINVAL;
340         }
341
342         spin_lock_bh(&reuseport_lock);
343         reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
344                                           lockdep_is_held(&reuseport_lock));
345         old_prog = rcu_dereference_protected(reuse->prog,
346                                              lockdep_is_held(&reuseport_lock));
347         rcu_assign_pointer(reuse->prog, prog);
348         spin_unlock_bh(&reuseport_lock);
349
350         sk_reuseport_prog_free(old_prog);
351         return 0;
352 }
353 EXPORT_SYMBOL(reuseport_attach_prog);
354
355 int reuseport_detach_prog(struct sock *sk)
356 {
357         struct sock_reuseport *reuse;
358         struct bpf_prog *old_prog;
359
360         if (!rcu_access_pointer(sk->sk_reuseport_cb))
361                 return sk->sk_reuseport ? -ENOENT : -EINVAL;
362
363         old_prog = NULL;
364         spin_lock_bh(&reuseport_lock);
365         reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
366                                           lockdep_is_held(&reuseport_lock));
367         old_prog = rcu_replace_pointer(reuse->prog, old_prog,
368                                        lockdep_is_held(&reuseport_lock));
369         spin_unlock_bh(&reuseport_lock);
370
371         if (!old_prog)
372                 return -ENOENT;
373
374         sk_reuseport_prog_free(old_prog);
375         return 0;
376 }
377 EXPORT_SYMBOL(reuseport_detach_prog);