Linux 6.7-rc7
[linux-modified.git] / arch / arm64 / crypto / sm4-ce-ccm-glue.c
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4-CCM AEAD Algorithm using ARMv8 Crypto Extensions
4  * as specified in rfc8998
5  * https://datatracker.ietf.org/doc/html/rfc8998
6  *
7  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
8  */
9
10 #include <linux/module.h>
11 #include <linux/crypto.h>
12 #include <linux/kernel.h>
13 #include <linux/cpufeature.h>
14 #include <asm/neon.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/internal/aead.h>
17 #include <crypto/internal/skcipher.h>
18 #include <crypto/sm4.h>
19 #include "sm4-ce.h"
20
21 asmlinkage void sm4_ce_cbcmac_update(const u32 *rkey_enc, u8 *mac,
22                                      const u8 *src, unsigned int nblocks);
23 asmlinkage void sm4_ce_ccm_enc(const u32 *rkey_enc, u8 *dst, const u8 *src,
24                                u8 *iv, unsigned int nbytes, u8 *mac);
25 asmlinkage void sm4_ce_ccm_dec(const u32 *rkey_enc, u8 *dst, const u8 *src,
26                                u8 *iv, unsigned int nbytes, u8 *mac);
27 asmlinkage void sm4_ce_ccm_final(const u32 *rkey_enc, u8 *iv, u8 *mac);
28
29
30 static int ccm_setkey(struct crypto_aead *tfm, const u8 *key,
31                       unsigned int key_len)
32 {
33         struct sm4_ctx *ctx = crypto_aead_ctx(tfm);
34
35         if (key_len != SM4_KEY_SIZE)
36                 return -EINVAL;
37
38         kernel_neon_begin();
39         sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
40                           crypto_sm4_fk, crypto_sm4_ck);
41         kernel_neon_end();
42
43         return 0;
44 }
45
46 static int ccm_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
47 {
48         if ((authsize & 1) || authsize < 4)
49                 return -EINVAL;
50         return 0;
51 }
52
53 static int ccm_format_input(u8 info[], struct aead_request *req,
54                             unsigned int msglen)
55 {
56         struct crypto_aead *aead = crypto_aead_reqtfm(req);
57         unsigned int l = req->iv[0] + 1;
58         unsigned int m;
59         __be32 len;
60
61         /* verify that CCM dimension 'L': 2 <= L <= 8 */
62         if (l < 2 || l > 8)
63                 return -EINVAL;
64         if (l < 4 && msglen >> (8 * l))
65                 return -EOVERFLOW;
66
67         memset(&req->iv[SM4_BLOCK_SIZE - l], 0, l);
68
69         memcpy(info, req->iv, SM4_BLOCK_SIZE);
70
71         m = crypto_aead_authsize(aead);
72
73         /* format flags field per RFC 3610/NIST 800-38C */
74         *info |= ((m - 2) / 2) << 3;
75         if (req->assoclen)
76                 *info |= (1 << 6);
77
78         /*
79          * format message length field,
80          * Linux uses a u32 type to represent msglen
81          */
82         if (l >= 4)
83                 l = 4;
84
85         len = cpu_to_be32(msglen);
86         memcpy(&info[SM4_BLOCK_SIZE - l], (u8 *)&len + 4 - l, l);
87
88         return 0;
89 }
90
91 static void ccm_calculate_auth_mac(struct aead_request *req, u8 mac[])
92 {
93         struct crypto_aead *aead = crypto_aead_reqtfm(req);
94         struct sm4_ctx *ctx = crypto_aead_ctx(aead);
95         struct __packed { __be16 l; __be32 h; } aadlen;
96         u32 assoclen = req->assoclen;
97         struct scatter_walk walk;
98         unsigned int len;
99
100         if (assoclen < 0xff00) {
101                 aadlen.l = cpu_to_be16(assoclen);
102                 len = 2;
103         } else {
104                 aadlen.l = cpu_to_be16(0xfffe);
105                 put_unaligned_be32(assoclen, &aadlen.h);
106                 len = 6;
107         }
108
109         sm4_ce_crypt_block(ctx->rkey_enc, mac, mac);
110         crypto_xor(mac, (const u8 *)&aadlen, len);
111
112         scatterwalk_start(&walk, req->src);
113
114         do {
115                 u32 n = scatterwalk_clamp(&walk, assoclen);
116                 u8 *p, *ptr;
117
118                 if (!n) {
119                         scatterwalk_start(&walk, sg_next(walk.sg));
120                         n = scatterwalk_clamp(&walk, assoclen);
121                 }
122
123                 p = ptr = scatterwalk_map(&walk);
124                 assoclen -= n;
125                 scatterwalk_advance(&walk, n);
126
127                 while (n > 0) {
128                         unsigned int l, nblocks;
129
130                         if (len == SM4_BLOCK_SIZE) {
131                                 if (n < SM4_BLOCK_SIZE) {
132                                         sm4_ce_crypt_block(ctx->rkey_enc,
133                                                            mac, mac);
134
135                                         len = 0;
136                                 } else {
137                                         nblocks = n / SM4_BLOCK_SIZE;
138                                         sm4_ce_cbcmac_update(ctx->rkey_enc,
139                                                              mac, ptr, nblocks);
140
141                                         ptr += nblocks * SM4_BLOCK_SIZE;
142                                         n %= SM4_BLOCK_SIZE;
143
144                                         continue;
145                                 }
146                         }
147
148                         l = min(n, SM4_BLOCK_SIZE - len);
149                         if (l) {
150                                 crypto_xor(mac + len, ptr, l);
151                                 len += l;
152                                 ptr += l;
153                                 n -= l;
154                         }
155                 }
156
157                 scatterwalk_unmap(p);
158                 scatterwalk_done(&walk, 0, assoclen);
159         } while (assoclen);
160 }
161
162 static int ccm_crypt(struct aead_request *req, struct skcipher_walk *walk,
163                      u32 *rkey_enc, u8 mac[],
164                      void (*sm4_ce_ccm_crypt)(const u32 *rkey_enc, u8 *dst,
165                                         const u8 *src, u8 *iv,
166                                         unsigned int nbytes, u8 *mac))
167 {
168         u8 __aligned(8) ctr0[SM4_BLOCK_SIZE];
169         int err = 0;
170
171         /* preserve the initial ctr0 for the TAG */
172         memcpy(ctr0, walk->iv, SM4_BLOCK_SIZE);
173         crypto_inc(walk->iv, SM4_BLOCK_SIZE);
174
175         kernel_neon_begin();
176
177         if (req->assoclen)
178                 ccm_calculate_auth_mac(req, mac);
179
180         while (walk->nbytes && walk->nbytes != walk->total) {
181                 unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
182
183                 sm4_ce_ccm_crypt(rkey_enc, walk->dst.virt.addr,
184                                  walk->src.virt.addr, walk->iv,
185                                  walk->nbytes - tail, mac);
186
187                 kernel_neon_end();
188
189                 err = skcipher_walk_done(walk, tail);
190
191                 kernel_neon_begin();
192         }
193
194         if (walk->nbytes) {
195                 sm4_ce_ccm_crypt(rkey_enc, walk->dst.virt.addr,
196                                  walk->src.virt.addr, walk->iv,
197                                  walk->nbytes, mac);
198
199                 sm4_ce_ccm_final(rkey_enc, ctr0, mac);
200
201                 kernel_neon_end();
202
203                 err = skcipher_walk_done(walk, 0);
204         } else {
205                 sm4_ce_ccm_final(rkey_enc, ctr0, mac);
206
207                 kernel_neon_end();
208         }
209
210         return err;
211 }
212
213 static int ccm_encrypt(struct aead_request *req)
214 {
215         struct crypto_aead *aead = crypto_aead_reqtfm(req);
216         struct sm4_ctx *ctx = crypto_aead_ctx(aead);
217         u8 __aligned(8) mac[SM4_BLOCK_SIZE];
218         struct skcipher_walk walk;
219         int err;
220
221         err = ccm_format_input(mac, req, req->cryptlen);
222         if (err)
223                 return err;
224
225         err = skcipher_walk_aead_encrypt(&walk, req, false);
226         if (err)
227                 return err;
228
229         err = ccm_crypt(req, &walk, ctx->rkey_enc, mac, sm4_ce_ccm_enc);
230         if (err)
231                 return err;
232
233         /* copy authtag to end of dst */
234         scatterwalk_map_and_copy(mac, req->dst, req->assoclen + req->cryptlen,
235                                  crypto_aead_authsize(aead), 1);
236
237         return 0;
238 }
239
240 static int ccm_decrypt(struct aead_request *req)
241 {
242         struct crypto_aead *aead = crypto_aead_reqtfm(req);
243         unsigned int authsize = crypto_aead_authsize(aead);
244         struct sm4_ctx *ctx = crypto_aead_ctx(aead);
245         u8 __aligned(8) mac[SM4_BLOCK_SIZE];
246         u8 authtag[SM4_BLOCK_SIZE];
247         struct skcipher_walk walk;
248         int err;
249
250         err = ccm_format_input(mac, req, req->cryptlen - authsize);
251         if (err)
252                 return err;
253
254         err = skcipher_walk_aead_decrypt(&walk, req, false);
255         if (err)
256                 return err;
257
258         err = ccm_crypt(req, &walk, ctx->rkey_enc, mac, sm4_ce_ccm_dec);
259         if (err)
260                 return err;
261
262         /* compare calculated auth tag with the stored one */
263         scatterwalk_map_and_copy(authtag, req->src,
264                                  req->assoclen + req->cryptlen - authsize,
265                                  authsize, 0);
266
267         if (crypto_memneq(authtag, mac, authsize))
268                 return -EBADMSG;
269
270         return 0;
271 }
272
273 static struct aead_alg sm4_ccm_alg = {
274         .base = {
275                 .cra_name               = "ccm(sm4)",
276                 .cra_driver_name        = "ccm-sm4-ce",
277                 .cra_priority           = 400,
278                 .cra_blocksize          = 1,
279                 .cra_ctxsize            = sizeof(struct sm4_ctx),
280                 .cra_module             = THIS_MODULE,
281         },
282         .ivsize         = SM4_BLOCK_SIZE,
283         .chunksize      = SM4_BLOCK_SIZE,
284         .maxauthsize    = SM4_BLOCK_SIZE,
285         .setkey         = ccm_setkey,
286         .setauthsize    = ccm_setauthsize,
287         .encrypt        = ccm_encrypt,
288         .decrypt        = ccm_decrypt,
289 };
290
291 static int __init sm4_ce_ccm_init(void)
292 {
293         return crypto_register_aead(&sm4_ccm_alg);
294 }
295
296 static void __exit sm4_ce_ccm_exit(void)
297 {
298         crypto_unregister_aead(&sm4_ccm_alg);
299 }
300
301 module_cpu_feature_match(SM4, sm4_ce_ccm_init);
302 module_exit(sm4_ce_ccm_exit);
303
304 MODULE_DESCRIPTION("Synchronous SM4 in CCM mode using ARMv8 Crypto Extensions");
305 MODULE_ALIAS_CRYPTO("ccm(sm4)");
306 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
307 MODULE_LICENSE("GPL v2");