Mention branches and keyring.
[releases.git] / netfilter / ipvs / ip_vs_mh.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* IPVS:        Maglev Hashing scheduling module
3  *
4  * Authors:     Inju Song <inju.song@navercorp.com>
5  *
6  */
7
8 /* The mh algorithm is to assign a preference list of all the lookup
9  * table positions to each destination and populate the table with
10  * the most-preferred position of destinations. Then it is to select
11  * destination with the hash key of source IP address through looking
12  * up a the lookup table.
13  *
14  * The algorithm is detailed in:
15  * [3.4 Consistent Hasing]
16 https://www.usenix.org/system/files/conference/nsdi16/nsdi16-paper-eisenbud.pdf
17  *
18  */
19
20 #define KMSG_COMPONENT "IPVS"
21 #define pr_fmt(fmt) KMSG_COMPONENT ": " fmt
22
23 #include <linux/ip.h>
24 #include <linux/slab.h>
25 #include <linux/module.h>
26 #include <linux/kernel.h>
27 #include <linux/skbuff.h>
28
29 #include <net/ip_vs.h>
30
31 #include <linux/siphash.h>
32 #include <linux/bitops.h>
33 #include <linux/gcd.h>
34
35 #define IP_VS_SVC_F_SCHED_MH_FALLBACK   IP_VS_SVC_F_SCHED1 /* MH fallback */
36 #define IP_VS_SVC_F_SCHED_MH_PORT       IP_VS_SVC_F_SCHED2 /* MH use port */
37
38 struct ip_vs_mh_lookup {
39         struct ip_vs_dest __rcu *dest;  /* real server (cache) */
40 };
41
42 struct ip_vs_mh_dest_setup {
43         unsigned int    offset; /* starting offset */
44         unsigned int    skip;   /* skip */
45         unsigned int    perm;   /* next_offset */
46         int             turns;  /* weight / gcd() and rshift */
47 };
48
49 /* Available prime numbers for MH table */
50 static int primes[] = {251, 509, 1021, 2039, 4093,
51                        8191, 16381, 32749, 65521, 131071};
52
53 /* For IPVS MH entry hash table */
54 #ifndef CONFIG_IP_VS_MH_TAB_INDEX
55 #define CONFIG_IP_VS_MH_TAB_INDEX       12
56 #endif
57 #define IP_VS_MH_TAB_BITS               (CONFIG_IP_VS_MH_TAB_INDEX / 2)
58 #define IP_VS_MH_TAB_INDEX              (CONFIG_IP_VS_MH_TAB_INDEX - 8)
59 #define IP_VS_MH_TAB_SIZE               primes[IP_VS_MH_TAB_INDEX]
60
61 struct ip_vs_mh_state {
62         struct rcu_head                 rcu_head;
63         struct ip_vs_mh_lookup          *lookup;
64         struct ip_vs_mh_dest_setup      *dest_setup;
65         hsiphash_key_t                  hash1, hash2;
66         int                             gcd;
67         int                             rshift;
68 };
69
70 static inline void generate_hash_secret(hsiphash_key_t *hash1,
71                                         hsiphash_key_t *hash2)
72 {
73         hash1->key[0] = 2654435761UL;
74         hash1->key[1] = 2654435761UL;
75
76         hash2->key[0] = 2654446892UL;
77         hash2->key[1] = 2654446892UL;
78 }
79
80 /* Helper function to determine if server is unavailable */
81 static inline bool is_unavailable(struct ip_vs_dest *dest)
82 {
83         return atomic_read(&dest->weight) <= 0 ||
84                dest->flags & IP_VS_DEST_F_OVERLOAD;
85 }
86
87 /* Returns hash value for IPVS MH entry */
88 static inline unsigned int
89 ip_vs_mh_hashkey(int af, const union nf_inet_addr *addr,
90                  __be16 port, hsiphash_key_t *key, unsigned int offset)
91 {
92         unsigned int v;
93         __be32 addr_fold = addr->ip;
94
95 #ifdef CONFIG_IP_VS_IPV6
96         if (af == AF_INET6)
97                 addr_fold = addr->ip6[0] ^ addr->ip6[1] ^
98                             addr->ip6[2] ^ addr->ip6[3];
99 #endif
100         v = (offset + ntohs(port) + ntohl(addr_fold));
101         return hsiphash(&v, sizeof(v), key);
102 }
103
104 /* Reset all the hash buckets of the specified table. */
105 static void ip_vs_mh_reset(struct ip_vs_mh_state *s)
106 {
107         int i;
108         struct ip_vs_mh_lookup *l;
109         struct ip_vs_dest *dest;
110
111         l = &s->lookup[0];
112         for (i = 0; i < IP_VS_MH_TAB_SIZE; i++) {
113                 dest = rcu_dereference_protected(l->dest, 1);
114                 if (dest) {
115                         ip_vs_dest_put(dest);
116                         RCU_INIT_POINTER(l->dest, NULL);
117                 }
118                 l++;
119         }
120 }
121
122 static int ip_vs_mh_permutate(struct ip_vs_mh_state *s,
123                               struct ip_vs_service *svc)
124 {
125         struct list_head *p;
126         struct ip_vs_mh_dest_setup *ds;
127         struct ip_vs_dest *dest;
128         int lw;
129
130         /* If gcd is smaller then 1, number of dests or
131          * all last_weight of dests are zero. So, skip
132          * permutation for the dests.
133          */
134         if (s->gcd < 1)
135                 return 0;
136
137         /* Set dest_setup for the dests permutation */
138         p = &svc->destinations;
139         ds = &s->dest_setup[0];
140         while ((p = p->next) != &svc->destinations) {
141                 dest = list_entry(p, struct ip_vs_dest, n_list);
142
143                 ds->offset = ip_vs_mh_hashkey(svc->af, &dest->addr,
144                                               dest->port, &s->hash1, 0) %
145                                               IP_VS_MH_TAB_SIZE;
146                 ds->skip = ip_vs_mh_hashkey(svc->af, &dest->addr,
147                                             dest->port, &s->hash2, 0) %
148                                             (IP_VS_MH_TAB_SIZE - 1) + 1;
149                 ds->perm = ds->offset;
150
151                 lw = atomic_read(&dest->last_weight);
152                 ds->turns = ((lw / s->gcd) >> s->rshift) ? : (lw != 0);
153                 ds++;
154         }
155
156         return 0;
157 }
158
159 static int ip_vs_mh_populate(struct ip_vs_mh_state *s,
160                              struct ip_vs_service *svc)
161 {
162         int n, c, dt_count;
163         unsigned long *table;
164         struct list_head *p;
165         struct ip_vs_mh_dest_setup *ds;
166         struct ip_vs_dest *dest, *new_dest;
167
168         /* If gcd is smaller then 1, number of dests or
169          * all last_weight of dests are zero. So, skip
170          * the population for the dests and reset lookup table.
171          */
172         if (s->gcd < 1) {
173                 ip_vs_mh_reset(s);
174                 return 0;
175         }
176
177         table = bitmap_zalloc(IP_VS_MH_TAB_SIZE, GFP_KERNEL);
178         if (!table)
179                 return -ENOMEM;
180
181         p = &svc->destinations;
182         n = 0;
183         dt_count = 0;
184         while (n < IP_VS_MH_TAB_SIZE) {
185                 if (p == &svc->destinations)
186                         p = p->next;
187
188                 ds = &s->dest_setup[0];
189                 while (p != &svc->destinations) {
190                         /* Ignore added server with zero weight */
191                         if (ds->turns < 1) {
192                                 p = p->next;
193                                 ds++;
194                                 continue;
195                         }
196
197                         c = ds->perm;
198                         while (test_bit(c, table)) {
199                                 /* Add skip, mod IP_VS_MH_TAB_SIZE */
200                                 ds->perm += ds->skip;
201                                 if (ds->perm >= IP_VS_MH_TAB_SIZE)
202                                         ds->perm -= IP_VS_MH_TAB_SIZE;
203                                 c = ds->perm;
204                         }
205
206                         __set_bit(c, table);
207
208                         dest = rcu_dereference_protected(s->lookup[c].dest, 1);
209                         new_dest = list_entry(p, struct ip_vs_dest, n_list);
210                         if (dest != new_dest) {
211                                 if (dest)
212                                         ip_vs_dest_put(dest);
213                                 ip_vs_dest_hold(new_dest);
214                                 RCU_INIT_POINTER(s->lookup[c].dest, new_dest);
215                         }
216
217                         if (++n == IP_VS_MH_TAB_SIZE)
218                                 goto out;
219
220                         if (++dt_count >= ds->turns) {
221                                 dt_count = 0;
222                                 p = p->next;
223                                 ds++;
224                         }
225                 }
226         }
227
228 out:
229         bitmap_free(table);
230         return 0;
231 }
232
233 /* Get ip_vs_dest associated with supplied parameters. */
234 static inline struct ip_vs_dest *
235 ip_vs_mh_get(struct ip_vs_service *svc, struct ip_vs_mh_state *s,
236              const union nf_inet_addr *addr, __be16 port)
237 {
238         unsigned int hash = ip_vs_mh_hashkey(svc->af, addr, port, &s->hash1, 0)
239                                              % IP_VS_MH_TAB_SIZE;
240         struct ip_vs_dest *dest = rcu_dereference(s->lookup[hash].dest);
241
242         return (!dest || is_unavailable(dest)) ? NULL : dest;
243 }
244
245 /* As ip_vs_mh_get, but with fallback if selected server is unavailable */
246 static inline struct ip_vs_dest *
247 ip_vs_mh_get_fallback(struct ip_vs_service *svc, struct ip_vs_mh_state *s,
248                       const union nf_inet_addr *addr, __be16 port)
249 {
250         unsigned int offset, roffset;
251         unsigned int hash, ihash;
252         struct ip_vs_dest *dest;
253
254         /* First try the dest it's supposed to go to */
255         ihash = ip_vs_mh_hashkey(svc->af, addr, port,
256                                  &s->hash1, 0) % IP_VS_MH_TAB_SIZE;
257         dest = rcu_dereference(s->lookup[ihash].dest);
258         if (!dest)
259                 return NULL;
260         if (!is_unavailable(dest))
261                 return dest;
262
263         IP_VS_DBG_BUF(6, "MH: selected unavailable server %s:%u, reselecting",
264                       IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port));
265
266         /* If the original dest is unavailable, loop around the table
267          * starting from ihash to find a new dest
268          */
269         for (offset = 0; offset < IP_VS_MH_TAB_SIZE; offset++) {
270                 roffset = (offset + ihash) % IP_VS_MH_TAB_SIZE;
271                 hash = ip_vs_mh_hashkey(svc->af, addr, port, &s->hash1,
272                                         roffset) % IP_VS_MH_TAB_SIZE;
273                 dest = rcu_dereference(s->lookup[hash].dest);
274                 if (!dest)
275                         break;
276                 if (!is_unavailable(dest))
277                         return dest;
278                 IP_VS_DBG_BUF(6,
279                               "MH: selected unavailable server %s:%u (offset %u), reselecting",
280                               IP_VS_DBG_ADDR(dest->af, &dest->addr),
281                               ntohs(dest->port), roffset);
282         }
283
284         return NULL;
285 }
286
287 /* Assign all the hash buckets of the specified table with the service. */
288 static int ip_vs_mh_reassign(struct ip_vs_mh_state *s,
289                              struct ip_vs_service *svc)
290 {
291         int ret;
292
293         if (svc->num_dests > IP_VS_MH_TAB_SIZE)
294                 return -EINVAL;
295
296         if (svc->num_dests >= 1) {
297                 s->dest_setup = kcalloc(svc->num_dests,
298                                         sizeof(struct ip_vs_mh_dest_setup),
299                                         GFP_KERNEL);
300                 if (!s->dest_setup)
301                         return -ENOMEM;
302         }
303
304         ip_vs_mh_permutate(s, svc);
305
306         ret = ip_vs_mh_populate(s, svc);
307         if (ret < 0)
308                 goto out;
309
310         IP_VS_DBG_BUF(6, "MH: reassign lookup table of %s:%u\n",
311                       IP_VS_DBG_ADDR(svc->af, &svc->addr),
312                       ntohs(svc->port));
313
314 out:
315         if (svc->num_dests >= 1) {
316                 kfree(s->dest_setup);
317                 s->dest_setup = NULL;
318         }
319         return ret;
320 }
321
322 static int ip_vs_mh_gcd_weight(struct ip_vs_service *svc)
323 {
324         struct ip_vs_dest *dest;
325         int weight;
326         int g = 0;
327
328         list_for_each_entry(dest, &svc->destinations, n_list) {
329                 weight = atomic_read(&dest->last_weight);
330                 if (weight > 0) {
331                         if (g > 0)
332                                 g = gcd(weight, g);
333                         else
334                                 g = weight;
335                 }
336         }
337         return g;
338 }
339
340 /* To avoid assigning huge weight for the MH table,
341  * calculate shift value with gcd.
342  */
343 static int ip_vs_mh_shift_weight(struct ip_vs_service *svc, int gcd)
344 {
345         struct ip_vs_dest *dest;
346         int new_weight, weight = 0;
347         int mw, shift;
348
349         /* If gcd is smaller then 1, number of dests or
350          * all last_weight of dests are zero. So, return
351          * shift value as zero.
352          */
353         if (gcd < 1)
354                 return 0;
355
356         list_for_each_entry(dest, &svc->destinations, n_list) {
357                 new_weight = atomic_read(&dest->last_weight);
358                 if (new_weight > weight)
359                         weight = new_weight;
360         }
361
362         /* Because gcd is greater than zero,
363          * the maximum weight and gcd are always greater than zero
364          */
365         mw = weight / gcd;
366
367         /* shift = occupied bits of weight/gcd - MH highest bits */
368         shift = fls(mw) - IP_VS_MH_TAB_BITS;
369         return (shift >= 0) ? shift : 0;
370 }
371
372 static void ip_vs_mh_state_free(struct rcu_head *head)
373 {
374         struct ip_vs_mh_state *s;
375
376         s = container_of(head, struct ip_vs_mh_state, rcu_head);
377         kfree(s->lookup);
378         kfree(s);
379 }
380
381 static int ip_vs_mh_init_svc(struct ip_vs_service *svc)
382 {
383         int ret;
384         struct ip_vs_mh_state *s;
385
386         /* Allocate the MH table for this service */
387         s = kzalloc(sizeof(*s), GFP_KERNEL);
388         if (!s)
389                 return -ENOMEM;
390
391         s->lookup = kcalloc(IP_VS_MH_TAB_SIZE, sizeof(struct ip_vs_mh_lookup),
392                             GFP_KERNEL);
393         if (!s->lookup) {
394                 kfree(s);
395                 return -ENOMEM;
396         }
397
398         generate_hash_secret(&s->hash1, &s->hash2);
399         s->gcd = ip_vs_mh_gcd_weight(svc);
400         s->rshift = ip_vs_mh_shift_weight(svc, s->gcd);
401
402         IP_VS_DBG(6,
403                   "MH lookup table (memory=%zdbytes) allocated for current service\n",
404                   sizeof(struct ip_vs_mh_lookup) * IP_VS_MH_TAB_SIZE);
405
406         /* Assign the lookup table with current dests */
407         ret = ip_vs_mh_reassign(s, svc);
408         if (ret < 0) {
409                 ip_vs_mh_reset(s);
410                 ip_vs_mh_state_free(&s->rcu_head);
411                 return ret;
412         }
413
414         /* No more failures, attach state */
415         svc->sched_data = s;
416         return 0;
417 }
418
419 static void ip_vs_mh_done_svc(struct ip_vs_service *svc)
420 {
421         struct ip_vs_mh_state *s = svc->sched_data;
422
423         /* Got to clean up lookup entry here */
424         ip_vs_mh_reset(s);
425
426         call_rcu(&s->rcu_head, ip_vs_mh_state_free);
427         IP_VS_DBG(6, "MH lookup table (memory=%zdbytes) released\n",
428                   sizeof(struct ip_vs_mh_lookup) * IP_VS_MH_TAB_SIZE);
429 }
430
431 static int ip_vs_mh_dest_changed(struct ip_vs_service *svc,
432                                  struct ip_vs_dest *dest)
433 {
434         struct ip_vs_mh_state *s = svc->sched_data;
435
436         s->gcd = ip_vs_mh_gcd_weight(svc);
437         s->rshift = ip_vs_mh_shift_weight(svc, s->gcd);
438
439         /* Assign the lookup table with the updated service */
440         return ip_vs_mh_reassign(s, svc);
441 }
442
443 /* Helper function to get port number */
444 static inline __be16
445 ip_vs_mh_get_port(const struct sk_buff *skb, struct ip_vs_iphdr *iph)
446 {
447         __be16 _ports[2], *ports;
448
449         /* At this point we know that we have a valid packet of some kind.
450          * Because ICMP packets are only guaranteed to have the first 8
451          * bytes, let's just grab the ports.  Fortunately they're in the
452          * same position for all three of the protocols we care about.
453          */
454         switch (iph->protocol) {
455         case IPPROTO_TCP:
456         case IPPROTO_UDP:
457         case IPPROTO_SCTP:
458                 ports = skb_header_pointer(skb, iph->len, sizeof(_ports),
459                                            &_ports);
460                 if (unlikely(!ports))
461                         return 0;
462
463                 if (likely(!ip_vs_iph_inverse(iph)))
464                         return ports[0];
465                 else
466                         return ports[1];
467         default:
468                 return 0;
469         }
470 }
471
472 /* Maglev Hashing scheduling */
473 static struct ip_vs_dest *
474 ip_vs_mh_schedule(struct ip_vs_service *svc, const struct sk_buff *skb,
475                   struct ip_vs_iphdr *iph)
476 {
477         struct ip_vs_dest *dest;
478         struct ip_vs_mh_state *s;
479         __be16 port = 0;
480         const union nf_inet_addr *hash_addr;
481
482         hash_addr = ip_vs_iph_inverse(iph) ? &iph->daddr : &iph->saddr;
483
484         IP_VS_DBG(6, "%s : Scheduling...\n", __func__);
485
486         if (svc->flags & IP_VS_SVC_F_SCHED_MH_PORT)
487                 port = ip_vs_mh_get_port(skb, iph);
488
489         s = (struct ip_vs_mh_state *)svc->sched_data;
490
491         if (svc->flags & IP_VS_SVC_F_SCHED_MH_FALLBACK)
492                 dest = ip_vs_mh_get_fallback(svc, s, hash_addr, port);
493         else
494                 dest = ip_vs_mh_get(svc, s, hash_addr, port);
495
496         if (!dest) {
497                 ip_vs_scheduler_err(svc, "no destination available");
498                 return NULL;
499         }
500
501         IP_VS_DBG_BUF(6, "MH: source IP address %s:%u --> server %s:%u\n",
502                       IP_VS_DBG_ADDR(svc->af, hash_addr),
503                       ntohs(port),
504                       IP_VS_DBG_ADDR(dest->af, &dest->addr),
505                       ntohs(dest->port));
506
507         return dest;
508 }
509
510 /* IPVS MH Scheduler structure */
511 static struct ip_vs_scheduler ip_vs_mh_scheduler = {
512         .name =                 "mh",
513         .refcnt =               ATOMIC_INIT(0),
514         .module =               THIS_MODULE,
515         .n_list  =              LIST_HEAD_INIT(ip_vs_mh_scheduler.n_list),
516         .init_service =         ip_vs_mh_init_svc,
517         .done_service =         ip_vs_mh_done_svc,
518         .add_dest =             ip_vs_mh_dest_changed,
519         .del_dest =             ip_vs_mh_dest_changed,
520         .upd_dest =             ip_vs_mh_dest_changed,
521         .schedule =             ip_vs_mh_schedule,
522 };
523
524 static int __init ip_vs_mh_init(void)
525 {
526         return register_ip_vs_scheduler(&ip_vs_mh_scheduler);
527 }
528
529 static void __exit ip_vs_mh_cleanup(void)
530 {
531         unregister_ip_vs_scheduler(&ip_vs_mh_scheduler);
532         rcu_barrier();
533 }
534
535 module_init(ip_vs_mh_init);
536 module_exit(ip_vs_mh_cleanup);
537 MODULE_DESCRIPTION("Maglev hashing ipvs scheduler");
538 MODULE_LICENSE("GPL v2");
539 MODULE_AUTHOR("Inju Song <inju.song@navercorp.com>");