GNU Linux-libre 4.14.324-gnu1
[releases.git] / arch / arm64 / net / bpf_jit_comp.c
1 /*
2  * BPF JIT compiler for ARM64
3  *
4  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 #define pr_fmt(fmt) "bpf_jit: " fmt
20
21 #include <linux/bpf.h>
22 #include <linux/filter.h>
23 #include <linux/printk.h>
24 #include <linux/skbuff.h>
25 #include <linux/slab.h>
26
27 #include <asm/byteorder.h>
28 #include <asm/cacheflush.h>
29 #include <asm/debug-monitors.h>
30 #include <asm/set_memory.h>
31
32 #include "bpf_jit.h"
33
34 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
35 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
36 #define TCALL_CNT (MAX_BPF_JIT_REG + 2)
37 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
38
39 /* Map BPF registers to A64 registers */
40 static const int bpf2a64[] = {
41         /* return value from in-kernel function, and exit value from eBPF */
42         [BPF_REG_0] = A64_R(7),
43         /* arguments from eBPF program to in-kernel function */
44         [BPF_REG_1] = A64_R(0),
45         [BPF_REG_2] = A64_R(1),
46         [BPF_REG_3] = A64_R(2),
47         [BPF_REG_4] = A64_R(3),
48         [BPF_REG_5] = A64_R(4),
49         /* callee saved registers that in-kernel function will preserve */
50         [BPF_REG_6] = A64_R(19),
51         [BPF_REG_7] = A64_R(20),
52         [BPF_REG_8] = A64_R(21),
53         [BPF_REG_9] = A64_R(22),
54         /* read-only frame pointer to access stack */
55         [BPF_REG_FP] = A64_R(25),
56         /* temporary registers for internal BPF JIT */
57         [TMP_REG_1] = A64_R(10),
58         [TMP_REG_2] = A64_R(11),
59         [TMP_REG_3] = A64_R(12),
60         /* tail_call_cnt */
61         [TCALL_CNT] = A64_R(26),
62         /* temporary register for blinding constants */
63         [BPF_REG_AX] = A64_R(9),
64 };
65
66 struct jit_ctx {
67         const struct bpf_prog *prog;
68         int idx;
69         int epilogue_offset;
70         int *offset;
71         __le32 *image;
72         u32 stack_size;
73 };
74
75 static inline void emit(const u32 insn, struct jit_ctx *ctx)
76 {
77         if (ctx->image != NULL)
78                 ctx->image[ctx->idx] = cpu_to_le32(insn);
79
80         ctx->idx++;
81 }
82
83 static inline void emit_a64_mov_i64(const int reg, const u64 val,
84                                     struct jit_ctx *ctx)
85 {
86         u64 tmp = val;
87         int shift = 0;
88
89         emit(A64_MOVZ(1, reg, tmp & 0xffff, shift), ctx);
90         tmp >>= 16;
91         shift += 16;
92         while (tmp) {
93                 if (tmp & 0xffff)
94                         emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
95                 tmp >>= 16;
96                 shift += 16;
97         }
98 }
99
100 static inline void emit_a64_mov_i(const int is64, const int reg,
101                                   const s32 val, struct jit_ctx *ctx)
102 {
103         u16 hi = val >> 16;
104         u16 lo = val & 0xffff;
105
106         if (hi & 0x8000) {
107                 if (hi == 0xffff) {
108                         emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
109                 } else {
110                         emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
111                         emit(A64_MOVK(is64, reg, lo, 0), ctx);
112                 }
113         } else {
114                 emit(A64_MOVZ(is64, reg, lo, 0), ctx);
115                 if (hi)
116                         emit(A64_MOVK(is64, reg, hi, 16), ctx);
117         }
118 }
119
120 static inline int bpf2a64_offset(int bpf_to, int bpf_from,
121                                  const struct jit_ctx *ctx)
122 {
123         int to = ctx->offset[bpf_to];
124         /* -1 to account for the Branch instruction */
125         int from = ctx->offset[bpf_from] - 1;
126
127         return to - from;
128 }
129
130 static void jit_fill_hole(void *area, unsigned int size)
131 {
132         __le32 *ptr;
133         /* We are guaranteed to have aligned memory. */
134         for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
135                 *ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
136 }
137
138 static inline int epilogue_offset(const struct jit_ctx *ctx)
139 {
140         int to = ctx->epilogue_offset;
141         int from = ctx->idx;
142
143         return to - from;
144 }
145
146 /* Stack must be multiples of 16B */
147 #define STACK_ALIGN(sz) (((sz) + 15) & ~15)
148
149 /* Tail call offset to jump into */
150 #define PROLOGUE_OFFSET 7
151
152 static int build_prologue(struct jit_ctx *ctx)
153 {
154         const struct bpf_prog *prog = ctx->prog;
155         const u8 r6 = bpf2a64[BPF_REG_6];
156         const u8 r7 = bpf2a64[BPF_REG_7];
157         const u8 r8 = bpf2a64[BPF_REG_8];
158         const u8 r9 = bpf2a64[BPF_REG_9];
159         const u8 fp = bpf2a64[BPF_REG_FP];
160         const u8 tcc = bpf2a64[TCALL_CNT];
161         const int idx0 = ctx->idx;
162         int cur_offset;
163
164         /*
165          * BPF prog stack layout
166          *
167          *                         high
168          * original A64_SP =>   0:+-----+ BPF prologue
169          *                        |FP/LR|
170          * current A64_FP =>  -16:+-----+
171          *                        | ... | callee saved registers
172          * BPF fp register => -64:+-----+ <= (BPF_FP)
173          *                        |     |
174          *                        | ... | BPF prog stack
175          *                        |     |
176          *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
177          *                        |RSVD | JIT scratchpad
178          * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
179          *                        |     |
180          *                        | ... | Function call stack
181          *                        |     |
182          *                        +-----+
183          *                          low
184          *
185          */
186
187         /* Save FP and LR registers to stay align with ARM64 AAPCS */
188         emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
189         emit(A64_MOV(1, A64_FP, A64_SP), ctx);
190
191         /* Save callee-saved registers */
192         emit(A64_PUSH(r6, r7, A64_SP), ctx);
193         emit(A64_PUSH(r8, r9, A64_SP), ctx);
194         emit(A64_PUSH(fp, tcc, A64_SP), ctx);
195
196         /* Set up BPF prog stack base register */
197         emit(A64_MOV(1, fp, A64_SP), ctx);
198
199         /* Initialize tail_call_cnt */
200         emit(A64_MOVZ(1, tcc, 0, 0), ctx);
201
202         cur_offset = ctx->idx - idx0;
203         if (cur_offset != PROLOGUE_OFFSET) {
204                 pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
205                             cur_offset, PROLOGUE_OFFSET);
206                 return -1;
207         }
208
209         /* 4 byte extra for skb_copy_bits buffer */
210         ctx->stack_size = prog->aux->stack_depth + 4;
211         ctx->stack_size = STACK_ALIGN(ctx->stack_size);
212
213         /* Set up function call stack */
214         emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
215         return 0;
216 }
217
218 static int out_offset = -1; /* initialized on the first pass of build_body() */
219 static int emit_bpf_tail_call(struct jit_ctx *ctx)
220 {
221         /* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
222         const u8 r2 = bpf2a64[BPF_REG_2];
223         const u8 r3 = bpf2a64[BPF_REG_3];
224
225         const u8 tmp = bpf2a64[TMP_REG_1];
226         const u8 prg = bpf2a64[TMP_REG_2];
227         const u8 tcc = bpf2a64[TCALL_CNT];
228         const int idx0 = ctx->idx;
229 #define cur_offset (ctx->idx - idx0)
230 #define jmp_offset (out_offset - (cur_offset))
231         size_t off;
232
233         /* if (index >= array->map.max_entries)
234          *     goto out;
235          */
236         off = offsetof(struct bpf_array, map.max_entries);
237         emit_a64_mov_i64(tmp, off, ctx);
238         emit(A64_LDR32(tmp, r2, tmp), ctx);
239         emit(A64_MOV(0, r3, r3), ctx);
240         emit(A64_CMP(0, r3, tmp), ctx);
241         emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
242
243         /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
244          *     goto out;
245          * tail_call_cnt++;
246          */
247         emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
248         emit(A64_CMP(1, tcc, tmp), ctx);
249         emit(A64_B_(A64_COND_HI, jmp_offset), ctx);
250         emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
251
252         /* prog = array->ptrs[index];
253          * if (prog == NULL)
254          *     goto out;
255          */
256         off = offsetof(struct bpf_array, ptrs);
257         emit_a64_mov_i64(tmp, off, ctx);
258         emit(A64_ADD(1, tmp, r2, tmp), ctx);
259         emit(A64_LSL(1, prg, r3, 3), ctx);
260         emit(A64_LDR64(prg, tmp, prg), ctx);
261         emit(A64_CBZ(1, prg, jmp_offset), ctx);
262
263         /* goto *(prog->bpf_func + prologue_offset); */
264         off = offsetof(struct bpf_prog, bpf_func);
265         emit_a64_mov_i64(tmp, off, ctx);
266         emit(A64_LDR64(tmp, prg, tmp), ctx);
267         emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
268         emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
269         emit(A64_BR(tmp), ctx);
270
271         /* out: */
272         if (out_offset == -1)
273                 out_offset = cur_offset;
274         if (cur_offset != out_offset) {
275                 pr_err_once("tail_call out_offset = %d, expected %d!\n",
276                             cur_offset, out_offset);
277                 return -1;
278         }
279         return 0;
280 #undef cur_offset
281 #undef jmp_offset
282 }
283
284 static void build_epilogue(struct jit_ctx *ctx)
285 {
286         const u8 r0 = bpf2a64[BPF_REG_0];
287         const u8 r6 = bpf2a64[BPF_REG_6];
288         const u8 r7 = bpf2a64[BPF_REG_7];
289         const u8 r8 = bpf2a64[BPF_REG_8];
290         const u8 r9 = bpf2a64[BPF_REG_9];
291         const u8 fp = bpf2a64[BPF_REG_FP];
292
293         /* We're done with BPF stack */
294         emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
295
296         /* Restore fs (x25) and x26 */
297         emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
298
299         /* Restore callee-saved register */
300         emit(A64_POP(r8, r9, A64_SP), ctx);
301         emit(A64_POP(r6, r7, A64_SP), ctx);
302
303         /* Restore FP/LR registers */
304         emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
305
306         /* Set return value */
307         emit(A64_MOV(1, A64_R(0), r0), ctx);
308
309         emit(A64_RET(A64_LR), ctx);
310 }
311
312 /* JITs an eBPF instruction.
313  * Returns:
314  * 0  - successfully JITed an 8-byte eBPF instruction.
315  * >0 - successfully JITed a 16-byte eBPF instruction.
316  * <0 - failed to JIT.
317  */
318 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
319 {
320         const u8 code = insn->code;
321         const u8 dst = bpf2a64[insn->dst_reg];
322         const u8 src = bpf2a64[insn->src_reg];
323         const u8 tmp = bpf2a64[TMP_REG_1];
324         const u8 tmp2 = bpf2a64[TMP_REG_2];
325         const u8 tmp3 = bpf2a64[TMP_REG_3];
326         const s16 off = insn->off;
327         const s32 imm = insn->imm;
328         const int i = insn - ctx->prog->insnsi;
329         const bool is64 = BPF_CLASS(code) == BPF_ALU64;
330         const bool isdw = BPF_SIZE(code) == BPF_DW;
331         u8 jmp_cond, reg;
332         s32 jmp_offset;
333
334 #define check_imm(bits, imm) do {                               \
335         if ((((imm) > 0) && ((imm) >> (bits))) ||               \
336             (((imm) < 0) && (~(imm) >> (bits)))) {              \
337                 pr_info("[%2d] imm=%d(0x%x) out of range\n",    \
338                         i, imm, imm);                           \
339                 return -EINVAL;                                 \
340         }                                                       \
341 } while (0)
342 #define check_imm19(imm) check_imm(19, imm)
343 #define check_imm26(imm) check_imm(26, imm)
344
345         switch (code) {
346         /* dst = src */
347         case BPF_ALU | BPF_MOV | BPF_X:
348         case BPF_ALU64 | BPF_MOV | BPF_X:
349                 emit(A64_MOV(is64, dst, src), ctx);
350                 break;
351         /* dst = dst OP src */
352         case BPF_ALU | BPF_ADD | BPF_X:
353         case BPF_ALU64 | BPF_ADD | BPF_X:
354                 emit(A64_ADD(is64, dst, dst, src), ctx);
355                 break;
356         case BPF_ALU | BPF_SUB | BPF_X:
357         case BPF_ALU64 | BPF_SUB | BPF_X:
358                 emit(A64_SUB(is64, dst, dst, src), ctx);
359                 break;
360         case BPF_ALU | BPF_AND | BPF_X:
361         case BPF_ALU64 | BPF_AND | BPF_X:
362                 emit(A64_AND(is64, dst, dst, src), ctx);
363                 break;
364         case BPF_ALU | BPF_OR | BPF_X:
365         case BPF_ALU64 | BPF_OR | BPF_X:
366                 emit(A64_ORR(is64, dst, dst, src), ctx);
367                 break;
368         case BPF_ALU | BPF_XOR | BPF_X:
369         case BPF_ALU64 | BPF_XOR | BPF_X:
370                 emit(A64_EOR(is64, dst, dst, src), ctx);
371                 break;
372         case BPF_ALU | BPF_MUL | BPF_X:
373         case BPF_ALU64 | BPF_MUL | BPF_X:
374                 emit(A64_MUL(is64, dst, dst, src), ctx);
375                 break;
376         case BPF_ALU | BPF_DIV | BPF_X:
377         case BPF_ALU64 | BPF_DIV | BPF_X:
378         case BPF_ALU | BPF_MOD | BPF_X:
379         case BPF_ALU64 | BPF_MOD | BPF_X:
380         {
381                 const u8 r0 = bpf2a64[BPF_REG_0];
382
383                 /* if (src == 0) return 0 */
384                 jmp_offset = 3; /* skip ahead to else path */
385                 check_imm19(jmp_offset);
386                 emit(A64_CBNZ(is64, src, jmp_offset), ctx);
387                 emit(A64_MOVZ(1, r0, 0, 0), ctx);
388                 jmp_offset = epilogue_offset(ctx);
389                 check_imm26(jmp_offset);
390                 emit(A64_B(jmp_offset), ctx);
391                 /* else */
392                 switch (BPF_OP(code)) {
393                 case BPF_DIV:
394                         emit(A64_UDIV(is64, dst, dst, src), ctx);
395                         break;
396                 case BPF_MOD:
397                         emit(A64_UDIV(is64, tmp, dst, src), ctx);
398                         emit(A64_MUL(is64, tmp, tmp, src), ctx);
399                         emit(A64_SUB(is64, dst, dst, tmp), ctx);
400                         break;
401                 }
402                 break;
403         }
404         case BPF_ALU | BPF_LSH | BPF_X:
405         case BPF_ALU64 | BPF_LSH | BPF_X:
406                 emit(A64_LSLV(is64, dst, dst, src), ctx);
407                 break;
408         case BPF_ALU | BPF_RSH | BPF_X:
409         case BPF_ALU64 | BPF_RSH | BPF_X:
410                 emit(A64_LSRV(is64, dst, dst, src), ctx);
411                 break;
412         case BPF_ALU | BPF_ARSH | BPF_X:
413         case BPF_ALU64 | BPF_ARSH | BPF_X:
414                 emit(A64_ASRV(is64, dst, dst, src), ctx);
415                 break;
416         /* dst = -dst */
417         case BPF_ALU | BPF_NEG:
418         case BPF_ALU64 | BPF_NEG:
419                 emit(A64_NEG(is64, dst, dst), ctx);
420                 break;
421         /* dst = BSWAP##imm(dst) */
422         case BPF_ALU | BPF_END | BPF_FROM_LE:
423         case BPF_ALU | BPF_END | BPF_FROM_BE:
424 #ifdef CONFIG_CPU_BIG_ENDIAN
425                 if (BPF_SRC(code) == BPF_FROM_BE)
426                         goto emit_bswap_uxt;
427 #else /* !CONFIG_CPU_BIG_ENDIAN */
428                 if (BPF_SRC(code) == BPF_FROM_LE)
429                         goto emit_bswap_uxt;
430 #endif
431                 switch (imm) {
432                 case 16:
433                         emit(A64_REV16(is64, dst, dst), ctx);
434                         /* zero-extend 16 bits into 64 bits */
435                         emit(A64_UXTH(is64, dst, dst), ctx);
436                         break;
437                 case 32:
438                         emit(A64_REV32(is64, dst, dst), ctx);
439                         /* upper 32 bits already cleared */
440                         break;
441                 case 64:
442                         emit(A64_REV64(dst, dst), ctx);
443                         break;
444                 }
445                 break;
446 emit_bswap_uxt:
447                 switch (imm) {
448                 case 16:
449                         /* zero-extend 16 bits into 64 bits */
450                         emit(A64_UXTH(is64, dst, dst), ctx);
451                         break;
452                 case 32:
453                         /* zero-extend 32 bits into 64 bits */
454                         emit(A64_UXTW(is64, dst, dst), ctx);
455                         break;
456                 case 64:
457                         /* nop */
458                         break;
459                 }
460                 break;
461         /* dst = imm */
462         case BPF_ALU | BPF_MOV | BPF_K:
463         case BPF_ALU64 | BPF_MOV | BPF_K:
464                 emit_a64_mov_i(is64, dst, imm, ctx);
465                 break;
466         /* dst = dst OP imm */
467         case BPF_ALU | BPF_ADD | BPF_K:
468         case BPF_ALU64 | BPF_ADD | BPF_K:
469                 emit_a64_mov_i(is64, tmp, imm, ctx);
470                 emit(A64_ADD(is64, dst, dst, tmp), ctx);
471                 break;
472         case BPF_ALU | BPF_SUB | BPF_K:
473         case BPF_ALU64 | BPF_SUB | BPF_K:
474                 emit_a64_mov_i(is64, tmp, imm, ctx);
475                 emit(A64_SUB(is64, dst, dst, tmp), ctx);
476                 break;
477         case BPF_ALU | BPF_AND | BPF_K:
478         case BPF_ALU64 | BPF_AND | BPF_K:
479                 emit_a64_mov_i(is64, tmp, imm, ctx);
480                 emit(A64_AND(is64, dst, dst, tmp), ctx);
481                 break;
482         case BPF_ALU | BPF_OR | BPF_K:
483         case BPF_ALU64 | BPF_OR | BPF_K:
484                 emit_a64_mov_i(is64, tmp, imm, ctx);
485                 emit(A64_ORR(is64, dst, dst, tmp), ctx);
486                 break;
487         case BPF_ALU | BPF_XOR | BPF_K:
488         case BPF_ALU64 | BPF_XOR | BPF_K:
489                 emit_a64_mov_i(is64, tmp, imm, ctx);
490                 emit(A64_EOR(is64, dst, dst, tmp), ctx);
491                 break;
492         case BPF_ALU | BPF_MUL | BPF_K:
493         case BPF_ALU64 | BPF_MUL | BPF_K:
494                 emit_a64_mov_i(is64, tmp, imm, ctx);
495                 emit(A64_MUL(is64, dst, dst, tmp), ctx);
496                 break;
497         case BPF_ALU | BPF_DIV | BPF_K:
498         case BPF_ALU64 | BPF_DIV | BPF_K:
499                 emit_a64_mov_i(is64, tmp, imm, ctx);
500                 emit(A64_UDIV(is64, dst, dst, tmp), ctx);
501                 break;
502         case BPF_ALU | BPF_MOD | BPF_K:
503         case BPF_ALU64 | BPF_MOD | BPF_K:
504                 emit_a64_mov_i(is64, tmp2, imm, ctx);
505                 emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
506                 emit(A64_MUL(is64, tmp, tmp, tmp2), ctx);
507                 emit(A64_SUB(is64, dst, dst, tmp), ctx);
508                 break;
509         case BPF_ALU | BPF_LSH | BPF_K:
510         case BPF_ALU64 | BPF_LSH | BPF_K:
511                 emit(A64_LSL(is64, dst, dst, imm), ctx);
512                 break;
513         case BPF_ALU | BPF_RSH | BPF_K:
514         case BPF_ALU64 | BPF_RSH | BPF_K:
515                 emit(A64_LSR(is64, dst, dst, imm), ctx);
516                 break;
517         case BPF_ALU | BPF_ARSH | BPF_K:
518         case BPF_ALU64 | BPF_ARSH | BPF_K:
519                 emit(A64_ASR(is64, dst, dst, imm), ctx);
520                 break;
521
522         /* JUMP off */
523         case BPF_JMP | BPF_JA:
524                 jmp_offset = bpf2a64_offset(i + off, i, ctx);
525                 check_imm26(jmp_offset);
526                 emit(A64_B(jmp_offset), ctx);
527                 break;
528         /* IF (dst COND src) JUMP off */
529         case BPF_JMP | BPF_JEQ | BPF_X:
530         case BPF_JMP | BPF_JGT | BPF_X:
531         case BPF_JMP | BPF_JLT | BPF_X:
532         case BPF_JMP | BPF_JGE | BPF_X:
533         case BPF_JMP | BPF_JLE | BPF_X:
534         case BPF_JMP | BPF_JNE | BPF_X:
535         case BPF_JMP | BPF_JSGT | BPF_X:
536         case BPF_JMP | BPF_JSLT | BPF_X:
537         case BPF_JMP | BPF_JSGE | BPF_X:
538         case BPF_JMP | BPF_JSLE | BPF_X:
539                 emit(A64_CMP(1, dst, src), ctx);
540 emit_cond_jmp:
541                 jmp_offset = bpf2a64_offset(i + off, i, ctx);
542                 check_imm19(jmp_offset);
543                 switch (BPF_OP(code)) {
544                 case BPF_JEQ:
545                         jmp_cond = A64_COND_EQ;
546                         break;
547                 case BPF_JGT:
548                         jmp_cond = A64_COND_HI;
549                         break;
550                 case BPF_JLT:
551                         jmp_cond = A64_COND_CC;
552                         break;
553                 case BPF_JGE:
554                         jmp_cond = A64_COND_CS;
555                         break;
556                 case BPF_JLE:
557                         jmp_cond = A64_COND_LS;
558                         break;
559                 case BPF_JSET:
560                 case BPF_JNE:
561                         jmp_cond = A64_COND_NE;
562                         break;
563                 case BPF_JSGT:
564                         jmp_cond = A64_COND_GT;
565                         break;
566                 case BPF_JSLT:
567                         jmp_cond = A64_COND_LT;
568                         break;
569                 case BPF_JSGE:
570                         jmp_cond = A64_COND_GE;
571                         break;
572                 case BPF_JSLE:
573                         jmp_cond = A64_COND_LE;
574                         break;
575                 default:
576                         return -EFAULT;
577                 }
578                 emit(A64_B_(jmp_cond, jmp_offset), ctx);
579                 break;
580         case BPF_JMP | BPF_JSET | BPF_X:
581                 emit(A64_TST(1, dst, src), ctx);
582                 goto emit_cond_jmp;
583         /* IF (dst COND imm) JUMP off */
584         case BPF_JMP | BPF_JEQ | BPF_K:
585         case BPF_JMP | BPF_JGT | BPF_K:
586         case BPF_JMP | BPF_JLT | BPF_K:
587         case BPF_JMP | BPF_JGE | BPF_K:
588         case BPF_JMP | BPF_JLE | BPF_K:
589         case BPF_JMP | BPF_JNE | BPF_K:
590         case BPF_JMP | BPF_JSGT | BPF_K:
591         case BPF_JMP | BPF_JSLT | BPF_K:
592         case BPF_JMP | BPF_JSGE | BPF_K:
593         case BPF_JMP | BPF_JSLE | BPF_K:
594                 emit_a64_mov_i(1, tmp, imm, ctx);
595                 emit(A64_CMP(1, dst, tmp), ctx);
596                 goto emit_cond_jmp;
597         case BPF_JMP | BPF_JSET | BPF_K:
598                 emit_a64_mov_i(1, tmp, imm, ctx);
599                 emit(A64_TST(1, dst, tmp), ctx);
600                 goto emit_cond_jmp;
601         /* function call */
602         case BPF_JMP | BPF_CALL:
603         {
604                 const u8 r0 = bpf2a64[BPF_REG_0];
605                 const u64 func = (u64)__bpf_call_base + imm;
606
607                 emit_a64_mov_i64(tmp, func, ctx);
608                 emit(A64_BLR(tmp), ctx);
609                 emit(A64_MOV(1, r0, A64_R(0)), ctx);
610                 break;
611         }
612         /* tail call */
613         case BPF_JMP | BPF_TAIL_CALL:
614                 if (emit_bpf_tail_call(ctx))
615                         return -EFAULT;
616                 break;
617         /* function return */
618         case BPF_JMP | BPF_EXIT:
619                 /* Optimization: when last instruction is EXIT,
620                    simply fallthrough to epilogue. */
621                 if (i == ctx->prog->len - 1)
622                         break;
623                 jmp_offset = epilogue_offset(ctx);
624                 check_imm26(jmp_offset);
625                 emit(A64_B(jmp_offset), ctx);
626                 break;
627
628         /* dst = imm64 */
629         case BPF_LD | BPF_IMM | BPF_DW:
630         {
631                 const struct bpf_insn insn1 = insn[1];
632                 u64 imm64;
633
634                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
635                 emit_a64_mov_i64(dst, imm64, ctx);
636
637                 return 1;
638         }
639
640         /* LDX: dst = *(size *)(src + off) */
641         case BPF_LDX | BPF_MEM | BPF_W:
642         case BPF_LDX | BPF_MEM | BPF_H:
643         case BPF_LDX | BPF_MEM | BPF_B:
644         case BPF_LDX | BPF_MEM | BPF_DW:
645                 emit_a64_mov_i(1, tmp, off, ctx);
646                 switch (BPF_SIZE(code)) {
647                 case BPF_W:
648                         emit(A64_LDR32(dst, src, tmp), ctx);
649                         break;
650                 case BPF_H:
651                         emit(A64_LDRH(dst, src, tmp), ctx);
652                         break;
653                 case BPF_B:
654                         emit(A64_LDRB(dst, src, tmp), ctx);
655                         break;
656                 case BPF_DW:
657                         emit(A64_LDR64(dst, src, tmp), ctx);
658                         break;
659                 }
660                 break;
661
662         /* ST: *(size *)(dst + off) = imm */
663         case BPF_ST | BPF_MEM | BPF_W:
664         case BPF_ST | BPF_MEM | BPF_H:
665         case BPF_ST | BPF_MEM | BPF_B:
666         case BPF_ST | BPF_MEM | BPF_DW:
667                 /* Load imm to a register then store it */
668                 emit_a64_mov_i(1, tmp2, off, ctx);
669                 emit_a64_mov_i(1, tmp, imm, ctx);
670                 switch (BPF_SIZE(code)) {
671                 case BPF_W:
672                         emit(A64_STR32(tmp, dst, tmp2), ctx);
673                         break;
674                 case BPF_H:
675                         emit(A64_STRH(tmp, dst, tmp2), ctx);
676                         break;
677                 case BPF_B:
678                         emit(A64_STRB(tmp, dst, tmp2), ctx);
679                         break;
680                 case BPF_DW:
681                         emit(A64_STR64(tmp, dst, tmp2), ctx);
682                         break;
683                 }
684                 break;
685
686         /* STX: *(size *)(dst + off) = src */
687         case BPF_STX | BPF_MEM | BPF_W:
688         case BPF_STX | BPF_MEM | BPF_H:
689         case BPF_STX | BPF_MEM | BPF_B:
690         case BPF_STX | BPF_MEM | BPF_DW:
691                 emit_a64_mov_i(1, tmp, off, ctx);
692                 switch (BPF_SIZE(code)) {
693                 case BPF_W:
694                         emit(A64_STR32(src, dst, tmp), ctx);
695                         break;
696                 case BPF_H:
697                         emit(A64_STRH(src, dst, tmp), ctx);
698                         break;
699                 case BPF_B:
700                         emit(A64_STRB(src, dst, tmp), ctx);
701                         break;
702                 case BPF_DW:
703                         emit(A64_STR64(src, dst, tmp), ctx);
704                         break;
705                 }
706                 break;
707
708         /* STX XADD: lock *(u32 *)(dst + off) += src */
709         case BPF_STX | BPF_XADD | BPF_W:
710         /* STX XADD: lock *(u64 *)(dst + off) += src */
711         case BPF_STX | BPF_XADD | BPF_DW:
712                 if (!off) {
713                         reg = dst;
714                 } else {
715                         emit_a64_mov_i(1, tmp, off, ctx);
716                         emit(A64_ADD(1, tmp, tmp, dst), ctx);
717                         reg = tmp;
718                 }
719                 if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS)) {
720                         emit(A64_STADD(isdw, reg, src), ctx);
721                 } else {
722                         emit(A64_LDXR(isdw, tmp2, reg), ctx);
723                         emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
724                         emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
725                         jmp_offset = -3;
726                         check_imm19(jmp_offset);
727                         emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
728                 }
729                 break;
730
731         /* R0 = ntohx(*(size *)(((struct sk_buff *)R6)->data + imm)) */
732         case BPF_LD | BPF_ABS | BPF_W:
733         case BPF_LD | BPF_ABS | BPF_H:
734         case BPF_LD | BPF_ABS | BPF_B:
735         /* R0 = ntohx(*(size *)(((struct sk_buff *)R6)->data + src + imm)) */
736         case BPF_LD | BPF_IND | BPF_W:
737         case BPF_LD | BPF_IND | BPF_H:
738         case BPF_LD | BPF_IND | BPF_B:
739         {
740                 const u8 r0 = bpf2a64[BPF_REG_0]; /* r0 = return value */
741                 const u8 r6 = bpf2a64[BPF_REG_6]; /* r6 = pointer to sk_buff */
742                 const u8 fp = bpf2a64[BPF_REG_FP];
743                 const u8 r1 = bpf2a64[BPF_REG_1]; /* r1: struct sk_buff *skb */
744                 const u8 r2 = bpf2a64[BPF_REG_2]; /* r2: int k */
745                 const u8 r3 = bpf2a64[BPF_REG_3]; /* r3: unsigned int size */
746                 const u8 r4 = bpf2a64[BPF_REG_4]; /* r4: void *buffer */
747                 const u8 r5 = bpf2a64[BPF_REG_5]; /* r5: void *(*func)(...) */
748                 int size;
749
750                 emit(A64_MOV(1, r1, r6), ctx);
751                 emit_a64_mov_i(0, r2, imm, ctx);
752                 if (BPF_MODE(code) == BPF_IND)
753                         emit(A64_ADD(0, r2, r2, src), ctx);
754                 switch (BPF_SIZE(code)) {
755                 case BPF_W:
756                         size = 4;
757                         break;
758                 case BPF_H:
759                         size = 2;
760                         break;
761                 case BPF_B:
762                         size = 1;
763                         break;
764                 default:
765                         return -EINVAL;
766                 }
767                 emit_a64_mov_i64(r3, size, ctx);
768                 emit(A64_SUB_I(1, r4, fp, ctx->stack_size), ctx);
769                 emit_a64_mov_i64(r5, (unsigned long)bpf_load_pointer, ctx);
770                 emit(A64_BLR(r5), ctx);
771                 emit(A64_MOV(1, r0, A64_R(0)), ctx);
772
773                 jmp_offset = epilogue_offset(ctx);
774                 check_imm19(jmp_offset);
775                 emit(A64_CBZ(1, r0, jmp_offset), ctx);
776                 emit(A64_MOV(1, r5, r0), ctx);
777                 switch (BPF_SIZE(code)) {
778                 case BPF_W:
779                         emit(A64_LDR32(r0, r5, A64_ZR), ctx);
780 #ifndef CONFIG_CPU_BIG_ENDIAN
781                         emit(A64_REV32(0, r0, r0), ctx);
782 #endif
783                         break;
784                 case BPF_H:
785                         emit(A64_LDRH(r0, r5, A64_ZR), ctx);
786 #ifndef CONFIG_CPU_BIG_ENDIAN
787                         emit(A64_REV16(0, r0, r0), ctx);
788 #endif
789                         break;
790                 case BPF_B:
791                         emit(A64_LDRB(r0, r5, A64_ZR), ctx);
792                         break;
793                 }
794                 break;
795         }
796         default:
797                 pr_err_once("unknown opcode %02x\n", code);
798                 return -EINVAL;
799         }
800
801         return 0;
802 }
803
804 static int build_body(struct jit_ctx *ctx)
805 {
806         const struct bpf_prog *prog = ctx->prog;
807         int i;
808
809         for (i = 0; i < prog->len; i++) {
810                 const struct bpf_insn *insn = &prog->insnsi[i];
811                 int ret;
812
813                 ret = build_insn(insn, ctx);
814                 if (ret > 0) {
815                         i++;
816                         if (ctx->image == NULL)
817                                 ctx->offset[i] = ctx->idx;
818                         continue;
819                 }
820                 if (ctx->image == NULL)
821                         ctx->offset[i] = ctx->idx;
822                 if (ret)
823                         return ret;
824         }
825
826         return 0;
827 }
828
829 static int validate_code(struct jit_ctx *ctx)
830 {
831         int i;
832
833         for (i = 0; i < ctx->idx; i++) {
834                 u32 a64_insn = le32_to_cpu(ctx->image[i]);
835
836                 if (a64_insn == AARCH64_BREAK_FAULT)
837                         return -1;
838         }
839
840         return 0;
841 }
842
843 static inline void bpf_flush_icache(void *start, void *end)
844 {
845         flush_icache_range((unsigned long)start, (unsigned long)end);
846 }
847
848 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
849 {
850         struct bpf_prog *tmp, *orig_prog = prog;
851         struct bpf_binary_header *header;
852         bool tmp_blinded = false;
853         struct jit_ctx ctx;
854         int image_size;
855         u8 *image_ptr;
856
857         if (!bpf_jit_enable)
858                 return orig_prog;
859
860         tmp = bpf_jit_blind_constants(prog);
861         /* If blinding was requested and we failed during blinding,
862          * we must fall back to the interpreter.
863          */
864         if (IS_ERR(tmp))
865                 return orig_prog;
866         if (tmp != prog) {
867                 tmp_blinded = true;
868                 prog = tmp;
869         }
870
871         memset(&ctx, 0, sizeof(ctx));
872         ctx.prog = prog;
873
874         ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
875         if (ctx.offset == NULL) {
876                 prog = orig_prog;
877                 goto out;
878         }
879
880         /* 1. Initial fake pass to compute ctx->idx. */
881
882         /* Fake pass to fill in ctx->offset. */
883         if (build_body(&ctx)) {
884                 prog = orig_prog;
885                 goto out_off;
886         }
887
888         if (build_prologue(&ctx)) {
889                 prog = orig_prog;
890                 goto out_off;
891         }
892
893         ctx.epilogue_offset = ctx.idx;
894         build_epilogue(&ctx);
895
896         /* Now we know the actual image size. */
897         image_size = sizeof(u32) * ctx.idx;
898         header = bpf_jit_binary_alloc(image_size, &image_ptr,
899                                       sizeof(u32), jit_fill_hole);
900         if (header == NULL) {
901                 prog = orig_prog;
902                 goto out_off;
903         }
904
905         /* 2. Now, the actual pass. */
906
907         ctx.image = (__le32 *)image_ptr;
908         ctx.idx = 0;
909
910         build_prologue(&ctx);
911
912         if (build_body(&ctx)) {
913                 bpf_jit_binary_free(header);
914                 prog = orig_prog;
915                 goto out_off;
916         }
917
918         build_epilogue(&ctx);
919
920         /* 3. Extra pass to validate JITed code. */
921         if (validate_code(&ctx)) {
922                 bpf_jit_binary_free(header);
923                 prog = orig_prog;
924                 goto out_off;
925         }
926
927         /* And we're done. */
928         if (bpf_jit_enable > 1)
929                 bpf_jit_dump(prog->len, image_size, 2, ctx.image);
930
931         bpf_flush_icache(header, ctx.image + ctx.idx);
932
933         bpf_jit_binary_lock_ro(header);
934         prog->bpf_func = (void *)ctx.image;
935         prog->jited = 1;
936         prog->jited_len = image_size;
937
938 out_off:
939         kfree(ctx.offset);
940 out:
941         if (tmp_blinded)
942                 bpf_jit_prog_release_other(prog, prog == orig_prog ?
943                                            tmp : orig_prog);
944         return prog;
945 }