GNU Linux-libre 5.19-rc6-gnu
[releases.git] / arch / arm / crypto / aes-neonbs-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/cipher.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
21
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)-all");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
26
27 MODULE_IMPORT_NS(CRYPTO_INTERNAL);
28
29 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
30
31 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
32                                   int rounds, int blocks);
33 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
34                                   int rounds, int blocks);
35
36 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
37                                   int rounds, int blocks, u8 iv[]);
38
39 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
40                                   int rounds, int blocks, u8 ctr[]);
41
42 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
43                                   int rounds, int blocks, u8 iv[], int);
44 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
45                                   int rounds, int blocks, u8 iv[], int);
46
47 struct aesbs_ctx {
48         int     rounds;
49         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
50 };
51
52 struct aesbs_cbc_ctx {
53         struct aesbs_ctx        key;
54         struct crypto_skcipher  *enc_tfm;
55 };
56
57 struct aesbs_xts_ctx {
58         struct aesbs_ctx        key;
59         struct crypto_cipher    *cts_tfm;
60         struct crypto_cipher    *tweak_tfm;
61 };
62
63 struct aesbs_ctr_ctx {
64         struct aesbs_ctx        key;            /* must be first member */
65         struct crypto_aes_ctx   fallback;
66 };
67
68 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
69                         unsigned int key_len)
70 {
71         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
72         struct crypto_aes_ctx rk;
73         int err;
74
75         err = aes_expandkey(&rk, in_key, key_len);
76         if (err)
77                 return err;
78
79         ctx->rounds = 6 + key_len / 4;
80
81         kernel_neon_begin();
82         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
83         kernel_neon_end();
84
85         return 0;
86 }
87
88 static int __ecb_crypt(struct skcipher_request *req,
89                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
90                                   int rounds, int blocks))
91 {
92         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
93         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
94         struct skcipher_walk walk;
95         int err;
96
97         err = skcipher_walk_virt(&walk, req, false);
98
99         while (walk.nbytes >= AES_BLOCK_SIZE) {
100                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
101
102                 if (walk.nbytes < walk.total)
103                         blocks = round_down(blocks,
104                                             walk.stride / AES_BLOCK_SIZE);
105
106                 kernel_neon_begin();
107                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
108                    ctx->rounds, blocks);
109                 kernel_neon_end();
110                 err = skcipher_walk_done(&walk,
111                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
112         }
113
114         return err;
115 }
116
117 static int ecb_encrypt(struct skcipher_request *req)
118 {
119         return __ecb_crypt(req, aesbs_ecb_encrypt);
120 }
121
122 static int ecb_decrypt(struct skcipher_request *req)
123 {
124         return __ecb_crypt(req, aesbs_ecb_decrypt);
125 }
126
127 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
128                             unsigned int key_len)
129 {
130         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
131         struct crypto_aes_ctx rk;
132         int err;
133
134         err = aes_expandkey(&rk, in_key, key_len);
135         if (err)
136                 return err;
137
138         ctx->key.rounds = 6 + key_len / 4;
139
140         kernel_neon_begin();
141         aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
142         kernel_neon_end();
143         memzero_explicit(&rk, sizeof(rk));
144
145         return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
146 }
147
148 static int cbc_encrypt(struct skcipher_request *req)
149 {
150         struct skcipher_request *subreq = skcipher_request_ctx(req);
151         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
152         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
153
154         skcipher_request_set_tfm(subreq, ctx->enc_tfm);
155         skcipher_request_set_callback(subreq,
156                                       skcipher_request_flags(req),
157                                       NULL, NULL);
158         skcipher_request_set_crypt(subreq, req->src, req->dst,
159                                    req->cryptlen, req->iv);
160
161         return crypto_skcipher_encrypt(subreq);
162 }
163
164 static int cbc_decrypt(struct skcipher_request *req)
165 {
166         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
167         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
168         struct skcipher_walk walk;
169         int err;
170
171         err = skcipher_walk_virt(&walk, req, false);
172
173         while (walk.nbytes >= AES_BLOCK_SIZE) {
174                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
175
176                 if (walk.nbytes < walk.total)
177                         blocks = round_down(blocks,
178                                             walk.stride / AES_BLOCK_SIZE);
179
180                 kernel_neon_begin();
181                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
182                                   ctx->key.rk, ctx->key.rounds, blocks,
183                                   walk.iv);
184                 kernel_neon_end();
185                 err = skcipher_walk_done(&walk,
186                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
187         }
188
189         return err;
190 }
191
192 static int cbc_init(struct crypto_skcipher *tfm)
193 {
194         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
195         unsigned int reqsize;
196
197         ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
198                                              CRYPTO_ALG_NEED_FALLBACK);
199         if (IS_ERR(ctx->enc_tfm))
200                 return PTR_ERR(ctx->enc_tfm);
201
202         reqsize = sizeof(struct skcipher_request);
203         reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
204         crypto_skcipher_set_reqsize(tfm, reqsize);
205
206         return 0;
207 }
208
209 static void cbc_exit(struct crypto_skcipher *tfm)
210 {
211         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
212
213         crypto_free_skcipher(ctx->enc_tfm);
214 }
215
216 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
217                                  unsigned int key_len)
218 {
219         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
220         int err;
221
222         err = aes_expandkey(&ctx->fallback, in_key, key_len);
223         if (err)
224                 return err;
225
226         ctx->key.rounds = 6 + key_len / 4;
227
228         kernel_neon_begin();
229         aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
230         kernel_neon_end();
231
232         return 0;
233 }
234
235 static int ctr_encrypt(struct skcipher_request *req)
236 {
237         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
238         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
239         struct skcipher_walk walk;
240         u8 buf[AES_BLOCK_SIZE];
241         int err;
242
243         err = skcipher_walk_virt(&walk, req, false);
244
245         while (walk.nbytes > 0) {
246                 const u8 *src = walk.src.virt.addr;
247                 u8 *dst = walk.dst.virt.addr;
248                 int bytes = walk.nbytes;
249
250                 if (unlikely(bytes < AES_BLOCK_SIZE))
251                         src = dst = memcpy(buf + sizeof(buf) - bytes,
252                                            src, bytes);
253                 else if (walk.nbytes < walk.total)
254                         bytes &= ~(8 * AES_BLOCK_SIZE - 1);
255
256                 kernel_neon_begin();
257                 aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
258                 kernel_neon_end();
259
260                 if (unlikely(bytes < AES_BLOCK_SIZE))
261                         memcpy(walk.dst.virt.addr,
262                                buf + sizeof(buf) - bytes, bytes);
263
264                 err = skcipher_walk_done(&walk, walk.nbytes - bytes);
265         }
266
267         return err;
268 }
269
270 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
271 {
272         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
273         unsigned long flags;
274
275         /*
276          * Temporarily disable interrupts to avoid races where
277          * cachelines are evicted when the CPU is interrupted
278          * to do something else.
279          */
280         local_irq_save(flags);
281         aes_encrypt(&ctx->fallback, dst, src);
282         local_irq_restore(flags);
283 }
284
285 static int ctr_encrypt_sync(struct skcipher_request *req)
286 {
287         if (!crypto_simd_usable())
288                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
289
290         return ctr_encrypt(req);
291 }
292
293 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
294                             unsigned int key_len)
295 {
296         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
297         int err;
298
299         err = xts_verify_key(tfm, in_key, key_len);
300         if (err)
301                 return err;
302
303         key_len /= 2;
304         err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
305         if (err)
306                 return err;
307         err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
308         if (err)
309                 return err;
310
311         return aesbs_setkey(tfm, in_key, key_len);
312 }
313
314 static int xts_init(struct crypto_skcipher *tfm)
315 {
316         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
317
318         ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
319         if (IS_ERR(ctx->cts_tfm))
320                 return PTR_ERR(ctx->cts_tfm);
321
322         ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
323         if (IS_ERR(ctx->tweak_tfm))
324                 crypto_free_cipher(ctx->cts_tfm);
325
326         return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
327 }
328
329 static void xts_exit(struct crypto_skcipher *tfm)
330 {
331         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
332
333         crypto_free_cipher(ctx->tweak_tfm);
334         crypto_free_cipher(ctx->cts_tfm);
335 }
336
337 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
338                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
339                                   int rounds, int blocks, u8 iv[], int))
340 {
341         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
342         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
343         int tail = req->cryptlen % AES_BLOCK_SIZE;
344         struct skcipher_request subreq;
345         u8 buf[2 * AES_BLOCK_SIZE];
346         struct skcipher_walk walk;
347         int err;
348
349         if (req->cryptlen < AES_BLOCK_SIZE)
350                 return -EINVAL;
351
352         if (unlikely(tail)) {
353                 skcipher_request_set_tfm(&subreq, tfm);
354                 skcipher_request_set_callback(&subreq,
355                                               skcipher_request_flags(req),
356                                               NULL, NULL);
357                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
358                                            req->cryptlen - tail, req->iv);
359                 req = &subreq;
360         }
361
362         err = skcipher_walk_virt(&walk, req, true);
363         if (err)
364                 return err;
365
366         crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
367
368         while (walk.nbytes >= AES_BLOCK_SIZE) {
369                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
370                 int reorder_last_tweak = !encrypt && tail > 0;
371
372                 if (walk.nbytes < walk.total) {
373                         blocks = round_down(blocks,
374                                             walk.stride / AES_BLOCK_SIZE);
375                         reorder_last_tweak = 0;
376                 }
377
378                 kernel_neon_begin();
379                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
380                    ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
381                 kernel_neon_end();
382                 err = skcipher_walk_done(&walk,
383                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
384         }
385
386         if (err || likely(!tail))
387                 return err;
388
389         /* handle ciphertext stealing */
390         scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
391                                  AES_BLOCK_SIZE, 0);
392         memcpy(buf + AES_BLOCK_SIZE, buf, tail);
393         scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
394
395         crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
396
397         if (encrypt)
398                 crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
399         else
400                 crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
401
402         crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
403
404         scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
405                                  AES_BLOCK_SIZE + tail, 1);
406         return 0;
407 }
408
409 static int xts_encrypt(struct skcipher_request *req)
410 {
411         return __xts_crypt(req, true, aesbs_xts_encrypt);
412 }
413
414 static int xts_decrypt(struct skcipher_request *req)
415 {
416         return __xts_crypt(req, false, aesbs_xts_decrypt);
417 }
418
419 static struct skcipher_alg aes_algs[] = { {
420         .base.cra_name          = "__ecb(aes)",
421         .base.cra_driver_name   = "__ecb-aes-neonbs",
422         .base.cra_priority      = 250,
423         .base.cra_blocksize     = AES_BLOCK_SIZE,
424         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
425         .base.cra_module        = THIS_MODULE,
426         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
427
428         .min_keysize            = AES_MIN_KEY_SIZE,
429         .max_keysize            = AES_MAX_KEY_SIZE,
430         .walksize               = 8 * AES_BLOCK_SIZE,
431         .setkey                 = aesbs_setkey,
432         .encrypt                = ecb_encrypt,
433         .decrypt                = ecb_decrypt,
434 }, {
435         .base.cra_name          = "__cbc(aes)",
436         .base.cra_driver_name   = "__cbc-aes-neonbs",
437         .base.cra_priority      = 250,
438         .base.cra_blocksize     = AES_BLOCK_SIZE,
439         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
440         .base.cra_module        = THIS_MODULE,
441         .base.cra_flags         = CRYPTO_ALG_INTERNAL |
442                                   CRYPTO_ALG_NEED_FALLBACK,
443
444         .min_keysize            = AES_MIN_KEY_SIZE,
445         .max_keysize            = AES_MAX_KEY_SIZE,
446         .walksize               = 8 * AES_BLOCK_SIZE,
447         .ivsize                 = AES_BLOCK_SIZE,
448         .setkey                 = aesbs_cbc_setkey,
449         .encrypt                = cbc_encrypt,
450         .decrypt                = cbc_decrypt,
451         .init                   = cbc_init,
452         .exit                   = cbc_exit,
453 }, {
454         .base.cra_name          = "__ctr(aes)",
455         .base.cra_driver_name   = "__ctr-aes-neonbs",
456         .base.cra_priority      = 250,
457         .base.cra_blocksize     = 1,
458         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
459         .base.cra_module        = THIS_MODULE,
460         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
461
462         .min_keysize            = AES_MIN_KEY_SIZE,
463         .max_keysize            = AES_MAX_KEY_SIZE,
464         .chunksize              = AES_BLOCK_SIZE,
465         .walksize               = 8 * AES_BLOCK_SIZE,
466         .ivsize                 = AES_BLOCK_SIZE,
467         .setkey                 = aesbs_setkey,
468         .encrypt                = ctr_encrypt,
469         .decrypt                = ctr_encrypt,
470 }, {
471         .base.cra_name          = "ctr(aes)",
472         .base.cra_driver_name   = "ctr-aes-neonbs-sync",
473         .base.cra_priority      = 250 - 1,
474         .base.cra_blocksize     = 1,
475         .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
476         .base.cra_module        = THIS_MODULE,
477
478         .min_keysize            = AES_MIN_KEY_SIZE,
479         .max_keysize            = AES_MAX_KEY_SIZE,
480         .chunksize              = AES_BLOCK_SIZE,
481         .walksize               = 8 * AES_BLOCK_SIZE,
482         .ivsize                 = AES_BLOCK_SIZE,
483         .setkey                 = aesbs_ctr_setkey_sync,
484         .encrypt                = ctr_encrypt_sync,
485         .decrypt                = ctr_encrypt_sync,
486 }, {
487         .base.cra_name          = "__xts(aes)",
488         .base.cra_driver_name   = "__xts-aes-neonbs",
489         .base.cra_priority      = 250,
490         .base.cra_blocksize     = AES_BLOCK_SIZE,
491         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
492         .base.cra_module        = THIS_MODULE,
493         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
494
495         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
496         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
497         .walksize               = 8 * AES_BLOCK_SIZE,
498         .ivsize                 = AES_BLOCK_SIZE,
499         .setkey                 = aesbs_xts_setkey,
500         .encrypt                = xts_encrypt,
501         .decrypt                = xts_decrypt,
502         .init                   = xts_init,
503         .exit                   = xts_exit,
504 } };
505
506 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
507
508 static void aes_exit(void)
509 {
510         int i;
511
512         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
513                 if (aes_simd_algs[i])
514                         simd_skcipher_free(aes_simd_algs[i]);
515
516         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
517 }
518
519 static int __init aes_init(void)
520 {
521         struct simd_skcipher_alg *simd;
522         const char *basename;
523         const char *algname;
524         const char *drvname;
525         int err;
526         int i;
527
528         if (!(elf_hwcap & HWCAP_NEON))
529                 return -ENODEV;
530
531         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
532         if (err)
533                 return err;
534
535         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
536                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
537                         continue;
538
539                 algname = aes_algs[i].base.cra_name + 2;
540                 drvname = aes_algs[i].base.cra_driver_name + 2;
541                 basename = aes_algs[i].base.cra_driver_name;
542                 simd = simd_skcipher_create_compat(algname, drvname, basename);
543                 err = PTR_ERR(simd);
544                 if (IS_ERR(simd))
545                         goto unregister_simds;
546
547                 aes_simd_algs[i] = simd;
548         }
549         return 0;
550
551 unregister_simds:
552         aes_exit();
553         return err;
554 }
555
556 late_initcall(aes_init);
557 module_exit(aes_exit);