GNU Linux-libre 4.14.266-gnu1
[releases.git] / net / tls / tls_main.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  *
5  * This software is available to you under a choice of one of two
6  * licenses.  You may choose to be licensed under the terms of the GNU
7  * General Public License (GPL) Version 2, available from the file
8  * COPYING in the main directory of this source tree, or the
9  * OpenIB.org BSD license below:
10  *
11  *     Redistribution and use in source and binary forms, with or
12  *     without modification, are permitted provided that the following
13  *     conditions are met:
14  *
15  *      - Redistributions of source code must retain the above
16  *        copyright notice, this list of conditions and the following
17  *        disclaimer.
18  *
19  *      - Redistributions in binary form must reproduce the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer in the documentation and/or other materials
22  *        provided with the distribution.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31  * SOFTWARE.
32  */
33
34 #include <linux/module.h>
35
36 #include <net/tcp.h>
37 #include <net/inet_common.h>
38 #include <linux/highmem.h>
39 #include <linux/netdevice.h>
40 #include <linux/sched/signal.h>
41
42 #include <net/tls.h>
43
44 MODULE_AUTHOR("Mellanox Technologies");
45 MODULE_DESCRIPTION("Transport Layer Security Support");
46 MODULE_LICENSE("Dual BSD/GPL");
47 MODULE_ALIAS_TCP_ULP("tls");
48
49 enum {
50         TLSV4,
51         TLSV6,
52         TLS_NUM_PROTS,
53 };
54
55 enum {
56         TLS_BASE_TX,
57         TLS_SW_TX,
58         TLS_NUM_CONFIG,
59 };
60
61 static struct proto *saved_tcpv6_prot;
62 static DEFINE_MUTEX(tcpv6_prot_mutex);
63 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG];
64
65 static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx)
66 {
67         int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
68
69         sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf];
70 }
71
72 int wait_on_pending_writer(struct sock *sk, long *timeo)
73 {
74         int rc = 0;
75         DEFINE_WAIT_FUNC(wait, woken_wake_function);
76
77         add_wait_queue(sk_sleep(sk), &wait);
78         while (1) {
79                 if (!*timeo) {
80                         rc = -EAGAIN;
81                         break;
82                 }
83
84                 if (signal_pending(current)) {
85                         rc = sock_intr_errno(*timeo);
86                         break;
87                 }
88
89                 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
90                         break;
91         }
92         remove_wait_queue(sk_sleep(sk), &wait);
93         return rc;
94 }
95
96 int tls_push_sg(struct sock *sk,
97                 struct tls_context *ctx,
98                 struct scatterlist *sg,
99                 u16 first_offset,
100                 int flags)
101 {
102         int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
103         int ret = 0;
104         struct page *p;
105         size_t size;
106         int offset = first_offset;
107
108         size = sg->length - offset;
109         offset += sg->offset;
110
111         ctx->in_tcp_sendpages = true;
112         while (1) {
113                 if (sg_is_last(sg))
114                         sendpage_flags = flags;
115
116                 /* is sending application-limited? */
117                 tcp_rate_check_app_limited(sk);
118                 p = sg_page(sg);
119 retry:
120                 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
121
122                 if (ret != size) {
123                         if (ret > 0) {
124                                 offset += ret;
125                                 size -= ret;
126                                 goto retry;
127                         }
128
129                         offset -= sg->offset;
130                         ctx->partially_sent_offset = offset;
131                         ctx->partially_sent_record = (void *)sg;
132                         ctx->in_tcp_sendpages = false;
133                         return ret;
134                 }
135
136                 put_page(p);
137                 sk_mem_uncharge(sk, sg->length);
138                 sg = sg_next(sg);
139                 if (!sg)
140                         break;
141
142                 offset = sg->offset;
143                 size = sg->length;
144         }
145
146         clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
147         ctx->in_tcp_sendpages = false;
148         ctx->sk_write_space(sk);
149
150         return 0;
151 }
152
153 static int tls_handle_open_record(struct sock *sk, int flags)
154 {
155         struct tls_context *ctx = tls_get_ctx(sk);
156
157         if (tls_is_pending_open_record(ctx))
158                 return ctx->push_pending_record(sk, flags);
159
160         return 0;
161 }
162
163 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
164                       unsigned char *record_type)
165 {
166         struct cmsghdr *cmsg;
167         int rc = -EINVAL;
168
169         for_each_cmsghdr(cmsg, msg) {
170                 if (!CMSG_OK(msg, cmsg))
171                         return -EINVAL;
172                 if (cmsg->cmsg_level != SOL_TLS)
173                         continue;
174
175                 switch (cmsg->cmsg_type) {
176                 case TLS_SET_RECORD_TYPE:
177                         if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
178                                 return -EINVAL;
179
180                         if (msg->msg_flags & MSG_MORE)
181                                 return -EINVAL;
182
183                         rc = tls_handle_open_record(sk, msg->msg_flags);
184                         if (rc)
185                                 return rc;
186
187                         *record_type = *(unsigned char *)CMSG_DATA(cmsg);
188                         rc = 0;
189                         break;
190                 default:
191                         return -EINVAL;
192                 }
193         }
194
195         return rc;
196 }
197
198 int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx,
199                                    int flags, long *timeo)
200 {
201         struct scatterlist *sg;
202         u16 offset;
203
204         if (!tls_is_partially_sent_record(ctx))
205                 return ctx->push_pending_record(sk, flags);
206
207         sg = ctx->partially_sent_record;
208         offset = ctx->partially_sent_offset;
209
210         ctx->partially_sent_record = NULL;
211         return tls_push_sg(sk, ctx, sg, offset, flags);
212 }
213
214 static void tls_write_space(struct sock *sk)
215 {
216         struct tls_context *ctx = tls_get_ctx(sk);
217
218         /* If in_tcp_sendpages call lower protocol write space handler
219          * to ensure we wake up any waiting operations there. For example
220          * if do_tcp_sendpages where to call sk_wait_event.
221          */
222         if (ctx->in_tcp_sendpages) {
223                 ctx->sk_write_space(sk);
224                 return;
225         }
226
227         if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) {
228                 gfp_t sk_allocation = sk->sk_allocation;
229                 int rc;
230                 long timeo = 0;
231
232                 sk->sk_allocation = GFP_ATOMIC;
233                 rc = tls_push_pending_closed_record(sk, ctx,
234                                                     MSG_DONTWAIT |
235                                                     MSG_NOSIGNAL,
236                                                     &timeo);
237                 sk->sk_allocation = sk_allocation;
238
239                 if (rc < 0)
240                         return;
241         }
242
243         ctx->sk_write_space(sk);
244 }
245
246 static void tls_ctx_free(struct tls_context *ctx)
247 {
248         if (!ctx)
249                 return;
250
251         memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
252         kfree(ctx);
253 }
254
255 static void tls_sk_proto_close(struct sock *sk, long timeout)
256 {
257         struct tls_context *ctx = tls_get_ctx(sk);
258         long timeo = sock_sndtimeo(sk, 0);
259         void (*sk_proto_close)(struct sock *sk, long timeout);
260
261         lock_sock(sk);
262         sk_proto_close = ctx->sk_proto_close;
263
264         if (ctx->tx_conf == TLS_BASE_TX) {
265                 tls_ctx_free(ctx);
266                 goto skip_tx_cleanup;
267         }
268
269         if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
270                 tls_handle_open_record(sk, 0);
271
272         if (ctx->partially_sent_record) {
273                 struct scatterlist *sg = ctx->partially_sent_record;
274
275                 while (1) {
276                         put_page(sg_page(sg));
277                         sk_mem_uncharge(sk, sg->length);
278
279                         if (sg_is_last(sg))
280                                 break;
281                         sg++;
282                 }
283         }
284
285         kfree(ctx->rec_seq);
286         kfree(ctx->iv);
287
288         if (ctx->tx_conf == TLS_SW_TX) {
289                 tls_sw_free_tx_resources(sk);
290                 tls_ctx_free(ctx);
291         }
292
293 skip_tx_cleanup:
294         release_sock(sk);
295         sk_proto_close(sk, timeout);
296 }
297
298 static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
299                                 int __user *optlen)
300 {
301         int rc = 0;
302         struct tls_context *ctx = tls_get_ctx(sk);
303         struct tls_crypto_info *crypto_info;
304         int len;
305
306         if (get_user(len, optlen))
307                 return -EFAULT;
308
309         if (!optval || (len < sizeof(*crypto_info))) {
310                 rc = -EINVAL;
311                 goto out;
312         }
313
314         if (!ctx) {
315                 rc = -EBUSY;
316                 goto out;
317         }
318
319         /* get user crypto info */
320         crypto_info = &ctx->crypto_send.info;
321
322         if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
323                 rc = -EBUSY;
324                 goto out;
325         }
326
327         if (len == sizeof(*crypto_info)) {
328                 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
329                         rc = -EFAULT;
330                 goto out;
331         }
332
333         switch (crypto_info->cipher_type) {
334         case TLS_CIPHER_AES_GCM_128: {
335                 struct tls12_crypto_info_aes_gcm_128 *
336                   crypto_info_aes_gcm_128 =
337                   container_of(crypto_info,
338                                struct tls12_crypto_info_aes_gcm_128,
339                                info);
340
341                 if (len != sizeof(*crypto_info_aes_gcm_128)) {
342                         rc = -EINVAL;
343                         goto out;
344                 }
345                 lock_sock(sk);
346                 memcpy(crypto_info_aes_gcm_128->iv,
347                        ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
348                        TLS_CIPHER_AES_GCM_128_IV_SIZE);
349                 release_sock(sk);
350                 if (copy_to_user(optval,
351                                  crypto_info_aes_gcm_128,
352                                  sizeof(*crypto_info_aes_gcm_128)))
353                         rc = -EFAULT;
354                 break;
355         }
356         default:
357                 rc = -EINVAL;
358         }
359
360 out:
361         return rc;
362 }
363
364 static int do_tls_getsockopt(struct sock *sk, int optname,
365                              char __user *optval, int __user *optlen)
366 {
367         int rc = 0;
368
369         switch (optname) {
370         case TLS_TX:
371                 rc = do_tls_getsockopt_tx(sk, optval, optlen);
372                 break;
373         default:
374                 rc = -ENOPROTOOPT;
375                 break;
376         }
377         return rc;
378 }
379
380 static int tls_getsockopt(struct sock *sk, int level, int optname,
381                           char __user *optval, int __user *optlen)
382 {
383         struct tls_context *ctx = tls_get_ctx(sk);
384
385         if (level != SOL_TLS)
386                 return ctx->getsockopt(sk, level, optname, optval, optlen);
387
388         return do_tls_getsockopt(sk, optname, optval, optlen);
389 }
390
391 static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
392                                 unsigned int optlen)
393 {
394         struct tls_crypto_info *crypto_info;
395         struct tls_context *ctx = tls_get_ctx(sk);
396         int rc = 0;
397         int tx_conf;
398
399         if (!optval || (optlen < sizeof(*crypto_info))) {
400                 rc = -EINVAL;
401                 goto out;
402         }
403
404         crypto_info = &ctx->crypto_send.info;
405         /* Currently we don't support set crypto info more than one time */
406         if (TLS_CRYPTO_INFO_READY(crypto_info)) {
407                 rc = -EBUSY;
408                 goto out;
409         }
410
411         rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info));
412         if (rc) {
413                 rc = -EFAULT;
414                 goto out;
415         }
416
417         /* check version */
418         if (crypto_info->version != TLS_1_2_VERSION) {
419                 rc = -ENOTSUPP;
420                 goto err_crypto_info;
421         }
422
423         switch (crypto_info->cipher_type) {
424         case TLS_CIPHER_AES_GCM_128: {
425                 if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) {
426                         rc = -EINVAL;
427                         goto err_crypto_info;
428                 }
429                 rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
430                                     optlen - sizeof(*crypto_info));
431                 if (rc) {
432                         rc = -EFAULT;
433                         goto err_crypto_info;
434                 }
435                 break;
436         }
437         default:
438                 rc = -EINVAL;
439                 goto err_crypto_info;
440         }
441
442         /* currently SW is default, we will have ethtool in future */
443         rc = tls_set_sw_offload(sk, ctx);
444         tx_conf = TLS_SW_TX;
445         if (rc)
446                 goto err_crypto_info;
447
448         ctx->tx_conf = tx_conf;
449         update_sk_prot(sk, ctx);
450         ctx->sk_write_space = sk->sk_write_space;
451         sk->sk_write_space = tls_write_space;
452         goto out;
453
454 err_crypto_info:
455         memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
456 out:
457         return rc;
458 }
459
460 static int do_tls_setsockopt(struct sock *sk, int optname,
461                              char __user *optval, unsigned int optlen)
462 {
463         int rc = 0;
464
465         switch (optname) {
466         case TLS_TX:
467                 lock_sock(sk);
468                 rc = do_tls_setsockopt_tx(sk, optval, optlen);
469                 release_sock(sk);
470                 break;
471         default:
472                 rc = -ENOPROTOOPT;
473                 break;
474         }
475         return rc;
476 }
477
478 static int tls_setsockopt(struct sock *sk, int level, int optname,
479                           char __user *optval, unsigned int optlen)
480 {
481         struct tls_context *ctx = tls_get_ctx(sk);
482
483         if (level != SOL_TLS)
484                 return ctx->setsockopt(sk, level, optname, optval, optlen);
485
486         return do_tls_setsockopt(sk, optname, optval, optlen);
487 }
488
489 static void build_protos(struct proto *prot, struct proto *base)
490 {
491         prot[TLS_BASE_TX] = *base;
492         prot[TLS_BASE_TX].setsockopt    = tls_setsockopt;
493         prot[TLS_BASE_TX].getsockopt    = tls_getsockopt;
494         prot[TLS_BASE_TX].close         = tls_sk_proto_close;
495
496         prot[TLS_SW_TX] = prot[TLS_BASE_TX];
497         prot[TLS_SW_TX].sendmsg         = tls_sw_sendmsg;
498         prot[TLS_SW_TX].sendpage        = tls_sw_sendpage;
499 }
500
501 static int tls_init(struct sock *sk)
502 {
503         int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
504         struct inet_connection_sock *icsk = inet_csk(sk);
505         struct tls_context *ctx;
506         int rc = 0;
507
508         /* The TLS ulp is currently supported only for TCP sockets
509          * in ESTABLISHED state.
510          * Supporting sockets in LISTEN state will require us
511          * to modify the accept implementation to clone rather then
512          * share the ulp context.
513          */
514         if (sk->sk_state != TCP_ESTABLISHED)
515                 return -ENOTSUPP;
516
517         /* allocate tls context */
518         ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
519         if (!ctx) {
520                 rc = -ENOMEM;
521                 goto out;
522         }
523         icsk->icsk_ulp_data = ctx;
524         ctx->setsockopt = sk->sk_prot->setsockopt;
525         ctx->getsockopt = sk->sk_prot->getsockopt;
526         ctx->sk_proto_close = sk->sk_prot->close;
527
528         /* Build IPv6 TLS whenever the address of tcpv6_prot changes */
529         if (ip_ver == TLSV6 &&
530             unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
531                 mutex_lock(&tcpv6_prot_mutex);
532                 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
533                         build_protos(tls_prots[TLSV6], sk->sk_prot);
534                         smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
535                 }
536                 mutex_unlock(&tcpv6_prot_mutex);
537         }
538
539         ctx->tx_conf = TLS_BASE_TX;
540         update_sk_prot(sk, ctx);
541 out:
542         return rc;
543 }
544
545 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
546         .name                   = "tls",
547         .owner                  = THIS_MODULE,
548         .init                   = tls_init,
549 };
550
551 static int __init tls_register(void)
552 {
553         build_protos(tls_prots[TLSV4], &tcp_prot);
554
555         tcp_register_ulp(&tcp_tls_ulp_ops);
556
557         return 0;
558 }
559
560 static void __exit tls_unregister(void)
561 {
562         tcp_unregister_ulp(&tcp_tls_ulp_ops);
563 }
564
565 module_init(tls_register);
566 module_exit(tls_unregister);