GNU Linux-libre 5.10.153-gnu1
[releases.git] / drivers / infiniband / ulp / rtrs / rtrs-srv.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * RDMA Transport Layer
4  *
5  * Copyright (c) 2014 - 2018 ProfitBricks GmbH. All rights reserved.
6  * Copyright (c) 2018 - 2019 1&1 IONOS Cloud GmbH. All rights reserved.
7  * Copyright (c) 2019 - 2020 1&1 IONOS SE. All rights reserved.
8  */
9
10 #undef pr_fmt
11 #define pr_fmt(fmt) KBUILD_MODNAME " L" __stringify(__LINE__) ": " fmt
12
13 #include <linux/module.h>
14 #include <linux/mempool.h>
15
16 #include "rtrs-srv.h"
17 #include "rtrs-log.h"
18 #include <rdma/ib_cm.h>
19 #include <rdma/ib_verbs.h>
20
21 MODULE_DESCRIPTION("RDMA Transport Server");
22 MODULE_LICENSE("GPL");
23
24 /* Must be power of 2, see mask from mr->page_size in ib_sg_to_pages() */
25 #define DEFAULT_MAX_CHUNK_SIZE (128 << 10)
26 #define DEFAULT_SESS_QUEUE_DEPTH 512
27 #define MAX_HDR_SIZE PAGE_SIZE
28
29 /* We guarantee to serve 10 paths at least */
30 #define CHUNK_POOL_SZ 10
31
32 static struct rtrs_rdma_dev_pd dev_pd;
33 static mempool_t *chunk_pool;
34 struct class *rtrs_dev_class;
35 static struct rtrs_srv_ib_ctx ib_ctx;
36
37 static int __read_mostly max_chunk_size = DEFAULT_MAX_CHUNK_SIZE;
38 static int __read_mostly sess_queue_depth = DEFAULT_SESS_QUEUE_DEPTH;
39
40 static bool always_invalidate = true;
41 module_param(always_invalidate, bool, 0444);
42 MODULE_PARM_DESC(always_invalidate,
43                  "Invalidate memory registration for contiguous memory regions before accessing.");
44
45 module_param_named(max_chunk_size, max_chunk_size, int, 0444);
46 MODULE_PARM_DESC(max_chunk_size,
47                  "Max size for each IO request, when change the unit is in byte (default: "
48                  __stringify(DEFAULT_MAX_CHUNK_SIZE) "KB)");
49
50 module_param_named(sess_queue_depth, sess_queue_depth, int, 0444);
51 MODULE_PARM_DESC(sess_queue_depth,
52                  "Number of buffers for pending I/O requests to allocate per session. Maximum: "
53                  __stringify(MAX_SESS_QUEUE_DEPTH) " (default: "
54                  __stringify(DEFAULT_SESS_QUEUE_DEPTH) ")");
55
56 static cpumask_t cq_affinity_mask = { CPU_BITS_ALL };
57
58 static struct workqueue_struct *rtrs_wq;
59
60 static inline struct rtrs_srv_con *to_srv_con(struct rtrs_con *c)
61 {
62         return container_of(c, struct rtrs_srv_con, c);
63 }
64
65 static inline struct rtrs_srv_sess *to_srv_sess(struct rtrs_sess *s)
66 {
67         return container_of(s, struct rtrs_srv_sess, s);
68 }
69
70 static bool __rtrs_srv_change_state(struct rtrs_srv_sess *sess,
71                                      enum rtrs_srv_state new_state)
72 {
73         enum rtrs_srv_state old_state;
74         bool changed = false;
75
76         lockdep_assert_held(&sess->state_lock);
77         old_state = sess->state;
78         switch (new_state) {
79         case RTRS_SRV_CONNECTED:
80                 switch (old_state) {
81                 case RTRS_SRV_CONNECTING:
82                         changed = true;
83                         fallthrough;
84                 default:
85                         break;
86                 }
87                 break;
88         case RTRS_SRV_CLOSING:
89                 switch (old_state) {
90                 case RTRS_SRV_CONNECTING:
91                 case RTRS_SRV_CONNECTED:
92                         changed = true;
93                         fallthrough;
94                 default:
95                         break;
96                 }
97                 break;
98         case RTRS_SRV_CLOSED:
99                 switch (old_state) {
100                 case RTRS_SRV_CLOSING:
101                         changed = true;
102                         fallthrough;
103                 default:
104                         break;
105                 }
106                 break;
107         default:
108                 break;
109         }
110         if (changed)
111                 sess->state = new_state;
112
113         return changed;
114 }
115
116 static bool rtrs_srv_change_state_get_old(struct rtrs_srv_sess *sess,
117                                            enum rtrs_srv_state new_state,
118                                            enum rtrs_srv_state *old_state)
119 {
120         bool changed;
121
122         spin_lock_irq(&sess->state_lock);
123         *old_state = sess->state;
124         changed = __rtrs_srv_change_state(sess, new_state);
125         spin_unlock_irq(&sess->state_lock);
126
127         return changed;
128 }
129
130 static bool rtrs_srv_change_state(struct rtrs_srv_sess *sess,
131                                    enum rtrs_srv_state new_state)
132 {
133         enum rtrs_srv_state old_state;
134
135         return rtrs_srv_change_state_get_old(sess, new_state, &old_state);
136 }
137
138 static void free_id(struct rtrs_srv_op *id)
139 {
140         if (!id)
141                 return;
142         kfree(id);
143 }
144
145 static void rtrs_srv_free_ops_ids(struct rtrs_srv_sess *sess)
146 {
147         struct rtrs_srv *srv = sess->srv;
148         int i;
149
150         WARN_ON(atomic_read(&sess->ids_inflight));
151         if (sess->ops_ids) {
152                 for (i = 0; i < srv->queue_depth; i++)
153                         free_id(sess->ops_ids[i]);
154                 kfree(sess->ops_ids);
155                 sess->ops_ids = NULL;
156         }
157 }
158
159 static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc);
160
161 static struct ib_cqe io_comp_cqe = {
162         .done = rtrs_srv_rdma_done
163 };
164
165 static int rtrs_srv_alloc_ops_ids(struct rtrs_srv_sess *sess)
166 {
167         struct rtrs_srv *srv = sess->srv;
168         struct rtrs_srv_op *id;
169         int i;
170
171         sess->ops_ids = kcalloc(srv->queue_depth, sizeof(*sess->ops_ids),
172                                 GFP_KERNEL);
173         if (!sess->ops_ids)
174                 goto err;
175
176         for (i = 0; i < srv->queue_depth; ++i) {
177                 id = kzalloc(sizeof(*id), GFP_KERNEL);
178                 if (!id)
179                         goto err;
180
181                 sess->ops_ids[i] = id;
182         }
183         init_waitqueue_head(&sess->ids_waitq);
184         atomic_set(&sess->ids_inflight, 0);
185
186         return 0;
187
188 err:
189         rtrs_srv_free_ops_ids(sess);
190         return -ENOMEM;
191 }
192
193 static inline void rtrs_srv_get_ops_ids(struct rtrs_srv_sess *sess)
194 {
195         atomic_inc(&sess->ids_inflight);
196 }
197
198 static inline void rtrs_srv_put_ops_ids(struct rtrs_srv_sess *sess)
199 {
200         if (atomic_dec_and_test(&sess->ids_inflight))
201                 wake_up(&sess->ids_waitq);
202 }
203
204 static void rtrs_srv_wait_ops_ids(struct rtrs_srv_sess *sess)
205 {
206         wait_event(sess->ids_waitq, !atomic_read(&sess->ids_inflight));
207 }
208
209
210 static void rtrs_srv_reg_mr_done(struct ib_cq *cq, struct ib_wc *wc)
211 {
212         struct rtrs_srv_con *con = cq->cq_context;
213         struct rtrs_sess *s = con->c.sess;
214         struct rtrs_srv_sess *sess = to_srv_sess(s);
215
216         if (unlikely(wc->status != IB_WC_SUCCESS)) {
217                 rtrs_err(s, "REG MR failed: %s\n",
218                           ib_wc_status_msg(wc->status));
219                 close_sess(sess);
220                 return;
221         }
222 }
223
224 static struct ib_cqe local_reg_cqe = {
225         .done = rtrs_srv_reg_mr_done
226 };
227
228 static int rdma_write_sg(struct rtrs_srv_op *id)
229 {
230         struct rtrs_sess *s = id->con->c.sess;
231         struct rtrs_srv_sess *sess = to_srv_sess(s);
232         dma_addr_t dma_addr = sess->dma_addr[id->msg_id];
233         struct rtrs_srv_mr *srv_mr;
234         struct rtrs_srv *srv = sess->srv;
235         struct ib_send_wr inv_wr;
236         struct ib_rdma_wr imm_wr;
237         struct ib_rdma_wr *wr = NULL;
238         enum ib_send_flags flags;
239         size_t sg_cnt;
240         int err, offset;
241         bool need_inval;
242         u32 rkey = 0;
243         struct ib_reg_wr rwr;
244         struct ib_sge *plist;
245         struct ib_sge list;
246
247         sg_cnt = le16_to_cpu(id->rd_msg->sg_cnt);
248         need_inval = le16_to_cpu(id->rd_msg->flags) & RTRS_MSG_NEED_INVAL_F;
249         if (unlikely(sg_cnt != 1))
250                 return -EINVAL;
251
252         offset = 0;
253
254         wr              = &id->tx_wr;
255         plist           = &id->tx_sg;
256         plist->addr     = dma_addr + offset;
257         plist->length   = le32_to_cpu(id->rd_msg->desc[0].len);
258
259         /* WR will fail with length error
260          * if this is 0
261          */
262         if (unlikely(plist->length == 0)) {
263                 rtrs_err(s, "Invalid RDMA-Write sg list length 0\n");
264                 return -EINVAL;
265         }
266
267         plist->lkey = sess->s.dev->ib_pd->local_dma_lkey;
268         offset += plist->length;
269
270         wr->wr.sg_list  = plist;
271         wr->wr.num_sge  = 1;
272         wr->remote_addr = le64_to_cpu(id->rd_msg->desc[0].addr);
273         wr->rkey        = le32_to_cpu(id->rd_msg->desc[0].key);
274         if (rkey == 0)
275                 rkey = wr->rkey;
276         else
277                 /* Only one key is actually used */
278                 WARN_ON_ONCE(rkey != wr->rkey);
279
280         wr->wr.opcode = IB_WR_RDMA_WRITE;
281         wr->wr.wr_cqe   = &io_comp_cqe;
282         wr->wr.ex.imm_data = 0;
283         wr->wr.send_flags  = 0;
284
285         if (need_inval && always_invalidate) {
286                 wr->wr.next = &rwr.wr;
287                 rwr.wr.next = &inv_wr;
288                 inv_wr.next = &imm_wr.wr;
289         } else if (always_invalidate) {
290                 wr->wr.next = &rwr.wr;
291                 rwr.wr.next = &imm_wr.wr;
292         } else if (need_inval) {
293                 wr->wr.next = &inv_wr;
294                 inv_wr.next = &imm_wr.wr;
295         } else {
296                 wr->wr.next = &imm_wr.wr;
297         }
298         /*
299          * From time to time we have to post signaled sends,
300          * or send queue will fill up and only QP reset can help.
301          */
302         flags = (atomic_inc_return(&id->con->wr_cnt) % srv->queue_depth) ?
303                 0 : IB_SEND_SIGNALED;
304
305         if (need_inval) {
306                 inv_wr.sg_list = NULL;
307                 inv_wr.num_sge = 0;
308                 inv_wr.opcode = IB_WR_SEND_WITH_INV;
309                 inv_wr.wr_cqe   = &io_comp_cqe;
310                 inv_wr.send_flags = 0;
311                 inv_wr.ex.invalidate_rkey = rkey;
312         }
313
314         imm_wr.wr.next = NULL;
315         if (always_invalidate) {
316                 struct rtrs_msg_rkey_rsp *msg;
317
318                 srv_mr = &sess->mrs[id->msg_id];
319                 rwr.wr.opcode = IB_WR_REG_MR;
320                 rwr.wr.wr_cqe = &local_reg_cqe;
321                 rwr.wr.num_sge = 0;
322                 rwr.mr = srv_mr->mr;
323                 rwr.wr.send_flags = 0;
324                 rwr.key = srv_mr->mr->rkey;
325                 rwr.access = (IB_ACCESS_LOCAL_WRITE |
326                               IB_ACCESS_REMOTE_WRITE);
327                 msg = srv_mr->iu->buf;
328                 msg->buf_id = cpu_to_le16(id->msg_id);
329                 msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
330                 msg->rkey = cpu_to_le32(srv_mr->mr->rkey);
331
332                 list.addr   = srv_mr->iu->dma_addr;
333                 list.length = sizeof(*msg);
334                 list.lkey   = sess->s.dev->ib_pd->local_dma_lkey;
335                 imm_wr.wr.sg_list = &list;
336                 imm_wr.wr.num_sge = 1;
337                 imm_wr.wr.opcode = IB_WR_SEND_WITH_IMM;
338                 ib_dma_sync_single_for_device(sess->s.dev->ib_dev,
339                                               srv_mr->iu->dma_addr,
340                                               srv_mr->iu->size, DMA_TO_DEVICE);
341         } else {
342                 imm_wr.wr.sg_list = NULL;
343                 imm_wr.wr.num_sge = 0;
344                 imm_wr.wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
345         }
346         imm_wr.wr.send_flags = flags;
347         imm_wr.wr.ex.imm_data = cpu_to_be32(rtrs_to_io_rsp_imm(id->msg_id,
348                                                              0, need_inval));
349
350         imm_wr.wr.wr_cqe   = &io_comp_cqe;
351         ib_dma_sync_single_for_device(sess->s.dev->ib_dev, dma_addr,
352                                       offset, DMA_BIDIRECTIONAL);
353
354         err = ib_post_send(id->con->c.qp, &id->tx_wr.wr, NULL);
355         if (unlikely(err))
356                 rtrs_err(s,
357                           "Posting RDMA-Write-Request to QP failed, err: %d\n",
358                           err);
359
360         return err;
361 }
362
363 /**
364  * send_io_resp_imm() - respond to client with empty IMM on failed READ/WRITE
365  *                      requests or on successful WRITE request.
366  * @con:        the connection to send back result
367  * @id:         the id associated with the IO
368  * @errno:      the error number of the IO.
369  *
370  * Return 0 on success, errno otherwise.
371  */
372 static int send_io_resp_imm(struct rtrs_srv_con *con, struct rtrs_srv_op *id,
373                             int errno)
374 {
375         struct rtrs_sess *s = con->c.sess;
376         struct rtrs_srv_sess *sess = to_srv_sess(s);
377         struct ib_send_wr inv_wr, *wr = NULL;
378         struct ib_rdma_wr imm_wr;
379         struct ib_reg_wr rwr;
380         struct rtrs_srv *srv = sess->srv;
381         struct rtrs_srv_mr *srv_mr;
382         bool need_inval = false;
383         enum ib_send_flags flags;
384         u32 imm;
385         int err;
386
387         if (id->dir == READ) {
388                 struct rtrs_msg_rdma_read *rd_msg = id->rd_msg;
389                 size_t sg_cnt;
390
391                 need_inval = le16_to_cpu(rd_msg->flags) &
392                                 RTRS_MSG_NEED_INVAL_F;
393                 sg_cnt = le16_to_cpu(rd_msg->sg_cnt);
394
395                 if (need_inval) {
396                         if (likely(sg_cnt)) {
397                                 inv_wr.wr_cqe   = &io_comp_cqe;
398                                 inv_wr.sg_list = NULL;
399                                 inv_wr.num_sge = 0;
400                                 inv_wr.opcode = IB_WR_SEND_WITH_INV;
401                                 inv_wr.send_flags = 0;
402                                 /* Only one key is actually used */
403                                 inv_wr.ex.invalidate_rkey =
404                                         le32_to_cpu(rd_msg->desc[0].key);
405                         } else {
406                                 WARN_ON_ONCE(1);
407                                 need_inval = false;
408                         }
409                 }
410         }
411
412         if (need_inval && always_invalidate) {
413                 wr = &inv_wr;
414                 inv_wr.next = &rwr.wr;
415                 rwr.wr.next = &imm_wr.wr;
416         } else if (always_invalidate) {
417                 wr = &rwr.wr;
418                 rwr.wr.next = &imm_wr.wr;
419         } else if (need_inval) {
420                 wr = &inv_wr;
421                 inv_wr.next = &imm_wr.wr;
422         } else {
423                 wr = &imm_wr.wr;
424         }
425         /*
426          * From time to time we have to post signalled sends,
427          * or send queue will fill up and only QP reset can help.
428          */
429         flags = (atomic_inc_return(&con->wr_cnt) % srv->queue_depth) ?
430                 0 : IB_SEND_SIGNALED;
431         imm = rtrs_to_io_rsp_imm(id->msg_id, errno, need_inval);
432         imm_wr.wr.next = NULL;
433         if (always_invalidate) {
434                 struct ib_sge list;
435                 struct rtrs_msg_rkey_rsp *msg;
436
437                 srv_mr = &sess->mrs[id->msg_id];
438                 rwr.wr.next = &imm_wr.wr;
439                 rwr.wr.opcode = IB_WR_REG_MR;
440                 rwr.wr.wr_cqe = &local_reg_cqe;
441                 rwr.wr.num_sge = 0;
442                 rwr.wr.send_flags = 0;
443                 rwr.mr = srv_mr->mr;
444                 rwr.key = srv_mr->mr->rkey;
445                 rwr.access = (IB_ACCESS_LOCAL_WRITE |
446                               IB_ACCESS_REMOTE_WRITE);
447                 msg = srv_mr->iu->buf;
448                 msg->buf_id = cpu_to_le16(id->msg_id);
449                 msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
450                 msg->rkey = cpu_to_le32(srv_mr->mr->rkey);
451
452                 list.addr   = srv_mr->iu->dma_addr;
453                 list.length = sizeof(*msg);
454                 list.lkey   = sess->s.dev->ib_pd->local_dma_lkey;
455                 imm_wr.wr.sg_list = &list;
456                 imm_wr.wr.num_sge = 1;
457                 imm_wr.wr.opcode = IB_WR_SEND_WITH_IMM;
458                 ib_dma_sync_single_for_device(sess->s.dev->ib_dev,
459                                               srv_mr->iu->dma_addr,
460                                               srv_mr->iu->size, DMA_TO_DEVICE);
461         } else {
462                 imm_wr.wr.sg_list = NULL;
463                 imm_wr.wr.num_sge = 0;
464                 imm_wr.wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
465         }
466         imm_wr.wr.send_flags = flags;
467         imm_wr.wr.wr_cqe   = &io_comp_cqe;
468
469         imm_wr.wr.ex.imm_data = cpu_to_be32(imm);
470
471         err = ib_post_send(id->con->c.qp, wr, NULL);
472         if (unlikely(err))
473                 rtrs_err_rl(s, "Posting RDMA-Reply to QP failed, err: %d\n",
474                              err);
475
476         return err;
477 }
478
479 void close_sess(struct rtrs_srv_sess *sess)
480 {
481         enum rtrs_srv_state old_state;
482
483         if (rtrs_srv_change_state_get_old(sess, RTRS_SRV_CLOSING,
484                                            &old_state))
485                 queue_work(rtrs_wq, &sess->close_work);
486         WARN_ON(sess->state != RTRS_SRV_CLOSING);
487 }
488
489 static inline const char *rtrs_srv_state_str(enum rtrs_srv_state state)
490 {
491         switch (state) {
492         case RTRS_SRV_CONNECTING:
493                 return "RTRS_SRV_CONNECTING";
494         case RTRS_SRV_CONNECTED:
495                 return "RTRS_SRV_CONNECTED";
496         case RTRS_SRV_CLOSING:
497                 return "RTRS_SRV_CLOSING";
498         case RTRS_SRV_CLOSED:
499                 return "RTRS_SRV_CLOSED";
500         default:
501                 return "UNKNOWN";
502         }
503 }
504
505 /**
506  * rtrs_srv_resp_rdma() - Finish an RDMA request
507  *
508  * @id:         Internal RTRS operation identifier
509  * @status:     Response Code sent to the other side for this operation.
510  *              0 = success, <=0 error
511  * Context: any
512  *
513  * Finish a RDMA operation. A message is sent to the client and the
514  * corresponding memory areas will be released.
515  */
516 bool rtrs_srv_resp_rdma(struct rtrs_srv_op *id, int status)
517 {
518         struct rtrs_srv_sess *sess;
519         struct rtrs_srv_con *con;
520         struct rtrs_sess *s;
521         int err;
522
523         if (WARN_ON(!id))
524                 return true;
525
526         con = id->con;
527         s = con->c.sess;
528         sess = to_srv_sess(s);
529
530         id->status = status;
531
532         if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
533                 rtrs_err_rl(s,
534                              "Sending I/O response failed,  session is disconnected, sess state %s\n",
535                              rtrs_srv_state_str(sess->state));
536                 goto out;
537         }
538         if (always_invalidate) {
539                 struct rtrs_srv_mr *mr = &sess->mrs[id->msg_id];
540
541                 ib_update_fast_reg_key(mr->mr, ib_inc_rkey(mr->mr->rkey));
542         }
543         if (unlikely(atomic_sub_return(1,
544                                        &con->sq_wr_avail) < 0)) {
545                 pr_err("IB send queue full\n");
546                 atomic_add(1, &con->sq_wr_avail);
547                 spin_lock(&con->rsp_wr_wait_lock);
548                 list_add_tail(&id->wait_list, &con->rsp_wr_wait_list);
549                 spin_unlock(&con->rsp_wr_wait_lock);
550                 return false;
551         }
552
553         if (status || id->dir == WRITE || !id->rd_msg->sg_cnt)
554                 err = send_io_resp_imm(con, id, status);
555         else
556                 err = rdma_write_sg(id);
557
558         if (unlikely(err)) {
559                 rtrs_err_rl(s, "IO response failed: %d\n", err);
560                 close_sess(sess);
561         }
562 out:
563         rtrs_srv_put_ops_ids(sess);
564         return true;
565 }
566 EXPORT_SYMBOL(rtrs_srv_resp_rdma);
567
568 /**
569  * rtrs_srv_set_sess_priv() - Set private pointer in rtrs_srv.
570  * @srv:        Session pointer
571  * @priv:       The private pointer that is associated with the session.
572  */
573 void rtrs_srv_set_sess_priv(struct rtrs_srv *srv, void *priv)
574 {
575         srv->priv = priv;
576 }
577 EXPORT_SYMBOL(rtrs_srv_set_sess_priv);
578
579 static void unmap_cont_bufs(struct rtrs_srv_sess *sess)
580 {
581         int i;
582
583         for (i = 0; i < sess->mrs_num; i++) {
584                 struct rtrs_srv_mr *srv_mr;
585
586                 srv_mr = &sess->mrs[i];
587                 rtrs_iu_free(srv_mr->iu, sess->s.dev->ib_dev, 1);
588                 ib_dereg_mr(srv_mr->mr);
589                 ib_dma_unmap_sg(sess->s.dev->ib_dev, srv_mr->sgt.sgl,
590                                 srv_mr->sgt.nents, DMA_BIDIRECTIONAL);
591                 sg_free_table(&srv_mr->sgt);
592         }
593         kfree(sess->mrs);
594 }
595
596 static int map_cont_bufs(struct rtrs_srv_sess *sess)
597 {
598         struct rtrs_srv *srv = sess->srv;
599         struct rtrs_sess *ss = &sess->s;
600         int i, mri, err, mrs_num;
601         unsigned int chunk_bits;
602         int chunks_per_mr = 1;
603
604         /*
605          * Here we map queue_depth chunks to MR.  Firstly we have to
606          * figure out how many chunks can we map per MR.
607          */
608         if (always_invalidate) {
609                 /*
610                  * in order to do invalidate for each chunks of memory, we needs
611                  * more memory regions.
612                  */
613                 mrs_num = srv->queue_depth;
614         } else {
615                 chunks_per_mr =
616                         sess->s.dev->ib_dev->attrs.max_fast_reg_page_list_len;
617                 mrs_num = DIV_ROUND_UP(srv->queue_depth, chunks_per_mr);
618                 chunks_per_mr = DIV_ROUND_UP(srv->queue_depth, mrs_num);
619         }
620
621         sess->mrs = kcalloc(mrs_num, sizeof(*sess->mrs), GFP_KERNEL);
622         if (!sess->mrs)
623                 return -ENOMEM;
624
625         sess->mrs_num = mrs_num;
626
627         for (mri = 0; mri < mrs_num; mri++) {
628                 struct rtrs_srv_mr *srv_mr = &sess->mrs[mri];
629                 struct sg_table *sgt = &srv_mr->sgt;
630                 struct scatterlist *s;
631                 struct ib_mr *mr;
632                 int nr, chunks;
633
634                 chunks = chunks_per_mr * mri;
635                 if (!always_invalidate)
636                         chunks_per_mr = min_t(int, chunks_per_mr,
637                                               srv->queue_depth - chunks);
638
639                 err = sg_alloc_table(sgt, chunks_per_mr, GFP_KERNEL);
640                 if (err)
641                         goto err;
642
643                 for_each_sg(sgt->sgl, s, chunks_per_mr, i)
644                         sg_set_page(s, srv->chunks[chunks + i],
645                                     max_chunk_size, 0);
646
647                 nr = ib_dma_map_sg(sess->s.dev->ib_dev, sgt->sgl,
648                                    sgt->nents, DMA_BIDIRECTIONAL);
649                 if (nr < sgt->nents) {
650                         err = nr < 0 ? nr : -EINVAL;
651                         goto free_sg;
652                 }
653                 mr = ib_alloc_mr(sess->s.dev->ib_pd, IB_MR_TYPE_MEM_REG,
654                                  sgt->nents);
655                 if (IS_ERR(mr)) {
656                         err = PTR_ERR(mr);
657                         goto unmap_sg;
658                 }
659                 nr = ib_map_mr_sg(mr, sgt->sgl, sgt->nents,
660                                   NULL, max_chunk_size);
661                 if (nr < 0 || nr < sgt->nents) {
662                         err = nr < 0 ? nr : -EINVAL;
663                         goto dereg_mr;
664                 }
665
666                 if (always_invalidate) {
667                         srv_mr->iu = rtrs_iu_alloc(1,
668                                         sizeof(struct rtrs_msg_rkey_rsp),
669                                         GFP_KERNEL, sess->s.dev->ib_dev,
670                                         DMA_TO_DEVICE, rtrs_srv_rdma_done);
671                         if (!srv_mr->iu) {
672                                 err = -ENOMEM;
673                                 rtrs_err(ss, "rtrs_iu_alloc(), err: %d\n", err);
674                                 goto dereg_mr;
675                         }
676                 }
677                 /* Eventually dma addr for each chunk can be cached */
678                 for_each_sg(sgt->sgl, s, sgt->orig_nents, i)
679                         sess->dma_addr[chunks + i] = sg_dma_address(s);
680
681                 ib_update_fast_reg_key(mr, ib_inc_rkey(mr->rkey));
682                 srv_mr->mr = mr;
683
684                 continue;
685 err:
686                 while (mri--) {
687                         srv_mr = &sess->mrs[mri];
688                         sgt = &srv_mr->sgt;
689                         mr = srv_mr->mr;
690                         rtrs_iu_free(srv_mr->iu, sess->s.dev->ib_dev, 1);
691 dereg_mr:
692                         ib_dereg_mr(mr);
693 unmap_sg:
694                         ib_dma_unmap_sg(sess->s.dev->ib_dev, sgt->sgl,
695                                         sgt->nents, DMA_BIDIRECTIONAL);
696 free_sg:
697                         sg_free_table(sgt);
698                 }
699                 kfree(sess->mrs);
700
701                 return err;
702         }
703
704         chunk_bits = ilog2(srv->queue_depth - 1) + 1;
705         sess->mem_bits = (MAX_IMM_PAYL_BITS - chunk_bits);
706
707         return 0;
708 }
709
710 static void rtrs_srv_hb_err_handler(struct rtrs_con *c)
711 {
712         close_sess(to_srv_sess(c->sess));
713 }
714
715 static void rtrs_srv_init_hb(struct rtrs_srv_sess *sess)
716 {
717         rtrs_init_hb(&sess->s, &io_comp_cqe,
718                       RTRS_HB_INTERVAL_MS,
719                       RTRS_HB_MISSED_MAX,
720                       rtrs_srv_hb_err_handler,
721                       rtrs_wq);
722 }
723
724 static void rtrs_srv_start_hb(struct rtrs_srv_sess *sess)
725 {
726         rtrs_start_hb(&sess->s);
727 }
728
729 static void rtrs_srv_stop_hb(struct rtrs_srv_sess *sess)
730 {
731         rtrs_stop_hb(&sess->s);
732 }
733
734 static void rtrs_srv_info_rsp_done(struct ib_cq *cq, struct ib_wc *wc)
735 {
736         struct rtrs_srv_con *con = cq->cq_context;
737         struct rtrs_sess *s = con->c.sess;
738         struct rtrs_srv_sess *sess = to_srv_sess(s);
739         struct rtrs_iu *iu;
740
741         iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
742         rtrs_iu_free(iu, sess->s.dev->ib_dev, 1);
743
744         if (unlikely(wc->status != IB_WC_SUCCESS)) {
745                 rtrs_err(s, "Sess info response send failed: %s\n",
746                           ib_wc_status_msg(wc->status));
747                 close_sess(sess);
748                 return;
749         }
750         WARN_ON(wc->opcode != IB_WC_SEND);
751 }
752
753 static void rtrs_srv_sess_up(struct rtrs_srv_sess *sess)
754 {
755         struct rtrs_srv *srv = sess->srv;
756         struct rtrs_srv_ctx *ctx = srv->ctx;
757         int up;
758
759         mutex_lock(&srv->paths_ev_mutex);
760         up = ++srv->paths_up;
761         if (up == 1)
762                 ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_CONNECTED, NULL);
763         mutex_unlock(&srv->paths_ev_mutex);
764
765         /* Mark session as established */
766         sess->established = true;
767 }
768
769 static void rtrs_srv_sess_down(struct rtrs_srv_sess *sess)
770 {
771         struct rtrs_srv *srv = sess->srv;
772         struct rtrs_srv_ctx *ctx = srv->ctx;
773
774         if (!sess->established)
775                 return;
776
777         sess->established = false;
778         mutex_lock(&srv->paths_ev_mutex);
779         WARN_ON(!srv->paths_up);
780         if (--srv->paths_up == 0)
781                 ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_DISCONNECTED, srv->priv);
782         mutex_unlock(&srv->paths_ev_mutex);
783 }
784
785 static int post_recv_sess(struct rtrs_srv_sess *sess);
786
787 static int process_info_req(struct rtrs_srv_con *con,
788                             struct rtrs_msg_info_req *msg)
789 {
790         struct rtrs_sess *s = con->c.sess;
791         struct rtrs_srv_sess *sess = to_srv_sess(s);
792         struct ib_send_wr *reg_wr = NULL;
793         struct rtrs_msg_info_rsp *rsp;
794         struct rtrs_iu *tx_iu;
795         struct ib_reg_wr *rwr;
796         int mri, err;
797         size_t tx_sz;
798
799         err = post_recv_sess(sess);
800         if (unlikely(err)) {
801                 rtrs_err(s, "post_recv_sess(), err: %d\n", err);
802                 return err;
803         }
804         rwr = kcalloc(sess->mrs_num, sizeof(*rwr), GFP_KERNEL);
805         if (unlikely(!rwr))
806                 return -ENOMEM;
807         strlcpy(sess->s.sessname, msg->sessname, sizeof(sess->s.sessname));
808
809         tx_sz  = sizeof(*rsp);
810         tx_sz += sizeof(rsp->desc[0]) * sess->mrs_num;
811         tx_iu = rtrs_iu_alloc(1, tx_sz, GFP_KERNEL, sess->s.dev->ib_dev,
812                                DMA_TO_DEVICE, rtrs_srv_info_rsp_done);
813         if (unlikely(!tx_iu)) {
814                 err = -ENOMEM;
815                 goto rwr_free;
816         }
817
818         rsp = tx_iu->buf;
819         rsp->type = cpu_to_le16(RTRS_MSG_INFO_RSP);
820         rsp->sg_cnt = cpu_to_le16(sess->mrs_num);
821
822         for (mri = 0; mri < sess->mrs_num; mri++) {
823                 struct ib_mr *mr = sess->mrs[mri].mr;
824
825                 rsp->desc[mri].addr = cpu_to_le64(mr->iova);
826                 rsp->desc[mri].key  = cpu_to_le32(mr->rkey);
827                 rsp->desc[mri].len  = cpu_to_le32(mr->length);
828
829                 /*
830                  * Fill in reg MR request and chain them *backwards*
831                  */
832                 rwr[mri].wr.next = mri ? &rwr[mri - 1].wr : NULL;
833                 rwr[mri].wr.opcode = IB_WR_REG_MR;
834                 rwr[mri].wr.wr_cqe = &local_reg_cqe;
835                 rwr[mri].wr.num_sge = 0;
836                 rwr[mri].wr.send_flags = 0;
837                 rwr[mri].mr = mr;
838                 rwr[mri].key = mr->rkey;
839                 rwr[mri].access = (IB_ACCESS_LOCAL_WRITE |
840                                    IB_ACCESS_REMOTE_WRITE);
841                 reg_wr = &rwr[mri].wr;
842         }
843
844         err = rtrs_srv_create_sess_files(sess);
845         if (unlikely(err))
846                 goto iu_free;
847         kobject_get(&sess->kobj);
848         get_device(&sess->srv->dev);
849         rtrs_srv_change_state(sess, RTRS_SRV_CONNECTED);
850         rtrs_srv_start_hb(sess);
851
852         /*
853          * We do not account number of established connections at the current
854          * moment, we rely on the client, which should send info request when
855          * all connections are successfully established.  Thus, simply notify
856          * listener with a proper event if we are the first path.
857          */
858         rtrs_srv_sess_up(sess);
859
860         ib_dma_sync_single_for_device(sess->s.dev->ib_dev, tx_iu->dma_addr,
861                                       tx_iu->size, DMA_TO_DEVICE);
862
863         /* Send info response */
864         err = rtrs_iu_post_send(&con->c, tx_iu, tx_sz, reg_wr);
865         if (unlikely(err)) {
866                 rtrs_err(s, "rtrs_iu_post_send(), err: %d\n", err);
867 iu_free:
868                 rtrs_iu_free(tx_iu, sess->s.dev->ib_dev, 1);
869         }
870 rwr_free:
871         kfree(rwr);
872
873         return err;
874 }
875
876 static void rtrs_srv_info_req_done(struct ib_cq *cq, struct ib_wc *wc)
877 {
878         struct rtrs_srv_con *con = cq->cq_context;
879         struct rtrs_sess *s = con->c.sess;
880         struct rtrs_srv_sess *sess = to_srv_sess(s);
881         struct rtrs_msg_info_req *msg;
882         struct rtrs_iu *iu;
883         int err;
884
885         WARN_ON(con->c.cid);
886
887         iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
888         if (unlikely(wc->status != IB_WC_SUCCESS)) {
889                 rtrs_err(s, "Sess info request receive failed: %s\n",
890                           ib_wc_status_msg(wc->status));
891                 goto close;
892         }
893         WARN_ON(wc->opcode != IB_WC_RECV);
894
895         if (unlikely(wc->byte_len < sizeof(*msg))) {
896                 rtrs_err(s, "Sess info request is malformed: size %d\n",
897                           wc->byte_len);
898                 goto close;
899         }
900         ib_dma_sync_single_for_cpu(sess->s.dev->ib_dev, iu->dma_addr,
901                                    iu->size, DMA_FROM_DEVICE);
902         msg = iu->buf;
903         if (unlikely(le16_to_cpu(msg->type) != RTRS_MSG_INFO_REQ)) {
904                 rtrs_err(s, "Sess info request is malformed: type %d\n",
905                           le16_to_cpu(msg->type));
906                 goto close;
907         }
908         err = process_info_req(con, msg);
909         if (unlikely(err))
910                 goto close;
911
912 out:
913         rtrs_iu_free(iu, sess->s.dev->ib_dev, 1);
914         return;
915 close:
916         close_sess(sess);
917         goto out;
918 }
919
920 static int post_recv_info_req(struct rtrs_srv_con *con)
921 {
922         struct rtrs_sess *s = con->c.sess;
923         struct rtrs_srv_sess *sess = to_srv_sess(s);
924         struct rtrs_iu *rx_iu;
925         int err;
926
927         rx_iu = rtrs_iu_alloc(1, sizeof(struct rtrs_msg_info_req),
928                                GFP_KERNEL, sess->s.dev->ib_dev,
929                                DMA_FROM_DEVICE, rtrs_srv_info_req_done);
930         if (unlikely(!rx_iu))
931                 return -ENOMEM;
932         /* Prepare for getting info response */
933         err = rtrs_iu_post_recv(&con->c, rx_iu);
934         if (unlikely(err)) {
935                 rtrs_err(s, "rtrs_iu_post_recv(), err: %d\n", err);
936                 rtrs_iu_free(rx_iu, sess->s.dev->ib_dev, 1);
937                 return err;
938         }
939
940         return 0;
941 }
942
943 static int post_recv_io(struct rtrs_srv_con *con, size_t q_size)
944 {
945         int i, err;
946
947         for (i = 0; i < q_size; i++) {
948                 err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
949                 if (unlikely(err))
950                         return err;
951         }
952
953         return 0;
954 }
955
956 static int post_recv_sess(struct rtrs_srv_sess *sess)
957 {
958         struct rtrs_srv *srv = sess->srv;
959         struct rtrs_sess *s = &sess->s;
960         size_t q_size;
961         int err, cid;
962
963         for (cid = 0; cid < sess->s.con_num; cid++) {
964                 if (cid == 0)
965                         q_size = SERVICE_CON_QUEUE_DEPTH;
966                 else
967                         q_size = srv->queue_depth;
968
969                 err = post_recv_io(to_srv_con(sess->s.con[cid]), q_size);
970                 if (unlikely(err)) {
971                         rtrs_err(s, "post_recv_io(), err: %d\n", err);
972                         return err;
973                 }
974         }
975
976         return 0;
977 }
978
979 static void process_read(struct rtrs_srv_con *con,
980                          struct rtrs_msg_rdma_read *msg,
981                          u32 buf_id, u32 off)
982 {
983         struct rtrs_sess *s = con->c.sess;
984         struct rtrs_srv_sess *sess = to_srv_sess(s);
985         struct rtrs_srv *srv = sess->srv;
986         struct rtrs_srv_ctx *ctx = srv->ctx;
987         struct rtrs_srv_op *id;
988
989         size_t usr_len, data_len;
990         void *data;
991         int ret;
992
993         if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
994                 rtrs_err_rl(s,
995                              "Processing read request failed,  session is disconnected, sess state %s\n",
996                              rtrs_srv_state_str(sess->state));
997                 return;
998         }
999         if (unlikely(msg->sg_cnt != 1 && msg->sg_cnt != 0)) {
1000                 rtrs_err_rl(s,
1001                             "Processing read request failed, invalid message\n");
1002                 return;
1003         }
1004         rtrs_srv_get_ops_ids(sess);
1005         rtrs_srv_update_rdma_stats(sess->stats, off, READ);
1006         id = sess->ops_ids[buf_id];
1007         id->con         = con;
1008         id->dir         = READ;
1009         id->msg_id      = buf_id;
1010         id->rd_msg      = msg;
1011         usr_len = le16_to_cpu(msg->usr_len);
1012         data_len = off - usr_len;
1013         data = page_address(srv->chunks[buf_id]);
1014         ret = ctx->ops.rdma_ev(srv, srv->priv, id, READ, data, data_len,
1015                            data + data_len, usr_len);
1016
1017         if (unlikely(ret)) {
1018                 rtrs_err_rl(s,
1019                              "Processing read request failed, user module cb reported for msg_id %d, err: %d\n",
1020                              buf_id, ret);
1021                 goto send_err_msg;
1022         }
1023
1024         return;
1025
1026 send_err_msg:
1027         ret = send_io_resp_imm(con, id, ret);
1028         if (ret < 0) {
1029                 rtrs_err_rl(s,
1030                              "Sending err msg for failed RDMA-Write-Req failed, msg_id %d, err: %d\n",
1031                              buf_id, ret);
1032                 close_sess(sess);
1033         }
1034         rtrs_srv_put_ops_ids(sess);
1035 }
1036
1037 static void process_write(struct rtrs_srv_con *con,
1038                           struct rtrs_msg_rdma_write *req,
1039                           u32 buf_id, u32 off)
1040 {
1041         struct rtrs_sess *s = con->c.sess;
1042         struct rtrs_srv_sess *sess = to_srv_sess(s);
1043         struct rtrs_srv *srv = sess->srv;
1044         struct rtrs_srv_ctx *ctx = srv->ctx;
1045         struct rtrs_srv_op *id;
1046
1047         size_t data_len, usr_len;
1048         void *data;
1049         int ret;
1050
1051         if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
1052                 rtrs_err_rl(s,
1053                              "Processing write request failed,  session is disconnected, sess state %s\n",
1054                              rtrs_srv_state_str(sess->state));
1055                 return;
1056         }
1057         rtrs_srv_get_ops_ids(sess);
1058         rtrs_srv_update_rdma_stats(sess->stats, off, WRITE);
1059         id = sess->ops_ids[buf_id];
1060         id->con    = con;
1061         id->dir    = WRITE;
1062         id->msg_id = buf_id;
1063
1064         usr_len = le16_to_cpu(req->usr_len);
1065         data_len = off - usr_len;
1066         data = page_address(srv->chunks[buf_id]);
1067         ret = ctx->ops.rdma_ev(srv, srv->priv, id, WRITE, data, data_len,
1068                            data + data_len, usr_len);
1069         if (unlikely(ret)) {
1070                 rtrs_err_rl(s,
1071                              "Processing write request failed, user module callback reports err: %d\n",
1072                              ret);
1073                 goto send_err_msg;
1074         }
1075
1076         return;
1077
1078 send_err_msg:
1079         ret = send_io_resp_imm(con, id, ret);
1080         if (ret < 0) {
1081                 rtrs_err_rl(s,
1082                              "Processing write request failed, sending I/O response failed, msg_id %d, err: %d\n",
1083                              buf_id, ret);
1084                 close_sess(sess);
1085         }
1086         rtrs_srv_put_ops_ids(sess);
1087 }
1088
1089 static void process_io_req(struct rtrs_srv_con *con, void *msg,
1090                            u32 id, u32 off)
1091 {
1092         struct rtrs_sess *s = con->c.sess;
1093         struct rtrs_srv_sess *sess = to_srv_sess(s);
1094         struct rtrs_msg_rdma_hdr *hdr;
1095         unsigned int type;
1096
1097         ib_dma_sync_single_for_cpu(sess->s.dev->ib_dev, sess->dma_addr[id],
1098                                    max_chunk_size, DMA_BIDIRECTIONAL);
1099         hdr = msg;
1100         type = le16_to_cpu(hdr->type);
1101
1102         switch (type) {
1103         case RTRS_MSG_WRITE:
1104                 process_write(con, msg, id, off);
1105                 break;
1106         case RTRS_MSG_READ:
1107                 process_read(con, msg, id, off);
1108                 break;
1109         default:
1110                 rtrs_err(s,
1111                           "Processing I/O request failed, unknown message type received: 0x%02x\n",
1112                           type);
1113                 goto err;
1114         }
1115
1116         return;
1117
1118 err:
1119         close_sess(sess);
1120 }
1121
1122 static void rtrs_srv_inv_rkey_done(struct ib_cq *cq, struct ib_wc *wc)
1123 {
1124         struct rtrs_srv_mr *mr =
1125                 container_of(wc->wr_cqe, typeof(*mr), inv_cqe);
1126         struct rtrs_srv_con *con = cq->cq_context;
1127         struct rtrs_sess *s = con->c.sess;
1128         struct rtrs_srv_sess *sess = to_srv_sess(s);
1129         struct rtrs_srv *srv = sess->srv;
1130         u32 msg_id, off;
1131         void *data;
1132
1133         if (unlikely(wc->status != IB_WC_SUCCESS)) {
1134                 rtrs_err(s, "Failed IB_WR_LOCAL_INV: %s\n",
1135                           ib_wc_status_msg(wc->status));
1136                 close_sess(sess);
1137         }
1138         msg_id = mr->msg_id;
1139         off = mr->msg_off;
1140         data = page_address(srv->chunks[msg_id]) + off;
1141         process_io_req(con, data, msg_id, off);
1142 }
1143
1144 static int rtrs_srv_inv_rkey(struct rtrs_srv_con *con,
1145                               struct rtrs_srv_mr *mr)
1146 {
1147         struct ib_send_wr wr = {
1148                 .opcode             = IB_WR_LOCAL_INV,
1149                 .wr_cqe             = &mr->inv_cqe,
1150                 .send_flags         = IB_SEND_SIGNALED,
1151                 .ex.invalidate_rkey = mr->mr->rkey,
1152         };
1153         mr->inv_cqe.done = rtrs_srv_inv_rkey_done;
1154
1155         return ib_post_send(con->c.qp, &wr, NULL);
1156 }
1157
1158 static void rtrs_rdma_process_wr_wait_list(struct rtrs_srv_con *con)
1159 {
1160         spin_lock(&con->rsp_wr_wait_lock);
1161         while (!list_empty(&con->rsp_wr_wait_list)) {
1162                 struct rtrs_srv_op *id;
1163                 int ret;
1164
1165                 id = list_entry(con->rsp_wr_wait_list.next,
1166                                 struct rtrs_srv_op, wait_list);
1167                 list_del(&id->wait_list);
1168
1169                 spin_unlock(&con->rsp_wr_wait_lock);
1170                 ret = rtrs_srv_resp_rdma(id, id->status);
1171                 spin_lock(&con->rsp_wr_wait_lock);
1172
1173                 if (!ret) {
1174                         list_add(&id->wait_list, &con->rsp_wr_wait_list);
1175                         break;
1176                 }
1177         }
1178         spin_unlock(&con->rsp_wr_wait_lock);
1179 }
1180
1181 static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc)
1182 {
1183         struct rtrs_srv_con *con = cq->cq_context;
1184         struct rtrs_sess *s = con->c.sess;
1185         struct rtrs_srv_sess *sess = to_srv_sess(s);
1186         struct rtrs_srv *srv = sess->srv;
1187         u32 imm_type, imm_payload;
1188         int err;
1189
1190         if (unlikely(wc->status != IB_WC_SUCCESS)) {
1191                 if (wc->status != IB_WC_WR_FLUSH_ERR) {
1192                         rtrs_err(s,
1193                                   "%s (wr_cqe: %p, type: %d, vendor_err: 0x%x, len: %u)\n",
1194                                   ib_wc_status_msg(wc->status), wc->wr_cqe,
1195                                   wc->opcode, wc->vendor_err, wc->byte_len);
1196                         close_sess(sess);
1197                 }
1198                 return;
1199         }
1200
1201         switch (wc->opcode) {
1202         case IB_WC_RECV_RDMA_WITH_IMM:
1203                 /*
1204                  * post_recv() RDMA write completions of IO reqs (read/write)
1205                  * and hb
1206                  */
1207                 if (WARN_ON(wc->wr_cqe != &io_comp_cqe))
1208                         return;
1209                 err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
1210                 if (unlikely(err)) {
1211                         rtrs_err(s, "rtrs_post_recv(), err: %d\n", err);
1212                         close_sess(sess);
1213                         break;
1214                 }
1215                 rtrs_from_imm(be32_to_cpu(wc->ex.imm_data),
1216                                &imm_type, &imm_payload);
1217                 if (likely(imm_type == RTRS_IO_REQ_IMM)) {
1218                         u32 msg_id, off;
1219                         void *data;
1220
1221                         msg_id = imm_payload >> sess->mem_bits;
1222                         off = imm_payload & ((1 << sess->mem_bits) - 1);
1223                         if (unlikely(msg_id >= srv->queue_depth ||
1224                                      off >= max_chunk_size)) {
1225                                 rtrs_err(s, "Wrong msg_id %u, off %u\n",
1226                                           msg_id, off);
1227                                 close_sess(sess);
1228                                 return;
1229                         }
1230                         if (always_invalidate) {
1231                                 struct rtrs_srv_mr *mr = &sess->mrs[msg_id];
1232
1233                                 mr->msg_off = off;
1234                                 mr->msg_id = msg_id;
1235                                 err = rtrs_srv_inv_rkey(con, mr);
1236                                 if (unlikely(err)) {
1237                                         rtrs_err(s, "rtrs_post_recv(), err: %d\n",
1238                                                   err);
1239                                         close_sess(sess);
1240                                         break;
1241                                 }
1242                         } else {
1243                                 data = page_address(srv->chunks[msg_id]) + off;
1244                                 process_io_req(con, data, msg_id, off);
1245                         }
1246                 } else if (imm_type == RTRS_HB_MSG_IMM) {
1247                         WARN_ON(con->c.cid);
1248                         rtrs_send_hb_ack(&sess->s);
1249                 } else if (imm_type == RTRS_HB_ACK_IMM) {
1250                         WARN_ON(con->c.cid);
1251                         sess->s.hb_missed_cnt = 0;
1252                 } else {
1253                         rtrs_wrn(s, "Unknown IMM type %u\n", imm_type);
1254                 }
1255                 break;
1256         case IB_WC_RDMA_WRITE:
1257         case IB_WC_SEND:
1258                 /*
1259                  * post_send() RDMA write completions of IO reqs (read/write)
1260                  */
1261                 atomic_add(srv->queue_depth, &con->sq_wr_avail);
1262
1263                 if (unlikely(!list_empty_careful(&con->rsp_wr_wait_list)))
1264                         rtrs_rdma_process_wr_wait_list(con);
1265
1266                 break;
1267         default:
1268                 rtrs_wrn(s, "Unexpected WC type: %d\n", wc->opcode);
1269                 return;
1270         }
1271 }
1272
1273 /**
1274  * rtrs_srv_get_sess_name() - Get rtrs_srv peer hostname.
1275  * @srv:        Session
1276  * @sessname:   Sessname buffer
1277  * @len:        Length of sessname buffer
1278  */
1279 int rtrs_srv_get_sess_name(struct rtrs_srv *srv, char *sessname, size_t len)
1280 {
1281         struct rtrs_srv_sess *sess;
1282         int err = -ENOTCONN;
1283
1284         mutex_lock(&srv->paths_mutex);
1285         list_for_each_entry(sess, &srv->paths_list, s.entry) {
1286                 if (sess->state != RTRS_SRV_CONNECTED)
1287                         continue;
1288                 strlcpy(sessname, sess->s.sessname,
1289                        min_t(size_t, sizeof(sess->s.sessname), len));
1290                 err = 0;
1291                 break;
1292         }
1293         mutex_unlock(&srv->paths_mutex);
1294
1295         return err;
1296 }
1297 EXPORT_SYMBOL(rtrs_srv_get_sess_name);
1298
1299 /**
1300  * rtrs_srv_get_sess_qdepth() - Get rtrs_srv qdepth.
1301  * @srv:        Session
1302  */
1303 int rtrs_srv_get_queue_depth(struct rtrs_srv *srv)
1304 {
1305         return srv->queue_depth;
1306 }
1307 EXPORT_SYMBOL(rtrs_srv_get_queue_depth);
1308
1309 static int find_next_bit_ring(struct rtrs_srv_sess *sess)
1310 {
1311         struct ib_device *ib_dev = sess->s.dev->ib_dev;
1312         int v;
1313
1314         v = cpumask_next(sess->cur_cq_vector, &cq_affinity_mask);
1315         if (v >= nr_cpu_ids || v >= ib_dev->num_comp_vectors)
1316                 v = cpumask_first(&cq_affinity_mask);
1317         return v;
1318 }
1319
1320 static int rtrs_srv_get_next_cq_vector(struct rtrs_srv_sess *sess)
1321 {
1322         sess->cur_cq_vector = find_next_bit_ring(sess);
1323
1324         return sess->cur_cq_vector;
1325 }
1326
1327 static void rtrs_srv_dev_release(struct device *dev)
1328 {
1329         struct rtrs_srv *srv = container_of(dev, struct rtrs_srv, dev);
1330
1331         kfree(srv);
1332 }
1333
1334 static void free_srv(struct rtrs_srv *srv)
1335 {
1336         int i;
1337
1338         WARN_ON(refcount_read(&srv->refcount));
1339         for (i = 0; i < srv->queue_depth; i++)
1340                 mempool_free(srv->chunks[i], chunk_pool);
1341         kfree(srv->chunks);
1342         mutex_destroy(&srv->paths_mutex);
1343         mutex_destroy(&srv->paths_ev_mutex);
1344         /* last put to release the srv structure */
1345         put_device(&srv->dev);
1346 }
1347
1348 static struct rtrs_srv *get_or_create_srv(struct rtrs_srv_ctx *ctx,
1349                                           const uuid_t *paths_uuid,
1350                                           bool first_conn)
1351 {
1352         struct rtrs_srv *srv;
1353         int i;
1354
1355         mutex_lock(&ctx->srv_mutex);
1356         list_for_each_entry(srv, &ctx->srv_list, ctx_list) {
1357                 if (uuid_equal(&srv->paths_uuid, paths_uuid) &&
1358                     refcount_inc_not_zero(&srv->refcount)) {
1359                         mutex_unlock(&ctx->srv_mutex);
1360                         return srv;
1361                 }
1362         }
1363         mutex_unlock(&ctx->srv_mutex);
1364         /*
1365          * If this request is not the first connection request from the
1366          * client for this session then fail and return error.
1367          */
1368         if (!first_conn)
1369                 return ERR_PTR(-ENXIO);
1370
1371         /* need to allocate a new srv */
1372         srv = kzalloc(sizeof(*srv), GFP_KERNEL);
1373         if  (!srv)
1374                 return ERR_PTR(-ENOMEM);
1375
1376         INIT_LIST_HEAD(&srv->paths_list);
1377         mutex_init(&srv->paths_mutex);
1378         mutex_init(&srv->paths_ev_mutex);
1379         uuid_copy(&srv->paths_uuid, paths_uuid);
1380         srv->queue_depth = sess_queue_depth;
1381         srv->ctx = ctx;
1382         device_initialize(&srv->dev);
1383         srv->dev.release = rtrs_srv_dev_release;
1384
1385         srv->chunks = kcalloc(srv->queue_depth, sizeof(*srv->chunks),
1386                               GFP_KERNEL);
1387         if (!srv->chunks)
1388                 goto err_free_srv;
1389
1390         for (i = 0; i < srv->queue_depth; i++) {
1391                 srv->chunks[i] = mempool_alloc(chunk_pool, GFP_KERNEL);
1392                 if (!srv->chunks[i])
1393                         goto err_free_chunks;
1394         }
1395         refcount_set(&srv->refcount, 1);
1396         mutex_lock(&ctx->srv_mutex);
1397         list_add(&srv->ctx_list, &ctx->srv_list);
1398         mutex_unlock(&ctx->srv_mutex);
1399
1400         return srv;
1401
1402 err_free_chunks:
1403         while (i--)
1404                 mempool_free(srv->chunks[i], chunk_pool);
1405         kfree(srv->chunks);
1406
1407 err_free_srv:
1408         kfree(srv);
1409         return ERR_PTR(-ENOMEM);
1410 }
1411
1412 static void put_srv(struct rtrs_srv *srv)
1413 {
1414         if (refcount_dec_and_test(&srv->refcount)) {
1415                 struct rtrs_srv_ctx *ctx = srv->ctx;
1416
1417                 WARN_ON(srv->dev.kobj.state_in_sysfs);
1418
1419                 mutex_lock(&ctx->srv_mutex);
1420                 list_del(&srv->ctx_list);
1421                 mutex_unlock(&ctx->srv_mutex);
1422                 free_srv(srv);
1423         }
1424 }
1425
1426 static void __add_path_to_srv(struct rtrs_srv *srv,
1427                               struct rtrs_srv_sess *sess)
1428 {
1429         list_add_tail(&sess->s.entry, &srv->paths_list);
1430         srv->paths_num++;
1431         WARN_ON(srv->paths_num >= MAX_PATHS_NUM);
1432 }
1433
1434 static void del_path_from_srv(struct rtrs_srv_sess *sess)
1435 {
1436         struct rtrs_srv *srv = sess->srv;
1437
1438         if (WARN_ON(!srv))
1439                 return;
1440
1441         mutex_lock(&srv->paths_mutex);
1442         list_del(&sess->s.entry);
1443         WARN_ON(!srv->paths_num);
1444         srv->paths_num--;
1445         mutex_unlock(&srv->paths_mutex);
1446 }
1447
1448 /* return true if addresses are the same, error other wise */
1449 static int sockaddr_cmp(const struct sockaddr *a, const struct sockaddr *b)
1450 {
1451         switch (a->sa_family) {
1452         case AF_IB:
1453                 return memcmp(&((struct sockaddr_ib *)a)->sib_addr,
1454                               &((struct sockaddr_ib *)b)->sib_addr,
1455                               sizeof(struct ib_addr)) &&
1456                         (b->sa_family == AF_IB);
1457         case AF_INET:
1458                 return memcmp(&((struct sockaddr_in *)a)->sin_addr,
1459                               &((struct sockaddr_in *)b)->sin_addr,
1460                               sizeof(struct in_addr)) &&
1461                         (b->sa_family == AF_INET);
1462         case AF_INET6:
1463                 return memcmp(&((struct sockaddr_in6 *)a)->sin6_addr,
1464                               &((struct sockaddr_in6 *)b)->sin6_addr,
1465                               sizeof(struct in6_addr)) &&
1466                         (b->sa_family == AF_INET6);
1467         default:
1468                 return -ENOENT;
1469         }
1470 }
1471
1472 static bool __is_path_w_addr_exists(struct rtrs_srv *srv,
1473                                     struct rdma_addr *addr)
1474 {
1475         struct rtrs_srv_sess *sess;
1476
1477         list_for_each_entry(sess, &srv->paths_list, s.entry)
1478                 if (!sockaddr_cmp((struct sockaddr *)&sess->s.dst_addr,
1479                                   (struct sockaddr *)&addr->dst_addr) &&
1480                     !sockaddr_cmp((struct sockaddr *)&sess->s.src_addr,
1481                                   (struct sockaddr *)&addr->src_addr))
1482                         return true;
1483
1484         return false;
1485 }
1486
1487 static void free_sess(struct rtrs_srv_sess *sess)
1488 {
1489         if (sess->kobj.state_in_sysfs) {
1490                 kobject_del(&sess->kobj);
1491                 kobject_put(&sess->kobj);
1492         } else {
1493                 kfree(sess->stats);
1494                 kfree(sess);
1495         }
1496 }
1497
1498 static void rtrs_srv_close_work(struct work_struct *work)
1499 {
1500         struct rtrs_srv_sess *sess;
1501         struct rtrs_srv_con *con;
1502         int i;
1503
1504         sess = container_of(work, typeof(*sess), close_work);
1505
1506         rtrs_srv_destroy_sess_files(sess);
1507         rtrs_srv_stop_hb(sess);
1508
1509         for (i = 0; i < sess->s.con_num; i++) {
1510                 if (!sess->s.con[i])
1511                         continue;
1512                 con = to_srv_con(sess->s.con[i]);
1513                 rdma_disconnect(con->c.cm_id);
1514                 ib_drain_qp(con->c.qp);
1515         }
1516         /* Wait for all inflights */
1517         rtrs_srv_wait_ops_ids(sess);
1518
1519         /* Notify upper layer if we are the last path */
1520         rtrs_srv_sess_down(sess);
1521
1522         unmap_cont_bufs(sess);
1523         rtrs_srv_free_ops_ids(sess);
1524
1525         for (i = 0; i < sess->s.con_num; i++) {
1526                 if (!sess->s.con[i])
1527                         continue;
1528                 con = to_srv_con(sess->s.con[i]);
1529                 rtrs_cq_qp_destroy(&con->c);
1530                 rdma_destroy_id(con->c.cm_id);
1531                 kfree(con);
1532         }
1533         rtrs_ib_dev_put(sess->s.dev);
1534
1535         del_path_from_srv(sess);
1536         put_srv(sess->srv);
1537         sess->srv = NULL;
1538         rtrs_srv_change_state(sess, RTRS_SRV_CLOSED);
1539
1540         kfree(sess->dma_addr);
1541         kfree(sess->s.con);
1542         free_sess(sess);
1543 }
1544
1545 static int rtrs_rdma_do_accept(struct rtrs_srv_sess *sess,
1546                                struct rdma_cm_id *cm_id)
1547 {
1548         struct rtrs_srv *srv = sess->srv;
1549         struct rtrs_msg_conn_rsp msg;
1550         struct rdma_conn_param param;
1551         int err;
1552
1553         param = (struct rdma_conn_param) {
1554                 .rnr_retry_count = 7,
1555                 .private_data = &msg,
1556                 .private_data_len = sizeof(msg),
1557         };
1558
1559         msg = (struct rtrs_msg_conn_rsp) {
1560                 .magic = cpu_to_le16(RTRS_MAGIC),
1561                 .version = cpu_to_le16(RTRS_PROTO_VER),
1562                 .queue_depth = cpu_to_le16(srv->queue_depth),
1563                 .max_io_size = cpu_to_le32(max_chunk_size - MAX_HDR_SIZE),
1564                 .max_hdr_size = cpu_to_le32(MAX_HDR_SIZE),
1565         };
1566
1567         if (always_invalidate)
1568                 msg.flags = cpu_to_le32(RTRS_MSG_NEW_RKEY_F);
1569
1570         err = rdma_accept(cm_id, &param);
1571         if (err)
1572                 pr_err("rdma_accept(), err: %d\n", err);
1573
1574         return err;
1575 }
1576
1577 static int rtrs_rdma_do_reject(struct rdma_cm_id *cm_id, int errno)
1578 {
1579         struct rtrs_msg_conn_rsp msg;
1580         int err;
1581
1582         msg = (struct rtrs_msg_conn_rsp) {
1583                 .magic = cpu_to_le16(RTRS_MAGIC),
1584                 .version = cpu_to_le16(RTRS_PROTO_VER),
1585                 .errno = cpu_to_le16(errno),
1586         };
1587
1588         err = rdma_reject(cm_id, &msg, sizeof(msg), IB_CM_REJ_CONSUMER_DEFINED);
1589         if (err)
1590                 pr_err("rdma_reject(), err: %d\n", err);
1591
1592         /* Bounce errno back */
1593         return errno;
1594 }
1595
1596 static struct rtrs_srv_sess *
1597 __find_sess(struct rtrs_srv *srv, const uuid_t *sess_uuid)
1598 {
1599         struct rtrs_srv_sess *sess;
1600
1601         list_for_each_entry(sess, &srv->paths_list, s.entry) {
1602                 if (uuid_equal(&sess->s.uuid, sess_uuid))
1603                         return sess;
1604         }
1605
1606         return NULL;
1607 }
1608
1609 static int create_con(struct rtrs_srv_sess *sess,
1610                       struct rdma_cm_id *cm_id,
1611                       unsigned int cid)
1612 {
1613         struct rtrs_srv *srv = sess->srv;
1614         struct rtrs_sess *s = &sess->s;
1615         struct rtrs_srv_con *con;
1616
1617         u32 cq_size, max_send_wr, max_recv_wr, wr_limit;
1618         int err, cq_vector;
1619
1620         con = kzalloc(sizeof(*con), GFP_KERNEL);
1621         if (!con) {
1622                 err = -ENOMEM;
1623                 goto err;
1624         }
1625
1626         spin_lock_init(&con->rsp_wr_wait_lock);
1627         INIT_LIST_HEAD(&con->rsp_wr_wait_list);
1628         con->c.cm_id = cm_id;
1629         con->c.sess = &sess->s;
1630         con->c.cid = cid;
1631         atomic_set(&con->wr_cnt, 1);
1632
1633         if (con->c.cid == 0) {
1634                 /*
1635                  * All receive and all send (each requiring invalidate)
1636                  * + 2 for drain and heartbeat
1637                  */
1638                 max_send_wr = SERVICE_CON_QUEUE_DEPTH * 2 + 2;
1639                 max_recv_wr = SERVICE_CON_QUEUE_DEPTH + 2;
1640                 cq_size = max_send_wr + max_recv_wr;
1641         } else {
1642                 /*
1643                  * In theory we might have queue_depth * 32
1644                  * outstanding requests if an unsafe global key is used
1645                  * and we have queue_depth read requests each consisting
1646                  * of 32 different addresses. div 3 for mlx5.
1647                  */
1648                 wr_limit = sess->s.dev->ib_dev->attrs.max_qp_wr / 3;
1649                 /* when always_invlaidate enalbed, we need linv+rinv+mr+imm */
1650                 if (always_invalidate)
1651                         max_send_wr =
1652                                 min_t(int, wr_limit,
1653                                       srv->queue_depth * (1 + 4) + 1);
1654                 else
1655                         max_send_wr =
1656                                 min_t(int, wr_limit,
1657                                       srv->queue_depth * (1 + 2) + 1);
1658
1659                 max_recv_wr = srv->queue_depth + 1;
1660                 /*
1661                  * If we have all receive requests posted and
1662                  * all write requests posted and each read request
1663                  * requires an invalidate request + drain
1664                  * and qp gets into error state.
1665                  */
1666                 cq_size = max_send_wr + max_recv_wr;
1667         }
1668         atomic_set(&con->sq_wr_avail, max_send_wr);
1669         cq_vector = rtrs_srv_get_next_cq_vector(sess);
1670
1671         /* TODO: SOFTIRQ can be faster, but be careful with softirq context */
1672         err = rtrs_cq_qp_create(&sess->s, &con->c, 1, cq_vector, cq_size,
1673                                  max_send_wr, max_recv_wr,
1674                                  IB_POLL_WORKQUEUE);
1675         if (err) {
1676                 rtrs_err(s, "rtrs_cq_qp_create(), err: %d\n", err);
1677                 goto free_con;
1678         }
1679         if (con->c.cid == 0) {
1680                 err = post_recv_info_req(con);
1681                 if (err)
1682                         goto free_cqqp;
1683         }
1684         WARN_ON(sess->s.con[cid]);
1685         sess->s.con[cid] = &con->c;
1686
1687         /*
1688          * Change context from server to current connection.  The other
1689          * way is to use cm_id->qp->qp_context, which does not work on OFED.
1690          */
1691         cm_id->context = &con->c;
1692
1693         return 0;
1694
1695 free_cqqp:
1696         rtrs_cq_qp_destroy(&con->c);
1697 free_con:
1698         kfree(con);
1699
1700 err:
1701         return err;
1702 }
1703
1704 static struct rtrs_srv_sess *__alloc_sess(struct rtrs_srv *srv,
1705                                            struct rdma_cm_id *cm_id,
1706                                            unsigned int con_num,
1707                                            unsigned int recon_cnt,
1708                                            const uuid_t *uuid)
1709 {
1710         struct rtrs_srv_sess *sess;
1711         int err = -ENOMEM;
1712
1713         if (srv->paths_num >= MAX_PATHS_NUM) {
1714                 err = -ECONNRESET;
1715                 goto err;
1716         }
1717         if (__is_path_w_addr_exists(srv, &cm_id->route.addr)) {
1718                 err = -EEXIST;
1719                 pr_err("Path with same addr exists\n");
1720                 goto err;
1721         }
1722         sess = kzalloc(sizeof(*sess), GFP_KERNEL);
1723         if (!sess)
1724                 goto err;
1725
1726         sess->stats = kzalloc(sizeof(*sess->stats), GFP_KERNEL);
1727         if (!sess->stats)
1728                 goto err_free_sess;
1729
1730         sess->stats->sess = sess;
1731
1732         sess->dma_addr = kcalloc(srv->queue_depth, sizeof(*sess->dma_addr),
1733                                  GFP_KERNEL);
1734         if (!sess->dma_addr)
1735                 goto err_free_stats;
1736
1737         sess->s.con = kcalloc(con_num, sizeof(*sess->s.con), GFP_KERNEL);
1738         if (!sess->s.con)
1739                 goto err_free_dma_addr;
1740
1741         sess->state = RTRS_SRV_CONNECTING;
1742         sess->srv = srv;
1743         sess->cur_cq_vector = -1;
1744         sess->s.dst_addr = cm_id->route.addr.dst_addr;
1745         sess->s.src_addr = cm_id->route.addr.src_addr;
1746         sess->s.con_num = con_num;
1747         sess->s.recon_cnt = recon_cnt;
1748         uuid_copy(&sess->s.uuid, uuid);
1749         spin_lock_init(&sess->state_lock);
1750         INIT_WORK(&sess->close_work, rtrs_srv_close_work);
1751         rtrs_srv_init_hb(sess);
1752
1753         sess->s.dev = rtrs_ib_dev_find_or_add(cm_id->device, &dev_pd);
1754         if (!sess->s.dev) {
1755                 err = -ENOMEM;
1756                 goto err_free_con;
1757         }
1758         err = map_cont_bufs(sess);
1759         if (err)
1760                 goto err_put_dev;
1761
1762         err = rtrs_srv_alloc_ops_ids(sess);
1763         if (err)
1764                 goto err_unmap_bufs;
1765
1766         __add_path_to_srv(srv, sess);
1767
1768         return sess;
1769
1770 err_unmap_bufs:
1771         unmap_cont_bufs(sess);
1772 err_put_dev:
1773         rtrs_ib_dev_put(sess->s.dev);
1774 err_free_con:
1775         kfree(sess->s.con);
1776 err_free_dma_addr:
1777         kfree(sess->dma_addr);
1778 err_free_stats:
1779         kfree(sess->stats);
1780 err_free_sess:
1781         kfree(sess);
1782 err:
1783         return ERR_PTR(err);
1784 }
1785
1786 static int rtrs_rdma_connect(struct rdma_cm_id *cm_id,
1787                               const struct rtrs_msg_conn_req *msg,
1788                               size_t len)
1789 {
1790         struct rtrs_srv_ctx *ctx = cm_id->context;
1791         struct rtrs_srv_sess *sess;
1792         struct rtrs_srv *srv;
1793
1794         u16 version, con_num, cid;
1795         u16 recon_cnt;
1796         int err;
1797
1798         if (len < sizeof(*msg)) {
1799                 pr_err("Invalid RTRS connection request\n");
1800                 goto reject_w_econnreset;
1801         }
1802         if (le16_to_cpu(msg->magic) != RTRS_MAGIC) {
1803                 pr_err("Invalid RTRS magic\n");
1804                 goto reject_w_econnreset;
1805         }
1806         version = le16_to_cpu(msg->version);
1807         if (version >> 8 != RTRS_PROTO_VER_MAJOR) {
1808                 pr_err("Unsupported major RTRS version: %d, expected %d\n",
1809                        version >> 8, RTRS_PROTO_VER_MAJOR);
1810                 goto reject_w_econnreset;
1811         }
1812         con_num = le16_to_cpu(msg->cid_num);
1813         if (con_num > 4096) {
1814                 /* Sanity check */
1815                 pr_err("Too many connections requested: %d\n", con_num);
1816                 goto reject_w_econnreset;
1817         }
1818         cid = le16_to_cpu(msg->cid);
1819         if (cid >= con_num) {
1820                 /* Sanity check */
1821                 pr_err("Incorrect cid: %d >= %d\n", cid, con_num);
1822                 goto reject_w_econnreset;
1823         }
1824         recon_cnt = le16_to_cpu(msg->recon_cnt);
1825         srv = get_or_create_srv(ctx, &msg->paths_uuid, msg->first_conn);
1826         if (IS_ERR(srv)) {
1827                 err = PTR_ERR(srv);
1828                 goto reject_w_err;
1829         }
1830         mutex_lock(&srv->paths_mutex);
1831         sess = __find_sess(srv, &msg->sess_uuid);
1832         if (sess) {
1833                 struct rtrs_sess *s = &sess->s;
1834
1835                 /* Session already holds a reference */
1836                 put_srv(srv);
1837
1838                 if (sess->state != RTRS_SRV_CONNECTING) {
1839                         rtrs_err(s, "Session in wrong state: %s\n",
1840                                   rtrs_srv_state_str(sess->state));
1841                         mutex_unlock(&srv->paths_mutex);
1842                         goto reject_w_econnreset;
1843                 }
1844                 /*
1845                  * Sanity checks
1846                  */
1847                 if (con_num != s->con_num || cid >= s->con_num) {
1848                         rtrs_err(s, "Incorrect request: %d, %d\n",
1849                                   cid, con_num);
1850                         mutex_unlock(&srv->paths_mutex);
1851                         goto reject_w_econnreset;
1852                 }
1853                 if (s->con[cid]) {
1854                         rtrs_err(s, "Connection already exists: %d\n",
1855                                   cid);
1856                         mutex_unlock(&srv->paths_mutex);
1857                         goto reject_w_econnreset;
1858                 }
1859         } else {
1860                 sess = __alloc_sess(srv, cm_id, con_num, recon_cnt,
1861                                     &msg->sess_uuid);
1862                 if (IS_ERR(sess)) {
1863                         mutex_unlock(&srv->paths_mutex);
1864                         put_srv(srv);
1865                         err = PTR_ERR(sess);
1866                         goto reject_w_err;
1867                 }
1868         }
1869         err = create_con(sess, cm_id, cid);
1870         if (err) {
1871                 (void)rtrs_rdma_do_reject(cm_id, err);
1872                 /*
1873                  * Since session has other connections we follow normal way
1874                  * through workqueue, but still return an error to tell cma.c
1875                  * to call rdma_destroy_id() for current connection.
1876                  */
1877                 goto close_and_return_err;
1878         }
1879         err = rtrs_rdma_do_accept(sess, cm_id);
1880         if (err) {
1881                 (void)rtrs_rdma_do_reject(cm_id, err);
1882                 /*
1883                  * Since current connection was successfully added to the
1884                  * session we follow normal way through workqueue to close the
1885                  * session, thus return 0 to tell cma.c we call
1886                  * rdma_destroy_id() ourselves.
1887                  */
1888                 err = 0;
1889                 goto close_and_return_err;
1890         }
1891         mutex_unlock(&srv->paths_mutex);
1892
1893         return 0;
1894
1895 reject_w_err:
1896         return rtrs_rdma_do_reject(cm_id, err);
1897
1898 reject_w_econnreset:
1899         return rtrs_rdma_do_reject(cm_id, -ECONNRESET);
1900
1901 close_and_return_err:
1902         mutex_unlock(&srv->paths_mutex);
1903         close_sess(sess);
1904
1905         return err;
1906 }
1907
1908 static int rtrs_srv_rdma_cm_handler(struct rdma_cm_id *cm_id,
1909                                      struct rdma_cm_event *ev)
1910 {
1911         struct rtrs_srv_sess *sess = NULL;
1912         struct rtrs_sess *s = NULL;
1913
1914         if (ev->event != RDMA_CM_EVENT_CONNECT_REQUEST) {
1915                 struct rtrs_con *c = cm_id->context;
1916
1917                 s = c->sess;
1918                 sess = to_srv_sess(s);
1919         }
1920
1921         switch (ev->event) {
1922         case RDMA_CM_EVENT_CONNECT_REQUEST:
1923                 /*
1924                  * In case of error cma.c will destroy cm_id,
1925                  * see cma_process_remove()
1926                  */
1927                 return rtrs_rdma_connect(cm_id, ev->param.conn.private_data,
1928                                           ev->param.conn.private_data_len);
1929         case RDMA_CM_EVENT_ESTABLISHED:
1930                 /* Nothing here */
1931                 break;
1932         case RDMA_CM_EVENT_REJECTED:
1933         case RDMA_CM_EVENT_CONNECT_ERROR:
1934         case RDMA_CM_EVENT_UNREACHABLE:
1935                 rtrs_err(s, "CM error (CM event: %s, err: %d)\n",
1936                           rdma_event_msg(ev->event), ev->status);
1937                 close_sess(sess);
1938                 break;
1939         case RDMA_CM_EVENT_DISCONNECTED:
1940         case RDMA_CM_EVENT_ADDR_CHANGE:
1941         case RDMA_CM_EVENT_TIMEWAIT_EXIT:
1942                 close_sess(sess);
1943                 break;
1944         case RDMA_CM_EVENT_DEVICE_REMOVAL:
1945                 close_sess(sess);
1946                 break;
1947         default:
1948                 pr_err("Ignoring unexpected CM event %s, err %d\n",
1949                        rdma_event_msg(ev->event), ev->status);
1950                 break;
1951         }
1952
1953         return 0;
1954 }
1955
1956 static struct rdma_cm_id *rtrs_srv_cm_init(struct rtrs_srv_ctx *ctx,
1957                                             struct sockaddr *addr,
1958                                             enum rdma_ucm_port_space ps)
1959 {
1960         struct rdma_cm_id *cm_id;
1961         int ret;
1962
1963         cm_id = rdma_create_id(&init_net, rtrs_srv_rdma_cm_handler,
1964                                ctx, ps, IB_QPT_RC);
1965         if (IS_ERR(cm_id)) {
1966                 ret = PTR_ERR(cm_id);
1967                 pr_err("Creating id for RDMA connection failed, err: %d\n",
1968                        ret);
1969                 goto err_out;
1970         }
1971         ret = rdma_bind_addr(cm_id, addr);
1972         if (ret) {
1973                 pr_err("Binding RDMA address failed, err: %d\n", ret);
1974                 goto err_cm;
1975         }
1976         ret = rdma_listen(cm_id, 64);
1977         if (ret) {
1978                 pr_err("Listening on RDMA connection failed, err: %d\n",
1979                        ret);
1980                 goto err_cm;
1981         }
1982
1983         return cm_id;
1984
1985 err_cm:
1986         rdma_destroy_id(cm_id);
1987 err_out:
1988
1989         return ERR_PTR(ret);
1990 }
1991
1992 static int rtrs_srv_rdma_init(struct rtrs_srv_ctx *ctx, u16 port)
1993 {
1994         struct sockaddr_in6 sin = {
1995                 .sin6_family    = AF_INET6,
1996                 .sin6_addr      = IN6ADDR_ANY_INIT,
1997                 .sin6_port      = htons(port),
1998         };
1999         struct sockaddr_ib sib = {
2000                 .sib_family                     = AF_IB,
2001                 .sib_sid        = cpu_to_be64(RDMA_IB_IP_PS_IB | port),
2002                 .sib_sid_mask   = cpu_to_be64(0xffffffffffffffffULL),
2003                 .sib_pkey       = cpu_to_be16(0xffff),
2004         };
2005         struct rdma_cm_id *cm_ip, *cm_ib;
2006         int ret;
2007
2008         /*
2009          * We accept both IPoIB and IB connections, so we need to keep
2010          * two cm id's, one for each socket type and port space.
2011          * If the cm initialization of one of the id's fails, we abort
2012          * everything.
2013          */
2014         cm_ip = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sin, RDMA_PS_TCP);
2015         if (IS_ERR(cm_ip))
2016                 return PTR_ERR(cm_ip);
2017
2018         cm_ib = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sib, RDMA_PS_IB);
2019         if (IS_ERR(cm_ib)) {
2020                 ret = PTR_ERR(cm_ib);
2021                 goto free_cm_ip;
2022         }
2023
2024         ctx->cm_id_ip = cm_ip;
2025         ctx->cm_id_ib = cm_ib;
2026
2027         return 0;
2028
2029 free_cm_ip:
2030         rdma_destroy_id(cm_ip);
2031
2032         return ret;
2033 }
2034
2035 static struct rtrs_srv_ctx *alloc_srv_ctx(struct rtrs_srv_ops *ops)
2036 {
2037         struct rtrs_srv_ctx *ctx;
2038
2039         ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
2040         if (!ctx)
2041                 return NULL;
2042
2043         ctx->ops = *ops;
2044         mutex_init(&ctx->srv_mutex);
2045         INIT_LIST_HEAD(&ctx->srv_list);
2046
2047         return ctx;
2048 }
2049
2050 static void free_srv_ctx(struct rtrs_srv_ctx *ctx)
2051 {
2052         WARN_ON(!list_empty(&ctx->srv_list));
2053         mutex_destroy(&ctx->srv_mutex);
2054         kfree(ctx);
2055 }
2056
2057 static int rtrs_srv_add_one(struct ib_device *device)
2058 {
2059         struct rtrs_srv_ctx *ctx;
2060         int ret = 0;
2061
2062         mutex_lock(&ib_ctx.ib_dev_mutex);
2063         if (ib_ctx.ib_dev_count)
2064                 goto out;
2065
2066         /*
2067          * Since our CM IDs are NOT bound to any ib device we will create them
2068          * only once
2069          */
2070         ctx = ib_ctx.srv_ctx;
2071         ret = rtrs_srv_rdma_init(ctx, ib_ctx.port);
2072         if (ret) {
2073                 /*
2074                  * We errored out here.
2075                  * According to the ib code, if we encounter an error here then the
2076                  * error code is ignored, and no more calls to our ops are made.
2077                  */
2078                 pr_err("Failed to initialize RDMA connection");
2079                 goto err_out;
2080         }
2081
2082 out:
2083         /*
2084          * Keep a track on the number of ib devices added
2085          */
2086         ib_ctx.ib_dev_count++;
2087
2088 err_out:
2089         mutex_unlock(&ib_ctx.ib_dev_mutex);
2090         return ret;
2091 }
2092
2093 static void rtrs_srv_remove_one(struct ib_device *device, void *client_data)
2094 {
2095         struct rtrs_srv_ctx *ctx;
2096
2097         mutex_lock(&ib_ctx.ib_dev_mutex);
2098         ib_ctx.ib_dev_count--;
2099
2100         if (ib_ctx.ib_dev_count)
2101                 goto out;
2102
2103         /*
2104          * Since our CM IDs are NOT bound to any ib device we will remove them
2105          * only once, when the last device is removed
2106          */
2107         ctx = ib_ctx.srv_ctx;
2108         rdma_destroy_id(ctx->cm_id_ip);
2109         rdma_destroy_id(ctx->cm_id_ib);
2110
2111 out:
2112         mutex_unlock(&ib_ctx.ib_dev_mutex);
2113 }
2114
2115 static struct ib_client rtrs_srv_client = {
2116         .name   = "rtrs_server",
2117         .add    = rtrs_srv_add_one,
2118         .remove = rtrs_srv_remove_one
2119 };
2120
2121 /**
2122  * rtrs_srv_open() - open RTRS server context
2123  * @ops:                callback functions
2124  * @port:               port to listen on
2125  *
2126  * Creates server context with specified callbacks.
2127  *
2128  * Return a valid pointer on success otherwise PTR_ERR.
2129  */
2130 struct rtrs_srv_ctx *rtrs_srv_open(struct rtrs_srv_ops *ops, u16 port)
2131 {
2132         struct rtrs_srv_ctx *ctx;
2133         int err;
2134
2135         ctx = alloc_srv_ctx(ops);
2136         if (!ctx)
2137                 return ERR_PTR(-ENOMEM);
2138
2139         mutex_init(&ib_ctx.ib_dev_mutex);
2140         ib_ctx.srv_ctx = ctx;
2141         ib_ctx.port = port;
2142
2143         err = ib_register_client(&rtrs_srv_client);
2144         if (err) {
2145                 free_srv_ctx(ctx);
2146                 return ERR_PTR(err);
2147         }
2148
2149         return ctx;
2150 }
2151 EXPORT_SYMBOL(rtrs_srv_open);
2152
2153 static void close_sessions(struct rtrs_srv *srv)
2154 {
2155         struct rtrs_srv_sess *sess;
2156
2157         mutex_lock(&srv->paths_mutex);
2158         list_for_each_entry(sess, &srv->paths_list, s.entry)
2159                 close_sess(sess);
2160         mutex_unlock(&srv->paths_mutex);
2161 }
2162
2163 static void close_ctx(struct rtrs_srv_ctx *ctx)
2164 {
2165         struct rtrs_srv *srv;
2166
2167         mutex_lock(&ctx->srv_mutex);
2168         list_for_each_entry(srv, &ctx->srv_list, ctx_list)
2169                 close_sessions(srv);
2170         mutex_unlock(&ctx->srv_mutex);
2171         flush_workqueue(rtrs_wq);
2172 }
2173
2174 /**
2175  * rtrs_srv_close() - close RTRS server context
2176  * @ctx: pointer to server context
2177  *
2178  * Closes RTRS server context with all client sessions.
2179  */
2180 void rtrs_srv_close(struct rtrs_srv_ctx *ctx)
2181 {
2182         ib_unregister_client(&rtrs_srv_client);
2183         mutex_destroy(&ib_ctx.ib_dev_mutex);
2184         close_ctx(ctx);
2185         free_srv_ctx(ctx);
2186 }
2187 EXPORT_SYMBOL(rtrs_srv_close);
2188
2189 static int check_module_params(void)
2190 {
2191         if (sess_queue_depth < 1 || sess_queue_depth > MAX_SESS_QUEUE_DEPTH) {
2192                 pr_err("Invalid sess_queue_depth value %d, has to be >= %d, <= %d.\n",
2193                        sess_queue_depth, 1, MAX_SESS_QUEUE_DEPTH);
2194                 return -EINVAL;
2195         }
2196         if (max_chunk_size < MIN_CHUNK_SIZE || !is_power_of_2(max_chunk_size)) {
2197                 pr_err("Invalid max_chunk_size value %d, has to be >= %d and should be power of two.\n",
2198                        max_chunk_size, MIN_CHUNK_SIZE);
2199                 return -EINVAL;
2200         }
2201
2202         /*
2203          * Check if IB immediate data size is enough to hold the mem_id and the
2204          * offset inside the memory chunk
2205          */
2206         if ((ilog2(sess_queue_depth - 1) + 1) +
2207             (ilog2(max_chunk_size - 1) + 1) > MAX_IMM_PAYL_BITS) {
2208                 pr_err("RDMA immediate size (%db) not enough to encode %d buffers of size %dB. Reduce 'sess_queue_depth' or 'max_chunk_size' parameters.\n",
2209                        MAX_IMM_PAYL_BITS, sess_queue_depth, max_chunk_size);
2210                 return -EINVAL;
2211         }
2212
2213         return 0;
2214 }
2215
2216 static int __init rtrs_server_init(void)
2217 {
2218         int err;
2219
2220         pr_info("Loading module %s, proto %s: (max_chunk_size: %d (pure IO %ld, headers %ld) , sess_queue_depth: %d, always_invalidate: %d)\n",
2221                 KBUILD_MODNAME, RTRS_PROTO_VER_STRING,
2222                 max_chunk_size, max_chunk_size - MAX_HDR_SIZE, MAX_HDR_SIZE,
2223                 sess_queue_depth, always_invalidate);
2224
2225         rtrs_rdma_dev_pd_init(0, &dev_pd);
2226
2227         err = check_module_params();
2228         if (err) {
2229                 pr_err("Failed to load module, invalid module parameters, err: %d\n",
2230                        err);
2231                 return err;
2232         }
2233         chunk_pool = mempool_create_page_pool(sess_queue_depth * CHUNK_POOL_SZ,
2234                                               get_order(max_chunk_size));
2235         if (!chunk_pool)
2236                 return -ENOMEM;
2237         rtrs_dev_class = class_create(THIS_MODULE, "rtrs-server");
2238         if (IS_ERR(rtrs_dev_class)) {
2239                 err = PTR_ERR(rtrs_dev_class);
2240                 goto out_chunk_pool;
2241         }
2242         rtrs_wq = alloc_workqueue("rtrs_server_wq", 0, 0);
2243         if (!rtrs_wq) {
2244                 err = -ENOMEM;
2245                 goto out_dev_class;
2246         }
2247
2248         return 0;
2249
2250 out_dev_class:
2251         class_destroy(rtrs_dev_class);
2252 out_chunk_pool:
2253         mempool_destroy(chunk_pool);
2254
2255         return err;
2256 }
2257
2258 static void __exit rtrs_server_exit(void)
2259 {
2260         destroy_workqueue(rtrs_wq);
2261         class_destroy(rtrs_dev_class);
2262         mempool_destroy(chunk_pool);
2263         rtrs_rdma_dev_pd_deinit(&dev_pd);
2264 }
2265
2266 module_init(rtrs_server_init);
2267 module_exit(rtrs_server_exit);