GNU Linux-libre 4.19.207-gnu1
[releases.git] / net / tls / tls_device.c
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40
41 /* device_offload_lock is used to synchronize tls_dev_add
42  * against NETDEV_DOWN notifications.
43  */
44 static DECLARE_RWSEM(device_offload_lock);
45
46 static void tls_device_gc_task(struct work_struct *work);
47
48 static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
49 static LIST_HEAD(tls_device_gc_list);
50 static LIST_HEAD(tls_device_list);
51 static DEFINE_SPINLOCK(tls_device_lock);
52
53 static void tls_device_free_ctx(struct tls_context *ctx)
54 {
55         if (ctx->tx_conf == TLS_HW) {
56                 kfree(tls_offload_ctx_tx(ctx));
57                 kfree(ctx->tx.rec_seq);
58                 kfree(ctx->tx.iv);
59         }
60
61         if (ctx->rx_conf == TLS_HW)
62                 kfree(tls_offload_ctx_rx(ctx));
63
64         tls_ctx_free(ctx);
65 }
66
67 static void tls_device_gc_task(struct work_struct *work)
68 {
69         struct tls_context *ctx, *tmp;
70         unsigned long flags;
71         LIST_HEAD(gc_list);
72
73         spin_lock_irqsave(&tls_device_lock, flags);
74         list_splice_init(&tls_device_gc_list, &gc_list);
75         spin_unlock_irqrestore(&tls_device_lock, flags);
76
77         list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
78                 struct net_device *netdev = ctx->netdev;
79
80                 if (netdev && ctx->tx_conf == TLS_HW) {
81                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
82                                                         TLS_OFFLOAD_CTX_DIR_TX);
83                         dev_put(netdev);
84                         ctx->netdev = NULL;
85                 }
86
87                 list_del(&ctx->list);
88                 tls_device_free_ctx(ctx);
89         }
90 }
91
92 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
93                               struct net_device *netdev)
94 {
95         if (sk->sk_destruct != tls_device_sk_destruct) {
96                 refcount_set(&ctx->refcount, 1);
97                 dev_hold(netdev);
98                 ctx->netdev = netdev;
99                 spin_lock_irq(&tls_device_lock);
100                 list_add_tail(&ctx->list, &tls_device_list);
101                 spin_unlock_irq(&tls_device_lock);
102
103                 ctx->sk_destruct = sk->sk_destruct;
104                 sk->sk_destruct = tls_device_sk_destruct;
105         }
106 }
107
108 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
109 {
110         unsigned long flags;
111
112         spin_lock_irqsave(&tls_device_lock, flags);
113         list_move_tail(&ctx->list, &tls_device_gc_list);
114
115         /* schedule_work inside the spinlock
116          * to make sure tls_device_down waits for that work.
117          */
118         schedule_work(&tls_device_gc_work);
119
120         spin_unlock_irqrestore(&tls_device_lock, flags);
121 }
122
123 /* We assume that the socket is already connected */
124 static struct net_device *get_netdev_for_sock(struct sock *sk)
125 {
126         struct dst_entry *dst = sk_dst_get(sk);
127         struct net_device *netdev = NULL;
128
129         if (likely(dst)) {
130                 netdev = dst->dev;
131                 dev_hold(netdev);
132         }
133
134         dst_release(dst);
135
136         return netdev;
137 }
138
139 static void destroy_record(struct tls_record_info *record)
140 {
141         int nr_frags = record->num_frags;
142         skb_frag_t *frag;
143
144         while (nr_frags-- > 0) {
145                 frag = &record->frags[nr_frags];
146                 __skb_frag_unref(frag);
147         }
148         kfree(record);
149 }
150
151 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
152 {
153         struct tls_record_info *info, *temp;
154
155         list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
156                 list_del(&info->list);
157                 destroy_record(info);
158         }
159
160         offload_ctx->retransmit_hint = NULL;
161 }
162
163 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
164 {
165         struct tls_context *tls_ctx = tls_get_ctx(sk);
166         struct tls_record_info *info, *temp;
167         struct tls_offload_context_tx *ctx;
168         u64 deleted_records = 0;
169         unsigned long flags;
170
171         if (!tls_ctx)
172                 return;
173
174         ctx = tls_offload_ctx_tx(tls_ctx);
175
176         spin_lock_irqsave(&ctx->lock, flags);
177         info = ctx->retransmit_hint;
178         if (info && !before(acked_seq, info->end_seq)) {
179                 ctx->retransmit_hint = NULL;
180                 list_del(&info->list);
181                 destroy_record(info);
182                 deleted_records++;
183         }
184
185         list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
186                 if (before(acked_seq, info->end_seq))
187                         break;
188                 list_del(&info->list);
189
190                 destroy_record(info);
191                 deleted_records++;
192         }
193
194         ctx->unacked_record_sn += deleted_records;
195         spin_unlock_irqrestore(&ctx->lock, flags);
196 }
197
198 /* At this point, there should be no references on this
199  * socket and no in-flight SKBs associated with this
200  * socket, so it is safe to free all the resources.
201  */
202 void tls_device_sk_destruct(struct sock *sk)
203 {
204         struct tls_context *tls_ctx = tls_get_ctx(sk);
205         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
206
207         tls_ctx->sk_destruct(sk);
208
209         if (tls_ctx->tx_conf == TLS_HW) {
210                 if (ctx->open_record)
211                         destroy_record(ctx->open_record);
212                 delete_all_records(ctx);
213                 crypto_free_aead(ctx->aead_send);
214                 clean_acked_data_disable(inet_csk(sk));
215         }
216
217         if (refcount_dec_and_test(&tls_ctx->refcount))
218                 tls_device_queue_ctx_destruction(tls_ctx);
219 }
220 EXPORT_SYMBOL(tls_device_sk_destruct);
221
222 static void tls_append_frag(struct tls_record_info *record,
223                             struct page_frag *pfrag,
224                             int size)
225 {
226         skb_frag_t *frag;
227
228         frag = &record->frags[record->num_frags - 1];
229         if (frag->page.p == pfrag->page &&
230             frag->page_offset + frag->size == pfrag->offset) {
231                 frag->size += size;
232         } else {
233                 ++frag;
234                 frag->page.p = pfrag->page;
235                 frag->page_offset = pfrag->offset;
236                 frag->size = size;
237                 ++record->num_frags;
238                 get_page(pfrag->page);
239         }
240
241         pfrag->offset += size;
242         record->len += size;
243 }
244
245 static int tls_push_record(struct sock *sk,
246                            struct tls_context *ctx,
247                            struct tls_offload_context_tx *offload_ctx,
248                            struct tls_record_info *record,
249                            struct page_frag *pfrag,
250                            int flags,
251                            unsigned char record_type)
252 {
253         struct tcp_sock *tp = tcp_sk(sk);
254         struct page_frag dummy_tag_frag;
255         skb_frag_t *frag;
256         int i;
257
258         /* fill prepend */
259         frag = &record->frags[0];
260         tls_fill_prepend(ctx,
261                          skb_frag_address(frag),
262                          record->len - ctx->tx.prepend_size,
263                          record_type);
264
265         /* HW doesn't care about the data in the tag, because it fills it. */
266         dummy_tag_frag.page = skb_frag_page(frag);
267         dummy_tag_frag.offset = 0;
268
269         tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size);
270         record->end_seq = tp->write_seq + record->len;
271         spin_lock_irq(&offload_ctx->lock);
272         list_add_tail(&record->list, &offload_ctx->records_list);
273         spin_unlock_irq(&offload_ctx->lock);
274         offload_ctx->open_record = NULL;
275         set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
276         tls_advance_record_sn(sk, &ctx->tx);
277
278         for (i = 0; i < record->num_frags; i++) {
279                 frag = &record->frags[i];
280                 sg_unmark_end(&offload_ctx->sg_tx_data[i]);
281                 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
282                             frag->size, frag->page_offset);
283                 sk_mem_charge(sk, frag->size);
284                 get_page(skb_frag_page(frag));
285         }
286         sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
287
288         /* all ready, send */
289         return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
290 }
291
292 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
293                                  struct page_frag *pfrag,
294                                  size_t prepend_size)
295 {
296         struct tls_record_info *record;
297         skb_frag_t *frag;
298
299         record = kmalloc(sizeof(*record), GFP_KERNEL);
300         if (!record)
301                 return -ENOMEM;
302
303         frag = &record->frags[0];
304         __skb_frag_set_page(frag, pfrag->page);
305         frag->page_offset = pfrag->offset;
306         skb_frag_size_set(frag, prepend_size);
307
308         get_page(pfrag->page);
309         pfrag->offset += prepend_size;
310
311         record->num_frags = 1;
312         record->len = prepend_size;
313         offload_ctx->open_record = record;
314         return 0;
315 }
316
317 static int tls_do_allocation(struct sock *sk,
318                              struct tls_offload_context_tx *offload_ctx,
319                              struct page_frag *pfrag,
320                              size_t prepend_size)
321 {
322         int ret;
323
324         if (!offload_ctx->open_record) {
325                 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
326                                                    sk->sk_allocation))) {
327                         sk->sk_prot->enter_memory_pressure(sk);
328                         sk_stream_moderate_sndbuf(sk);
329                         return -ENOMEM;
330                 }
331
332                 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
333                 if (ret)
334                         return ret;
335
336                 if (pfrag->size > pfrag->offset)
337                         return 0;
338         }
339
340         if (!sk_page_frag_refill(sk, pfrag))
341                 return -ENOMEM;
342
343         return 0;
344 }
345
346 static int tls_push_data(struct sock *sk,
347                          struct iov_iter *msg_iter,
348                          size_t size, int flags,
349                          unsigned char record_type)
350 {
351         struct tls_context *tls_ctx = tls_get_ctx(sk);
352         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
353         int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
354         struct tls_record_info *record = ctx->open_record;
355         struct page_frag *pfrag;
356         size_t orig_size = size;
357         u32 max_open_record_len;
358         bool more = false;
359         bool done = false;
360         int copy, rc = 0;
361         long timeo;
362
363         if (flags &
364             ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
365                 return -ENOTSUPP;
366
367         if (sk->sk_err)
368                 return -sk->sk_err;
369
370         timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
371         rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
372         if (rc < 0)
373                 return rc;
374
375         pfrag = sk_page_frag(sk);
376
377         /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
378          * we need to leave room for an authentication tag.
379          */
380         max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
381                               tls_ctx->tx.prepend_size;
382         do {
383                 rc = tls_do_allocation(sk, ctx, pfrag,
384                                        tls_ctx->tx.prepend_size);
385                 if (rc) {
386                         rc = sk_stream_wait_memory(sk, &timeo);
387                         if (!rc)
388                                 continue;
389
390                         record = ctx->open_record;
391                         if (!record)
392                                 break;
393 handle_error:
394                         if (record_type != TLS_RECORD_TYPE_DATA) {
395                                 /* avoid sending partial
396                                  * record with type !=
397                                  * application_data
398                                  */
399                                 size = orig_size;
400                                 destroy_record(record);
401                                 ctx->open_record = NULL;
402                         } else if (record->len > tls_ctx->tx.prepend_size) {
403                                 goto last_record;
404                         }
405
406                         break;
407                 }
408
409                 record = ctx->open_record;
410                 copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
411                 copy = min_t(size_t, copy, (max_open_record_len - record->len));
412
413                 if (copy_from_iter_nocache(page_address(pfrag->page) +
414                                                pfrag->offset,
415                                            copy, msg_iter) != copy) {
416                         rc = -EFAULT;
417                         goto handle_error;
418                 }
419                 tls_append_frag(record, pfrag, copy);
420
421                 size -= copy;
422                 if (!size) {
423 last_record:
424                         tls_push_record_flags = flags;
425                         if (flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE)) {
426                                 more = true;
427                                 break;
428                         }
429
430                         done = true;
431                 }
432
433                 if (done || record->len >= max_open_record_len ||
434                     (record->num_frags >= MAX_SKB_FRAGS - 1)) {
435                         rc = tls_push_record(sk,
436                                              tls_ctx,
437                                              ctx,
438                                              record,
439                                              pfrag,
440                                              tls_push_record_flags,
441                                              record_type);
442                         if (rc < 0)
443                                 break;
444                 }
445         } while (!done);
446
447         tls_ctx->pending_open_record_frags = more;
448
449         if (orig_size - size > 0)
450                 rc = orig_size - size;
451
452         return rc;
453 }
454
455 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
456 {
457         unsigned char record_type = TLS_RECORD_TYPE_DATA;
458         int rc;
459
460         lock_sock(sk);
461
462         if (unlikely(msg->msg_controllen)) {
463                 rc = tls_proccess_cmsg(sk, msg, &record_type);
464                 if (rc)
465                         goto out;
466         }
467
468         rc = tls_push_data(sk, &msg->msg_iter, size,
469                            msg->msg_flags, record_type);
470
471 out:
472         release_sock(sk);
473         return rc;
474 }
475
476 int tls_device_sendpage(struct sock *sk, struct page *page,
477                         int offset, size_t size, int flags)
478 {
479         struct iov_iter msg_iter;
480         char *kaddr;
481         struct kvec iov;
482         int rc;
483
484         if (flags & MSG_SENDPAGE_NOTLAST)
485                 flags |= MSG_MORE;
486
487         lock_sock(sk);
488
489         if (flags & MSG_OOB) {
490                 rc = -ENOTSUPP;
491                 goto out;
492         }
493
494         kaddr = kmap(page);
495         iov.iov_base = kaddr + offset;
496         iov.iov_len = size;
497         iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
498         rc = tls_push_data(sk, &msg_iter, size,
499                            flags, TLS_RECORD_TYPE_DATA);
500         kunmap(page);
501
502 out:
503         release_sock(sk);
504         return rc;
505 }
506
507 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
508                                        u32 seq, u64 *p_record_sn)
509 {
510         u64 record_sn = context->hint_record_sn;
511         struct tls_record_info *info, *last;
512
513         info = context->retransmit_hint;
514         if (!info ||
515             before(seq, info->end_seq - info->len)) {
516                 /* if retransmit_hint is irrelevant start
517                  * from the beggining of the list
518                  */
519                 info = list_first_entry(&context->records_list,
520                                         struct tls_record_info, list);
521
522                 /* send the start_marker record if seq number is before the
523                  * tls offload start marker sequence number. This record is
524                  * required to handle TCP packets which are before TLS offload
525                  * started.
526                  *  And if it's not start marker, look if this seq number
527                  * belongs to the list.
528                  */
529                 if (likely(!tls_record_is_start_marker(info))) {
530                         /* we have the first record, get the last record to see
531                          * if this seq number belongs to the list.
532                          */
533                         last = list_last_entry(&context->records_list,
534                                                struct tls_record_info, list);
535
536                         if (!between(seq, tls_record_start_seq(info),
537                                      last->end_seq))
538                                 return NULL;
539                 }
540                 record_sn = context->unacked_record_sn;
541         }
542
543         list_for_each_entry_from(info, &context->records_list, list) {
544                 if (before(seq, info->end_seq)) {
545                         if (!context->retransmit_hint ||
546                             after(info->end_seq,
547                                   context->retransmit_hint->end_seq)) {
548                                 context->hint_record_sn = record_sn;
549                                 context->retransmit_hint = info;
550                         }
551                         *p_record_sn = record_sn;
552                         return info;
553                 }
554                 record_sn++;
555         }
556
557         return NULL;
558 }
559 EXPORT_SYMBOL(tls_get_record);
560
561 static int tls_device_push_pending_record(struct sock *sk, int flags)
562 {
563         struct iov_iter msg_iter;
564
565         iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
566         return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
567 }
568
569 static void tls_device_resync_rx(struct tls_context *tls_ctx,
570                                  struct sock *sk, u32 seq, u64 rcd_sn)
571 {
572         struct net_device *netdev;
573
574         if (WARN_ON(test_and_set_bit(TLS_RX_SYNC_RUNNING, &tls_ctx->flags)))
575                 return;
576         netdev = READ_ONCE(tls_ctx->netdev);
577         if (netdev)
578                 netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk, seq, rcd_sn);
579         clear_bit_unlock(TLS_RX_SYNC_RUNNING, &tls_ctx->flags);
580 }
581
582 void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
583 {
584         struct tls_context *tls_ctx = tls_get_ctx(sk);
585         struct tls_offload_context_rx *rx_ctx;
586         u32 is_req_pending;
587         s64 resync_req;
588         u32 req_seq;
589
590         if (tls_ctx->rx_conf != TLS_HW)
591                 return;
592
593         rx_ctx = tls_offload_ctx_rx(tls_ctx);
594         resync_req = atomic64_read(&rx_ctx->resync_req);
595         req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
596         is_req_pending = resync_req;
597
598         if (unlikely(is_req_pending) && req_seq == seq &&
599             atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0)) {
600                 seq += TLS_HEADER_SIZE - 1;
601                 tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
602         }
603 }
604
605 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
606 {
607         struct strp_msg *rxm = strp_msg(skb);
608         int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
609         struct sk_buff *skb_iter, *unused;
610         struct scatterlist sg[1];
611         char *orig_buf, *buf;
612
613         orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
614                            TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
615         if (!orig_buf)
616                 return -ENOMEM;
617         buf = orig_buf;
618
619         nsg = skb_cow_data(skb, 0, &unused);
620         if (unlikely(nsg < 0)) {
621                 err = nsg;
622                 goto free_buf;
623         }
624
625         sg_init_table(sg, 1);
626         sg_set_buf(&sg[0], buf,
627                    rxm->full_len + TLS_HEADER_SIZE +
628                    TLS_CIPHER_AES_GCM_128_IV_SIZE);
629         skb_copy_bits(skb, offset, buf,
630                       TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
631
632         /* We are interested only in the decrypted data not the auth */
633         err = decrypt_skb(sk, skb, sg);
634         if (err != -EBADMSG)
635                 goto free_buf;
636         else
637                 err = 0;
638
639         data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
640
641         if (skb_pagelen(skb) > offset) {
642                 copy = min_t(int, skb_pagelen(skb) - offset, data_len);
643
644                 if (skb->decrypted)
645                         skb_store_bits(skb, offset, buf, copy);
646
647                 offset += copy;
648                 buf += copy;
649         }
650
651         pos = skb_pagelen(skb);
652         skb_walk_frags(skb, skb_iter) {
653                 int frag_pos;
654
655                 /* Practically all frags must belong to msg if reencrypt
656                  * is needed with current strparser and coalescing logic,
657                  * but strparser may "get optimized", so let's be safe.
658                  */
659                 if (pos + skb_iter->len <= offset)
660                         goto done_with_frag;
661                 if (pos >= data_len + rxm->offset)
662                         break;
663
664                 frag_pos = offset - pos;
665                 copy = min_t(int, skb_iter->len - frag_pos,
666                              data_len + rxm->offset - offset);
667
668                 if (skb_iter->decrypted)
669                         skb_store_bits(skb_iter, frag_pos, buf, copy);
670
671                 offset += copy;
672                 buf += copy;
673 done_with_frag:
674                 pos += skb_iter->len;
675         }
676
677 free_buf:
678         kfree(orig_buf);
679         return err;
680 }
681
682 int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
683 {
684         struct tls_context *tls_ctx = tls_get_ctx(sk);
685         struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
686         int is_decrypted = skb->decrypted;
687         int is_encrypted = !is_decrypted;
688         struct sk_buff *skb_iter;
689
690         /* Skip if it is already decrypted */
691         if (ctx->sw.decrypted)
692                 return 0;
693
694         /* Check if all the data is decrypted already */
695         skb_walk_frags(skb, skb_iter) {
696                 is_decrypted &= skb_iter->decrypted;
697                 is_encrypted &= !skb_iter->decrypted;
698         }
699
700         ctx->sw.decrypted |= is_decrypted;
701
702         /* Return immedeatly if the record is either entirely plaintext or
703          * entirely ciphertext. Otherwise handle reencrypt partially decrypted
704          * record.
705          */
706         return (is_encrypted || is_decrypted) ? 0 :
707                 tls_device_reencrypt(sk, skb);
708 }
709
710 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
711 {
712         u16 nonce_size, tag_size, iv_size, rec_seq_size;
713         struct tls_record_info *start_marker_record;
714         struct tls_offload_context_tx *offload_ctx;
715         struct tls_crypto_info *crypto_info;
716         struct net_device *netdev;
717         char *iv, *rec_seq;
718         struct sk_buff *skb;
719         int rc = -EINVAL;
720         __be64 rcd_sn;
721
722         if (!ctx)
723                 goto out;
724
725         if (ctx->priv_ctx_tx) {
726                 rc = -EEXIST;
727                 goto out;
728         }
729
730         start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
731         if (!start_marker_record) {
732                 rc = -ENOMEM;
733                 goto out;
734         }
735
736         offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
737         if (!offload_ctx) {
738                 rc = -ENOMEM;
739                 goto free_marker_record;
740         }
741
742         crypto_info = &ctx->crypto_send.info;
743         switch (crypto_info->cipher_type) {
744         case TLS_CIPHER_AES_GCM_128:
745                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
746                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
747                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
748                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
749                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
750                 rec_seq =
751                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
752                 break;
753         default:
754                 rc = -EINVAL;
755                 goto free_offload_ctx;
756         }
757
758         ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
759         ctx->tx.tag_size = tag_size;
760         ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
761         ctx->tx.iv_size = iv_size;
762         ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
763                              GFP_KERNEL);
764         if (!ctx->tx.iv) {
765                 rc = -ENOMEM;
766                 goto free_offload_ctx;
767         }
768
769         memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
770
771         ctx->tx.rec_seq_size = rec_seq_size;
772         ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
773         if (!ctx->tx.rec_seq) {
774                 rc = -ENOMEM;
775                 goto free_iv;
776         }
777
778         rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
779         if (rc)
780                 goto free_rec_seq;
781
782         /* start at rec_seq - 1 to account for the start marker record */
783         memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
784         offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
785
786         start_marker_record->end_seq = tcp_sk(sk)->write_seq;
787         start_marker_record->len = 0;
788         start_marker_record->num_frags = 0;
789
790         INIT_LIST_HEAD(&offload_ctx->records_list);
791         list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
792         spin_lock_init(&offload_ctx->lock);
793         sg_init_table(offload_ctx->sg_tx_data,
794                       ARRAY_SIZE(offload_ctx->sg_tx_data));
795
796         clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
797         ctx->push_pending_record = tls_device_push_pending_record;
798
799         /* TLS offload is greatly simplified if we don't send
800          * SKBs where only part of the payload needs to be encrypted.
801          * So mark the last skb in the write queue as end of record.
802          */
803         skb = tcp_write_queue_tail(sk);
804         if (skb)
805                 TCP_SKB_CB(skb)->eor = 1;
806
807         /* We support starting offload on multiple sockets
808          * concurrently, so we only need a read lock here.
809          * This lock must precede get_netdev_for_sock to prevent races between
810          * NETDEV_DOWN and setsockopt.
811          */
812         down_read(&device_offload_lock);
813         netdev = get_netdev_for_sock(sk);
814         if (!netdev) {
815                 pr_err_ratelimited("%s: netdev not found\n", __func__);
816                 rc = -EINVAL;
817                 goto release_lock;
818         }
819
820         if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
821                 rc = -ENOTSUPP;
822                 goto release_netdev;
823         }
824
825         /* Avoid offloading if the device is down
826          * We don't want to offload new flows after
827          * the NETDEV_DOWN event
828          */
829         if (!(netdev->flags & IFF_UP)) {
830                 rc = -EINVAL;
831                 goto release_netdev;
832         }
833
834         ctx->priv_ctx_tx = offload_ctx;
835         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
836                                              &ctx->crypto_send.info,
837                                              tcp_sk(sk)->write_seq);
838         if (rc)
839                 goto release_netdev;
840
841         tls_device_attach(ctx, sk, netdev);
842
843         /* following this assignment tls_is_sk_tx_device_offloaded
844          * will return true and the context might be accessed
845          * by the netdev's xmit function.
846          */
847         smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
848         dev_put(netdev);
849         up_read(&device_offload_lock);
850         goto out;
851
852 release_netdev:
853         dev_put(netdev);
854 release_lock:
855         up_read(&device_offload_lock);
856         clean_acked_data_disable(inet_csk(sk));
857         crypto_free_aead(offload_ctx->aead_send);
858 free_rec_seq:
859         kfree(ctx->tx.rec_seq);
860 free_iv:
861         kfree(ctx->tx.iv);
862 free_offload_ctx:
863         kfree(offload_ctx);
864         ctx->priv_ctx_tx = NULL;
865 free_marker_record:
866         kfree(start_marker_record);
867 out:
868         return rc;
869 }
870
871 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
872 {
873         struct tls_offload_context_rx *context;
874         struct net_device *netdev;
875         int rc = 0;
876
877         /* We support starting offload on multiple sockets
878          * concurrently, so we only need a read lock here.
879          * This lock must precede get_netdev_for_sock to prevent races between
880          * NETDEV_DOWN and setsockopt.
881          */
882         down_read(&device_offload_lock);
883         netdev = get_netdev_for_sock(sk);
884         if (!netdev) {
885                 pr_err_ratelimited("%s: netdev not found\n", __func__);
886                 rc = -EINVAL;
887                 goto release_lock;
888         }
889
890         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
891                 pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
892                                    __func__, netdev->name);
893                 rc = -ENOTSUPP;
894                 goto release_netdev;
895         }
896
897         /* Avoid offloading if the device is down
898          * We don't want to offload new flows after
899          * the NETDEV_DOWN event
900          */
901         if (!(netdev->flags & IFF_UP)) {
902                 rc = -EINVAL;
903                 goto release_netdev;
904         }
905
906         context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
907         if (!context) {
908                 rc = -ENOMEM;
909                 goto release_netdev;
910         }
911
912         ctx->priv_ctx_rx = context;
913         rc = tls_set_sw_offload(sk, ctx, 0);
914         if (rc)
915                 goto release_ctx;
916
917         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
918                                              &ctx->crypto_recv.info,
919                                              tcp_sk(sk)->copied_seq);
920         if (rc) {
921                 pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
922                                    __func__);
923                 goto free_sw_resources;
924         }
925
926         tls_device_attach(ctx, sk, netdev);
927         goto release_netdev;
928
929 free_sw_resources:
930         up_read(&device_offload_lock);
931         tls_sw_free_resources_rx(sk);
932         down_read(&device_offload_lock);
933 release_ctx:
934         ctx->priv_ctx_rx = NULL;
935 release_netdev:
936         dev_put(netdev);
937 release_lock:
938         up_read(&device_offload_lock);
939         return rc;
940 }
941
942 void tls_device_offload_cleanup_rx(struct sock *sk)
943 {
944         struct tls_context *tls_ctx = tls_get_ctx(sk);
945         struct net_device *netdev;
946
947         down_read(&device_offload_lock);
948         netdev = tls_ctx->netdev;
949         if (!netdev)
950                 goto out;
951
952         netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
953                                         TLS_OFFLOAD_CTX_DIR_RX);
954
955         if (tls_ctx->tx_conf != TLS_HW) {
956                 dev_put(netdev);
957                 tls_ctx->netdev = NULL;
958         } else {
959                 set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
960         }
961 out:
962         up_read(&device_offload_lock);
963         tls_sw_release_resources_rx(sk);
964 }
965
966 static int tls_device_down(struct net_device *netdev)
967 {
968         struct tls_context *ctx, *tmp;
969         unsigned long flags;
970         LIST_HEAD(list);
971
972         /* Request a write lock to block new offload attempts */
973         down_write(&device_offload_lock);
974
975         spin_lock_irqsave(&tls_device_lock, flags);
976         list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
977                 if (ctx->netdev != netdev ||
978                     !refcount_inc_not_zero(&ctx->refcount))
979                         continue;
980
981                 list_move(&ctx->list, &list);
982         }
983         spin_unlock_irqrestore(&tls_device_lock, flags);
984
985         list_for_each_entry_safe(ctx, tmp, &list, list) {
986                 if (ctx->tx_conf == TLS_HW)
987                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
988                                                         TLS_OFFLOAD_CTX_DIR_TX);
989                 if (ctx->rx_conf == TLS_HW &&
990                     !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
991                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
992                                                         TLS_OFFLOAD_CTX_DIR_RX);
993                 WRITE_ONCE(ctx->netdev, NULL);
994                 smp_mb__before_atomic(); /* pairs with test_and_set_bit() */
995                 while (test_bit(TLS_RX_SYNC_RUNNING, &ctx->flags))
996                         usleep_range(10, 200);
997                 dev_put(netdev);
998                 list_del_init(&ctx->list);
999
1000                 if (refcount_dec_and_test(&ctx->refcount))
1001                         tls_device_free_ctx(ctx);
1002         }
1003
1004         up_write(&device_offload_lock);
1005
1006         flush_work(&tls_device_gc_work);
1007
1008         return NOTIFY_DONE;
1009 }
1010
1011 static int tls_dev_event(struct notifier_block *this, unsigned long event,
1012                          void *ptr)
1013 {
1014         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1015
1016         if (!dev->tlsdev_ops &&
1017             !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
1018                 return NOTIFY_DONE;
1019
1020         switch (event) {
1021         case NETDEV_REGISTER:
1022         case NETDEV_FEAT_CHANGE:
1023                 if ((dev->features & NETIF_F_HW_TLS_RX) &&
1024                     !dev->tlsdev_ops->tls_dev_resync_rx)
1025                         return NOTIFY_BAD;
1026
1027                 if  (dev->tlsdev_ops &&
1028                      dev->tlsdev_ops->tls_dev_add &&
1029                      dev->tlsdev_ops->tls_dev_del)
1030                         return NOTIFY_DONE;
1031                 else
1032                         return NOTIFY_BAD;
1033         case NETDEV_DOWN:
1034                 return tls_device_down(dev);
1035         }
1036         return NOTIFY_DONE;
1037 }
1038
1039 static struct notifier_block tls_dev_notifier = {
1040         .notifier_call  = tls_dev_event,
1041 };
1042
1043 void __init tls_device_init(void)
1044 {
1045         register_netdevice_notifier(&tls_dev_notifier);
1046 }
1047
1048 void __exit tls_device_cleanup(void)
1049 {
1050         unregister_netdevice_notifier(&tls_dev_notifier);
1051         flush_work(&tls_device_gc_work);
1052 }