GNU Linux-libre 6.1.90-gnu
[releases.git] / net / xfrm / xfrm_ipcomp.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * IP Payload Compression Protocol (IPComp) - RFC3173.
4  *
5  * Copyright (c) 2003 James Morris <jmorris@intercode.com.au>
6  * Copyright (c) 2003-2008 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * Todo:
9  *   - Tunable compression parameters.
10  *   - Compression stats.
11  *   - Adaptive compression.
12  */
13
14 #include <linux/crypto.h>
15 #include <linux/err.h>
16 #include <linux/list.h>
17 #include <linux/module.h>
18 #include <linux/mutex.h>
19 #include <linux/percpu.h>
20 #include <linux/slab.h>
21 #include <linux/smp.h>
22 #include <linux/vmalloc.h>
23 #include <net/ip.h>
24 #include <net/ipcomp.h>
25 #include <net/xfrm.h>
26
27 struct ipcomp_tfms {
28         struct list_head list;
29         struct crypto_comp * __percpu *tfms;
30         int users;
31 };
32
33 static DEFINE_MUTEX(ipcomp_resource_mutex);
34 static void * __percpu *ipcomp_scratches;
35 static int ipcomp_scratch_users;
36 static LIST_HEAD(ipcomp_tfms_list);
37
38 static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb)
39 {
40         struct ipcomp_data *ipcd = x->data;
41         const int plen = skb->len;
42         int dlen = IPCOMP_SCRATCH_SIZE;
43         const u8 *start = skb->data;
44         u8 *scratch = *this_cpu_ptr(ipcomp_scratches);
45         struct crypto_comp *tfm = *this_cpu_ptr(ipcd->tfms);
46         int err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen);
47         int len;
48
49         if (err)
50                 return err;
51
52         if (dlen < (plen + sizeof(struct ip_comp_hdr)))
53                 return -EINVAL;
54
55         len = dlen - plen;
56         if (len > skb_tailroom(skb))
57                 len = skb_tailroom(skb);
58
59         __skb_put(skb, len);
60
61         len += plen;
62         skb_copy_to_linear_data(skb, scratch, len);
63
64         while ((scratch += len, dlen -= len) > 0) {
65                 skb_frag_t *frag;
66                 struct page *page;
67
68                 if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS))
69                         return -EMSGSIZE;
70
71                 frag = skb_shinfo(skb)->frags + skb_shinfo(skb)->nr_frags;
72                 page = alloc_page(GFP_ATOMIC);
73
74                 if (!page)
75                         return -ENOMEM;
76
77                 __skb_frag_set_page(frag, page);
78
79                 len = PAGE_SIZE;
80                 if (dlen < len)
81                         len = dlen;
82
83                 skb_frag_off_set(frag, 0);
84                 skb_frag_size_set(frag, len);
85                 memcpy(skb_frag_address(frag), scratch, len);
86
87                 skb->truesize += len;
88                 skb->data_len += len;
89                 skb->len += len;
90
91                 skb_shinfo(skb)->nr_frags++;
92         }
93
94         return 0;
95 }
96
97 int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb)
98 {
99         int nexthdr;
100         int err = -ENOMEM;
101         struct ip_comp_hdr *ipch;
102
103         if (skb_linearize_cow(skb))
104                 goto out;
105
106         skb->ip_summed = CHECKSUM_NONE;
107
108         /* Remove ipcomp header and decompress original payload */
109         ipch = (void *)skb->data;
110         nexthdr = ipch->nexthdr;
111
112         skb->transport_header = skb->network_header + sizeof(*ipch);
113         __skb_pull(skb, sizeof(*ipch));
114         err = ipcomp_decompress(x, skb);
115         if (err)
116                 goto out;
117
118         err = nexthdr;
119
120 out:
121         return err;
122 }
123 EXPORT_SYMBOL_GPL(ipcomp_input);
124
125 static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb)
126 {
127         struct ipcomp_data *ipcd = x->data;
128         const int plen = skb->len;
129         int dlen = IPCOMP_SCRATCH_SIZE;
130         u8 *start = skb->data;
131         struct crypto_comp *tfm;
132         u8 *scratch;
133         int err;
134
135         local_bh_disable();
136         scratch = *this_cpu_ptr(ipcomp_scratches);
137         tfm = *this_cpu_ptr(ipcd->tfms);
138         err = crypto_comp_compress(tfm, start, plen, scratch, &dlen);
139         if (err)
140                 goto out;
141
142         if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) {
143                 err = -EMSGSIZE;
144                 goto out;
145         }
146
147         memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen);
148         local_bh_enable();
149
150         pskb_trim(skb, dlen + sizeof(struct ip_comp_hdr));
151         return 0;
152
153 out:
154         local_bh_enable();
155         return err;
156 }
157
158 int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb)
159 {
160         int err;
161         struct ip_comp_hdr *ipch;
162         struct ipcomp_data *ipcd = x->data;
163
164         if (skb->len < ipcd->threshold) {
165                 /* Don't bother compressing */
166                 goto out_ok;
167         }
168
169         if (skb_linearize_cow(skb))
170                 goto out_ok;
171
172         err = ipcomp_compress(x, skb);
173
174         if (err) {
175                 goto out_ok;
176         }
177
178         /* Install ipcomp header, convert into ipcomp datagram. */
179         ipch = ip_comp_hdr(skb);
180         ipch->nexthdr = *skb_mac_header(skb);
181         ipch->flags = 0;
182         ipch->cpi = htons((u16 )ntohl(x->id.spi));
183         *skb_mac_header(skb) = IPPROTO_COMP;
184 out_ok:
185         skb_push(skb, -skb_network_offset(skb));
186         return 0;
187 }
188 EXPORT_SYMBOL_GPL(ipcomp_output);
189
190 static void ipcomp_free_scratches(void)
191 {
192         int i;
193         void * __percpu *scratches;
194
195         if (--ipcomp_scratch_users)
196                 return;
197
198         scratches = ipcomp_scratches;
199         if (!scratches)
200                 return;
201
202         for_each_possible_cpu(i)
203                 vfree(*per_cpu_ptr(scratches, i));
204
205         free_percpu(scratches);
206         ipcomp_scratches = NULL;
207 }
208
209 static void * __percpu *ipcomp_alloc_scratches(void)
210 {
211         void * __percpu *scratches;
212         int i;
213
214         if (ipcomp_scratch_users++)
215                 return ipcomp_scratches;
216
217         scratches = alloc_percpu(void *);
218         if (!scratches)
219                 return NULL;
220
221         ipcomp_scratches = scratches;
222
223         for_each_possible_cpu(i) {
224                 void *scratch;
225
226                 scratch = vmalloc_node(IPCOMP_SCRATCH_SIZE, cpu_to_node(i));
227                 if (!scratch)
228                         return NULL;
229                 *per_cpu_ptr(scratches, i) = scratch;
230         }
231
232         return scratches;
233 }
234
235 static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms)
236 {
237         struct ipcomp_tfms *pos;
238         int cpu;
239
240         list_for_each_entry(pos, &ipcomp_tfms_list, list) {
241                 if (pos->tfms == tfms)
242                         break;
243         }
244
245         WARN_ON(list_entry_is_head(pos, &ipcomp_tfms_list, list));
246
247         if (--pos->users)
248                 return;
249
250         list_del(&pos->list);
251         kfree(pos);
252
253         if (!tfms)
254                 return;
255
256         for_each_possible_cpu(cpu) {
257                 struct crypto_comp *tfm = *per_cpu_ptr(tfms, cpu);
258                 crypto_free_comp(tfm);
259         }
260         free_percpu(tfms);
261 }
262
263 static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name)
264 {
265         struct ipcomp_tfms *pos;
266         struct crypto_comp * __percpu *tfms;
267         int cpu;
268
269
270         list_for_each_entry(pos, &ipcomp_tfms_list, list) {
271                 struct crypto_comp *tfm;
272
273                 /* This can be any valid CPU ID so we don't need locking. */
274                 tfm = this_cpu_read(*pos->tfms);
275
276                 if (!strcmp(crypto_comp_name(tfm), alg_name)) {
277                         pos->users++;
278                         return pos->tfms;
279                 }
280         }
281
282         pos = kmalloc(sizeof(*pos), GFP_KERNEL);
283         if (!pos)
284                 return NULL;
285
286         pos->users = 1;
287         INIT_LIST_HEAD(&pos->list);
288         list_add(&pos->list, &ipcomp_tfms_list);
289
290         pos->tfms = tfms = alloc_percpu(struct crypto_comp *);
291         if (!tfms)
292                 goto error;
293
294         for_each_possible_cpu(cpu) {
295                 struct crypto_comp *tfm = crypto_alloc_comp(alg_name, 0,
296                                                             CRYPTO_ALG_ASYNC);
297                 if (IS_ERR(tfm))
298                         goto error;
299                 *per_cpu_ptr(tfms, cpu) = tfm;
300         }
301
302         return tfms;
303
304 error:
305         ipcomp_free_tfms(tfms);
306         return NULL;
307 }
308
309 static void ipcomp_free_data(struct ipcomp_data *ipcd)
310 {
311         if (ipcd->tfms)
312                 ipcomp_free_tfms(ipcd->tfms);
313         ipcomp_free_scratches();
314 }
315
316 void ipcomp_destroy(struct xfrm_state *x)
317 {
318         struct ipcomp_data *ipcd = x->data;
319         if (!ipcd)
320                 return;
321         xfrm_state_delete_tunnel(x);
322         mutex_lock(&ipcomp_resource_mutex);
323         ipcomp_free_data(ipcd);
324         mutex_unlock(&ipcomp_resource_mutex);
325         kfree(ipcd);
326 }
327 EXPORT_SYMBOL_GPL(ipcomp_destroy);
328
329 int ipcomp_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack)
330 {
331         int err;
332         struct ipcomp_data *ipcd;
333         struct xfrm_algo_desc *calg_desc;
334
335         err = -EINVAL;
336         if (!x->calg) {
337                 NL_SET_ERR_MSG(extack, "Missing required compression algorithm");
338                 goto out;
339         }
340
341         if (x->encap) {
342                 NL_SET_ERR_MSG(extack, "IPComp is not compatible with encapsulation");
343                 goto out;
344         }
345
346         err = -ENOMEM;
347         ipcd = kzalloc(sizeof(*ipcd), GFP_KERNEL);
348         if (!ipcd)
349                 goto out;
350
351         mutex_lock(&ipcomp_resource_mutex);
352         if (!ipcomp_alloc_scratches())
353                 goto error;
354
355         ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name);
356         if (!ipcd->tfms)
357                 goto error;
358         mutex_unlock(&ipcomp_resource_mutex);
359
360         calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0);
361         BUG_ON(!calg_desc);
362         ipcd->threshold = calg_desc->uinfo.comp.threshold;
363         x->data = ipcd;
364         err = 0;
365 out:
366         return err;
367
368 error:
369         ipcomp_free_data(ipcd);
370         mutex_unlock(&ipcomp_resource_mutex);
371         kfree(ipcd);
372         goto out;
373 }
374 EXPORT_SYMBOL_GPL(ipcomp_init_state);
375
376 MODULE_LICENSE("GPL");
377 MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) - RFC3173");
378 MODULE_AUTHOR("James Morris <jmorris@intercode.com.au>");