GNU Linux-libre 5.15.137-gnu
[releases.git] / arch / arm64 / crypto / aes-neonbs-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
17
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
20
21 MODULE_ALIAS_CRYPTO("ecb(aes)");
22 MODULE_ALIAS_CRYPTO("cbc(aes)");
23 MODULE_ALIAS_CRYPTO("ctr(aes)");
24 MODULE_ALIAS_CRYPTO("xts(aes)");
25
26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27
28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29                                   int rounds, int blocks);
30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31                                   int rounds, int blocks);
32
33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34                                   int rounds, int blocks, u8 iv[]);
35
36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37                                   int rounds, int blocks, u8 iv[], u8 final[]);
38
39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40                                   int rounds, int blocks, u8 iv[]);
41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42                                   int rounds, int blocks, u8 iv[]);
43
44 /* borrowed from aes-neon-blk.ko */
45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
46                                      int rounds, int blocks);
47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
48                                      int rounds, int blocks, u8 iv[]);
49 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
50                                      u32 const rk1[], int rounds, int bytes,
51                                      u32 const rk2[], u8 iv[], int first);
52 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
53                                      u32 const rk1[], int rounds, int bytes,
54                                      u32 const rk2[], u8 iv[], int first);
55
56 struct aesbs_ctx {
57         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
58         int     rounds;
59 } __aligned(AES_BLOCK_SIZE);
60
61 struct aesbs_cbc_ctx {
62         struct aesbs_ctx        key;
63         u32                     enc[AES_MAX_KEYLENGTH_U32];
64 };
65
66 struct aesbs_ctr_ctx {
67         struct aesbs_ctx        key;            /* must be first member */
68         struct crypto_aes_ctx   fallback;
69 };
70
71 struct aesbs_xts_ctx {
72         struct aesbs_ctx        key;
73         u32                     twkey[AES_MAX_KEYLENGTH_U32];
74         struct crypto_aes_ctx   cts;
75 };
76
77 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
78                         unsigned int key_len)
79 {
80         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
81         struct crypto_aes_ctx rk;
82         int err;
83
84         err = aes_expandkey(&rk, in_key, key_len);
85         if (err)
86                 return err;
87
88         ctx->rounds = 6 + key_len / 4;
89
90         kernel_neon_begin();
91         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
92         kernel_neon_end();
93
94         return 0;
95 }
96
97 static int __ecb_crypt(struct skcipher_request *req,
98                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
99                                   int rounds, int blocks))
100 {
101         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
102         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
103         struct skcipher_walk walk;
104         int err;
105
106         err = skcipher_walk_virt(&walk, req, false);
107
108         while (walk.nbytes >= AES_BLOCK_SIZE) {
109                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
110
111                 if (walk.nbytes < walk.total)
112                         blocks = round_down(blocks,
113                                             walk.stride / AES_BLOCK_SIZE);
114
115                 kernel_neon_begin();
116                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
117                    ctx->rounds, blocks);
118                 kernel_neon_end();
119                 err = skcipher_walk_done(&walk,
120                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
121         }
122
123         return err;
124 }
125
126 static int ecb_encrypt(struct skcipher_request *req)
127 {
128         return __ecb_crypt(req, aesbs_ecb_encrypt);
129 }
130
131 static int ecb_decrypt(struct skcipher_request *req)
132 {
133         return __ecb_crypt(req, aesbs_ecb_decrypt);
134 }
135
136 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
137                             unsigned int key_len)
138 {
139         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
140         struct crypto_aes_ctx rk;
141         int err;
142
143         err = aes_expandkey(&rk, in_key, key_len);
144         if (err)
145                 return err;
146
147         ctx->key.rounds = 6 + key_len / 4;
148
149         memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
150
151         kernel_neon_begin();
152         aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
153         kernel_neon_end();
154         memzero_explicit(&rk, sizeof(rk));
155
156         return 0;
157 }
158
159 static int cbc_encrypt(struct skcipher_request *req)
160 {
161         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
162         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
163         struct skcipher_walk walk;
164         int err;
165
166         err = skcipher_walk_virt(&walk, req, false);
167
168         while (walk.nbytes >= AES_BLOCK_SIZE) {
169                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
170
171                 /* fall back to the non-bitsliced NEON implementation */
172                 kernel_neon_begin();
173                 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
174                                      ctx->enc, ctx->key.rounds, blocks,
175                                      walk.iv);
176                 kernel_neon_end();
177                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
178         }
179         return err;
180 }
181
182 static int cbc_decrypt(struct skcipher_request *req)
183 {
184         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
185         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
186         struct skcipher_walk walk;
187         int err;
188
189         err = skcipher_walk_virt(&walk, req, false);
190
191         while (walk.nbytes >= AES_BLOCK_SIZE) {
192                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
193
194                 if (walk.nbytes < walk.total)
195                         blocks = round_down(blocks,
196                                             walk.stride / AES_BLOCK_SIZE);
197
198                 kernel_neon_begin();
199                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
200                                   ctx->key.rk, ctx->key.rounds, blocks,
201                                   walk.iv);
202                 kernel_neon_end();
203                 err = skcipher_walk_done(&walk,
204                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
205         }
206
207         return err;
208 }
209
210 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
211                                  unsigned int key_len)
212 {
213         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
214         int err;
215
216         err = aes_expandkey(&ctx->fallback, in_key, key_len);
217         if (err)
218                 return err;
219
220         ctx->key.rounds = 6 + key_len / 4;
221
222         kernel_neon_begin();
223         aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
224         kernel_neon_end();
225
226         return 0;
227 }
228
229 static int ctr_encrypt(struct skcipher_request *req)
230 {
231         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
232         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
233         struct skcipher_walk walk;
234         u8 buf[AES_BLOCK_SIZE];
235         int err;
236
237         err = skcipher_walk_virt(&walk, req, false);
238
239         while (walk.nbytes > 0) {
240                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
241                 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
242
243                 if (walk.nbytes < walk.total) {
244                         blocks = round_down(blocks,
245                                             walk.stride / AES_BLOCK_SIZE);
246                         final = NULL;
247                 }
248
249                 kernel_neon_begin();
250                 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
251                                   ctx->rk, ctx->rounds, blocks, walk.iv, final);
252                 kernel_neon_end();
253
254                 if (final) {
255                         u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
256                         u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
257
258                         crypto_xor_cpy(dst, src, final,
259                                        walk.total % AES_BLOCK_SIZE);
260
261                         err = skcipher_walk_done(&walk, 0);
262                         break;
263                 }
264                 err = skcipher_walk_done(&walk,
265                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
266         }
267         return err;
268 }
269
270 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
271                             unsigned int key_len)
272 {
273         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
274         struct crypto_aes_ctx rk;
275         int err;
276
277         err = xts_verify_key(tfm, in_key, key_len);
278         if (err)
279                 return err;
280
281         key_len /= 2;
282         err = aes_expandkey(&ctx->cts, in_key, key_len);
283         if (err)
284                 return err;
285
286         err = aes_expandkey(&rk, in_key + key_len, key_len);
287         if (err)
288                 return err;
289
290         memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
291
292         return aesbs_setkey(tfm, in_key, key_len);
293 }
294
295 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
296 {
297         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
298         unsigned long flags;
299
300         /*
301          * Temporarily disable interrupts to avoid races where
302          * cachelines are evicted when the CPU is interrupted
303          * to do something else.
304          */
305         local_irq_save(flags);
306         aes_encrypt(&ctx->fallback, dst, src);
307         local_irq_restore(flags);
308 }
309
310 static int ctr_encrypt_sync(struct skcipher_request *req)
311 {
312         if (!crypto_simd_usable())
313                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
314
315         return ctr_encrypt(req);
316 }
317
318 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
319                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
320                                   int rounds, int blocks, u8 iv[]))
321 {
322         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
323         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
324         int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
325         struct scatterlist sg_src[2], sg_dst[2];
326         struct skcipher_request subreq;
327         struct scatterlist *src, *dst;
328         struct skcipher_walk walk;
329         int nbytes, err;
330         int first = 1;
331         u8 *out, *in;
332
333         if (req->cryptlen < AES_BLOCK_SIZE)
334                 return -EINVAL;
335
336         /* ensure that the cts tail is covered by a single step */
337         if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
338                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
339                                               AES_BLOCK_SIZE) - 2;
340
341                 skcipher_request_set_tfm(&subreq, tfm);
342                 skcipher_request_set_callback(&subreq,
343                                               skcipher_request_flags(req),
344                                               NULL, NULL);
345                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
346                                            xts_blocks * AES_BLOCK_SIZE,
347                                            req->iv);
348                 req = &subreq;
349         } else {
350                 tail = 0;
351         }
352
353         err = skcipher_walk_virt(&walk, req, false);
354         if (err)
355                 return err;
356
357         while (walk.nbytes >= AES_BLOCK_SIZE) {
358                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
359
360                 if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
361                         blocks = round_down(blocks,
362                                             walk.stride / AES_BLOCK_SIZE);
363
364                 out = walk.dst.virt.addr;
365                 in = walk.src.virt.addr;
366                 nbytes = walk.nbytes;
367
368                 kernel_neon_begin();
369                 if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
370                         if (first)
371                                 neon_aes_ecb_encrypt(walk.iv, walk.iv,
372                                                      ctx->twkey,
373                                                      ctx->key.rounds, 1);
374                         first = 0;
375
376                         fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
377                            walk.iv);
378
379                         out += blocks * AES_BLOCK_SIZE;
380                         in += blocks * AES_BLOCK_SIZE;
381                         nbytes -= blocks * AES_BLOCK_SIZE;
382                 }
383
384                 if (walk.nbytes == walk.total && nbytes > 0)
385                         goto xts_tail;
386
387                 kernel_neon_end();
388                 err = skcipher_walk_done(&walk, nbytes);
389         }
390
391         if (err || likely(!tail))
392                 return err;
393
394         /* handle ciphertext stealing */
395         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
396         if (req->dst != req->src)
397                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
398
399         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
400                                    req->iv);
401
402         err = skcipher_walk_virt(&walk, req, false);
403         if (err)
404                 return err;
405
406         out = walk.dst.virt.addr;
407         in = walk.src.virt.addr;
408         nbytes = walk.nbytes;
409
410         kernel_neon_begin();
411 xts_tail:
412         if (encrypt)
413                 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
414                                      nbytes, ctx->twkey, walk.iv, first ?: 2);
415         else
416                 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
417                                      nbytes, ctx->twkey, walk.iv, first ?: 2);
418         kernel_neon_end();
419
420         return skcipher_walk_done(&walk, 0);
421 }
422
423 static int xts_encrypt(struct skcipher_request *req)
424 {
425         return __xts_crypt(req, true, aesbs_xts_encrypt);
426 }
427
428 static int xts_decrypt(struct skcipher_request *req)
429 {
430         return __xts_crypt(req, false, aesbs_xts_decrypt);
431 }
432
433 static struct skcipher_alg aes_algs[] = { {
434         .base.cra_name          = "__ecb(aes)",
435         .base.cra_driver_name   = "__ecb-aes-neonbs",
436         .base.cra_priority      = 250,
437         .base.cra_blocksize     = AES_BLOCK_SIZE,
438         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
439         .base.cra_module        = THIS_MODULE,
440         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
441
442         .min_keysize            = AES_MIN_KEY_SIZE,
443         .max_keysize            = AES_MAX_KEY_SIZE,
444         .walksize               = 8 * AES_BLOCK_SIZE,
445         .setkey                 = aesbs_setkey,
446         .encrypt                = ecb_encrypt,
447         .decrypt                = ecb_decrypt,
448 }, {
449         .base.cra_name          = "__cbc(aes)",
450         .base.cra_driver_name   = "__cbc-aes-neonbs",
451         .base.cra_priority      = 250,
452         .base.cra_blocksize     = AES_BLOCK_SIZE,
453         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
454         .base.cra_module        = THIS_MODULE,
455         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
456
457         .min_keysize            = AES_MIN_KEY_SIZE,
458         .max_keysize            = AES_MAX_KEY_SIZE,
459         .walksize               = 8 * AES_BLOCK_SIZE,
460         .ivsize                 = AES_BLOCK_SIZE,
461         .setkey                 = aesbs_cbc_setkey,
462         .encrypt                = cbc_encrypt,
463         .decrypt                = cbc_decrypt,
464 }, {
465         .base.cra_name          = "__ctr(aes)",
466         .base.cra_driver_name   = "__ctr-aes-neonbs",
467         .base.cra_priority      = 250,
468         .base.cra_blocksize     = 1,
469         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
470         .base.cra_module        = THIS_MODULE,
471         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
472
473         .min_keysize            = AES_MIN_KEY_SIZE,
474         .max_keysize            = AES_MAX_KEY_SIZE,
475         .chunksize              = AES_BLOCK_SIZE,
476         .walksize               = 8 * AES_BLOCK_SIZE,
477         .ivsize                 = AES_BLOCK_SIZE,
478         .setkey                 = aesbs_setkey,
479         .encrypt                = ctr_encrypt,
480         .decrypt                = ctr_encrypt,
481 }, {
482         .base.cra_name          = "ctr(aes)",
483         .base.cra_driver_name   = "ctr-aes-neonbs",
484         .base.cra_priority      = 250 - 1,
485         .base.cra_blocksize     = 1,
486         .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
487         .base.cra_module        = THIS_MODULE,
488
489         .min_keysize            = AES_MIN_KEY_SIZE,
490         .max_keysize            = AES_MAX_KEY_SIZE,
491         .chunksize              = AES_BLOCK_SIZE,
492         .walksize               = 8 * AES_BLOCK_SIZE,
493         .ivsize                 = AES_BLOCK_SIZE,
494         .setkey                 = aesbs_ctr_setkey_sync,
495         .encrypt                = ctr_encrypt_sync,
496         .decrypt                = ctr_encrypt_sync,
497 }, {
498         .base.cra_name          = "__xts(aes)",
499         .base.cra_driver_name   = "__xts-aes-neonbs",
500         .base.cra_priority      = 250,
501         .base.cra_blocksize     = AES_BLOCK_SIZE,
502         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
503         .base.cra_module        = THIS_MODULE,
504         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
505
506         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
507         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
508         .walksize               = 8 * AES_BLOCK_SIZE,
509         .ivsize                 = AES_BLOCK_SIZE,
510         .setkey                 = aesbs_xts_setkey,
511         .encrypt                = xts_encrypt,
512         .decrypt                = xts_decrypt,
513 } };
514
515 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
516
517 static void aes_exit(void)
518 {
519         int i;
520
521         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
522                 if (aes_simd_algs[i])
523                         simd_skcipher_free(aes_simd_algs[i]);
524
525         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
526 }
527
528 static int __init aes_init(void)
529 {
530         struct simd_skcipher_alg *simd;
531         const char *basename;
532         const char *algname;
533         const char *drvname;
534         int err;
535         int i;
536
537         if (!cpu_have_named_feature(ASIMD))
538                 return -ENODEV;
539
540         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
541         if (err)
542                 return err;
543
544         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
545                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
546                         continue;
547
548                 algname = aes_algs[i].base.cra_name + 2;
549                 drvname = aes_algs[i].base.cra_driver_name + 2;
550                 basename = aes_algs[i].base.cra_driver_name;
551                 simd = simd_skcipher_create_compat(algname, drvname, basename);
552                 err = PTR_ERR(simd);
553                 if (IS_ERR(simd))
554                         goto unregister_simds;
555
556                 aes_simd_algs[i] = simd;
557         }
558         return 0;
559
560 unregister_simds:
561         aes_exit();
562         return err;
563 }
564
565 module_init(aes_init);
566 module_exit(aes_exit);