Linux 6.7-rc7
[linux-modified.git] / arch / loongarch / lib / csum.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
3
4 #include <linux/compiler.h>
5 #include <linux/kasan-checks.h>
6 #include <linux/kernel.h>
7
8 #include <net/checksum.h>
9
10 static u64 accumulate(u64 sum, u64 data)
11 {
12         sum += data;
13         if (sum < data)
14                 sum += 1;
15         return sum;
16 }
17
18 /*
19  * We over-read the buffer and this makes KASAN unhappy. Instead, disable
20  * instrumentation and call kasan explicitly.
21  */
22 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
23 {
24         unsigned int offset, shift, sum;
25         const u64 *ptr;
26         u64 data, sum64 = 0;
27
28         if (unlikely(len == 0))
29                 return 0;
30
31         offset = (unsigned long)buff & 7;
32         /*
33          * This is to all intents and purposes safe, since rounding down cannot
34          * result in a different page or cache line being accessed, and @buff
35          * should absolutely not be pointing to anything read-sensitive. We do,
36          * however, have to be careful not to piss off KASAN, which means using
37          * unchecked reads to accommodate the head and tail, for which we'll
38          * compensate with an explicit check up-front.
39          */
40         kasan_check_read(buff, len);
41         ptr = (u64 *)(buff - offset);
42         len = len + offset - 8;
43
44         /*
45          * Head: zero out any excess leading bytes. Shifting back by the same
46          * amount should be at least as fast as any other way of handling the
47          * odd/even alignment, and means we can ignore it until the very end.
48          */
49         shift = offset * 8;
50         data = *ptr++;
51         data = (data >> shift) << shift;
52
53         /*
54          * Body: straightforward aligned loads from here on (the paired loads
55          * underlying the quadword type still only need dword alignment). The
56          * main loop strictly excludes the tail, so the second loop will always
57          * run at least once.
58          */
59         while (unlikely(len > 64)) {
60                 __uint128_t tmp1, tmp2, tmp3, tmp4;
61
62                 tmp1 = *(__uint128_t *)ptr;
63                 tmp2 = *(__uint128_t *)(ptr + 2);
64                 tmp3 = *(__uint128_t *)(ptr + 4);
65                 tmp4 = *(__uint128_t *)(ptr + 6);
66
67                 len -= 64;
68                 ptr += 8;
69
70                 /* This is the "don't dump the carry flag into a GPR" idiom */
71                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
72                 tmp2 += (tmp2 >> 64) | (tmp2 << 64);
73                 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
74                 tmp4 += (tmp4 >> 64) | (tmp4 << 64);
75                 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
76                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
77                 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
78                 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
79                 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
80                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
81                 tmp1 = ((tmp1 >> 64) << 64) | sum64;
82                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
83                 sum64 = tmp1 >> 64;
84         }
85         while (len > 8) {
86                 __uint128_t tmp;
87
88                 sum64 = accumulate(sum64, data);
89                 tmp = *(__uint128_t *)ptr;
90
91                 len -= 16;
92                 ptr += 2;
93
94                 data = tmp >> 64;
95                 sum64 = accumulate(sum64, tmp);
96         }
97         if (len > 0) {
98                 sum64 = accumulate(sum64, data);
99                 data = *ptr;
100                 len -= 8;
101         }
102         /*
103          * Tail: zero any over-read bytes similarly to the head, again
104          * preserving odd/even alignment.
105          */
106         shift = len * -8;
107         data = (data << shift) >> shift;
108         sum64 = accumulate(sum64, data);
109
110         /* Finally, folding */
111         sum64 += (sum64 >> 32) | (sum64 << 32);
112         sum = sum64 >> 32;
113         sum += (sum >> 16) | (sum << 16);
114         if (offset & 1)
115                 return (u16)swab32(sum);
116
117         return sum >> 16;
118 }
119
120 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
121                         const struct in6_addr *daddr,
122                         __u32 len, __u8 proto, __wsum csum)
123 {
124         __uint128_t src, dst;
125         u64 sum = (__force u64)csum;
126
127         src = *(const __uint128_t *)saddr->s6_addr;
128         dst = *(const __uint128_t *)daddr->s6_addr;
129
130         sum += (__force u32)htonl(len);
131         sum += (u32)proto << 24;
132         src += (src >> 64) | (src << 64);
133         dst += (dst >> 64) | (dst << 64);
134
135         sum = accumulate(sum, src >> 64);
136         sum = accumulate(sum, dst >> 64);
137
138         sum += ((sum >> 32) | (sum << 32));
139         return csum_fold((__force __wsum)(sum >> 32));
140 }
141 EXPORT_SYMBOL(csum_ipv6_magic);