GNU Linux-libre 6.8.7-gnu
[releases.git] / net / netfilter / nf_nat_proto.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* (C) 1999-2001 Paul `Rusty' Russell
3  * (C) 2002-2006 Netfilter Core Team <coreteam@netfilter.org>
4  */
5
6 #include <linux/types.h>
7 #include <linux/export.h>
8 #include <linux/init.h>
9 #include <linux/udp.h>
10 #include <linux/tcp.h>
11 #include <linux/icmp.h>
12 #include <linux/icmpv6.h>
13
14 #include <linux/dccp.h>
15 #include <linux/sctp.h>
16 #include <net/sctp/checksum.h>
17
18 #include <linux/netfilter.h>
19 #include <net/netfilter/nf_nat.h>
20
21 #include <linux/ipv6.h>
22 #include <linux/netfilter_ipv6.h>
23 #include <net/checksum.h>
24 #include <net/ip6_checksum.h>
25 #include <net/ip6_route.h>
26 #include <net/xfrm.h>
27 #include <net/ipv6.h>
28
29 #include <net/netfilter/nf_conntrack_core.h>
30 #include <net/netfilter/nf_conntrack.h>
31 #include <linux/netfilter/nfnetlink_conntrack.h>
32
33 static void nf_csum_update(struct sk_buff *skb,
34                            unsigned int iphdroff, __sum16 *check,
35                            const struct nf_conntrack_tuple *t,
36                            enum nf_nat_manip_type maniptype);
37
38 static void
39 __udp_manip_pkt(struct sk_buff *skb,
40                 unsigned int iphdroff, struct udphdr *hdr,
41                 const struct nf_conntrack_tuple *tuple,
42                 enum nf_nat_manip_type maniptype, bool do_csum)
43 {
44         __be16 *portptr, newport;
45
46         if (maniptype == NF_NAT_MANIP_SRC) {
47                 /* Get rid of src port */
48                 newport = tuple->src.u.udp.port;
49                 portptr = &hdr->source;
50         } else {
51                 /* Get rid of dst port */
52                 newport = tuple->dst.u.udp.port;
53                 portptr = &hdr->dest;
54         }
55         if (do_csum) {
56                 nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype);
57                 inet_proto_csum_replace2(&hdr->check, skb, *portptr, newport,
58                                          false);
59                 if (!hdr->check)
60                         hdr->check = CSUM_MANGLED_0;
61         }
62         *portptr = newport;
63 }
64
65 static bool udp_manip_pkt(struct sk_buff *skb,
66                           unsigned int iphdroff, unsigned int hdroff,
67                           const struct nf_conntrack_tuple *tuple,
68                           enum nf_nat_manip_type maniptype)
69 {
70         struct udphdr *hdr;
71
72         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
73                 return false;
74
75         hdr = (struct udphdr *)(skb->data + hdroff);
76         __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, !!hdr->check);
77
78         return true;
79 }
80
81 static bool udplite_manip_pkt(struct sk_buff *skb,
82                               unsigned int iphdroff, unsigned int hdroff,
83                               const struct nf_conntrack_tuple *tuple,
84                               enum nf_nat_manip_type maniptype)
85 {
86 #ifdef CONFIG_NF_CT_PROTO_UDPLITE
87         struct udphdr *hdr;
88
89         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
90                 return false;
91
92         hdr = (struct udphdr *)(skb->data + hdroff);
93         __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, true);
94 #endif
95         return true;
96 }
97
98 static bool
99 sctp_manip_pkt(struct sk_buff *skb,
100                unsigned int iphdroff, unsigned int hdroff,
101                const struct nf_conntrack_tuple *tuple,
102                enum nf_nat_manip_type maniptype)
103 {
104 #ifdef CONFIG_NF_CT_PROTO_SCTP
105         struct sctphdr *hdr;
106         int hdrsize = 8;
107
108         /* This could be an inner header returned in imcp packet; in such
109          * cases we cannot update the checksum field since it is outside
110          * of the 8 bytes of transport layer headers we are guaranteed.
111          */
112         if (skb->len >= hdroff + sizeof(*hdr))
113                 hdrsize = sizeof(*hdr);
114
115         if (skb_ensure_writable(skb, hdroff + hdrsize))
116                 return false;
117
118         hdr = (struct sctphdr *)(skb->data + hdroff);
119
120         if (maniptype == NF_NAT_MANIP_SRC) {
121                 /* Get rid of src port */
122                 hdr->source = tuple->src.u.sctp.port;
123         } else {
124                 /* Get rid of dst port */
125                 hdr->dest = tuple->dst.u.sctp.port;
126         }
127
128         if (hdrsize < sizeof(*hdr))
129                 return true;
130
131         if (skb->ip_summed != CHECKSUM_PARTIAL) {
132                 hdr->checksum = sctp_compute_cksum(skb, hdroff);
133                 skb->ip_summed = CHECKSUM_NONE;
134         }
135
136 #endif
137         return true;
138 }
139
140 static bool
141 tcp_manip_pkt(struct sk_buff *skb,
142               unsigned int iphdroff, unsigned int hdroff,
143               const struct nf_conntrack_tuple *tuple,
144               enum nf_nat_manip_type maniptype)
145 {
146         struct tcphdr *hdr;
147         __be16 *portptr, newport, oldport;
148         int hdrsize = 8; /* TCP connection tracking guarantees this much */
149
150         /* this could be a inner header returned in icmp packet; in such
151            cases we cannot update the checksum field since it is outside of
152            the 8 bytes of transport layer headers we are guaranteed */
153         if (skb->len >= hdroff + sizeof(struct tcphdr))
154                 hdrsize = sizeof(struct tcphdr);
155
156         if (skb_ensure_writable(skb, hdroff + hdrsize))
157                 return false;
158
159         hdr = (struct tcphdr *)(skb->data + hdroff);
160
161         if (maniptype == NF_NAT_MANIP_SRC) {
162                 /* Get rid of src port */
163                 newport = tuple->src.u.tcp.port;
164                 portptr = &hdr->source;
165         } else {
166                 /* Get rid of dst port */
167                 newport = tuple->dst.u.tcp.port;
168                 portptr = &hdr->dest;
169         }
170
171         oldport = *portptr;
172         *portptr = newport;
173
174         if (hdrsize < sizeof(*hdr))
175                 return true;
176
177         nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype);
178         inet_proto_csum_replace2(&hdr->check, skb, oldport, newport, false);
179         return true;
180 }
181
182 static bool
183 dccp_manip_pkt(struct sk_buff *skb,
184                unsigned int iphdroff, unsigned int hdroff,
185                const struct nf_conntrack_tuple *tuple,
186                enum nf_nat_manip_type maniptype)
187 {
188 #ifdef CONFIG_NF_CT_PROTO_DCCP
189         struct dccp_hdr *hdr;
190         __be16 *portptr, oldport, newport;
191         int hdrsize = 8; /* DCCP connection tracking guarantees this much */
192
193         if (skb->len >= hdroff + sizeof(struct dccp_hdr))
194                 hdrsize = sizeof(struct dccp_hdr);
195
196         if (skb_ensure_writable(skb, hdroff + hdrsize))
197                 return false;
198
199         hdr = (struct dccp_hdr *)(skb->data + hdroff);
200
201         if (maniptype == NF_NAT_MANIP_SRC) {
202                 newport = tuple->src.u.dccp.port;
203                 portptr = &hdr->dccph_sport;
204         } else {
205                 newport = tuple->dst.u.dccp.port;
206                 portptr = &hdr->dccph_dport;
207         }
208
209         oldport = *portptr;
210         *portptr = newport;
211
212         if (hdrsize < sizeof(*hdr))
213                 return true;
214
215         nf_csum_update(skb, iphdroff, &hdr->dccph_checksum, tuple, maniptype);
216         inet_proto_csum_replace2(&hdr->dccph_checksum, skb, oldport, newport,
217                                  false);
218 #endif
219         return true;
220 }
221
222 static bool
223 icmp_manip_pkt(struct sk_buff *skb,
224                unsigned int iphdroff, unsigned int hdroff,
225                const struct nf_conntrack_tuple *tuple,
226                enum nf_nat_manip_type maniptype)
227 {
228         struct icmphdr *hdr;
229
230         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
231                 return false;
232
233         hdr = (struct icmphdr *)(skb->data + hdroff);
234         switch (hdr->type) {
235         case ICMP_ECHO:
236         case ICMP_ECHOREPLY:
237         case ICMP_TIMESTAMP:
238         case ICMP_TIMESTAMPREPLY:
239         case ICMP_INFO_REQUEST:
240         case ICMP_INFO_REPLY:
241         case ICMP_ADDRESS:
242         case ICMP_ADDRESSREPLY:
243                 break;
244         default:
245                 return true;
246         }
247         inet_proto_csum_replace2(&hdr->checksum, skb,
248                                  hdr->un.echo.id, tuple->src.u.icmp.id, false);
249         hdr->un.echo.id = tuple->src.u.icmp.id;
250         return true;
251 }
252
253 static bool
254 icmpv6_manip_pkt(struct sk_buff *skb,
255                  unsigned int iphdroff, unsigned int hdroff,
256                  const struct nf_conntrack_tuple *tuple,
257                  enum nf_nat_manip_type maniptype)
258 {
259         struct icmp6hdr *hdr;
260
261         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
262                 return false;
263
264         hdr = (struct icmp6hdr *)(skb->data + hdroff);
265         nf_csum_update(skb, iphdroff, &hdr->icmp6_cksum, tuple, maniptype);
266         if (hdr->icmp6_type == ICMPV6_ECHO_REQUEST ||
267             hdr->icmp6_type == ICMPV6_ECHO_REPLY) {
268                 inet_proto_csum_replace2(&hdr->icmp6_cksum, skb,
269                                          hdr->icmp6_identifier,
270                                          tuple->src.u.icmp.id, false);
271                 hdr->icmp6_identifier = tuple->src.u.icmp.id;
272         }
273         return true;
274 }
275
276 /* manipulate a GRE packet according to maniptype */
277 static bool
278 gre_manip_pkt(struct sk_buff *skb,
279               unsigned int iphdroff, unsigned int hdroff,
280               const struct nf_conntrack_tuple *tuple,
281               enum nf_nat_manip_type maniptype)
282 {
283 #if IS_ENABLED(CONFIG_NF_CT_PROTO_GRE)
284         const struct gre_base_hdr *greh;
285         struct pptp_gre_header *pgreh;
286
287         /* pgreh includes two optional 32bit fields which are not required
288          * to be there.  That's where the magic '8' comes from */
289         if (skb_ensure_writable(skb, hdroff + sizeof(*pgreh) - 8))
290                 return false;
291
292         greh = (void *)skb->data + hdroff;
293         pgreh = (struct pptp_gre_header *)greh;
294
295         /* we only have destination manip of a packet, since 'source key'
296          * is not present in the packet itself */
297         if (maniptype != NF_NAT_MANIP_DST)
298                 return true;
299
300         switch (greh->flags & GRE_VERSION) {
301         case GRE_VERSION_0:
302                 /* We do not currently NAT any GREv0 packets.
303                  * Try to behave like "nf_nat_proto_unknown" */
304                 break;
305         case GRE_VERSION_1:
306                 pr_debug("call_id -> 0x%04x\n", ntohs(tuple->dst.u.gre.key));
307                 pgreh->call_id = tuple->dst.u.gre.key;
308                 break;
309         default:
310                 pr_debug("can't nat unknown GRE version\n");
311                 return false;
312         }
313 #endif
314         return true;
315 }
316
317 static bool l4proto_manip_pkt(struct sk_buff *skb,
318                               unsigned int iphdroff, unsigned int hdroff,
319                               const struct nf_conntrack_tuple *tuple,
320                               enum nf_nat_manip_type maniptype)
321 {
322         switch (tuple->dst.protonum) {
323         case IPPROTO_TCP:
324                 return tcp_manip_pkt(skb, iphdroff, hdroff,
325                                      tuple, maniptype);
326         case IPPROTO_UDP:
327                 return udp_manip_pkt(skb, iphdroff, hdroff,
328                                      tuple, maniptype);
329         case IPPROTO_UDPLITE:
330                 return udplite_manip_pkt(skb, iphdroff, hdroff,
331                                          tuple, maniptype);
332         case IPPROTO_SCTP:
333                 return sctp_manip_pkt(skb, iphdroff, hdroff,
334                                       tuple, maniptype);
335         case IPPROTO_ICMP:
336                 return icmp_manip_pkt(skb, iphdroff, hdroff,
337                                       tuple, maniptype);
338         case IPPROTO_ICMPV6:
339                 return icmpv6_manip_pkt(skb, iphdroff, hdroff,
340                                         tuple, maniptype);
341         case IPPROTO_DCCP:
342                 return dccp_manip_pkt(skb, iphdroff, hdroff,
343                                       tuple, maniptype);
344         case IPPROTO_GRE:
345                 return gre_manip_pkt(skb, iphdroff, hdroff,
346                                      tuple, maniptype);
347         }
348
349         /* If we don't know protocol -- no error, pass it unmodified. */
350         return true;
351 }
352
353 static bool nf_nat_ipv4_manip_pkt(struct sk_buff *skb,
354                                   unsigned int iphdroff,
355                                   const struct nf_conntrack_tuple *target,
356                                   enum nf_nat_manip_type maniptype)
357 {
358         struct iphdr *iph;
359         unsigned int hdroff;
360
361         if (skb_ensure_writable(skb, iphdroff + sizeof(*iph)))
362                 return false;
363
364         iph = (void *)skb->data + iphdroff;
365         hdroff = iphdroff + iph->ihl * 4;
366
367         if (!l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype))
368                 return false;
369         iph = (void *)skb->data + iphdroff;
370
371         if (maniptype == NF_NAT_MANIP_SRC) {
372                 csum_replace4(&iph->check, iph->saddr, target->src.u3.ip);
373                 iph->saddr = target->src.u3.ip;
374         } else {
375                 csum_replace4(&iph->check, iph->daddr, target->dst.u3.ip);
376                 iph->daddr = target->dst.u3.ip;
377         }
378         return true;
379 }
380
381 static bool nf_nat_ipv6_manip_pkt(struct sk_buff *skb,
382                                   unsigned int iphdroff,
383                                   const struct nf_conntrack_tuple *target,
384                                   enum nf_nat_manip_type maniptype)
385 {
386 #if IS_ENABLED(CONFIG_IPV6)
387         struct ipv6hdr *ipv6h;
388         __be16 frag_off;
389         int hdroff;
390         u8 nexthdr;
391
392         if (skb_ensure_writable(skb, iphdroff + sizeof(*ipv6h)))
393                 return false;
394
395         ipv6h = (void *)skb->data + iphdroff;
396         nexthdr = ipv6h->nexthdr;
397         hdroff = ipv6_skip_exthdr(skb, iphdroff + sizeof(*ipv6h),
398                                   &nexthdr, &frag_off);
399         if (hdroff < 0)
400                 goto manip_addr;
401
402         if ((frag_off & htons(~0x7)) == 0 &&
403             !l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype))
404                 return false;
405
406         /* must reload, offset might have changed */
407         ipv6h = (void *)skb->data + iphdroff;
408
409 manip_addr:
410         if (maniptype == NF_NAT_MANIP_SRC)
411                 ipv6h->saddr = target->src.u3.in6;
412         else
413                 ipv6h->daddr = target->dst.u3.in6;
414
415 #endif
416         return true;
417 }
418
419 unsigned int nf_nat_manip_pkt(struct sk_buff *skb, struct nf_conn *ct,
420                               enum nf_nat_manip_type mtype,
421                               enum ip_conntrack_dir dir)
422 {
423         struct nf_conntrack_tuple target;
424
425         /* We are aiming to look like inverse of other direction. */
426         nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple);
427
428         switch (target.src.l3num) {
429         case NFPROTO_IPV6:
430                 if (nf_nat_ipv6_manip_pkt(skb, 0, &target, mtype))
431                         return NF_ACCEPT;
432                 break;
433         case NFPROTO_IPV4:
434                 if (nf_nat_ipv4_manip_pkt(skb, 0, &target, mtype))
435                         return NF_ACCEPT;
436                 break;
437         default:
438                 WARN_ON_ONCE(1);
439                 break;
440         }
441
442         return NF_DROP;
443 }
444
445 static void nf_nat_ipv4_csum_update(struct sk_buff *skb,
446                                     unsigned int iphdroff, __sum16 *check,
447                                     const struct nf_conntrack_tuple *t,
448                                     enum nf_nat_manip_type maniptype)
449 {
450         struct iphdr *iph = (struct iphdr *)(skb->data + iphdroff);
451         __be32 oldip, newip;
452
453         if (maniptype == NF_NAT_MANIP_SRC) {
454                 oldip = iph->saddr;
455                 newip = t->src.u3.ip;
456         } else {
457                 oldip = iph->daddr;
458                 newip = t->dst.u3.ip;
459         }
460         inet_proto_csum_replace4(check, skb, oldip, newip, true);
461 }
462
463 static void nf_nat_ipv6_csum_update(struct sk_buff *skb,
464                                     unsigned int iphdroff, __sum16 *check,
465                                     const struct nf_conntrack_tuple *t,
466                                     enum nf_nat_manip_type maniptype)
467 {
468 #if IS_ENABLED(CONFIG_IPV6)
469         const struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + iphdroff);
470         const struct in6_addr *oldip, *newip;
471
472         if (maniptype == NF_NAT_MANIP_SRC) {
473                 oldip = &ipv6h->saddr;
474                 newip = &t->src.u3.in6;
475         } else {
476                 oldip = &ipv6h->daddr;
477                 newip = &t->dst.u3.in6;
478         }
479         inet_proto_csum_replace16(check, skb, oldip->s6_addr32,
480                                   newip->s6_addr32, true);
481 #endif
482 }
483
484 static void nf_csum_update(struct sk_buff *skb,
485                            unsigned int iphdroff, __sum16 *check,
486                            const struct nf_conntrack_tuple *t,
487                            enum nf_nat_manip_type maniptype)
488 {
489         switch (t->src.l3num) {
490         case NFPROTO_IPV4:
491                 nf_nat_ipv4_csum_update(skb, iphdroff, check, t, maniptype);
492                 return;
493         case NFPROTO_IPV6:
494                 nf_nat_ipv6_csum_update(skb, iphdroff, check, t, maniptype);
495                 return;
496         }
497 }
498
499 static void nf_nat_ipv4_csum_recalc(struct sk_buff *skb,
500                                     u8 proto, void *data, __sum16 *check,
501                                     int datalen, int oldlen)
502 {
503         if (skb->ip_summed != CHECKSUM_PARTIAL) {
504                 const struct iphdr *iph = ip_hdr(skb);
505
506                 skb->ip_summed = CHECKSUM_PARTIAL;
507                 skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) +
508                         ip_hdrlen(skb);
509                 skb->csum_offset = (void *)check - data;
510                 *check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
511                                             proto, 0);
512         } else {
513                 inet_proto_csum_replace2(check, skb,
514                                          htons(oldlen), htons(datalen), true);
515         }
516 }
517
518 #if IS_ENABLED(CONFIG_IPV6)
519 static void nf_nat_ipv6_csum_recalc(struct sk_buff *skb,
520                                     u8 proto, void *data, __sum16 *check,
521                                     int datalen, int oldlen)
522 {
523         if (skb->ip_summed != CHECKSUM_PARTIAL) {
524                 const struct ipv6hdr *ipv6h = ipv6_hdr(skb);
525
526                 skb->ip_summed = CHECKSUM_PARTIAL;
527                 skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) +
528                         (data - (void *)skb->data);
529                 skb->csum_offset = (void *)check - data;
530                 *check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
531                                           datalen, proto, 0);
532         } else {
533                 inet_proto_csum_replace2(check, skb,
534                                          htons(oldlen), htons(datalen), true);
535         }
536 }
537 #endif
538
539 void nf_nat_csum_recalc(struct sk_buff *skb,
540                         u8 nfproto, u8 proto, void *data, __sum16 *check,
541                         int datalen, int oldlen)
542 {
543         switch (nfproto) {
544         case NFPROTO_IPV4:
545                 nf_nat_ipv4_csum_recalc(skb, proto, data, check,
546                                         datalen, oldlen);
547                 return;
548 #if IS_ENABLED(CONFIG_IPV6)
549         case NFPROTO_IPV6:
550                 nf_nat_ipv6_csum_recalc(skb, proto, data, check,
551                                         datalen, oldlen);
552                 return;
553 #endif
554         }
555
556         WARN_ON_ONCE(1);
557 }
558
559 int nf_nat_icmp_reply_translation(struct sk_buff *skb,
560                                   struct nf_conn *ct,
561                                   enum ip_conntrack_info ctinfo,
562                                   unsigned int hooknum)
563 {
564         struct {
565                 struct icmphdr  icmp;
566                 struct iphdr    ip;
567         } *inside;
568         enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
569         enum nf_nat_manip_type manip = HOOK2MANIP(hooknum);
570         unsigned int hdrlen = ip_hdrlen(skb);
571         struct nf_conntrack_tuple target;
572         unsigned long statusbit;
573
574         WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY);
575
576         if (skb_ensure_writable(skb, hdrlen + sizeof(*inside)))
577                 return 0;
578         if (nf_ip_checksum(skb, hooknum, hdrlen, IPPROTO_ICMP))
579                 return 0;
580
581         inside = (void *)skb->data + hdrlen;
582         if (inside->icmp.type == ICMP_REDIRECT) {
583                 if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK)
584                         return 0;
585                 if (ct->status & IPS_NAT_MASK)
586                         return 0;
587         }
588
589         if (manip == NF_NAT_MANIP_SRC)
590                 statusbit = IPS_SRC_NAT;
591         else
592                 statusbit = IPS_DST_NAT;
593
594         /* Invert if this is reply direction */
595         if (dir == IP_CT_DIR_REPLY)
596                 statusbit ^= IPS_NAT_MASK;
597
598         if (!(ct->status & statusbit))
599                 return 1;
600
601         if (!nf_nat_ipv4_manip_pkt(skb, hdrlen + sizeof(inside->icmp),
602                                    &ct->tuplehash[!dir].tuple, !manip))
603                 return 0;
604
605         if (skb->ip_summed != CHECKSUM_PARTIAL) {
606                 /* Reloading "inside" here since manip_pkt may reallocate */
607                 inside = (void *)skb->data + hdrlen;
608                 inside->icmp.checksum = 0;
609                 inside->icmp.checksum =
610                         csum_fold(skb_checksum(skb, hdrlen,
611                                                skb->len - hdrlen, 0));
612         }
613
614         /* Change outer to look like the reply to an incoming packet */
615         nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple);
616         target.dst.protonum = IPPROTO_ICMP;
617         if (!nf_nat_ipv4_manip_pkt(skb, 0, &target, manip))
618                 return 0;
619
620         return 1;
621 }
622 EXPORT_SYMBOL_GPL(nf_nat_icmp_reply_translation);
623
624 static unsigned int
625 nf_nat_ipv4_fn(void *priv, struct sk_buff *skb,
626                const struct nf_hook_state *state)
627 {
628         struct nf_conn *ct;
629         enum ip_conntrack_info ctinfo;
630
631         ct = nf_ct_get(skb, &ctinfo);
632         if (!ct)
633                 return NF_ACCEPT;
634
635         if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) {
636                 if (ip_hdr(skb)->protocol == IPPROTO_ICMP) {
637                         if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo,
638                                                            state->hook))
639                                 return NF_DROP;
640                         else
641                                 return NF_ACCEPT;
642                 }
643         }
644
645         return nf_nat_inet_fn(priv, skb, state);
646 }
647
648 static unsigned int
649 nf_nat_ipv4_pre_routing(void *priv, struct sk_buff *skb,
650                         const struct nf_hook_state *state)
651 {
652         unsigned int ret;
653         __be32 daddr = ip_hdr(skb)->daddr;
654
655         ret = nf_nat_ipv4_fn(priv, skb, state);
656         if (ret == NF_ACCEPT && daddr != ip_hdr(skb)->daddr)
657                 skb_dst_drop(skb);
658
659         return ret;
660 }
661
662 #ifdef CONFIG_XFRM
663 static int nf_xfrm_me_harder(struct net *net, struct sk_buff *skb, unsigned int family)
664 {
665         struct sock *sk = skb->sk;
666         struct dst_entry *dst;
667         unsigned int hh_len;
668         struct flowi fl;
669         int err;
670
671         err = xfrm_decode_session(net, skb, &fl, family);
672         if (err < 0)
673                 return err;
674
675         dst = skb_dst(skb);
676         if (dst->xfrm)
677                 dst = ((struct xfrm_dst *)dst)->route;
678         if (!dst_hold_safe(dst))
679                 return -EHOSTUNREACH;
680
681         if (sk && !net_eq(net, sock_net(sk)))
682                 sk = NULL;
683
684         dst = xfrm_lookup(net, dst, &fl, sk, 0);
685         if (IS_ERR(dst))
686                 return PTR_ERR(dst);
687
688         skb_dst_drop(skb);
689         skb_dst_set(skb, dst);
690
691         /* Change in oif may mean change in hh_len. */
692         hh_len = skb_dst(skb)->dev->hard_header_len;
693         if (skb_headroom(skb) < hh_len &&
694             pskb_expand_head(skb, hh_len - skb_headroom(skb), 0, GFP_ATOMIC))
695                 return -ENOMEM;
696         return 0;
697 }
698 #endif
699
700 static bool nf_nat_inet_port_was_mangled(const struct sk_buff *skb, __be16 sport)
701 {
702         enum ip_conntrack_info ctinfo;
703         enum ip_conntrack_dir dir;
704         const struct nf_conn *ct;
705
706         ct = nf_ct_get(skb, &ctinfo);
707         if (!ct)
708                 return false;
709
710         switch (nf_ct_protonum(ct)) {
711         case IPPROTO_TCP:
712         case IPPROTO_UDP:
713                 break;
714         default:
715                 return false;
716         }
717
718         dir = CTINFO2DIR(ctinfo);
719         if (dir != IP_CT_DIR_ORIGINAL)
720                 return false;
721
722         return ct->tuplehash[!dir].tuple.dst.u.all != sport;
723 }
724
725 static unsigned int
726 nf_nat_ipv4_local_in(void *priv, struct sk_buff *skb,
727                      const struct nf_hook_state *state)
728 {
729         __be32 saddr = ip_hdr(skb)->saddr;
730         struct sock *sk = skb->sk;
731         unsigned int ret;
732
733         ret = nf_nat_ipv4_fn(priv, skb, state);
734
735         if (ret != NF_ACCEPT || !sk || inet_sk_transparent(sk))
736                 return ret;
737
738         /* skb has a socket assigned via tcp edemux. We need to check
739          * if nf_nat_ipv4_fn() has mangled the packet in a way that
740          * edemux would not have found this socket.
741          *
742          * This includes both changes to the source address and changes
743          * to the source port, which are both handled by the
744          * nf_nat_ipv4_fn() call above -- long after tcp/udp early demux
745          * might have found a socket for the old (pre-snat) address.
746          */
747         if (saddr != ip_hdr(skb)->saddr ||
748             nf_nat_inet_port_was_mangled(skb, sk->sk_dport))
749                 skb_orphan(skb); /* TCP edemux obtained wrong socket */
750
751         return ret;
752 }
753
754 static unsigned int
755 nf_nat_ipv4_out(void *priv, struct sk_buff *skb,
756                 const struct nf_hook_state *state)
757 {
758 #ifdef CONFIG_XFRM
759         const struct nf_conn *ct;
760         enum ip_conntrack_info ctinfo;
761         int err;
762 #endif
763         unsigned int ret;
764
765         ret = nf_nat_ipv4_fn(priv, skb, state);
766 #ifdef CONFIG_XFRM
767         if (ret != NF_ACCEPT)
768                 return ret;
769
770         if (IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED)
771                 return ret;
772
773         ct = nf_ct_get(skb, &ctinfo);
774         if (ct) {
775                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
776
777                 if (ct->tuplehash[dir].tuple.src.u3.ip !=
778                      ct->tuplehash[!dir].tuple.dst.u3.ip ||
779                     (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP &&
780                      ct->tuplehash[dir].tuple.src.u.all !=
781                      ct->tuplehash[!dir].tuple.dst.u.all)) {
782                         err = nf_xfrm_me_harder(state->net, skb, AF_INET);
783                         if (err < 0)
784                                 ret = NF_DROP_ERR(err);
785                 }
786         }
787 #endif
788         return ret;
789 }
790
791 static unsigned int
792 nf_nat_ipv4_local_fn(void *priv, struct sk_buff *skb,
793                      const struct nf_hook_state *state)
794 {
795         const struct nf_conn *ct;
796         enum ip_conntrack_info ctinfo;
797         unsigned int ret;
798         int err;
799
800         ret = nf_nat_ipv4_fn(priv, skb, state);
801         if (ret != NF_ACCEPT)
802                 return ret;
803
804         ct = nf_ct_get(skb, &ctinfo);
805         if (ct) {
806                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
807
808                 if (ct->tuplehash[dir].tuple.dst.u3.ip !=
809                     ct->tuplehash[!dir].tuple.src.u3.ip) {
810                         err = ip_route_me_harder(state->net, state->sk, skb, RTN_UNSPEC);
811                         if (err < 0)
812                                 ret = NF_DROP_ERR(err);
813                 }
814 #ifdef CONFIG_XFRM
815                 else if (!(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) &&
816                          ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP &&
817                          ct->tuplehash[dir].tuple.dst.u.all !=
818                          ct->tuplehash[!dir].tuple.src.u.all) {
819                         err = nf_xfrm_me_harder(state->net, skb, AF_INET);
820                         if (err < 0)
821                                 ret = NF_DROP_ERR(err);
822                 }
823 #endif
824         }
825         return ret;
826 }
827
828 static const struct nf_hook_ops nf_nat_ipv4_ops[] = {
829         /* Before packet filtering, change destination */
830         {
831                 .hook           = nf_nat_ipv4_pre_routing,
832                 .pf             = NFPROTO_IPV4,
833                 .hooknum        = NF_INET_PRE_ROUTING,
834                 .priority       = NF_IP_PRI_NAT_DST,
835         },
836         /* After packet filtering, change source */
837         {
838                 .hook           = nf_nat_ipv4_out,
839                 .pf             = NFPROTO_IPV4,
840                 .hooknum        = NF_INET_POST_ROUTING,
841                 .priority       = NF_IP_PRI_NAT_SRC,
842         },
843         /* Before packet filtering, change destination */
844         {
845                 .hook           = nf_nat_ipv4_local_fn,
846                 .pf             = NFPROTO_IPV4,
847                 .hooknum        = NF_INET_LOCAL_OUT,
848                 .priority       = NF_IP_PRI_NAT_DST,
849         },
850         /* After packet filtering, change source */
851         {
852                 .hook           = nf_nat_ipv4_local_in,
853                 .pf             = NFPROTO_IPV4,
854                 .hooknum        = NF_INET_LOCAL_IN,
855                 .priority       = NF_IP_PRI_NAT_SRC,
856         },
857 };
858
859 int nf_nat_ipv4_register_fn(struct net *net, const struct nf_hook_ops *ops)
860 {
861         return nf_nat_register_fn(net, ops->pf, ops, nf_nat_ipv4_ops,
862                                   ARRAY_SIZE(nf_nat_ipv4_ops));
863 }
864 EXPORT_SYMBOL_GPL(nf_nat_ipv4_register_fn);
865
866 void nf_nat_ipv4_unregister_fn(struct net *net, const struct nf_hook_ops *ops)
867 {
868         nf_nat_unregister_fn(net, ops->pf, ops, ARRAY_SIZE(nf_nat_ipv4_ops));
869 }
870 EXPORT_SYMBOL_GPL(nf_nat_ipv4_unregister_fn);
871
872 #if IS_ENABLED(CONFIG_IPV6)
873 int nf_nat_icmpv6_reply_translation(struct sk_buff *skb,
874                                     struct nf_conn *ct,
875                                     enum ip_conntrack_info ctinfo,
876                                     unsigned int hooknum,
877                                     unsigned int hdrlen)
878 {
879         struct {
880                 struct icmp6hdr icmp6;
881                 struct ipv6hdr  ip6;
882         } *inside;
883         enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
884         enum nf_nat_manip_type manip = HOOK2MANIP(hooknum);
885         struct nf_conntrack_tuple target;
886         unsigned long statusbit;
887
888         WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY);
889
890         if (skb_ensure_writable(skb, hdrlen + sizeof(*inside)))
891                 return 0;
892         if (nf_ip6_checksum(skb, hooknum, hdrlen, IPPROTO_ICMPV6))
893                 return 0;
894
895         inside = (void *)skb->data + hdrlen;
896         if (inside->icmp6.icmp6_type == NDISC_REDIRECT) {
897                 if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK)
898                         return 0;
899                 if (ct->status & IPS_NAT_MASK)
900                         return 0;
901         }
902
903         if (manip == NF_NAT_MANIP_SRC)
904                 statusbit = IPS_SRC_NAT;
905         else
906                 statusbit = IPS_DST_NAT;
907
908         /* Invert if this is reply direction */
909         if (dir == IP_CT_DIR_REPLY)
910                 statusbit ^= IPS_NAT_MASK;
911
912         if (!(ct->status & statusbit))
913                 return 1;
914
915         if (!nf_nat_ipv6_manip_pkt(skb, hdrlen + sizeof(inside->icmp6),
916                                    &ct->tuplehash[!dir].tuple, !manip))
917                 return 0;
918
919         if (skb->ip_summed != CHECKSUM_PARTIAL) {
920                 struct ipv6hdr *ipv6h = ipv6_hdr(skb);
921
922                 inside = (void *)skb->data + hdrlen;
923                 inside->icmp6.icmp6_cksum = 0;
924                 inside->icmp6.icmp6_cksum =
925                         csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
926                                         skb->len - hdrlen, IPPROTO_ICMPV6,
927                                         skb_checksum(skb, hdrlen,
928                                                      skb->len - hdrlen, 0));
929         }
930
931         nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple);
932         target.dst.protonum = IPPROTO_ICMPV6;
933         if (!nf_nat_ipv6_manip_pkt(skb, 0, &target, manip))
934                 return 0;
935
936         return 1;
937 }
938 EXPORT_SYMBOL_GPL(nf_nat_icmpv6_reply_translation);
939
940 static unsigned int
941 nf_nat_ipv6_fn(void *priv, struct sk_buff *skb,
942                const struct nf_hook_state *state)
943 {
944         struct nf_conn *ct;
945         enum ip_conntrack_info ctinfo;
946         __be16 frag_off;
947         int hdrlen;
948         u8 nexthdr;
949
950         ct = nf_ct_get(skb, &ctinfo);
951         /* Can't track?  It's not due to stress, or conntrack would
952          * have dropped it.  Hence it's the user's responsibilty to
953          * packet filter it out, or implement conntrack/NAT for that
954          * protocol. 8) --RR
955          */
956         if (!ct)
957                 return NF_ACCEPT;
958
959         if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) {
960                 nexthdr = ipv6_hdr(skb)->nexthdr;
961                 hdrlen = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr),
962                                           &nexthdr, &frag_off);
963
964                 if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) {
965                         if (!nf_nat_icmpv6_reply_translation(skb, ct, ctinfo,
966                                                              state->hook,
967                                                              hdrlen))
968                                 return NF_DROP;
969                         else
970                                 return NF_ACCEPT;
971                 }
972         }
973
974         return nf_nat_inet_fn(priv, skb, state);
975 }
976
977 static unsigned int
978 nf_nat_ipv6_local_in(void *priv, struct sk_buff *skb,
979                      const struct nf_hook_state *state)
980 {
981         struct in6_addr saddr = ipv6_hdr(skb)->saddr;
982         struct sock *sk = skb->sk;
983         unsigned int ret;
984
985         ret = nf_nat_ipv6_fn(priv, skb, state);
986
987         if (ret != NF_ACCEPT || !sk || inet_sk_transparent(sk))
988                 return ret;
989
990         /* see nf_nat_ipv4_local_in */
991         if (ipv6_addr_cmp(&saddr, &ipv6_hdr(skb)->saddr) ||
992             nf_nat_inet_port_was_mangled(skb, sk->sk_dport))
993                 skb_orphan(skb);
994
995         return ret;
996 }
997
998 static unsigned int
999 nf_nat_ipv6_in(void *priv, struct sk_buff *skb,
1000                const struct nf_hook_state *state)
1001 {
1002         unsigned int ret, verdict;
1003         struct in6_addr daddr = ipv6_hdr(skb)->daddr;
1004
1005         ret = nf_nat_ipv6_fn(priv, skb, state);
1006         verdict = ret & NF_VERDICT_MASK;
1007         if (verdict != NF_DROP && verdict != NF_STOLEN &&
1008             ipv6_addr_cmp(&daddr, &ipv6_hdr(skb)->daddr))
1009                 skb_dst_drop(skb);
1010
1011         return ret;
1012 }
1013
1014 static unsigned int
1015 nf_nat_ipv6_out(void *priv, struct sk_buff *skb,
1016                 const struct nf_hook_state *state)
1017 {
1018 #ifdef CONFIG_XFRM
1019         const struct nf_conn *ct;
1020         enum ip_conntrack_info ctinfo;
1021         int err;
1022 #endif
1023         unsigned int ret;
1024
1025         ret = nf_nat_ipv6_fn(priv, skb, state);
1026 #ifdef CONFIG_XFRM
1027         if (ret != NF_ACCEPT)
1028                 return ret;
1029
1030         if (IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED)
1031                 return ret;
1032         ct = nf_ct_get(skb, &ctinfo);
1033         if (ct) {
1034                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
1035
1036                 if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3,
1037                                       &ct->tuplehash[!dir].tuple.dst.u3) ||
1038                     (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 &&
1039                      ct->tuplehash[dir].tuple.src.u.all !=
1040                      ct->tuplehash[!dir].tuple.dst.u.all)) {
1041                         err = nf_xfrm_me_harder(state->net, skb, AF_INET6);
1042                         if (err < 0)
1043                                 ret = NF_DROP_ERR(err);
1044                 }
1045         }
1046 #endif
1047
1048         return ret;
1049 }
1050
1051 static unsigned int
1052 nf_nat_ipv6_local_fn(void *priv, struct sk_buff *skb,
1053                      const struct nf_hook_state *state)
1054 {
1055         const struct nf_conn *ct;
1056         enum ip_conntrack_info ctinfo;
1057         unsigned int ret;
1058         int err;
1059
1060         ret = nf_nat_ipv6_fn(priv, skb, state);
1061         if (ret != NF_ACCEPT)
1062                 return ret;
1063
1064         ct = nf_ct_get(skb, &ctinfo);
1065         if (ct) {
1066                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
1067
1068                 if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.dst.u3,
1069                                       &ct->tuplehash[!dir].tuple.src.u3)) {
1070                         err = nf_ip6_route_me_harder(state->net, state->sk, skb);
1071                         if (err < 0)
1072                                 ret = NF_DROP_ERR(err);
1073                 }
1074 #ifdef CONFIG_XFRM
1075                 else if (!(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) &&
1076                          ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 &&
1077                          ct->tuplehash[dir].tuple.dst.u.all !=
1078                          ct->tuplehash[!dir].tuple.src.u.all) {
1079                         err = nf_xfrm_me_harder(state->net, skb, AF_INET6);
1080                         if (err < 0)
1081                                 ret = NF_DROP_ERR(err);
1082                 }
1083 #endif
1084         }
1085
1086         return ret;
1087 }
1088
1089 static const struct nf_hook_ops nf_nat_ipv6_ops[] = {
1090         /* Before packet filtering, change destination */
1091         {
1092                 .hook           = nf_nat_ipv6_in,
1093                 .pf             = NFPROTO_IPV6,
1094                 .hooknum        = NF_INET_PRE_ROUTING,
1095                 .priority       = NF_IP6_PRI_NAT_DST,
1096         },
1097         /* After packet filtering, change source */
1098         {
1099                 .hook           = nf_nat_ipv6_out,
1100                 .pf             = NFPROTO_IPV6,
1101                 .hooknum        = NF_INET_POST_ROUTING,
1102                 .priority       = NF_IP6_PRI_NAT_SRC,
1103         },
1104         /* Before packet filtering, change destination */
1105         {
1106                 .hook           = nf_nat_ipv6_local_fn,
1107                 .pf             = NFPROTO_IPV6,
1108                 .hooknum        = NF_INET_LOCAL_OUT,
1109                 .priority       = NF_IP6_PRI_NAT_DST,
1110         },
1111         /* After packet filtering, change source */
1112         {
1113                 .hook           = nf_nat_ipv6_local_in,
1114                 .pf             = NFPROTO_IPV6,
1115                 .hooknum        = NF_INET_LOCAL_IN,
1116                 .priority       = NF_IP6_PRI_NAT_SRC,
1117         },
1118 };
1119
1120 int nf_nat_ipv6_register_fn(struct net *net, const struct nf_hook_ops *ops)
1121 {
1122         return nf_nat_register_fn(net, ops->pf, ops, nf_nat_ipv6_ops,
1123                                   ARRAY_SIZE(nf_nat_ipv6_ops));
1124 }
1125 EXPORT_SYMBOL_GPL(nf_nat_ipv6_register_fn);
1126
1127 void nf_nat_ipv6_unregister_fn(struct net *net, const struct nf_hook_ops *ops)
1128 {
1129         nf_nat_unregister_fn(net, ops->pf, ops, ARRAY_SIZE(nf_nat_ipv6_ops));
1130 }
1131 EXPORT_SYMBOL_GPL(nf_nat_ipv6_unregister_fn);
1132 #endif /* CONFIG_IPV6 */
1133
1134 #if defined(CONFIG_NF_TABLES_INET) && IS_ENABLED(CONFIG_NFT_NAT)
1135 int nf_nat_inet_register_fn(struct net *net, const struct nf_hook_ops *ops)
1136 {
1137         int ret;
1138
1139         if (WARN_ON_ONCE(ops->pf != NFPROTO_INET))
1140                 return -EINVAL;
1141
1142         ret = nf_nat_register_fn(net, NFPROTO_IPV6, ops, nf_nat_ipv6_ops,
1143                                  ARRAY_SIZE(nf_nat_ipv6_ops));
1144         if (ret)
1145                 return ret;
1146
1147         ret = nf_nat_register_fn(net, NFPROTO_IPV4, ops, nf_nat_ipv4_ops,
1148                                  ARRAY_SIZE(nf_nat_ipv4_ops));
1149         if (ret)
1150                 nf_nat_unregister_fn(net, NFPROTO_IPV6, ops,
1151                                         ARRAY_SIZE(nf_nat_ipv6_ops));
1152         return ret;
1153 }
1154 EXPORT_SYMBOL_GPL(nf_nat_inet_register_fn);
1155
1156 void nf_nat_inet_unregister_fn(struct net *net, const struct nf_hook_ops *ops)
1157 {
1158         nf_nat_unregister_fn(net, NFPROTO_IPV4, ops, ARRAY_SIZE(nf_nat_ipv4_ops));
1159         nf_nat_unregister_fn(net, NFPROTO_IPV6, ops, ARRAY_SIZE(nf_nat_ipv6_ops));
1160 }
1161 EXPORT_SYMBOL_GPL(nf_nat_inet_unregister_fn);
1162 #endif /* NFT INET NAT */