GNU Linux-libre 6.8.9-gnu
[releases.git] / drivers / crypto / starfive / jh7110-aes.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * StarFive AES acceleration driver
4  *
5  * Copyright (c) 2022 StarFive Technology
6  */
7
8 #include <crypto/engine.h>
9 #include <crypto/gcm.h>
10 #include <crypto/internal/aead.h>
11 #include <crypto/internal/skcipher.h>
12 #include <crypto/scatterwalk.h>
13 #include "jh7110-cryp.h"
14 #include <linux/err.h>
15 #include <linux/iopoll.h>
16 #include <linux/kernel.h>
17 #include <linux/slab.h>
18 #include <linux/string.h>
19
20 #define STARFIVE_AES_REGS_OFFSET        0x100
21 #define STARFIVE_AES_AESDIO0R           (STARFIVE_AES_REGS_OFFSET + 0x0)
22 #define STARFIVE_AES_KEY0               (STARFIVE_AES_REGS_OFFSET + 0x4)
23 #define STARFIVE_AES_KEY1               (STARFIVE_AES_REGS_OFFSET + 0x8)
24 #define STARFIVE_AES_KEY2               (STARFIVE_AES_REGS_OFFSET + 0xC)
25 #define STARFIVE_AES_KEY3               (STARFIVE_AES_REGS_OFFSET + 0x10)
26 #define STARFIVE_AES_KEY4               (STARFIVE_AES_REGS_OFFSET + 0x14)
27 #define STARFIVE_AES_KEY5               (STARFIVE_AES_REGS_OFFSET + 0x18)
28 #define STARFIVE_AES_KEY6               (STARFIVE_AES_REGS_OFFSET + 0x1C)
29 #define STARFIVE_AES_KEY7               (STARFIVE_AES_REGS_OFFSET + 0x20)
30 #define STARFIVE_AES_CSR                (STARFIVE_AES_REGS_OFFSET + 0x24)
31 #define STARFIVE_AES_IV0                (STARFIVE_AES_REGS_OFFSET + 0x28)
32 #define STARFIVE_AES_IV1                (STARFIVE_AES_REGS_OFFSET + 0x2C)
33 #define STARFIVE_AES_IV2                (STARFIVE_AES_REGS_OFFSET + 0x30)
34 #define STARFIVE_AES_IV3                (STARFIVE_AES_REGS_OFFSET + 0x34)
35 #define STARFIVE_AES_NONCE0             (STARFIVE_AES_REGS_OFFSET + 0x3C)
36 #define STARFIVE_AES_NONCE1             (STARFIVE_AES_REGS_OFFSET + 0x40)
37 #define STARFIVE_AES_NONCE2             (STARFIVE_AES_REGS_OFFSET + 0x44)
38 #define STARFIVE_AES_NONCE3             (STARFIVE_AES_REGS_OFFSET + 0x48)
39 #define STARFIVE_AES_ALEN0              (STARFIVE_AES_REGS_OFFSET + 0x4C)
40 #define STARFIVE_AES_ALEN1              (STARFIVE_AES_REGS_OFFSET + 0x50)
41 #define STARFIVE_AES_MLEN0              (STARFIVE_AES_REGS_OFFSET + 0x54)
42 #define STARFIVE_AES_MLEN1              (STARFIVE_AES_REGS_OFFSET + 0x58)
43 #define STARFIVE_AES_IVLEN              (STARFIVE_AES_REGS_OFFSET + 0x5C)
44
45 #define FLG_MODE_MASK                   GENMASK(2, 0)
46 #define FLG_ENCRYPT                     BIT(4)
47
48 /* Misc */
49 #define CCM_B0_ADATA                    0x40
50 #define AES_BLOCK_32                    (AES_BLOCK_SIZE / sizeof(u32))
51
52 static inline int starfive_aes_wait_busy(struct starfive_cryp_dev *cryp)
53 {
54         u32 status;
55
56         return readl_relaxed_poll_timeout(cryp->base + STARFIVE_AES_CSR, status,
57                                           !(status & STARFIVE_AES_BUSY), 10, 100000);
58 }
59
60 static inline int starfive_aes_wait_keydone(struct starfive_cryp_dev *cryp)
61 {
62         u32 status;
63
64         return readl_relaxed_poll_timeout(cryp->base + STARFIVE_AES_CSR, status,
65                                           (status & STARFIVE_AES_KEY_DONE), 10, 100000);
66 }
67
68 static inline int starfive_aes_wait_gcmdone(struct starfive_cryp_dev *cryp)
69 {
70         u32 status;
71
72         return readl_relaxed_poll_timeout(cryp->base + STARFIVE_AES_CSR, status,
73                                           (status & STARFIVE_AES_GCM_DONE), 10, 100000);
74 }
75
76 static inline int is_gcm(struct starfive_cryp_dev *cryp)
77 {
78         return (cryp->flags & FLG_MODE_MASK) == STARFIVE_AES_MODE_GCM;
79 }
80
81 static inline int is_encrypt(struct starfive_cryp_dev *cryp)
82 {
83         return cryp->flags & FLG_ENCRYPT;
84 }
85
86 static void starfive_aes_aead_hw_start(struct starfive_cryp_ctx *ctx, u32 hw_mode)
87 {
88         struct starfive_cryp_dev *cryp = ctx->cryp;
89         unsigned int value;
90
91         switch (hw_mode) {
92         case STARFIVE_AES_MODE_GCM:
93                 value = readl(ctx->cryp->base + STARFIVE_AES_CSR);
94                 value |= STARFIVE_AES_GCM_START;
95                 writel(value, cryp->base + STARFIVE_AES_CSR);
96                 starfive_aes_wait_gcmdone(cryp);
97                 break;
98         case STARFIVE_AES_MODE_CCM:
99                 value = readl(ctx->cryp->base + STARFIVE_AES_CSR);
100                 value |= STARFIVE_AES_CCM_START;
101                 writel(value, cryp->base + STARFIVE_AES_CSR);
102                 break;
103         }
104 }
105
106 static inline void starfive_aes_set_ivlen(struct starfive_cryp_ctx *ctx)
107 {
108         struct starfive_cryp_dev *cryp = ctx->cryp;
109
110         if (is_gcm(cryp))
111                 writel(GCM_AES_IV_SIZE, cryp->base + STARFIVE_AES_IVLEN);
112         else
113                 writel(AES_BLOCK_SIZE, cryp->base + STARFIVE_AES_IVLEN);
114 }
115
116 static inline void starfive_aes_set_alen(struct starfive_cryp_ctx *ctx)
117 {
118         struct starfive_cryp_dev *cryp = ctx->cryp;
119
120         writel(upper_32_bits(cryp->assoclen), cryp->base + STARFIVE_AES_ALEN0);
121         writel(lower_32_bits(cryp->assoclen), cryp->base + STARFIVE_AES_ALEN1);
122 }
123
124 static inline void starfive_aes_set_mlen(struct starfive_cryp_ctx *ctx)
125 {
126         struct starfive_cryp_dev *cryp = ctx->cryp;
127
128         writel(upper_32_bits(cryp->total_in), cryp->base + STARFIVE_AES_MLEN0);
129         writel(lower_32_bits(cryp->total_in), cryp->base + STARFIVE_AES_MLEN1);
130 }
131
132 static inline int starfive_aes_ccm_check_iv(const u8 *iv)
133 {
134         /* 2 <= L <= 8, so 1 <= L' <= 7. */
135         if (iv[0] < 1 || iv[0] > 7)
136                 return -EINVAL;
137
138         return 0;
139 }
140
141 static int starfive_aes_write_iv(struct starfive_cryp_ctx *ctx, u32 *iv)
142 {
143         struct starfive_cryp_dev *cryp = ctx->cryp;
144
145         writel(iv[0], cryp->base + STARFIVE_AES_IV0);
146         writel(iv[1], cryp->base + STARFIVE_AES_IV1);
147         writel(iv[2], cryp->base + STARFIVE_AES_IV2);
148
149         if (is_gcm(cryp)) {
150                 if (starfive_aes_wait_gcmdone(cryp))
151                         return -ETIMEDOUT;
152
153                 return 0;
154         }
155
156         writel(iv[3], cryp->base + STARFIVE_AES_IV3);
157
158         return 0;
159 }
160
161 static inline void starfive_aes_get_iv(struct starfive_cryp_dev *cryp, u32 *iv)
162 {
163         iv[0] = readl(cryp->base + STARFIVE_AES_IV0);
164         iv[1] = readl(cryp->base + STARFIVE_AES_IV1);
165         iv[2] = readl(cryp->base + STARFIVE_AES_IV2);
166         iv[3] = readl(cryp->base + STARFIVE_AES_IV3);
167 }
168
169 static inline void starfive_aes_write_nonce(struct starfive_cryp_ctx *ctx, u32 *nonce)
170 {
171         struct starfive_cryp_dev *cryp = ctx->cryp;
172
173         writel(nonce[0], cryp->base + STARFIVE_AES_NONCE0);
174         writel(nonce[1], cryp->base + STARFIVE_AES_NONCE1);
175         writel(nonce[2], cryp->base + STARFIVE_AES_NONCE2);
176         writel(nonce[3], cryp->base + STARFIVE_AES_NONCE3);
177 }
178
179 static int starfive_aes_write_key(struct starfive_cryp_ctx *ctx)
180 {
181         struct starfive_cryp_dev *cryp = ctx->cryp;
182         u32 *key = (u32 *)ctx->key;
183
184         if (ctx->keylen >= AES_KEYSIZE_128) {
185                 writel(key[0], cryp->base + STARFIVE_AES_KEY0);
186                 writel(key[1], cryp->base + STARFIVE_AES_KEY1);
187                 writel(key[2], cryp->base + STARFIVE_AES_KEY2);
188                 writel(key[3], cryp->base + STARFIVE_AES_KEY3);
189         }
190
191         if (ctx->keylen >= AES_KEYSIZE_192) {
192                 writel(key[4], cryp->base + STARFIVE_AES_KEY4);
193                 writel(key[5], cryp->base + STARFIVE_AES_KEY5);
194         }
195
196         if (ctx->keylen >= AES_KEYSIZE_256) {
197                 writel(key[6], cryp->base + STARFIVE_AES_KEY6);
198                 writel(key[7], cryp->base + STARFIVE_AES_KEY7);
199         }
200
201         if (starfive_aes_wait_keydone(cryp))
202                 return -ETIMEDOUT;
203
204         return 0;
205 }
206
207 static int starfive_aes_ccm_init(struct starfive_cryp_ctx *ctx)
208 {
209         struct starfive_cryp_dev *cryp = ctx->cryp;
210         u8 iv[AES_BLOCK_SIZE], b0[AES_BLOCK_SIZE];
211         unsigned int textlen;
212
213         memcpy(iv, cryp->req.areq->iv, AES_BLOCK_SIZE);
214         memset(iv + AES_BLOCK_SIZE - 1 - iv[0], 0, iv[0] + 1);
215
216         /* Build B0 */
217         memcpy(b0, iv, AES_BLOCK_SIZE);
218
219         b0[0] |= (8 * ((cryp->authsize - 2) / 2));
220
221         if (cryp->assoclen)
222                 b0[0] |= CCM_B0_ADATA;
223
224         textlen = cryp->total_in;
225
226         b0[AES_BLOCK_SIZE - 2] = textlen >> 8;
227         b0[AES_BLOCK_SIZE - 1] = textlen & 0xFF;
228
229         starfive_aes_write_nonce(ctx, (u32 *)b0);
230
231         return 0;
232 }
233
234 static int starfive_aes_hw_init(struct starfive_cryp_ctx *ctx)
235 {
236         struct starfive_cryp_request_ctx *rctx = ctx->rctx;
237         struct starfive_cryp_dev *cryp = ctx->cryp;
238         u32 hw_mode;
239
240         /* reset */
241         rctx->csr.aes.v = 0;
242         rctx->csr.aes.aesrst = 1;
243         writel(rctx->csr.aes.v, cryp->base + STARFIVE_AES_CSR);
244
245         /* csr setup */
246         hw_mode = cryp->flags & FLG_MODE_MASK;
247
248         rctx->csr.aes.v = 0;
249
250         switch (ctx->keylen) {
251         case AES_KEYSIZE_128:
252                 rctx->csr.aes.keymode = STARFIVE_AES_KEYMODE_128;
253                 break;
254         case AES_KEYSIZE_192:
255                 rctx->csr.aes.keymode = STARFIVE_AES_KEYMODE_192;
256                 break;
257         case AES_KEYSIZE_256:
258                 rctx->csr.aes.keymode = STARFIVE_AES_KEYMODE_256;
259                 break;
260         }
261
262         rctx->csr.aes.mode  = hw_mode;
263         rctx->csr.aes.cmode = !is_encrypt(cryp);
264         rctx->csr.aes.ie = 1;
265         rctx->csr.aes.stmode = STARFIVE_AES_MODE_XFB_1;
266
267         if (cryp->side_chan) {
268                 rctx->csr.aes.delay_aes = 1;
269                 rctx->csr.aes.vaes_start = 1;
270         }
271
272         writel(rctx->csr.aes.v, cryp->base + STARFIVE_AES_CSR);
273
274         cryp->err = starfive_aes_write_key(ctx);
275         if (cryp->err)
276                 return cryp->err;
277
278         switch (hw_mode) {
279         case STARFIVE_AES_MODE_GCM:
280                 starfive_aes_set_alen(ctx);
281                 starfive_aes_set_mlen(ctx);
282                 starfive_aes_set_ivlen(ctx);
283                 starfive_aes_aead_hw_start(ctx, hw_mode);
284                 starfive_aes_write_iv(ctx, (void *)cryp->req.areq->iv);
285                 break;
286         case STARFIVE_AES_MODE_CCM:
287                 starfive_aes_set_alen(ctx);
288                 starfive_aes_set_mlen(ctx);
289                 starfive_aes_ccm_init(ctx);
290                 starfive_aes_aead_hw_start(ctx, hw_mode);
291                 break;
292         case STARFIVE_AES_MODE_CBC:
293         case STARFIVE_AES_MODE_CTR:
294                 starfive_aes_write_iv(ctx, (void *)cryp->req.sreq->iv);
295                 break;
296         default:
297                 break;
298         }
299
300         return cryp->err;
301 }
302
303 static int starfive_aes_read_authtag(struct starfive_cryp_dev *cryp)
304 {
305         int i, start_addr;
306
307         if (starfive_aes_wait_busy(cryp))
308                 return dev_err_probe(cryp->dev, -ETIMEDOUT,
309                                      "Timeout waiting for tag generation.");
310
311         start_addr = STARFIVE_AES_NONCE0;
312
313         if (is_gcm(cryp))
314                 for (i = 0; i < AES_BLOCK_32; i++, start_addr += 4)
315                         cryp->tag_out[i] = readl(cryp->base + start_addr);
316         else
317                 for (i = 0; i < AES_BLOCK_32; i++)
318                         cryp->tag_out[i] = readl(cryp->base + STARFIVE_AES_AESDIO0R);
319
320         if (is_encrypt(cryp)) {
321                 scatterwalk_copychunks(cryp->tag_out, &cryp->out_walk, cryp->authsize, 1);
322         } else {
323                 scatterwalk_copychunks(cryp->tag_in, &cryp->in_walk, cryp->authsize, 0);
324
325                 if (crypto_memneq(cryp->tag_in, cryp->tag_out, cryp->authsize))
326                         return dev_err_probe(cryp->dev, -EBADMSG, "Failed tag verification\n");
327         }
328
329         return 0;
330 }
331
332 static void starfive_aes_finish_req(struct starfive_cryp_dev *cryp)
333 {
334         union starfive_aes_csr csr;
335         int err = cryp->err;
336
337         if (!err && cryp->authsize)
338                 err = starfive_aes_read_authtag(cryp);
339
340         if (!err && ((cryp->flags & FLG_MODE_MASK) == STARFIVE_AES_MODE_CBC ||
341                      (cryp->flags & FLG_MODE_MASK) == STARFIVE_AES_MODE_CTR))
342                 starfive_aes_get_iv(cryp, (void *)cryp->req.sreq->iv);
343
344         /* reset irq flags*/
345         csr.v = 0;
346         csr.aesrst = 1;
347         writel(csr.v, cryp->base + STARFIVE_AES_CSR);
348
349         if (cryp->authsize)
350                 crypto_finalize_aead_request(cryp->engine, cryp->req.areq, err);
351         else
352                 crypto_finalize_skcipher_request(cryp->engine, cryp->req.sreq,
353                                                  err);
354 }
355
356 void starfive_aes_done_task(unsigned long param)
357 {
358         struct starfive_cryp_dev *cryp = (struct starfive_cryp_dev *)param;
359         u32 block[AES_BLOCK_32];
360         u32 stat;
361         int i;
362
363         for (i = 0; i < AES_BLOCK_32; i++)
364                 block[i] = readl(cryp->base + STARFIVE_AES_AESDIO0R);
365
366         scatterwalk_copychunks(block, &cryp->out_walk, min_t(size_t, AES_BLOCK_SIZE,
367                                                              cryp->total_out), 1);
368
369         cryp->total_out -= min_t(size_t, AES_BLOCK_SIZE, cryp->total_out);
370
371         if (!cryp->total_out) {
372                 starfive_aes_finish_req(cryp);
373                 return;
374         }
375
376         memset(block, 0, AES_BLOCK_SIZE);
377         scatterwalk_copychunks(block, &cryp->in_walk, min_t(size_t, AES_BLOCK_SIZE,
378                                                             cryp->total_in), 0);
379         cryp->total_in -= min_t(size_t, AES_BLOCK_SIZE, cryp->total_in);
380
381         for (i = 0; i < AES_BLOCK_32; i++)
382                 writel(block[i], cryp->base + STARFIVE_AES_AESDIO0R);
383
384         stat = readl(cryp->base + STARFIVE_IE_MASK_OFFSET);
385         stat &= ~STARFIVE_IE_MASK_AES_DONE;
386         writel(stat, cryp->base + STARFIVE_IE_MASK_OFFSET);
387 }
388
389 static int starfive_aes_gcm_write_adata(struct starfive_cryp_ctx *ctx)
390 {
391         struct starfive_cryp_dev *cryp = ctx->cryp;
392         struct starfive_cryp_request_ctx *rctx = ctx->rctx;
393         u32 *buffer;
394         int total_len, loop;
395
396         total_len = ALIGN(cryp->assoclen, AES_BLOCK_SIZE) / sizeof(unsigned int);
397         buffer = (u32 *)rctx->adata;
398
399         for (loop = 0; loop < total_len; loop += 4) {
400                 writel(*buffer, cryp->base + STARFIVE_AES_NONCE0);
401                 buffer++;
402                 writel(*buffer, cryp->base + STARFIVE_AES_NONCE1);
403                 buffer++;
404                 writel(*buffer, cryp->base + STARFIVE_AES_NONCE2);
405                 buffer++;
406                 writel(*buffer, cryp->base + STARFIVE_AES_NONCE3);
407                 buffer++;
408         }
409
410         if (starfive_aes_wait_gcmdone(cryp))
411                 return dev_err_probe(cryp->dev, -ETIMEDOUT,
412                                      "Timeout processing gcm aad block");
413
414         return 0;
415 }
416
417 static int starfive_aes_ccm_write_adata(struct starfive_cryp_ctx *ctx)
418 {
419         struct starfive_cryp_dev *cryp = ctx->cryp;
420         struct starfive_cryp_request_ctx *rctx = ctx->rctx;
421         u32 *buffer;
422         u8 *ci;
423         int total_len, loop;
424
425         total_len = cryp->assoclen;
426
427         ci = rctx->adata;
428         writeb(*ci, cryp->base + STARFIVE_AES_AESDIO0R);
429         ci++;
430         writeb(*ci, cryp->base + STARFIVE_AES_AESDIO0R);
431         ci++;
432         total_len -= 2;
433         buffer = (u32 *)ci;
434
435         for (loop = 0; loop < 3; loop++, buffer++)
436                 writel(*buffer, cryp->base + STARFIVE_AES_AESDIO0R);
437
438         total_len -= 12;
439
440         while (total_len > 0) {
441                 for (loop = 0; loop < AES_BLOCK_32; loop++, buffer++)
442                         writel(*buffer, cryp->base + STARFIVE_AES_AESDIO0R);
443
444                 total_len -= AES_BLOCK_SIZE;
445         }
446
447         if (starfive_aes_wait_busy(cryp))
448                 return dev_err_probe(cryp->dev, -ETIMEDOUT,
449                                      "Timeout processing ccm aad block");
450
451         return 0;
452 }
453
454 static int starfive_aes_prepare_req(struct skcipher_request *req,
455                                     struct aead_request *areq)
456 {
457         struct starfive_cryp_ctx *ctx;
458         struct starfive_cryp_request_ctx *rctx;
459         struct starfive_cryp_dev *cryp;
460
461         if (!req && !areq)
462                 return -EINVAL;
463
464         ctx = req ? crypto_skcipher_ctx(crypto_skcipher_reqtfm(req)) :
465                     crypto_aead_ctx(crypto_aead_reqtfm(areq));
466
467         cryp = ctx->cryp;
468         rctx = req ? skcipher_request_ctx(req) : aead_request_ctx(areq);
469
470         if (req) {
471                 cryp->req.sreq = req;
472                 cryp->total_in = req->cryptlen;
473                 cryp->total_out = req->cryptlen;
474                 cryp->assoclen = 0;
475                 cryp->authsize = 0;
476         } else {
477                 cryp->req.areq = areq;
478                 cryp->assoclen = areq->assoclen;
479                 cryp->authsize = crypto_aead_authsize(crypto_aead_reqtfm(areq));
480                 if (is_encrypt(cryp)) {
481                         cryp->total_in = areq->cryptlen;
482                         cryp->total_out = areq->cryptlen;
483                 } else {
484                         cryp->total_in = areq->cryptlen - cryp->authsize;
485                         cryp->total_out = cryp->total_in;
486                 }
487         }
488
489         rctx->in_sg = req ? req->src : areq->src;
490         scatterwalk_start(&cryp->in_walk, rctx->in_sg);
491
492         rctx->out_sg = req ? req->dst : areq->dst;
493         scatterwalk_start(&cryp->out_walk, rctx->out_sg);
494
495         if (cryp->assoclen) {
496                 rctx->adata = kzalloc(cryp->assoclen + AES_BLOCK_SIZE, GFP_KERNEL);
497                 if (!rctx->adata)
498                         return dev_err_probe(cryp->dev, -ENOMEM,
499                                              "Failed to alloc memory for adata");
500
501                 scatterwalk_copychunks(rctx->adata, &cryp->in_walk, cryp->assoclen, 0);
502                 scatterwalk_copychunks(NULL, &cryp->out_walk, cryp->assoclen, 2);
503         }
504
505         ctx->rctx = rctx;
506
507         return starfive_aes_hw_init(ctx);
508 }
509
510 static int starfive_aes_do_one_req(struct crypto_engine *engine, void *areq)
511 {
512         struct skcipher_request *req =
513                 container_of(areq, struct skcipher_request, base);
514         struct starfive_cryp_ctx *ctx =
515                 crypto_skcipher_ctx(crypto_skcipher_reqtfm(req));
516         struct starfive_cryp_dev *cryp = ctx->cryp;
517         u32 block[AES_BLOCK_32];
518         u32 stat;
519         int err;
520         int i;
521
522         err = starfive_aes_prepare_req(req, NULL);
523         if (err)
524                 return err;
525
526         /*
527          * Write first plain/ciphertext block to start the module
528          * then let irq tasklet handle the rest of the data blocks.
529          */
530         scatterwalk_copychunks(block, &cryp->in_walk, min_t(size_t, AES_BLOCK_SIZE,
531                                                             cryp->total_in), 0);
532         cryp->total_in -= min_t(size_t, AES_BLOCK_SIZE, cryp->total_in);
533
534         for (i = 0; i < AES_BLOCK_32; i++)
535                 writel(block[i], cryp->base + STARFIVE_AES_AESDIO0R);
536
537         stat = readl(cryp->base + STARFIVE_IE_MASK_OFFSET);
538         stat &= ~STARFIVE_IE_MASK_AES_DONE;
539         writel(stat, cryp->base + STARFIVE_IE_MASK_OFFSET);
540
541         return 0;
542 }
543
544 static int starfive_aes_init_tfm(struct crypto_skcipher *tfm)
545 {
546         struct starfive_cryp_ctx *ctx = crypto_skcipher_ctx(tfm);
547
548         ctx->cryp = starfive_cryp_find_dev(ctx);
549         if (!ctx->cryp)
550                 return -ENODEV;
551
552         crypto_skcipher_set_reqsize(tfm, sizeof(struct starfive_cryp_request_ctx) +
553                                     sizeof(struct skcipher_request));
554
555         return 0;
556 }
557
558 static int starfive_aes_aead_do_one_req(struct crypto_engine *engine, void *areq)
559 {
560         struct aead_request *req =
561                 container_of(areq, struct aead_request, base);
562         struct starfive_cryp_ctx *ctx =
563                 crypto_aead_ctx(crypto_aead_reqtfm(req));
564         struct starfive_cryp_dev *cryp = ctx->cryp;
565         struct starfive_cryp_request_ctx *rctx;
566         u32 block[AES_BLOCK_32];
567         u32 stat;
568         int err;
569         int i;
570
571         err = starfive_aes_prepare_req(NULL, req);
572         if (err)
573                 return err;
574
575         rctx = ctx->rctx;
576
577         if (!cryp->assoclen)
578                 goto write_text;
579
580         if ((cryp->flags & FLG_MODE_MASK) == STARFIVE_AES_MODE_CCM)
581                 cryp->err = starfive_aes_ccm_write_adata(ctx);
582         else
583                 cryp->err = starfive_aes_gcm_write_adata(ctx);
584
585         kfree(rctx->adata);
586
587         if (cryp->err)
588                 return cryp->err;
589
590 write_text:
591         if (!cryp->total_in)
592                 goto finish_req;
593
594         /*
595          * Write first plain/ciphertext block to start the module
596          * then let irq tasklet handle the rest of the data blocks.
597          */
598         scatterwalk_copychunks(block, &cryp->in_walk, min_t(size_t, AES_BLOCK_SIZE,
599                                                             cryp->total_in), 0);
600         cryp->total_in -= min_t(size_t, AES_BLOCK_SIZE, cryp->total_in);
601
602         for (i = 0; i < AES_BLOCK_32; i++)
603                 writel(block[i], cryp->base + STARFIVE_AES_AESDIO0R);
604
605         stat = readl(cryp->base + STARFIVE_IE_MASK_OFFSET);
606         stat &= ~STARFIVE_IE_MASK_AES_DONE;
607         writel(stat, cryp->base + STARFIVE_IE_MASK_OFFSET);
608
609         return 0;
610
611 finish_req:
612         starfive_aes_finish_req(cryp);
613         return 0;
614 }
615
616 static int starfive_aes_aead_init_tfm(struct crypto_aead *tfm)
617 {
618         struct starfive_cryp_ctx *ctx = crypto_aead_ctx(tfm);
619         struct starfive_cryp_dev *cryp = ctx->cryp;
620         struct crypto_tfm *aead = crypto_aead_tfm(tfm);
621         struct crypto_alg *alg = aead->__crt_alg;
622
623         ctx->cryp = starfive_cryp_find_dev(ctx);
624         if (!ctx->cryp)
625                 return -ENODEV;
626
627         if (alg->cra_flags & CRYPTO_ALG_NEED_FALLBACK) {
628                 ctx->aead_fbk = crypto_alloc_aead(alg->cra_name, 0,
629                                                   CRYPTO_ALG_NEED_FALLBACK);
630                 if (IS_ERR(ctx->aead_fbk))
631                         return dev_err_probe(cryp->dev, PTR_ERR(ctx->aead_fbk),
632                                              "%s() failed to allocate fallback for %s\n",
633                                              __func__, alg->cra_name);
634         }
635
636         crypto_aead_set_reqsize(tfm, sizeof(struct starfive_cryp_ctx) +
637                                 sizeof(struct aead_request));
638
639         return 0;
640 }
641
642 static void starfive_aes_aead_exit_tfm(struct crypto_aead *tfm)
643 {
644         struct starfive_cryp_ctx *ctx = crypto_aead_ctx(tfm);
645
646         crypto_free_aead(ctx->aead_fbk);
647 }
648
649 static int starfive_aes_crypt(struct skcipher_request *req, unsigned long flags)
650 {
651         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
652         struct starfive_cryp_ctx *ctx = crypto_skcipher_ctx(tfm);
653         struct starfive_cryp_dev *cryp = ctx->cryp;
654         unsigned int blocksize_align = crypto_skcipher_blocksize(tfm) - 1;
655
656         cryp->flags = flags;
657
658         if ((cryp->flags & FLG_MODE_MASK) == STARFIVE_AES_MODE_ECB ||
659             (cryp->flags & FLG_MODE_MASK) == STARFIVE_AES_MODE_CBC)
660                 if (req->cryptlen & blocksize_align)
661                         return -EINVAL;
662
663         return crypto_transfer_skcipher_request_to_engine(cryp->engine, req);
664 }
665
666 static int starfive_aes_aead_crypt(struct aead_request *req, unsigned long flags)
667 {
668         struct starfive_cryp_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
669         struct starfive_cryp_dev *cryp = ctx->cryp;
670
671         cryp->flags = flags;
672
673         /*
674          * HW engine could not perform CCM tag verification on
675          * non-blocksize aligned text, use fallback algo instead
676          */
677         if (ctx->aead_fbk && !is_encrypt(cryp)) {
678                 struct aead_request *subreq = aead_request_ctx(req);
679
680                 aead_request_set_tfm(subreq, ctx->aead_fbk);
681                 aead_request_set_callback(subreq, req->base.flags,
682                                           req->base.complete, req->base.data);
683                 aead_request_set_crypt(subreq, req->src,
684                                        req->dst, req->cryptlen, req->iv);
685                 aead_request_set_ad(subreq, req->assoclen);
686
687                 return crypto_aead_decrypt(subreq);
688         }
689
690         return crypto_transfer_aead_request_to_engine(cryp->engine, req);
691 }
692
693 static int starfive_aes_setkey(struct crypto_skcipher *tfm, const u8 *key,
694                                unsigned int keylen)
695 {
696         struct starfive_cryp_ctx *ctx = crypto_skcipher_ctx(tfm);
697
698         if (!key || !keylen)
699                 return -EINVAL;
700
701         if (keylen != AES_KEYSIZE_128 &&
702             keylen != AES_KEYSIZE_192 &&
703             keylen != AES_KEYSIZE_256)
704                 return -EINVAL;
705
706         memcpy(ctx->key, key, keylen);
707         ctx->keylen = keylen;
708
709         return 0;
710 }
711
712 static int starfive_aes_aead_setkey(struct crypto_aead *tfm, const u8 *key,
713                                     unsigned int keylen)
714 {
715         struct starfive_cryp_ctx *ctx = crypto_aead_ctx(tfm);
716
717         if (!key || !keylen)
718                 return -EINVAL;
719
720         if (keylen != AES_KEYSIZE_128 &&
721             keylen != AES_KEYSIZE_192 &&
722             keylen != AES_KEYSIZE_256)
723                 return -EINVAL;
724
725         memcpy(ctx->key, key, keylen);
726         ctx->keylen = keylen;
727
728         if (ctx->aead_fbk)
729                 return crypto_aead_setkey(ctx->aead_fbk, key, keylen);
730
731         return 0;
732 }
733
734 static int starfive_aes_gcm_setauthsize(struct crypto_aead *tfm,
735                                         unsigned int authsize)
736 {
737         return crypto_gcm_check_authsize(authsize);
738 }
739
740 static int starfive_aes_ccm_setauthsize(struct crypto_aead *tfm,
741                                         unsigned int authsize)
742 {
743         struct starfive_cryp_ctx *ctx = crypto_aead_ctx(tfm);
744
745         switch (authsize) {
746         case 4:
747         case 6:
748         case 8:
749         case 10:
750         case 12:
751         case 14:
752         case 16:
753                 break;
754         default:
755                 return -EINVAL;
756         }
757
758         return crypto_aead_setauthsize(ctx->aead_fbk, authsize);
759 }
760
761 static int starfive_aes_ecb_encrypt(struct skcipher_request *req)
762 {
763         return starfive_aes_crypt(req, STARFIVE_AES_MODE_ECB | FLG_ENCRYPT);
764 }
765
766 static int starfive_aes_ecb_decrypt(struct skcipher_request *req)
767 {
768         return starfive_aes_crypt(req, STARFIVE_AES_MODE_ECB);
769 }
770
771 static int starfive_aes_cbc_encrypt(struct skcipher_request *req)
772 {
773         return starfive_aes_crypt(req, STARFIVE_AES_MODE_CBC | FLG_ENCRYPT);
774 }
775
776 static int starfive_aes_cbc_decrypt(struct skcipher_request *req)
777 {
778         return starfive_aes_crypt(req, STARFIVE_AES_MODE_CBC);
779 }
780
781 static int starfive_aes_ctr_encrypt(struct skcipher_request *req)
782 {
783         return starfive_aes_crypt(req, STARFIVE_AES_MODE_CTR | FLG_ENCRYPT);
784 }
785
786 static int starfive_aes_ctr_decrypt(struct skcipher_request *req)
787 {
788         return starfive_aes_crypt(req, STARFIVE_AES_MODE_CTR);
789 }
790
791 static int starfive_aes_gcm_encrypt(struct aead_request *req)
792 {
793         return starfive_aes_aead_crypt(req, STARFIVE_AES_MODE_GCM | FLG_ENCRYPT);
794 }
795
796 static int starfive_aes_gcm_decrypt(struct aead_request *req)
797 {
798         return starfive_aes_aead_crypt(req, STARFIVE_AES_MODE_GCM);
799 }
800
801 static int starfive_aes_ccm_encrypt(struct aead_request *req)
802 {
803         int ret;
804
805         ret = starfive_aes_ccm_check_iv(req->iv);
806         if (ret)
807                 return ret;
808
809         return starfive_aes_aead_crypt(req, STARFIVE_AES_MODE_CCM | FLG_ENCRYPT);
810 }
811
812 static int starfive_aes_ccm_decrypt(struct aead_request *req)
813 {
814         int ret;
815
816         ret = starfive_aes_ccm_check_iv(req->iv);
817         if (ret)
818                 return ret;
819
820         return starfive_aes_aead_crypt(req, STARFIVE_AES_MODE_CCM);
821 }
822
823 static struct skcipher_engine_alg skcipher_algs[] = {
824 {
825         .base.init                      = starfive_aes_init_tfm,
826         .base.setkey                    = starfive_aes_setkey,
827         .base.encrypt                   = starfive_aes_ecb_encrypt,
828         .base.decrypt                   = starfive_aes_ecb_decrypt,
829         .base.min_keysize               = AES_MIN_KEY_SIZE,
830         .base.max_keysize               = AES_MAX_KEY_SIZE,
831         .base.base = {
832                 .cra_name               = "ecb(aes)",
833                 .cra_driver_name        = "starfive-ecb-aes",
834                 .cra_priority           = 200,
835                 .cra_flags              = CRYPTO_ALG_ASYNC,
836                 .cra_blocksize          = AES_BLOCK_SIZE,
837                 .cra_ctxsize            = sizeof(struct starfive_cryp_ctx),
838                 .cra_alignmask          = 0xf,
839                 .cra_module             = THIS_MODULE,
840         },
841         .op = {
842                 .do_one_request = starfive_aes_do_one_req,
843         },
844 }, {
845         .base.init                      = starfive_aes_init_tfm,
846         .base.setkey                    = starfive_aes_setkey,
847         .base.encrypt                   = starfive_aes_cbc_encrypt,
848         .base.decrypt                   = starfive_aes_cbc_decrypt,
849         .base.min_keysize               = AES_MIN_KEY_SIZE,
850         .base.max_keysize               = AES_MAX_KEY_SIZE,
851         .base.ivsize                    = AES_BLOCK_SIZE,
852         .base.base = {
853                 .cra_name               = "cbc(aes)",
854                 .cra_driver_name        = "starfive-cbc-aes",
855                 .cra_priority           = 200,
856                 .cra_flags              = CRYPTO_ALG_ASYNC,
857                 .cra_blocksize          = AES_BLOCK_SIZE,
858                 .cra_ctxsize            = sizeof(struct starfive_cryp_ctx),
859                 .cra_alignmask          = 0xf,
860                 .cra_module             = THIS_MODULE,
861         },
862         .op = {
863                 .do_one_request = starfive_aes_do_one_req,
864         },
865 }, {
866         .base.init                      = starfive_aes_init_tfm,
867         .base.setkey                    = starfive_aes_setkey,
868         .base.encrypt                   = starfive_aes_ctr_encrypt,
869         .base.decrypt                   = starfive_aes_ctr_decrypt,
870         .base.min_keysize               = AES_MIN_KEY_SIZE,
871         .base.max_keysize               = AES_MAX_KEY_SIZE,
872         .base.ivsize                    = AES_BLOCK_SIZE,
873         .base.base = {
874                 .cra_name               = "ctr(aes)",
875                 .cra_driver_name        = "starfive-ctr-aes",
876                 .cra_priority           = 200,
877                 .cra_flags              = CRYPTO_ALG_ASYNC,
878                 .cra_blocksize          = 1,
879                 .cra_ctxsize            = sizeof(struct starfive_cryp_ctx),
880                 .cra_alignmask          = 0xf,
881                 .cra_module             = THIS_MODULE,
882         },
883         .op = {
884                 .do_one_request = starfive_aes_do_one_req,
885         },
886 },
887 };
888
889 static struct aead_engine_alg aead_algs[] = {
890 {
891         .base.setkey                    = starfive_aes_aead_setkey,
892         .base.setauthsize               = starfive_aes_gcm_setauthsize,
893         .base.encrypt                   = starfive_aes_gcm_encrypt,
894         .base.decrypt                   = starfive_aes_gcm_decrypt,
895         .base.init                      = starfive_aes_aead_init_tfm,
896         .base.exit                      = starfive_aes_aead_exit_tfm,
897         .base.ivsize                    = GCM_AES_IV_SIZE,
898         .base.maxauthsize               = AES_BLOCK_SIZE,
899         .base.base = {
900                 .cra_name               = "gcm(aes)",
901                 .cra_driver_name        = "starfive-gcm-aes",
902                 .cra_priority           = 200,
903                 .cra_flags              = CRYPTO_ALG_ASYNC,
904                 .cra_blocksize          = 1,
905                 .cra_ctxsize            = sizeof(struct starfive_cryp_ctx),
906                 .cra_alignmask          = 0xf,
907                 .cra_module             = THIS_MODULE,
908         },
909         .op = {
910                 .do_one_request = starfive_aes_aead_do_one_req,
911         },
912 }, {
913         .base.setkey                    = starfive_aes_aead_setkey,
914         .base.setauthsize               = starfive_aes_ccm_setauthsize,
915         .base.encrypt                   = starfive_aes_ccm_encrypt,
916         .base.decrypt                   = starfive_aes_ccm_decrypt,
917         .base.init                      = starfive_aes_aead_init_tfm,
918         .base.exit                      = starfive_aes_aead_exit_tfm,
919         .base.ivsize                    = AES_BLOCK_SIZE,
920         .base.maxauthsize               = AES_BLOCK_SIZE,
921         .base.base = {
922                 .cra_name               = "ccm(aes)",
923                 .cra_driver_name        = "starfive-ccm-aes",
924                 .cra_priority           = 200,
925                 .cra_flags              = CRYPTO_ALG_ASYNC |
926                                           CRYPTO_ALG_NEED_FALLBACK,
927                 .cra_blocksize          = 1,
928                 .cra_ctxsize            = sizeof(struct starfive_cryp_ctx),
929                 .cra_alignmask          = 0xf,
930                 .cra_module             = THIS_MODULE,
931         },
932         .op = {
933                 .do_one_request = starfive_aes_aead_do_one_req,
934         },
935 },
936 };
937
938 int starfive_aes_register_algs(void)
939 {
940         int ret;
941
942         ret = crypto_engine_register_skciphers(skcipher_algs, ARRAY_SIZE(skcipher_algs));
943         if (ret)
944                 return ret;
945
946         ret = crypto_engine_register_aeads(aead_algs, ARRAY_SIZE(aead_algs));
947         if (ret)
948                 crypto_engine_unregister_skciphers(skcipher_algs, ARRAY_SIZE(skcipher_algs));
949
950         return ret;
951 }
952
953 void starfive_aes_unregister_algs(void)
954 {
955         crypto_engine_unregister_aeads(aead_algs, ARRAY_SIZE(aead_algs));
956         crypto_engine_unregister_skciphers(skcipher_algs, ARRAY_SIZE(skcipher_algs));
957 }