GNU Linux-libre 5.4.257-gnu1
[releases.git] / net / ipv4 / esp4.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 #define pr_fmt(fmt) "IPsec: " fmt
3
4 #include <crypto/aead.h>
5 #include <crypto/authenc.h>
6 #include <linux/err.h>
7 #include <linux/module.h>
8 #include <net/ip.h>
9 #include <net/xfrm.h>
10 #include <net/esp.h>
11 #include <linux/scatterlist.h>
12 #include <linux/kernel.h>
13 #include <linux/pfkeyv2.h>
14 #include <linux/rtnetlink.h>
15 #include <linux/slab.h>
16 #include <linux/spinlock.h>
17 #include <linux/in6.h>
18 #include <net/icmp.h>
19 #include <net/protocol.h>
20 #include <net/udp.h>
21
22 #include <linux/highmem.h>
23
24 struct esp_skb_cb {
25         struct xfrm_skb_cb xfrm;
26         void *tmp;
27 };
28
29 struct esp_output_extra {
30         __be32 seqhi;
31         u32 esphoff;
32 };
33
34 #define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0]))
35
36 /*
37  * Allocate an AEAD request structure with extra space for SG and IV.
38  *
39  * For alignment considerations the IV is placed at the front, followed
40  * by the request and finally the SG list.
41  *
42  * TODO: Use spare space in skb for this where possible.
43  */
44 static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int extralen)
45 {
46         unsigned int len;
47
48         len = extralen;
49
50         len += crypto_aead_ivsize(aead);
51
52         if (len) {
53                 len += crypto_aead_alignmask(aead) &
54                        ~(crypto_tfm_ctx_alignment() - 1);
55                 len = ALIGN(len, crypto_tfm_ctx_alignment());
56         }
57
58         len += sizeof(struct aead_request) + crypto_aead_reqsize(aead);
59         len = ALIGN(len, __alignof__(struct scatterlist));
60
61         len += sizeof(struct scatterlist) * nfrags;
62
63         return kmalloc(len, GFP_ATOMIC);
64 }
65
66 static inline void *esp_tmp_extra(void *tmp)
67 {
68         return PTR_ALIGN(tmp, __alignof__(struct esp_output_extra));
69 }
70
71 static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int extralen)
72 {
73         return crypto_aead_ivsize(aead) ?
74                PTR_ALIGN((u8 *)tmp + extralen,
75                          crypto_aead_alignmask(aead) + 1) : tmp + extralen;
76 }
77
78 static inline struct aead_request *esp_tmp_req(struct crypto_aead *aead, u8 *iv)
79 {
80         struct aead_request *req;
81
82         req = (void *)PTR_ALIGN(iv + crypto_aead_ivsize(aead),
83                                 crypto_tfm_ctx_alignment());
84         aead_request_set_tfm(req, aead);
85         return req;
86 }
87
88 static inline struct scatterlist *esp_req_sg(struct crypto_aead *aead,
89                                              struct aead_request *req)
90 {
91         return (void *)ALIGN((unsigned long)(req + 1) +
92                              crypto_aead_reqsize(aead),
93                              __alignof__(struct scatterlist));
94 }
95
96 static void esp_ssg_unref(struct xfrm_state *x, void *tmp)
97 {
98         struct esp_output_extra *extra = esp_tmp_extra(tmp);
99         struct crypto_aead *aead = x->data;
100         int extralen = 0;
101         u8 *iv;
102         struct aead_request *req;
103         struct scatterlist *sg;
104
105         if (x->props.flags & XFRM_STATE_ESN)
106                 extralen += sizeof(*extra);
107
108         extra = esp_tmp_extra(tmp);
109         iv = esp_tmp_iv(aead, tmp, extralen);
110         req = esp_tmp_req(aead, iv);
111
112         /* Unref skb_frag_pages in the src scatterlist if necessary.
113          * Skip the first sg which comes from skb->data.
114          */
115         if (req->src != req->dst)
116                 for (sg = sg_next(req->src); sg; sg = sg_next(sg))
117                         put_page(sg_page(sg));
118 }
119
120 static void esp_output_done(struct crypto_async_request *base, int err)
121 {
122         struct sk_buff *skb = base->data;
123         struct xfrm_offload *xo = xfrm_offload(skb);
124         void *tmp;
125         struct xfrm_state *x;
126
127         if (xo && (xo->flags & XFRM_DEV_RESUME)) {
128                 struct sec_path *sp = skb_sec_path(skb);
129
130                 x = sp->xvec[sp->len - 1];
131         } else {
132                 x = skb_dst(skb)->xfrm;
133         }
134
135         tmp = ESP_SKB_CB(skb)->tmp;
136         esp_ssg_unref(x, tmp);
137         kfree(tmp);
138
139         if (xo && (xo->flags & XFRM_DEV_RESUME)) {
140                 if (err) {
141                         XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR);
142                         kfree_skb(skb);
143                         return;
144                 }
145
146                 skb_push(skb, skb->data - skb_mac_header(skb));
147                 secpath_reset(skb);
148                 xfrm_dev_resume(skb);
149         } else {
150                 xfrm_output_resume(skb, err);
151         }
152 }
153
154 /* Move ESP header back into place. */
155 static void esp_restore_header(struct sk_buff *skb, unsigned int offset)
156 {
157         struct ip_esp_hdr *esph = (void *)(skb->data + offset);
158         void *tmp = ESP_SKB_CB(skb)->tmp;
159         __be32 *seqhi = esp_tmp_extra(tmp);
160
161         esph->seq_no = esph->spi;
162         esph->spi = *seqhi;
163 }
164
165 static void esp_output_restore_header(struct sk_buff *skb)
166 {
167         void *tmp = ESP_SKB_CB(skb)->tmp;
168         struct esp_output_extra *extra = esp_tmp_extra(tmp);
169
170         esp_restore_header(skb, skb_transport_offset(skb) + extra->esphoff -
171                                 sizeof(__be32));
172 }
173
174 static struct ip_esp_hdr *esp_output_set_extra(struct sk_buff *skb,
175                                                struct xfrm_state *x,
176                                                struct ip_esp_hdr *esph,
177                                                struct esp_output_extra *extra)
178 {
179         /* For ESN we move the header forward by 4 bytes to
180          * accomodate the high bits.  We will move it back after
181          * encryption.
182          */
183         if ((x->props.flags & XFRM_STATE_ESN)) {
184                 __u32 seqhi;
185                 struct xfrm_offload *xo = xfrm_offload(skb);
186
187                 if (xo)
188                         seqhi = xo->seq.hi;
189                 else
190                         seqhi = XFRM_SKB_CB(skb)->seq.output.hi;
191
192                 extra->esphoff = (unsigned char *)esph -
193                                  skb_transport_header(skb);
194                 esph = (struct ip_esp_hdr *)((unsigned char *)esph - 4);
195                 extra->seqhi = esph->spi;
196                 esph->seq_no = htonl(seqhi);
197         }
198
199         esph->spi = x->id.spi;
200
201         return esph;
202 }
203
204 static void esp_output_done_esn(struct crypto_async_request *base, int err)
205 {
206         struct sk_buff *skb = base->data;
207
208         esp_output_restore_header(skb);
209         esp_output_done(base, err);
210 }
211
212 static void esp_output_fill_trailer(u8 *tail, int tfclen, int plen, __u8 proto)
213 {
214         /* Fill padding... */
215         if (tfclen) {
216                 memset(tail, 0, tfclen);
217                 tail += tfclen;
218         }
219         do {
220                 int i;
221                 for (i = 0; i < plen - 2; i++)
222                         tail[i] = i + 1;
223         } while (0);
224         tail[plen - 2] = plen - 2;
225         tail[plen - 1] = proto;
226 }
227
228 static int esp_output_udp_encap(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
229 {
230         int encap_type;
231         struct udphdr *uh;
232         __be32 *udpdata32;
233         __be16 sport, dport;
234         struct xfrm_encap_tmpl *encap = x->encap;
235         struct ip_esp_hdr *esph = esp->esph;
236         unsigned int len;
237
238         spin_lock_bh(&x->lock);
239         sport = encap->encap_sport;
240         dport = encap->encap_dport;
241         encap_type = encap->encap_type;
242         spin_unlock_bh(&x->lock);
243
244         len = skb->len + esp->tailen - skb_transport_offset(skb);
245         if (len + sizeof(struct iphdr) >= IP_MAX_MTU)
246                 return -EMSGSIZE;
247
248         uh = (struct udphdr *)esph;
249         uh->source = sport;
250         uh->dest = dport;
251         uh->len = htons(len);
252         uh->check = 0;
253
254         switch (encap_type) {
255         default:
256         case UDP_ENCAP_ESPINUDP:
257                 esph = (struct ip_esp_hdr *)(uh + 1);
258                 break;
259         case UDP_ENCAP_ESPINUDP_NON_IKE:
260                 udpdata32 = (__be32 *)(uh + 1);
261                 udpdata32[0] = udpdata32[1] = 0;
262                 esph = (struct ip_esp_hdr *)(udpdata32 + 2);
263                 break;
264         }
265
266         *skb_mac_header(skb) = IPPROTO_UDP;
267         esp->esph = esph;
268
269         return 0;
270 }
271
272 int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
273 {
274         u8 *tail;
275         int nfrags;
276         int esph_offset;
277         struct page *page;
278         struct sk_buff *trailer;
279         int tailen = esp->tailen;
280
281         /* this is non-NULL only with UDP Encapsulation */
282         if (x->encap) {
283                 int err = esp_output_udp_encap(x, skb, esp);
284
285                 if (err < 0)
286                         return err;
287         }
288
289         if (ALIGN(tailen, L1_CACHE_BYTES) > PAGE_SIZE ||
290             ALIGN(skb->data_len, L1_CACHE_BYTES) > PAGE_SIZE)
291                 goto cow;
292
293         if (!skb_cloned(skb)) {
294                 if (tailen <= skb_tailroom(skb)) {
295                         nfrags = 1;
296                         trailer = skb;
297                         tail = skb_tail_pointer(trailer);
298
299                         goto skip_cow;
300                 } else if ((skb_shinfo(skb)->nr_frags < MAX_SKB_FRAGS)
301                            && !skb_has_frag_list(skb)) {
302                         int allocsize;
303                         struct sock *sk = skb->sk;
304                         struct page_frag *pfrag = &x->xfrag;
305
306                         esp->inplace = false;
307
308                         allocsize = ALIGN(tailen, L1_CACHE_BYTES);
309
310                         spin_lock_bh(&x->lock);
311
312                         if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) {
313                                 spin_unlock_bh(&x->lock);
314                                 goto cow;
315                         }
316
317                         page = pfrag->page;
318                         get_page(page);
319
320                         tail = page_address(page) + pfrag->offset;
321
322                         esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto);
323
324                         nfrags = skb_shinfo(skb)->nr_frags;
325
326                         __skb_fill_page_desc(skb, nfrags, page, pfrag->offset,
327                                              tailen);
328                         skb_shinfo(skb)->nr_frags = ++nfrags;
329
330                         pfrag->offset = pfrag->offset + allocsize;
331
332                         spin_unlock_bh(&x->lock);
333
334                         nfrags++;
335
336                         skb->len += tailen;
337                         skb->data_len += tailen;
338                         skb->truesize += tailen;
339                         if (sk && sk_fullsock(sk))
340                                 refcount_add(tailen, &sk->sk_wmem_alloc);
341
342                         goto out;
343                 }
344         }
345
346 cow:
347         esph_offset = (unsigned char *)esp->esph - skb_transport_header(skb);
348
349         nfrags = skb_cow_data(skb, tailen, &trailer);
350         if (nfrags < 0)
351                 goto out;
352         tail = skb_tail_pointer(trailer);
353         esp->esph = (struct ip_esp_hdr *)(skb_transport_header(skb) + esph_offset);
354
355 skip_cow:
356         esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto);
357         pskb_put(skb, trailer, tailen);
358
359 out:
360         return nfrags;
361 }
362 EXPORT_SYMBOL_GPL(esp_output_head);
363
364 int esp_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
365 {
366         u8 *iv;
367         int alen;
368         void *tmp;
369         int ivlen;
370         int assoclen;
371         int extralen;
372         struct page *page;
373         struct ip_esp_hdr *esph;
374         struct crypto_aead *aead;
375         struct aead_request *req;
376         struct scatterlist *sg, *dsg;
377         struct esp_output_extra *extra;
378         int err = -ENOMEM;
379
380         assoclen = sizeof(struct ip_esp_hdr);
381         extralen = 0;
382
383         if (x->props.flags & XFRM_STATE_ESN) {
384                 extralen += sizeof(*extra);
385                 assoclen += sizeof(__be32);
386         }
387
388         aead = x->data;
389         alen = crypto_aead_authsize(aead);
390         ivlen = crypto_aead_ivsize(aead);
391
392         tmp = esp_alloc_tmp(aead, esp->nfrags + 2, extralen);
393         if (!tmp)
394                 goto error;
395
396         extra = esp_tmp_extra(tmp);
397         iv = esp_tmp_iv(aead, tmp, extralen);
398         req = esp_tmp_req(aead, iv);
399         sg = esp_req_sg(aead, req);
400
401         if (esp->inplace)
402                 dsg = sg;
403         else
404                 dsg = &sg[esp->nfrags];
405
406         esph = esp_output_set_extra(skb, x, esp->esph, extra);
407         esp->esph = esph;
408
409         sg_init_table(sg, esp->nfrags);
410         err = skb_to_sgvec(skb, sg,
411                            (unsigned char *)esph - skb->data,
412                            assoclen + ivlen + esp->clen + alen);
413         if (unlikely(err < 0))
414                 goto error_free;
415
416         if (!esp->inplace) {
417                 int allocsize;
418                 struct page_frag *pfrag = &x->xfrag;
419
420                 allocsize = ALIGN(skb->data_len, L1_CACHE_BYTES);
421
422                 spin_lock_bh(&x->lock);
423                 if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) {
424                         spin_unlock_bh(&x->lock);
425                         goto error_free;
426                 }
427
428                 skb_shinfo(skb)->nr_frags = 1;
429
430                 page = pfrag->page;
431                 get_page(page);
432                 /* replace page frags in skb with new page */
433                 __skb_fill_page_desc(skb, 0, page, pfrag->offset, skb->data_len);
434                 pfrag->offset = pfrag->offset + allocsize;
435                 spin_unlock_bh(&x->lock);
436
437                 sg_init_table(dsg, skb_shinfo(skb)->nr_frags + 1);
438                 err = skb_to_sgvec(skb, dsg,
439                                    (unsigned char *)esph - skb->data,
440                                    assoclen + ivlen + esp->clen + alen);
441                 if (unlikely(err < 0))
442                         goto error_free;
443         }
444
445         if ((x->props.flags & XFRM_STATE_ESN))
446                 aead_request_set_callback(req, 0, esp_output_done_esn, skb);
447         else
448                 aead_request_set_callback(req, 0, esp_output_done, skb);
449
450         aead_request_set_crypt(req, sg, dsg, ivlen + esp->clen, iv);
451         aead_request_set_ad(req, assoclen);
452
453         memset(iv, 0, ivlen);
454         memcpy(iv + ivlen - min(ivlen, 8), (u8 *)&esp->seqno + 8 - min(ivlen, 8),
455                min(ivlen, 8));
456
457         ESP_SKB_CB(skb)->tmp = tmp;
458         err = crypto_aead_encrypt(req);
459
460         switch (err) {
461         case -EINPROGRESS:
462                 goto error;
463
464         case -ENOSPC:
465                 err = NET_XMIT_DROP;
466                 break;
467
468         case 0:
469                 if ((x->props.flags & XFRM_STATE_ESN))
470                         esp_output_restore_header(skb);
471         }
472
473         if (sg != dsg)
474                 esp_ssg_unref(x, tmp);
475
476 error_free:
477         kfree(tmp);
478 error:
479         return err;
480 }
481 EXPORT_SYMBOL_GPL(esp_output_tail);
482
483 static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
484 {
485         int alen;
486         int blksize;
487         struct ip_esp_hdr *esph;
488         struct crypto_aead *aead;
489         struct esp_info esp;
490
491         esp.inplace = true;
492
493         esp.proto = *skb_mac_header(skb);
494         *skb_mac_header(skb) = IPPROTO_ESP;
495
496         /* skb is pure payload to encrypt */
497
498         aead = x->data;
499         alen = crypto_aead_authsize(aead);
500
501         esp.tfclen = 0;
502         if (x->tfcpad) {
503                 struct xfrm_dst *dst = (struct xfrm_dst *)skb_dst(skb);
504                 u32 padto;
505
506                 padto = min(x->tfcpad, xfrm_state_mtu(x, dst->child_mtu_cached));
507                 if (skb->len < padto)
508                         esp.tfclen = padto - skb->len;
509         }
510         blksize = ALIGN(crypto_aead_blocksize(aead), 4);
511         esp.clen = ALIGN(skb->len + 2 + esp.tfclen, blksize);
512         esp.plen = esp.clen - skb->len - esp.tfclen;
513         esp.tailen = esp.tfclen + esp.plen + alen;
514
515         esp.esph = ip_esp_hdr(skb);
516
517         esp.nfrags = esp_output_head(x, skb, &esp);
518         if (esp.nfrags < 0)
519                 return esp.nfrags;
520
521         esph = esp.esph;
522         esph->spi = x->id.spi;
523
524         esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low);
525         esp.seqno = cpu_to_be64(XFRM_SKB_CB(skb)->seq.output.low +
526                                  ((u64)XFRM_SKB_CB(skb)->seq.output.hi << 32));
527
528         skb_push(skb, -skb_network_offset(skb));
529
530         return esp_output_tail(x, skb, &esp);
531 }
532
533 static inline int esp_remove_trailer(struct sk_buff *skb)
534 {
535         struct xfrm_state *x = xfrm_input_state(skb);
536         struct xfrm_offload *xo = xfrm_offload(skb);
537         struct crypto_aead *aead = x->data;
538         int alen, hlen, elen;
539         int padlen, trimlen;
540         __wsum csumdiff;
541         u8 nexthdr[2];
542         int ret;
543
544         alen = crypto_aead_authsize(aead);
545         hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead);
546         elen = skb->len - hlen;
547
548         if (xo && (xo->flags & XFRM_ESP_NO_TRAILER)) {
549                 ret = xo->proto;
550                 goto out;
551         }
552
553         if (skb_copy_bits(skb, skb->len - alen - 2, nexthdr, 2))
554                 BUG();
555
556         ret = -EINVAL;
557         padlen = nexthdr[0];
558         if (padlen + 2 + alen >= elen) {
559                 net_dbg_ratelimited("ipsec esp packet is garbage padlen=%d, elen=%d\n",
560                                     padlen + 2, elen - alen);
561                 goto out;
562         }
563
564         trimlen = alen + padlen + 2;
565         if (skb->ip_summed == CHECKSUM_COMPLETE) {
566                 csumdiff = skb_checksum(skb, skb->len - trimlen, trimlen, 0);
567                 skb->csum = csum_block_sub(skb->csum, csumdiff,
568                                            skb->len - trimlen);
569         }
570         pskb_trim(skb, skb->len - trimlen);
571
572         ret = nexthdr[1];
573
574 out:
575         return ret;
576 }
577
578 int esp_input_done2(struct sk_buff *skb, int err)
579 {
580         const struct iphdr *iph;
581         struct xfrm_state *x = xfrm_input_state(skb);
582         struct xfrm_offload *xo = xfrm_offload(skb);
583         struct crypto_aead *aead = x->data;
584         int hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead);
585         int ihl;
586
587         if (!xo || (xo && !(xo->flags & CRYPTO_DONE)))
588                 kfree(ESP_SKB_CB(skb)->tmp);
589
590         if (unlikely(err))
591                 goto out;
592
593         err = esp_remove_trailer(skb);
594         if (unlikely(err < 0))
595                 goto out;
596
597         iph = ip_hdr(skb);
598         ihl = iph->ihl * 4;
599
600         if (x->encap) {
601                 struct xfrm_encap_tmpl *encap = x->encap;
602                 struct udphdr *uh = (void *)(skb_network_header(skb) + ihl);
603
604                 /*
605                  * 1) if the NAT-T peer's IP or port changed then
606                  *    advertize the change to the keying daemon.
607                  *    This is an inbound SA, so just compare
608                  *    SRC ports.
609                  */
610                 if (iph->saddr != x->props.saddr.a4 ||
611                     uh->source != encap->encap_sport) {
612                         xfrm_address_t ipaddr;
613
614                         ipaddr.a4 = iph->saddr;
615                         km_new_mapping(x, &ipaddr, uh->source);
616
617                         /* XXX: perhaps add an extra
618                          * policy check here, to see
619                          * if we should allow or
620                          * reject a packet from a
621                          * different source
622                          * address/port.
623                          */
624                 }
625
626                 /*
627                  * 2) ignore UDP/TCP checksums in case
628                  *    of NAT-T in Transport Mode, or
629                  *    perform other post-processing fixes
630                  *    as per draft-ietf-ipsec-udp-encaps-06,
631                  *    section 3.1.2
632                  */
633                 if (x->props.mode == XFRM_MODE_TRANSPORT)
634                         skb->ip_summed = CHECKSUM_UNNECESSARY;
635         }
636
637         skb_pull_rcsum(skb, hlen);
638         if (x->props.mode == XFRM_MODE_TUNNEL)
639                 skb_reset_transport_header(skb);
640         else
641                 skb_set_transport_header(skb, -ihl);
642
643         /* RFC4303: Drop dummy packets without any error */
644         if (err == IPPROTO_NONE)
645                 err = -EINVAL;
646
647 out:
648         return err;
649 }
650 EXPORT_SYMBOL_GPL(esp_input_done2);
651
652 static void esp_input_done(struct crypto_async_request *base, int err)
653 {
654         struct sk_buff *skb = base->data;
655
656         xfrm_input_resume(skb, esp_input_done2(skb, err));
657 }
658
659 static void esp_input_restore_header(struct sk_buff *skb)
660 {
661         esp_restore_header(skb, 0);
662         __skb_pull(skb, 4);
663 }
664
665 static void esp_input_set_header(struct sk_buff *skb, __be32 *seqhi)
666 {
667         struct xfrm_state *x = xfrm_input_state(skb);
668         struct ip_esp_hdr *esph;
669
670         /* For ESN we move the header forward by 4 bytes to
671          * accomodate the high bits.  We will move it back after
672          * decryption.
673          */
674         if ((x->props.flags & XFRM_STATE_ESN)) {
675                 esph = skb_push(skb, 4);
676                 *seqhi = esph->spi;
677                 esph->spi = esph->seq_no;
678                 esph->seq_no = XFRM_SKB_CB(skb)->seq.input.hi;
679         }
680 }
681
682 static void esp_input_done_esn(struct crypto_async_request *base, int err)
683 {
684         struct sk_buff *skb = base->data;
685
686         esp_input_restore_header(skb);
687         esp_input_done(base, err);
688 }
689
690 /*
691  * Note: detecting truncated vs. non-truncated authentication data is very
692  * expensive, so we only support truncated data, which is the recommended
693  * and common case.
694  */
695 static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
696 {
697         struct crypto_aead *aead = x->data;
698         struct aead_request *req;
699         struct sk_buff *trailer;
700         int ivlen = crypto_aead_ivsize(aead);
701         int elen = skb->len - sizeof(struct ip_esp_hdr) - ivlen;
702         int nfrags;
703         int assoclen;
704         int seqhilen;
705         __be32 *seqhi;
706         void *tmp;
707         u8 *iv;
708         struct scatterlist *sg;
709         int err = -EINVAL;
710
711         if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + ivlen))
712                 goto out;
713
714         if (elen <= 0)
715                 goto out;
716
717         assoclen = sizeof(struct ip_esp_hdr);
718         seqhilen = 0;
719
720         if (x->props.flags & XFRM_STATE_ESN) {
721                 seqhilen += sizeof(__be32);
722                 assoclen += seqhilen;
723         }
724
725         if (!skb_cloned(skb)) {
726                 if (!skb_is_nonlinear(skb)) {
727                         nfrags = 1;
728
729                         goto skip_cow;
730                 } else if (!skb_has_frag_list(skb)) {
731                         nfrags = skb_shinfo(skb)->nr_frags;
732                         nfrags++;
733
734                         goto skip_cow;
735                 }
736         }
737
738         err = skb_cow_data(skb, 0, &trailer);
739         if (err < 0)
740                 goto out;
741
742         nfrags = err;
743
744 skip_cow:
745         err = -ENOMEM;
746         tmp = esp_alloc_tmp(aead, nfrags, seqhilen);
747         if (!tmp)
748                 goto out;
749
750         ESP_SKB_CB(skb)->tmp = tmp;
751         seqhi = esp_tmp_extra(tmp);
752         iv = esp_tmp_iv(aead, tmp, seqhilen);
753         req = esp_tmp_req(aead, iv);
754         sg = esp_req_sg(aead, req);
755
756         esp_input_set_header(skb, seqhi);
757
758         sg_init_table(sg, nfrags);
759         err = skb_to_sgvec(skb, sg, 0, skb->len);
760         if (unlikely(err < 0)) {
761                 kfree(tmp);
762                 goto out;
763         }
764
765         skb->ip_summed = CHECKSUM_NONE;
766
767         if ((x->props.flags & XFRM_STATE_ESN))
768                 aead_request_set_callback(req, 0, esp_input_done_esn, skb);
769         else
770                 aead_request_set_callback(req, 0, esp_input_done, skb);
771
772         aead_request_set_crypt(req, sg, sg, elen + ivlen, iv);
773         aead_request_set_ad(req, assoclen);
774
775         err = crypto_aead_decrypt(req);
776         if (err == -EINPROGRESS)
777                 goto out;
778
779         if ((x->props.flags & XFRM_STATE_ESN))
780                 esp_input_restore_header(skb);
781
782         err = esp_input_done2(skb, err);
783
784 out:
785         return err;
786 }
787
788 static int esp4_err(struct sk_buff *skb, u32 info)
789 {
790         struct net *net = dev_net(skb->dev);
791         const struct iphdr *iph = (const struct iphdr *)skb->data;
792         struct ip_esp_hdr *esph = (struct ip_esp_hdr *)(skb->data+(iph->ihl<<2));
793         struct xfrm_state *x;
794
795         switch (icmp_hdr(skb)->type) {
796         case ICMP_DEST_UNREACH:
797                 if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED)
798                         return 0;
799         case ICMP_REDIRECT:
800                 break;
801         default:
802                 return 0;
803         }
804
805         x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr,
806                               esph->spi, IPPROTO_ESP, AF_INET);
807         if (!x)
808                 return 0;
809
810         if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH)
811                 ipv4_update_pmtu(skb, net, info, 0, IPPROTO_ESP);
812         else
813                 ipv4_redirect(skb, net, 0, IPPROTO_ESP);
814         xfrm_state_put(x);
815
816         return 0;
817 }
818
819 static void esp_destroy(struct xfrm_state *x)
820 {
821         struct crypto_aead *aead = x->data;
822
823         if (!aead)
824                 return;
825
826         crypto_free_aead(aead);
827 }
828
829 static int esp_init_aead(struct xfrm_state *x)
830 {
831         char aead_name[CRYPTO_MAX_ALG_NAME];
832         struct crypto_aead *aead;
833         int err;
834
835         err = -ENAMETOOLONG;
836         if (snprintf(aead_name, CRYPTO_MAX_ALG_NAME, "%s(%s)",
837                      x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME)
838                 goto error;
839
840         aead = crypto_alloc_aead(aead_name, 0, 0);
841         err = PTR_ERR(aead);
842         if (IS_ERR(aead))
843                 goto error;
844
845         x->data = aead;
846
847         err = crypto_aead_setkey(aead, x->aead->alg_key,
848                                  (x->aead->alg_key_len + 7) / 8);
849         if (err)
850                 goto error;
851
852         err = crypto_aead_setauthsize(aead, x->aead->alg_icv_len / 8);
853         if (err)
854                 goto error;
855
856 error:
857         return err;
858 }
859
860 static int esp_init_authenc(struct xfrm_state *x)
861 {
862         struct crypto_aead *aead;
863         struct crypto_authenc_key_param *param;
864         struct rtattr *rta;
865         char *key;
866         char *p;
867         char authenc_name[CRYPTO_MAX_ALG_NAME];
868         unsigned int keylen;
869         int err;
870
871         err = -EINVAL;
872         if (!x->ealg)
873                 goto error;
874
875         err = -ENAMETOOLONG;
876
877         if ((x->props.flags & XFRM_STATE_ESN)) {
878                 if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME,
879                              "%s%sauthencesn(%s,%s)%s",
880                              x->geniv ?: "", x->geniv ? "(" : "",
881                              x->aalg ? x->aalg->alg_name : "digest_null",
882                              x->ealg->alg_name,
883                              x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME)
884                         goto error;
885         } else {
886                 if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME,
887                              "%s%sauthenc(%s,%s)%s",
888                              x->geniv ?: "", x->geniv ? "(" : "",
889                              x->aalg ? x->aalg->alg_name : "digest_null",
890                              x->ealg->alg_name,
891                              x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME)
892                         goto error;
893         }
894
895         aead = crypto_alloc_aead(authenc_name, 0, 0);
896         err = PTR_ERR(aead);
897         if (IS_ERR(aead))
898                 goto error;
899
900         x->data = aead;
901
902         keylen = (x->aalg ? (x->aalg->alg_key_len + 7) / 8 : 0) +
903                  (x->ealg->alg_key_len + 7) / 8 + RTA_SPACE(sizeof(*param));
904         err = -ENOMEM;
905         key = kmalloc(keylen, GFP_KERNEL);
906         if (!key)
907                 goto error;
908
909         p = key;
910         rta = (void *)p;
911         rta->rta_type = CRYPTO_AUTHENC_KEYA_PARAM;
912         rta->rta_len = RTA_LENGTH(sizeof(*param));
913         param = RTA_DATA(rta);
914         p += RTA_SPACE(sizeof(*param));
915
916         if (x->aalg) {
917                 struct xfrm_algo_desc *aalg_desc;
918
919                 memcpy(p, x->aalg->alg_key, (x->aalg->alg_key_len + 7) / 8);
920                 p += (x->aalg->alg_key_len + 7) / 8;
921
922                 aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0);
923                 BUG_ON(!aalg_desc);
924
925                 err = -EINVAL;
926                 if (aalg_desc->uinfo.auth.icv_fullbits / 8 !=
927                     crypto_aead_authsize(aead)) {
928                         pr_info("ESP: %s digestsize %u != %hu\n",
929                                 x->aalg->alg_name,
930                                 crypto_aead_authsize(aead),
931                                 aalg_desc->uinfo.auth.icv_fullbits / 8);
932                         goto free_key;
933                 }
934
935                 err = crypto_aead_setauthsize(
936                         aead, x->aalg->alg_trunc_len / 8);
937                 if (err)
938                         goto free_key;
939         }
940
941         param->enckeylen = cpu_to_be32((x->ealg->alg_key_len + 7) / 8);
942         memcpy(p, x->ealg->alg_key, (x->ealg->alg_key_len + 7) / 8);
943
944         err = crypto_aead_setkey(aead, key, keylen);
945
946 free_key:
947         kfree(key);
948
949 error:
950         return err;
951 }
952
953 static int esp_init_state(struct xfrm_state *x)
954 {
955         struct crypto_aead *aead;
956         u32 align;
957         int err;
958
959         x->data = NULL;
960
961         if (x->aead)
962                 err = esp_init_aead(x);
963         else
964                 err = esp_init_authenc(x);
965
966         if (err)
967                 goto error;
968
969         aead = x->data;
970
971         x->props.header_len = sizeof(struct ip_esp_hdr) +
972                               crypto_aead_ivsize(aead);
973         if (x->props.mode == XFRM_MODE_TUNNEL)
974                 x->props.header_len += sizeof(struct iphdr);
975         else if (x->props.mode == XFRM_MODE_BEET && x->sel.family != AF_INET6)
976                 x->props.header_len += IPV4_BEET_PHMAXLEN;
977         if (x->encap) {
978                 struct xfrm_encap_tmpl *encap = x->encap;
979
980                 switch (encap->encap_type) {
981                 default:
982                         err = -EINVAL;
983                         goto error;
984                 case UDP_ENCAP_ESPINUDP:
985                         x->props.header_len += sizeof(struct udphdr);
986                         break;
987                 case UDP_ENCAP_ESPINUDP_NON_IKE:
988                         x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32);
989                         break;
990                 }
991         }
992
993         align = ALIGN(crypto_aead_blocksize(aead), 4);
994         x->props.trailer_len = align + 1 + crypto_aead_authsize(aead);
995
996 error:
997         return err;
998 }
999
1000 static int esp4_rcv_cb(struct sk_buff *skb, int err)
1001 {
1002         return 0;
1003 }
1004
1005 static const struct xfrm_type esp_type =
1006 {
1007         .description    = "ESP4",
1008         .owner          = THIS_MODULE,
1009         .proto          = IPPROTO_ESP,
1010         .flags          = XFRM_TYPE_REPLAY_PROT,
1011         .init_state     = esp_init_state,
1012         .destructor     = esp_destroy,
1013         .input          = esp_input,
1014         .output         = esp_output,
1015 };
1016
1017 static struct xfrm4_protocol esp4_protocol = {
1018         .handler        =       xfrm4_rcv,
1019         .input_handler  =       xfrm_input,
1020         .cb_handler     =       esp4_rcv_cb,
1021         .err_handler    =       esp4_err,
1022         .priority       =       0,
1023 };
1024
1025 static int __init esp4_init(void)
1026 {
1027         if (xfrm_register_type(&esp_type, AF_INET) < 0) {
1028                 pr_info("%s: can't add xfrm type\n", __func__);
1029                 return -EAGAIN;
1030         }
1031         if (xfrm4_protocol_register(&esp4_protocol, IPPROTO_ESP) < 0) {
1032                 pr_info("%s: can't add protocol\n", __func__);
1033                 xfrm_unregister_type(&esp_type, AF_INET);
1034                 return -EAGAIN;
1035         }
1036         return 0;
1037 }
1038
1039 static void __exit esp4_fini(void)
1040 {
1041         if (xfrm4_protocol_deregister(&esp4_protocol, IPPROTO_ESP) < 0)
1042                 pr_info("%s: can't remove protocol\n", __func__);
1043         xfrm_unregister_type(&esp_type, AF_INET);
1044 }
1045
1046 module_init(esp4_init);
1047 module_exit(esp4_fini);
1048 MODULE_LICENSE("GPL");
1049 MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_ESP);