GNU Linux-libre 5.10.217-gnu1
[releases.git] / drivers / infiniband / sw / rxe / rxe_recv.c
1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /*
3  * Copyright (c) 2016 Mellanox Technologies Ltd. All rights reserved.
4  * Copyright (c) 2015 System Fabric Works, Inc. All rights reserved.
5  */
6
7 #include <linux/skbuff.h>
8
9 #include "rxe.h"
10 #include "rxe_loc.h"
11
12 /* check that QP matches packet opcode type and is in a valid state */
13 static int check_type_state(struct rxe_dev *rxe, struct rxe_pkt_info *pkt,
14                             struct rxe_qp *qp)
15 {
16         unsigned int pkt_type;
17
18         if (unlikely(!qp->valid))
19                 goto err1;
20
21         pkt_type = pkt->opcode & 0xe0;
22
23         switch (qp_type(qp)) {
24         case IB_QPT_RC:
25                 if (unlikely(pkt_type != IB_OPCODE_RC)) {
26                         pr_warn_ratelimited("bad qp type\n");
27                         goto err1;
28                 }
29                 break;
30         case IB_QPT_UC:
31                 if (unlikely(pkt_type != IB_OPCODE_UC)) {
32                         pr_warn_ratelimited("bad qp type\n");
33                         goto err1;
34                 }
35                 break;
36         case IB_QPT_UD:
37         case IB_QPT_SMI:
38         case IB_QPT_GSI:
39                 if (unlikely(pkt_type != IB_OPCODE_UD)) {
40                         pr_warn_ratelimited("bad qp type\n");
41                         goto err1;
42                 }
43                 break;
44         default:
45                 pr_warn_ratelimited("unsupported qp type\n");
46                 goto err1;
47         }
48
49         if (pkt->mask & RXE_REQ_MASK) {
50                 if (unlikely(qp->resp.state != QP_STATE_READY))
51                         goto err1;
52         } else if (unlikely(qp->req.state < QP_STATE_READY ||
53                                 qp->req.state > QP_STATE_DRAINED)) {
54                 goto err1;
55         }
56
57         return 0;
58
59 err1:
60         return -EINVAL;
61 }
62
63 static void set_bad_pkey_cntr(struct rxe_port *port)
64 {
65         spin_lock_bh(&port->port_lock);
66         port->attr.bad_pkey_cntr = min((u32)0xffff,
67                                        port->attr.bad_pkey_cntr + 1);
68         spin_unlock_bh(&port->port_lock);
69 }
70
71 static void set_qkey_viol_cntr(struct rxe_port *port)
72 {
73         spin_lock_bh(&port->port_lock);
74         port->attr.qkey_viol_cntr = min((u32)0xffff,
75                                         port->attr.qkey_viol_cntr + 1);
76         spin_unlock_bh(&port->port_lock);
77 }
78
79 static int check_keys(struct rxe_dev *rxe, struct rxe_pkt_info *pkt,
80                       u32 qpn, struct rxe_qp *qp)
81 {
82         struct rxe_port *port = &rxe->port;
83         u16 pkey = bth_pkey(pkt);
84
85         pkt->pkey_index = 0;
86
87         if (!pkey_match(pkey, IB_DEFAULT_PKEY_FULL)) {
88                 pr_warn_ratelimited("bad pkey = 0x%x\n", pkey);
89                 set_bad_pkey_cntr(port);
90                 goto err1;
91         }
92
93         if ((qp_type(qp) == IB_QPT_UD || qp_type(qp) == IB_QPT_GSI) &&
94             pkt->mask) {
95                 u32 qkey = (qpn == 1) ? GSI_QKEY : qp->attr.qkey;
96
97                 if (unlikely(deth_qkey(pkt) != qkey)) {
98                         pr_warn_ratelimited("bad qkey, got 0x%x expected 0x%x for qpn 0x%x\n",
99                                             deth_qkey(pkt), qkey, qpn);
100                         set_qkey_viol_cntr(port);
101                         goto err1;
102                 }
103         }
104
105         return 0;
106
107 err1:
108         return -EINVAL;
109 }
110
111 static int check_addr(struct rxe_dev *rxe, struct rxe_pkt_info *pkt,
112                       struct rxe_qp *qp)
113 {
114         struct sk_buff *skb = PKT_TO_SKB(pkt);
115
116         if (qp_type(qp) != IB_QPT_RC && qp_type(qp) != IB_QPT_UC)
117                 goto done;
118
119         if (unlikely(pkt->port_num != qp->attr.port_num)) {
120                 pr_warn_ratelimited("port %d != qp port %d\n",
121                                     pkt->port_num, qp->attr.port_num);
122                 goto err1;
123         }
124
125         if (skb->protocol == htons(ETH_P_IP)) {
126                 struct in_addr *saddr =
127                         &qp->pri_av.sgid_addr._sockaddr_in.sin_addr;
128                 struct in_addr *daddr =
129                         &qp->pri_av.dgid_addr._sockaddr_in.sin_addr;
130
131                 if (ip_hdr(skb)->daddr != saddr->s_addr) {
132                         pr_warn_ratelimited("dst addr %pI4 != qp source addr %pI4\n",
133                                             &ip_hdr(skb)->daddr,
134                                             &saddr->s_addr);
135                         goto err1;
136                 }
137
138                 if (ip_hdr(skb)->saddr != daddr->s_addr) {
139                         pr_warn_ratelimited("source addr %pI4 != qp dst addr %pI4\n",
140                                             &ip_hdr(skb)->saddr,
141                                             &daddr->s_addr);
142                         goto err1;
143                 }
144
145         } else if (skb->protocol == htons(ETH_P_IPV6)) {
146                 struct in6_addr *saddr =
147                         &qp->pri_av.sgid_addr._sockaddr_in6.sin6_addr;
148                 struct in6_addr *daddr =
149                         &qp->pri_av.dgid_addr._sockaddr_in6.sin6_addr;
150
151                 if (memcmp(&ipv6_hdr(skb)->daddr, saddr, sizeof(*saddr))) {
152                         pr_warn_ratelimited("dst addr %pI6 != qp source addr %pI6\n",
153                                             &ipv6_hdr(skb)->daddr, saddr);
154                         goto err1;
155                 }
156
157                 if (memcmp(&ipv6_hdr(skb)->saddr, daddr, sizeof(*daddr))) {
158                         pr_warn_ratelimited("source addr %pI6 != qp dst addr %pI6\n",
159                                             &ipv6_hdr(skb)->saddr, daddr);
160                         goto err1;
161                 }
162         }
163
164 done:
165         return 0;
166
167 err1:
168         return -EINVAL;
169 }
170
171 static int hdr_check(struct rxe_pkt_info *pkt)
172 {
173         struct rxe_dev *rxe = pkt->rxe;
174         struct rxe_port *port = &rxe->port;
175         struct rxe_qp *qp = NULL;
176         u32 qpn = bth_qpn(pkt);
177         int index;
178         int err;
179
180         if (unlikely(bth_tver(pkt) != BTH_TVER)) {
181                 pr_warn_ratelimited("bad tver\n");
182                 goto err1;
183         }
184
185         if (unlikely(qpn == 0)) {
186                 pr_warn_once("QP 0 not supported");
187                 goto err1;
188         }
189
190         if (qpn != IB_MULTICAST_QPN) {
191                 index = (qpn == 1) ? port->qp_gsi_index : qpn;
192
193                 qp = rxe_pool_get_index(&rxe->qp_pool, index);
194                 if (unlikely(!qp)) {
195                         pr_warn_ratelimited("no qp matches qpn 0x%x\n", qpn);
196                         goto err1;
197                 }
198
199                 err = check_type_state(rxe, pkt, qp);
200                 if (unlikely(err))
201                         goto err2;
202
203                 err = check_addr(rxe, pkt, qp);
204                 if (unlikely(err))
205                         goto err2;
206
207                 err = check_keys(rxe, pkt, qpn, qp);
208                 if (unlikely(err))
209                         goto err2;
210         } else {
211                 if (unlikely((pkt->mask & RXE_GRH_MASK) == 0)) {
212                         pr_warn_ratelimited("no grh for mcast qpn\n");
213                         goto err1;
214                 }
215         }
216
217         pkt->qp = qp;
218         return 0;
219
220 err2:
221         rxe_drop_ref(qp);
222 err1:
223         return -EINVAL;
224 }
225
226 static inline void rxe_rcv_pkt(struct rxe_pkt_info *pkt, struct sk_buff *skb)
227 {
228         if (pkt->mask & RXE_REQ_MASK)
229                 rxe_resp_queue_pkt(pkt->qp, skb);
230         else
231                 rxe_comp_queue_pkt(pkt->qp, skb);
232 }
233
234 static void rxe_rcv_mcast_pkt(struct rxe_dev *rxe, struct sk_buff *skb)
235 {
236         struct rxe_pkt_info *pkt = SKB_TO_PKT(skb);
237         struct rxe_mc_grp *mcg;
238         struct rxe_mc_elem *mce;
239         struct rxe_qp *qp;
240         union ib_gid dgid;
241         struct sk_buff *per_qp_skb;
242         struct rxe_pkt_info *per_qp_pkt;
243         int err;
244
245         if (skb->protocol == htons(ETH_P_IP))
246                 ipv6_addr_set_v4mapped(ip_hdr(skb)->daddr,
247                                        (struct in6_addr *)&dgid);
248         else if (skb->protocol == htons(ETH_P_IPV6))
249                 memcpy(&dgid, &ipv6_hdr(skb)->daddr, sizeof(dgid));
250
251         /* lookup mcast group corresponding to mgid, takes a ref */
252         mcg = rxe_pool_get_key(&rxe->mc_grp_pool, &dgid);
253         if (!mcg)
254                 goto err1;      /* mcast group not registered */
255
256         spin_lock_bh(&mcg->mcg_lock);
257
258         list_for_each_entry(mce, &mcg->qp_list, qp_list) {
259                 qp = mce->qp;
260
261                 /* validate qp for incoming packet */
262                 err = check_type_state(rxe, pkt, qp);
263                 if (err)
264                         continue;
265
266                 err = check_keys(rxe, pkt, bth_qpn(pkt), qp);
267                 if (err)
268                         continue;
269
270                 /* for all but the last qp create a new clone of the
271                  * skb and pass to the qp. If an error occurs in the
272                  * checks for the last qp in the list we need to
273                  * free the skb since it hasn't been passed on to
274                  * rxe_rcv_pkt() which would free it later.
275                  */
276                 if (mce->qp_list.next != &mcg->qp_list) {
277                         per_qp_skb = skb_clone(skb, GFP_ATOMIC);
278                 } else {
279                         per_qp_skb = skb;
280                         /* show we have consumed the skb */
281                         skb = NULL;
282                 }
283
284                 if (unlikely(!per_qp_skb))
285                         continue;
286
287                 per_qp_pkt = SKB_TO_PKT(per_qp_skb);
288                 per_qp_pkt->qp = qp;
289                 rxe_add_ref(qp);
290                 rxe_rcv_pkt(per_qp_pkt, per_qp_skb);
291         }
292
293         spin_unlock_bh(&mcg->mcg_lock);
294
295         rxe_drop_ref(mcg);      /* drop ref from rxe_pool_get_key. */
296
297 err1:
298         /* free skb if not consumed */
299         kfree_skb(skb);
300 }
301
302 /**
303  * rxe_chk_dgid - validate destination IP address
304  * @rxe: rxe device that received packet
305  * @skb: the received packet buffer
306  *
307  * Accept any loopback packets
308  * Extract IP address from packet and
309  * Accept if multicast packet
310  * Accept if matches an SGID table entry
311  */
312 static int rxe_chk_dgid(struct rxe_dev *rxe, struct sk_buff *skb)
313 {
314         struct rxe_pkt_info *pkt = SKB_TO_PKT(skb);
315         const struct ib_gid_attr *gid_attr;
316         union ib_gid dgid;
317         union ib_gid *pdgid;
318
319         if (pkt->mask & RXE_LOOPBACK_MASK)
320                 return 0;
321
322         if (skb->protocol == htons(ETH_P_IP)) {
323                 ipv6_addr_set_v4mapped(ip_hdr(skb)->daddr,
324                                        (struct in6_addr *)&dgid);
325                 pdgid = &dgid;
326         } else {
327                 pdgid = (union ib_gid *)&ipv6_hdr(skb)->daddr;
328         }
329
330         if (rdma_is_multicast_addr((struct in6_addr *)pdgid))
331                 return 0;
332
333         gid_attr = rdma_find_gid_by_port(&rxe->ib_dev, pdgid,
334                                          IB_GID_TYPE_ROCE_UDP_ENCAP,
335                                          1, skb->dev);
336         if (IS_ERR(gid_attr))
337                 return PTR_ERR(gid_attr);
338
339         rdma_put_gid_attr(gid_attr);
340         return 0;
341 }
342
343 /* rxe_rcv is called from the interface driver */
344 void rxe_rcv(struct sk_buff *skb)
345 {
346         int err;
347         struct rxe_pkt_info *pkt = SKB_TO_PKT(skb);
348         struct rxe_dev *rxe = pkt->rxe;
349         __be32 *icrcp;
350         u32 calc_icrc, pack_icrc;
351
352         pkt->offset = 0;
353
354         if (unlikely(skb->len < pkt->offset + RXE_BTH_BYTES))
355                 goto drop;
356
357         if (rxe_chk_dgid(rxe, skb) < 0) {
358                 pr_warn_ratelimited("failed checking dgid\n");
359                 goto drop;
360         }
361
362         pkt->opcode = bth_opcode(pkt);
363         pkt->psn = bth_psn(pkt);
364         pkt->qp = NULL;
365         pkt->mask |= rxe_opcode[pkt->opcode].mask;
366
367         if (unlikely(skb->len < header_size(pkt)))
368                 goto drop;
369
370         err = hdr_check(pkt);
371         if (unlikely(err))
372                 goto drop;
373
374         /* Verify ICRC */
375         icrcp = (__be32 *)(pkt->hdr + pkt->paylen - RXE_ICRC_SIZE);
376         pack_icrc = be32_to_cpu(*icrcp);
377
378         calc_icrc = rxe_icrc_hdr(pkt, skb);
379         calc_icrc = rxe_crc32(rxe, calc_icrc, (u8 *)payload_addr(pkt),
380                               payload_size(pkt) + bth_pad(pkt));
381         calc_icrc = (__force u32)cpu_to_be32(~calc_icrc);
382         if (unlikely(calc_icrc != pack_icrc)) {
383                 if (skb->protocol == htons(ETH_P_IPV6))
384                         pr_warn_ratelimited("bad ICRC from %pI6c\n",
385                                             &ipv6_hdr(skb)->saddr);
386                 else if (skb->protocol == htons(ETH_P_IP))
387                         pr_warn_ratelimited("bad ICRC from %pI4\n",
388                                             &ip_hdr(skb)->saddr);
389                 else
390                         pr_warn_ratelimited("bad ICRC from unknown\n");
391
392                 goto drop;
393         }
394
395         rxe_counter_inc(rxe, RXE_CNT_RCVD_PKTS);
396
397         if (unlikely(bth_qpn(pkt) == IB_MULTICAST_QPN))
398                 rxe_rcv_mcast_pkt(rxe, skb);
399         else
400                 rxe_rcv_pkt(pkt, skb);
401
402         return;
403
404 drop:
405         if (pkt->qp)
406                 rxe_drop_ref(pkt->qp);
407
408         kfree_skb(skb);
409 }