Mention branches and keyring.
[releases.git] / tls / tls_main.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  *
5  * This software is available to you under a choice of one of two
6  * licenses.  You may choose to be licensed under the terms of the GNU
7  * General Public License (GPL) Version 2, available from the file
8  * COPYING in the main directory of this source tree, or the
9  * OpenIB.org BSD license below:
10  *
11  *     Redistribution and use in source and binary forms, with or
12  *     without modification, are permitted provided that the following
13  *     conditions are met:
14  *
15  *      - Redistributions of source code must retain the above
16  *        copyright notice, this list of conditions and the following
17  *        disclaimer.
18  *
19  *      - Redistributions in binary form must reproduce the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer in the documentation and/or other materials
22  *        provided with the distribution.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31  * SOFTWARE.
32  */
33
34 #include <linux/module.h>
35
36 #include <net/tcp.h>
37 #include <net/inet_common.h>
38 #include <linux/highmem.h>
39 #include <linux/netdevice.h>
40 #include <linux/sched/signal.h>
41 #include <linux/inetdevice.h>
42 #include <linux/inet_diag.h>
43
44 #include <net/snmp.h>
45 #include <net/tls.h>
46 #include <net/tls_toe.h>
47
48 #include "tls.h"
49
50 MODULE_AUTHOR("Mellanox Technologies");
51 MODULE_DESCRIPTION("Transport Layer Security Support");
52 MODULE_LICENSE("Dual BSD/GPL");
53 MODULE_ALIAS_TCP_ULP("tls");
54
55 enum {
56         TLSV4,
57         TLSV6,
58         TLS_NUM_PROTS,
59 };
60
61 #define CIPHER_SIZE_DESC(cipher) [cipher] = { \
62         .iv = cipher ## _IV_SIZE, \
63         .key = cipher ## _KEY_SIZE, \
64         .salt = cipher ## _SALT_SIZE, \
65         .tag = cipher ## _TAG_SIZE, \
66         .rec_seq = cipher ## _REC_SEQ_SIZE, \
67 }
68
69 const struct tls_cipher_size_desc tls_cipher_size_desc[] = {
70         CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128),
71         CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256),
72         CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128),
73         CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305),
74         CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM),
75         CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM),
76 };
77
78 static const struct proto *saved_tcpv6_prot;
79 static DEFINE_MUTEX(tcpv6_prot_mutex);
80 static const struct proto *saved_tcpv4_prot;
81 static DEFINE_MUTEX(tcpv4_prot_mutex);
82 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
83 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
84 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
85                          const struct proto *base);
86
87 void update_sk_prot(struct sock *sk, struct tls_context *ctx)
88 {
89         int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
90
91         WRITE_ONCE(sk->sk_prot,
92                    &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
93         WRITE_ONCE(sk->sk_socket->ops,
94                    &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
95 }
96
97 int wait_on_pending_writer(struct sock *sk, long *timeo)
98 {
99         DEFINE_WAIT_FUNC(wait, woken_wake_function);
100         int ret, rc = 0;
101
102         add_wait_queue(sk_sleep(sk), &wait);
103         while (1) {
104                 if (!*timeo) {
105                         rc = -EAGAIN;
106                         break;
107                 }
108
109                 if (signal_pending(current)) {
110                         rc = sock_intr_errno(*timeo);
111                         break;
112                 }
113
114                 ret = sk_wait_event(sk, timeo,
115                                     !READ_ONCE(sk->sk_write_pending), &wait);
116                 if (ret) {
117                         if (ret < 0)
118                                 rc = ret;
119                         break;
120                 }
121         }
122         remove_wait_queue(sk_sleep(sk), &wait);
123         return rc;
124 }
125
126 int tls_push_sg(struct sock *sk,
127                 struct tls_context *ctx,
128                 struct scatterlist *sg,
129                 u16 first_offset,
130                 int flags)
131 {
132         int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
133         int ret = 0;
134         struct page *p;
135         size_t size;
136         int offset = first_offset;
137
138         size = sg->length - offset;
139         offset += sg->offset;
140
141         ctx->in_tcp_sendpages = true;
142         while (1) {
143                 if (sg_is_last(sg))
144                         sendpage_flags = flags;
145
146                 /* is sending application-limited? */
147                 tcp_rate_check_app_limited(sk);
148                 p = sg_page(sg);
149 retry:
150                 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
151
152                 if (ret != size) {
153                         if (ret > 0) {
154                                 offset += ret;
155                                 size -= ret;
156                                 goto retry;
157                         }
158
159                         offset -= sg->offset;
160                         ctx->partially_sent_offset = offset;
161                         ctx->partially_sent_record = (void *)sg;
162                         ctx->in_tcp_sendpages = false;
163                         return ret;
164                 }
165
166                 put_page(p);
167                 sk_mem_uncharge(sk, sg->length);
168                 sg = sg_next(sg);
169                 if (!sg)
170                         break;
171
172                 offset = sg->offset;
173                 size = sg->length;
174         }
175
176         ctx->in_tcp_sendpages = false;
177
178         return 0;
179 }
180
181 static int tls_handle_open_record(struct sock *sk, int flags)
182 {
183         struct tls_context *ctx = tls_get_ctx(sk);
184
185         if (tls_is_pending_open_record(ctx))
186                 return ctx->push_pending_record(sk, flags);
187
188         return 0;
189 }
190
191 int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
192                      unsigned char *record_type)
193 {
194         struct cmsghdr *cmsg;
195         int rc = -EINVAL;
196
197         for_each_cmsghdr(cmsg, msg) {
198                 if (!CMSG_OK(msg, cmsg))
199                         return -EINVAL;
200                 if (cmsg->cmsg_level != SOL_TLS)
201                         continue;
202
203                 switch (cmsg->cmsg_type) {
204                 case TLS_SET_RECORD_TYPE:
205                         if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
206                                 return -EINVAL;
207
208                         if (msg->msg_flags & MSG_MORE)
209                                 return -EINVAL;
210
211                         rc = tls_handle_open_record(sk, msg->msg_flags);
212                         if (rc)
213                                 return rc;
214
215                         *record_type = *(unsigned char *)CMSG_DATA(cmsg);
216                         rc = 0;
217                         break;
218                 default:
219                         return -EINVAL;
220                 }
221         }
222
223         return rc;
224 }
225
226 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
227                             int flags)
228 {
229         struct scatterlist *sg;
230         u16 offset;
231
232         sg = ctx->partially_sent_record;
233         offset = ctx->partially_sent_offset;
234
235         ctx->partially_sent_record = NULL;
236         return tls_push_sg(sk, ctx, sg, offset, flags);
237 }
238
239 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
240 {
241         struct scatterlist *sg;
242
243         for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
244                 put_page(sg_page(sg));
245                 sk_mem_uncharge(sk, sg->length);
246         }
247         ctx->partially_sent_record = NULL;
248 }
249
250 static void tls_write_space(struct sock *sk)
251 {
252         struct tls_context *ctx = tls_get_ctx(sk);
253
254         /* If in_tcp_sendpages call lower protocol write space handler
255          * to ensure we wake up any waiting operations there. For example
256          * if do_tcp_sendpages where to call sk_wait_event.
257          */
258         if (ctx->in_tcp_sendpages) {
259                 ctx->sk_write_space(sk);
260                 return;
261         }
262
263 #ifdef CONFIG_TLS_DEVICE
264         if (ctx->tx_conf == TLS_HW)
265                 tls_device_write_space(sk, ctx);
266         else
267 #endif
268                 tls_sw_write_space(sk, ctx);
269
270         ctx->sk_write_space(sk);
271 }
272
273 /**
274  * tls_ctx_free() - free TLS ULP context
275  * @sk:  socket to with @ctx is attached
276  * @ctx: TLS context structure
277  *
278  * Free TLS context. If @sk is %NULL caller guarantees that the socket
279  * to which @ctx was attached has no outstanding references.
280  */
281 void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
282 {
283         if (!ctx)
284                 return;
285
286         memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
287         memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
288         mutex_destroy(&ctx->tx_lock);
289
290         if (sk)
291                 kfree_rcu(ctx, rcu);
292         else
293                 kfree(ctx);
294 }
295
296 static void tls_sk_proto_cleanup(struct sock *sk,
297                                  struct tls_context *ctx, long timeo)
298 {
299         if (unlikely(sk->sk_write_pending) &&
300             !wait_on_pending_writer(sk, &timeo))
301                 tls_handle_open_record(sk, 0);
302
303         /* We need these for tls_sw_fallback handling of other packets */
304         if (ctx->tx_conf == TLS_SW) {
305                 kfree(ctx->tx.rec_seq);
306                 kfree(ctx->tx.iv);
307                 tls_sw_release_resources_tx(sk);
308                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
309         } else if (ctx->tx_conf == TLS_HW) {
310                 tls_device_free_resources_tx(sk);
311                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
312         }
313
314         if (ctx->rx_conf == TLS_SW) {
315                 tls_sw_release_resources_rx(sk);
316                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
317         } else if (ctx->rx_conf == TLS_HW) {
318                 tls_device_offload_cleanup_rx(sk);
319                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
320         }
321 }
322
323 static void tls_sk_proto_close(struct sock *sk, long timeout)
324 {
325         struct inet_connection_sock *icsk = inet_csk(sk);
326         struct tls_context *ctx = tls_get_ctx(sk);
327         long timeo = sock_sndtimeo(sk, 0);
328         bool free_ctx;
329
330         if (ctx->tx_conf == TLS_SW)
331                 tls_sw_cancel_work_tx(ctx);
332
333         lock_sock(sk);
334         free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
335
336         if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
337                 tls_sk_proto_cleanup(sk, ctx, timeo);
338
339         write_lock_bh(&sk->sk_callback_lock);
340         if (free_ctx)
341                 rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
342         WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
343         if (sk->sk_write_space == tls_write_space)
344                 sk->sk_write_space = ctx->sk_write_space;
345         write_unlock_bh(&sk->sk_callback_lock);
346         release_sock(sk);
347         if (ctx->tx_conf == TLS_SW)
348                 tls_sw_free_ctx_tx(ctx);
349         if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
350                 tls_sw_strparser_done(ctx);
351         if (ctx->rx_conf == TLS_SW)
352                 tls_sw_free_ctx_rx(ctx);
353         ctx->sk_proto->close(sk, timeout);
354
355         if (free_ctx)
356                 tls_ctx_free(sk, ctx);
357 }
358
359 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
360                                   int __user *optlen, int tx)
361 {
362         int rc = 0;
363         struct tls_context *ctx = tls_get_ctx(sk);
364         struct tls_crypto_info *crypto_info;
365         struct cipher_context *cctx;
366         int len;
367
368         if (get_user(len, optlen))
369                 return -EFAULT;
370
371         if (!optval || (len < sizeof(*crypto_info))) {
372                 rc = -EINVAL;
373                 goto out;
374         }
375
376         if (!ctx) {
377                 rc = -EBUSY;
378                 goto out;
379         }
380
381         /* get user crypto info */
382         if (tx) {
383                 crypto_info = &ctx->crypto_send.info;
384                 cctx = &ctx->tx;
385         } else {
386                 crypto_info = &ctx->crypto_recv.info;
387                 cctx = &ctx->rx;
388         }
389
390         if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
391                 rc = -EBUSY;
392                 goto out;
393         }
394
395         if (len == sizeof(*crypto_info)) {
396                 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
397                         rc = -EFAULT;
398                 goto out;
399         }
400
401         switch (crypto_info->cipher_type) {
402         case TLS_CIPHER_AES_GCM_128: {
403                 struct tls12_crypto_info_aes_gcm_128 *
404                   crypto_info_aes_gcm_128 =
405                   container_of(crypto_info,
406                                struct tls12_crypto_info_aes_gcm_128,
407                                info);
408
409                 if (len != sizeof(*crypto_info_aes_gcm_128)) {
410                         rc = -EINVAL;
411                         goto out;
412                 }
413                 memcpy(crypto_info_aes_gcm_128->iv,
414                        cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
415                        TLS_CIPHER_AES_GCM_128_IV_SIZE);
416                 memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
417                        TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
418                 if (copy_to_user(optval,
419                                  crypto_info_aes_gcm_128,
420                                  sizeof(*crypto_info_aes_gcm_128)))
421                         rc = -EFAULT;
422                 break;
423         }
424         case TLS_CIPHER_AES_GCM_256: {
425                 struct tls12_crypto_info_aes_gcm_256 *
426                   crypto_info_aes_gcm_256 =
427                   container_of(crypto_info,
428                                struct tls12_crypto_info_aes_gcm_256,
429                                info);
430
431                 if (len != sizeof(*crypto_info_aes_gcm_256)) {
432                         rc = -EINVAL;
433                         goto out;
434                 }
435                 memcpy(crypto_info_aes_gcm_256->iv,
436                        cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
437                        TLS_CIPHER_AES_GCM_256_IV_SIZE);
438                 memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
439                        TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
440                 if (copy_to_user(optval,
441                                  crypto_info_aes_gcm_256,
442                                  sizeof(*crypto_info_aes_gcm_256)))
443                         rc = -EFAULT;
444                 break;
445         }
446         case TLS_CIPHER_AES_CCM_128: {
447                 struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
448                         container_of(crypto_info,
449                                 struct tls12_crypto_info_aes_ccm_128, info);
450
451                 if (len != sizeof(*aes_ccm_128)) {
452                         rc = -EINVAL;
453                         goto out;
454                 }
455                 memcpy(aes_ccm_128->iv,
456                        cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
457                        TLS_CIPHER_AES_CCM_128_IV_SIZE);
458                 memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
459                        TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
460                 if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
461                         rc = -EFAULT;
462                 break;
463         }
464         case TLS_CIPHER_CHACHA20_POLY1305: {
465                 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
466                         container_of(crypto_info,
467                                 struct tls12_crypto_info_chacha20_poly1305,
468                                 info);
469
470                 if (len != sizeof(*chacha20_poly1305)) {
471                         rc = -EINVAL;
472                         goto out;
473                 }
474                 memcpy(chacha20_poly1305->iv,
475                        cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
476                        TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
477                 memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
478                        TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
479                 if (copy_to_user(optval, chacha20_poly1305,
480                                 sizeof(*chacha20_poly1305)))
481                         rc = -EFAULT;
482                 break;
483         }
484         case TLS_CIPHER_SM4_GCM: {
485                 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
486                         container_of(crypto_info,
487                                 struct tls12_crypto_info_sm4_gcm, info);
488
489                 if (len != sizeof(*sm4_gcm_info)) {
490                         rc = -EINVAL;
491                         goto out;
492                 }
493                 memcpy(sm4_gcm_info->iv,
494                        cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
495                        TLS_CIPHER_SM4_GCM_IV_SIZE);
496                 memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
497                        TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
498                 if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
499                         rc = -EFAULT;
500                 break;
501         }
502         case TLS_CIPHER_SM4_CCM: {
503                 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
504                         container_of(crypto_info,
505                                 struct tls12_crypto_info_sm4_ccm, info);
506
507                 if (len != sizeof(*sm4_ccm_info)) {
508                         rc = -EINVAL;
509                         goto out;
510                 }
511                 memcpy(sm4_ccm_info->iv,
512                        cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
513                        TLS_CIPHER_SM4_CCM_IV_SIZE);
514                 memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
515                        TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
516                 if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
517                         rc = -EFAULT;
518                 break;
519         }
520         case TLS_CIPHER_ARIA_GCM_128: {
521                 struct tls12_crypto_info_aria_gcm_128 *
522                   crypto_info_aria_gcm_128 =
523                   container_of(crypto_info,
524                                struct tls12_crypto_info_aria_gcm_128,
525                                info);
526
527                 if (len != sizeof(*crypto_info_aria_gcm_128)) {
528                         rc = -EINVAL;
529                         goto out;
530                 }
531                 memcpy(crypto_info_aria_gcm_128->iv,
532                        cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE,
533                        TLS_CIPHER_ARIA_GCM_128_IV_SIZE);
534                 memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq,
535                        TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE);
536                 if (copy_to_user(optval,
537                                  crypto_info_aria_gcm_128,
538                                  sizeof(*crypto_info_aria_gcm_128)))
539                         rc = -EFAULT;
540                 break;
541         }
542         case TLS_CIPHER_ARIA_GCM_256: {
543                 struct tls12_crypto_info_aria_gcm_256 *
544                   crypto_info_aria_gcm_256 =
545                   container_of(crypto_info,
546                                struct tls12_crypto_info_aria_gcm_256,
547                                info);
548
549                 if (len != sizeof(*crypto_info_aria_gcm_256)) {
550                         rc = -EINVAL;
551                         goto out;
552                 }
553                 memcpy(crypto_info_aria_gcm_256->iv,
554                        cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE,
555                        TLS_CIPHER_ARIA_GCM_256_IV_SIZE);
556                 memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq,
557                        TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE);
558                 if (copy_to_user(optval,
559                                  crypto_info_aria_gcm_256,
560                                  sizeof(*crypto_info_aria_gcm_256)))
561                         rc = -EFAULT;
562                 break;
563         }
564         default:
565                 rc = -EINVAL;
566         }
567
568 out:
569         return rc;
570 }
571
572 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
573                                    int __user *optlen)
574 {
575         struct tls_context *ctx = tls_get_ctx(sk);
576         unsigned int value;
577         int len;
578
579         if (get_user(len, optlen))
580                 return -EFAULT;
581
582         if (len != sizeof(value))
583                 return -EINVAL;
584
585         value = ctx->zerocopy_sendfile;
586         if (copy_to_user(optval, &value, sizeof(value)))
587                 return -EFAULT;
588
589         return 0;
590 }
591
592 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
593                                     int __user *optlen)
594 {
595         struct tls_context *ctx = tls_get_ctx(sk);
596         int value, len;
597
598         if (ctx->prot_info.version != TLS_1_3_VERSION)
599                 return -EINVAL;
600
601         if (get_user(len, optlen))
602                 return -EFAULT;
603         if (len < sizeof(value))
604                 return -EINVAL;
605
606         value = -EINVAL;
607         if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
608                 value = ctx->rx_no_pad;
609         if (value < 0)
610                 return value;
611
612         if (put_user(sizeof(value), optlen))
613                 return -EFAULT;
614         if (copy_to_user(optval, &value, sizeof(value)))
615                 return -EFAULT;
616
617         return 0;
618 }
619
620 static int do_tls_getsockopt(struct sock *sk, int optname,
621                              char __user *optval, int __user *optlen)
622 {
623         int rc = 0;
624
625         lock_sock(sk);
626
627         switch (optname) {
628         case TLS_TX:
629         case TLS_RX:
630                 rc = do_tls_getsockopt_conf(sk, optval, optlen,
631                                             optname == TLS_TX);
632                 break;
633         case TLS_TX_ZEROCOPY_RO:
634                 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
635                 break;
636         case TLS_RX_EXPECT_NO_PAD:
637                 rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
638                 break;
639         default:
640                 rc = -ENOPROTOOPT;
641                 break;
642         }
643
644         release_sock(sk);
645
646         return rc;
647 }
648
649 static int tls_getsockopt(struct sock *sk, int level, int optname,
650                           char __user *optval, int __user *optlen)
651 {
652         struct tls_context *ctx = tls_get_ctx(sk);
653
654         if (level != SOL_TLS)
655                 return ctx->sk_proto->getsockopt(sk, level,
656                                                  optname, optval, optlen);
657
658         return do_tls_getsockopt(sk, optname, optval, optlen);
659 }
660
661 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
662                                   unsigned int optlen, int tx)
663 {
664         struct tls_crypto_info *crypto_info;
665         struct tls_crypto_info *alt_crypto_info;
666         struct tls_context *ctx = tls_get_ctx(sk);
667         size_t optsize;
668         int rc = 0;
669         int conf;
670
671         if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
672                 return -EINVAL;
673
674         if (tx) {
675                 crypto_info = &ctx->crypto_send.info;
676                 alt_crypto_info = &ctx->crypto_recv.info;
677         } else {
678                 crypto_info = &ctx->crypto_recv.info;
679                 alt_crypto_info = &ctx->crypto_send.info;
680         }
681
682         /* Currently we don't support set crypto info more than one time */
683         if (TLS_CRYPTO_INFO_READY(crypto_info))
684                 return -EBUSY;
685
686         rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
687         if (rc) {
688                 rc = -EFAULT;
689                 goto err_crypto_info;
690         }
691
692         /* check version */
693         if (crypto_info->version != TLS_1_2_VERSION &&
694             crypto_info->version != TLS_1_3_VERSION) {
695                 rc = -EINVAL;
696                 goto err_crypto_info;
697         }
698
699         /* Ensure that TLS version and ciphers are same in both directions */
700         if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
701                 if (alt_crypto_info->version != crypto_info->version ||
702                     alt_crypto_info->cipher_type != crypto_info->cipher_type) {
703                         rc = -EINVAL;
704                         goto err_crypto_info;
705                 }
706         }
707
708         switch (crypto_info->cipher_type) {
709         case TLS_CIPHER_AES_GCM_128:
710                 optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
711                 break;
712         case TLS_CIPHER_AES_GCM_256: {
713                 optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
714                 break;
715         }
716         case TLS_CIPHER_AES_CCM_128:
717                 optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
718                 break;
719         case TLS_CIPHER_CHACHA20_POLY1305:
720                 optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
721                 break;
722         case TLS_CIPHER_SM4_GCM:
723                 optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
724                 break;
725         case TLS_CIPHER_SM4_CCM:
726                 optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
727                 break;
728         case TLS_CIPHER_ARIA_GCM_128:
729                 if (crypto_info->version != TLS_1_2_VERSION) {
730                         rc = -EINVAL;
731                         goto err_crypto_info;
732                 }
733                 optsize = sizeof(struct tls12_crypto_info_aria_gcm_128);
734                 break;
735         case TLS_CIPHER_ARIA_GCM_256:
736                 if (crypto_info->version != TLS_1_2_VERSION) {
737                         rc = -EINVAL;
738                         goto err_crypto_info;
739                 }
740                 optsize = sizeof(struct tls12_crypto_info_aria_gcm_256);
741                 break;
742         default:
743                 rc = -EINVAL;
744                 goto err_crypto_info;
745         }
746
747         if (optlen != optsize) {
748                 rc = -EINVAL;
749                 goto err_crypto_info;
750         }
751
752         rc = copy_from_sockptr_offset(crypto_info + 1, optval,
753                                       sizeof(*crypto_info),
754                                       optlen - sizeof(*crypto_info));
755         if (rc) {
756                 rc = -EFAULT;
757                 goto err_crypto_info;
758         }
759
760         if (tx) {
761                 rc = tls_set_device_offload(sk, ctx);
762                 conf = TLS_HW;
763                 if (!rc) {
764                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
765                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
766                 } else {
767                         rc = tls_set_sw_offload(sk, ctx, 1);
768                         if (rc)
769                                 goto err_crypto_info;
770                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
771                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
772                         conf = TLS_SW;
773                 }
774         } else {
775                 rc = tls_set_device_offload_rx(sk, ctx);
776                 conf = TLS_HW;
777                 if (!rc) {
778                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
779                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
780                 } else {
781                         rc = tls_set_sw_offload(sk, ctx, 0);
782                         if (rc)
783                                 goto err_crypto_info;
784                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
785                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
786                         conf = TLS_SW;
787                 }
788                 tls_sw_strparser_arm(sk, ctx);
789         }
790
791         if (tx)
792                 ctx->tx_conf = conf;
793         else
794                 ctx->rx_conf = conf;
795         update_sk_prot(sk, ctx);
796         if (tx) {
797                 ctx->sk_write_space = sk->sk_write_space;
798                 sk->sk_write_space = tls_write_space;
799         } else {
800                 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
801
802                 tls_strp_check_rcv(&rx_ctx->strp);
803         }
804         return 0;
805
806 err_crypto_info:
807         memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
808         return rc;
809 }
810
811 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
812                                    unsigned int optlen)
813 {
814         struct tls_context *ctx = tls_get_ctx(sk);
815         unsigned int value;
816
817         if (sockptr_is_null(optval) || optlen != sizeof(value))
818                 return -EINVAL;
819
820         if (copy_from_sockptr(&value, optval, sizeof(value)))
821                 return -EFAULT;
822
823         if (value > 1)
824                 return -EINVAL;
825
826         ctx->zerocopy_sendfile = value;
827
828         return 0;
829 }
830
831 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
832                                     unsigned int optlen)
833 {
834         struct tls_context *ctx = tls_get_ctx(sk);
835         u32 val;
836         int rc;
837
838         if (ctx->prot_info.version != TLS_1_3_VERSION ||
839             sockptr_is_null(optval) || optlen < sizeof(val))
840                 return -EINVAL;
841
842         rc = copy_from_sockptr(&val, optval, sizeof(val));
843         if (rc)
844                 return -EFAULT;
845         if (val > 1)
846                 return -EINVAL;
847         rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
848         if (rc < 1)
849                 return rc == 0 ? -EINVAL : rc;
850
851         lock_sock(sk);
852         rc = -EINVAL;
853         if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
854                 ctx->rx_no_pad = val;
855                 tls_update_rx_zc_capable(ctx);
856                 rc = 0;
857         }
858         release_sock(sk);
859
860         return rc;
861 }
862
863 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
864                              unsigned int optlen)
865 {
866         int rc = 0;
867
868         switch (optname) {
869         case TLS_TX:
870         case TLS_RX:
871                 lock_sock(sk);
872                 rc = do_tls_setsockopt_conf(sk, optval, optlen,
873                                             optname == TLS_TX);
874                 release_sock(sk);
875                 break;
876         case TLS_TX_ZEROCOPY_RO:
877                 lock_sock(sk);
878                 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
879                 release_sock(sk);
880                 break;
881         case TLS_RX_EXPECT_NO_PAD:
882                 rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
883                 break;
884         default:
885                 rc = -ENOPROTOOPT;
886                 break;
887         }
888         return rc;
889 }
890
891 static int tls_setsockopt(struct sock *sk, int level, int optname,
892                           sockptr_t optval, unsigned int optlen)
893 {
894         struct tls_context *ctx = tls_get_ctx(sk);
895
896         if (level != SOL_TLS)
897                 return ctx->sk_proto->setsockopt(sk, level, optname, optval,
898                                                  optlen);
899
900         return do_tls_setsockopt(sk, optname, optval, optlen);
901 }
902
903 struct tls_context *tls_ctx_create(struct sock *sk)
904 {
905         struct inet_connection_sock *icsk = inet_csk(sk);
906         struct tls_context *ctx;
907
908         ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
909         if (!ctx)
910                 return NULL;
911
912         mutex_init(&ctx->tx_lock);
913         rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
914         ctx->sk_proto = READ_ONCE(sk->sk_prot);
915         ctx->sk = sk;
916         return ctx;
917 }
918
919 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
920                             const struct proto_ops *base)
921 {
922         ops[TLS_BASE][TLS_BASE] = *base;
923
924         ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
925         ops[TLS_SW  ][TLS_BASE].splice_eof      = tls_sw_splice_eof;
926         ops[TLS_SW  ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
927
928         ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
929         ops[TLS_BASE][TLS_SW  ].splice_read     = tls_sw_splice_read;
930
931         ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
932         ops[TLS_SW  ][TLS_SW  ].splice_read     = tls_sw_splice_read;
933
934 #ifdef CONFIG_TLS_DEVICE
935         ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
936         ops[TLS_HW  ][TLS_BASE].sendpage_locked = NULL;
937
938         ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
939         ops[TLS_HW  ][TLS_SW  ].sendpage_locked = NULL;
940
941         ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
942
943         ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
944
945         ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
946         ops[TLS_HW  ][TLS_HW  ].sendpage_locked = NULL;
947 #endif
948 #ifdef CONFIG_TLS_TOE
949         ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
950 #endif
951 }
952
953 static void tls_build_proto(struct sock *sk)
954 {
955         int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
956         struct proto *prot = READ_ONCE(sk->sk_prot);
957
958         /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
959         if (ip_ver == TLSV6 &&
960             unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
961                 mutex_lock(&tcpv6_prot_mutex);
962                 if (likely(prot != saved_tcpv6_prot)) {
963                         build_protos(tls_prots[TLSV6], prot);
964                         build_proto_ops(tls_proto_ops[TLSV6],
965                                         sk->sk_socket->ops);
966                         smp_store_release(&saved_tcpv6_prot, prot);
967                 }
968                 mutex_unlock(&tcpv6_prot_mutex);
969         }
970
971         if (ip_ver == TLSV4 &&
972             unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
973                 mutex_lock(&tcpv4_prot_mutex);
974                 if (likely(prot != saved_tcpv4_prot)) {
975                         build_protos(tls_prots[TLSV4], prot);
976                         build_proto_ops(tls_proto_ops[TLSV4],
977                                         sk->sk_socket->ops);
978                         smp_store_release(&saved_tcpv4_prot, prot);
979                 }
980                 mutex_unlock(&tcpv4_prot_mutex);
981         }
982 }
983
984 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
985                          const struct proto *base)
986 {
987         prot[TLS_BASE][TLS_BASE] = *base;
988         prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
989         prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
990         prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
991
992         prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
993         prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
994         prot[TLS_SW][TLS_BASE].splice_eof       = tls_sw_splice_eof;
995         prot[TLS_SW][TLS_BASE].sendpage         = tls_sw_sendpage;
996
997         prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
998         prot[TLS_BASE][TLS_SW].recvmsg            = tls_sw_recvmsg;
999         prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
1000         prot[TLS_BASE][TLS_SW].close              = tls_sk_proto_close;
1001
1002         prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
1003         prot[TLS_SW][TLS_SW].recvmsg            = tls_sw_recvmsg;
1004         prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
1005         prot[TLS_SW][TLS_SW].close              = tls_sk_proto_close;
1006
1007 #ifdef CONFIG_TLS_DEVICE
1008         prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
1009         prot[TLS_HW][TLS_BASE].sendmsg          = tls_device_sendmsg;
1010         prot[TLS_HW][TLS_BASE].sendpage         = tls_device_sendpage;
1011
1012         prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
1013         prot[TLS_HW][TLS_SW].sendmsg            = tls_device_sendmsg;
1014         prot[TLS_HW][TLS_SW].sendpage           = tls_device_sendpage;
1015
1016         prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
1017
1018         prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
1019
1020         prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
1021 #endif
1022 #ifdef CONFIG_TLS_TOE
1023         prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
1024         prot[TLS_HW_RECORD][TLS_HW_RECORD].hash         = tls_toe_hash;
1025         prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash       = tls_toe_unhash;
1026 #endif
1027 }
1028
1029 static int tls_init(struct sock *sk)
1030 {
1031         struct tls_context *ctx;
1032         int rc = 0;
1033
1034         tls_build_proto(sk);
1035
1036 #ifdef CONFIG_TLS_TOE
1037         if (tls_toe_bypass(sk))
1038                 return 0;
1039 #endif
1040
1041         /* The TLS ulp is currently supported only for TCP sockets
1042          * in ESTABLISHED state.
1043          * Supporting sockets in LISTEN state will require us
1044          * to modify the accept implementation to clone rather then
1045          * share the ulp context.
1046          */
1047         if (sk->sk_state != TCP_ESTABLISHED)
1048                 return -ENOTCONN;
1049
1050         /* allocate tls context */
1051         write_lock_bh(&sk->sk_callback_lock);
1052         ctx = tls_ctx_create(sk);
1053         if (!ctx) {
1054                 rc = -ENOMEM;
1055                 goto out;
1056         }
1057
1058         ctx->tx_conf = TLS_BASE;
1059         ctx->rx_conf = TLS_BASE;
1060         update_sk_prot(sk, ctx);
1061 out:
1062         write_unlock_bh(&sk->sk_callback_lock);
1063         return rc;
1064 }
1065
1066 static void tls_update(struct sock *sk, struct proto *p,
1067                        void (*write_space)(struct sock *sk))
1068 {
1069         struct tls_context *ctx;
1070
1071         WARN_ON_ONCE(sk->sk_prot == p);
1072
1073         ctx = tls_get_ctx(sk);
1074         if (likely(ctx)) {
1075                 ctx->sk_write_space = write_space;
1076                 ctx->sk_proto = p;
1077         } else {
1078                 /* Pairs with lockless read in sk_clone_lock(). */
1079                 WRITE_ONCE(sk->sk_prot, p);
1080                 sk->sk_write_space = write_space;
1081         }
1082 }
1083
1084 static u16 tls_user_config(struct tls_context *ctx, bool tx)
1085 {
1086         u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
1087
1088         switch (config) {
1089         case TLS_BASE:
1090                 return TLS_CONF_BASE;
1091         case TLS_SW:
1092                 return TLS_CONF_SW;
1093         case TLS_HW:
1094                 return TLS_CONF_HW;
1095         case TLS_HW_RECORD:
1096                 return TLS_CONF_HW_RECORD;
1097         }
1098         return 0;
1099 }
1100
1101 static int tls_get_info(struct sock *sk, struct sk_buff *skb)
1102 {
1103         u16 version, cipher_type;
1104         struct tls_context *ctx;
1105         struct nlattr *start;
1106         int err;
1107
1108         start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1109         if (!start)
1110                 return -EMSGSIZE;
1111
1112         rcu_read_lock();
1113         ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1114         if (!ctx) {
1115                 err = 0;
1116                 goto nla_failure;
1117         }
1118         version = ctx->prot_info.version;
1119         if (version) {
1120                 err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1121                 if (err)
1122                         goto nla_failure;
1123         }
1124         cipher_type = ctx->prot_info.cipher_type;
1125         if (cipher_type) {
1126                 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1127                 if (err)
1128                         goto nla_failure;
1129         }
1130         err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1131         if (err)
1132                 goto nla_failure;
1133
1134         err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1135         if (err)
1136                 goto nla_failure;
1137
1138         if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1139                 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1140                 if (err)
1141                         goto nla_failure;
1142         }
1143         if (ctx->rx_no_pad) {
1144                 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1145                 if (err)
1146                         goto nla_failure;
1147         }
1148
1149         rcu_read_unlock();
1150         nla_nest_end(skb, start);
1151         return 0;
1152
1153 nla_failure:
1154         rcu_read_unlock();
1155         nla_nest_cancel(skb, start);
1156         return err;
1157 }
1158
1159 static size_t tls_get_info_size(const struct sock *sk)
1160 {
1161         size_t size = 0;
1162
1163         size += nla_total_size(0) +             /* INET_ULP_INFO_TLS */
1164                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_VERSION */
1165                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_CIPHER */
1166                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_RXCONF */
1167                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_TXCONF */
1168                 nla_total_size(0) +             /* TLS_INFO_ZC_RO_TX */
1169                 nla_total_size(0) +             /* TLS_INFO_RX_NO_PAD */
1170                 0;
1171
1172         return size;
1173 }
1174
1175 static int __net_init tls_init_net(struct net *net)
1176 {
1177         int err;
1178
1179         net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1180         if (!net->mib.tls_statistics)
1181                 return -ENOMEM;
1182
1183         err = tls_proc_init(net);
1184         if (err)
1185                 goto err_free_stats;
1186
1187         return 0;
1188 err_free_stats:
1189         free_percpu(net->mib.tls_statistics);
1190         return err;
1191 }
1192
1193 static void __net_exit tls_exit_net(struct net *net)
1194 {
1195         tls_proc_fini(net);
1196         free_percpu(net->mib.tls_statistics);
1197 }
1198
1199 static struct pernet_operations tls_proc_ops = {
1200         .init = tls_init_net,
1201         .exit = tls_exit_net,
1202 };
1203
1204 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1205         .name                   = "tls",
1206         .owner                  = THIS_MODULE,
1207         .init                   = tls_init,
1208         .update                 = tls_update,
1209         .get_info               = tls_get_info,
1210         .get_info_size          = tls_get_info_size,
1211 };
1212
1213 static int __init tls_register(void)
1214 {
1215         int err;
1216
1217         err = register_pernet_subsys(&tls_proc_ops);
1218         if (err)
1219                 return err;
1220
1221         err = tls_strp_dev_init();
1222         if (err)
1223                 goto err_pernet;
1224
1225         err = tls_device_init();
1226         if (err)
1227                 goto err_strp;
1228
1229         tcp_register_ulp(&tcp_tls_ulp_ops);
1230
1231         return 0;
1232 err_strp:
1233         tls_strp_dev_exit();
1234 err_pernet:
1235         unregister_pernet_subsys(&tls_proc_ops);
1236         return err;
1237 }
1238
1239 static void __exit tls_unregister(void)
1240 {
1241         tcp_unregister_ulp(&tcp_tls_ulp_ops);
1242         tls_strp_dev_exit();
1243         tls_device_cleanup();
1244         unregister_pernet_subsys(&tls_proc_ops);
1245 }
1246
1247 module_init(tls_register);
1248 module_exit(tls_unregister);