GNU Linux-libre 6.9.1-gnu
[releases.git] / kernel / bpf / memalloc.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2022 Meta Platforms, Inc. and affiliates. */
3 #include <linux/mm.h>
4 #include <linux/llist.h>
5 #include <linux/bpf.h>
6 #include <linux/irq_work.h>
7 #include <linux/bpf_mem_alloc.h>
8 #include <linux/memcontrol.h>
9 #include <asm/local.h>
10
11 /* Any context (including NMI) BPF specific memory allocator.
12  *
13  * Tracing BPF programs can attach to kprobe and fentry. Hence they
14  * run in unknown context where calling plain kmalloc() might not be safe.
15  *
16  * Front-end kmalloc() with per-cpu per-bucket cache of free elements.
17  * Refill this cache asynchronously from irq_work.
18  *
19  * CPU_0 buckets
20  * 16 32 64 96 128 196 256 512 1024 2048 4096
21  * ...
22  * CPU_N buckets
23  * 16 32 64 96 128 196 256 512 1024 2048 4096
24  *
25  * The buckets are prefilled at the start.
26  * BPF programs always run with migration disabled.
27  * It's safe to allocate from cache of the current cpu with irqs disabled.
28  * Free-ing is always done into bucket of the current cpu as well.
29  * irq_work trims extra free elements from buckets with kfree
30  * and refills them with kmalloc, so global kmalloc logic takes care
31  * of freeing objects allocated by one cpu and freed on another.
32  *
33  * Every allocated objected is padded with extra 8 bytes that contains
34  * struct llist_node.
35  */
36 #define LLIST_NODE_SZ sizeof(struct llist_node)
37
38 /* similar to kmalloc, but sizeof == 8 bucket is gone */
39 static u8 size_index[24] __ro_after_init = {
40         3,      /* 8 */
41         3,      /* 16 */
42         4,      /* 24 */
43         4,      /* 32 */
44         5,      /* 40 */
45         5,      /* 48 */
46         5,      /* 56 */
47         5,      /* 64 */
48         1,      /* 72 */
49         1,      /* 80 */
50         1,      /* 88 */
51         1,      /* 96 */
52         6,      /* 104 */
53         6,      /* 112 */
54         6,      /* 120 */
55         6,      /* 128 */
56         2,      /* 136 */
57         2,      /* 144 */
58         2,      /* 152 */
59         2,      /* 160 */
60         2,      /* 168 */
61         2,      /* 176 */
62         2,      /* 184 */
63         2       /* 192 */
64 };
65
66 static int bpf_mem_cache_idx(size_t size)
67 {
68         if (!size || size > 4096)
69                 return -1;
70
71         if (size <= 192)
72                 return size_index[(size - 1) / 8] - 1;
73
74         return fls(size - 1) - 2;
75 }
76
77 #define NUM_CACHES 11
78
79 struct bpf_mem_cache {
80         /* per-cpu list of free objects of size 'unit_size'.
81          * All accesses are done with interrupts disabled and 'active' counter
82          * protection with __llist_add() and __llist_del_first().
83          */
84         struct llist_head free_llist;
85         local_t active;
86
87         /* Operations on the free_list from unit_alloc/unit_free/bpf_mem_refill
88          * are sequenced by per-cpu 'active' counter. But unit_free() cannot
89          * fail. When 'active' is busy the unit_free() will add an object to
90          * free_llist_extra.
91          */
92         struct llist_head free_llist_extra;
93
94         struct irq_work refill_work;
95         struct obj_cgroup *objcg;
96         int unit_size;
97         /* count of objects in free_llist */
98         int free_cnt;
99         int low_watermark, high_watermark, batch;
100         int percpu_size;
101         bool draining;
102         struct bpf_mem_cache *tgt;
103
104         /* list of objects to be freed after RCU GP */
105         struct llist_head free_by_rcu;
106         struct llist_node *free_by_rcu_tail;
107         struct llist_head waiting_for_gp;
108         struct llist_node *waiting_for_gp_tail;
109         struct rcu_head rcu;
110         atomic_t call_rcu_in_progress;
111         struct llist_head free_llist_extra_rcu;
112
113         /* list of objects to be freed after RCU tasks trace GP */
114         struct llist_head free_by_rcu_ttrace;
115         struct llist_head waiting_for_gp_ttrace;
116         struct rcu_head rcu_ttrace;
117         atomic_t call_rcu_ttrace_in_progress;
118 };
119
120 struct bpf_mem_caches {
121         struct bpf_mem_cache cache[NUM_CACHES];
122 };
123
124 static const u16 sizes[NUM_CACHES] = {96, 192, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096};
125
126 static struct llist_node notrace *__llist_del_first(struct llist_head *head)
127 {
128         struct llist_node *entry, *next;
129
130         entry = head->first;
131         if (!entry)
132                 return NULL;
133         next = entry->next;
134         head->first = next;
135         return entry;
136 }
137
138 static void *__alloc(struct bpf_mem_cache *c, int node, gfp_t flags)
139 {
140         if (c->percpu_size) {
141                 void **obj = kmalloc_node(c->percpu_size, flags, node);
142                 void *pptr = __alloc_percpu_gfp(c->unit_size, 8, flags);
143
144                 if (!obj || !pptr) {
145                         free_percpu(pptr);
146                         kfree(obj);
147                         return NULL;
148                 }
149                 obj[1] = pptr;
150                 return obj;
151         }
152
153         return kmalloc_node(c->unit_size, flags | __GFP_ZERO, node);
154 }
155
156 static struct mem_cgroup *get_memcg(const struct bpf_mem_cache *c)
157 {
158 #ifdef CONFIG_MEMCG_KMEM
159         if (c->objcg)
160                 return get_mem_cgroup_from_objcg(c->objcg);
161 #endif
162
163 #ifdef CONFIG_MEMCG
164         return root_mem_cgroup;
165 #else
166         return NULL;
167 #endif
168 }
169
170 static void inc_active(struct bpf_mem_cache *c, unsigned long *flags)
171 {
172         if (IS_ENABLED(CONFIG_PREEMPT_RT))
173                 /* In RT irq_work runs in per-cpu kthread, so disable
174                  * interrupts to avoid preemption and interrupts and
175                  * reduce the chance of bpf prog executing on this cpu
176                  * when active counter is busy.
177                  */
178                 local_irq_save(*flags);
179         /* alloc_bulk runs from irq_work which will not preempt a bpf
180          * program that does unit_alloc/unit_free since IRQs are
181          * disabled there. There is no race to increment 'active'
182          * counter. It protects free_llist from corruption in case NMI
183          * bpf prog preempted this loop.
184          */
185         WARN_ON_ONCE(local_inc_return(&c->active) != 1);
186 }
187
188 static void dec_active(struct bpf_mem_cache *c, unsigned long *flags)
189 {
190         local_dec(&c->active);
191         if (IS_ENABLED(CONFIG_PREEMPT_RT))
192                 local_irq_restore(*flags);
193 }
194
195 static void add_obj_to_free_list(struct bpf_mem_cache *c, void *obj)
196 {
197         unsigned long flags;
198
199         inc_active(c, &flags);
200         __llist_add(obj, &c->free_llist);
201         c->free_cnt++;
202         dec_active(c, &flags);
203 }
204
205 /* Mostly runs from irq_work except __init phase. */
206 static void alloc_bulk(struct bpf_mem_cache *c, int cnt, int node, bool atomic)
207 {
208         struct mem_cgroup *memcg = NULL, *old_memcg;
209         gfp_t gfp;
210         void *obj;
211         int i;
212
213         gfp = __GFP_NOWARN | __GFP_ACCOUNT;
214         gfp |= atomic ? GFP_NOWAIT : GFP_KERNEL;
215
216         for (i = 0; i < cnt; i++) {
217                 /*
218                  * For every 'c' llist_del_first(&c->free_by_rcu_ttrace); is
219                  * done only by one CPU == current CPU. Other CPUs might
220                  * llist_add() and llist_del_all() in parallel.
221                  */
222                 obj = llist_del_first(&c->free_by_rcu_ttrace);
223                 if (!obj)
224                         break;
225                 add_obj_to_free_list(c, obj);
226         }
227         if (i >= cnt)
228                 return;
229
230         for (; i < cnt; i++) {
231                 obj = llist_del_first(&c->waiting_for_gp_ttrace);
232                 if (!obj)
233                         break;
234                 add_obj_to_free_list(c, obj);
235         }
236         if (i >= cnt)
237                 return;
238
239         memcg = get_memcg(c);
240         old_memcg = set_active_memcg(memcg);
241         for (; i < cnt; i++) {
242                 /* Allocate, but don't deplete atomic reserves that typical
243                  * GFP_ATOMIC would do. irq_work runs on this cpu and kmalloc
244                  * will allocate from the current numa node which is what we
245                  * want here.
246                  */
247                 obj = __alloc(c, node, gfp);
248                 if (!obj)
249                         break;
250                 add_obj_to_free_list(c, obj);
251         }
252         set_active_memcg(old_memcg);
253         mem_cgroup_put(memcg);
254 }
255
256 static void free_one(void *obj, bool percpu)
257 {
258         if (percpu) {
259                 free_percpu(((void **)obj)[1]);
260                 kfree(obj);
261                 return;
262         }
263
264         kfree(obj);
265 }
266
267 static int free_all(struct llist_node *llnode, bool percpu)
268 {
269         struct llist_node *pos, *t;
270         int cnt = 0;
271
272         llist_for_each_safe(pos, t, llnode) {
273                 free_one(pos, percpu);
274                 cnt++;
275         }
276         return cnt;
277 }
278
279 static void __free_rcu(struct rcu_head *head)
280 {
281         struct bpf_mem_cache *c = container_of(head, struct bpf_mem_cache, rcu_ttrace);
282
283         free_all(llist_del_all(&c->waiting_for_gp_ttrace), !!c->percpu_size);
284         atomic_set(&c->call_rcu_ttrace_in_progress, 0);
285 }
286
287 static void __free_rcu_tasks_trace(struct rcu_head *head)
288 {
289         /* If RCU Tasks Trace grace period implies RCU grace period,
290          * there is no need to invoke call_rcu().
291          */
292         if (rcu_trace_implies_rcu_gp())
293                 __free_rcu(head);
294         else
295                 call_rcu(head, __free_rcu);
296 }
297
298 static void enque_to_free(struct bpf_mem_cache *c, void *obj)
299 {
300         struct llist_node *llnode = obj;
301
302         /* bpf_mem_cache is a per-cpu object. Freeing happens in irq_work.
303          * Nothing races to add to free_by_rcu_ttrace list.
304          */
305         llist_add(llnode, &c->free_by_rcu_ttrace);
306 }
307
308 static void do_call_rcu_ttrace(struct bpf_mem_cache *c)
309 {
310         struct llist_node *llnode, *t;
311
312         if (atomic_xchg(&c->call_rcu_ttrace_in_progress, 1)) {
313                 if (unlikely(READ_ONCE(c->draining))) {
314                         llnode = llist_del_all(&c->free_by_rcu_ttrace);
315                         free_all(llnode, !!c->percpu_size);
316                 }
317                 return;
318         }
319
320         WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp_ttrace));
321         llist_for_each_safe(llnode, t, llist_del_all(&c->free_by_rcu_ttrace))
322                 llist_add(llnode, &c->waiting_for_gp_ttrace);
323
324         if (unlikely(READ_ONCE(c->draining))) {
325                 __free_rcu(&c->rcu_ttrace);
326                 return;
327         }
328
329         /* Use call_rcu_tasks_trace() to wait for sleepable progs to finish.
330          * If RCU Tasks Trace grace period implies RCU grace period, free
331          * these elements directly, else use call_rcu() to wait for normal
332          * progs to finish and finally do free_one() on each element.
333          */
334         call_rcu_tasks_trace(&c->rcu_ttrace, __free_rcu_tasks_trace);
335 }
336
337 static void free_bulk(struct bpf_mem_cache *c)
338 {
339         struct bpf_mem_cache *tgt = c->tgt;
340         struct llist_node *llnode, *t;
341         unsigned long flags;
342         int cnt;
343
344         WARN_ON_ONCE(tgt->unit_size != c->unit_size);
345         WARN_ON_ONCE(tgt->percpu_size != c->percpu_size);
346
347         do {
348                 inc_active(c, &flags);
349                 llnode = __llist_del_first(&c->free_llist);
350                 if (llnode)
351                         cnt = --c->free_cnt;
352                 else
353                         cnt = 0;
354                 dec_active(c, &flags);
355                 if (llnode)
356                         enque_to_free(tgt, llnode);
357         } while (cnt > (c->high_watermark + c->low_watermark) / 2);
358
359         /* and drain free_llist_extra */
360         llist_for_each_safe(llnode, t, llist_del_all(&c->free_llist_extra))
361                 enque_to_free(tgt, llnode);
362         do_call_rcu_ttrace(tgt);
363 }
364
365 static void __free_by_rcu(struct rcu_head *head)
366 {
367         struct bpf_mem_cache *c = container_of(head, struct bpf_mem_cache, rcu);
368         struct bpf_mem_cache *tgt = c->tgt;
369         struct llist_node *llnode;
370
371         WARN_ON_ONCE(tgt->unit_size != c->unit_size);
372         WARN_ON_ONCE(tgt->percpu_size != c->percpu_size);
373
374         llnode = llist_del_all(&c->waiting_for_gp);
375         if (!llnode)
376                 goto out;
377
378         llist_add_batch(llnode, c->waiting_for_gp_tail, &tgt->free_by_rcu_ttrace);
379
380         /* Objects went through regular RCU GP. Send them to RCU tasks trace */
381         do_call_rcu_ttrace(tgt);
382 out:
383         atomic_set(&c->call_rcu_in_progress, 0);
384 }
385
386 static void check_free_by_rcu(struct bpf_mem_cache *c)
387 {
388         struct llist_node *llnode, *t;
389         unsigned long flags;
390
391         /* drain free_llist_extra_rcu */
392         if (unlikely(!llist_empty(&c->free_llist_extra_rcu))) {
393                 inc_active(c, &flags);
394                 llist_for_each_safe(llnode, t, llist_del_all(&c->free_llist_extra_rcu))
395                         if (__llist_add(llnode, &c->free_by_rcu))
396                                 c->free_by_rcu_tail = llnode;
397                 dec_active(c, &flags);
398         }
399
400         if (llist_empty(&c->free_by_rcu))
401                 return;
402
403         if (atomic_xchg(&c->call_rcu_in_progress, 1)) {
404                 /*
405                  * Instead of kmalloc-ing new rcu_head and triggering 10k
406                  * call_rcu() to hit rcutree.qhimark and force RCU to notice
407                  * the overload just ask RCU to hurry up. There could be many
408                  * objects in free_by_rcu list.
409                  * This hint reduces memory consumption for an artificial
410                  * benchmark from 2 Gbyte to 150 Mbyte.
411                  */
412                 rcu_request_urgent_qs_task(current);
413                 return;
414         }
415
416         WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp));
417
418         inc_active(c, &flags);
419         WRITE_ONCE(c->waiting_for_gp.first, __llist_del_all(&c->free_by_rcu));
420         c->waiting_for_gp_tail = c->free_by_rcu_tail;
421         dec_active(c, &flags);
422
423         if (unlikely(READ_ONCE(c->draining))) {
424                 free_all(llist_del_all(&c->waiting_for_gp), !!c->percpu_size);
425                 atomic_set(&c->call_rcu_in_progress, 0);
426         } else {
427                 call_rcu_hurry(&c->rcu, __free_by_rcu);
428         }
429 }
430
431 static void bpf_mem_refill(struct irq_work *work)
432 {
433         struct bpf_mem_cache *c = container_of(work, struct bpf_mem_cache, refill_work);
434         int cnt;
435
436         /* Racy access to free_cnt. It doesn't need to be 100% accurate */
437         cnt = c->free_cnt;
438         if (cnt < c->low_watermark)
439                 /* irq_work runs on this cpu and kmalloc will allocate
440                  * from the current numa node which is what we want here.
441                  */
442                 alloc_bulk(c, c->batch, NUMA_NO_NODE, true);
443         else if (cnt > c->high_watermark)
444                 free_bulk(c);
445
446         check_free_by_rcu(c);
447 }
448
449 static void notrace irq_work_raise(struct bpf_mem_cache *c)
450 {
451         irq_work_queue(&c->refill_work);
452 }
453
454 /* For typical bpf map case that uses bpf_mem_cache_alloc and single bucket
455  * the freelist cache will be elem_size * 64 (or less) on each cpu.
456  *
457  * For bpf programs that don't have statically known allocation sizes and
458  * assuming (low_mark + high_mark) / 2 as an average number of elements per
459  * bucket and all buckets are used the total amount of memory in freelists
460  * on each cpu will be:
461  * 64*16 + 64*32 + 64*64 + 64*96 + 64*128 + 64*196 + 64*256 + 32*512 + 16*1024 + 8*2048 + 4*4096
462  * == ~ 116 Kbyte using below heuristic.
463  * Initialized, but unused bpf allocator (not bpf map specific one) will
464  * consume ~ 11 Kbyte per cpu.
465  * Typical case will be between 11K and 116K closer to 11K.
466  * bpf progs can and should share bpf_mem_cache when possible.
467  *
468  * Percpu allocation is typically rare. To avoid potential unnecessary large
469  * memory consumption, set low_mark = 1 and high_mark = 3, resulting in c->batch = 1.
470  */
471 static void init_refill_work(struct bpf_mem_cache *c)
472 {
473         init_irq_work(&c->refill_work, bpf_mem_refill);
474         if (c->percpu_size) {
475                 c->low_watermark = 1;
476                 c->high_watermark = 3;
477         } else if (c->unit_size <= 256) {
478                 c->low_watermark = 32;
479                 c->high_watermark = 96;
480         } else {
481                 /* When page_size == 4k, order-0 cache will have low_mark == 2
482                  * and high_mark == 6 with batch alloc of 3 individual pages at
483                  * a time.
484                  * 8k allocs and above low == 1, high == 3, batch == 1.
485                  */
486                 c->low_watermark = max(32 * 256 / c->unit_size, 1);
487                 c->high_watermark = max(96 * 256 / c->unit_size, 3);
488         }
489         c->batch = max((c->high_watermark - c->low_watermark) / 4 * 3, 1);
490 }
491
492 static void prefill_mem_cache(struct bpf_mem_cache *c, int cpu)
493 {
494         int cnt = 1;
495
496         /* To avoid consuming memory, for non-percpu allocation, assume that
497          * 1st run of bpf prog won't be doing more than 4 map_update_elem from
498          * irq disabled region if unit size is less than or equal to 256.
499          * For all other cases, let us just do one allocation.
500          */
501         if (!c->percpu_size && c->unit_size <= 256)
502                 cnt = 4;
503         alloc_bulk(c, cnt, cpu_to_node(cpu), false);
504 }
505
506 /* When size != 0 bpf_mem_cache for each cpu.
507  * This is typical bpf hash map use case when all elements have equal size.
508  *
509  * When size == 0 allocate 11 bpf_mem_cache-s for each cpu, then rely on
510  * kmalloc/kfree. Max allocation size is 4096 in this case.
511  * This is bpf_dynptr and bpf_kptr use case.
512  */
513 int bpf_mem_alloc_init(struct bpf_mem_alloc *ma, int size, bool percpu)
514 {
515         struct bpf_mem_caches *cc, __percpu *pcc;
516         struct bpf_mem_cache *c, __percpu *pc;
517         struct obj_cgroup *objcg = NULL;
518         int cpu, i, unit_size, percpu_size = 0;
519
520         if (percpu && size == 0)
521                 return -EINVAL;
522
523         /* room for llist_node and per-cpu pointer */
524         if (percpu)
525                 percpu_size = LLIST_NODE_SZ + sizeof(void *);
526         ma->percpu = percpu;
527
528         if (size) {
529                 pc = __alloc_percpu_gfp(sizeof(*pc), 8, GFP_KERNEL);
530                 if (!pc)
531                         return -ENOMEM;
532
533                 if (!percpu)
534                         size += LLIST_NODE_SZ; /* room for llist_node */
535                 unit_size = size;
536
537 #ifdef CONFIG_MEMCG_KMEM
538                 if (memcg_bpf_enabled())
539                         objcg = get_obj_cgroup_from_current();
540 #endif
541                 ma->objcg = objcg;
542
543                 for_each_possible_cpu(cpu) {
544                         c = per_cpu_ptr(pc, cpu);
545                         c->unit_size = unit_size;
546                         c->objcg = objcg;
547                         c->percpu_size = percpu_size;
548                         c->tgt = c;
549                         init_refill_work(c);
550                         prefill_mem_cache(c, cpu);
551                 }
552                 ma->cache = pc;
553                 return 0;
554         }
555
556         pcc = __alloc_percpu_gfp(sizeof(*cc), 8, GFP_KERNEL);
557         if (!pcc)
558                 return -ENOMEM;
559 #ifdef CONFIG_MEMCG_KMEM
560         objcg = get_obj_cgroup_from_current();
561 #endif
562         ma->objcg = objcg;
563         for_each_possible_cpu(cpu) {
564                 cc = per_cpu_ptr(pcc, cpu);
565                 for (i = 0; i < NUM_CACHES; i++) {
566                         c = &cc->cache[i];
567                         c->unit_size = sizes[i];
568                         c->objcg = objcg;
569                         c->percpu_size = percpu_size;
570                         c->tgt = c;
571
572                         init_refill_work(c);
573                         prefill_mem_cache(c, cpu);
574                 }
575         }
576
577         ma->caches = pcc;
578         return 0;
579 }
580
581 int bpf_mem_alloc_percpu_init(struct bpf_mem_alloc *ma, struct obj_cgroup *objcg)
582 {
583         struct bpf_mem_caches __percpu *pcc;
584
585         pcc = __alloc_percpu_gfp(sizeof(struct bpf_mem_caches), 8, GFP_KERNEL);
586         if (!pcc)
587                 return -ENOMEM;
588
589         ma->caches = pcc;
590         ma->objcg = objcg;
591         ma->percpu = true;
592         return 0;
593 }
594
595 int bpf_mem_alloc_percpu_unit_init(struct bpf_mem_alloc *ma, int size)
596 {
597         struct bpf_mem_caches *cc, __percpu *pcc;
598         int cpu, i, unit_size, percpu_size;
599         struct obj_cgroup *objcg;
600         struct bpf_mem_cache *c;
601
602         i = bpf_mem_cache_idx(size);
603         if (i < 0)
604                 return -EINVAL;
605
606         /* room for llist_node and per-cpu pointer */
607         percpu_size = LLIST_NODE_SZ + sizeof(void *);
608
609         unit_size = sizes[i];
610         objcg = ma->objcg;
611         pcc = ma->caches;
612
613         for_each_possible_cpu(cpu) {
614                 cc = per_cpu_ptr(pcc, cpu);
615                 c = &cc->cache[i];
616                 if (c->unit_size)
617                         break;
618
619                 c->unit_size = unit_size;
620                 c->objcg = objcg;
621                 c->percpu_size = percpu_size;
622                 c->tgt = c;
623
624                 init_refill_work(c);
625                 prefill_mem_cache(c, cpu);
626         }
627
628         return 0;
629 }
630
631 static void drain_mem_cache(struct bpf_mem_cache *c)
632 {
633         bool percpu = !!c->percpu_size;
634
635         /* No progs are using this bpf_mem_cache, but htab_map_free() called
636          * bpf_mem_cache_free() for all remaining elements and they can be in
637          * free_by_rcu_ttrace or in waiting_for_gp_ttrace lists, so drain those lists now.
638          *
639          * Except for waiting_for_gp_ttrace list, there are no concurrent operations
640          * on these lists, so it is safe to use __llist_del_all().
641          */
642         free_all(llist_del_all(&c->free_by_rcu_ttrace), percpu);
643         free_all(llist_del_all(&c->waiting_for_gp_ttrace), percpu);
644         free_all(__llist_del_all(&c->free_llist), percpu);
645         free_all(__llist_del_all(&c->free_llist_extra), percpu);
646         free_all(__llist_del_all(&c->free_by_rcu), percpu);
647         free_all(__llist_del_all(&c->free_llist_extra_rcu), percpu);
648         free_all(llist_del_all(&c->waiting_for_gp), percpu);
649 }
650
651 static void check_mem_cache(struct bpf_mem_cache *c)
652 {
653         WARN_ON_ONCE(!llist_empty(&c->free_by_rcu_ttrace));
654         WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp_ttrace));
655         WARN_ON_ONCE(!llist_empty(&c->free_llist));
656         WARN_ON_ONCE(!llist_empty(&c->free_llist_extra));
657         WARN_ON_ONCE(!llist_empty(&c->free_by_rcu));
658         WARN_ON_ONCE(!llist_empty(&c->free_llist_extra_rcu));
659         WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp));
660 }
661
662 static void check_leaked_objs(struct bpf_mem_alloc *ma)
663 {
664         struct bpf_mem_caches *cc;
665         struct bpf_mem_cache *c;
666         int cpu, i;
667
668         if (ma->cache) {
669                 for_each_possible_cpu(cpu) {
670                         c = per_cpu_ptr(ma->cache, cpu);
671                         check_mem_cache(c);
672                 }
673         }
674         if (ma->caches) {
675                 for_each_possible_cpu(cpu) {
676                         cc = per_cpu_ptr(ma->caches, cpu);
677                         for (i = 0; i < NUM_CACHES; i++) {
678                                 c = &cc->cache[i];
679                                 check_mem_cache(c);
680                         }
681                 }
682         }
683 }
684
685 static void free_mem_alloc_no_barrier(struct bpf_mem_alloc *ma)
686 {
687         check_leaked_objs(ma);
688         free_percpu(ma->cache);
689         free_percpu(ma->caches);
690         ma->cache = NULL;
691         ma->caches = NULL;
692 }
693
694 static void free_mem_alloc(struct bpf_mem_alloc *ma)
695 {
696         /* waiting_for_gp[_ttrace] lists were drained, but RCU callbacks
697          * might still execute. Wait for them.
698          *
699          * rcu_barrier_tasks_trace() doesn't imply synchronize_rcu_tasks_trace(),
700          * but rcu_barrier_tasks_trace() and rcu_barrier() below are only used
701          * to wait for the pending __free_rcu_tasks_trace() and __free_rcu(),
702          * so if call_rcu(head, __free_rcu) is skipped due to
703          * rcu_trace_implies_rcu_gp(), it will be OK to skip rcu_barrier() by
704          * using rcu_trace_implies_rcu_gp() as well.
705          */
706         rcu_barrier(); /* wait for __free_by_rcu */
707         rcu_barrier_tasks_trace(); /* wait for __free_rcu */
708         if (!rcu_trace_implies_rcu_gp())
709                 rcu_barrier();
710         free_mem_alloc_no_barrier(ma);
711 }
712
713 static void free_mem_alloc_deferred(struct work_struct *work)
714 {
715         struct bpf_mem_alloc *ma = container_of(work, struct bpf_mem_alloc, work);
716
717         free_mem_alloc(ma);
718         kfree(ma);
719 }
720
721 static void destroy_mem_alloc(struct bpf_mem_alloc *ma, int rcu_in_progress)
722 {
723         struct bpf_mem_alloc *copy;
724
725         if (!rcu_in_progress) {
726                 /* Fast path. No callbacks are pending, hence no need to do
727                  * rcu_barrier-s.
728                  */
729                 free_mem_alloc_no_barrier(ma);
730                 return;
731         }
732
733         copy = kmemdup(ma, sizeof(*ma), GFP_KERNEL);
734         if (!copy) {
735                 /* Slow path with inline barrier-s */
736                 free_mem_alloc(ma);
737                 return;
738         }
739
740         /* Defer barriers into worker to let the rest of map memory to be freed */
741         memset(ma, 0, sizeof(*ma));
742         INIT_WORK(&copy->work, free_mem_alloc_deferred);
743         queue_work(system_unbound_wq, &copy->work);
744 }
745
746 void bpf_mem_alloc_destroy(struct bpf_mem_alloc *ma)
747 {
748         struct bpf_mem_caches *cc;
749         struct bpf_mem_cache *c;
750         int cpu, i, rcu_in_progress;
751
752         if (ma->cache) {
753                 rcu_in_progress = 0;
754                 for_each_possible_cpu(cpu) {
755                         c = per_cpu_ptr(ma->cache, cpu);
756                         WRITE_ONCE(c->draining, true);
757                         irq_work_sync(&c->refill_work);
758                         drain_mem_cache(c);
759                         rcu_in_progress += atomic_read(&c->call_rcu_ttrace_in_progress);
760                         rcu_in_progress += atomic_read(&c->call_rcu_in_progress);
761                 }
762                 if (ma->objcg)
763                         obj_cgroup_put(ma->objcg);
764                 destroy_mem_alloc(ma, rcu_in_progress);
765         }
766         if (ma->caches) {
767                 rcu_in_progress = 0;
768                 for_each_possible_cpu(cpu) {
769                         cc = per_cpu_ptr(ma->caches, cpu);
770                         for (i = 0; i < NUM_CACHES; i++) {
771                                 c = &cc->cache[i];
772                                 WRITE_ONCE(c->draining, true);
773                                 irq_work_sync(&c->refill_work);
774                                 drain_mem_cache(c);
775                                 rcu_in_progress += atomic_read(&c->call_rcu_ttrace_in_progress);
776                                 rcu_in_progress += atomic_read(&c->call_rcu_in_progress);
777                         }
778                 }
779                 if (ma->objcg)
780                         obj_cgroup_put(ma->objcg);
781                 destroy_mem_alloc(ma, rcu_in_progress);
782         }
783 }
784
785 /* notrace is necessary here and in other functions to make sure
786  * bpf programs cannot attach to them and cause llist corruptions.
787  */
788 static void notrace *unit_alloc(struct bpf_mem_cache *c)
789 {
790         struct llist_node *llnode = NULL;
791         unsigned long flags;
792         int cnt = 0;
793
794         /* Disable irqs to prevent the following race for majority of prog types:
795          * prog_A
796          *   bpf_mem_alloc
797          *      preemption or irq -> prog_B
798          *        bpf_mem_alloc
799          *
800          * but prog_B could be a perf_event NMI prog.
801          * Use per-cpu 'active' counter to order free_list access between
802          * unit_alloc/unit_free/bpf_mem_refill.
803          */
804         local_irq_save(flags);
805         if (local_inc_return(&c->active) == 1) {
806                 llnode = __llist_del_first(&c->free_llist);
807                 if (llnode) {
808                         cnt = --c->free_cnt;
809                         *(struct bpf_mem_cache **)llnode = c;
810                 }
811         }
812         local_dec(&c->active);
813
814         WARN_ON(cnt < 0);
815
816         if (cnt < c->low_watermark)
817                 irq_work_raise(c);
818         /* Enable IRQ after the enqueue of irq work completes, so irq work
819          * will run after IRQ is enabled and free_llist may be refilled by
820          * irq work before other task preempts current task.
821          */
822         local_irq_restore(flags);
823
824         return llnode;
825 }
826
827 /* Though 'ptr' object could have been allocated on a different cpu
828  * add it to the free_llist of the current cpu.
829  * Let kfree() logic deal with it when it's later called from irq_work.
830  */
831 static void notrace unit_free(struct bpf_mem_cache *c, void *ptr)
832 {
833         struct llist_node *llnode = ptr - LLIST_NODE_SZ;
834         unsigned long flags;
835         int cnt = 0;
836
837         BUILD_BUG_ON(LLIST_NODE_SZ > 8);
838
839         /*
840          * Remember bpf_mem_cache that allocated this object.
841          * The hint is not accurate.
842          */
843         c->tgt = *(struct bpf_mem_cache **)llnode;
844
845         local_irq_save(flags);
846         if (local_inc_return(&c->active) == 1) {
847                 __llist_add(llnode, &c->free_llist);
848                 cnt = ++c->free_cnt;
849         } else {
850                 /* unit_free() cannot fail. Therefore add an object to atomic
851                  * llist. free_bulk() will drain it. Though free_llist_extra is
852                  * a per-cpu list we have to use atomic llist_add here, since
853                  * it also can be interrupted by bpf nmi prog that does another
854                  * unit_free() into the same free_llist_extra.
855                  */
856                 llist_add(llnode, &c->free_llist_extra);
857         }
858         local_dec(&c->active);
859
860         if (cnt > c->high_watermark)
861                 /* free few objects from current cpu into global kmalloc pool */
862                 irq_work_raise(c);
863         /* Enable IRQ after irq_work_raise() completes, otherwise when current
864          * task is preempted by task which does unit_alloc(), unit_alloc() may
865          * return NULL unexpectedly because irq work is already pending but can
866          * not been triggered and free_llist can not be refilled timely.
867          */
868         local_irq_restore(flags);
869 }
870
871 static void notrace unit_free_rcu(struct bpf_mem_cache *c, void *ptr)
872 {
873         struct llist_node *llnode = ptr - LLIST_NODE_SZ;
874         unsigned long flags;
875
876         c->tgt = *(struct bpf_mem_cache **)llnode;
877
878         local_irq_save(flags);
879         if (local_inc_return(&c->active) == 1) {
880                 if (__llist_add(llnode, &c->free_by_rcu))
881                         c->free_by_rcu_tail = llnode;
882         } else {
883                 llist_add(llnode, &c->free_llist_extra_rcu);
884         }
885         local_dec(&c->active);
886
887         if (!atomic_read(&c->call_rcu_in_progress))
888                 irq_work_raise(c);
889         local_irq_restore(flags);
890 }
891
892 /* Called from BPF program or from sys_bpf syscall.
893  * In both cases migration is disabled.
894  */
895 void notrace *bpf_mem_alloc(struct bpf_mem_alloc *ma, size_t size)
896 {
897         int idx;
898         void *ret;
899
900         if (!size)
901                 return NULL;
902
903         if (!ma->percpu)
904                 size += LLIST_NODE_SZ;
905         idx = bpf_mem_cache_idx(size);
906         if (idx < 0)
907                 return NULL;
908
909         ret = unit_alloc(this_cpu_ptr(ma->caches)->cache + idx);
910         return !ret ? NULL : ret + LLIST_NODE_SZ;
911 }
912
913 void notrace bpf_mem_free(struct bpf_mem_alloc *ma, void *ptr)
914 {
915         struct bpf_mem_cache *c;
916         int idx;
917
918         if (!ptr)
919                 return;
920
921         c = *(void **)(ptr - LLIST_NODE_SZ);
922         idx = bpf_mem_cache_idx(c->unit_size);
923         if (WARN_ON_ONCE(idx < 0))
924                 return;
925
926         unit_free(this_cpu_ptr(ma->caches)->cache + idx, ptr);
927 }
928
929 void notrace bpf_mem_free_rcu(struct bpf_mem_alloc *ma, void *ptr)
930 {
931         struct bpf_mem_cache *c;
932         int idx;
933
934         if (!ptr)
935                 return;
936
937         c = *(void **)(ptr - LLIST_NODE_SZ);
938         idx = bpf_mem_cache_idx(c->unit_size);
939         if (WARN_ON_ONCE(idx < 0))
940                 return;
941
942         unit_free_rcu(this_cpu_ptr(ma->caches)->cache + idx, ptr);
943 }
944
945 void notrace *bpf_mem_cache_alloc(struct bpf_mem_alloc *ma)
946 {
947         void *ret;
948
949         ret = unit_alloc(this_cpu_ptr(ma->cache));
950         return !ret ? NULL : ret + LLIST_NODE_SZ;
951 }
952
953 void notrace bpf_mem_cache_free(struct bpf_mem_alloc *ma, void *ptr)
954 {
955         if (!ptr)
956                 return;
957
958         unit_free(this_cpu_ptr(ma->cache), ptr);
959 }
960
961 void notrace bpf_mem_cache_free_rcu(struct bpf_mem_alloc *ma, void *ptr)
962 {
963         if (!ptr)
964                 return;
965
966         unit_free_rcu(this_cpu_ptr(ma->cache), ptr);
967 }
968
969 /* Directly does a kfree() without putting 'ptr' back to the free_llist
970  * for reuse and without waiting for a rcu_tasks_trace gp.
971  * The caller must first go through the rcu_tasks_trace gp for 'ptr'
972  * before calling bpf_mem_cache_raw_free().
973  * It could be used when the rcu_tasks_trace callback does not have
974  * a hold on the original bpf_mem_alloc object that allocated the
975  * 'ptr'. This should only be used in the uncommon code path.
976  * Otherwise, the bpf_mem_alloc's free_llist cannot be refilled
977  * and may affect performance.
978  */
979 void bpf_mem_cache_raw_free(void *ptr)
980 {
981         if (!ptr)
982                 return;
983
984         kfree(ptr - LLIST_NODE_SZ);
985 }
986
987 /* When flags == GFP_KERNEL, it signals that the caller will not cause
988  * deadlock when using kmalloc. bpf_mem_cache_alloc_flags() will use
989  * kmalloc if the free_llist is empty.
990  */
991 void notrace *bpf_mem_cache_alloc_flags(struct bpf_mem_alloc *ma, gfp_t flags)
992 {
993         struct bpf_mem_cache *c;
994         void *ret;
995
996         c = this_cpu_ptr(ma->cache);
997
998         ret = unit_alloc(c);
999         if (!ret && flags == GFP_KERNEL) {
1000                 struct mem_cgroup *memcg, *old_memcg;
1001
1002                 memcg = get_memcg(c);
1003                 old_memcg = set_active_memcg(memcg);
1004                 ret = __alloc(c, NUMA_NO_NODE, GFP_KERNEL | __GFP_NOWARN | __GFP_ACCOUNT);
1005                 if (ret)
1006                         *(struct bpf_mem_cache **)ret = c;
1007                 set_active_memcg(old_memcg);
1008                 mem_cgroup_put(memcg);
1009         }
1010
1011         return !ret ? NULL : ret + LLIST_NODE_SZ;
1012 }