GNU Linux-libre 5.4.257-gnu1
[releases.git] / net / rxrpc / local_object.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Local endpoint object management
3  *
4  * Copyright (C) 2016 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <linux/module.h>
11 #include <linux/net.h>
12 #include <linux/skbuff.h>
13 #include <linux/slab.h>
14 #include <linux/udp.h>
15 #include <linux/ip.h>
16 #include <linux/hashtable.h>
17 #include <net/sock.h>
18 #include <net/udp.h>
19 #include <net/af_rxrpc.h>
20 #include "ar-internal.h"
21
22 static void rxrpc_local_processor(struct work_struct *);
23 static void rxrpc_local_rcu(struct rcu_head *);
24
25 /*
26  * Compare a local to an address.  Return -ve, 0 or +ve to indicate less than,
27  * same or greater than.
28  *
29  * We explicitly don't compare the RxRPC service ID as we want to reject
30  * conflicting uses by differing services.  Further, we don't want to share
31  * addresses with different options (IPv6), so we don't compare those bits
32  * either.
33  */
34 static long rxrpc_local_cmp_key(const struct rxrpc_local *local,
35                                 const struct sockaddr_rxrpc *srx)
36 {
37         long diff;
38
39         diff = ((local->srx.transport_type - srx->transport_type) ?:
40                 (local->srx.transport_len - srx->transport_len) ?:
41                 (local->srx.transport.family - srx->transport.family));
42         if (diff != 0)
43                 return diff;
44
45         switch (srx->transport.family) {
46         case AF_INET:
47                 /* If the choice of UDP port is left up to the transport, then
48                  * the endpoint record doesn't match.
49                  */
50                 return ((u16 __force)local->srx.transport.sin.sin_port -
51                         (u16 __force)srx->transport.sin.sin_port) ?:
52                         memcmp(&local->srx.transport.sin.sin_addr,
53                                &srx->transport.sin.sin_addr,
54                                sizeof(struct in_addr));
55 #ifdef CONFIG_AF_RXRPC_IPV6
56         case AF_INET6:
57                 /* If the choice of UDP6 port is left up to the transport, then
58                  * the endpoint record doesn't match.
59                  */
60                 return ((u16 __force)local->srx.transport.sin6.sin6_port -
61                         (u16 __force)srx->transport.sin6.sin6_port) ?:
62                         memcmp(&local->srx.transport.sin6.sin6_addr,
63                                &srx->transport.sin6.sin6_addr,
64                                sizeof(struct in6_addr));
65 #endif
66         default:
67                 BUG();
68         }
69 }
70
71 /*
72  * Allocate a new local endpoint.
73  */
74 static struct rxrpc_local *rxrpc_alloc_local(struct rxrpc_net *rxnet,
75                                              const struct sockaddr_rxrpc *srx)
76 {
77         struct rxrpc_local *local;
78
79         local = kzalloc(sizeof(struct rxrpc_local), GFP_KERNEL);
80         if (local) {
81                 atomic_set(&local->usage, 1);
82                 atomic_set(&local->active_users, 1);
83                 local->rxnet = rxnet;
84                 INIT_LIST_HEAD(&local->link);
85                 INIT_WORK(&local->processor, rxrpc_local_processor);
86                 init_rwsem(&local->defrag_sem);
87                 skb_queue_head_init(&local->reject_queue);
88                 skb_queue_head_init(&local->event_queue);
89                 local->client_conns = RB_ROOT;
90                 spin_lock_init(&local->client_conns_lock);
91                 spin_lock_init(&local->lock);
92                 rwlock_init(&local->services_lock);
93                 local->debug_id = atomic_inc_return(&rxrpc_debug_id);
94                 memcpy(&local->srx, srx, sizeof(*srx));
95                 local->srx.srx_service = 0;
96                 trace_rxrpc_local(local->debug_id, rxrpc_local_new, 1, NULL);
97         }
98
99         _leave(" = %p", local);
100         return local;
101 }
102
103 /*
104  * create the local socket
105  * - must be called with rxrpc_local_mutex locked
106  */
107 static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net)
108 {
109         struct sock *usk;
110         int ret, opt;
111
112         _enter("%p{%d,%d}",
113                local, local->srx.transport_type, local->srx.transport.family);
114
115         /* create a socket to represent the local endpoint */
116         ret = sock_create_kern(net, local->srx.transport.family,
117                                local->srx.transport_type, 0, &local->socket);
118         if (ret < 0) {
119                 _leave(" = %d [socket]", ret);
120                 return ret;
121         }
122
123         /* set the socket up */
124         usk = local->socket->sk;
125         inet_sk(usk)->mc_loop = 0;
126
127         /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */
128         inet_inc_convert_csum(usk);
129
130         rcu_assign_sk_user_data(usk, local);
131
132         udp_sk(usk)->encap_type = UDP_ENCAP_RXRPC;
133         udp_sk(usk)->encap_rcv = rxrpc_input_packet;
134         udp_sk(usk)->encap_destroy = NULL;
135         udp_sk(usk)->gro_receive = NULL;
136         udp_sk(usk)->gro_complete = NULL;
137
138         udp_encap_enable();
139 #if IS_ENABLED(CONFIG_AF_RXRPC_IPV6)
140         if (local->srx.transport.family == AF_INET6)
141                 udpv6_encap_enable();
142 #endif
143         usk->sk_error_report = rxrpc_error_report;
144
145         /* if a local address was supplied then bind it */
146         if (local->srx.transport_len > sizeof(sa_family_t)) {
147                 _debug("bind");
148                 ret = kernel_bind(local->socket,
149                                   (struct sockaddr *)&local->srx.transport,
150                                   local->srx.transport_len);
151                 if (ret < 0) {
152                         _debug("bind failed %d", ret);
153                         goto error;
154                 }
155         }
156
157         switch (local->srx.transport.family) {
158         case AF_INET6:
159                 /* we want to receive ICMPv6 errors */
160                 opt = 1;
161                 ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_RECVERR,
162                                         (char *) &opt, sizeof(opt));
163                 if (ret < 0) {
164                         _debug("setsockopt failed");
165                         goto error;
166                 }
167
168                 /* Fall through and set IPv4 options too otherwise we don't get
169                  * errors from IPv4 packets sent through the IPv6 socket.
170                  */
171                 /* Fall through */
172         case AF_INET:
173                 /* we want to receive ICMP errors */
174                 opt = 1;
175                 ret = kernel_setsockopt(local->socket, SOL_IP, IP_RECVERR,
176                                         (char *) &opt, sizeof(opt));
177                 if (ret < 0) {
178                         _debug("setsockopt failed");
179                         goto error;
180                 }
181
182                 /* we want to set the don't fragment bit */
183                 opt = IP_PMTUDISC_DO;
184                 ret = kernel_setsockopt(local->socket, SOL_IP, IP_MTU_DISCOVER,
185                                         (char *) &opt, sizeof(opt));
186                 if (ret < 0) {
187                         _debug("setsockopt failed");
188                         goto error;
189                 }
190
191                 /* We want receive timestamps. */
192                 opt = 1;
193                 ret = kernel_setsockopt(local->socket, SOL_SOCKET, SO_TIMESTAMPNS_OLD,
194                                         (char *)&opt, sizeof(opt));
195                 if (ret < 0) {
196                         _debug("setsockopt failed");
197                         goto error;
198                 }
199                 break;
200
201         default:
202                 BUG();
203         }
204
205         _leave(" = 0");
206         return 0;
207
208 error:
209         kernel_sock_shutdown(local->socket, SHUT_RDWR);
210         local->socket->sk->sk_user_data = NULL;
211         sock_release(local->socket);
212         local->socket = NULL;
213
214         _leave(" = %d", ret);
215         return ret;
216 }
217
218 /*
219  * Look up or create a new local endpoint using the specified local address.
220  */
221 struct rxrpc_local *rxrpc_lookup_local(struct net *net,
222                                        const struct sockaddr_rxrpc *srx)
223 {
224         struct rxrpc_local *local;
225         struct rxrpc_net *rxnet = rxrpc_net(net);
226         struct list_head *cursor;
227         const char *age;
228         long diff;
229         int ret;
230
231         _enter("{%d,%d,%pISp}",
232                srx->transport_type, srx->transport.family, &srx->transport);
233
234         mutex_lock(&rxnet->local_mutex);
235
236         for (cursor = rxnet->local_endpoints.next;
237              cursor != &rxnet->local_endpoints;
238              cursor = cursor->next) {
239                 local = list_entry(cursor, struct rxrpc_local, link);
240
241                 diff = rxrpc_local_cmp_key(local, srx);
242                 if (diff < 0)
243                         continue;
244                 if (diff > 0)
245                         break;
246
247                 /* Services aren't allowed to share transport sockets, so
248                  * reject that here.  It is possible that the object is dying -
249                  * but it may also still have the local transport address that
250                  * we want bound.
251                  */
252                 if (srx->srx_service) {
253                         local = NULL;
254                         goto addr_in_use;
255                 }
256
257                 /* Found a match.  We replace a dying object.  Attempting to
258                  * bind the transport socket may still fail if we're attempting
259                  * to use a local address that the dying object is still using.
260                  */
261                 if (!rxrpc_use_local(local))
262                         break;
263
264                 age = "old";
265                 goto found;
266         }
267
268         local = rxrpc_alloc_local(rxnet, srx);
269         if (!local)
270                 goto nomem;
271
272         ret = rxrpc_open_socket(local, net);
273         if (ret < 0)
274                 goto sock_error;
275
276         if (cursor != &rxnet->local_endpoints)
277                 list_replace_init(cursor, &local->link);
278         else
279                 list_add_tail(&local->link, cursor);
280         age = "new";
281
282 found:
283         mutex_unlock(&rxnet->local_mutex);
284
285         _net("LOCAL %s %d {%pISp}",
286              age, local->debug_id, &local->srx.transport);
287
288         _leave(" = %p", local);
289         return local;
290
291 nomem:
292         ret = -ENOMEM;
293 sock_error:
294         mutex_unlock(&rxnet->local_mutex);
295         if (local)
296                 call_rcu(&local->rcu, rxrpc_local_rcu);
297         _leave(" = %d", ret);
298         return ERR_PTR(ret);
299
300 addr_in_use:
301         mutex_unlock(&rxnet->local_mutex);
302         _leave(" = -EADDRINUSE");
303         return ERR_PTR(-EADDRINUSE);
304 }
305
306 /*
307  * Get a ref on a local endpoint.
308  */
309 struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *local)
310 {
311         const void *here = __builtin_return_address(0);
312         int n;
313
314         n = atomic_inc_return(&local->usage);
315         trace_rxrpc_local(local->debug_id, rxrpc_local_got, n, here);
316         return local;
317 }
318
319 /*
320  * Get a ref on a local endpoint unless its usage has already reached 0.
321  */
322 struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *local)
323 {
324         const void *here = __builtin_return_address(0);
325
326         if (local) {
327                 int n = atomic_fetch_add_unless(&local->usage, 1, 0);
328                 if (n > 0)
329                         trace_rxrpc_local(local->debug_id, rxrpc_local_got,
330                                           n + 1, here);
331                 else
332                         local = NULL;
333         }
334         return local;
335 }
336
337 /*
338  * Queue a local endpoint and pass the caller's reference to the work item.
339  */
340 void rxrpc_queue_local(struct rxrpc_local *local)
341 {
342         const void *here = __builtin_return_address(0);
343         unsigned int debug_id = local->debug_id;
344         int n = atomic_read(&local->usage);
345
346         if (rxrpc_queue_work(&local->processor))
347                 trace_rxrpc_local(debug_id, rxrpc_local_queued, n, here);
348         else
349                 rxrpc_put_local(local);
350 }
351
352 /*
353  * Drop a ref on a local endpoint.
354  */
355 void rxrpc_put_local(struct rxrpc_local *local)
356 {
357         const void *here = __builtin_return_address(0);
358         unsigned int debug_id;
359         int n;
360
361         if (local) {
362                 debug_id = local->debug_id;
363
364                 n = atomic_dec_return(&local->usage);
365                 trace_rxrpc_local(debug_id, rxrpc_local_put, n, here);
366
367                 if (n == 0)
368                         call_rcu(&local->rcu, rxrpc_local_rcu);
369         }
370 }
371
372 /*
373  * Start using a local endpoint.
374  */
375 struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *local)
376 {
377         local = rxrpc_get_local_maybe(local);
378         if (!local)
379                 return NULL;
380
381         if (!__rxrpc_use_local(local)) {
382                 rxrpc_put_local(local);
383                 return NULL;
384         }
385
386         return local;
387 }
388
389 /*
390  * Cease using a local endpoint.  Once the number of active users reaches 0, we
391  * start the closure of the transport in the work processor.
392  */
393 void rxrpc_unuse_local(struct rxrpc_local *local)
394 {
395         if (local) {
396                 if (__rxrpc_unuse_local(local)) {
397                         rxrpc_get_local(local);
398                         rxrpc_queue_local(local);
399                 }
400         }
401 }
402
403 /*
404  * Destroy a local endpoint's socket and then hand the record to RCU to dispose
405  * of.
406  *
407  * Closing the socket cannot be done from bottom half context or RCU callback
408  * context because it might sleep.
409  */
410 static void rxrpc_local_destroyer(struct rxrpc_local *local)
411 {
412         struct socket *socket = local->socket;
413         struct rxrpc_net *rxnet = local->rxnet;
414
415         _enter("%d", local->debug_id);
416
417         local->dead = true;
418
419         mutex_lock(&rxnet->local_mutex);
420         list_del_init(&local->link);
421         mutex_unlock(&rxnet->local_mutex);
422
423         rxrpc_clean_up_local_conns(local);
424         rxrpc_service_connection_reaper(&rxnet->service_conn_reaper);
425         ASSERT(!local->service);
426
427         if (socket) {
428                 local->socket = NULL;
429                 kernel_sock_shutdown(socket, SHUT_RDWR);
430                 socket->sk->sk_user_data = NULL;
431                 sock_release(socket);
432         }
433
434         /* At this point, there should be no more packets coming in to the
435          * local endpoint.
436          */
437         rxrpc_purge_queue(&local->reject_queue);
438         rxrpc_purge_queue(&local->event_queue);
439 }
440
441 /*
442  * Process events on an endpoint.  The work item carries a ref which
443  * we must release.
444  */
445 static void rxrpc_local_processor(struct work_struct *work)
446 {
447         struct rxrpc_local *local =
448                 container_of(work, struct rxrpc_local, processor);
449         bool again;
450
451         if (local->dead)
452                 return;
453
454         trace_rxrpc_local(local->debug_id, rxrpc_local_processing,
455                           atomic_read(&local->usage), NULL);
456
457         do {
458                 again = false;
459                 if (!__rxrpc_use_local(local)) {
460                         rxrpc_local_destroyer(local);
461                         break;
462                 }
463
464                 if (!skb_queue_empty(&local->reject_queue)) {
465                         rxrpc_reject_packets(local);
466                         again = true;
467                 }
468
469                 if (!skb_queue_empty(&local->event_queue)) {
470                         rxrpc_process_local_events(local);
471                         again = true;
472                 }
473
474                 __rxrpc_unuse_local(local);
475         } while (again);
476
477         rxrpc_put_local(local);
478 }
479
480 /*
481  * Destroy a local endpoint after the RCU grace period expires.
482  */
483 static void rxrpc_local_rcu(struct rcu_head *rcu)
484 {
485         struct rxrpc_local *local = container_of(rcu, struct rxrpc_local, rcu);
486
487         _enter("%d", local->debug_id);
488
489         ASSERT(!work_pending(&local->processor));
490
491         _net("DESTROY LOCAL %d", local->debug_id);
492         kfree(local);
493         _leave("");
494 }
495
496 /*
497  * Verify the local endpoint list is empty by this point.
498  */
499 void rxrpc_destroy_all_locals(struct rxrpc_net *rxnet)
500 {
501         struct rxrpc_local *local;
502
503         _enter("");
504
505         flush_workqueue(rxrpc_workqueue);
506
507         if (!list_empty(&rxnet->local_endpoints)) {
508                 mutex_lock(&rxnet->local_mutex);
509                 list_for_each_entry(local, &rxnet->local_endpoints, link) {
510                         pr_err("AF_RXRPC: Leaked local %p {%d}\n",
511                                local, atomic_read(&local->usage));
512                 }
513                 mutex_unlock(&rxnet->local_mutex);
514                 BUG();
515         }
516 }