GNU Linux-libre 6.8.9-gnu
[releases.git] / drivers / crypto / intel / keembay / keembay-ocs-ecc.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Intel Keem Bay OCS ECC Crypto Driver.
4  *
5  * Copyright (C) 2019-2021 Intel Corporation
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <crypto/ecc_curve.h>
11 #include <crypto/ecdh.h>
12 #include <crypto/engine.h>
13 #include <crypto/internal/ecc.h>
14 #include <crypto/internal/kpp.h>
15 #include <crypto/kpp.h>
16 #include <crypto/rng.h>
17 #include <linux/clk.h>
18 #include <linux/completion.h>
19 #include <linux/err.h>
20 #include <linux/fips.h>
21 #include <linux/interrupt.h>
22 #include <linux/io.h>
23 #include <linux/iopoll.h>
24 #include <linux/irq.h>
25 #include <linux/kernel.h>
26 #include <linux/module.h>
27 #include <linux/of.h>
28 #include <linux/platform_device.h>
29 #include <linux/scatterlist.h>
30 #include <linux/string.h>
31
32 #define DRV_NAME                        "keembay-ocs-ecc"
33
34 #define KMB_OCS_ECC_PRIORITY            350
35
36 #define HW_OFFS_OCS_ECC_COMMAND         0x00000000
37 #define HW_OFFS_OCS_ECC_STATUS          0x00000004
38 #define HW_OFFS_OCS_ECC_DATA_IN         0x00000080
39 #define HW_OFFS_OCS_ECC_CX_DATA_OUT     0x00000100
40 #define HW_OFFS_OCS_ECC_CY_DATA_OUT     0x00000180
41 #define HW_OFFS_OCS_ECC_ISR             0x00000400
42 #define HW_OFFS_OCS_ECC_IER             0x00000404
43
44 #define HW_OCS_ECC_ISR_INT_STATUS_DONE  BIT(0)
45 #define HW_OCS_ECC_COMMAND_INS_BP       BIT(0)
46
47 #define HW_OCS_ECC_COMMAND_START_VAL    BIT(0)
48
49 #define OCS_ECC_OP_SIZE_384             BIT(8)
50 #define OCS_ECC_OP_SIZE_256             0
51
52 /* ECC Instruction : for ECC_COMMAND */
53 #define OCS_ECC_INST_WRITE_AX           (0x1 << HW_OCS_ECC_COMMAND_INS_BP)
54 #define OCS_ECC_INST_WRITE_AY           (0x2 << HW_OCS_ECC_COMMAND_INS_BP)
55 #define OCS_ECC_INST_WRITE_BX_D         (0x3 << HW_OCS_ECC_COMMAND_INS_BP)
56 #define OCS_ECC_INST_WRITE_BY_L         (0x4 << HW_OCS_ECC_COMMAND_INS_BP)
57 #define OCS_ECC_INST_WRITE_P            (0x5 << HW_OCS_ECC_COMMAND_INS_BP)
58 #define OCS_ECC_INST_WRITE_A            (0x6 << HW_OCS_ECC_COMMAND_INS_BP)
59 #define OCS_ECC_INST_CALC_D_IDX_A       (0x8 << HW_OCS_ECC_COMMAND_INS_BP)
60 #define OCS_ECC_INST_CALC_A_POW_B_MODP  (0xB << HW_OCS_ECC_COMMAND_INS_BP)
61 #define OCS_ECC_INST_CALC_A_MUL_B_MODP  (0xC  << HW_OCS_ECC_COMMAND_INS_BP)
62 #define OCS_ECC_INST_CALC_A_ADD_B_MODP  (0xD << HW_OCS_ECC_COMMAND_INS_BP)
63
64 #define ECC_ENABLE_INTR                 1
65
66 #define POLL_USEC                       100
67 #define TIMEOUT_USEC                    10000
68
69 #define KMB_ECC_VLI_MAX_DIGITS          ECC_CURVE_NIST_P384_DIGITS
70 #define KMB_ECC_VLI_MAX_BYTES           (KMB_ECC_VLI_MAX_DIGITS \
71                                          << ECC_DIGITS_TO_BYTES_SHIFT)
72
73 #define POW_CUBE                        3
74
75 /**
76  * struct ocs_ecc_dev - ECC device context
77  * @list: List of device contexts
78  * @dev: OCS ECC device
79  * @base_reg: IO base address of OCS ECC
80  * @engine: Crypto engine for the device
81  * @irq_done: IRQ done completion.
82  * @irq: IRQ number
83  */
84 struct ocs_ecc_dev {
85         struct list_head list;
86         struct device *dev;
87         void __iomem *base_reg;
88         struct crypto_engine *engine;
89         struct completion irq_done;
90         int irq;
91 };
92
93 /**
94  * struct ocs_ecc_ctx - Transformation context.
95  * @ecc_dev:     The ECC driver associated with this context.
96  * @curve:       The elliptic curve used by this transformation.
97  * @private_key: The private key.
98  */
99 struct ocs_ecc_ctx {
100         struct ocs_ecc_dev *ecc_dev;
101         const struct ecc_curve *curve;
102         u64 private_key[KMB_ECC_VLI_MAX_DIGITS];
103 };
104
105 /* Driver data. */
106 struct ocs_ecc_drv {
107         struct list_head dev_list;
108         spinlock_t lock;        /* Protects dev_list. */
109 };
110
111 /* Global variable holding the list of OCS ECC devices (only one expected). */
112 static struct ocs_ecc_drv ocs_ecc = {
113         .dev_list = LIST_HEAD_INIT(ocs_ecc.dev_list),
114         .lock = __SPIN_LOCK_UNLOCKED(ocs_ecc.lock),
115 };
116
117 /* Get OCS ECC tfm context from kpp_request. */
118 static inline struct ocs_ecc_ctx *kmb_ocs_ecc_tctx(struct kpp_request *req)
119 {
120         return kpp_tfm_ctx(crypto_kpp_reqtfm(req));
121 }
122
123 /* Converts number of digits to number of bytes. */
124 static inline unsigned int digits_to_bytes(unsigned int n)
125 {
126         return n << ECC_DIGITS_TO_BYTES_SHIFT;
127 }
128
129 /*
130  * Wait for ECC idle i.e when an operation (other than write operations)
131  * is done.
132  */
133 static inline int ocs_ecc_wait_idle(struct ocs_ecc_dev *dev)
134 {
135         u32 value;
136
137         return readl_poll_timeout((dev->base_reg + HW_OFFS_OCS_ECC_STATUS),
138                                   value,
139                                   !(value & HW_OCS_ECC_ISR_INT_STATUS_DONE),
140                                   POLL_USEC, TIMEOUT_USEC);
141 }
142
143 static void ocs_ecc_cmd_start(struct ocs_ecc_dev *ecc_dev, u32 op_size)
144 {
145         iowrite32(op_size | HW_OCS_ECC_COMMAND_START_VAL,
146                   ecc_dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
147 }
148
149 /* Direct write of u32 buffer to ECC engine with associated instruction. */
150 static void ocs_ecc_write_cmd_and_data(struct ocs_ecc_dev *dev,
151                                        u32 op_size,
152                                        u32 inst,
153                                        const void *data_in,
154                                        size_t data_size)
155 {
156         iowrite32(op_size | inst, dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
157
158         /* MMIO Write src uint32 to dst. */
159         memcpy_toio(dev->base_reg + HW_OFFS_OCS_ECC_DATA_IN, data_in,
160                     data_size);
161 }
162
163 /* Start OCS ECC operation and wait for its completion. */
164 static int ocs_ecc_trigger_op(struct ocs_ecc_dev *ecc_dev, u32 op_size,
165                               u32 inst)
166 {
167         reinit_completion(&ecc_dev->irq_done);
168
169         iowrite32(ECC_ENABLE_INTR, ecc_dev->base_reg + HW_OFFS_OCS_ECC_IER);
170         iowrite32(op_size | inst, ecc_dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
171
172         return wait_for_completion_interruptible(&ecc_dev->irq_done);
173 }
174
175 /**
176  * ocs_ecc_read_cx_out() - Read the CX data output buffer.
177  * @dev:        The OCS ECC device to read from.
178  * @cx_out:     The buffer where to store the CX value. Must be at least
179  *              @byte_count byte long.
180  * @byte_count: The amount of data to read.
181  */
182 static inline void ocs_ecc_read_cx_out(struct ocs_ecc_dev *dev, void *cx_out,
183                                        size_t byte_count)
184 {
185         memcpy_fromio(cx_out, dev->base_reg + HW_OFFS_OCS_ECC_CX_DATA_OUT,
186                       byte_count);
187 }
188
189 /**
190  * ocs_ecc_read_cy_out() - Read the CX data output buffer.
191  * @dev:        The OCS ECC device to read from.
192  * @cy_out:     The buffer where to store the CY value. Must be at least
193  *              @byte_count byte long.
194  * @byte_count: The amount of data to read.
195  */
196 static inline void ocs_ecc_read_cy_out(struct ocs_ecc_dev *dev, void *cy_out,
197                                        size_t byte_count)
198 {
199         memcpy_fromio(cy_out, dev->base_reg + HW_OFFS_OCS_ECC_CY_DATA_OUT,
200                       byte_count);
201 }
202
203 static struct ocs_ecc_dev *kmb_ocs_ecc_find_dev(struct ocs_ecc_ctx *tctx)
204 {
205         if (tctx->ecc_dev)
206                 return tctx->ecc_dev;
207
208         spin_lock(&ocs_ecc.lock);
209
210         /* Only a single OCS device available. */
211         tctx->ecc_dev = list_first_entry(&ocs_ecc.dev_list, struct ocs_ecc_dev,
212                                          list);
213
214         spin_unlock(&ocs_ecc.lock);
215
216         return tctx->ecc_dev;
217 }
218
219 /* Do point multiplication using OCS ECC HW. */
220 static int kmb_ecc_point_mult(struct ocs_ecc_dev *ecc_dev,
221                               struct ecc_point *result,
222                               const struct ecc_point *point,
223                               u64 *scalar,
224                               const struct ecc_curve *curve)
225 {
226         u8 sca[KMB_ECC_VLI_MAX_BYTES]; /* Use the maximum data size. */
227         u32 op_size = (curve->g.ndigits > ECC_CURVE_NIST_P256_DIGITS) ?
228                       OCS_ECC_OP_SIZE_384 : OCS_ECC_OP_SIZE_256;
229         size_t nbytes = digits_to_bytes(curve->g.ndigits);
230         int rc = 0;
231
232         /* Generate random nbytes for Simple and Differential SCA protection. */
233         rc = crypto_get_default_rng();
234         if (rc)
235                 return rc;
236
237         rc = crypto_rng_get_bytes(crypto_default_rng, sca, nbytes);
238         crypto_put_default_rng();
239         if (rc)
240                 return rc;
241
242         /* Wait engine to be idle before starting new operation. */
243         rc = ocs_ecc_wait_idle(ecc_dev);
244         if (rc)
245                 return rc;
246
247         /* Send ecc_start pulse as well as indicating operation size. */
248         ocs_ecc_cmd_start(ecc_dev, op_size);
249
250         /* Write ax param; Base point (Gx). */
251         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AX,
252                                    point->x, nbytes);
253
254         /* Write ay param; Base point (Gy). */
255         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AY,
256                                    point->y, nbytes);
257
258         /*
259          * Write the private key into DATA_IN reg.
260          *
261          * Since DATA_IN register is used to write different values during the
262          * computation private Key value is overwritten with
263          * side-channel-resistance value.
264          */
265         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_BX_D,
266                                    scalar, nbytes);
267
268         /* Write operand by/l. */
269         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_BY_L,
270                                    sca, nbytes);
271         memzero_explicit(sca, sizeof(sca));
272
273         /* Write p = curve prime(GF modulus). */
274         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_P,
275                                    curve->p, nbytes);
276
277         /* Write a = curve coefficient. */
278         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_A,
279                                    curve->a, nbytes);
280
281         /* Make hardware perform the multiplication. */
282         rc = ocs_ecc_trigger_op(ecc_dev, op_size, OCS_ECC_INST_CALC_D_IDX_A);
283         if (rc)
284                 return rc;
285
286         /* Read result. */
287         ocs_ecc_read_cx_out(ecc_dev, result->x, nbytes);
288         ocs_ecc_read_cy_out(ecc_dev, result->y, nbytes);
289
290         return 0;
291 }
292
293 /**
294  * kmb_ecc_do_scalar_op() - Perform Scalar operation using OCS ECC HW.
295  * @ecc_dev:    The OCS ECC device to use.
296  * @scalar_out: Where to store the output scalar.
297  * @scalar_a:   Input scalar operand 'a'.
298  * @scalar_b:   Input scalar operand 'b'
299  * @curve:      The curve on which the operation is performed.
300  * @ndigits:    The size of the operands (in digits).
301  * @inst:       The operation to perform (as an OCS ECC instruction).
302  *
303  * Return:      0 on success, negative error code otherwise.
304  */
305 static int kmb_ecc_do_scalar_op(struct ocs_ecc_dev *ecc_dev, u64 *scalar_out,
306                                 const u64 *scalar_a, const u64 *scalar_b,
307                                 const struct ecc_curve *curve,
308                                 unsigned int ndigits, const u32 inst)
309 {
310         u32 op_size = (ndigits > ECC_CURVE_NIST_P256_DIGITS) ?
311                       OCS_ECC_OP_SIZE_384 : OCS_ECC_OP_SIZE_256;
312         size_t nbytes = digits_to_bytes(ndigits);
313         int rc;
314
315         /* Wait engine to be idle before starting new operation. */
316         rc = ocs_ecc_wait_idle(ecc_dev);
317         if (rc)
318                 return rc;
319
320         /* Send ecc_start pulse as well as indicating operation size. */
321         ocs_ecc_cmd_start(ecc_dev, op_size);
322
323         /* Write ax param (Base point (Gx).*/
324         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AX,
325                                    scalar_a, nbytes);
326
327         /* Write ay param Base point (Gy).*/
328         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AY,
329                                    scalar_b, nbytes);
330
331         /* Write p = curve prime(GF modulus).*/
332         ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_P,
333                                    curve->p, nbytes);
334
335         /* Give instruction A.B or A+B to ECC engine. */
336         rc = ocs_ecc_trigger_op(ecc_dev, op_size, inst);
337         if (rc)
338                 return rc;
339
340         ocs_ecc_read_cx_out(ecc_dev, scalar_out, nbytes);
341
342         if (vli_is_zero(scalar_out, ndigits))
343                 return -EINVAL;
344
345         return 0;
346 }
347
348 /* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */
349 static int kmb_ocs_ecc_is_pubkey_valid_partial(struct ocs_ecc_dev *ecc_dev,
350                                                const struct ecc_curve *curve,
351                                                struct ecc_point *pk)
352 {
353         u64 xxx[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
354         u64 yy[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
355         u64 w[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
356         int rc;
357
358         if (WARN_ON(pk->ndigits != curve->g.ndigits))
359                 return -EINVAL;
360
361         /* Check 1: Verify key is not the zero point. */
362         if (ecc_point_is_zero(pk))
363                 return -EINVAL;
364
365         /* Check 2: Verify key is in the range [0, p-1]. */
366         if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1)
367                 return -EINVAL;
368
369         if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1)
370                 return -EINVAL;
371
372         /* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */
373
374          /* y^2 */
375         /* Compute y^2 -> store in yy */
376         rc = kmb_ecc_do_scalar_op(ecc_dev, yy, pk->y, pk->y, curve, pk->ndigits,
377                                   OCS_ECC_INST_CALC_A_MUL_B_MODP);
378         if (rc)
379                 goto exit;
380
381         /* x^3 */
382         /* Assigning w = 3, used for calculating x^3. */
383         w[0] = POW_CUBE;
384         /* Load the next stage.*/
385         rc = kmb_ecc_do_scalar_op(ecc_dev, xxx, pk->x, w, curve, pk->ndigits,
386                                   OCS_ECC_INST_CALC_A_POW_B_MODP);
387         if (rc)
388                 goto exit;
389
390         /* Do a*x -> store in w. */
391         rc = kmb_ecc_do_scalar_op(ecc_dev, w, curve->a, pk->x, curve,
392                                   pk->ndigits,
393                                   OCS_ECC_INST_CALC_A_MUL_B_MODP);
394         if (rc)
395                 goto exit;
396
397         /* Do ax + b == w + b; store in w. */
398         rc = kmb_ecc_do_scalar_op(ecc_dev, w, w, curve->b, curve,
399                                   pk->ndigits,
400                                   OCS_ECC_INST_CALC_A_ADD_B_MODP);
401         if (rc)
402                 goto exit;
403
404         /* x^3 + ax + b == x^3 + w -> store in w. */
405         rc = kmb_ecc_do_scalar_op(ecc_dev, w, xxx, w, curve, pk->ndigits,
406                                   OCS_ECC_INST_CALC_A_ADD_B_MODP);
407         if (rc)
408                 goto exit;
409
410         /* Compare y^2 == x^3 + a·x + b. */
411         rc = vli_cmp(yy, w, pk->ndigits);
412         if (rc)
413                 rc = -EINVAL;
414
415 exit:
416         memzero_explicit(xxx, sizeof(xxx));
417         memzero_explicit(yy, sizeof(yy));
418         memzero_explicit(w, sizeof(w));
419
420         return rc;
421 }
422
423 /* SP800-56A section 5.6.2.3.3 full verification */
424 static int kmb_ocs_ecc_is_pubkey_valid_full(struct ocs_ecc_dev *ecc_dev,
425                                             const struct ecc_curve *curve,
426                                             struct ecc_point *pk)
427 {
428         struct ecc_point *nQ;
429         int rc;
430
431         /* Checks 1 through 3 */
432         rc = kmb_ocs_ecc_is_pubkey_valid_partial(ecc_dev, curve, pk);
433         if (rc)
434                 return rc;
435
436         /* Check 4: Verify that nQ is the zero point. */
437         nQ = ecc_alloc_point(pk->ndigits);
438         if (!nQ)
439                 return -ENOMEM;
440
441         rc = kmb_ecc_point_mult(ecc_dev, nQ, pk, curve->n, curve);
442         if (rc)
443                 goto exit;
444
445         if (!ecc_point_is_zero(nQ))
446                 rc = -EINVAL;
447
448 exit:
449         ecc_free_point(nQ);
450
451         return rc;
452 }
453
454 static int kmb_ecc_is_key_valid(const struct ecc_curve *curve,
455                                 const u64 *private_key, size_t private_key_len)
456 {
457         size_t ndigits = curve->g.ndigits;
458         u64 one[KMB_ECC_VLI_MAX_DIGITS] = {1};
459         u64 res[KMB_ECC_VLI_MAX_DIGITS];
460
461         if (private_key_len != digits_to_bytes(ndigits))
462                 return -EINVAL;
463
464         if (!private_key)
465                 return -EINVAL;
466
467         /* Make sure the private key is in the range [2, n-3]. */
468         if (vli_cmp(one, private_key, ndigits) != -1)
469                 return -EINVAL;
470
471         vli_sub(res, curve->n, one, ndigits);
472         vli_sub(res, res, one, ndigits);
473         if (vli_cmp(res, private_key, ndigits) != 1)
474                 return -EINVAL;
475
476         return 0;
477 }
478
479 /*
480  * ECC private keys are generated using the method of extra random bits,
481  * equivalent to that described in FIPS 186-4, Appendix B.4.1.
482  *
483  * d = (c mod(n–1)) + 1    where c is a string of random bits, 64 bits longer
484  *                         than requested
485  * 0 <= c mod(n-1) <= n-2  and implies that
486  * 1 <= d <= n-1
487  *
488  * This method generates a private key uniformly distributed in the range
489  * [1, n-1].
490  */
491 static int kmb_ecc_gen_privkey(const struct ecc_curve *curve, u64 *privkey)
492 {
493         size_t nbytes = digits_to_bytes(curve->g.ndigits);
494         u64 priv[KMB_ECC_VLI_MAX_DIGITS];
495         size_t nbits;
496         int rc;
497
498         nbits = vli_num_bits(curve->n, curve->g.ndigits);
499
500         /* Check that N is included in Table 1 of FIPS 186-4, section 6.1.1 */
501         if (nbits < 160 || curve->g.ndigits > ARRAY_SIZE(priv))
502                 return -EINVAL;
503
504         /*
505          * FIPS 186-4 recommends that the private key should be obtained from a
506          * RBG with a security strength equal to or greater than the security
507          * strength associated with N.
508          *
509          * The maximum security strength identified by NIST SP800-57pt1r4 for
510          * ECC is 256 (N >= 512).
511          *
512          * This condition is met by the default RNG because it selects a favored
513          * DRBG with a security strength of 256.
514          */
515         if (crypto_get_default_rng())
516                 return -EFAULT;
517
518         rc = crypto_rng_get_bytes(crypto_default_rng, (u8 *)priv, nbytes);
519         crypto_put_default_rng();
520         if (rc)
521                 goto cleanup;
522
523         rc = kmb_ecc_is_key_valid(curve, priv, nbytes);
524         if (rc)
525                 goto cleanup;
526
527         ecc_swap_digits(priv, privkey, curve->g.ndigits);
528
529 cleanup:
530         memzero_explicit(&priv, sizeof(priv));
531
532         return rc;
533 }
534
535 static int kmb_ocs_ecdh_set_secret(struct crypto_kpp *tfm, const void *buf,
536                                    unsigned int len)
537 {
538         struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
539         struct ecdh params;
540         int rc = 0;
541
542         rc = crypto_ecdh_decode_key(buf, len, &params);
543         if (rc)
544                 goto cleanup;
545
546         /* Ensure key size is not bigger then expected. */
547         if (params.key_size > digits_to_bytes(tctx->curve->g.ndigits)) {
548                 rc = -EINVAL;
549                 goto cleanup;
550         }
551
552         /* Auto-generate private key is not provided. */
553         if (!params.key || !params.key_size) {
554                 rc = kmb_ecc_gen_privkey(tctx->curve, tctx->private_key);
555                 goto cleanup;
556         }
557
558         rc = kmb_ecc_is_key_valid(tctx->curve, (const u64 *)params.key,
559                                   params.key_size);
560         if (rc)
561                 goto cleanup;
562
563         ecc_swap_digits((const u64 *)params.key, tctx->private_key,
564                         tctx->curve->g.ndigits);
565 cleanup:
566         memzero_explicit(&params, sizeof(params));
567
568         if (rc)
569                 tctx->curve = NULL;
570
571         return rc;
572 }
573
574 /* Compute shared secret. */
575 static int kmb_ecc_do_shared_secret(struct ocs_ecc_ctx *tctx,
576                                     struct kpp_request *req)
577 {
578         struct ocs_ecc_dev *ecc_dev = tctx->ecc_dev;
579         const struct ecc_curve *curve = tctx->curve;
580         u64 shared_secret[KMB_ECC_VLI_MAX_DIGITS];
581         u64 pubk_buf[KMB_ECC_VLI_MAX_DIGITS * 2];
582         size_t copied, nbytes, pubk_len;
583         struct ecc_point *pk, *result;
584         int rc;
585
586         nbytes = digits_to_bytes(curve->g.ndigits);
587
588         /* Public key is a point, thus it has two coordinates */
589         pubk_len = 2 * nbytes;
590
591         /* Copy public key from SG list to pubk_buf. */
592         copied = sg_copy_to_buffer(req->src,
593                                    sg_nents_for_len(req->src, pubk_len),
594                                    pubk_buf, pubk_len);
595         if (copied != pubk_len)
596                 return -EINVAL;
597
598         /* Allocate and initialize public key point. */
599         pk = ecc_alloc_point(curve->g.ndigits);
600         if (!pk)
601                 return -ENOMEM;
602
603         ecc_swap_digits(pubk_buf, pk->x, curve->g.ndigits);
604         ecc_swap_digits(&pubk_buf[curve->g.ndigits], pk->y, curve->g.ndigits);
605
606         /*
607          * Check the public key for following
608          * Check 1: Verify key is not the zero point.
609          * Check 2: Verify key is in the range [1, p-1].
610          * Check 3: Verify that y^2 == (x^3 + a·x + b) mod p
611          */
612         rc = kmb_ocs_ecc_is_pubkey_valid_partial(ecc_dev, curve, pk);
613         if (rc)
614                 goto exit_free_pk;
615
616         /* Allocate point for storing computed shared secret. */
617         result = ecc_alloc_point(pk->ndigits);
618         if (!result) {
619                 rc = -ENOMEM;
620                 goto exit_free_pk;
621         }
622
623         /* Calculate the shared secret.*/
624         rc = kmb_ecc_point_mult(ecc_dev, result, pk, tctx->private_key, curve);
625         if (rc)
626                 goto exit_free_result;
627
628         if (ecc_point_is_zero(result)) {
629                 rc = -EFAULT;
630                 goto exit_free_result;
631         }
632
633         /* Copy shared secret from point to buffer. */
634         ecc_swap_digits(result->x, shared_secret, result->ndigits);
635
636         /* Request might ask for less bytes than what we have. */
637         nbytes = min_t(size_t, nbytes, req->dst_len);
638
639         copied = sg_copy_from_buffer(req->dst,
640                                      sg_nents_for_len(req->dst, nbytes),
641                                      shared_secret, nbytes);
642
643         if (copied != nbytes)
644                 rc = -EINVAL;
645
646         memzero_explicit(shared_secret, sizeof(shared_secret));
647
648 exit_free_result:
649         ecc_free_point(result);
650
651 exit_free_pk:
652         ecc_free_point(pk);
653
654         return rc;
655 }
656
657 /* Compute public key. */
658 static int kmb_ecc_do_public_key(struct ocs_ecc_ctx *tctx,
659                                  struct kpp_request *req)
660 {
661         const struct ecc_curve *curve = tctx->curve;
662         u64 pubk_buf[KMB_ECC_VLI_MAX_DIGITS * 2];
663         struct ecc_point *pk;
664         size_t pubk_len;
665         size_t copied;
666         int rc;
667
668         /* Public key is a point, so it has double the digits. */
669         pubk_len = 2 * digits_to_bytes(curve->g.ndigits);
670
671         pk = ecc_alloc_point(curve->g.ndigits);
672         if (!pk)
673                 return -ENOMEM;
674
675         /* Public Key(pk) = priv * G. */
676         rc = kmb_ecc_point_mult(tctx->ecc_dev, pk, &curve->g, tctx->private_key,
677                                 curve);
678         if (rc)
679                 goto exit;
680
681         /* SP800-56A rev 3 5.6.2.1.3 key check */
682         if (kmb_ocs_ecc_is_pubkey_valid_full(tctx->ecc_dev, curve, pk)) {
683                 rc = -EAGAIN;
684                 goto exit;
685         }
686
687         /* Copy public key from point to buffer. */
688         ecc_swap_digits(pk->x, pubk_buf, pk->ndigits);
689         ecc_swap_digits(pk->y, &pubk_buf[pk->ndigits], pk->ndigits);
690
691         /* Copy public key to req->dst. */
692         copied = sg_copy_from_buffer(req->dst,
693                                      sg_nents_for_len(req->dst, pubk_len),
694                                      pubk_buf, pubk_len);
695
696         if (copied != pubk_len)
697                 rc = -EINVAL;
698
699 exit:
700         ecc_free_point(pk);
701
702         return rc;
703 }
704
705 static int kmb_ocs_ecc_do_one_request(struct crypto_engine *engine,
706                                       void *areq)
707 {
708         struct kpp_request *req = container_of(areq, struct kpp_request, base);
709         struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
710         struct ocs_ecc_dev *ecc_dev = tctx->ecc_dev;
711         int rc;
712
713         if (req->src)
714                 rc = kmb_ecc_do_shared_secret(tctx, req);
715         else
716                 rc = kmb_ecc_do_public_key(tctx, req);
717
718         crypto_finalize_kpp_request(ecc_dev->engine, req, rc);
719
720         return 0;
721 }
722
723 static int kmb_ocs_ecdh_generate_public_key(struct kpp_request *req)
724 {
725         struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
726         const struct ecc_curve *curve = tctx->curve;
727
728         /* Ensure kmb_ocs_ecdh_set_secret() has been successfully called. */
729         if (!tctx->curve)
730                 return -EINVAL;
731
732         /* Ensure dst is present. */
733         if (!req->dst)
734                 return -EINVAL;
735
736         /* Check the request dst is big enough to hold the public key. */
737         if (req->dst_len < (2 * digits_to_bytes(curve->g.ndigits)))
738                 return -EINVAL;
739
740         /* 'src' is not supposed to be present when generate pubk is called. */
741         if (req->src)
742                 return -EINVAL;
743
744         return crypto_transfer_kpp_request_to_engine(tctx->ecc_dev->engine,
745                                                      req);
746 }
747
748 static int kmb_ocs_ecdh_compute_shared_secret(struct kpp_request *req)
749 {
750         struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
751         const struct ecc_curve *curve = tctx->curve;
752
753         /* Ensure kmb_ocs_ecdh_set_secret() has been successfully called. */
754         if (!tctx->curve)
755                 return -EINVAL;
756
757         /* Ensure dst is present. */
758         if (!req->dst)
759                 return -EINVAL;
760
761         /* Ensure src is present. */
762         if (!req->src)
763                 return -EINVAL;
764
765         /*
766          * req->src is expected to the (other-side) public key, so its length
767          * must be 2 * coordinate size (in bytes).
768          */
769         if (req->src_len != 2 * digits_to_bytes(curve->g.ndigits))
770                 return -EINVAL;
771
772         return crypto_transfer_kpp_request_to_engine(tctx->ecc_dev->engine,
773                                                      req);
774 }
775
776 static int kmb_ecc_tctx_init(struct ocs_ecc_ctx *tctx, unsigned int curve_id)
777 {
778         memset(tctx, 0, sizeof(*tctx));
779
780         tctx->ecc_dev = kmb_ocs_ecc_find_dev(tctx);
781
782         if (IS_ERR(tctx->ecc_dev)) {
783                 pr_err("Failed to find the device : %ld\n",
784                        PTR_ERR(tctx->ecc_dev));
785                 return PTR_ERR(tctx->ecc_dev);
786         }
787
788         tctx->curve = ecc_get_curve(curve_id);
789         if (!tctx->curve)
790                 return -EOPNOTSUPP;
791
792         return 0;
793 }
794
795 static int kmb_ocs_ecdh_nist_p256_init_tfm(struct crypto_kpp *tfm)
796 {
797         struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
798
799         return kmb_ecc_tctx_init(tctx, ECC_CURVE_NIST_P256);
800 }
801
802 static int kmb_ocs_ecdh_nist_p384_init_tfm(struct crypto_kpp *tfm)
803 {
804         struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
805
806         return kmb_ecc_tctx_init(tctx, ECC_CURVE_NIST_P384);
807 }
808
809 static void kmb_ocs_ecdh_exit_tfm(struct crypto_kpp *tfm)
810 {
811         struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
812
813         memzero_explicit(tctx->private_key, sizeof(*tctx->private_key));
814 }
815
816 static unsigned int kmb_ocs_ecdh_max_size(struct crypto_kpp *tfm)
817 {
818         struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
819
820         /* Public key is made of two coordinates, so double the digits. */
821         return digits_to_bytes(tctx->curve->g.ndigits) * 2;
822 }
823
824 static struct kpp_engine_alg ocs_ecdh_p256 = {
825         .base.set_secret = kmb_ocs_ecdh_set_secret,
826         .base.generate_public_key = kmb_ocs_ecdh_generate_public_key,
827         .base.compute_shared_secret = kmb_ocs_ecdh_compute_shared_secret,
828         .base.init = kmb_ocs_ecdh_nist_p256_init_tfm,
829         .base.exit = kmb_ocs_ecdh_exit_tfm,
830         .base.max_size = kmb_ocs_ecdh_max_size,
831         .base.base = {
832                 .cra_name = "ecdh-nist-p256",
833                 .cra_driver_name = "ecdh-nist-p256-keembay-ocs",
834                 .cra_priority = KMB_OCS_ECC_PRIORITY,
835                 .cra_module = THIS_MODULE,
836                 .cra_ctxsize = sizeof(struct ocs_ecc_ctx),
837         },
838         .op.do_one_request = kmb_ocs_ecc_do_one_request,
839 };
840
841 static struct kpp_engine_alg ocs_ecdh_p384 = {
842         .base.set_secret = kmb_ocs_ecdh_set_secret,
843         .base.generate_public_key = kmb_ocs_ecdh_generate_public_key,
844         .base.compute_shared_secret = kmb_ocs_ecdh_compute_shared_secret,
845         .base.init = kmb_ocs_ecdh_nist_p384_init_tfm,
846         .base.exit = kmb_ocs_ecdh_exit_tfm,
847         .base.max_size = kmb_ocs_ecdh_max_size,
848         .base.base = {
849                 .cra_name = "ecdh-nist-p384",
850                 .cra_driver_name = "ecdh-nist-p384-keembay-ocs",
851                 .cra_priority = KMB_OCS_ECC_PRIORITY,
852                 .cra_module = THIS_MODULE,
853                 .cra_ctxsize = sizeof(struct ocs_ecc_ctx),
854         },
855         .op.do_one_request = kmb_ocs_ecc_do_one_request,
856 };
857
858 static irqreturn_t ocs_ecc_irq_handler(int irq, void *dev_id)
859 {
860         struct ocs_ecc_dev *ecc_dev = dev_id;
861         u32 status;
862
863         /*
864          * Read the status register and write it back to clear the
865          * DONE_INT_STATUS bit.
866          */
867         status = ioread32(ecc_dev->base_reg + HW_OFFS_OCS_ECC_ISR);
868         iowrite32(status, ecc_dev->base_reg + HW_OFFS_OCS_ECC_ISR);
869
870         if (!(status & HW_OCS_ECC_ISR_INT_STATUS_DONE))
871                 return IRQ_NONE;
872
873         complete(&ecc_dev->irq_done);
874
875         return IRQ_HANDLED;
876 }
877
878 static int kmb_ocs_ecc_probe(struct platform_device *pdev)
879 {
880         struct device *dev = &pdev->dev;
881         struct ocs_ecc_dev *ecc_dev;
882         int rc;
883
884         ecc_dev = devm_kzalloc(dev, sizeof(*ecc_dev), GFP_KERNEL);
885         if (!ecc_dev)
886                 return -ENOMEM;
887
888         ecc_dev->dev = dev;
889
890         platform_set_drvdata(pdev, ecc_dev);
891
892         INIT_LIST_HEAD(&ecc_dev->list);
893         init_completion(&ecc_dev->irq_done);
894
895         /* Get base register address. */
896         ecc_dev->base_reg = devm_platform_ioremap_resource(pdev, 0);
897         if (IS_ERR(ecc_dev->base_reg)) {
898                 dev_err(dev, "Failed to get base address\n");
899                 rc = PTR_ERR(ecc_dev->base_reg);
900                 goto list_del;
901         }
902
903         /* Get and request IRQ */
904         ecc_dev->irq = platform_get_irq(pdev, 0);
905         if (ecc_dev->irq < 0) {
906                 rc = ecc_dev->irq;
907                 goto list_del;
908         }
909
910         rc = devm_request_threaded_irq(dev, ecc_dev->irq, ocs_ecc_irq_handler,
911                                        NULL, 0, "keembay-ocs-ecc", ecc_dev);
912         if (rc < 0) {
913                 dev_err(dev, "Could not request IRQ\n");
914                 goto list_del;
915         }
916
917         /* Add device to the list of OCS ECC devices. */
918         spin_lock(&ocs_ecc.lock);
919         list_add_tail(&ecc_dev->list, &ocs_ecc.dev_list);
920         spin_unlock(&ocs_ecc.lock);
921
922         /* Initialize crypto engine. */
923         ecc_dev->engine = crypto_engine_alloc_init(dev, 1);
924         if (!ecc_dev->engine) {
925                 dev_err(dev, "Could not allocate crypto engine\n");
926                 rc = -ENOMEM;
927                 goto list_del;
928         }
929
930         rc = crypto_engine_start(ecc_dev->engine);
931         if (rc) {
932                 dev_err(dev, "Could not start crypto engine\n");
933                 goto cleanup;
934         }
935
936         /* Register the KPP algo. */
937         rc = crypto_engine_register_kpp(&ocs_ecdh_p256);
938         if (rc) {
939                 dev_err(dev,
940                         "Could not register OCS algorithms with Crypto API\n");
941                 goto cleanup;
942         }
943
944         rc = crypto_engine_register_kpp(&ocs_ecdh_p384);
945         if (rc) {
946                 dev_err(dev,
947                         "Could not register OCS algorithms with Crypto API\n");
948                 goto ocs_ecdh_p384_error;
949         }
950
951         return 0;
952
953 ocs_ecdh_p384_error:
954         crypto_engine_unregister_kpp(&ocs_ecdh_p256);
955
956 cleanup:
957         crypto_engine_exit(ecc_dev->engine);
958
959 list_del:
960         spin_lock(&ocs_ecc.lock);
961         list_del(&ecc_dev->list);
962         spin_unlock(&ocs_ecc.lock);
963
964         return rc;
965 }
966
967 static void kmb_ocs_ecc_remove(struct platform_device *pdev)
968 {
969         struct ocs_ecc_dev *ecc_dev;
970
971         ecc_dev = platform_get_drvdata(pdev);
972
973         crypto_engine_unregister_kpp(&ocs_ecdh_p384);
974         crypto_engine_unregister_kpp(&ocs_ecdh_p256);
975
976         spin_lock(&ocs_ecc.lock);
977         list_del(&ecc_dev->list);
978         spin_unlock(&ocs_ecc.lock);
979
980         crypto_engine_exit(ecc_dev->engine);
981 }
982
983 /* Device tree driver match. */
984 static const struct of_device_id kmb_ocs_ecc_of_match[] = {
985         {
986                 .compatible = "intel,keembay-ocs-ecc",
987         },
988         {}
989 };
990
991 /* The OCS driver is a platform device. */
992 static struct platform_driver kmb_ocs_ecc_driver = {
993         .probe = kmb_ocs_ecc_probe,
994         .remove_new = kmb_ocs_ecc_remove,
995         .driver = {
996                         .name = DRV_NAME,
997                         .of_match_table = kmb_ocs_ecc_of_match,
998                 },
999 };
1000 module_platform_driver(kmb_ocs_ecc_driver);
1001
1002 MODULE_LICENSE("GPL");
1003 MODULE_DESCRIPTION("Intel Keem Bay OCS ECC Driver");
1004 MODULE_ALIAS_CRYPTO("ecdh-nist-p256");
1005 MODULE_ALIAS_CRYPTO("ecdh-nist-p384");
1006 MODULE_ALIAS_CRYPTO("ecdh-nist-p256-keembay-ocs");
1007 MODULE_ALIAS_CRYPTO("ecdh-nist-p384-keembay-ocs");