Linux 6.7-rc7
[linux-modified.git] / arch / arm64 / lib / xor-neon.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * arch/arm64/lib/xor-neon.c
4  *
5  * Authors: Jackie Liu <liuyun01@kylinos.cn>
6  * Copyright (C) 2018,Tianjin KYLIN Information Technology Co., Ltd.
7  */
8
9 #include <linux/raid/xor.h>
10 #include <linux/module.h>
11 #include <asm/neon-intrinsics.h>
12
13 static void xor_arm64_neon_2(unsigned long bytes, unsigned long * __restrict p1,
14         const unsigned long * __restrict p2)
15 {
16         uint64_t *dp1 = (uint64_t *)p1;
17         uint64_t *dp2 = (uint64_t *)p2;
18
19         register uint64x2_t v0, v1, v2, v3;
20         long lines = bytes / (sizeof(uint64x2_t) * 4);
21
22         do {
23                 /* p1 ^= p2 */
24                 v0 = veorq_u64(vld1q_u64(dp1 +  0), vld1q_u64(dp2 +  0));
25                 v1 = veorq_u64(vld1q_u64(dp1 +  2), vld1q_u64(dp2 +  2));
26                 v2 = veorq_u64(vld1q_u64(dp1 +  4), vld1q_u64(dp2 +  4));
27                 v3 = veorq_u64(vld1q_u64(dp1 +  6), vld1q_u64(dp2 +  6));
28
29                 /* store */
30                 vst1q_u64(dp1 +  0, v0);
31                 vst1q_u64(dp1 +  2, v1);
32                 vst1q_u64(dp1 +  4, v2);
33                 vst1q_u64(dp1 +  6, v3);
34
35                 dp1 += 8;
36                 dp2 += 8;
37         } while (--lines > 0);
38 }
39
40 static void xor_arm64_neon_3(unsigned long bytes, unsigned long * __restrict p1,
41         const unsigned long * __restrict p2,
42         const unsigned long * __restrict p3)
43 {
44         uint64_t *dp1 = (uint64_t *)p1;
45         uint64_t *dp2 = (uint64_t *)p2;
46         uint64_t *dp3 = (uint64_t *)p3;
47
48         register uint64x2_t v0, v1, v2, v3;
49         long lines = bytes / (sizeof(uint64x2_t) * 4);
50
51         do {
52                 /* p1 ^= p2 */
53                 v0 = veorq_u64(vld1q_u64(dp1 +  0), vld1q_u64(dp2 +  0));
54                 v1 = veorq_u64(vld1q_u64(dp1 +  2), vld1q_u64(dp2 +  2));
55                 v2 = veorq_u64(vld1q_u64(dp1 +  4), vld1q_u64(dp2 +  4));
56                 v3 = veorq_u64(vld1q_u64(dp1 +  6), vld1q_u64(dp2 +  6));
57
58                 /* p1 ^= p3 */
59                 v0 = veorq_u64(v0, vld1q_u64(dp3 +  0));
60                 v1 = veorq_u64(v1, vld1q_u64(dp3 +  2));
61                 v2 = veorq_u64(v2, vld1q_u64(dp3 +  4));
62                 v3 = veorq_u64(v3, vld1q_u64(dp3 +  6));
63
64                 /* store */
65                 vst1q_u64(dp1 +  0, v0);
66                 vst1q_u64(dp1 +  2, v1);
67                 vst1q_u64(dp1 +  4, v2);
68                 vst1q_u64(dp1 +  6, v3);
69
70                 dp1 += 8;
71                 dp2 += 8;
72                 dp3 += 8;
73         } while (--lines > 0);
74 }
75
76 static void xor_arm64_neon_4(unsigned long bytes, unsigned long * __restrict p1,
77         const unsigned long * __restrict p2,
78         const unsigned long * __restrict p3,
79         const unsigned long * __restrict p4)
80 {
81         uint64_t *dp1 = (uint64_t *)p1;
82         uint64_t *dp2 = (uint64_t *)p2;
83         uint64_t *dp3 = (uint64_t *)p3;
84         uint64_t *dp4 = (uint64_t *)p4;
85
86         register uint64x2_t v0, v1, v2, v3;
87         long lines = bytes / (sizeof(uint64x2_t) * 4);
88
89         do {
90                 /* p1 ^= p2 */
91                 v0 = veorq_u64(vld1q_u64(dp1 +  0), vld1q_u64(dp2 +  0));
92                 v1 = veorq_u64(vld1q_u64(dp1 +  2), vld1q_u64(dp2 +  2));
93                 v2 = veorq_u64(vld1q_u64(dp1 +  4), vld1q_u64(dp2 +  4));
94                 v3 = veorq_u64(vld1q_u64(dp1 +  6), vld1q_u64(dp2 +  6));
95
96                 /* p1 ^= p3 */
97                 v0 = veorq_u64(v0, vld1q_u64(dp3 +  0));
98                 v1 = veorq_u64(v1, vld1q_u64(dp3 +  2));
99                 v2 = veorq_u64(v2, vld1q_u64(dp3 +  4));
100                 v3 = veorq_u64(v3, vld1q_u64(dp3 +  6));
101
102                 /* p1 ^= p4 */
103                 v0 = veorq_u64(v0, vld1q_u64(dp4 +  0));
104                 v1 = veorq_u64(v1, vld1q_u64(dp4 +  2));
105                 v2 = veorq_u64(v2, vld1q_u64(dp4 +  4));
106                 v3 = veorq_u64(v3, vld1q_u64(dp4 +  6));
107
108                 /* store */
109                 vst1q_u64(dp1 +  0, v0);
110                 vst1q_u64(dp1 +  2, v1);
111                 vst1q_u64(dp1 +  4, v2);
112                 vst1q_u64(dp1 +  6, v3);
113
114                 dp1 += 8;
115                 dp2 += 8;
116                 dp3 += 8;
117                 dp4 += 8;
118         } while (--lines > 0);
119 }
120
121 static void xor_arm64_neon_5(unsigned long bytes, unsigned long * __restrict p1,
122         const unsigned long * __restrict p2,
123         const unsigned long * __restrict p3,
124         const unsigned long * __restrict p4,
125         const unsigned long * __restrict p5)
126 {
127         uint64_t *dp1 = (uint64_t *)p1;
128         uint64_t *dp2 = (uint64_t *)p2;
129         uint64_t *dp3 = (uint64_t *)p3;
130         uint64_t *dp4 = (uint64_t *)p4;
131         uint64_t *dp5 = (uint64_t *)p5;
132
133         register uint64x2_t v0, v1, v2, v3;
134         long lines = bytes / (sizeof(uint64x2_t) * 4);
135
136         do {
137                 /* p1 ^= p2 */
138                 v0 = veorq_u64(vld1q_u64(dp1 +  0), vld1q_u64(dp2 +  0));
139                 v1 = veorq_u64(vld1q_u64(dp1 +  2), vld1q_u64(dp2 +  2));
140                 v2 = veorq_u64(vld1q_u64(dp1 +  4), vld1q_u64(dp2 +  4));
141                 v3 = veorq_u64(vld1q_u64(dp1 +  6), vld1q_u64(dp2 +  6));
142
143                 /* p1 ^= p3 */
144                 v0 = veorq_u64(v0, vld1q_u64(dp3 +  0));
145                 v1 = veorq_u64(v1, vld1q_u64(dp3 +  2));
146                 v2 = veorq_u64(v2, vld1q_u64(dp3 +  4));
147                 v3 = veorq_u64(v3, vld1q_u64(dp3 +  6));
148
149                 /* p1 ^= p4 */
150                 v0 = veorq_u64(v0, vld1q_u64(dp4 +  0));
151                 v1 = veorq_u64(v1, vld1q_u64(dp4 +  2));
152                 v2 = veorq_u64(v2, vld1q_u64(dp4 +  4));
153                 v3 = veorq_u64(v3, vld1q_u64(dp4 +  6));
154
155                 /* p1 ^= p5 */
156                 v0 = veorq_u64(v0, vld1q_u64(dp5 +  0));
157                 v1 = veorq_u64(v1, vld1q_u64(dp5 +  2));
158                 v2 = veorq_u64(v2, vld1q_u64(dp5 +  4));
159                 v3 = veorq_u64(v3, vld1q_u64(dp5 +  6));
160
161                 /* store */
162                 vst1q_u64(dp1 +  0, v0);
163                 vst1q_u64(dp1 +  2, v1);
164                 vst1q_u64(dp1 +  4, v2);
165                 vst1q_u64(dp1 +  6, v3);
166
167                 dp1 += 8;
168                 dp2 += 8;
169                 dp3 += 8;
170                 dp4 += 8;
171                 dp5 += 8;
172         } while (--lines > 0);
173 }
174
175 struct xor_block_template xor_block_inner_neon __ro_after_init = {
176         .name   = "__inner_neon__",
177         .do_2   = xor_arm64_neon_2,
178         .do_3   = xor_arm64_neon_3,
179         .do_4   = xor_arm64_neon_4,
180         .do_5   = xor_arm64_neon_5,
181 };
182 EXPORT_SYMBOL(xor_block_inner_neon);
183
184 static inline uint64x2_t eor3(uint64x2_t p, uint64x2_t q, uint64x2_t r)
185 {
186         uint64x2_t res;
187
188         asm(ARM64_ASM_PREAMBLE ".arch_extension sha3\n"
189             "eor3 %0.16b, %1.16b, %2.16b, %3.16b"
190             : "=w"(res) : "w"(p), "w"(q), "w"(r));
191         return res;
192 }
193
194 static void xor_arm64_eor3_3(unsigned long bytes,
195         unsigned long * __restrict p1,
196         const unsigned long * __restrict p2,
197         const unsigned long * __restrict p3)
198 {
199         uint64_t *dp1 = (uint64_t *)p1;
200         uint64_t *dp2 = (uint64_t *)p2;
201         uint64_t *dp3 = (uint64_t *)p3;
202
203         register uint64x2_t v0, v1, v2, v3;
204         long lines = bytes / (sizeof(uint64x2_t) * 4);
205
206         do {
207                 /* p1 ^= p2 ^ p3 */
208                 v0 = eor3(vld1q_u64(dp1 + 0), vld1q_u64(dp2 + 0),
209                           vld1q_u64(dp3 + 0));
210                 v1 = eor3(vld1q_u64(dp1 + 2), vld1q_u64(dp2 + 2),
211                           vld1q_u64(dp3 + 2));
212                 v2 = eor3(vld1q_u64(dp1 + 4), vld1q_u64(dp2 + 4),
213                           vld1q_u64(dp3 + 4));
214                 v3 = eor3(vld1q_u64(dp1 + 6), vld1q_u64(dp2 + 6),
215                           vld1q_u64(dp3 + 6));
216
217                 /* store */
218                 vst1q_u64(dp1 + 0, v0);
219                 vst1q_u64(dp1 + 2, v1);
220                 vst1q_u64(dp1 + 4, v2);
221                 vst1q_u64(dp1 + 6, v3);
222
223                 dp1 += 8;
224                 dp2 += 8;
225                 dp3 += 8;
226         } while (--lines > 0);
227 }
228
229 static void xor_arm64_eor3_4(unsigned long bytes,
230         unsigned long * __restrict p1,
231         const unsigned long * __restrict p2,
232         const unsigned long * __restrict p3,
233         const unsigned long * __restrict p4)
234 {
235         uint64_t *dp1 = (uint64_t *)p1;
236         uint64_t *dp2 = (uint64_t *)p2;
237         uint64_t *dp3 = (uint64_t *)p3;
238         uint64_t *dp4 = (uint64_t *)p4;
239
240         register uint64x2_t v0, v1, v2, v3;
241         long lines = bytes / (sizeof(uint64x2_t) * 4);
242
243         do {
244                 /* p1 ^= p2 ^ p3 */
245                 v0 = eor3(vld1q_u64(dp1 + 0), vld1q_u64(dp2 + 0),
246                           vld1q_u64(dp3 + 0));
247                 v1 = eor3(vld1q_u64(dp1 + 2), vld1q_u64(dp2 + 2),
248                           vld1q_u64(dp3 + 2));
249                 v2 = eor3(vld1q_u64(dp1 + 4), vld1q_u64(dp2 + 4),
250                           vld1q_u64(dp3 + 4));
251                 v3 = eor3(vld1q_u64(dp1 + 6), vld1q_u64(dp2 + 6),
252                           vld1q_u64(dp3 + 6));
253
254                 /* p1 ^= p4 */
255                 v0 = veorq_u64(v0, vld1q_u64(dp4 + 0));
256                 v1 = veorq_u64(v1, vld1q_u64(dp4 + 2));
257                 v2 = veorq_u64(v2, vld1q_u64(dp4 + 4));
258                 v3 = veorq_u64(v3, vld1q_u64(dp4 + 6));
259
260                 /* store */
261                 vst1q_u64(dp1 + 0, v0);
262                 vst1q_u64(dp1 + 2, v1);
263                 vst1q_u64(dp1 + 4, v2);
264                 vst1q_u64(dp1 + 6, v3);
265
266                 dp1 += 8;
267                 dp2 += 8;
268                 dp3 += 8;
269                 dp4 += 8;
270         } while (--lines > 0);
271 }
272
273 static void xor_arm64_eor3_5(unsigned long bytes,
274         unsigned long * __restrict p1,
275         const unsigned long * __restrict p2,
276         const unsigned long * __restrict p3,
277         const unsigned long * __restrict p4,
278         const unsigned long * __restrict p5)
279 {
280         uint64_t *dp1 = (uint64_t *)p1;
281         uint64_t *dp2 = (uint64_t *)p2;
282         uint64_t *dp3 = (uint64_t *)p3;
283         uint64_t *dp4 = (uint64_t *)p4;
284         uint64_t *dp5 = (uint64_t *)p5;
285
286         register uint64x2_t v0, v1, v2, v3;
287         long lines = bytes / (sizeof(uint64x2_t) * 4);
288
289         do {
290                 /* p1 ^= p2 ^ p3 */
291                 v0 = eor3(vld1q_u64(dp1 + 0), vld1q_u64(dp2 + 0),
292                           vld1q_u64(dp3 + 0));
293                 v1 = eor3(vld1q_u64(dp1 + 2), vld1q_u64(dp2 + 2),
294                           vld1q_u64(dp3 + 2));
295                 v2 = eor3(vld1q_u64(dp1 + 4), vld1q_u64(dp2 + 4),
296                           vld1q_u64(dp3 + 4));
297                 v3 = eor3(vld1q_u64(dp1 + 6), vld1q_u64(dp2 + 6),
298                           vld1q_u64(dp3 + 6));
299
300                 /* p1 ^= p4 ^ p5 */
301                 v0 = eor3(v0, vld1q_u64(dp4 + 0), vld1q_u64(dp5 + 0));
302                 v1 = eor3(v1, vld1q_u64(dp4 + 2), vld1q_u64(dp5 + 2));
303                 v2 = eor3(v2, vld1q_u64(dp4 + 4), vld1q_u64(dp5 + 4));
304                 v3 = eor3(v3, vld1q_u64(dp4 + 6), vld1q_u64(dp5 + 6));
305
306                 /* store */
307                 vst1q_u64(dp1 + 0, v0);
308                 vst1q_u64(dp1 + 2, v1);
309                 vst1q_u64(dp1 + 4, v2);
310                 vst1q_u64(dp1 + 6, v3);
311
312                 dp1 += 8;
313                 dp2 += 8;
314                 dp3 += 8;
315                 dp4 += 8;
316                 dp5 += 8;
317         } while (--lines > 0);
318 }
319
320 static int __init xor_neon_init(void)
321 {
322         if (IS_ENABLED(CONFIG_AS_HAS_SHA3) && cpu_have_named_feature(SHA3)) {
323                 xor_block_inner_neon.do_3 = xor_arm64_eor3_3;
324                 xor_block_inner_neon.do_4 = xor_arm64_eor3_4;
325                 xor_block_inner_neon.do_5 = xor_arm64_eor3_5;
326         }
327         return 0;
328 }
329 module_init(xor_neon_init);
330
331 static void __exit xor_neon_exit(void)
332 {
333 }
334 module_exit(xor_neon_exit);
335
336 MODULE_AUTHOR("Jackie Liu <liuyun01@kylinos.cn>");
337 MODULE_DESCRIPTION("ARMv8 XOR Extensions");
338 MODULE_LICENSE("GPL");