GNU Linux-libre 5.4.274-gnu1
[releases.git] / arch / arm64 / crypto / aes-neonbs-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
17
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
20
21 MODULE_ALIAS_CRYPTO("ecb(aes)");
22 MODULE_ALIAS_CRYPTO("cbc(aes)");
23 MODULE_ALIAS_CRYPTO("ctr(aes)");
24 MODULE_ALIAS_CRYPTO("xts(aes)");
25
26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27
28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29                                   int rounds, int blocks);
30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31                                   int rounds, int blocks);
32
33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34                                   int rounds, int blocks, u8 iv[]);
35
36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37                                   int rounds, int blocks, u8 iv[], u8 final[]);
38
39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40                                   int rounds, int blocks, u8 iv[]);
41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42                                   int rounds, int blocks, u8 iv[]);
43
44 /* borrowed from aes-neon-blk.ko */
45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
46                                      int rounds, int blocks);
47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
48                                      int rounds, int blocks, u8 iv[]);
49 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
50                                      u32 const rk1[], int rounds, int bytes,
51                                      u32 const rk2[], u8 iv[], int first);
52 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
53                                      u32 const rk1[], int rounds, int bytes,
54                                      u32 const rk2[], u8 iv[], int first);
55
56 struct aesbs_ctx {
57         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
58         int     rounds;
59 } __aligned(AES_BLOCK_SIZE);
60
61 struct aesbs_cbc_ctx {
62         struct aesbs_ctx        key;
63         u32                     enc[AES_MAX_KEYLENGTH_U32];
64 };
65
66 struct aesbs_ctr_ctx {
67         struct aesbs_ctx        key;            /* must be first member */
68         struct crypto_aes_ctx   fallback;
69 };
70
71 struct aesbs_xts_ctx {
72         struct aesbs_ctx        key;
73         u32                     twkey[AES_MAX_KEYLENGTH_U32];
74         struct crypto_aes_ctx   cts;
75 };
76
77 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
78                         unsigned int key_len)
79 {
80         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
81         struct crypto_aes_ctx rk;
82         int err;
83
84         err = aes_expandkey(&rk, in_key, key_len);
85         if (err)
86                 return err;
87
88         ctx->rounds = 6 + key_len / 4;
89
90         kernel_neon_begin();
91         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
92         kernel_neon_end();
93
94         return 0;
95 }
96
97 static int __ecb_crypt(struct skcipher_request *req,
98                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
99                                   int rounds, int blocks))
100 {
101         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
102         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
103         struct skcipher_walk walk;
104         int err;
105
106         err = skcipher_walk_virt(&walk, req, false);
107
108         while (walk.nbytes >= AES_BLOCK_SIZE) {
109                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
110
111                 if (walk.nbytes < walk.total)
112                         blocks = round_down(blocks,
113                                             walk.stride / AES_BLOCK_SIZE);
114
115                 kernel_neon_begin();
116                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
117                    ctx->rounds, blocks);
118                 kernel_neon_end();
119                 err = skcipher_walk_done(&walk,
120                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
121         }
122
123         return err;
124 }
125
126 static int ecb_encrypt(struct skcipher_request *req)
127 {
128         return __ecb_crypt(req, aesbs_ecb_encrypt);
129 }
130
131 static int ecb_decrypt(struct skcipher_request *req)
132 {
133         return __ecb_crypt(req, aesbs_ecb_decrypt);
134 }
135
136 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
137                             unsigned int key_len)
138 {
139         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
140         struct crypto_aes_ctx rk;
141         int err;
142
143         err = aes_expandkey(&rk, in_key, key_len);
144         if (err)
145                 return err;
146
147         ctx->key.rounds = 6 + key_len / 4;
148
149         memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
150
151         kernel_neon_begin();
152         aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
153         kernel_neon_end();
154
155         return 0;
156 }
157
158 static int cbc_encrypt(struct skcipher_request *req)
159 {
160         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
161         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
162         struct skcipher_walk walk;
163         int err;
164
165         err = skcipher_walk_virt(&walk, req, false);
166
167         while (walk.nbytes >= AES_BLOCK_SIZE) {
168                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
169
170                 /* fall back to the non-bitsliced NEON implementation */
171                 kernel_neon_begin();
172                 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
173                                      ctx->enc, ctx->key.rounds, blocks,
174                                      walk.iv);
175                 kernel_neon_end();
176                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
177         }
178         return err;
179 }
180
181 static int cbc_decrypt(struct skcipher_request *req)
182 {
183         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
184         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
185         struct skcipher_walk walk;
186         int err;
187
188         err = skcipher_walk_virt(&walk, req, false);
189
190         while (walk.nbytes >= AES_BLOCK_SIZE) {
191                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
192
193                 if (walk.nbytes < walk.total)
194                         blocks = round_down(blocks,
195                                             walk.stride / AES_BLOCK_SIZE);
196
197                 kernel_neon_begin();
198                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
199                                   ctx->key.rk, ctx->key.rounds, blocks,
200                                   walk.iv);
201                 kernel_neon_end();
202                 err = skcipher_walk_done(&walk,
203                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
204         }
205
206         return err;
207 }
208
209 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
210                                  unsigned int key_len)
211 {
212         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
213         int err;
214
215         err = aes_expandkey(&ctx->fallback, in_key, key_len);
216         if (err)
217                 return err;
218
219         ctx->key.rounds = 6 + key_len / 4;
220
221         kernel_neon_begin();
222         aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
223         kernel_neon_end();
224
225         return 0;
226 }
227
228 static int ctr_encrypt(struct skcipher_request *req)
229 {
230         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
231         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
232         struct skcipher_walk walk;
233         u8 buf[AES_BLOCK_SIZE];
234         int err;
235
236         err = skcipher_walk_virt(&walk, req, false);
237
238         while (walk.nbytes > 0) {
239                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
240                 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
241
242                 if (walk.nbytes < walk.total) {
243                         blocks = round_down(blocks,
244                                             walk.stride / AES_BLOCK_SIZE);
245                         final = NULL;
246                 }
247
248                 kernel_neon_begin();
249                 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
250                                   ctx->rk, ctx->rounds, blocks, walk.iv, final);
251                 kernel_neon_end();
252
253                 if (final) {
254                         u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
255                         u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
256
257                         crypto_xor_cpy(dst, src, final,
258                                        walk.total % AES_BLOCK_SIZE);
259
260                         err = skcipher_walk_done(&walk, 0);
261                         break;
262                 }
263                 err = skcipher_walk_done(&walk,
264                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
265         }
266         return err;
267 }
268
269 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
270                             unsigned int key_len)
271 {
272         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
273         struct crypto_aes_ctx rk;
274         int err;
275
276         err = xts_verify_key(tfm, in_key, key_len);
277         if (err)
278                 return err;
279
280         key_len /= 2;
281         err = aes_expandkey(&ctx->cts, in_key, key_len);
282         if (err)
283                 return err;
284
285         err = aes_expandkey(&rk, in_key + key_len, key_len);
286         if (err)
287                 return err;
288
289         memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
290
291         return aesbs_setkey(tfm, in_key, key_len);
292 }
293
294 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
295 {
296         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
297         unsigned long flags;
298
299         /*
300          * Temporarily disable interrupts to avoid races where
301          * cachelines are evicted when the CPU is interrupted
302          * to do something else.
303          */
304         local_irq_save(flags);
305         aes_encrypt(&ctx->fallback, dst, src);
306         local_irq_restore(flags);
307 }
308
309 static int ctr_encrypt_sync(struct skcipher_request *req)
310 {
311         if (!crypto_simd_usable())
312                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
313
314         return ctr_encrypt(req);
315 }
316
317 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
318                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
319                                   int rounds, int blocks, u8 iv[]))
320 {
321         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
322         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
323         int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
324         struct scatterlist sg_src[2], sg_dst[2];
325         struct skcipher_request subreq;
326         struct scatterlist *src, *dst;
327         struct skcipher_walk walk;
328         int nbytes, err;
329         int first = 1;
330         u8 *out, *in;
331
332         if (req->cryptlen < AES_BLOCK_SIZE)
333                 return -EINVAL;
334
335         /* ensure that the cts tail is covered by a single step */
336         if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
337                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
338                                               AES_BLOCK_SIZE) - 2;
339
340                 skcipher_request_set_tfm(&subreq, tfm);
341                 skcipher_request_set_callback(&subreq,
342                                               skcipher_request_flags(req),
343                                               NULL, NULL);
344                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
345                                            xts_blocks * AES_BLOCK_SIZE,
346                                            req->iv);
347                 req = &subreq;
348         } else {
349                 tail = 0;
350         }
351
352         err = skcipher_walk_virt(&walk, req, false);
353         if (err)
354                 return err;
355
356         while (walk.nbytes >= AES_BLOCK_SIZE) {
357                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
358
359                 if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
360                         blocks = round_down(blocks,
361                                             walk.stride / AES_BLOCK_SIZE);
362
363                 out = walk.dst.virt.addr;
364                 in = walk.src.virt.addr;
365                 nbytes = walk.nbytes;
366
367                 kernel_neon_begin();
368                 if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
369                         if (first)
370                                 neon_aes_ecb_encrypt(walk.iv, walk.iv,
371                                                      ctx->twkey,
372                                                      ctx->key.rounds, 1);
373                         first = 0;
374
375                         fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
376                            walk.iv);
377
378                         out += blocks * AES_BLOCK_SIZE;
379                         in += blocks * AES_BLOCK_SIZE;
380                         nbytes -= blocks * AES_BLOCK_SIZE;
381                 }
382
383                 if (walk.nbytes == walk.total && nbytes > 0)
384                         goto xts_tail;
385
386                 kernel_neon_end();
387                 err = skcipher_walk_done(&walk, nbytes);
388         }
389
390         if (err || likely(!tail))
391                 return err;
392
393         /* handle ciphertext stealing */
394         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
395         if (req->dst != req->src)
396                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
397
398         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
399                                    req->iv);
400
401         err = skcipher_walk_virt(&walk, req, false);
402         if (err)
403                 return err;
404
405         out = walk.dst.virt.addr;
406         in = walk.src.virt.addr;
407         nbytes = walk.nbytes;
408
409         kernel_neon_begin();
410 xts_tail:
411         if (encrypt)
412                 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
413                                      nbytes, ctx->twkey, walk.iv, first ?: 2);
414         else
415                 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
416                                      nbytes, ctx->twkey, walk.iv, first ?: 2);
417         kernel_neon_end();
418
419         return skcipher_walk_done(&walk, 0);
420 }
421
422 static int xts_encrypt(struct skcipher_request *req)
423 {
424         return __xts_crypt(req, true, aesbs_xts_encrypt);
425 }
426
427 static int xts_decrypt(struct skcipher_request *req)
428 {
429         return __xts_crypt(req, false, aesbs_xts_decrypt);
430 }
431
432 static struct skcipher_alg aes_algs[] = { {
433         .base.cra_name          = "__ecb(aes)",
434         .base.cra_driver_name   = "__ecb-aes-neonbs",
435         .base.cra_priority      = 250,
436         .base.cra_blocksize     = AES_BLOCK_SIZE,
437         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
438         .base.cra_module        = THIS_MODULE,
439         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
440
441         .min_keysize            = AES_MIN_KEY_SIZE,
442         .max_keysize            = AES_MAX_KEY_SIZE,
443         .walksize               = 8 * AES_BLOCK_SIZE,
444         .setkey                 = aesbs_setkey,
445         .encrypt                = ecb_encrypt,
446         .decrypt                = ecb_decrypt,
447 }, {
448         .base.cra_name          = "__cbc(aes)",
449         .base.cra_driver_name   = "__cbc-aes-neonbs",
450         .base.cra_priority      = 250,
451         .base.cra_blocksize     = AES_BLOCK_SIZE,
452         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
453         .base.cra_module        = THIS_MODULE,
454         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
455
456         .min_keysize            = AES_MIN_KEY_SIZE,
457         .max_keysize            = AES_MAX_KEY_SIZE,
458         .walksize               = 8 * AES_BLOCK_SIZE,
459         .ivsize                 = AES_BLOCK_SIZE,
460         .setkey                 = aesbs_cbc_setkey,
461         .encrypt                = cbc_encrypt,
462         .decrypt                = cbc_decrypt,
463 }, {
464         .base.cra_name          = "__ctr(aes)",
465         .base.cra_driver_name   = "__ctr-aes-neonbs",
466         .base.cra_priority      = 250,
467         .base.cra_blocksize     = 1,
468         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
469         .base.cra_module        = THIS_MODULE,
470         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
471
472         .min_keysize            = AES_MIN_KEY_SIZE,
473         .max_keysize            = AES_MAX_KEY_SIZE,
474         .chunksize              = AES_BLOCK_SIZE,
475         .walksize               = 8 * AES_BLOCK_SIZE,
476         .ivsize                 = AES_BLOCK_SIZE,
477         .setkey                 = aesbs_setkey,
478         .encrypt                = ctr_encrypt,
479         .decrypt                = ctr_encrypt,
480 }, {
481         .base.cra_name          = "ctr(aes)",
482         .base.cra_driver_name   = "ctr-aes-neonbs",
483         .base.cra_priority      = 250 - 1,
484         .base.cra_blocksize     = 1,
485         .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
486         .base.cra_module        = THIS_MODULE,
487
488         .min_keysize            = AES_MIN_KEY_SIZE,
489         .max_keysize            = AES_MAX_KEY_SIZE,
490         .chunksize              = AES_BLOCK_SIZE,
491         .walksize               = 8 * AES_BLOCK_SIZE,
492         .ivsize                 = AES_BLOCK_SIZE,
493         .setkey                 = aesbs_ctr_setkey_sync,
494         .encrypt                = ctr_encrypt_sync,
495         .decrypt                = ctr_encrypt_sync,
496 }, {
497         .base.cra_name          = "__xts(aes)",
498         .base.cra_driver_name   = "__xts-aes-neonbs",
499         .base.cra_priority      = 250,
500         .base.cra_blocksize     = AES_BLOCK_SIZE,
501         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
502         .base.cra_module        = THIS_MODULE,
503         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
504
505         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
506         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
507         .walksize               = 8 * AES_BLOCK_SIZE,
508         .ivsize                 = AES_BLOCK_SIZE,
509         .setkey                 = aesbs_xts_setkey,
510         .encrypt                = xts_encrypt,
511         .decrypt                = xts_decrypt,
512 } };
513
514 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
515
516 static void aes_exit(void)
517 {
518         int i;
519
520         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
521                 if (aes_simd_algs[i])
522                         simd_skcipher_free(aes_simd_algs[i]);
523
524         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
525 }
526
527 static int __init aes_init(void)
528 {
529         struct simd_skcipher_alg *simd;
530         const char *basename;
531         const char *algname;
532         const char *drvname;
533         int err;
534         int i;
535
536         if (!cpu_have_named_feature(ASIMD))
537                 return -ENODEV;
538
539         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
540         if (err)
541                 return err;
542
543         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
544                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
545                         continue;
546
547                 algname = aes_algs[i].base.cra_name + 2;
548                 drvname = aes_algs[i].base.cra_driver_name + 2;
549                 basename = aes_algs[i].base.cra_driver_name;
550                 simd = simd_skcipher_create_compat(algname, drvname, basename);
551                 err = PTR_ERR(simd);
552                 if (IS_ERR(simd))
553                         goto unregister_simds;
554
555                 aes_simd_algs[i] = simd;
556         }
557         return 0;
558
559 unregister_simds:
560         aes_exit();
561         return err;
562 }
563
564 module_init(aes_init);
565 module_exit(aes_exit);