GNU Linux-libre 6.7.9-gnu
[releases.git] / drivers / net / wireguard / noise.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/utils.h>
19
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33
34 void __init wg_noise_init(void)
35 {
36         struct blake2s_state blake;
37
38         blake2s(handshake_init_chaining_key, handshake_name, NULL,
39                 NOISE_HASH_LEN, sizeof(handshake_name), 0);
40         blake2s_init(&blake, NOISE_HASH_LEN);
41         blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42         blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43         blake2s_final(&blake, handshake_init_hash);
44 }
45
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49         down_write(&peer->handshake.lock);
50         if (!peer->handshake.static_identity->has_identity ||
51             !curve25519(peer->handshake.precomputed_static_static,
52                         peer->handshake.static_identity->static_private,
53                         peer->handshake.remote_static))
54                 memset(peer->handshake.precomputed_static_static, 0,
55                        NOISE_PUBLIC_KEY_LEN);
56         up_write(&peer->handshake.lock);
57 }
58
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60                              struct noise_static_identity *static_identity,
61                              const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62                              const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63                              struct wg_peer *peer)
64 {
65         memset(handshake, 0, sizeof(*handshake));
66         init_rwsem(&handshake->lock);
67         handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68         handshake->entry.peer = peer;
69         memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70         if (peer_preshared_key)
71                 memcpy(handshake->preshared_key, peer_preshared_key,
72                        NOISE_SYMMETRIC_KEY_LEN);
73         handshake->static_identity = static_identity;
74         handshake->state = HANDSHAKE_ZEROED;
75         wg_noise_precompute_static_static(peer);
76 }
77
78 static void handshake_zero(struct noise_handshake *handshake)
79 {
80         memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81         memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82         memset(&handshake->hash, 0, NOISE_HASH_LEN);
83         memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84         handshake->remote_index = 0;
85         handshake->state = HANDSHAKE_ZEROED;
86 }
87
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
89 {
90         down_write(&handshake->lock);
91         wg_index_hashtable_remove(
92                         handshake->entry.peer->device->index_hashtable,
93                         &handshake->entry);
94         handshake_zero(handshake);
95         up_write(&handshake->lock);
96 }
97
98 static struct noise_keypair *keypair_create(struct wg_peer *peer)
99 {
100         struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
101
102         if (unlikely(!keypair))
103                 return NULL;
104         spin_lock_init(&keypair->receiving_counter.lock);
105         keypair->internal_id = atomic64_inc_return(&keypair_counter);
106         keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
107         keypair->entry.peer = peer;
108         kref_init(&keypair->refcount);
109         return keypair;
110 }
111
112 static void keypair_free_rcu(struct rcu_head *rcu)
113 {
114         kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
115 }
116
117 static void keypair_free_kref(struct kref *kref)
118 {
119         struct noise_keypair *keypair =
120                 container_of(kref, struct noise_keypair, refcount);
121
122         net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
123                             keypair->entry.peer->device->dev->name,
124                             keypair->internal_id,
125                             keypair->entry.peer->internal_id);
126         wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
127                                   &keypair->entry);
128         call_rcu(&keypair->rcu, keypair_free_rcu);
129 }
130
131 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
132 {
133         if (unlikely(!keypair))
134                 return;
135         if (unlikely(unreference_now))
136                 wg_index_hashtable_remove(
137                         keypair->entry.peer->device->index_hashtable,
138                         &keypair->entry);
139         kref_put(&keypair->refcount, keypair_free_kref);
140 }
141
142 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
143 {
144         RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
145                 "Taking noise keypair reference without holding the RCU BH read lock");
146         if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
147                 return NULL;
148         return keypair;
149 }
150
151 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
152 {
153         struct noise_keypair *old;
154
155         spin_lock_bh(&keypairs->keypair_update_lock);
156
157         /* We zero the next_keypair before zeroing the others, so that
158          * wg_noise_received_with_keypair returns early before subsequent ones
159          * are zeroed.
160          */
161         old = rcu_dereference_protected(keypairs->next_keypair,
162                 lockdep_is_held(&keypairs->keypair_update_lock));
163         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
164         wg_noise_keypair_put(old, true);
165
166         old = rcu_dereference_protected(keypairs->previous_keypair,
167                 lockdep_is_held(&keypairs->keypair_update_lock));
168         RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
169         wg_noise_keypair_put(old, true);
170
171         old = rcu_dereference_protected(keypairs->current_keypair,
172                 lockdep_is_held(&keypairs->keypair_update_lock));
173         RCU_INIT_POINTER(keypairs->current_keypair, NULL);
174         wg_noise_keypair_put(old, true);
175
176         spin_unlock_bh(&keypairs->keypair_update_lock);
177 }
178
179 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
180 {
181         struct noise_keypair *keypair;
182
183         wg_noise_handshake_clear(&peer->handshake);
184         wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
185
186         spin_lock_bh(&peer->keypairs.keypair_update_lock);
187         keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
188                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
189         if (keypair)
190                 keypair->sending.is_valid = false;
191         keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
192                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
193         if (keypair)
194                 keypair->sending.is_valid = false;
195         spin_unlock_bh(&peer->keypairs.keypair_update_lock);
196 }
197
198 static void add_new_keypair(struct noise_keypairs *keypairs,
199                             struct noise_keypair *new_keypair)
200 {
201         struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
202
203         spin_lock_bh(&keypairs->keypair_update_lock);
204         previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
205                 lockdep_is_held(&keypairs->keypair_update_lock));
206         next_keypair = rcu_dereference_protected(keypairs->next_keypair,
207                 lockdep_is_held(&keypairs->keypair_update_lock));
208         current_keypair = rcu_dereference_protected(keypairs->current_keypair,
209                 lockdep_is_held(&keypairs->keypair_update_lock));
210         if (new_keypair->i_am_the_initiator) {
211                 /* If we're the initiator, it means we've sent a handshake, and
212                  * received a confirmation response, which means this new
213                  * keypair can now be used.
214                  */
215                 if (next_keypair) {
216                         /* If there already was a next keypair pending, we
217                          * demote it to be the previous keypair, and free the
218                          * existing current. Note that this means KCI can result
219                          * in this transition. It would perhaps be more sound to
220                          * always just get rid of the unused next keypair
221                          * instead of putting it in the previous slot, but this
222                          * might be a bit less robust. Something to think about
223                          * for the future.
224                          */
225                         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
226                         rcu_assign_pointer(keypairs->previous_keypair,
227                                            next_keypair);
228                         wg_noise_keypair_put(current_keypair, true);
229                 } else /* If there wasn't an existing next keypair, we replace
230                         * the previous with the current one.
231                         */
232                         rcu_assign_pointer(keypairs->previous_keypair,
233                                            current_keypair);
234                 /* At this point we can get rid of the old previous keypair, and
235                  * set up the new keypair.
236                  */
237                 wg_noise_keypair_put(previous_keypair, true);
238                 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
239         } else {
240                 /* If we're the responder, it means we can't use the new keypair
241                  * until we receive confirmation via the first data packet, so
242                  * we get rid of the existing previous one, the possibly
243                  * existing next one, and slide in the new next one.
244                  */
245                 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
246                 wg_noise_keypair_put(next_keypair, true);
247                 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
248                 wg_noise_keypair_put(previous_keypair, true);
249         }
250         spin_unlock_bh(&keypairs->keypair_update_lock);
251 }
252
253 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
254                                     struct noise_keypair *received_keypair)
255 {
256         struct noise_keypair *old_keypair;
257         bool key_is_new;
258
259         /* We first check without taking the spinlock. */
260         key_is_new = received_keypair ==
261                      rcu_access_pointer(keypairs->next_keypair);
262         if (likely(!key_is_new))
263                 return false;
264
265         spin_lock_bh(&keypairs->keypair_update_lock);
266         /* After locking, we double check that things didn't change from
267          * beneath us.
268          */
269         if (unlikely(received_keypair !=
270                     rcu_dereference_protected(keypairs->next_keypair,
271                             lockdep_is_held(&keypairs->keypair_update_lock)))) {
272                 spin_unlock_bh(&keypairs->keypair_update_lock);
273                 return false;
274         }
275
276         /* When we've finally received the confirmation, we slide the next
277          * into the current, the current into the previous, and get rid of
278          * the old previous.
279          */
280         old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
281                 lockdep_is_held(&keypairs->keypair_update_lock));
282         rcu_assign_pointer(keypairs->previous_keypair,
283                 rcu_dereference_protected(keypairs->current_keypair,
284                         lockdep_is_held(&keypairs->keypair_update_lock)));
285         wg_noise_keypair_put(old_keypair, true);
286         rcu_assign_pointer(keypairs->current_keypair, received_keypair);
287         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
288
289         spin_unlock_bh(&keypairs->keypair_update_lock);
290         return true;
291 }
292
293 /* Must hold static_identity->lock */
294 void wg_noise_set_static_identity_private_key(
295         struct noise_static_identity *static_identity,
296         const u8 private_key[NOISE_PUBLIC_KEY_LEN])
297 {
298         memcpy(static_identity->static_private, private_key,
299                NOISE_PUBLIC_KEY_LEN);
300         curve25519_clamp_secret(static_identity->static_private);
301         static_identity->has_identity = curve25519_generate_public(
302                 static_identity->static_public, private_key);
303 }
304
305 static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
306 {
307         struct blake2s_state state;
308         u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
309         u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
310         int i;
311
312         if (keylen > BLAKE2S_BLOCK_SIZE) {
313                 blake2s_init(&state, BLAKE2S_HASH_SIZE);
314                 blake2s_update(&state, key, keylen);
315                 blake2s_final(&state, x_key);
316         } else
317                 memcpy(x_key, key, keylen);
318
319         for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
320                 x_key[i] ^= 0x36;
321
322         blake2s_init(&state, BLAKE2S_HASH_SIZE);
323         blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
324         blake2s_update(&state, in, inlen);
325         blake2s_final(&state, i_hash);
326
327         for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
328                 x_key[i] ^= 0x5c ^ 0x36;
329
330         blake2s_init(&state, BLAKE2S_HASH_SIZE);
331         blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
332         blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
333         blake2s_final(&state, i_hash);
334
335         memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
336         memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
337         memzero_explicit(i_hash, BLAKE2S_HASH_SIZE);
338 }
339
340 /* This is Hugo Krawczyk's HKDF:
341  *  - https://eprint.iacr.org/2010/264.pdf
342  *  - https://tools.ietf.org/html/rfc5869
343  */
344 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
345                 size_t first_len, size_t second_len, size_t third_len,
346                 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
347 {
348         u8 output[BLAKE2S_HASH_SIZE + 1];
349         u8 secret[BLAKE2S_HASH_SIZE];
350
351         WARN_ON(IS_ENABLED(DEBUG) &&
352                 (first_len > BLAKE2S_HASH_SIZE ||
353                  second_len > BLAKE2S_HASH_SIZE ||
354                  third_len > BLAKE2S_HASH_SIZE ||
355                  ((second_len || second_dst || third_len || third_dst) &&
356                   (!first_len || !first_dst)) ||
357                  ((third_len || third_dst) && (!second_len || !second_dst))));
358
359         /* Extract entropy from data into secret */
360         hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
361
362         if (!first_dst || !first_len)
363                 goto out;
364
365         /* Expand first key: key = secret, data = 0x1 */
366         output[0] = 1;
367         hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
368         memcpy(first_dst, output, first_len);
369
370         if (!second_dst || !second_len)
371                 goto out;
372
373         /* Expand second key: key = secret, data = first-key || 0x2 */
374         output[BLAKE2S_HASH_SIZE] = 2;
375         hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
376         memcpy(second_dst, output, second_len);
377
378         if (!third_dst || !third_len)
379                 goto out;
380
381         /* Expand third key: key = secret, data = second-key || 0x3 */
382         output[BLAKE2S_HASH_SIZE] = 3;
383         hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
384         memcpy(third_dst, output, third_len);
385
386 out:
387         /* Clear sensitive data from stack */
388         memzero_explicit(secret, BLAKE2S_HASH_SIZE);
389         memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
390 }
391
392 static void derive_keys(struct noise_symmetric_key *first_dst,
393                         struct noise_symmetric_key *second_dst,
394                         const u8 chaining_key[NOISE_HASH_LEN])
395 {
396         u64 birthdate = ktime_get_coarse_boottime_ns();
397         kdf(first_dst->key, second_dst->key, NULL, NULL,
398             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
399             chaining_key);
400         first_dst->birthdate = second_dst->birthdate = birthdate;
401         first_dst->is_valid = second_dst->is_valid = true;
402 }
403
404 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
405                                 u8 key[NOISE_SYMMETRIC_KEY_LEN],
406                                 const u8 private[NOISE_PUBLIC_KEY_LEN],
407                                 const u8 public[NOISE_PUBLIC_KEY_LEN])
408 {
409         u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
410
411         if (unlikely(!curve25519(dh_calculation, private, public)))
412                 return false;
413         kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
414             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
415         memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
416         return true;
417 }
418
419 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
420                                             u8 key[NOISE_SYMMETRIC_KEY_LEN],
421                                             const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
422 {
423         static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
424         if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
425                 return false;
426         kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
427             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
428             chaining_key);
429         return true;
430 }
431
432 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
433 {
434         struct blake2s_state blake;
435
436         blake2s_init(&blake, NOISE_HASH_LEN);
437         blake2s_update(&blake, hash, NOISE_HASH_LEN);
438         blake2s_update(&blake, src, src_len);
439         blake2s_final(&blake, hash);
440 }
441
442 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
443                     u8 key[NOISE_SYMMETRIC_KEY_LEN],
444                     const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
445 {
446         u8 temp_hash[NOISE_HASH_LEN];
447
448         kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
449             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
450         mix_hash(hash, temp_hash, NOISE_HASH_LEN);
451         memzero_explicit(temp_hash, NOISE_HASH_LEN);
452 }
453
454 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
455                            u8 hash[NOISE_HASH_LEN],
456                            const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
457 {
458         memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
459         memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
460         mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
461 }
462
463 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
464                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
465                             u8 hash[NOISE_HASH_LEN])
466 {
467         chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
468                                  NOISE_HASH_LEN,
469                                  0 /* Always zero for Noise_IK */, key);
470         mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
471 }
472
473 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
474                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
475                             u8 hash[NOISE_HASH_LEN])
476 {
477         if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
478                                       hash, NOISE_HASH_LEN,
479                                       0 /* Always zero for Noise_IK */, key))
480                 return false;
481         mix_hash(hash, src_ciphertext, src_len);
482         return true;
483 }
484
485 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
486                               const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
487                               u8 chaining_key[NOISE_HASH_LEN],
488                               u8 hash[NOISE_HASH_LEN])
489 {
490         if (ephemeral_dst != ephemeral_src)
491                 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
492         mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
493         kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
494             NOISE_PUBLIC_KEY_LEN, chaining_key);
495 }
496
497 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
498 {
499         struct timespec64 now;
500
501         ktime_get_real_ts64(&now);
502
503         /* In order to prevent some sort of infoleak from precise timers, we
504          * round down the nanoseconds part to the closest rounded-down power of
505          * two to the maximum initiations per second allowed anyway by the
506          * implementation.
507          */
508         now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
509                 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
510
511         /* https://cr.yp.to/libtai/tai64.html */
512         *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
513         *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
514 }
515
516 bool
517 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
518                                      struct noise_handshake *handshake)
519 {
520         u8 timestamp[NOISE_TIMESTAMP_LEN];
521         u8 key[NOISE_SYMMETRIC_KEY_LEN];
522         bool ret = false;
523
524         /* We need to wait for crng _before_ taking any locks, since
525          * curve25519_generate_secret uses get_random_bytes_wait.
526          */
527         wait_for_random_bytes();
528
529         down_read(&handshake->static_identity->lock);
530         down_write(&handshake->lock);
531
532         if (unlikely(!handshake->static_identity->has_identity))
533                 goto out;
534
535         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
536
537         handshake_init(handshake->chaining_key, handshake->hash,
538                        handshake->remote_static);
539
540         /* e */
541         curve25519_generate_secret(handshake->ephemeral_private);
542         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
543                                         handshake->ephemeral_private))
544                 goto out;
545         message_ephemeral(dst->unencrypted_ephemeral,
546                           dst->unencrypted_ephemeral, handshake->chaining_key,
547                           handshake->hash);
548
549         /* es */
550         if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
551                     handshake->remote_static))
552                 goto out;
553
554         /* s */
555         message_encrypt(dst->encrypted_static,
556                         handshake->static_identity->static_public,
557                         NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
558
559         /* ss */
560         if (!mix_precomputed_dh(handshake->chaining_key, key,
561                                 handshake->precomputed_static_static))
562                 goto out;
563
564         /* {t} */
565         tai64n_now(timestamp);
566         message_encrypt(dst->encrypted_timestamp, timestamp,
567                         NOISE_TIMESTAMP_LEN, key, handshake->hash);
568
569         dst->sender_index = wg_index_hashtable_insert(
570                 handshake->entry.peer->device->index_hashtable,
571                 &handshake->entry);
572
573         handshake->state = HANDSHAKE_CREATED_INITIATION;
574         ret = true;
575
576 out:
577         up_write(&handshake->lock);
578         up_read(&handshake->static_identity->lock);
579         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
580         return ret;
581 }
582
583 struct wg_peer *
584 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
585                                       struct wg_device *wg)
586 {
587         struct wg_peer *peer = NULL, *ret_peer = NULL;
588         struct noise_handshake *handshake;
589         bool replay_attack, flood_attack;
590         u8 key[NOISE_SYMMETRIC_KEY_LEN];
591         u8 chaining_key[NOISE_HASH_LEN];
592         u8 hash[NOISE_HASH_LEN];
593         u8 s[NOISE_PUBLIC_KEY_LEN];
594         u8 e[NOISE_PUBLIC_KEY_LEN];
595         u8 t[NOISE_TIMESTAMP_LEN];
596         u64 initiation_consumption;
597
598         down_read(&wg->static_identity.lock);
599         if (unlikely(!wg->static_identity.has_identity))
600                 goto out;
601
602         handshake_init(chaining_key, hash, wg->static_identity.static_public);
603
604         /* e */
605         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
606
607         /* es */
608         if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
609                 goto out;
610
611         /* s */
612         if (!message_decrypt(s, src->encrypted_static,
613                              sizeof(src->encrypted_static), key, hash))
614                 goto out;
615
616         /* Lookup which peer we're actually talking to */
617         peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
618         if (!peer)
619                 goto out;
620         handshake = &peer->handshake;
621
622         /* ss */
623         if (!mix_precomputed_dh(chaining_key, key,
624                                 handshake->precomputed_static_static))
625             goto out;
626
627         /* {t} */
628         if (!message_decrypt(t, src->encrypted_timestamp,
629                              sizeof(src->encrypted_timestamp), key, hash))
630                 goto out;
631
632         down_read(&handshake->lock);
633         replay_attack = memcmp(t, handshake->latest_timestamp,
634                                NOISE_TIMESTAMP_LEN) <= 0;
635         flood_attack = (s64)handshake->last_initiation_consumption +
636                                NSEC_PER_SEC / INITIATIONS_PER_SECOND >
637                        (s64)ktime_get_coarse_boottime_ns();
638         up_read(&handshake->lock);
639         if (replay_attack || flood_attack)
640                 goto out;
641
642         /* Success! Copy everything to peer */
643         down_write(&handshake->lock);
644         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
645         if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
646                 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
647         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
648         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
649         handshake->remote_index = src->sender_index;
650         initiation_consumption = ktime_get_coarse_boottime_ns();
651         if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
652                 handshake->last_initiation_consumption = initiation_consumption;
653         handshake->state = HANDSHAKE_CONSUMED_INITIATION;
654         up_write(&handshake->lock);
655         ret_peer = peer;
656
657 out:
658         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
659         memzero_explicit(hash, NOISE_HASH_LEN);
660         memzero_explicit(chaining_key, NOISE_HASH_LEN);
661         up_read(&wg->static_identity.lock);
662         if (!ret_peer)
663                 wg_peer_put(peer);
664         return ret_peer;
665 }
666
667 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
668                                         struct noise_handshake *handshake)
669 {
670         u8 key[NOISE_SYMMETRIC_KEY_LEN];
671         bool ret = false;
672
673         /* We need to wait for crng _before_ taking any locks, since
674          * curve25519_generate_secret uses get_random_bytes_wait.
675          */
676         wait_for_random_bytes();
677
678         down_read(&handshake->static_identity->lock);
679         down_write(&handshake->lock);
680
681         if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
682                 goto out;
683
684         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
685         dst->receiver_index = handshake->remote_index;
686
687         /* e */
688         curve25519_generate_secret(handshake->ephemeral_private);
689         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
690                                         handshake->ephemeral_private))
691                 goto out;
692         message_ephemeral(dst->unencrypted_ephemeral,
693                           dst->unencrypted_ephemeral, handshake->chaining_key,
694                           handshake->hash);
695
696         /* ee */
697         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
698                     handshake->remote_ephemeral))
699                 goto out;
700
701         /* se */
702         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
703                     handshake->remote_static))
704                 goto out;
705
706         /* psk */
707         mix_psk(handshake->chaining_key, handshake->hash, key,
708                 handshake->preshared_key);
709
710         /* {} */
711         message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
712
713         dst->sender_index = wg_index_hashtable_insert(
714                 handshake->entry.peer->device->index_hashtable,
715                 &handshake->entry);
716
717         handshake->state = HANDSHAKE_CREATED_RESPONSE;
718         ret = true;
719
720 out:
721         up_write(&handshake->lock);
722         up_read(&handshake->static_identity->lock);
723         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
724         return ret;
725 }
726
727 struct wg_peer *
728 wg_noise_handshake_consume_response(struct message_handshake_response *src,
729                                     struct wg_device *wg)
730 {
731         enum noise_handshake_state state = HANDSHAKE_ZEROED;
732         struct wg_peer *peer = NULL, *ret_peer = NULL;
733         struct noise_handshake *handshake;
734         u8 key[NOISE_SYMMETRIC_KEY_LEN];
735         u8 hash[NOISE_HASH_LEN];
736         u8 chaining_key[NOISE_HASH_LEN];
737         u8 e[NOISE_PUBLIC_KEY_LEN];
738         u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
739         u8 static_private[NOISE_PUBLIC_KEY_LEN];
740         u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
741
742         down_read(&wg->static_identity.lock);
743
744         if (unlikely(!wg->static_identity.has_identity))
745                 goto out;
746
747         handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
748                 wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
749                 src->receiver_index, &peer);
750         if (unlikely(!handshake))
751                 goto out;
752
753         down_read(&handshake->lock);
754         state = handshake->state;
755         memcpy(hash, handshake->hash, NOISE_HASH_LEN);
756         memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
757         memcpy(ephemeral_private, handshake->ephemeral_private,
758                NOISE_PUBLIC_KEY_LEN);
759         memcpy(preshared_key, handshake->preshared_key,
760                NOISE_SYMMETRIC_KEY_LEN);
761         up_read(&handshake->lock);
762
763         if (state != HANDSHAKE_CREATED_INITIATION)
764                 goto fail;
765
766         /* e */
767         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
768
769         /* ee */
770         if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
771                 goto fail;
772
773         /* se */
774         if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
775                 goto fail;
776
777         /* psk */
778         mix_psk(chaining_key, hash, key, preshared_key);
779
780         /* {} */
781         if (!message_decrypt(NULL, src->encrypted_nothing,
782                              sizeof(src->encrypted_nothing), key, hash))
783                 goto fail;
784
785         /* Success! Copy everything to peer */
786         down_write(&handshake->lock);
787         /* It's important to check that the state is still the same, while we
788          * have an exclusive lock.
789          */
790         if (handshake->state != state) {
791                 up_write(&handshake->lock);
792                 goto fail;
793         }
794         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
795         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
796         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
797         handshake->remote_index = src->sender_index;
798         handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
799         up_write(&handshake->lock);
800         ret_peer = peer;
801         goto out;
802
803 fail:
804         wg_peer_put(peer);
805 out:
806         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
807         memzero_explicit(hash, NOISE_HASH_LEN);
808         memzero_explicit(chaining_key, NOISE_HASH_LEN);
809         memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
810         memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
811         memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
812         up_read(&wg->static_identity.lock);
813         return ret_peer;
814 }
815
816 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
817                                       struct noise_keypairs *keypairs)
818 {
819         struct noise_keypair *new_keypair;
820         bool ret = false;
821
822         down_write(&handshake->lock);
823         if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
824             handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
825                 goto out;
826
827         new_keypair = keypair_create(handshake->entry.peer);
828         if (!new_keypair)
829                 goto out;
830         new_keypair->i_am_the_initiator = handshake->state ==
831                                           HANDSHAKE_CONSUMED_RESPONSE;
832         new_keypair->remote_index = handshake->remote_index;
833
834         if (new_keypair->i_am_the_initiator)
835                 derive_keys(&new_keypair->sending, &new_keypair->receiving,
836                             handshake->chaining_key);
837         else
838                 derive_keys(&new_keypair->receiving, &new_keypair->sending,
839                             handshake->chaining_key);
840
841         handshake_zero(handshake);
842         rcu_read_lock_bh();
843         if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
844                                            handshake)->is_dead))) {
845                 add_new_keypair(keypairs, new_keypair);
846                 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
847                                     handshake->entry.peer->device->dev->name,
848                                     new_keypair->internal_id,
849                                     handshake->entry.peer->internal_id);
850                 ret = wg_index_hashtable_replace(
851                         handshake->entry.peer->device->index_hashtable,
852                         &handshake->entry, &new_keypair->entry);
853         } else {
854                 kfree_sensitive(new_keypair);
855         }
856         rcu_read_unlock_bh();
857
858 out:
859         up_write(&handshake->lock);
860         return ret;
861 }