diff options
author | John Baldwin <jhb@FreeBSD.org> | 2022-10-28 20:36:12 +0000 |
---|---|---|
committer | John Baldwin <jhb@FreeBSD.org> | 2022-10-28 20:36:12 +0000 |
commit | 744bfb213144c63cbaf38d91a1c4f7aebb9b9fbc (patch) | |
tree | 668f485d546b43d129c21513afccdfc2cc30fd3c | |
parent | 9e0aaedd704ee8a040ecb1d1aadf0bd75ed4dc09 (diff) | |
download | src-744bfb213144c63cbaf38d91a1c4f7aebb9b9fbc.tar.gz src-744bfb213144c63cbaf38d91a1c4f7aebb9b9fbc.zip |
Import the WireGuard driver from zx2c4.com.
This commit brings back the driver from FreeBSD commit
f187d6dfbf633665ba6740fe22742aec60ce02a2 plus subsequent fixes from
upstream.
Relative to upstream this commit includes a few other small fixes such
as additional INET and INET6 #ifdef's, #include cleanups, and updates
for recent API changes in main.
Reviewed by: pauamma, gbe, kevans, emaste
Obtained from: git@git.zx2c4.com:wireguard-freebsd @ 3cc22b2
Sponsored by: The FreeBSD Foundation
Differential Revision: https://reviews.freebsd.org/D36909
-rw-r--r-- | etc/mtree/BSD.include.dist | 2 | ||||
-rw-r--r-- | include/Makefile | 9 | ||||
-rw-r--r-- | share/man/man4/Makefile | 2 | ||||
-rw-r--r-- | share/man/man4/wg.4 | 213 | ||||
-rw-r--r-- | sys/conf/NOTES | 3 | ||||
-rw-r--r-- | sys/conf/files | 12 | ||||
-rw-r--r-- | sys/dev/wg/compat.h | 118 | ||||
-rw-r--r-- | sys/dev/wg/crypto.h | 182 | ||||
-rw-r--r-- | sys/dev/wg/if_wg.c | 3055 | ||||
-rw-r--r-- | sys/dev/wg/if_wg.h | 37 | ||||
-rw-r--r-- | sys/dev/wg/support.h | 21 | ||||
-rw-r--r-- | sys/dev/wg/version.h | 1 | ||||
-rw-r--r-- | sys/dev/wg/wg_cookie.c | 500 | ||||
-rw-r--r-- | sys/dev/wg/wg_cookie.h | 72 | ||||
-rw-r--r-- | sys/dev/wg/wg_crypto.c | 1830 | ||||
-rw-r--r-- | sys/dev/wg/wg_noise.c | 1410 | ||||
-rw-r--r-- | sys/dev/wg/wg_noise.h | 131 | ||||
-rw-r--r-- | sys/kern/kern_jail.c | 1 | ||||
-rw-r--r-- | sys/modules/Makefile | 4 | ||||
-rw-r--r-- | sys/modules/if_wg/Makefile | 10 | ||||
-rw-r--r-- | sys/net/if_types.h | 1 | ||||
-rw-r--r-- | sys/netinet6/nd6.c | 4 | ||||
-rw-r--r-- | sys/sys/priv.h | 1 |
23 files changed, 7613 insertions, 6 deletions
diff --git a/etc/mtree/BSD.include.dist b/etc/mtree/BSD.include.dist index 192508bbf6f1..9a1fe1cd60a7 100644 --- a/etc/mtree/BSD.include.dist +++ b/etc/mtree/BSD.include.dist @@ -136,6 +136,8 @@ .. vkbd .. + wg + .. wi .. .. diff --git a/include/Makefile b/include/Makefile index 80d2d9da8b06..988b0a56baa7 100644 --- a/include/Makefile +++ b/include/Makefile @@ -49,7 +49,7 @@ LSUBDIRS= dev/acpica dev/agp dev/ciss dev/filemon dev/firewire \ dev/hwpmc dev/hyperv \ dev/ic dev/iicbus dev/io dev/mfi dev/mmc dev/nvme \ dev/ofw dev/pbio dev/pci ${_dev_powermac_nvram} dev/ppbus dev/pwm \ - dev/smbus dev/speaker dev/tcp_log dev/veriexec dev/vkbd \ + dev/smbus dev/speaker dev/tcp_log dev/veriexec dev/vkbd dev/wg \ fs/devfs fs/fdescfs fs/msdosfs fs/nfs fs/nullfs \ fs/procfs fs/smbfs fs/udf fs/unionfs \ geom/cache geom/concat geom/eli geom/gate geom/journal geom/label \ @@ -225,6 +225,10 @@ NVPAIRDIR= ${INCLUDEDIR}/sys MLX5= mlx5io.h MLX5DIR= ${INCLUDEDIR}/dev/mlx5 +.PATH: ${SRCTOP}/sys/dev/wg +WG= if_wg.h +WGDIR= ${INCLUDEDIR}/dev/wg + INCSGROUPS= INCS \ ACPICA \ AGP \ @@ -244,7 +248,8 @@ INCSGROUPS= INCS \ RPC \ SECAUDIT \ TEKEN \ - VERIEXEC + VERIEXEC \ + WG .if ${MK_IPFILTER} != "no" INCSGROUPS+= IPFILTER diff --git a/share/man/man4/Makefile b/share/man/man4/Makefile index 4650d9d3ede8..413ac035003d 100644 --- a/share/man/man4/Makefile +++ b/share/man/man4/Makefile @@ -584,6 +584,7 @@ MAN= aac.4 \ vtnet.4 \ watchdog.4 \ ${_wbwd.4} \ + wg.4 \ witness.4 \ wlan.4 \ wlan_acl.4 \ @@ -761,6 +762,7 @@ MLINKS+=vr.4 if_vr.4 MLINKS+=vte.4 if_vte.4 MLINKS+=vtnet.4 if_vtnet.4 MLINKS+=watchdog.4 SW_WATCHDOG.4 +MLINKS+=wg.4 if_wg.4 MLINKS+=${_wpi.4} ${_if_wpi.4} MLINKS+=xl.4 if_xl.4 diff --git a/share/man/man4/wg.4 b/share/man/man4/wg.4 new file mode 100644 index 000000000000..f2ae425002d7 --- /dev/null +++ b/share/man/man4/wg.4 @@ -0,0 +1,213 @@ +.\" Copyright (c) 2020 Gordon Bergling <gbe@FreeBSD.org> +.\" +.\" Redistribution and use in source and binary forms, with or without +.\" modification, are permitted provided that the following conditions +.\" are met: +.\" 1. Redistributions of source code must retain the above copyright +.\" notice, this list of conditions and the following disclaimer. +.\" 2. Redistributions in binary form must reproduce the above copyright +.\" notice, this list of conditions and the following disclaimer in the +.\" documentation and/or other materials provided with the distribution. +.\" +.\" THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +.\" ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +.\" IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +.\" ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +.\" FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +.\" DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +.\" OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +.\" HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +.\" LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +.\" OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +.\" SUCH DAMAGE. +.\" +.\" $FreeBSD$ +.\" +.Dd October 28, 2022 +.Dt WG 4 +.Os +.Sh NAME +.Nm wg +.Nd "WireGuard - pseudo-device" +.Sh SYNOPSIS +To load the driver as a module at boot time, place the following line in +.Xr loader.conf 5 : +.Bd -literal -offset indent +if_wg_load="YES" +.Ed +.Sh DESCRIPTION +The +.Nm +driver provides Virtual Private Network (VPN) interfaces for the secure +exchange of layer 3 traffic with other WireGuard peers using the WireGuard +protocol. +.Pp +A +.Nm +interface recognises one or more peers, establishes a secure tunnel with +each on demand, and tracks each peer's UDP endpoint for exchanging encrypted +traffic with. +.Pp +The interfaces can be created at runtime using the +.Ic ifconfig Cm wg Ns Ar N Cm create +command. +The interface itself can be configured with +.Xr wg 8 . +.Pp +The following glossary provides a brief overview of WireGuard +terminology: +.Bl -tag -width indent -offset 3n +.It Peer +Peers exchange IPv4 or IPv6 traffic over secure tunnels. +Each +.Nm +interface may be configured to recognise one or more peers. +.It Key +Each peer uses its private key and corresponding public key to +identify itself to others. +A peer configures a +.Nm +interface with its own private key and with the public keys of its peers. +.It Pre-shared key +In addition to the public keys, each peer pair may be configured with a +unique pre-shared symmetric key. +This is used in their handshake to guard against future compromise of the +peers' encrypted tunnel if a quantum-computational attack on their +Diffie-Hellman exchange becomes feasible. +It is optional, but recommended. +.It Allowed IPs +A single +.Nm +interface may maintain concurrent tunnels connecting diverse networks. +The interface therefore implements rudimentary routing and reverse-path +filtering functions for its tunneled traffic. +These functions reference a set of allowed IP ranges configured against +each peer. +.Pp +The interface will route outbound tunneled traffic to the peer configured +with the most specific matching allowed IP address range, or drop it +if no such match exists. +.Pp +The interface will accept tunneled traffic only from the peer +configured with the most specific matching allowed IP address range +for the incoming traffic, or drop it if no such match exists. +That is, tunneled traffic routed to a given peer cannot return through +another peer of the same +.Nm +interface. +This ensures that peers cannot spoof another's traffic. +.It Handshake +Two peers handshake to mutually authenticate each other and to +establish a shared series of secret ephemeral encryption keys. +Any peer may initiate a handshake. +Handshakes occur only when there is traffic to send, and recur every +two minutes during transfers. +.It Connectionless +Due to the handshake behavior, there is no connected or disconnected +state. +.El +.Ss Keys +Private keys for WireGuard can be generated from any sufficiently +secure random source. +The Curve25519 keys and the pre-shared keys are both 32 bytes +long and are commonly encoded in base64 for ease of use. +.Pp +Keys can be generated with +.Xr wg 8 +as follows: +.Pp +.Dl $ wg genkey +.Pp +Although a valid Curve25519 key must have 5 bits set to +specific values, this is done by the interface and so it +will accept any random 32-byte base64 string. +.Sh EXAMPLES +Create a +.Nm +interface and set random private key. +.Bd -literal -offset indent +# ifconfig wg0 create +# wg genkey | wg set wg0 listen-port 54321 private-key /dev/stdin +.Ed +.Pp +Retrieve the associated public key from a +.Nm +interface. +.Bd -literal -offset indent +$ wg show wg0 public-key +.Ed +.Pp +Connect to a specific endpoint using its public-key and set the allowed IP address +.Bd -literal -offset indent +# wg set wg0 peer '7lWtsDdqaGB3EY9WNxRN3hVaHMtu1zXw71+bOjNOVUw=' endpoint 10.0.1.100:54321 allowed-ips 192.168.2.100/32 +.Ed +.Pp +Remove a peer +.Bd -literal -offset indent +# wg set wg0 peer '7lWtsDdqaGB3EY9WNxRN3hVaHMtu1zXw71+bOjNOVUw=' remove +.Ed +.Sh DIAGNOSTICS +The +.Nm +interface supports runtime debugging, which can be enabled with: +.Pp +.D1 Ic ifconfig Cm wg Ns Ar N Cm debug +.Pp +Some common error messages include: +.Bl -diag +.It "Handshake for peer X did not complete after 5 seconds, retrying" +Peer X did not reply to our initiation packet, for example because: +.Bl -bullet +.It +The peer does not have the local interface configured as a peer. +Peers must be able to mutually authenticate each other. +.It +The peer endpoint IP address is incorrectly configured. +.It +There are firewall rules preventing communication between hosts. +.El +.It "Invalid handshake initiation" +The incoming handshake packet could not be processed. +This is likely due to the local interface not containing +the correct public key for the peer. +.It "Invalid initiation MAC" +The incoming handshake initiation packet had an invalid MAC. +This is likely because the initiation sender has the wrong public key +for the handshake receiver. +.It "Packet has unallowed src IP from peer X" +After decryption, an incoming data packet has a source IP address that +is not assigned to the allowed IPs of Peer X. +.El +.Sh SEE ALSO +.Xr inet 4 , +.Xr ip 4 , +.Xr netintro 4 , +.Xr ipf 5 , +.Xr pf.conf 5 , +.Xr ifconfig 8 , +.Xr ipfw 8 , +.Xr wg 8 +.Rs +.%T WireGuard whitepaper +.%U https://www.wireguard.com/papers/wireguard.pdf +.Re +.Sh HISTORY +The +.Nm +device driver first appeared in +.Fx 14.0 . +.Sh AUTHORS +The +.Nm +device driver written by +.An Jason A. Donenfeld Aq Mt Jason@zx2c4.com , +.An Matt Dunwoodie Aq Mt ncon@nconroy.net , +and +.An Kyle Evans Aq Mt kevans@FreeBSD.org . +.Pp +This manual page was written by +.An Gordon Bergling Aq Mt gbe@FreeBSD.org +and is based on the +.Ox +manual page written by +.An David Gwynne Aq Mt dlg@openbsd.org . diff --git a/sys/conf/NOTES b/sys/conf/NOTES index 434c739c8b21..8a9c726b792c 100644 --- a/sys/conf/NOTES +++ b/sys/conf/NOTES @@ -961,6 +961,9 @@ device enc # Link aggregation interface. device lagg +# WireGuard interface. +device wg + # # Internet family options: # diff --git a/sys/conf/files b/sys/conf/files index f4f7cf6208e1..e47f6577e39c 100644 --- a/sys/conf/files +++ b/sys/conf/files @@ -750,8 +750,8 @@ crypto/sha2/sha256c.c optional crypto | ekcd | geom_bde | \ crypto/sha2/sha512c.c optional crypto | geom_bde | zfs crypto/skein/skein.c optional crypto | zfs crypto/skein/skein_block.c optional crypto | zfs -crypto/siphash/siphash.c optional inet | inet6 -crypto/siphash/siphash_test.c optional inet | inet6 +crypto/siphash/siphash.c optional inet | inet6 | wg +crypto/siphash/siphash_test.c optional inet | inet6 | wg ddb/db_access.c optional ddb ddb/db_break.c optional ddb ddb/db_capture.c optional ddb @@ -3480,6 +3480,14 @@ dev/vt/vt_font.c optional vt dev/vt/vt_sysmouse.c optional vt dev/vte/if_vte.c optional vte pci dev/watchdog/watchdog.c standard +dev/wg/if_wg.c optional wg \ + compile-with "${NORMAL_C} -include $S/dev/wg/compat.h" +dev/wg/wg_cookie.c optional wg \ + compile-with "${NORMAL_C} -include $S/dev/wg/compat.h" +dev/wg/wg_crypto.c optional wg \ + compile-with "${NORMAL_C} -include $S/dev/wg/compat.h" +dev/wg/wg_noise.c optional wg \ + compile-with "${NORMAL_C} -include $S/dev/wg/compat.h" dev/wpi/if_wpi.c optional wpi pci wpifw.c optional wpifw \ compile-with "${AWK} -f $S/tools/fw_stub.awk wpi.fw:wpifw:153229 -mwpi -c${.TARGET}" \ diff --git a/sys/dev/wg/compat.h b/sys/dev/wg/compat.h new file mode 100644 index 000000000000..101a771579d9 --- /dev/null +++ b/sys/dev/wg/compat.h @@ -0,0 +1,118 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (c) 2022 The FreeBSD Foundation + * + * compat.h contains code that is backported from FreeBSD's main branch. + * It is different from support.h, which is for code that is not _yet_ upstream. + */ + +#include <sys/param.h> + +#if (__FreeBSD_version < 1400036 && __FreeBSD_version >= 1400000) || __FreeBSD_version < 1300519 +#define COMPAT_NEED_CHACHA20POLY1305_MBUF +#endif + +#if __FreeBSD_version < 1400048 +#define COMPAT_NEED_CHACHA20POLY1305 +#endif + +#if __FreeBSD_version < 1400049 +#define COMPAT_NEED_CURVE25519 +#endif + +#if __FreeBSD_version < 0x7fffffff /* TODO: update this when implemented */ +#define COMPAT_NEED_BLAKE2S +#endif + +#if __FreeBSD_version < 1400059 +#include <sys/sockbuf.h> +#define sbcreatecontrol(a, b, c, d, e) sbcreatecontrol(a, b, c, d) +#endif + +#if __FreeBSD_version < 1300507 +#include <sys/smp.h> +#include <sys/gtaskqueue.h> + +struct taskqgroup_cpu { + LIST_HEAD(, grouptask) tgc_tasks; + struct gtaskqueue *tgc_taskq; + int tgc_cnt; + int tgc_cpu; +}; + +struct taskqgroup { + struct taskqgroup_cpu tqg_queue[MAXCPU]; + /* Other members trimmed from compat. */ +}; + +static inline void taskqgroup_drain_all(struct taskqgroup *tqg) +{ + struct gtaskqueue *q; + + for (int i = 0; i < mp_ncpus; i++) { + q = tqg->tqg_queue[i].tgc_taskq; + if (q == NULL) + continue; + gtaskqueue_drain_all(q); + } +} +#endif + +#if __FreeBSD_version < 1300000 +#define VIMAGE + +#include <sys/types.h> +#include <sys/limits.h> +#include <sys/endian.h> +#include <sys/socket.h> +#include <sys/libkern.h> +#include <sys/malloc.h> +#include <sys/proc.h> +#include <sys/lock.h> +#include <sys/socketvar.h> +#include <sys/protosw.h> +#include <net/vnet.h> +#include <net/if.h> +#include <net/if_var.h> +#include <vm/uma.h> + +#define taskqgroup_attach(a, b, c, d, e, f) taskqgroup_attach((a), (b), (c), -1, (f)) +#define taskqgroup_attach_cpu(a, b, c, d, e, f, g) taskqgroup_attach_cpu((a), (b), (c), (d), -1, (g)) + +#undef NET_EPOCH_ENTER +#define NET_EPOCH_ENTER(et) NET_EPOCH_ENTER_ET(et) +#undef NET_EPOCH_EXIT +#define NET_EPOCH_EXIT(et) NET_EPOCH_EXIT_ET(et) +#define NET_EPOCH_CALL(f, c) epoch_call(net_epoch_preempt, (c), (f)) +#define NET_EPOCH_ASSERT() MPASS(in_epoch(net_epoch_preempt)) + +#undef atomic_load_ptr +#define atomic_load_ptr(p) (*(volatile __typeof(*p) *)(p)) + +#endif + +#if __FreeBSD_version < 1202000 +static inline uint32_t arc4random_uniform(uint32_t bound) +{ + uint32_t ret, max_mod_bound; + + if (bound < 2) + return 0; + + max_mod_bound = (1 + ~bound) % bound; + + do { + ret = arc4random(); + } while (ret < max_mod_bound); + + return ret % bound; +} + +typedef void callout_func_t(void *); + +#ifndef CSUM_SND_TAG +#define CSUM_SND_TAG 0x80000000 +#endif + +#endif diff --git a/sys/dev/wg/crypto.h b/sys/dev/wg/crypto.h new file mode 100644 index 000000000000..2115039321b1 --- /dev/null +++ b/sys/dev/wg/crypto.h @@ -0,0 +1,182 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (c) 2022 The FreeBSD Foundation + */ + +#ifndef _WG_CRYPTO +#define _WG_CRYPTO + +#include <sys/param.h> + +struct mbuf; + +int crypto_init(void); +void crypto_deinit(void); + +enum chacha20poly1305_lengths { + XCHACHA20POLY1305_NONCE_SIZE = 24, + CHACHA20POLY1305_KEY_SIZE = 32, + CHACHA20POLY1305_AUTHTAG_SIZE = 16 +}; + +#ifdef COMPAT_NEED_CHACHA20POLY1305 +void +chacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + +bool +chacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + +void +xchacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + +bool +xchacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); +#else +#include <sys/endian.h> +#include <crypto/chacha20_poly1305.h> + +static inline void +chacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + uint8_t nonce_bytes[8]; + + le64enc(nonce_bytes, nonce); + chacha20_poly1305_encrypt(dst, src, src_len, ad, ad_len, + nonce_bytes, sizeof(nonce_bytes), key); +} + +static inline bool +chacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + uint8_t nonce_bytes[8]; + + le64enc(nonce_bytes, nonce); + return (chacha20_poly1305_decrypt(dst, src, src_len, ad, ad_len, + nonce_bytes, sizeof(nonce_bytes), key)); +} + +static inline void +xchacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + xchacha20_poly1305_encrypt(dst, src, src_len, ad, ad_len, nonce, key); +} + +static inline bool +xchacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + return (xchacha20_poly1305_decrypt(dst, src, src_len, ad, ad_len, nonce, key)); +} +#endif + +int +chacha20poly1305_encrypt_mbuf(struct mbuf *, const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + +int +chacha20poly1305_decrypt_mbuf(struct mbuf *, const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + + +enum blake2s_lengths { + BLAKE2S_BLOCK_SIZE = 64, + BLAKE2S_HASH_SIZE = 32, + BLAKE2S_KEY_SIZE = 32 +}; + +#ifdef COMPAT_NEED_BLAKE2S +struct blake2s_state { + uint32_t h[8]; + uint32_t t[2]; + uint32_t f[2]; + uint8_t buf[BLAKE2S_BLOCK_SIZE]; + unsigned int buflen; + unsigned int outlen; +}; + +void blake2s_init(struct blake2s_state *state, const size_t outlen); + +void blake2s_init_key(struct blake2s_state *state, const size_t outlen, + const uint8_t *key, const size_t keylen); + +void blake2s_update(struct blake2s_state *state, const uint8_t *in, size_t inlen); + +void blake2s_final(struct blake2s_state *state, uint8_t *out); + +static inline void blake2s(uint8_t *out, const uint8_t *in, const uint8_t *key, + const size_t outlen, const size_t inlen, const size_t keylen) +{ + struct blake2s_state state; + + if (keylen) + blake2s_init_key(&state, outlen, key, keylen); + else + blake2s_init(&state, outlen); + + blake2s_update(&state, in, inlen); + blake2s_final(&state, out); +} +#endif + +#ifdef COMPAT_NEED_CURVE25519 +enum curve25519_lengths { + CURVE25519_KEY_SIZE = 32 +}; + +bool curve25519(uint8_t mypublic[static CURVE25519_KEY_SIZE], + const uint8_t secret[static CURVE25519_KEY_SIZE], + const uint8_t basepoint[static CURVE25519_KEY_SIZE]); + +static inline bool +curve25519_generate_public(uint8_t pub[static CURVE25519_KEY_SIZE], + const uint8_t secret[static CURVE25519_KEY_SIZE]) +{ + static const uint8_t basepoint[CURVE25519_KEY_SIZE] = { 9 }; + + return curve25519(pub, secret, basepoint); +} + +static inline void curve25519_clamp_secret(uint8_t secret[static CURVE25519_KEY_SIZE]) +{ + secret[0] &= 248; + secret[31] = (secret[31] & 127) | 64; +} + +static inline void curve25519_generate_secret(uint8_t secret[CURVE25519_KEY_SIZE]) +{ + arc4random_buf(secret, CURVE25519_KEY_SIZE); + curve25519_clamp_secret(secret); +} +#else +#include <crypto/curve25519.h> +#endif + +#endif diff --git a/sys/dev/wg/if_wg.c b/sys/dev/wg/if_wg.c new file mode 100644 index 000000000000..59979c087db2 --- /dev/null +++ b/sys/dev/wg/if_wg.c @@ -0,0 +1,3055 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net> + * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate) + * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org> + * Copyright (c) 2022 The FreeBSD Foundation + */ + +#include "opt_inet.h" +#include "opt_inet6.h" + +#include <sys/param.h> +#include <sys/systm.h> +#include <sys/counter.h> +#include <sys/gtaskqueue.h> +#include <sys/jail.h> +#include <sys/kernel.h> +#include <sys/lock.h> +#include <sys/mbuf.h> +#include <sys/module.h> +#include <sys/nv.h> +#include <sys/priv.h> +#include <sys/protosw.h> +#include <sys/rmlock.h> +#include <sys/rwlock.h> +#include <sys/smp.h> +#include <sys/socket.h> +#include <sys/socketvar.h> +#include <sys/sockio.h> +#include <sys/sysctl.h> +#include <sys/sx.h> +#include <machine/_inttypes.h> +#include <net/bpf.h> +#include <net/ethernet.h> +#include <net/if.h> +#include <net/if_clone.h> +#include <net/if_types.h> +#include <net/if_var.h> +#include <net/netisr.h> +#include <net/radix.h> +#include <netinet/in.h> +#include <netinet6/in6_var.h> +#include <netinet/ip.h> +#include <netinet/ip6.h> +#include <netinet/ip_icmp.h> +#include <netinet/icmp6.h> +#include <netinet/udp_var.h> +#include <netinet6/nd6.h> + +#include "support.h" +#include "wg_noise.h" +#include "wg_cookie.h" +#include "version.h" +#include "if_wg.h" + +#define DEFAULT_MTU (ETHERMTU - 80) +#define MAX_MTU (IF_MAXMTU - 80) + +#define MAX_STAGED_PKT 128 +#define MAX_QUEUED_PKT 1024 +#define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1) + +#define MAX_QUEUED_HANDSHAKES 4096 + +#define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */ +#define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT) +#define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT) +#define UNDERLOAD_TIMEOUT 1 + +#define DPRINTF(sc, ...) if (sc->sc_ifp->if_flags & IFF_DEBUG) if_printf(sc->sc_ifp, ##__VA_ARGS__) + +/* First byte indicating packet type on the wire */ +#define WG_PKT_INITIATION htole32(1) +#define WG_PKT_RESPONSE htole32(2) +#define WG_PKT_COOKIE htole32(3) +#define WG_PKT_DATA htole32(4) + +#define WG_PKT_PADDING 16 +#define WG_KEY_SIZE 32 + +struct wg_pkt_initiation { + uint32_t t; + uint32_t s_idx; + uint8_t ue[NOISE_PUBLIC_KEY_LEN]; + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN]; + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]; + struct cookie_macs m; +}; + +struct wg_pkt_response { + uint32_t t; + uint32_t s_idx; + uint32_t r_idx; + uint8_t ue[NOISE_PUBLIC_KEY_LEN]; + uint8_t en[0 + NOISE_AUTHTAG_LEN]; + struct cookie_macs m; +}; + +struct wg_pkt_cookie { + uint32_t t; + uint32_t r_idx; + uint8_t nonce[COOKIE_NONCE_SIZE]; + uint8_t ec[COOKIE_ENCRYPTED_SIZE]; +}; + +struct wg_pkt_data { + uint32_t t; + uint32_t r_idx; + uint64_t nonce; + uint8_t buf[]; +}; + +struct wg_endpoint { + union { + struct sockaddr r_sa; + struct sockaddr_in r_sin; +#ifdef INET6 + struct sockaddr_in6 r_sin6; +#endif + } e_remote; + union { + struct in_addr l_in; +#ifdef INET6 + struct in6_pktinfo l_pktinfo6; +#define l_in6 l_pktinfo6.ipi6_addr +#endif + } e_local; +}; + +struct aip_addr { + uint8_t length; + union { + uint8_t bytes[16]; + uint32_t ip; + uint32_t ip6[4]; + struct in_addr in; + struct in6_addr in6; + }; +}; + +struct wg_aip { + struct radix_node a_nodes[2]; + LIST_ENTRY(wg_aip) a_entry; + struct aip_addr a_addr; + struct aip_addr a_mask; + struct wg_peer *a_peer; + sa_family_t a_af; +}; + +struct wg_packet { + STAILQ_ENTRY(wg_packet) p_serial; + STAILQ_ENTRY(wg_packet) p_parallel; + struct wg_endpoint p_endpoint; + struct noise_keypair *p_keypair; + uint64_t p_nonce; + struct mbuf *p_mbuf; + int p_mtu; + sa_family_t p_af; + enum wg_ring_state { + WG_PACKET_UNCRYPTED, + WG_PACKET_CRYPTED, + WG_PACKET_DEAD, + } p_state; +}; + +STAILQ_HEAD(wg_packet_list, wg_packet); + +struct wg_queue { + struct mtx q_mtx; + struct wg_packet_list q_queue; + size_t q_len; +}; + +struct wg_peer { + TAILQ_ENTRY(wg_peer) p_entry; + uint64_t p_id; + struct wg_softc *p_sc; + + struct noise_remote *p_remote; + struct cookie_maker p_cookie; + + struct rwlock p_endpoint_lock; + struct wg_endpoint p_endpoint; + + struct wg_queue p_stage_queue; + struct wg_queue p_encrypt_serial; + struct wg_queue p_decrypt_serial; + + bool p_enabled; + bool p_need_another_keepalive; + uint16_t p_persistent_keepalive_interval; + struct callout p_new_handshake; + struct callout p_send_keepalive; + struct callout p_retry_handshake; + struct callout p_zero_key_material; + struct callout p_persistent_keepalive; + + struct mtx p_handshake_mtx; + struct timespec p_handshake_complete; /* nanotime */ + int p_handshake_retries; + + struct grouptask p_send; + struct grouptask p_recv; + + counter_u64_t p_tx_bytes; + counter_u64_t p_rx_bytes; + + LIST_HEAD(, wg_aip) p_aips; + size_t p_aips_num; +}; + +struct wg_socket { + struct socket *so_so4; + struct socket *so_so6; + uint32_t so_user_cookie; + int so_fibnum; + in_port_t so_port; +}; + +struct wg_softc { + LIST_ENTRY(wg_softc) sc_entry; + struct ifnet *sc_ifp; + int sc_flags; + + struct ucred *sc_ucred; + struct wg_socket sc_socket; + + TAILQ_HEAD(,wg_peer) sc_peers; + size_t sc_peers_num; + + struct noise_local *sc_local; + struct cookie_checker sc_cookie; + + struct radix_node_head *sc_aip4; + struct radix_node_head *sc_aip6; + + struct grouptask sc_handshake; + struct wg_queue sc_handshake_queue; + + struct grouptask *sc_encrypt; + struct grouptask *sc_decrypt; + struct wg_queue sc_encrypt_parallel; + struct wg_queue sc_decrypt_parallel; + u_int sc_encrypt_last_cpu; + u_int sc_decrypt_last_cpu; + + struct sx sc_lock; +}; + +#define WGF_DYING 0x0001 + +#define MAX_LOOPS 8 +#define MTAG_WGLOOP 0x77676c70 /* wglp */ +#ifndef ENOKEY +#define ENOKEY ENOTCAPABLE +#endif + +#define GROUPTASK_DRAIN(gtask) \ + gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task) + +#define BPF_MTAP2_AF(ifp, m, af) do { \ + uint32_t __bpf_tap_af = (af); \ + BPF_MTAP2(ifp, &__bpf_tap_af, sizeof(__bpf_tap_af), m); \ + } while (0) + +static int clone_count; +static uma_zone_t wg_packet_zone; +static volatile unsigned long peer_counter = 0; +static const char wgname[] = "wg"; +static unsigned wg_osd_jail_slot; + +static struct sx wg_sx; +SX_SYSINIT(wg_sx, &wg_sx, "wg_sx"); + +static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list); + +static TASKQGROUP_DEFINE(wg_tqg, mp_ncpus, 1); + +MALLOC_DEFINE(M_WG, "WG", "wireguard"); + +VNET_DEFINE_STATIC(struct if_clone *, wg_cloner); + +#define V_wg_cloner VNET(wg_cloner) +#define WG_CAPS IFCAP_LINKSTATE + +struct wg_timespec64 { + uint64_t tv_sec; + uint64_t tv_nsec; +}; + +static int wg_socket_init(struct wg_softc *, in_port_t); +static int wg_socket_bind(struct socket **, struct socket **, in_port_t *); +static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *); +static void wg_socket_uninit(struct wg_softc *); +static int wg_socket_set_sockopt(struct socket *, struct socket *, int, void *, size_t); +static int wg_socket_set_cookie(struct wg_softc *, uint32_t); +static int wg_socket_set_fibnum(struct wg_softc *, int); +static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *); +static void wg_timers_enable(struct wg_peer *); +static void wg_timers_disable(struct wg_peer *); +static void wg_timers_set_persistent_keepalive(struct wg_peer *, uint16_t); +static void wg_timers_get_last_handshake(struct wg_peer *, struct wg_timespec64 *); +static void wg_timers_event_data_sent(struct wg_peer *); +static void wg_timers_event_data_received(struct wg_peer *); +static void wg_timers_event_any_authenticated_packet_sent(struct wg_peer *); +static void wg_timers_event_any_authenticated_packet_received(struct wg_peer *); +static void wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *); +static void wg_timers_event_handshake_initiated(struct wg_peer *); +static void wg_timers_event_handshake_complete(struct wg_peer *); +static void wg_timers_event_session_derived(struct wg_peer *); +static void wg_timers_event_want_initiation(struct wg_peer *); +static void wg_timers_run_send_initiation(struct wg_peer *, bool); +static void wg_timers_run_retry_handshake(void *); +static void wg_timers_run_send_keepalive(void *); +static void wg_timers_run_new_handshake(void *); +static void wg_timers_run_zero_key_material(void *); +static void wg_timers_run_persistent_keepalive(void *); +static int wg_aip_add(struct wg_softc *, struct wg_peer *, sa_family_t, const void *, uint8_t); +static struct wg_peer *wg_aip_lookup(struct wg_softc *, sa_family_t, void *); +static void wg_aip_remove_all(struct wg_softc *, struct wg_peer *); +static struct wg_peer *wg_peer_alloc(struct wg_softc *, const uint8_t [WG_KEY_SIZE]); +static void wg_peer_free_deferred(struct noise_remote *); +static void wg_peer_destroy(struct wg_peer *); +static void wg_peer_destroy_all(struct wg_softc *); +static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t); +static void wg_send_initiation(struct wg_peer *); +static void wg_send_response(struct wg_peer *); +static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *); +static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *); +static void wg_peer_clear_src(struct wg_peer *); +static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *); +static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t); +static void wg_send_keepalive(struct wg_peer *); +static void wg_handshake(struct wg_softc *, struct wg_packet *); +static void wg_encrypt(struct wg_softc *, struct wg_packet *); +static void wg_decrypt(struct wg_softc *, struct wg_packet *); +static void wg_softc_handshake_receive(struct wg_softc *); +static void wg_softc_decrypt(struct wg_softc *); +static void wg_softc_encrypt(struct wg_softc *); +static void wg_encrypt_dispatch(struct wg_softc *); +static void wg_decrypt_dispatch(struct wg_softc *); +static void wg_deliver_out(struct wg_peer *); +static void wg_deliver_in(struct wg_peer *); +static struct wg_packet *wg_packet_alloc(struct mbuf *); +static void wg_packet_free(struct wg_packet *); +static void wg_queue_init(struct wg_queue *, const char *); +static void wg_queue_deinit(struct wg_queue *); +static size_t wg_queue_len(struct wg_queue *); +static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *); +static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *); +static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *); +static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *); +static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *); +static void wg_queue_purge(struct wg_queue *); +static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *); +static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *); +static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *); +static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); +static void wg_peer_send_staged(struct wg_peer *); +static int wg_clone_create(struct if_clone *, int, caddr_t); +static void wg_qflush(struct ifnet *); +static inline int determine_af_and_pullup(struct mbuf **m, sa_family_t *af); +static int wg_xmit(struct ifnet *, struct mbuf *, sa_family_t, uint32_t); +static int wg_transmit(struct ifnet *, struct mbuf *); +static int wg_output(struct ifnet *, struct mbuf *, const struct sockaddr *, struct route *); +static void wg_clone_destroy(struct ifnet *); +static bool wgc_privileged(struct wg_softc *); +static int wgc_get(struct wg_softc *, struct wg_data_io *); +static int wgc_set(struct wg_softc *, struct wg_data_io *); +static int wg_up(struct wg_softc *); +static void wg_down(struct wg_softc *); +static void wg_reassign(struct ifnet *, struct vnet *, char *unused); +static void wg_init(void *); +static int wg_ioctl(struct ifnet *, u_long, caddr_t); +static void vnet_wg_init(const void *); +static void vnet_wg_uninit(const void *); +static int wg_module_init(void); +static void wg_module_deinit(void); + +/* TODO Peer */ +static struct wg_peer * +wg_peer_alloc(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE]) +{ + struct wg_peer *peer; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO); + peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key); + peer->p_tx_bytes = counter_u64_alloc(M_WAITOK); + peer->p_rx_bytes = counter_u64_alloc(M_WAITOK); + peer->p_id = peer_counter++; + peer->p_sc = sc; + + cookie_maker_init(&peer->p_cookie, pub_key); + + rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint"); + + wg_queue_init(&peer->p_stage_queue, "stageq"); + wg_queue_init(&peer->p_encrypt_serial, "txq"); + wg_queue_init(&peer->p_decrypt_serial, "rxq"); + + peer->p_enabled = false; + peer->p_need_another_keepalive = false; + peer->p_persistent_keepalive_interval = 0; + callout_init(&peer->p_new_handshake, true); + callout_init(&peer->p_send_keepalive, true); + callout_init(&peer->p_retry_handshake, true); + callout_init(&peer->p_persistent_keepalive, true); + callout_init(&peer->p_zero_key_material, true); + + mtx_init(&peer->p_handshake_mtx, "peer handshake", NULL, MTX_DEF); + bzero(&peer->p_handshake_complete, sizeof(peer->p_handshake_complete)); + peer->p_handshake_retries = 0; + + GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer); + taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send"); + GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer); + taskqgroup_attach(qgroup_wg_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv"); + + LIST_INIT(&peer->p_aips); + peer->p_aips_num = 0; + + return (peer); +} + +static void +wg_peer_free_deferred(struct noise_remote *r) +{ + struct wg_peer *peer = noise_remote_arg(r); + + /* While there are no references remaining, we may still have + * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out} + * needs to check the queue. We should wait for them and then free. */ + GROUPTASK_DRAIN(&peer->p_recv); + GROUPTASK_DRAIN(&peer->p_send); + taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv); + taskqgroup_detach(qgroup_wg_tqg, &peer->p_send); + + wg_queue_deinit(&peer->p_decrypt_serial); + wg_queue_deinit(&peer->p_encrypt_serial); + wg_queue_deinit(&peer->p_stage_queue); + + counter_u64_free(peer->p_tx_bytes); + counter_u64_free(peer->p_rx_bytes); + rw_destroy(&peer->p_endpoint_lock); + mtx_destroy(&peer->p_handshake_mtx); + + cookie_maker_free(&peer->p_cookie); + + free(peer, M_WG); +} + +static void +wg_peer_destroy(struct wg_peer *peer) +{ + struct wg_softc *sc = peer->p_sc; + sx_assert(&sc->sc_lock, SX_XLOCKED); + + /* Disable remote and timers. This will prevent any new handshakes + * occuring. */ + noise_remote_disable(peer->p_remote); + wg_timers_disable(peer); + + /* Now we can remove all allowed IPs so no more packets will be routed + * to the peer. */ + wg_aip_remove_all(sc, peer); + + /* Remove peer from the interface, then free. Some references may still + * exist to p_remote, so noise_remote_free will wait until they're all + * put to call wg_peer_free_deferred. */ + sc->sc_peers_num--; + TAILQ_REMOVE(&sc->sc_peers, peer, p_entry); + DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id); + noise_remote_free(peer->p_remote, wg_peer_free_deferred); +} + +static void +wg_peer_destroy_all(struct wg_softc *sc) +{ + struct wg_peer *peer, *tpeer; + TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer) + wg_peer_destroy(peer); +} + +static void +wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e) +{ + MPASS(e->e_remote.r_sa.sa_family != 0); + if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0) + return; + + rw_wlock(&peer->p_endpoint_lock); + peer->p_endpoint = *e; + rw_wunlock(&peer->p_endpoint_lock); +} + +static void +wg_peer_clear_src(struct wg_peer *peer) +{ + rw_wlock(&peer->p_endpoint_lock); + bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local)); + rw_wunlock(&peer->p_endpoint_lock); +} + +static void +wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e) +{ + rw_rlock(&peer->p_endpoint_lock); + *e = peer->p_endpoint; + rw_runlock(&peer->p_endpoint_lock); +} + +/* Allowed IP */ +static int +wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void *addr, uint8_t cidr) +{ + struct radix_node_head *root; + struct radix_node *node; + struct wg_aip *aip; + int ret = 0; + + aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO); + aip->a_peer = peer; + aip->a_af = af; + + switch (af) { +#ifdef INET + case AF_INET: + if (cidr > 32) cidr = 32; + root = sc->sc_aip4; + aip->a_addr.in = *(const struct in_addr *)addr; + aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff); + aip->a_addr.ip &= aip->a_mask.ip; + aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr); + break; +#endif +#ifdef INET6 + case AF_INET6: + if (cidr > 128) cidr = 128; + root = sc->sc_aip6; + aip->a_addr.in6 = *(const struct in6_addr *)addr; + in6_prefixlen2mask(&aip->a_mask.in6, cidr); + for (int i = 0; i < 4; i++) + aip->a_addr.ip6[i] &= aip->a_mask.ip6[i]; + aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr); + break; +#endif + default: + free(aip, M_WG); + return (EAFNOSUPPORT); + } + + RADIX_NODE_HEAD_LOCK(root); + node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes); + if (node == aip->a_nodes) { + LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry); + peer->p_aips_num++; + } else if (!node) + node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh); + if (!node) { + free(aip, M_WG); + return (ENOMEM); + } else if (node != aip->a_nodes) { + free(aip, M_WG); + aip = (struct wg_aip *)node; + if (aip->a_peer != peer) { + LIST_REMOVE(aip, a_entry); + aip->a_peer->p_aips_num--; + aip->a_peer = peer; + LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry); + aip->a_peer->p_aips_num++; + } + } + RADIX_NODE_HEAD_UNLOCK(root); + return (ret); +} + +static struct wg_peer * +wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a) +{ + struct radix_node_head *root; + struct radix_node *node; + struct wg_peer *peer; + struct aip_addr addr; + RADIX_NODE_HEAD_RLOCK_TRACKER; + + switch (af) { + case AF_INET: + root = sc->sc_aip4; + memcpy(&addr.in, a, sizeof(addr.in)); + addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr); + break; + case AF_INET6: + root = sc->sc_aip6; + memcpy(&addr.in6, a, sizeof(addr.in6)); + addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr); + break; + default: + return NULL; + } + + RADIX_NODE_HEAD_RLOCK(root); + node = root->rnh_matchaddr(&addr, &root->rh); + if (node != NULL) { + peer = ((struct wg_aip *)node)->a_peer; + noise_remote_ref(peer->p_remote); + } else { + peer = NULL; + } + RADIX_NODE_HEAD_RUNLOCK(root); + + return (peer); +} + +static void +wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer) +{ + struct wg_aip *aip, *taip; + + RADIX_NODE_HEAD_LOCK(sc->sc_aip4); + LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) { + if (aip->a_af == AF_INET) { + if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL) + panic("failed to delete aip %p", aip); + LIST_REMOVE(aip, a_entry); + peer->p_aips_num--; + free(aip, M_WG); + } + } + RADIX_NODE_HEAD_UNLOCK(sc->sc_aip4); + + RADIX_NODE_HEAD_LOCK(sc->sc_aip6); + LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) { + if (aip->a_af == AF_INET6) { + if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL) + panic("failed to delete aip %p", aip); + LIST_REMOVE(aip, a_entry); + peer->p_aips_num--; + free(aip, M_WG); + } + } + RADIX_NODE_HEAD_UNLOCK(sc->sc_aip6); + + if (!LIST_EMPTY(&peer->p_aips) || peer->p_aips_num != 0) + panic("wg_aip_remove_all could not delete all %p", peer); +} + +static int +wg_socket_init(struct wg_softc *sc, in_port_t port) +{ + struct ucred *cred = sc->sc_ucred; + struct socket *so4 = NULL, *so6 = NULL; + int rc; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + if (!cred) + return (EBUSY); + + /* + * For socket creation, we use the creds of the thread that created the + * tunnel rather than the current thread to maintain the semantics that + * WireGuard has on Linux with network namespaces -- that the sockets + * are created in their home vnet so that they can be configured and + * functionally attached to a foreign vnet as the jail's only interface + * to the network. + */ +#ifdef INET + rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, curthread); + if (rc) + goto out; + + rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc); + /* + * udp_set_kernel_tunneling can only fail if there is already a tunneling function set. + * This should never happen with a new socket. + */ + MPASS(rc == 0); +#endif + +#ifdef INET6 + rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, curthread); + if (rc) + goto out; + rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc); + MPASS(rc == 0); +#endif + + if (sc->sc_socket.so_user_cookie) { + rc = wg_socket_set_sockopt(so4, so6, SO_USER_COOKIE, &sc->sc_socket.so_user_cookie, sizeof(sc->sc_socket.so_user_cookie)); + if (rc) + goto out; + } + rc = wg_socket_set_sockopt(so4, so6, SO_SETFIB, &sc->sc_socket.so_fibnum, sizeof(sc->sc_socket.so_fibnum)); + if (rc) + goto out; + + rc = wg_socket_bind(&so4, &so6, &port); + if (!rc) { + sc->sc_socket.so_port = port; + wg_socket_set(sc, so4, so6); + } +out: + if (rc) { + if (so4 != NULL) + soclose(so4); + if (so6 != NULL) + soclose(so6); + } + return (rc); +} + +static int wg_socket_set_sockopt(struct socket *so4, struct socket *so6, int name, void *val, size_t len) +{ + int ret4 = 0, ret6 = 0; + struct sockopt sopt = { + .sopt_dir = SOPT_SET, + .sopt_level = SOL_SOCKET, + .sopt_name = name, + .sopt_val = val, + .sopt_valsize = len + }; + + if (so4) + ret4 = sosetopt(so4, &sopt); + if (so6) + ret6 = sosetopt(so6, &sopt); + return (ret4 ?: ret6); +} + +static int wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie) +{ + struct wg_socket *so = &sc->sc_socket; + int ret; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_USER_COOKIE, &user_cookie, sizeof(user_cookie)); + if (!ret) + so->so_user_cookie = user_cookie; + return (ret); +} + +static int wg_socket_set_fibnum(struct wg_softc *sc, int fibnum) +{ + struct wg_socket *so = &sc->sc_socket; + int ret; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_SETFIB, &fibnum, sizeof(fibnum)); + if (!ret) + so->so_fibnum = fibnum; + return (ret); +} + +static void +wg_socket_uninit(struct wg_softc *sc) +{ + wg_socket_set(sc, NULL, NULL); +} + +static void +wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6) +{ + struct wg_socket *so = &sc->sc_socket; + struct socket *so4, *so6; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + so4 = ck_pr_load_ptr(&so->so_so4); + so6 = ck_pr_load_ptr(&so->so_so6); + ck_pr_store_ptr(&so->so_so4, new_so4); + ck_pr_store_ptr(&so->so_so6, new_so6); + + if (!so4 && !so6) + return; + NET_EPOCH_WAIT(); + if (so4) + soclose(so4); + if (so6) + soclose(so6); +} + +static int +wg_socket_bind(struct socket **in_so4, struct socket **in_so6, in_port_t *requested_port) +{ + struct socket *so4 = *in_so4, *so6 = *in_so6; + int ret4 = 0, ret6 = 0; + in_port_t port = *requested_port; + struct sockaddr_in sin = { + .sin_len = sizeof(struct sockaddr_in), + .sin_family = AF_INET, + .sin_port = htons(port) + }; + struct sockaddr_in6 sin6 = { + .sin6_len = sizeof(struct sockaddr_in6), + .sin6_family = AF_INET6, + .sin6_port = htons(port) + }; + + if (so4) { + ret4 = sobind(so4, (struct sockaddr *)&sin, curthread); + if (ret4 && ret4 != EADDRNOTAVAIL) + return (ret4); + if (!ret4 && !sin.sin_port) { + struct sockaddr_in *bound_sin; + int ret = so4->so_proto->pr_sockaddr(so4, + (struct sockaddr **)&bound_sin); + if (ret) + return (ret); + port = ntohs(bound_sin->sin_port); + sin6.sin6_port = bound_sin->sin_port; + free(bound_sin, M_SONAME); + } + } + + if (so6) { + ret6 = sobind(so6, (struct sockaddr *)&sin6, curthread); + if (ret6 && ret6 != EADDRNOTAVAIL) + return (ret6); + if (!ret6 && !sin6.sin6_port) { + struct sockaddr_in6 *bound_sin6; + int ret = so6->so_proto->pr_sockaddr(so6, + (struct sockaddr **)&bound_sin6); + if (ret) + return (ret); + port = ntohs(bound_sin6->sin6_port); + free(bound_sin6, M_SONAME); + } + } + + if (ret4 && ret6) + return (ret4); + *requested_port = port; + if (ret4 && !ret6 && so4) { + soclose(so4); + *in_so4 = NULL; + } else if (ret6 && !ret4 && so6) { + soclose(so6); + *in_so6 = NULL; + } + return (0); +} + +static int +wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m) +{ + struct epoch_tracker et; + struct sockaddr *sa; + struct wg_socket *so = &sc->sc_socket; + struct socket *so4, *so6; + struct mbuf *control = NULL; + int ret = 0; + size_t len = m->m_pkthdr.len; + + /* Get local control address before locking */ + if (e->e_remote.r_sa.sa_family == AF_INET) { + if (e->e_local.l_in.s_addr != INADDR_ANY) + control = sbcreatecontrol((caddr_t)&e->e_local.l_in, + sizeof(struct in_addr), IP_SENDSRCADDR, + IPPROTO_IP, M_NOWAIT); +#ifdef INET6 + } else if (e->e_remote.r_sa.sa_family == AF_INET6) { + if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6)) + control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6, + sizeof(struct in6_pktinfo), IPV6_PKTINFO, + IPPROTO_IPV6, M_NOWAIT); +#endif + } else { + m_freem(m); + return (EAFNOSUPPORT); + } + + /* Get remote address */ + sa = &e->e_remote.r_sa; + + NET_EPOCH_ENTER(et); + so4 = ck_pr_load_ptr(&so->so_so4); + so6 = ck_pr_load_ptr(&so->so_so6); + if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL) + ret = sosend(so4, sa, NULL, m, control, 0, curthread); + else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL) + ret = sosend(so6, sa, NULL, m, control, 0, curthread); + else { + ret = ENOTCONN; + m_freem(control); + m_freem(m); + } + NET_EPOCH_EXIT(et); + if (ret == 0) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len); + } + return (ret); +} + +static void +wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, size_t len) +{ + struct mbuf *m; + int ret = 0; + bool retried = false; + +retry: + m = m_get2(len, M_NOWAIT, MT_DATA, M_PKTHDR); + if (!m) { + ret = ENOMEM; + goto out; + } + m_copyback(m, 0, len, buf); + + if (ret == 0) { + ret = wg_send(sc, e, m); + /* Retry if we couldn't bind to e->e_local */ + if (ret == EADDRNOTAVAIL && !retried) { + bzero(&e->e_local, sizeof(e->e_local)); + retried = true; + goto retry; + } + } else { + ret = wg_send(sc, e, m); + } +out: + if (ret) + DPRINTF(sc, "Unable to send packet: %d\n", ret); +} + +/* Timers */ +static void +wg_timers_enable(struct wg_peer *peer) +{ + ck_pr_store_bool(&peer->p_enabled, true); + wg_timers_run_persistent_keepalive(peer); +} + +static void +wg_timers_disable(struct wg_peer *peer) +{ + /* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be + * sure no new handshakes are created after the wait. This is because + * all callout_resets (scheduling the callout) are guarded by + * p_enabled. We can be sure all sections that read p_enabled and then + * optionally call callout_reset are finished as they are surrounded by + * NET_EPOCH_{ENTER,EXIT}. + * + * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but + * not after), we stop all callouts leaving no callouts active. + * + * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the + * performance impact is acceptable for the time being. */ + ck_pr_store_bool(&peer->p_enabled, false); + NET_EPOCH_WAIT(); + ck_pr_store_bool(&peer->p_need_another_keepalive, false); + + callout_stop(&peer->p_new_handshake); + callout_stop(&peer->p_send_keepalive); + callout_stop(&peer->p_retry_handshake); + callout_stop(&peer->p_persistent_keepalive); + callout_stop(&peer->p_zero_key_material); +} + +static void +wg_timers_set_persistent_keepalive(struct wg_peer *peer, uint16_t interval) +{ + struct epoch_tracker et; + if (interval != peer->p_persistent_keepalive_interval) { + ck_pr_store_16(&peer->p_persistent_keepalive_interval, interval); + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled)) + wg_timers_run_persistent_keepalive(peer); + NET_EPOCH_EXIT(et); + } +} + +static void +wg_timers_get_last_handshake(struct wg_peer *peer, struct wg_timespec64 *time) +{ + mtx_lock(&peer->p_handshake_mtx); + time->tv_sec = peer->p_handshake_complete.tv_sec; + time->tv_nsec = peer->p_handshake_complete.tv_nsec; + mtx_unlock(&peer->p_handshake_mtx); +} + +static void +wg_timers_event_data_sent(struct wg_peer *peer) +{ + struct epoch_tracker et; + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled) && !callout_pending(&peer->p_new_handshake)) + callout_reset(&peer->p_new_handshake, MSEC_2_TICKS( + NEW_HANDSHAKE_TIMEOUT * 1000 + + arc4random_uniform(REKEY_TIMEOUT_JITTER)), + wg_timers_run_new_handshake, peer); + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_event_data_received(struct wg_peer *peer) +{ + struct epoch_tracker et; + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled)) { + if (!callout_pending(&peer->p_send_keepalive)) + callout_reset(&peer->p_send_keepalive, + MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000), + wg_timers_run_send_keepalive, peer); + else + ck_pr_store_bool(&peer->p_need_another_keepalive, true); + } + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_event_any_authenticated_packet_sent(struct wg_peer *peer) +{ + callout_stop(&peer->p_send_keepalive); +} + +static void +wg_timers_event_any_authenticated_packet_received(struct wg_peer *peer) +{ + callout_stop(&peer->p_new_handshake); +} + +static void +wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *peer) +{ + struct epoch_tracker et; + uint16_t interval; + NET_EPOCH_ENTER(et); + interval = ck_pr_load_16(&peer->p_persistent_keepalive_interval); + if (ck_pr_load_bool(&peer->p_enabled) && interval > 0) + callout_reset(&peer->p_persistent_keepalive, + MSEC_2_TICKS(interval * 1000), + wg_timers_run_persistent_keepalive, peer); + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_event_handshake_initiated(struct wg_peer *peer) +{ + struct epoch_tracker et; + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled)) + callout_reset(&peer->p_retry_handshake, MSEC_2_TICKS( + REKEY_TIMEOUT * 1000 + + arc4random_uniform(REKEY_TIMEOUT_JITTER)), + wg_timers_run_retry_handshake, peer); + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_event_handshake_complete(struct wg_peer *peer) +{ + struct epoch_tracker et; + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled)) { + mtx_lock(&peer->p_handshake_mtx); + callout_stop(&peer->p_retry_handshake); + peer->p_handshake_retries = 0; + getnanotime(&peer->p_handshake_complete); + mtx_unlock(&peer->p_handshake_mtx); + wg_timers_run_send_keepalive(peer); + } + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_event_session_derived(struct wg_peer *peer) +{ + struct epoch_tracker et; + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled)) + callout_reset(&peer->p_zero_key_material, + MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000), + wg_timers_run_zero_key_material, peer); + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_event_want_initiation(struct wg_peer *peer) +{ + struct epoch_tracker et; + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled)) + wg_timers_run_send_initiation(peer, false); + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_run_send_initiation(struct wg_peer *peer, bool is_retry) +{ + if (!is_retry) + peer->p_handshake_retries = 0; + if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT) + wg_send_initiation(peer); +} + +static void +wg_timers_run_retry_handshake(void *_peer) +{ + struct epoch_tracker et; + struct wg_peer *peer = _peer; + + mtx_lock(&peer->p_handshake_mtx); + if (peer->p_handshake_retries <= MAX_TIMER_HANDSHAKES) { + peer->p_handshake_retries++; + mtx_unlock(&peer->p_handshake_mtx); + + DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete " + "after %d seconds, retrying (try %d)\n", peer->p_id, + REKEY_TIMEOUT, peer->p_handshake_retries + 1); + wg_peer_clear_src(peer); + wg_timers_run_send_initiation(peer, true); + } else { + mtx_unlock(&peer->p_handshake_mtx); + + DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete " + "after %d retries, giving up\n", peer->p_id, + MAX_TIMER_HANDSHAKES + 2); + + callout_stop(&peer->p_send_keepalive); + wg_queue_purge(&peer->p_stage_queue); + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled) && + !callout_pending(&peer->p_zero_key_material)) + callout_reset(&peer->p_zero_key_material, + MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000), + wg_timers_run_zero_key_material, peer); + NET_EPOCH_EXIT(et); + } +} + +static void +wg_timers_run_send_keepalive(void *_peer) +{ + struct epoch_tracker et; + struct wg_peer *peer = _peer; + + wg_send_keepalive(peer); + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled) && + ck_pr_load_bool(&peer->p_need_another_keepalive)) { + ck_pr_store_bool(&peer->p_need_another_keepalive, false); + callout_reset(&peer->p_send_keepalive, + MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000), + wg_timers_run_send_keepalive, peer); + } + NET_EPOCH_EXIT(et); +} + +static void +wg_timers_run_new_handshake(void *_peer) +{ + struct wg_peer *peer = _peer; + + DPRINTF(peer->p_sc, "Retrying handshake with peer %" PRIu64 " because we " + "stopped hearing back after %d seconds\n", + peer->p_id, NEW_HANDSHAKE_TIMEOUT); + + wg_peer_clear_src(peer); + wg_timers_run_send_initiation(peer, false); +} + +static void +wg_timers_run_zero_key_material(void *_peer) +{ + struct wg_peer *peer = _peer; + + DPRINTF(peer->p_sc, "Zeroing out keys for peer %" PRIu64 ", since we " + "haven't received a new one in %d seconds\n", + peer->p_id, REJECT_AFTER_TIME * 3); + noise_remote_keypairs_clear(peer->p_remote); +} + +static void +wg_timers_run_persistent_keepalive(void *_peer) +{ + struct wg_peer *peer = _peer; + + if (ck_pr_load_16(&peer->p_persistent_keepalive_interval) > 0) + wg_send_keepalive(peer); +} + +/* TODO Handshake */ +static void +wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len) +{ + struct wg_endpoint endpoint; + + counter_u64_add(peer->p_tx_bytes, len); + wg_timers_event_any_authenticated_packet_traversal(peer); + wg_timers_event_any_authenticated_packet_sent(peer); + wg_peer_get_endpoint(peer, &endpoint); + wg_send_buf(peer->p_sc, &endpoint, buf, len); +} + +static void +wg_send_initiation(struct wg_peer *peer) +{ + struct wg_pkt_initiation pkt; + + if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue, + pkt.es, pkt.ets) != 0) + return; + + DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id); + + pkt.t = WG_PKT_INITIATION; + cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, + sizeof(pkt) - sizeof(pkt.m)); + wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt)); + wg_timers_event_handshake_initiated(peer); +} + +static void +wg_send_response(struct wg_peer *peer) +{ + struct wg_pkt_response pkt; + + if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx, + pkt.ue, pkt.en) != 0) + return; + + DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id); + + wg_timers_event_session_derived(peer); + pkt.t = WG_PKT_RESPONSE; + cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, + sizeof(pkt)-sizeof(pkt.m)); + wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt)); +} + +static void +wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx, + struct wg_endpoint *e) +{ + struct wg_pkt_cookie pkt; + + DPRINTF(sc, "Sending cookie response for denied handshake message\n"); + + pkt.t = WG_PKT_COOKIE; + pkt.r_idx = idx; + + cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce, + pkt.ec, &e->e_remote.r_sa); + wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt)); +} + +static void +wg_send_keepalive(struct wg_peer *peer) +{ + struct wg_packet *pkt; + struct mbuf *m; + + if (wg_queue_len(&peer->p_stage_queue) > 0) + goto send; + if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL) + return; + if ((pkt = wg_packet_alloc(m)) == NULL) { + m_freem(m); + return; + } + wg_queue_push_staged(&peer->p_stage_queue, pkt); + DPRINTF(peer->p_sc, "Sending keepalive packet to peer %" PRIu64 "\n", peer->p_id); +send: + wg_peer_send_staged(peer); +} + +static void +wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) +{ + struct wg_pkt_initiation *init; + struct wg_pkt_response *resp; + struct wg_pkt_cookie *cook; + struct wg_endpoint *e; + struct wg_peer *peer; + struct mbuf *m; + struct noise_remote *remote = NULL; + int res; + bool underload = false; + static sbintime_t wg_last_underload; /* sbinuptime */ + + underload = wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES / 8; + if (underload) { + wg_last_underload = getsbinuptime(); + } else if (wg_last_underload) { + underload = wg_last_underload + UNDERLOAD_TIMEOUT * SBT_1S > getsbinuptime(); + if (!underload) + wg_last_underload = 0; + } + + m = pkt->p_mbuf; + e = &pkt->p_endpoint; + + if ((pkt->p_mbuf = m = m_pullup(m, m->m_pkthdr.len)) == NULL) + goto error; + + switch (*mtod(m, uint32_t *)) { + case WG_PKT_INITIATION: + init = mtod(m, struct wg_pkt_initiation *); + + res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m, + init, sizeof(*init) - sizeof(init->m), + underload, &e->e_remote.r_sa, + sc->sc_ifp->if_vnet); + + if (res == EINVAL) { + DPRINTF(sc, "Invalid initiation MAC\n"); + goto error; + } else if (res == ECONNREFUSED) { + DPRINTF(sc, "Handshake ratelimited\n"); + goto error; + } else if (res == EAGAIN) { + wg_send_cookie(sc, &init->m, init->s_idx, e); + goto error; + } else if (res != 0) { + panic("unexpected response: %d\n", res); + } + + if (noise_consume_initiation(sc->sc_local, &remote, + init->s_idx, init->ue, init->es, init->ets) != 0) { + DPRINTF(sc, "Invalid handshake initiation\n"); + goto error; + } + + peer = noise_remote_arg(remote); + + DPRINTF(sc, "Receiving handshake initiation from peer %" PRIu64 "\n", peer->p_id); + + wg_peer_set_endpoint(peer, e); + wg_send_response(peer); + break; + case WG_PKT_RESPONSE: + resp = mtod(m, struct wg_pkt_response *); + + res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m, + resp, sizeof(*resp) - sizeof(resp->m), + underload, &e->e_remote.r_sa, + sc->sc_ifp->if_vnet); + + if (res == EINVAL) { + DPRINTF(sc, "Invalid response MAC\n"); + goto error; + } else if (res == ECONNREFUSED) { + DPRINTF(sc, "Handshake ratelimited\n"); + goto error; + } else if (res == EAGAIN) { + wg_send_cookie(sc, &resp->m, resp->s_idx, e); + goto error; + } else if (res != 0) { + panic("unexpected response: %d\n", res); + } + + if (noise_consume_response(sc->sc_local, &remote, + resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) { + DPRINTF(sc, "Invalid handshake response\n"); + goto error; + } + + peer = noise_remote_arg(remote); + DPRINTF(sc, "Receiving handshake response from peer %" PRIu64 "\n", peer->p_id); + + wg_peer_set_endpoint(peer, e); + wg_timers_event_session_derived(peer); + wg_timers_event_handshake_complete(peer); + break; + case WG_PKT_COOKIE: + cook = mtod(m, struct wg_pkt_cookie *); + + if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) { + DPRINTF(sc, "Unknown cookie index\n"); + goto error; + } + + peer = noise_remote_arg(remote); + + if (cookie_maker_consume_payload(&peer->p_cookie, + cook->nonce, cook->ec) == 0) { + DPRINTF(sc, "Receiving cookie response\n"); + } else { + DPRINTF(sc, "Could not decrypt cookie response\n"); + goto error; + } + + goto not_authenticated; + default: + panic("invalid packet in handshake queue"); + } + + wg_timers_event_any_authenticated_packet_received(peer); + wg_timers_event_any_authenticated_packet_traversal(peer); + +not_authenticated: + counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len); +error: + if (remote != NULL) + noise_remote_put(remote); + wg_packet_free(pkt); +} + +static void +wg_softc_handshake_receive(struct wg_softc *sc) +{ + struct wg_packet *pkt; + while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL) + wg_handshake(sc, pkt); +} + +static void +wg_mbuf_reset(struct mbuf *m) +{ + + struct m_tag *t, *tmp; + + /* + * We want to reset the mbuf to a newly allocated state, containing + * just the packet contents. Unfortunately FreeBSD doesn't seem to + * offer this anywhere, so we have to make it up as we go. If we can + * get this in kern/kern_mbuf.c, that would be best. + * + * Notice: this may break things unexpectedly but it is better to fail + * closed in the extreme case than leak informtion in every + * case. + * + * With that said, all this attempts to do is remove any extraneous + * information that could be present. + */ + + M_ASSERTPKTHDR(m); + + m->m_flags &= ~(M_BCAST|M_MCAST|M_VLANTAG|M_PROMISC|M_PROTOFLAGS); + + M_HASHTYPE_CLEAR(m); +#ifdef NUMA + m->m_pkthdr.numa_domain = M_NODOM; +#endif + SLIST_FOREACH_SAFE(t, &m->m_pkthdr.tags, m_tag_link, tmp) { + if ((t->m_tag_id != 0 || t->m_tag_cookie != MTAG_WGLOOP) && + t->m_tag_id != PACKET_TAG_MACLABEL) + m_tag_delete(m, t); + } + + KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0, + ("%s: mbuf %p has a send tag", __func__, m)); + + m->m_pkthdr.csum_flags = 0; + m->m_pkthdr.PH_per.sixtyfour[0] = 0; + m->m_pkthdr.PH_loc.sixtyfour[0] = 0; +} + +static inline unsigned int +calculate_padding(struct wg_packet *pkt) +{ + unsigned int padded_size, last_unit = pkt->p_mbuf->m_pkthdr.len; + + if (__predict_false(!pkt->p_mtu)) + return (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1); + + if (__predict_false(last_unit > pkt->p_mtu)) + last_unit %= pkt->p_mtu; + + padded_size = (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1); + if (pkt->p_mtu < padded_size) + padded_size = pkt->p_mtu; + return padded_size - last_unit; +} + +static void +wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt) +{ + static const uint8_t padding[WG_PKT_PADDING] = { 0 }; + struct wg_pkt_data *data; + struct wg_peer *peer; + struct noise_remote *remote; + struct mbuf *m; + uint32_t idx; + unsigned int padlen; + enum wg_ring_state state = WG_PACKET_DEAD; + + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); + m = pkt->p_mbuf; + + /* Pad the packet */ + padlen = calculate_padding(pkt); + if (padlen != 0 && !m_append(m, padlen, padding)) + goto out; + + /* Do encryption */ + if (noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, m) != 0) + goto out; + + /* Put header into packet */ + M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT); + if (m == NULL) + goto out; + data = mtod(m, struct wg_pkt_data *); + data->t = WG_PKT_DATA; + data->r_idx = idx; + data->nonce = htole64(pkt->p_nonce); + + wg_mbuf_reset(m); + state = WG_PACKET_CRYPTED; +out: + pkt->p_mbuf = m; + wmb(); + pkt->p_state = state; + GROUPTASK_ENQUEUE(&peer->p_send); + noise_remote_put(remote); +} + +static void +wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt) +{ + struct wg_peer *peer, *allowed_peer; + struct noise_remote *remote; + struct mbuf *m; + int len; + enum wg_ring_state state = WG_PACKET_DEAD; + + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); + m = pkt->p_mbuf; + + /* Read nonce and then adjust to remove the header. */ + pkt->p_nonce = le64toh(mtod(m, struct wg_pkt_data *)->nonce); + m_adj(m, sizeof(struct wg_pkt_data)); + + if (noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, m) != 0) + goto out; + + /* A packet with length 0 is a keepalive packet */ + if (__predict_false(m->m_pkthdr.len == 0)) { + DPRINTF(sc, "Receiving keepalive packet from peer " + "%" PRIu64 "\n", peer->p_id); + state = WG_PACKET_CRYPTED; + goto out; + } + + /* + * We can let the network stack handle the intricate validation of the + * IP header, we just worry about the sizeof and the version, so we can + * read the source address in wg_aip_lookup. + */ + + if (determine_af_and_pullup(&m, &pkt->p_af) == 0) { + if (pkt->p_af == AF_INET) { + struct ip *ip = mtod(m, struct ip *); + allowed_peer = wg_aip_lookup(sc, AF_INET, &ip->ip_src); + len = ntohs(ip->ip_len); + if (len >= sizeof(struct ip) && len < m->m_pkthdr.len) + m_adj(m, len - m->m_pkthdr.len); + } else if (pkt->p_af == AF_INET6) { + struct ip6_hdr *ip6 = mtod(m, struct ip6_hdr *); + allowed_peer = wg_aip_lookup(sc, AF_INET6, &ip6->ip6_src); + len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr); + if (len < m->m_pkthdr.len) + m_adj(m, len - m->m_pkthdr.len); + } else + panic("determine_af_and_pullup returned unexpected value"); + } else { + DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from peer %" PRIu64 "\n", peer->p_id); + goto out; + } + + /* We only want to compare the address, not dereference, so drop the ref. */ + if (allowed_peer != NULL) + noise_remote_put(allowed_peer->p_remote); + + if (__predict_false(peer != allowed_peer)) { + DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id); + goto out; + } + + wg_mbuf_reset(m); + state = WG_PACKET_CRYPTED; +out: + pkt->p_mbuf = m; + wmb(); + pkt->p_state = state; + GROUPTASK_ENQUEUE(&peer->p_recv); + noise_remote_put(remote); +} + +static void +wg_softc_decrypt(struct wg_softc *sc) +{ + struct wg_packet *pkt; + + while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL) + wg_decrypt(sc, pkt); +} + +static void +wg_softc_encrypt(struct wg_softc *sc) +{ + struct wg_packet *pkt; + + while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL) + wg_encrypt(sc, pkt); +} + +static void +wg_encrypt_dispatch(struct wg_softc *sc) +{ + /* + * The update to encrypt_last_cpu is racey such that we may + * reschedule the task for the same CPU multiple times, but + * the race doesn't really matter. + */ + u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus; + sc->sc_encrypt_last_cpu = cpu; + GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]); +} + +static void +wg_decrypt_dispatch(struct wg_softc *sc) +{ + u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus; + sc->sc_decrypt_last_cpu = cpu; + GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]); +} + +static void +wg_deliver_out(struct wg_peer *peer) +{ + struct wg_endpoint endpoint; + struct wg_softc *sc = peer->p_sc; + struct wg_packet *pkt; + struct mbuf *m; + int rc, len; + + wg_peer_get_endpoint(peer, &endpoint); + + while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) { + if (pkt->p_state != WG_PACKET_CRYPTED) + goto error; + + m = pkt->p_mbuf; + pkt->p_mbuf = NULL; + + len = m->m_pkthdr.len; + + wg_timers_event_any_authenticated_packet_traversal(peer); + wg_timers_event_any_authenticated_packet_sent(peer); + rc = wg_send(sc, &endpoint, m); + if (rc == 0) { + if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN)) + wg_timers_event_data_sent(peer); + counter_u64_add(peer->p_tx_bytes, len); + } else if (rc == EADDRNOTAVAIL) { + wg_peer_clear_src(peer); + wg_peer_get_endpoint(peer, &endpoint); + goto error; + } else { + goto error; + } + wg_packet_free(pkt); + if (noise_keep_key_fresh_send(peer->p_remote)) + wg_timers_event_want_initiation(peer); + continue; +error: + if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1); + wg_packet_free(pkt); + } +} + +static void +wg_deliver_in(struct wg_peer *peer) +{ + struct wg_softc *sc = peer->p_sc; + struct ifnet *ifp = sc->sc_ifp; + struct wg_packet *pkt; + struct mbuf *m; + struct epoch_tracker et; + + while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) { + if (pkt->p_state != WG_PACKET_CRYPTED) + goto error; + + m = pkt->p_mbuf; + if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0) + goto error; + + if (noise_keypair_received_with(pkt->p_keypair) == ECONNRESET) + wg_timers_event_handshake_complete(peer); + + wg_timers_event_any_authenticated_packet_received(peer); + wg_timers_event_any_authenticated_packet_traversal(peer); + wg_peer_set_endpoint(peer, &pkt->p_endpoint); + + counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len + + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len + + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); + + if (m->m_pkthdr.len == 0) + goto done; + + MPASS(pkt->p_af == AF_INET || pkt->p_af == AF_INET6); + pkt->p_mbuf = NULL; + + m->m_pkthdr.rcvif = ifp; + + NET_EPOCH_ENTER(et); + BPF_MTAP2_AF(ifp, m, pkt->p_af); + + CURVNET_SET(ifp->if_vnet); + M_SETFIB(m, ifp->if_fib); + if (pkt->p_af == AF_INET) + netisr_dispatch(NETISR_IP, m); + if (pkt->p_af == AF_INET6) + netisr_dispatch(NETISR_IPV6, m); + CURVNET_RESTORE(); + NET_EPOCH_EXIT(et); + + wg_timers_event_data_received(peer); + +done: + if (noise_keep_key_fresh_recv(peer->p_remote)) + wg_timers_event_want_initiation(peer); + wg_packet_free(pkt); + continue; +error: + if_inc_counter(ifp, IFCOUNTER_IERRORS, 1); + wg_packet_free(pkt); + } +} + +static struct wg_packet * +wg_packet_alloc(struct mbuf *m) +{ + struct wg_packet *pkt; + + if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT | M_ZERO)) == NULL) + return (NULL); + pkt->p_mbuf = m; + return (pkt); +} + +static void +wg_packet_free(struct wg_packet *pkt) +{ + if (pkt->p_keypair != NULL) + noise_keypair_put(pkt->p_keypair); + if (pkt->p_mbuf != NULL) + m_freem(pkt->p_mbuf); + uma_zfree(wg_packet_zone, pkt); +} + +static void +wg_queue_init(struct wg_queue *queue, const char *name) +{ + mtx_init(&queue->q_mtx, name, NULL, MTX_DEF); + STAILQ_INIT(&queue->q_queue); + queue->q_len = 0; +} + +static void +wg_queue_deinit(struct wg_queue *queue) +{ + wg_queue_purge(queue); + mtx_destroy(&queue->q_mtx); +} + +static size_t +wg_queue_len(struct wg_queue *queue) +{ + return (queue->q_len); +} + +static int +wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt) +{ + int ret = 0; + mtx_lock(&hs->q_mtx); + if (hs->q_len < MAX_QUEUED_HANDSHAKES) { + STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel); + hs->q_len++; + } else { + ret = ENOBUFS; + } + mtx_unlock(&hs->q_mtx); + if (ret != 0) + wg_packet_free(pkt); + return (ret); +} + +static struct wg_packet * +wg_queue_dequeue_handshake(struct wg_queue *hs) +{ + struct wg_packet *pkt; + mtx_lock(&hs->q_mtx); + if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) { + STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel); + hs->q_len--; + } + mtx_unlock(&hs->q_mtx); + return (pkt); +} + +static void +wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt) +{ + struct wg_packet *old = NULL; + + mtx_lock(&staged->q_mtx); + if (staged->q_len >= MAX_STAGED_PKT) { + old = STAILQ_FIRST(&staged->q_queue); + STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel); + staged->q_len--; + } + STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel); + staged->q_len++; + mtx_unlock(&staged->q_mtx); + + if (old != NULL) + wg_packet_free(old); +} + +static void +wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list) +{ + struct wg_packet *pkt, *tpkt; + STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt) + wg_queue_push_staged(staged, pkt); +} + +static void +wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list) +{ + STAILQ_INIT(list); + mtx_lock(&staged->q_mtx); + STAILQ_CONCAT(list, &staged->q_queue); + staged->q_len = 0; + mtx_unlock(&staged->q_mtx); +} + +static void +wg_queue_purge(struct wg_queue *staged) +{ + struct wg_packet_list list; + struct wg_packet *pkt, *tpkt; + wg_queue_delist_staged(staged, &list); + STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) + wg_packet_free(pkt); +} + +static int +wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt) +{ + pkt->p_state = WG_PACKET_UNCRYPTED; + + mtx_lock(&serial->q_mtx); + if (serial->q_len < MAX_QUEUED_PKT) { + serial->q_len++; + STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial); + } else { + mtx_unlock(&serial->q_mtx); + wg_packet_free(pkt); + return (ENOBUFS); + } + mtx_unlock(&serial->q_mtx); + + mtx_lock(¶llel->q_mtx); + if (parallel->q_len < MAX_QUEUED_PKT) { + parallel->q_len++; + STAILQ_INSERT_TAIL(¶llel->q_queue, pkt, p_parallel); + } else { + mtx_unlock(¶llel->q_mtx); + pkt->p_state = WG_PACKET_DEAD; + return (ENOBUFS); + } + mtx_unlock(¶llel->q_mtx); + + return (0); +} + +static struct wg_packet * +wg_queue_dequeue_serial(struct wg_queue *serial) +{ + struct wg_packet *pkt = NULL; + mtx_lock(&serial->q_mtx); + if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) { + serial->q_len--; + pkt = STAILQ_FIRST(&serial->q_queue); + STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial); + } + mtx_unlock(&serial->q_mtx); + return (pkt); +} + +static struct wg_packet * +wg_queue_dequeue_parallel(struct wg_queue *parallel) +{ + struct wg_packet *pkt = NULL; + mtx_lock(¶llel->q_mtx); + if (parallel->q_len > 0) { + parallel->q_len--; + pkt = STAILQ_FIRST(¶llel->q_queue); + STAILQ_REMOVE_HEAD(¶llel->q_queue, p_parallel); + } + mtx_unlock(¶llel->q_mtx); + return (pkt); +} + +static bool +wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, + const struct sockaddr *sa, void *_sc) +{ +#ifdef INET + const struct sockaddr_in *sin; +#endif +#ifdef INET6 + const struct sockaddr_in6 *sin6; +#endif + struct noise_remote *remote; + struct wg_pkt_data *data; + struct wg_packet *pkt; + struct wg_peer *peer; + struct wg_softc *sc = _sc; + struct mbuf *defragged; + + defragged = m_defrag(m, M_NOWAIT); + if (defragged) + m = defragged; + m = m_unshare(m, M_NOWAIT); + if (!m) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + return true; + } + + /* Caller provided us with `sa`, no need for this header. */ + m_adj(m, offset + sizeof(struct udphdr)); + + /* Pullup enough to read packet type */ + if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + return true; + } + + if ((pkt = wg_packet_alloc(m)) == NULL) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + m_freem(m); + return true; + } + + /* Save send/recv address and port for later. */ + switch (sa->sa_family) { +#ifdef INET + case AF_INET: + sin = (const struct sockaddr_in *)sa; + pkt->p_endpoint.e_remote.r_sin = sin[0]; + pkt->p_endpoint.e_local.l_in = sin[1].sin_addr; + break; +#endif +#ifdef INET6 + case AF_INET6: + sin6 = (const struct sockaddr_in6 *)sa; + pkt->p_endpoint.e_remote.r_sin6 = sin6[0]; + pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr; + break; +#endif + default: + goto error; + } + + if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) && + *mtod(m, uint32_t *) == WG_PKT_INITIATION) || + (m->m_pkthdr.len == sizeof(struct wg_pkt_response) && + *mtod(m, uint32_t *) == WG_PKT_RESPONSE) || + (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) && + *mtod(m, uint32_t *) == WG_PKT_COOKIE)) { + + if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + DPRINTF(sc, "Dropping handshake packet\n"); + } + GROUPTASK_ENQUEUE(&sc->sc_handshake); + } else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) + + NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) { + + /* Pullup whole header to read r_idx below. */ + if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL) + goto error; + + data = mtod(pkt->p_mbuf, struct wg_pkt_data *); + if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL) + goto error; + + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); + if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0) + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + wg_decrypt_dispatch(sc); + noise_remote_put(remote); + } else { + goto error; + } + return true; +error: + if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1); + wg_packet_free(pkt); + return true; +} + +static void +wg_peer_send_staged(struct wg_peer *peer) +{ + struct wg_packet_list list; + struct noise_keypair *keypair; + struct wg_packet *pkt, *tpkt; + struct wg_softc *sc = peer->p_sc; + + wg_queue_delist_staged(&peer->p_stage_queue, &list); + + if (STAILQ_EMPTY(&list)) + return; + + if ((keypair = noise_keypair_current(peer->p_remote)) == NULL) + goto error; + + STAILQ_FOREACH(pkt, &list, p_parallel) { + if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0) + goto error_keypair; + } + STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) { + pkt->p_keypair = noise_keypair_ref(keypair); + if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0) + if_inc_counter(sc->sc_ifp, IFCOUNTER_OQDROPS, 1); + } + wg_encrypt_dispatch(sc); + noise_keypair_put(keypair); + return; + +error_keypair: + noise_keypair_put(keypair); +error: + wg_queue_enlist_staged(&peer->p_stage_queue, &list); + wg_timers_event_want_initiation(peer); +} + +static inline void +xmit_err(struct ifnet *ifp, struct mbuf *m, struct wg_packet *pkt, sa_family_t af) +{ + if_inc_counter(ifp, IFCOUNTER_OERRORS, 1); + switch (af) { +#ifdef INET + case AF_INET: + icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_HOST, 0, 0); + if (pkt) + pkt->p_mbuf = NULL; + m = NULL; + break; +#endif +#ifdef INET6 + case AF_INET6: + icmp6_error(m, ICMP6_DST_UNREACH, 0, 0); + if (pkt) + pkt->p_mbuf = NULL; + m = NULL; + break; +#endif + } + if (pkt) + wg_packet_free(pkt); + else if (m) + m_freem(m); +} + +static int +wg_xmit(struct ifnet *ifp, struct mbuf *m, sa_family_t af, uint32_t mtu) +{ + struct wg_packet *pkt = NULL; + struct wg_softc *sc = ifp->if_softc; + struct wg_peer *peer; + int rc = 0; + sa_family_t peer_af; + + /* Work around lifetime issue in the ipv6 mld code. */ + if (__predict_false((ifp->if_flags & IFF_DYING) || !sc)) { + rc = ENXIO; + goto err_xmit; + } + + if ((pkt = wg_packet_alloc(m)) == NULL) { + rc = ENOBUFS; + goto err_xmit; + } + pkt->p_mtu = mtu; + pkt->p_af = af; + + if (af == AF_INET) { + peer = wg_aip_lookup(sc, AF_INET, &mtod(m, struct ip *)->ip_dst); + } else if (af == AF_INET6) { + peer = wg_aip_lookup(sc, AF_INET6, &mtod(m, struct ip6_hdr *)->ip6_dst); + } else { + rc = EAFNOSUPPORT; + goto err_xmit; + } + + BPF_MTAP2_AF(ifp, m, pkt->p_af); + + if (__predict_false(peer == NULL)) { + rc = ENOKEY; + goto err_xmit; + } + + if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) { + DPRINTF(sc, "Packet looped"); + rc = ELOOP; + goto err_peer; + } + + peer_af = peer->p_endpoint.e_remote.r_sa.sa_family; + if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) { + DPRINTF(sc, "No valid endpoint has been configured or " + "discovered for peer %" PRIu64 "\n", peer->p_id); + rc = EHOSTUNREACH; + goto err_peer; + } + + wg_queue_push_staged(&peer->p_stage_queue, pkt); + wg_peer_send_staged(peer); + noise_remote_put(peer->p_remote); + return (0); + +err_peer: + noise_remote_put(peer->p_remote); +err_xmit: + xmit_err(ifp, m, pkt, af); + return (rc); +} + +static inline int +determine_af_and_pullup(struct mbuf **m, sa_family_t *af) +{ + u_char ipv; + if ((*m)->m_pkthdr.len >= sizeof(struct ip6_hdr)) + *m = m_pullup(*m, sizeof(struct ip6_hdr)); + else if ((*m)->m_pkthdr.len >= sizeof(struct ip)) + *m = m_pullup(*m, sizeof(struct ip)); + else + return (EAFNOSUPPORT); + if (*m == NULL) + return (ENOBUFS); + ipv = mtod(*m, struct ip *)->ip_v; + if (ipv == 4) + *af = AF_INET; + else if (ipv == 6 && (*m)->m_pkthdr.len >= sizeof(struct ip6_hdr)) + *af = AF_INET6; + else + return (EAFNOSUPPORT); + return (0); +} + +static int +wg_transmit(struct ifnet *ifp, struct mbuf *m) +{ + sa_family_t af; + int ret; + struct mbuf *defragged; + + defragged = m_defrag(m, M_NOWAIT); + if (defragged) + m = defragged; + m = m_unshare(m, M_NOWAIT); + if (!m) { + xmit_err(ifp, m, NULL, AF_UNSPEC); + return (ENOBUFS); + } + + ret = determine_af_and_pullup(&m, &af); + if (ret) { + xmit_err(ifp, m, NULL, AF_UNSPEC); + return (ret); + } + return (wg_xmit(ifp, m, af, ifp->if_mtu)); +} + +static int +wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *dst, struct route *ro) +{ + sa_family_t parsed_af; + uint32_t af, mtu; + int ret; + struct mbuf *defragged; + + if (dst->sa_family == AF_UNSPEC) + memcpy(&af, dst->sa_data, sizeof(af)); + else + af = dst->sa_family; + if (af == AF_UNSPEC) { + xmit_err(ifp, m, NULL, af); + return (EAFNOSUPPORT); + } + + defragged = m_defrag(m, M_NOWAIT); + if (defragged) + m = defragged; + m = m_unshare(m, M_NOWAIT); + if (!m) { + xmit_err(ifp, m, NULL, AF_UNSPEC); + return (ENOBUFS); + } + + ret = determine_af_and_pullup(&m, &parsed_af); + if (ret) { + xmit_err(ifp, m, NULL, AF_UNSPEC); + return (ret); + } + if (parsed_af != af) { + xmit_err(ifp, m, NULL, AF_UNSPEC); + return (EAFNOSUPPORT); + } + mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : ifp->if_mtu; + return (wg_xmit(ifp, m, parsed_af, mtu)); +} + +static int +wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) +{ + uint8_t public[WG_KEY_SIZE]; + const void *pub_key, *preshared_key = NULL; + const struct sockaddr *endpoint; + int err; + size_t size; + struct noise_remote *remote; + struct wg_peer *peer = NULL; + bool need_insert = false; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + if (!nvlist_exists_binary(nvl, "public-key")) { + return (EINVAL); + } + pub_key = nvlist_get_binary(nvl, "public-key", &size); + if (size != WG_KEY_SIZE) { + return (EINVAL); + } + if (noise_local_keys(sc->sc_local, public, NULL) == 0 && + bcmp(public, pub_key, WG_KEY_SIZE) == 0) { + return (0); // Silently ignored; not actually a failure. + } + if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL) + peer = noise_remote_arg(remote); + if (nvlist_exists_bool(nvl, "remove") && + nvlist_get_bool(nvl, "remove")) { + if (remote != NULL) { + wg_peer_destroy(peer); + noise_remote_put(remote); + } + return (0); + } + if (nvlist_exists_bool(nvl, "replace-allowedips") && + nvlist_get_bool(nvl, "replace-allowedips") && + peer != NULL) { + + wg_aip_remove_all(sc, peer); + } + if (peer == NULL) { + peer = wg_peer_alloc(sc, pub_key); + need_insert = true; + } + if (nvlist_exists_binary(nvl, "endpoint")) { + endpoint = nvlist_get_binary(nvl, "endpoint", &size); + if (size > sizeof(peer->p_endpoint.e_remote)) { + err = EINVAL; + goto out; + } + memcpy(&peer->p_endpoint.e_remote, endpoint, size); + } + if (nvlist_exists_binary(nvl, "preshared-key")) { + preshared_key = nvlist_get_binary(nvl, "preshared-key", &size); + if (size != WG_KEY_SIZE) { + err = EINVAL; + goto out; + } + noise_remote_set_psk(peer->p_remote, preshared_key); + } + if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) { + uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval"); + if (pki > UINT16_MAX) { + err = EINVAL; + goto out; + } + wg_timers_set_persistent_keepalive(peer, pki); + } + if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) { + const void *addr; + uint64_t cidr; + const nvlist_t * const * aipl; + size_t allowedip_count; + + aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", &allowedip_count); + for (size_t idx = 0; idx < allowedip_count; idx++) { + if (!nvlist_exists_number(aipl[idx], "cidr")) + continue; + cidr = nvlist_get_number(aipl[idx], "cidr"); + if (nvlist_exists_binary(aipl[idx], "ipv4")) { + addr = nvlist_get_binary(aipl[idx], "ipv4", &size); + if (addr == NULL || cidr > 32 || size != sizeof(struct in_addr)) { + err = EINVAL; + goto out; + } + if ((err = wg_aip_add(sc, peer, AF_INET, addr, cidr)) != 0) + goto out; + } else if (nvlist_exists_binary(aipl[idx], "ipv6")) { + addr = nvlist_get_binary(aipl[idx], "ipv6", &size); + if (addr == NULL || cidr > 128 || size != sizeof(struct in6_addr)) { + err = EINVAL; + goto out; + } + if ((err = wg_aip_add(sc, peer, AF_INET6, addr, cidr)) != 0) + goto out; + } else { + continue; + } + } + } + if (need_insert) { + if ((err = noise_remote_enable(peer->p_remote)) != 0) + goto out; + TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry); + sc->sc_peers_num++; + if (sc->sc_ifp->if_link_state == LINK_STATE_UP) + wg_timers_enable(peer); + } + if (remote != NULL) + noise_remote_put(remote); + return (0); +out: + if (need_insert) /* If we fail, only destroy if it was new. */ + wg_peer_destroy(peer); + if (remote != NULL) + noise_remote_put(remote); + return (err); +} + +static int +wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) +{ + uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE]; + struct ifnet *ifp; + void *nvlpacked; + nvlist_t *nvl; + ssize_t size; + int err; + + ifp = sc->sc_ifp; + if (wgd->wgd_size == 0 || wgd->wgd_data == NULL) + return (EFAULT); + + /* Can nvlists be streamed in? It's not nice to impose arbitrary limits like that but + * there needs to be _some_ limitation. */ + if (wgd->wgd_size >= UINT32_MAX / 2) + return (E2BIG); + + nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO); + + err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size); + if (err) + goto out; + nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0); + if (nvl == NULL) { + err = EBADMSG; + goto out; + } + sx_xlock(&sc->sc_lock); + if (nvlist_exists_bool(nvl, "replace-peers") && + nvlist_get_bool(nvl, "replace-peers")) + wg_peer_destroy_all(sc); + if (nvlist_exists_number(nvl, "listen-port")) { + uint64_t new_port = nvlist_get_number(nvl, "listen-port"); + if (new_port > UINT16_MAX) { + err = EINVAL; + goto out_locked; + } + if (new_port != sc->sc_socket.so_port) { + if ((ifp->if_drv_flags & IFF_DRV_RUNNING) != 0) { + if ((err = wg_socket_init(sc, new_port)) != 0) + goto out_locked; + } else + sc->sc_socket.so_port = new_port; + } + } + if (nvlist_exists_binary(nvl, "private-key")) { + const void *key = nvlist_get_binary(nvl, "private-key", &size); + if (size != WG_KEY_SIZE) { + err = EINVAL; + goto out_locked; + } + + if (noise_local_keys(sc->sc_local, NULL, private) != 0 || + timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) { + struct wg_peer *peer; + + if (curve25519_generate_public(public, key)) { + /* Peer conflict: remove conflicting peer. */ + struct noise_remote *remote; + if ((remote = noise_remote_lookup(sc->sc_local, + public)) != NULL) { + peer = noise_remote_arg(remote); + wg_peer_destroy(peer); + noise_remote_put(remote); + } + } + + /* + * Set the private key and invalidate all existing + * handshakes. + */ + /* Note: we might be removing the private key. */ + noise_local_private(sc->sc_local, key); + if (noise_local_keys(sc->sc_local, NULL, NULL) == 0) + cookie_checker_update(&sc->sc_cookie, public); + else + cookie_checker_update(&sc->sc_cookie, NULL); + } + } + if (nvlist_exists_number(nvl, "user-cookie")) { + uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie"); + if (user_cookie > UINT32_MAX) { + err = EINVAL; + goto out_locked; + } + err = wg_socket_set_cookie(sc, user_cookie); + if (err) + goto out_locked; + } + if (nvlist_exists_nvlist_array(nvl, "peers")) { + size_t peercount; + const nvlist_t * const*nvl_peers; + + nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount); + for (int i = 0; i < peercount; i++) { + err = wg_peer_add(sc, nvl_peers[i]); + if (err != 0) + goto out_locked; + } + } + +out_locked: + sx_xunlock(&sc->sc_lock); + nvlist_destroy(nvl); +out: + explicit_bzero(nvlpacked, wgd->wgd_size); + free(nvlpacked, M_TEMP); + return (err); +} + +static int +wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) +{ + uint8_t public_key[WG_KEY_SIZE] = { 0 }; + uint8_t private_key[WG_KEY_SIZE] = { 0 }; + uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 }; + nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips; + size_t size, peer_count, aip_count, i, j; + struct wg_timespec64 ts64; + struct wg_peer *peer; + struct wg_aip *aip; + void *packed; + int err = 0; + + nvl = nvlist_create(0); + if (!nvl) + return (ENOMEM); + + sx_slock(&sc->sc_lock); + + if (sc->sc_socket.so_port != 0) + nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port); + if (sc->sc_socket.so_user_cookie != 0) + nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie); + if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) { + nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE); + if (wgc_privileged(sc)) + nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE); + explicit_bzero(private_key, sizeof(private_key)); + } + peer_count = sc->sc_peers_num; + if (peer_count) { + nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO); + i = 0; + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { + if (i >= peer_count) + panic("peers changed from under us"); + + nvl_peers[i++] = nvl_peer = nvlist_create(0); + if (!nvl_peer) { + err = ENOMEM; + goto err_peer; + } + + (void)noise_remote_keys(peer->p_remote, public_key, preshared_key); + nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key)); + if (wgc_privileged(sc)) + nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key)); + explicit_bzero(preshared_key, sizeof(preshared_key)); + if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET) + nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in)); + else if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET6) + nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in6)); + wg_timers_get_last_handshake(peer, &ts64); + nvlist_add_binary(nvl_peer, "last-handshake-time", &ts64, sizeof(ts64)); + nvlist_add_number(nvl_peer, "persistent-keepalive-interval", peer->p_persistent_keepalive_interval); + nvlist_add_number(nvl_peer, "rx-bytes", counter_u64_fetch(peer->p_rx_bytes)); + nvlist_add_number(nvl_peer, "tx-bytes", counter_u64_fetch(peer->p_tx_bytes)); + + aip_count = peer->p_aips_num; + if (aip_count) { + nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO); + j = 0; + LIST_FOREACH(aip, &peer->p_aips, a_entry) { + if (j >= aip_count) + panic("aips changed from under us"); + + nvl_aips[j++] = nvl_aip = nvlist_create(0); + if (!nvl_aip) { + err = ENOMEM; + goto err_aip; + } + if (aip->a_af == AF_INET) { + nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in)); + nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip)); + } +#ifdef INET6 + else if (aip->a_af == AF_INET6) { + nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6)); + nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL)); + } +#endif + } + nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count); + err_aip: + for (j = 0; j < aip_count; ++j) + nvlist_destroy(nvl_aips[j]); + free(nvl_aips, M_NVLIST); + if (err) + goto err_peer; + } + } + nvlist_add_nvlist_array(nvl, "peers", (const nvlist_t * const *)nvl_peers, peer_count); + err_peer: + for (i = 0; i < peer_count; ++i) + nvlist_destroy(nvl_peers[i]); + free(nvl_peers, M_NVLIST); + if (err) { + sx_sunlock(&sc->sc_lock); + goto err; + } + } + sx_sunlock(&sc->sc_lock); + packed = nvlist_pack(nvl, &size); + if (!packed) { + err = ENOMEM; + goto err; + } + if (!wgd->wgd_size) { + wgd->wgd_size = size; + goto out; + } + if (wgd->wgd_size < size) { + err = ENOSPC; + goto out; + } + err = copyout(packed, wgd->wgd_data, size); + wgd->wgd_size = size; + +out: + explicit_bzero(packed, size); + free(packed, M_NVLIST); +err: + nvlist_destroy(nvl); + return (err); +} + +static int +wg_ioctl(struct ifnet *ifp, u_long cmd, caddr_t data) +{ + struct wg_data_io *wgd = (struct wg_data_io *)data; + struct ifreq *ifr = (struct ifreq *)data; + struct wg_softc *sc; + int ret = 0; + + sx_slock(&wg_sx); + sc = ifp->if_softc; + if (!sc) { + ret = ENXIO; + goto out; + } + + switch (cmd) { + case SIOCSWG: + ret = priv_check(curthread, PRIV_NET_WG); + if (ret == 0) + ret = wgc_set(sc, wgd); + break; + case SIOCGWG: + ret = wgc_get(sc, wgd); + break; + /* Interface IOCTLs */ + case SIOCSIFADDR: + /* + * This differs from *BSD norms, but is more uniform with how + * WireGuard behaves elsewhere. + */ + break; + case SIOCSIFFLAGS: + if (ifp->if_flags & IFF_UP) + ret = wg_up(sc); + else + wg_down(sc); + break; + case SIOCSIFMTU: + if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU) + ret = EINVAL; + else + ifp->if_mtu = ifr->ifr_mtu; + break; + case SIOCADDMULTI: + case SIOCDELMULTI: + break; + case SIOCGTUNFIB: + ifr->ifr_fib = sc->sc_socket.so_fibnum; + break; + case SIOCSTUNFIB: + ret = priv_check(curthread, PRIV_NET_WG); + if (ret) + break; + ret = priv_check(curthread, PRIV_NET_SETIFFIB); + if (ret) + break; + sx_xlock(&sc->sc_lock); + ret = wg_socket_set_fibnum(sc, ifr->ifr_fib); + sx_xunlock(&sc->sc_lock); + break; + default: + ret = ENOTTY; + } + +out: + sx_sunlock(&wg_sx); + return (ret); +} + +static int +wg_up(struct wg_softc *sc) +{ + struct ifnet *ifp = sc->sc_ifp; + struct wg_peer *peer; + int rc = EBUSY; + + sx_xlock(&sc->sc_lock); + /* Jail's being removed, no more wg_up(). */ + if ((sc->sc_flags & WGF_DYING) != 0) + goto out; + + /* Silent success if we're already running. */ + rc = 0; + if (ifp->if_drv_flags & IFF_DRV_RUNNING) + goto out; + ifp->if_drv_flags |= IFF_DRV_RUNNING; + + rc = wg_socket_init(sc, sc->sc_socket.so_port); + if (rc == 0) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) + wg_timers_enable(peer); + if_link_state_change(sc->sc_ifp, LINK_STATE_UP); + } else { + ifp->if_drv_flags &= ~IFF_DRV_RUNNING; + DPRINTF(sc, "Unable to initialize sockets: %d\n", rc); + } +out: + sx_xunlock(&sc->sc_lock); + return (rc); +} + +static void +wg_down(struct wg_softc *sc) +{ + struct ifnet *ifp = sc->sc_ifp; + struct wg_peer *peer; + + sx_xlock(&sc->sc_lock); + if (!(ifp->if_drv_flags & IFF_DRV_RUNNING)) { + sx_xunlock(&sc->sc_lock); + return; + } + ifp->if_drv_flags &= ~IFF_DRV_RUNNING; + + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { + wg_queue_purge(&peer->p_stage_queue); + wg_timers_disable(peer); + } + + wg_queue_purge(&sc->sc_handshake_queue); + + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { + noise_remote_handshake_clear(peer->p_remote); + noise_remote_keypairs_clear(peer->p_remote); + } + + if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN); + wg_socket_uninit(sc); + + sx_xunlock(&sc->sc_lock); +} + +static int +wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) +{ + struct wg_softc *sc; + struct ifnet *ifp; + + sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO); + + sc->sc_local = noise_local_alloc(sc); + + sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO); + + sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO); + + if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY)) + goto free_decrypt; + + if (!rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY)) + goto free_aip4; + + atomic_add_int(&clone_count, 1); + ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD); + + sc->sc_ucred = crhold(curthread->td_ucred); + sc->sc_socket.so_fibnum = curthread->td_proc->p_fibnum; + sc->sc_socket.so_port = 0; + + TAILQ_INIT(&sc->sc_peers); + sc->sc_peers_num = 0; + + cookie_checker_init(&sc->sc_cookie); + + RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4); + RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6); + + GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc); + taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation"); + wg_queue_init(&sc->sc_handshake_queue, "hsq"); + + for (int i = 0; i < mp_ncpus; i++) { + GROUPTASK_INIT(&sc->sc_encrypt[i], 0, + (gtask_fn_t *)wg_softc_encrypt, sc); + taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt"); + GROUPTASK_INIT(&sc->sc_decrypt[i], 0, + (gtask_fn_t *)wg_softc_decrypt, sc); + taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt"); + } + + wg_queue_init(&sc->sc_encrypt_parallel, "encp"); + wg_queue_init(&sc->sc_decrypt_parallel, "decp"); + + sx_init(&sc->sc_lock, "wg softc lock"); + + ifp->if_softc = sc; + ifp->if_capabilities = ifp->if_capenable = WG_CAPS; + if_initname(ifp, wgname, unit); + + if_setmtu(ifp, DEFAULT_MTU); + ifp->if_flags = IFF_NOARP | IFF_MULTICAST; + ifp->if_init = wg_init; + ifp->if_reassign = wg_reassign; + ifp->if_qflush = wg_qflush; + ifp->if_transmit = wg_transmit; + ifp->if_output = wg_output; + ifp->if_ioctl = wg_ioctl; + if_attach(ifp); + bpfattach(ifp, DLT_NULL, sizeof(uint32_t)); +#ifdef INET6 + ND_IFINFO(ifp)->flags &= ~ND6_IFF_AUTO_LINKLOCAL; + ND_IFINFO(ifp)->flags |= ND6_IFF_NO_DAD; +#endif + sx_xlock(&wg_sx); + LIST_INSERT_HEAD(&wg_list, sc, sc_entry); + sx_xunlock(&wg_sx); + return (0); +free_aip4: + RADIX_NODE_HEAD_DESTROY(sc->sc_aip4); + free(sc->sc_aip4, M_RTABLE); +free_decrypt: + free(sc->sc_decrypt, M_WG); + free(sc->sc_encrypt, M_WG); + noise_local_free(sc->sc_local, NULL); + free(sc, M_WG); + return (ENOMEM); +} + +static void +wg_clone_deferred_free(struct noise_local *l) +{ + struct wg_softc *sc = noise_local_arg(l); + + free(sc, M_WG); + atomic_add_int(&clone_count, -1); +} + +static void +wg_clone_destroy(struct ifnet *ifp) +{ + struct wg_softc *sc = ifp->if_softc; + struct ucred *cred; + + sx_xlock(&wg_sx); + ifp->if_softc = NULL; + sx_xlock(&sc->sc_lock); + sc->sc_flags |= WGF_DYING; + cred = sc->sc_ucred; + sc->sc_ucred = NULL; + sx_xunlock(&sc->sc_lock); + LIST_REMOVE(sc, sc_entry); + sx_xunlock(&wg_sx); + + if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN); + CURVNET_SET(sc->sc_ifp->if_vnet); + if_purgeaddrs(sc->sc_ifp); + CURVNET_RESTORE(); + + sx_xlock(&sc->sc_lock); + wg_socket_uninit(sc); + sx_xunlock(&sc->sc_lock); + + /* + * No guarantees that all traffic have passed until the epoch has + * elapsed with the socket closed. + */ + NET_EPOCH_WAIT(); + + taskqgroup_drain_all(qgroup_wg_tqg); + sx_xlock(&sc->sc_lock); + wg_peer_destroy_all(sc); + epoch_drain_callbacks(net_epoch_preempt); + sx_xunlock(&sc->sc_lock); + sx_destroy(&sc->sc_lock); + taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake); + for (int i = 0; i < mp_ncpus; i++) { + taskqgroup_detach(qgroup_wg_tqg, &sc->sc_encrypt[i]); + taskqgroup_detach(qgroup_wg_tqg, &sc->sc_decrypt[i]); + } + free(sc->sc_encrypt, M_WG); + free(sc->sc_decrypt, M_WG); + wg_queue_deinit(&sc->sc_handshake_queue); + wg_queue_deinit(&sc->sc_encrypt_parallel); + wg_queue_deinit(&sc->sc_decrypt_parallel); + + RADIX_NODE_HEAD_DESTROY(sc->sc_aip4); + RADIX_NODE_HEAD_DESTROY(sc->sc_aip6); + rn_detachhead((void **)&sc->sc_aip4); + rn_detachhead((void **)&sc->sc_aip6); + + cookie_checker_free(&sc->sc_cookie); + + if (cred != NULL) + crfree(cred); + if_detach(sc->sc_ifp); + if_free(sc->sc_ifp); + + noise_local_free(sc->sc_local, wg_clone_deferred_free); +} + +static void +wg_qflush(struct ifnet *ifp __unused) +{ +} + +/* + * Privileged information (private-key, preshared-key) are only exported for + * root and jailed root by default. + */ +static bool +wgc_privileged(struct wg_softc *sc) +{ + struct thread *td; + + td = curthread; + return (priv_check(td, PRIV_NET_WG) == 0); +} + +static void +wg_reassign(struct ifnet *ifp, struct vnet *new_vnet __unused, + char *unused __unused) +{ + struct wg_softc *sc; + + sc = ifp->if_softc; + wg_down(sc); +} + +static void +wg_init(void *xsc) +{ + struct wg_softc *sc; + + sc = xsc; + wg_up(sc); +} + +static void +vnet_wg_init(const void *unused __unused) +{ + V_wg_cloner = if_clone_simple(wgname, wg_clone_create, wg_clone_destroy, + 0); +} +VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY, + vnet_wg_init, NULL); + +static void +vnet_wg_uninit(const void *unused __unused) +{ + if (V_wg_cloner) + if_clone_detach(V_wg_cloner); +} +VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY, + vnet_wg_uninit, NULL); + +static int +wg_prison_remove(void *obj, void *data __unused) +{ + const struct prison *pr = obj; + struct wg_softc *sc; + + /* + * Do a pass through all if_wg interfaces and release creds on any from + * the jail that are supposed to be going away. This will, in turn, let + * the jail die so that we don't end up with Schrödinger's jail. + */ + sx_slock(&wg_sx); + LIST_FOREACH(sc, &wg_list, sc_entry) { + sx_xlock(&sc->sc_lock); + if (!(sc->sc_flags & WGF_DYING) && sc->sc_ucred && sc->sc_ucred->cr_prison == pr) { + struct ucred *cred = sc->sc_ucred; + DPRINTF(sc, "Creating jail exiting\n"); + if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN); + wg_socket_uninit(sc); + sc->sc_ucred = NULL; + crfree(cred); + sc->sc_flags |= WGF_DYING; + } + sx_xunlock(&sc->sc_lock); + } + sx_sunlock(&wg_sx); + + return (0); +} + +#ifdef SELFTESTS +#include "selftest/allowedips.c" +static bool wg_run_selftests(void) +{ + bool ret = true; + ret &= wg_allowedips_selftest(); + ret &= noise_counter_selftest(); + ret &= cookie_selftest(); + return ret; +} +#else +static inline bool wg_run_selftests(void) { return true; } +#endif + +static int +wg_module_init(void) +{ + int ret = ENOMEM; + + osd_method_t methods[PR_MAXMETHOD] = { + [PR_METHOD_REMOVE] = wg_prison_remove, + }; + + if ((wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet), + NULL, NULL, NULL, NULL, 0, 0)) == NULL) + goto free_none; + ret = crypto_init(); + if (ret != 0) + goto free_zone; + if (cookie_init() != 0) + goto free_crypto; + + wg_osd_jail_slot = osd_jail_register(NULL, methods); + + ret = ENOTRECOVERABLE; + if (!wg_run_selftests()) + goto free_all; + + return (0); + +free_all: + osd_jail_deregister(wg_osd_jail_slot); + cookie_deinit(); +free_crypto: + crypto_deinit(); +free_zone: + uma_zdestroy(wg_packet_zone); +free_none: + return (ret); +} + +static void +wg_module_deinit(void) +{ + VNET_ITERATOR_DECL(vnet_iter); + VNET_LIST_RLOCK(); + VNET_FOREACH(vnet_iter) { + struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner); + if (clone) { + if_clone_detach(clone); + VNET_VNET(vnet_iter, wg_cloner) = NULL; + } + } + VNET_LIST_RUNLOCK(); + NET_EPOCH_WAIT(); + MPASS(LIST_EMPTY(&wg_list)); + osd_jail_deregister(wg_osd_jail_slot); + cookie_deinit(); + crypto_deinit(); + uma_zdestroy(wg_packet_zone); +} + +static int +wg_module_event_handler(module_t mod, int what, void *arg) +{ + switch (what) { + case MOD_LOAD: + return wg_module_init(); + case MOD_UNLOAD: + wg_module_deinit(); + break; + default: + return (EOPNOTSUPP); + } + return (0); +} + +static moduledata_t wg_moduledata = { + wgname, + wg_module_event_handler, + NULL +}; + +DECLARE_MODULE(wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY); +MODULE_VERSION(wg, WIREGUARD_VERSION); +MODULE_DEPEND(wg, crypto, 1, 1, 1); diff --git a/sys/dev/wg/if_wg.h b/sys/dev/wg/if_wg.h new file mode 100644 index 000000000000..f137c931b5ce --- /dev/null +++ b/sys/dev/wg/if_wg.h @@ -0,0 +1,37 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (c) 2019 Matt Dunwoodie <ncon@noconroy.net> + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + * $FreeBSD$ + */ + +#ifndef __IF_WG_H__ +#define __IF_WG_H__ + +#include <net/if.h> +#include <netinet/in.h> + +struct wg_data_io { + char wgd_name[IFNAMSIZ]; + void *wgd_data; + size_t wgd_size; +}; + +#define WG_KEY_SIZE 32 + +#define SIOCSWG _IOWR('i', 210, struct wg_data_io) +#define SIOCGWG _IOWR('i', 211, struct wg_data_io) + +#endif /* __IF_WG_H__ */ diff --git a/sys/dev/wg/support.h b/sys/dev/wg/support.h new file mode 100644 index 000000000000..7934c5784a40 --- /dev/null +++ b/sys/dev/wg/support.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org> + * + * support.h contains code that is not _yet_ upstream in FreeBSD's main branch. + * It is different from compat.h, which is strictly for backports. + */ + +#ifndef _WG_SUPPORT +#define _WG_SUPPORT + +#ifndef ck_pr_store_bool +#define ck_pr_store_bool(dst, val) ck_pr_store_8((uint8_t *)(dst), (uint8_t)(val)) +#endif + +#ifndef ck_pr_load_bool +#define ck_pr_load_bool(src) ((bool)ck_pr_load_8((uint8_t *)(src))) +#endif + +#endif diff --git a/sys/dev/wg/version.h b/sys/dev/wg/version.h new file mode 100644 index 000000000000..f1a1d7246832 --- /dev/null +++ b/sys/dev/wg/version.h @@ -0,0 +1 @@ +#define WIREGUARD_VERSION 20220615 diff --git a/sys/dev/wg/wg_cookie.c b/sys/dev/wg/wg_cookie.c new file mode 100644 index 000000000000..16fa5d7fb52d --- /dev/null +++ b/sys/dev/wg/wg_cookie.c @@ -0,0 +1,500 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net> + */ + +#include "opt_inet.h" +#include "opt_inet6.h" + +#include <sys/param.h> +#include <sys/systm.h> +#include <sys/kernel.h> +#include <sys/lock.h> +#include <sys/mutex.h> +#include <sys/rwlock.h> +#include <sys/socket.h> +#include <crypto/siphash/siphash.h> +#include <netinet/in.h> +#include <vm/uma.h> + +#include "wg_cookie.h" + +#define COOKIE_MAC1_KEY_LABEL "mac1----" +#define COOKIE_COOKIE_KEY_LABEL "cookie--" +#define COOKIE_SECRET_MAX_AGE 120 +#define COOKIE_SECRET_LATENCY 5 + +/* Constants for initiation rate limiting */ +#define RATELIMIT_SIZE (1 << 13) +#define RATELIMIT_MASK (RATELIMIT_SIZE - 1) +#define RATELIMIT_SIZE_MAX (RATELIMIT_SIZE * 8) +#define INITIATIONS_PER_SECOND 20 +#define INITIATIONS_BURSTABLE 5 +#define INITIATION_COST (SBT_1S / INITIATIONS_PER_SECOND) +#define TOKEN_MAX (INITIATION_COST * INITIATIONS_BURSTABLE) +#define ELEMENT_TIMEOUT 1 +#define IPV4_MASK_SIZE 4 /* Use all 4 bytes of IPv4 address */ +#define IPV6_MASK_SIZE 8 /* Use top 8 bytes (/64) of IPv6 address */ + +struct ratelimit_key { + struct vnet *vnet; + uint8_t ip[IPV6_MASK_SIZE]; +}; + +struct ratelimit_entry { + LIST_ENTRY(ratelimit_entry) r_entry; + struct ratelimit_key r_key; + sbintime_t r_last_time; /* sbinuptime */ + uint64_t r_tokens; +}; + +struct ratelimit { + uint8_t rl_secret[SIPHASH_KEY_LENGTH]; + struct mtx rl_mtx; + struct callout rl_gc; + LIST_HEAD(, ratelimit_entry) rl_table[RATELIMIT_SIZE]; + size_t rl_table_num; +}; + +static void precompute_key(uint8_t *, + const uint8_t[COOKIE_INPUT_SIZE], const char *); +static void macs_mac1(struct cookie_macs *, const void *, size_t, + const uint8_t[COOKIE_KEY_SIZE]); +static void macs_mac2(struct cookie_macs *, const void *, size_t, + const uint8_t[COOKIE_COOKIE_SIZE]); +static int timer_expired(sbintime_t, uint32_t, uint32_t); +static void make_cookie(struct cookie_checker *, + uint8_t[COOKIE_COOKIE_SIZE], struct sockaddr *); +static void ratelimit_init(struct ratelimit *); +static void ratelimit_deinit(struct ratelimit *); +static void ratelimit_gc_callout(void *); +static void ratelimit_gc_schedule(struct ratelimit *); +static void ratelimit_gc(struct ratelimit *, bool); +static int ratelimit_allow(struct ratelimit *, struct sockaddr *, struct vnet *); +static uint64_t siphash13(const uint8_t [SIPHASH_KEY_LENGTH], const void *, size_t); + +static struct ratelimit ratelimit_v4; +#ifdef INET6 +static struct ratelimit ratelimit_v6; +#endif +static uma_zone_t ratelimit_zone; + +/* Public Functions */ +int +cookie_init(void) +{ + if ((ratelimit_zone = uma_zcreate("wg ratelimit", + sizeof(struct ratelimit_entry), NULL, NULL, NULL, NULL, 0, 0)) == NULL) + return ENOMEM; + + ratelimit_init(&ratelimit_v4); +#ifdef INET6 + ratelimit_init(&ratelimit_v6); +#endif + return (0); +} + +void +cookie_deinit(void) +{ + ratelimit_deinit(&ratelimit_v4); +#ifdef INET6 + ratelimit_deinit(&ratelimit_v6); +#endif + uma_zdestroy(ratelimit_zone); +} + +void +cookie_checker_init(struct cookie_checker *cc) +{ + bzero(cc, sizeof(*cc)); + + rw_init(&cc->cc_key_lock, "cookie_checker_key"); + mtx_init(&cc->cc_secret_mtx, "cookie_checker_secret", NULL, MTX_DEF); +} + +void +cookie_checker_free(struct cookie_checker *cc) +{ + rw_destroy(&cc->cc_key_lock); + mtx_destroy(&cc->cc_secret_mtx); + explicit_bzero(cc, sizeof(*cc)); +} + +void +cookie_checker_update(struct cookie_checker *cc, + const uint8_t key[COOKIE_INPUT_SIZE]) +{ + rw_wlock(&cc->cc_key_lock); + if (key) { + precompute_key(cc->cc_mac1_key, key, COOKIE_MAC1_KEY_LABEL); + precompute_key(cc->cc_cookie_key, key, COOKIE_COOKIE_KEY_LABEL); + } else { + bzero(cc->cc_mac1_key, sizeof(cc->cc_mac1_key)); + bzero(cc->cc_cookie_key, sizeof(cc->cc_cookie_key)); + } + rw_wunlock(&cc->cc_key_lock); +} + +void +cookie_checker_create_payload(struct cookie_checker *cc, + struct cookie_macs *macs, uint8_t nonce[COOKIE_NONCE_SIZE], + uint8_t ecookie[COOKIE_ENCRYPTED_SIZE], struct sockaddr *sa) +{ + uint8_t cookie[COOKIE_COOKIE_SIZE]; + + make_cookie(cc, cookie, sa); + arc4random_buf(nonce, COOKIE_NONCE_SIZE); + + rw_rlock(&cc->cc_key_lock); + xchacha20poly1305_encrypt(ecookie, cookie, COOKIE_COOKIE_SIZE, + macs->mac1, COOKIE_MAC_SIZE, nonce, cc->cc_cookie_key); + rw_runlock(&cc->cc_key_lock); + + explicit_bzero(cookie, sizeof(cookie)); +} + +void +cookie_maker_init(struct cookie_maker *cm, const uint8_t key[COOKIE_INPUT_SIZE]) +{ + bzero(cm, sizeof(*cm)); + precompute_key(cm->cm_mac1_key, key, COOKIE_MAC1_KEY_LABEL); + precompute_key(cm->cm_cookie_key, key, COOKIE_COOKIE_KEY_LABEL); + rw_init(&cm->cm_lock, "cookie_maker"); +} + +void +cookie_maker_free(struct cookie_maker *cm) +{ + rw_destroy(&cm->cm_lock); + explicit_bzero(cm, sizeof(*cm)); +} + +int +cookie_maker_consume_payload(struct cookie_maker *cm, + uint8_t nonce[COOKIE_NONCE_SIZE], uint8_t ecookie[COOKIE_ENCRYPTED_SIZE]) +{ + uint8_t cookie[COOKIE_COOKIE_SIZE]; + int ret; + + rw_rlock(&cm->cm_lock); + if (!cm->cm_mac1_sent) { + ret = ETIMEDOUT; + goto error; + } + + if (!xchacha20poly1305_decrypt(cookie, ecookie, COOKIE_ENCRYPTED_SIZE, + cm->cm_mac1_last, COOKIE_MAC_SIZE, nonce, cm->cm_cookie_key)) { + ret = EINVAL; + goto error; + } + rw_runlock(&cm->cm_lock); + + rw_wlock(&cm->cm_lock); + memcpy(cm->cm_cookie, cookie, COOKIE_COOKIE_SIZE); + cm->cm_cookie_birthdate = getsbinuptime(); + cm->cm_cookie_valid = true; + cm->cm_mac1_sent = false; + rw_wunlock(&cm->cm_lock); + + return 0; +error: + rw_runlock(&cm->cm_lock); + return ret; +} + +void +cookie_maker_mac(struct cookie_maker *cm, struct cookie_macs *macs, void *buf, + size_t len) +{ + rw_wlock(&cm->cm_lock); + macs_mac1(macs, buf, len, cm->cm_mac1_key); + memcpy(cm->cm_mac1_last, macs->mac1, COOKIE_MAC_SIZE); + cm->cm_mac1_sent = true; + + if (cm->cm_cookie_valid && + !timer_expired(cm->cm_cookie_birthdate, + COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY, 0)) { + macs_mac2(macs, buf, len, cm->cm_cookie); + } else { + bzero(macs->mac2, COOKIE_MAC_SIZE); + cm->cm_cookie_valid = false; + } + rw_wunlock(&cm->cm_lock); +} + +int +cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *macs, + void *buf, size_t len, bool check_cookie, struct sockaddr *sa, struct vnet *vnet) +{ + struct cookie_macs our_macs; + uint8_t cookie[COOKIE_COOKIE_SIZE]; + + /* Validate incoming MACs */ + rw_rlock(&cc->cc_key_lock); + macs_mac1(&our_macs, buf, len, cc->cc_mac1_key); + rw_runlock(&cc->cc_key_lock); + + /* If mac1 is invald, we want to drop the packet */ + if (timingsafe_bcmp(our_macs.mac1, macs->mac1, COOKIE_MAC_SIZE) != 0) + return EINVAL; + + if (check_cookie) { + make_cookie(cc, cookie, sa); + macs_mac2(&our_macs, buf, len, cookie); + + /* If the mac2 is invalid, we want to send a cookie response */ + if (timingsafe_bcmp(our_macs.mac2, macs->mac2, COOKIE_MAC_SIZE) != 0) + return EAGAIN; + + /* If the mac2 is valid, we may want rate limit the peer. + * ratelimit_allow will return either 0 or ECONNREFUSED, + * implying there is no ratelimiting, or we should ratelimit + * (refuse) respectively. */ + if (sa->sa_family == AF_INET) + return ratelimit_allow(&ratelimit_v4, sa, vnet); +#ifdef INET6 + else if (sa->sa_family == AF_INET6) + return ratelimit_allow(&ratelimit_v6, sa, vnet); +#endif + else + return EAFNOSUPPORT; + } + + return 0; +} + +/* Private functions */ +static void +precompute_key(uint8_t *key, const uint8_t input[COOKIE_INPUT_SIZE], + const char *label) +{ + struct blake2s_state blake; + blake2s_init(&blake, COOKIE_KEY_SIZE); + blake2s_update(&blake, label, strlen(label)); + blake2s_update(&blake, input, COOKIE_INPUT_SIZE); + blake2s_final(&blake, key); +} + +static void +macs_mac1(struct cookie_macs *macs, const void *buf, size_t len, + const uint8_t key[COOKIE_KEY_SIZE]) +{ + struct blake2s_state state; + blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_KEY_SIZE); + blake2s_update(&state, buf, len); + blake2s_final(&state, macs->mac1); +} + +static void +macs_mac2(struct cookie_macs *macs, const void *buf, size_t len, + const uint8_t key[COOKIE_COOKIE_SIZE]) +{ + struct blake2s_state state; + blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_COOKIE_SIZE); + blake2s_update(&state, buf, len); + blake2s_update(&state, macs->mac1, COOKIE_MAC_SIZE); + blake2s_final(&state, macs->mac2); +} + +static __inline int +timer_expired(sbintime_t timer, uint32_t sec, uint32_t nsec) +{ + sbintime_t now = getsbinuptime(); + return (now > (timer + sec * SBT_1S + nstosbt(nsec))) ? ETIMEDOUT : 0; +} + +static void +make_cookie(struct cookie_checker *cc, uint8_t cookie[COOKIE_COOKIE_SIZE], + struct sockaddr *sa) +{ + struct blake2s_state state; + + mtx_lock(&cc->cc_secret_mtx); + if (timer_expired(cc->cc_secret_birthdate, + COOKIE_SECRET_MAX_AGE, 0)) { + arc4random_buf(cc->cc_secret, COOKIE_SECRET_SIZE); + cc->cc_secret_birthdate = getsbinuptime(); + } + blake2s_init_key(&state, COOKIE_COOKIE_SIZE, cc->cc_secret, + COOKIE_SECRET_SIZE); + mtx_unlock(&cc->cc_secret_mtx); + + if (sa->sa_family == AF_INET) { + blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_addr, + sizeof(struct in_addr)); + blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_port, + sizeof(in_port_t)); + blake2s_final(&state, cookie); +#ifdef INET6 + } else if (sa->sa_family == AF_INET6) { + blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_addr, + sizeof(struct in6_addr)); + blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_port, + sizeof(in_port_t)); + blake2s_final(&state, cookie); +#endif + } else { + arc4random_buf(cookie, COOKIE_COOKIE_SIZE); + } +} + +static void +ratelimit_init(struct ratelimit *rl) +{ + size_t i; + mtx_init(&rl->rl_mtx, "ratelimit_lock", NULL, MTX_DEF); + callout_init_mtx(&rl->rl_gc, &rl->rl_mtx, 0); + arc4random_buf(rl->rl_secret, sizeof(rl->rl_secret)); + for (i = 0; i < RATELIMIT_SIZE; i++) + LIST_INIT(&rl->rl_table[i]); + rl->rl_table_num = 0; +} + +static void +ratelimit_deinit(struct ratelimit *rl) +{ + mtx_lock(&rl->rl_mtx); + callout_stop(&rl->rl_gc); + ratelimit_gc(rl, true); + mtx_unlock(&rl->rl_mtx); + mtx_destroy(&rl->rl_mtx); +} + +static void +ratelimit_gc_callout(void *_rl) +{ + /* callout will lock rl_mtx for us */ + ratelimit_gc(_rl, false); +} + +static void +ratelimit_gc_schedule(struct ratelimit *rl) +{ + /* Trigger another GC if needed. There is no point calling GC if there + * are no entries in the table. We also want to ensure that GC occurs + * on a regular interval, so don't override a currently pending GC. + * + * In the case of a forced ratelimit_gc, there will be no entries left + * so we will will not schedule another GC. */ + if (rl->rl_table_num > 0 && !callout_pending(&rl->rl_gc)) + callout_reset(&rl->rl_gc, ELEMENT_TIMEOUT * hz, + ratelimit_gc_callout, rl); +} + +static void +ratelimit_gc(struct ratelimit *rl, bool force) +{ + size_t i; + struct ratelimit_entry *r, *tr; + sbintime_t expiry; + + mtx_assert(&rl->rl_mtx, MA_OWNED); + + if (rl->rl_table_num == 0) + return; + + expiry = getsbinuptime() - ELEMENT_TIMEOUT * SBT_1S; + + for (i = 0; i < RATELIMIT_SIZE; i++) { + LIST_FOREACH_SAFE(r, &rl->rl_table[i], r_entry, tr) { + if (r->r_last_time < expiry || force) { + rl->rl_table_num--; + LIST_REMOVE(r, r_entry); + uma_zfree(ratelimit_zone, r); + } + } + } + + ratelimit_gc_schedule(rl); +} + +static int +ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa, struct vnet *vnet) +{ + uint64_t bucket, tokens; + sbintime_t diff, now; + struct ratelimit_entry *r; + int ret = ECONNREFUSED; + struct ratelimit_key key = { .vnet = vnet }; + size_t len = sizeof(key); + + if (sa->sa_family == AF_INET) { + memcpy(key.ip, &satosin(sa)->sin_addr, IPV4_MASK_SIZE); + len -= IPV6_MASK_SIZE - IPV4_MASK_SIZE; + } +#ifdef INET6 + else if (sa->sa_family == AF_INET6) + memcpy(key.ip, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE); +#endif + else + return ret; + + bucket = siphash13(rl->rl_secret, &key, len) & RATELIMIT_MASK; + mtx_lock(&rl->rl_mtx); + + LIST_FOREACH(r, &rl->rl_table[bucket], r_entry) { + if (bcmp(&r->r_key, &key, len) != 0) + continue; + + /* If we get to here, we've found an entry for the endpoint. + * We apply standard token bucket, by calculating the time + * lapsed since our last_time, adding that, ensuring that we + * cap the tokens at TOKEN_MAX. If the endpoint has no tokens + * left (that is tokens <= INITIATION_COST) then we block the + * request, otherwise we subtract the INITITIATION_COST and + * return OK. */ + now = getsbinuptime(); + diff = now - r->r_last_time; + r->r_last_time = now; + + tokens = r->r_tokens + diff; + + if (tokens > TOKEN_MAX) + tokens = TOKEN_MAX; + + if (tokens >= INITIATION_COST) { + r->r_tokens = tokens - INITIATION_COST; + goto ok; + } else { + r->r_tokens = tokens; + goto error; + } + } + + /* If we get to here, we didn't have an entry for the endpoint, let's + * add one if we have space. */ + if (rl->rl_table_num >= RATELIMIT_SIZE_MAX) + goto error; + + /* Goto error if out of memory */ + if ((r = uma_zalloc(ratelimit_zone, M_NOWAIT | M_ZERO)) == NULL) + goto error; + + rl->rl_table_num++; + + /* Insert entry into the hashtable and ensure it's initialised */ + LIST_INSERT_HEAD(&rl->rl_table[bucket], r, r_entry); + r->r_key = key; + r->r_last_time = getsbinuptime(); + r->r_tokens = TOKEN_MAX - INITIATION_COST; + + /* If we've added a new entry, let's trigger GC. */ + ratelimit_gc_schedule(rl); +ok: + ret = 0; +error: + mtx_unlock(&rl->rl_mtx); + return ret; +} + +static uint64_t siphash13(const uint8_t key[SIPHASH_KEY_LENGTH], const void *src, size_t len) +{ + SIPHASH_CTX ctx; + return (SipHashX(&ctx, 1, 3, key, src, len)); +} + +#ifdef SELFTESTS +#include "selftest/cookie.c" +#endif /* SELFTESTS */ diff --git a/sys/dev/wg/wg_cookie.h b/sys/dev/wg/wg_cookie.h new file mode 100644 index 000000000000..97ff10da2aa5 --- /dev/null +++ b/sys/dev/wg/wg_cookie.h @@ -0,0 +1,72 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net> + */ + +#ifndef __COOKIE_H__ +#define __COOKIE_H__ + +#include "crypto.h" + +#define COOKIE_MAC_SIZE 16 +#define COOKIE_KEY_SIZE 32 +#define COOKIE_NONCE_SIZE XCHACHA20POLY1305_NONCE_SIZE +#define COOKIE_COOKIE_SIZE 16 +#define COOKIE_SECRET_SIZE 32 +#define COOKIE_INPUT_SIZE 32 +#define COOKIE_ENCRYPTED_SIZE (COOKIE_COOKIE_SIZE + COOKIE_MAC_SIZE) + +struct vnet; + +struct cookie_macs { + uint8_t mac1[COOKIE_MAC_SIZE]; + uint8_t mac2[COOKIE_MAC_SIZE]; +}; + +struct cookie_maker { + uint8_t cm_mac1_key[COOKIE_KEY_SIZE]; + uint8_t cm_cookie_key[COOKIE_KEY_SIZE]; + + struct rwlock cm_lock; + bool cm_cookie_valid; + uint8_t cm_cookie[COOKIE_COOKIE_SIZE]; + sbintime_t cm_cookie_birthdate; /* sbinuptime */ + bool cm_mac1_sent; + uint8_t cm_mac1_last[COOKIE_MAC_SIZE]; +}; + +struct cookie_checker { + struct rwlock cc_key_lock; + uint8_t cc_mac1_key[COOKIE_KEY_SIZE]; + uint8_t cc_cookie_key[COOKIE_KEY_SIZE]; + + struct mtx cc_secret_mtx; + sbintime_t cc_secret_birthdate; /* sbinuptime */ + uint8_t cc_secret[COOKIE_SECRET_SIZE]; +}; + +int cookie_init(void); +void cookie_deinit(void); +void cookie_checker_init(struct cookie_checker *); +void cookie_checker_free(struct cookie_checker *); +void cookie_checker_update(struct cookie_checker *, + const uint8_t[COOKIE_INPUT_SIZE]); +void cookie_checker_create_payload(struct cookie_checker *, + struct cookie_macs *cm, uint8_t[COOKIE_NONCE_SIZE], + uint8_t [COOKIE_ENCRYPTED_SIZE], struct sockaddr *); +void cookie_maker_init(struct cookie_maker *, const uint8_t[COOKIE_INPUT_SIZE]); +void cookie_maker_free(struct cookie_maker *); +int cookie_maker_consume_payload(struct cookie_maker *, + uint8_t[COOKIE_NONCE_SIZE], uint8_t[COOKIE_ENCRYPTED_SIZE]); +void cookie_maker_mac(struct cookie_maker *, struct cookie_macs *, + void *, size_t); +int cookie_checker_validate_macs(struct cookie_checker *, + struct cookie_macs *, void *, size_t, bool, struct sockaddr *, + struct vnet *); + +#ifdef SELFTESTS +bool cookie_selftest(void); +#endif /* SELFTESTS */ + +#endif /* __COOKIE_H__ */ diff --git a/sys/dev/wg/wg_crypto.c b/sys/dev/wg/wg_crypto.c new file mode 100644 index 000000000000..29d9487d647f --- /dev/null +++ b/sys/dev/wg/wg_crypto.c @@ -0,0 +1,1830 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (c) 2022 The FreeBSD Foundation + */ + +#include <sys/types.h> +#include <sys/systm.h> +#include <sys/endian.h> +#include <sys/mbuf.h> +#include <opencrypto/cryptodev.h> + +#include "crypto.h" + +#ifndef COMPAT_NEED_CHACHA20POLY1305_MBUF +static crypto_session_t chacha20_poly1305_sid; +#endif + +#ifndef ARRAY_SIZE +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) +#endif +#ifndef noinline +#define noinline __attribute__((noinline)) +#endif +#ifndef __aligned +#define __aligned(x) __attribute__((aligned(x))) +#endif +#ifndef DIV_ROUND_UP +#define DIV_ROUND_UP(n,d) (((n) + (d) - 1) / (d)) +#endif + +#define le32_to_cpup(a) le32toh(*(a)) +#define le64_to_cpup(a) le64toh(*(a)) +#define cpu_to_le32(a) htole32(a) +#define cpu_to_le64(a) htole64(a) + +static inline __unused uint32_t get_unaligned_le32(const uint8_t *a) +{ + uint32_t l; + __builtin_memcpy(&l, a, sizeof(l)); + return le32_to_cpup(&l); +} +static inline __unused uint64_t get_unaligned_le64(const uint8_t *a) +{ + uint64_t l; + __builtin_memcpy(&l, a, sizeof(l)); + return le64_to_cpup(&l); +} +static inline __unused void put_unaligned_le32(uint32_t s, uint8_t *d) +{ + uint32_t l = cpu_to_le32(s); + __builtin_memcpy(d, &l, sizeof(l)); +} +static inline __unused void cpu_to_le32_array(uint32_t *buf, unsigned int words) +{ + while (words--) { + *buf = cpu_to_le32(*buf); + ++buf; + } +} +static inline __unused void le32_to_cpu_array(uint32_t *buf, unsigned int words) +{ + while (words--) { + *buf = le32_to_cpup(buf); + ++buf; + } +} +static inline __unused uint32_t rol32(uint32_t word, unsigned int shift) +{ + return (word << (shift & 31)) | (word >> ((-shift) & 31)); +} +static inline __unused uint32_t ror32(uint32_t word, unsigned int shift) +{ + return (word >> (shift & 31)) | (word << ((-shift) & 31)); +} + +#if defined(COMPAT_NEED_CHACHA20POLY1305) || defined(COMPAT_NEED_CHACHA20POLY1305_MBUF) +static void xor_cpy(uint8_t *dst, const uint8_t *src1, const uint8_t *src2, size_t len) +{ + size_t i; + + for (i = 0; i < len; ++i) + dst[i] = src1[i] ^ src2[i]; +} + +#define QUARTER_ROUND(x, a, b, c, d) ( \ + x[a] += x[b], \ + x[d] = rol32((x[d] ^ x[a]), 16), \ + x[c] += x[d], \ + x[b] = rol32((x[b] ^ x[c]), 12), \ + x[a] += x[b], \ + x[d] = rol32((x[d] ^ x[a]), 8), \ + x[c] += x[d], \ + x[b] = rol32((x[b] ^ x[c]), 7) \ +) + +#define C(i, j) (i * 4 + j) + +#define DOUBLE_ROUND(x) ( \ + /* Column Round */ \ + QUARTER_ROUND(x, C(0, 0), C(1, 0), C(2, 0), C(3, 0)), \ + QUARTER_ROUND(x, C(0, 1), C(1, 1), C(2, 1), C(3, 1)), \ + QUARTER_ROUND(x, C(0, 2), C(1, 2), C(2, 2), C(3, 2)), \ + QUARTER_ROUND(x, C(0, 3), C(1, 3), C(2, 3), C(3, 3)), \ + /* Diagonal Round */ \ + QUARTER_ROUND(x, C(0, 0), C(1, 1), C(2, 2), C(3, 3)), \ + QUARTER_ROUND(x, C(0, 1), C(1, 2), C(2, 3), C(3, 0)), \ + QUARTER_ROUND(x, C(0, 2), C(1, 3), C(2, 0), C(3, 1)), \ + QUARTER_ROUND(x, C(0, 3), C(1, 0), C(2, 1), C(3, 2)) \ +) + +#define TWENTY_ROUNDS(x) ( \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x) \ +) + +enum chacha20_lengths { + CHACHA20_NONCE_SIZE = 16, + CHACHA20_KEY_SIZE = 32, + CHACHA20_KEY_WORDS = CHACHA20_KEY_SIZE / sizeof(uint32_t), + CHACHA20_BLOCK_SIZE = 64, + CHACHA20_BLOCK_WORDS = CHACHA20_BLOCK_SIZE / sizeof(uint32_t), + HCHACHA20_NONCE_SIZE = CHACHA20_NONCE_SIZE, + HCHACHA20_KEY_SIZE = CHACHA20_KEY_SIZE +}; + +enum chacha20_constants { /* expand 32-byte k */ + CHACHA20_CONSTANT_EXPA = 0x61707865U, + CHACHA20_CONSTANT_ND_3 = 0x3320646eU, + CHACHA20_CONSTANT_2_BY = 0x79622d32U, + CHACHA20_CONSTANT_TE_K = 0x6b206574U +}; + +struct chacha20_ctx { + union { + uint32_t state[16]; + struct { + uint32_t constant[4]; + uint32_t key[8]; + uint32_t counter[4]; + }; + }; +}; + +static void chacha20_init(struct chacha20_ctx *ctx, + const uint8_t key[CHACHA20_KEY_SIZE], + const uint64_t nonce) +{ + ctx->constant[0] = CHACHA20_CONSTANT_EXPA; + ctx->constant[1] = CHACHA20_CONSTANT_ND_3; + ctx->constant[2] = CHACHA20_CONSTANT_2_BY; + ctx->constant[3] = CHACHA20_CONSTANT_TE_K; + ctx->key[0] = get_unaligned_le32(key + 0); + ctx->key[1] = get_unaligned_le32(key + 4); + ctx->key[2] = get_unaligned_le32(key + 8); + ctx->key[3] = get_unaligned_le32(key + 12); + ctx->key[4] = get_unaligned_le32(key + 16); + ctx->key[5] = get_unaligned_le32(key + 20); + ctx->key[6] = get_unaligned_le32(key + 24); + ctx->key[7] = get_unaligned_le32(key + 28); + ctx->counter[0] = 0; + ctx->counter[1] = 0; + ctx->counter[2] = nonce & 0xffffffffU; + ctx->counter[3] = nonce >> 32; +} + +static void chacha20_block(struct chacha20_ctx *ctx, uint32_t *stream) +{ + uint32_t x[CHACHA20_BLOCK_WORDS]; + int i; + + for (i = 0; i < ARRAY_SIZE(x); ++i) + x[i] = ctx->state[i]; + + TWENTY_ROUNDS(x); + + for (i = 0; i < ARRAY_SIZE(x); ++i) + stream[i] = cpu_to_le32(x[i] + ctx->state[i]); + + ctx->counter[0] += 1; +} + +static void chacha20(struct chacha20_ctx *ctx, uint8_t *out, const uint8_t *in, + uint32_t len) +{ + uint32_t buf[CHACHA20_BLOCK_WORDS]; + + while (len >= CHACHA20_BLOCK_SIZE) { + chacha20_block(ctx, buf); + xor_cpy(out, in, (uint8_t *)buf, CHACHA20_BLOCK_SIZE); + len -= CHACHA20_BLOCK_SIZE; + out += CHACHA20_BLOCK_SIZE; + in += CHACHA20_BLOCK_SIZE; + } + if (len) { + chacha20_block(ctx, buf); + xor_cpy(out, in, (uint8_t *)buf, len); + } +} + +static void hchacha20(uint32_t derived_key[CHACHA20_KEY_WORDS], + const uint8_t nonce[HCHACHA20_NONCE_SIZE], + const uint8_t key[HCHACHA20_KEY_SIZE]) +{ + uint32_t x[] = { CHACHA20_CONSTANT_EXPA, + CHACHA20_CONSTANT_ND_3, + CHACHA20_CONSTANT_2_BY, + CHACHA20_CONSTANT_TE_K, + get_unaligned_le32(key + 0), + get_unaligned_le32(key + 4), + get_unaligned_le32(key + 8), + get_unaligned_le32(key + 12), + get_unaligned_le32(key + 16), + get_unaligned_le32(key + 20), + get_unaligned_le32(key + 24), + get_unaligned_le32(key + 28), + get_unaligned_le32(nonce + 0), + get_unaligned_le32(nonce + 4), + get_unaligned_le32(nonce + 8), + get_unaligned_le32(nonce + 12) + }; + + TWENTY_ROUNDS(x); + + memcpy(derived_key + 0, x + 0, sizeof(uint32_t) * 4); + memcpy(derived_key + 4, x + 12, sizeof(uint32_t) * 4); +} + +enum poly1305_lengths { + POLY1305_BLOCK_SIZE = 16, + POLY1305_KEY_SIZE = 32, + POLY1305_MAC_SIZE = 16 +}; + +struct poly1305_internal { + uint32_t h[5]; + uint32_t r[5]; + uint32_t s[4]; +}; + +struct poly1305_ctx { + struct poly1305_internal state; + uint32_t nonce[4]; + uint8_t data[POLY1305_BLOCK_SIZE]; + size_t num; +}; + +static void poly1305_init_core(struct poly1305_internal *st, + const uint8_t key[16]) +{ + /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */ + st->r[0] = (get_unaligned_le32(&key[0])) & 0x3ffffff; + st->r[1] = (get_unaligned_le32(&key[3]) >> 2) & 0x3ffff03; + st->r[2] = (get_unaligned_le32(&key[6]) >> 4) & 0x3ffc0ff; + st->r[3] = (get_unaligned_le32(&key[9]) >> 6) & 0x3f03fff; + st->r[4] = (get_unaligned_le32(&key[12]) >> 8) & 0x00fffff; + + /* s = 5*r */ + st->s[0] = st->r[1] * 5; + st->s[1] = st->r[2] * 5; + st->s[2] = st->r[3] * 5; + st->s[3] = st->r[4] * 5; + + /* h = 0 */ + st->h[0] = 0; + st->h[1] = 0; + st->h[2] = 0; + st->h[3] = 0; + st->h[4] = 0; +} + +static void poly1305_blocks_core(struct poly1305_internal *st, + const uint8_t *input, size_t len, + const uint32_t padbit) +{ + const uint32_t hibit = padbit << 24; + uint32_t r0, r1, r2, r3, r4; + uint32_t s1, s2, s3, s4; + uint32_t h0, h1, h2, h3, h4; + uint64_t d0, d1, d2, d3, d4; + uint32_t c; + + r0 = st->r[0]; + r1 = st->r[1]; + r2 = st->r[2]; + r3 = st->r[3]; + r4 = st->r[4]; + + s1 = st->s[0]; + s2 = st->s[1]; + s3 = st->s[2]; + s4 = st->s[3]; + + h0 = st->h[0]; + h1 = st->h[1]; + h2 = st->h[2]; + h3 = st->h[3]; + h4 = st->h[4]; + + while (len >= POLY1305_BLOCK_SIZE) { + /* h += m[i] */ + h0 += (get_unaligned_le32(&input[0])) & 0x3ffffff; + h1 += (get_unaligned_le32(&input[3]) >> 2) & 0x3ffffff; + h2 += (get_unaligned_le32(&input[6]) >> 4) & 0x3ffffff; + h3 += (get_unaligned_le32(&input[9]) >> 6) & 0x3ffffff; + h4 += (get_unaligned_le32(&input[12]) >> 8) | hibit; + + /* h *= r */ + d0 = ((uint64_t)h0 * r0) + ((uint64_t)h1 * s4) + + ((uint64_t)h2 * s3) + ((uint64_t)h3 * s2) + + ((uint64_t)h4 * s1); + d1 = ((uint64_t)h0 * r1) + ((uint64_t)h1 * r0) + + ((uint64_t)h2 * s4) + ((uint64_t)h3 * s3) + + ((uint64_t)h4 * s2); + d2 = ((uint64_t)h0 * r2) + ((uint64_t)h1 * r1) + + ((uint64_t)h2 * r0) + ((uint64_t)h3 * s4) + + ((uint64_t)h4 * s3); + d3 = ((uint64_t)h0 * r3) + ((uint64_t)h1 * r2) + + ((uint64_t)h2 * r1) + ((uint64_t)h3 * r0) + + ((uint64_t)h4 * s4); + d4 = ((uint64_t)h0 * r4) + ((uint64_t)h1 * r3) + + ((uint64_t)h2 * r2) + ((uint64_t)h3 * r1) + + ((uint64_t)h4 * r0); + + /* (partial) h %= p */ + c = (uint32_t)(d0 >> 26); + h0 = (uint32_t)d0 & 0x3ffffff; + d1 += c; + c = (uint32_t)(d1 >> 26); + h1 = (uint32_t)d1 & 0x3ffffff; + d2 += c; + c = (uint32_t)(d2 >> 26); + h2 = (uint32_t)d2 & 0x3ffffff; + d3 += c; + c = (uint32_t)(d3 >> 26); + h3 = (uint32_t)d3 & 0x3ffffff; + d4 += c; + c = (uint32_t)(d4 >> 26); + h4 = (uint32_t)d4 & 0x3ffffff; + h0 += c * 5; + c = (h0 >> 26); + h0 = h0 & 0x3ffffff; + h1 += c; + + input += POLY1305_BLOCK_SIZE; + len -= POLY1305_BLOCK_SIZE; + } + + st->h[0] = h0; + st->h[1] = h1; + st->h[2] = h2; + st->h[3] = h3; + st->h[4] = h4; +} + +static void poly1305_emit_core(struct poly1305_internal *st, uint8_t mac[16], + const uint32_t nonce[4]) +{ + uint32_t h0, h1, h2, h3, h4, c; + uint32_t g0, g1, g2, g3, g4; + uint64_t f; + uint32_t mask; + + /* fully carry h */ + h0 = st->h[0]; + h1 = st->h[1]; + h2 = st->h[2]; + h3 = st->h[3]; + h4 = st->h[4]; + + c = h1 >> 26; + h1 = h1 & 0x3ffffff; + h2 += c; + c = h2 >> 26; + h2 = h2 & 0x3ffffff; + h3 += c; + c = h3 >> 26; + h3 = h3 & 0x3ffffff; + h4 += c; + c = h4 >> 26; + h4 = h4 & 0x3ffffff; + h0 += c * 5; + c = h0 >> 26; + h0 = h0 & 0x3ffffff; + h1 += c; + + /* compute h + -p */ + g0 = h0 + 5; + c = g0 >> 26; + g0 &= 0x3ffffff; + g1 = h1 + c; + c = g1 >> 26; + g1 &= 0x3ffffff; + g2 = h2 + c; + c = g2 >> 26; + g2 &= 0x3ffffff; + g3 = h3 + c; + c = g3 >> 26; + g3 &= 0x3ffffff; + g4 = h4 + c - (1UL << 26); + + /* select h if h < p, or h + -p if h >= p */ + mask = (g4 >> ((sizeof(uint32_t) * 8) - 1)) - 1; + g0 &= mask; + g1 &= mask; + g2 &= mask; + g3 &= mask; + g4 &= mask; + mask = ~mask; + + h0 = (h0 & mask) | g0; + h1 = (h1 & mask) | g1; + h2 = (h2 & mask) | g2; + h3 = (h3 & mask) | g3; + h4 = (h4 & mask) | g4; + + /* h = h % (2^128) */ + h0 = ((h0) | (h1 << 26)) & 0xffffffff; + h1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffff; + h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff; + h3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffff; + + /* mac = (h + nonce) % (2^128) */ + f = (uint64_t)h0 + nonce[0]; + h0 = (uint32_t)f; + f = (uint64_t)h1 + nonce[1] + (f >> 32); + h1 = (uint32_t)f; + f = (uint64_t)h2 + nonce[2] + (f >> 32); + h2 = (uint32_t)f; + f = (uint64_t)h3 + nonce[3] + (f >> 32); + h3 = (uint32_t)f; + + put_unaligned_le32(h0, &mac[0]); + put_unaligned_le32(h1, &mac[4]); + put_unaligned_le32(h2, &mac[8]); + put_unaligned_le32(h3, &mac[12]); +} + +static void poly1305_init(struct poly1305_ctx *ctx, + const uint8_t key[POLY1305_KEY_SIZE]) +{ + ctx->nonce[0] = get_unaligned_le32(&key[16]); + ctx->nonce[1] = get_unaligned_le32(&key[20]); + ctx->nonce[2] = get_unaligned_le32(&key[24]); + ctx->nonce[3] = get_unaligned_le32(&key[28]); + + poly1305_init_core(&ctx->state, key); + + ctx->num = 0; +} + +static void poly1305_update(struct poly1305_ctx *ctx, const uint8_t *input, + size_t len) +{ + const size_t num = ctx->num; + size_t rem; + + if (num) { + rem = POLY1305_BLOCK_SIZE - num; + if (len < rem) { + memcpy(ctx->data + num, input, len); + ctx->num = num + len; + return; + } + memcpy(ctx->data + num, input, rem); + poly1305_blocks_core(&ctx->state, ctx->data, + POLY1305_BLOCK_SIZE, 1); + input += rem; + len -= rem; + } + + rem = len % POLY1305_BLOCK_SIZE; + len -= rem; + + if (len >= POLY1305_BLOCK_SIZE) { + poly1305_blocks_core(&ctx->state, input, len, 1); + input += len; + } + + if (rem) + memcpy(ctx->data, input, rem); + + ctx->num = rem; +} + +static void poly1305_final(struct poly1305_ctx *ctx, + uint8_t mac[POLY1305_MAC_SIZE]) +{ + size_t num = ctx->num; + + if (num) { + ctx->data[num++] = 1; + while (num < POLY1305_BLOCK_SIZE) + ctx->data[num++] = 0; + poly1305_blocks_core(&ctx->state, ctx->data, + POLY1305_BLOCK_SIZE, 0); + } + + poly1305_emit_core(&ctx->state, mac, ctx->nonce); + + explicit_bzero(ctx, sizeof(*ctx)); +} +#endif + +#ifdef COMPAT_NEED_CHACHA20POLY1305 +static const uint8_t pad0[16] = { 0 }; + +void +chacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + struct poly1305_ctx poly1305_state; + struct chacha20_ctx chacha20_state; + union { + uint8_t block0[POLY1305_KEY_SIZE]; + uint64_t lens[2]; + } b = { { 0 } }; + + chacha20_init(&chacha20_state, key, nonce); + chacha20(&chacha20_state, b.block0, b.block0, sizeof(b.block0)); + poly1305_init(&poly1305_state, b.block0); + + poly1305_update(&poly1305_state, ad, ad_len); + poly1305_update(&poly1305_state, pad0, (0x10 - ad_len) & 0xf); + + chacha20(&chacha20_state, dst, src, src_len); + + poly1305_update(&poly1305_state, dst, src_len); + poly1305_update(&poly1305_state, pad0, (0x10 - src_len) & 0xf); + + b.lens[0] = cpu_to_le64(ad_len); + b.lens[1] = cpu_to_le64(src_len); + poly1305_update(&poly1305_state, (uint8_t *)b.lens, sizeof(b.lens)); + + poly1305_final(&poly1305_state, dst + src_len); + + explicit_bzero(&chacha20_state, sizeof(chacha20_state)); + explicit_bzero(&b, sizeof(b)); +} + +bool +chacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + struct poly1305_ctx poly1305_state; + struct chacha20_ctx chacha20_state; + bool ret; + size_t dst_len; + union { + uint8_t block0[POLY1305_KEY_SIZE]; + uint8_t mac[POLY1305_MAC_SIZE]; + uint64_t lens[2]; + } b = { { 0 } }; + + if (src_len < POLY1305_MAC_SIZE) + return false; + + chacha20_init(&chacha20_state, key, nonce); + chacha20(&chacha20_state, b.block0, b.block0, sizeof(b.block0)); + poly1305_init(&poly1305_state, b.block0); + + poly1305_update(&poly1305_state, ad, ad_len); + poly1305_update(&poly1305_state, pad0, (0x10 - ad_len) & 0xf); + + dst_len = src_len - POLY1305_MAC_SIZE; + poly1305_update(&poly1305_state, src, dst_len); + poly1305_update(&poly1305_state, pad0, (0x10 - dst_len) & 0xf); + + b.lens[0] = cpu_to_le64(ad_len); + b.lens[1] = cpu_to_le64(dst_len); + poly1305_update(&poly1305_state, (uint8_t *)b.lens, sizeof(b.lens)); + + poly1305_final(&poly1305_state, b.mac); + + ret = timingsafe_bcmp(b.mac, src + dst_len, POLY1305_MAC_SIZE) == 0; + if (ret) + chacha20(&chacha20_state, dst, src, dst_len); + + explicit_bzero(&chacha20_state, sizeof(chacha20_state)); + explicit_bzero(&b, sizeof(b)); + + return ret; +} + +void +xchacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + uint32_t derived_key[CHACHA20_KEY_WORDS]; + + hchacha20(derived_key, nonce, key); + cpu_to_le32_array(derived_key, ARRAY_SIZE(derived_key)); + chacha20poly1305_encrypt(dst, src, src_len, ad, ad_len, + get_unaligned_le64(nonce + 16), + (uint8_t *)derived_key); + explicit_bzero(derived_key, CHACHA20POLY1305_KEY_SIZE); +} + +bool +xchacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + bool ret; + uint32_t derived_key[CHACHA20_KEY_WORDS]; + + hchacha20(derived_key, nonce, key); + cpu_to_le32_array(derived_key, ARRAY_SIZE(derived_key)); + ret = chacha20poly1305_decrypt(dst, src, src_len, ad, ad_len, + get_unaligned_le64(nonce + 16), + (uint8_t *)derived_key); + explicit_bzero(derived_key, CHACHA20POLY1305_KEY_SIZE); + return ret; +} +#endif + +#ifdef COMPAT_NEED_CHACHA20POLY1305_MBUF +static inline int +chacha20poly1305_crypt_mbuf(struct mbuf *m0, uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE], bool encrypt) +{ + struct poly1305_ctx poly1305_state; + struct chacha20_ctx chacha20_state; + uint8_t *buf, mbuf_mac[POLY1305_MAC_SIZE]; + size_t len, leftover = 0; + struct mbuf *m; + int ret; + union { + uint32_t stream[CHACHA20_BLOCK_WORDS]; + uint8_t block0[POLY1305_KEY_SIZE]; + uint8_t mac[POLY1305_MAC_SIZE]; + uint64_t lens[2]; + } b = { { 0 } }; + + if (!encrypt) { + if (m0->m_pkthdr.len < POLY1305_MAC_SIZE) + return EMSGSIZE; + m_copydata(m0, m0->m_pkthdr.len - POLY1305_MAC_SIZE, POLY1305_MAC_SIZE, mbuf_mac); + m_adj(m0, -POLY1305_MAC_SIZE); + } + + chacha20_init(&chacha20_state, key, nonce); + chacha20(&chacha20_state, b.block0, b.block0, sizeof(b.block0)); + poly1305_init(&poly1305_state, b.block0); + + for (m = m0; m; m = m->m_next) { + len = m->m_len; + buf = m->m_data; + + if (!encrypt) + poly1305_update(&poly1305_state, m->m_data, m->m_len); + + if (leftover != 0) { + size_t l = min(len, leftover); + xor_cpy(buf, buf, ((uint8_t *)b.stream) + (CHACHA20_BLOCK_SIZE - leftover), l); + leftover -= l; + buf += l; + len -= l; + } + + while (len >= CHACHA20_BLOCK_SIZE) { + chacha20_block(&chacha20_state, b.stream); + xor_cpy(buf, buf, (uint8_t *)b.stream, CHACHA20_BLOCK_SIZE); + buf += CHACHA20_BLOCK_SIZE; + len -= CHACHA20_BLOCK_SIZE; + } + + if (len) { + chacha20_block(&chacha20_state, b.stream); + xor_cpy(buf, buf, (uint8_t *)b.stream, len); + leftover = CHACHA20_BLOCK_SIZE - len; + } + + if (encrypt) + poly1305_update(&poly1305_state, m->m_data, m->m_len); + } + poly1305_update(&poly1305_state, pad0, (0x10 - m0->m_pkthdr.len) & 0xf); + + b.lens[0] = 0; + b.lens[1] = cpu_to_le64(m0->m_pkthdr.len); + poly1305_update(&poly1305_state, (uint8_t *)b.lens, sizeof(b.lens)); + + poly1305_final(&poly1305_state, b.mac); + + if (encrypt) + ret = m_append(m0, POLY1305_MAC_SIZE, b.mac) ? 0 : ENOMEM; + else + ret = timingsafe_bcmp(b.mac, mbuf_mac, POLY1305_MAC_SIZE) == 0 ? 0 : EBADMSG; + + explicit_bzero(&chacha20_state, sizeof(chacha20_state)); + explicit_bzero(&b, sizeof(b)); + + return ret; +} + +int +chacha20poly1305_encrypt_mbuf(struct mbuf *m, const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + return chacha20poly1305_crypt_mbuf(m, nonce, key, true); +} + +int +chacha20poly1305_decrypt_mbuf(struct mbuf *m, const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + return chacha20poly1305_crypt_mbuf(m, nonce, key, false); +} +#else +static int +crypto_callback(struct cryptop *crp) +{ + return (0); +} + +int +chacha20poly1305_encrypt_mbuf(struct mbuf *m, const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + static const char blank_tag[POLY1305_HASH_LEN]; + struct cryptop crp; + int ret; + + if (!m_append(m, POLY1305_HASH_LEN, blank_tag)) + return (ENOMEM); + crypto_initreq(&crp, chacha20_poly1305_sid); + crp.crp_op = CRYPTO_OP_ENCRYPT | CRYPTO_OP_COMPUTE_DIGEST; + crp.crp_flags = CRYPTO_F_IV_SEPARATE | CRYPTO_F_CBIMM; + crypto_use_mbuf(&crp, m); + crp.crp_payload_length = m->m_pkthdr.len - POLY1305_HASH_LEN; + crp.crp_digest_start = crp.crp_payload_length; + le64enc(crp.crp_iv, nonce); + crp.crp_cipher_key = key; + crp.crp_callback = crypto_callback; + ret = crypto_dispatch(&crp); + crypto_destroyreq(&crp); + return (ret); +} + +int +chacha20poly1305_decrypt_mbuf(struct mbuf *m, const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + struct cryptop crp; + int ret; + + if (m->m_pkthdr.len < POLY1305_HASH_LEN) + return (EMSGSIZE); + crypto_initreq(&crp, chacha20_poly1305_sid); + crp.crp_op = CRYPTO_OP_DECRYPT | CRYPTO_OP_VERIFY_DIGEST; + crp.crp_flags = CRYPTO_F_IV_SEPARATE | CRYPTO_F_CBIMM; + crypto_use_mbuf(&crp, m); + crp.crp_payload_length = m->m_pkthdr.len - POLY1305_HASH_LEN; + crp.crp_digest_start = crp.crp_payload_length; + le64enc(crp.crp_iv, nonce); + crp.crp_cipher_key = key; + crp.crp_callback = crypto_callback; + ret = crypto_dispatch(&crp); + crypto_destroyreq(&crp); + if (ret) + return (ret); + m_adj(m, -POLY1305_HASH_LEN); + return (0); +} +#endif + +#ifdef COMPAT_NEED_BLAKE2S +static const uint32_t blake2s_iv[8] = { + 0x6A09E667UL, 0xBB67AE85UL, 0x3C6EF372UL, 0xA54FF53AUL, + 0x510E527FUL, 0x9B05688CUL, 0x1F83D9ABUL, 0x5BE0CD19UL +}; + +static const uint8_t blake2s_sigma[10][16] = { + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, + { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 }, + { 11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4 }, + { 7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8 }, + { 9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13 }, + { 2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9 }, + { 12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11 }, + { 13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10 }, + { 6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5 }, + { 10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0 }, +}; + +static inline void blake2s_set_lastblock(struct blake2s_state *state) +{ + state->f[0] = -1; +} + +static inline void blake2s_increment_counter(struct blake2s_state *state, + const uint32_t inc) +{ + state->t[0] += inc; + state->t[1] += (state->t[0] < inc); +} + +static inline void blake2s_init_param(struct blake2s_state *state, + const uint32_t param) +{ + int i; + + memset(state, 0, sizeof(*state)); + for (i = 0; i < 8; ++i) + state->h[i] = blake2s_iv[i]; + state->h[0] ^= param; +} + +void blake2s_init(struct blake2s_state *state, const size_t outlen) +{ + blake2s_init_param(state, 0x01010000 | outlen); + state->outlen = outlen; +} + +void blake2s_init_key(struct blake2s_state *state, const size_t outlen, + const uint8_t *key, const size_t keylen) +{ + uint8_t block[BLAKE2S_BLOCK_SIZE] = { 0 }; + + blake2s_init_param(state, 0x01010000 | keylen << 8 | outlen); + state->outlen = outlen; + memcpy(block, key, keylen); + blake2s_update(state, block, BLAKE2S_BLOCK_SIZE); + explicit_bzero(block, BLAKE2S_BLOCK_SIZE); +} + +static inline void blake2s_compress(struct blake2s_state *state, + const uint8_t *block, size_t nblocks, + const uint32_t inc) +{ + uint32_t m[16]; + uint32_t v[16]; + int i; + + while (nblocks > 0) { + blake2s_increment_counter(state, inc); + memcpy(m, block, BLAKE2S_BLOCK_SIZE); + le32_to_cpu_array(m, ARRAY_SIZE(m)); + memcpy(v, state->h, 32); + v[ 8] = blake2s_iv[0]; + v[ 9] = blake2s_iv[1]; + v[10] = blake2s_iv[2]; + v[11] = blake2s_iv[3]; + v[12] = blake2s_iv[4] ^ state->t[0]; + v[13] = blake2s_iv[5] ^ state->t[1]; + v[14] = blake2s_iv[6] ^ state->f[0]; + v[15] = blake2s_iv[7] ^ state->f[1]; + +#define G(r, i, a, b, c, d) do { \ + a += b + m[blake2s_sigma[r][2 * i + 0]]; \ + d = ror32(d ^ a, 16); \ + c += d; \ + b = ror32(b ^ c, 12); \ + a += b + m[blake2s_sigma[r][2 * i + 1]]; \ + d = ror32(d ^ a, 8); \ + c += d; \ + b = ror32(b ^ c, 7); \ +} while (0) + +#define ROUND(r) do { \ + G(r, 0, v[0], v[ 4], v[ 8], v[12]); \ + G(r, 1, v[1], v[ 5], v[ 9], v[13]); \ + G(r, 2, v[2], v[ 6], v[10], v[14]); \ + G(r, 3, v[3], v[ 7], v[11], v[15]); \ + G(r, 4, v[0], v[ 5], v[10], v[15]); \ + G(r, 5, v[1], v[ 6], v[11], v[12]); \ + G(r, 6, v[2], v[ 7], v[ 8], v[13]); \ + G(r, 7, v[3], v[ 4], v[ 9], v[14]); \ +} while (0) + ROUND(0); + ROUND(1); + ROUND(2); + ROUND(3); + ROUND(4); + ROUND(5); + ROUND(6); + ROUND(7); + ROUND(8); + ROUND(9); + +#undef G +#undef ROUND + + for (i = 0; i < 8; ++i) + state->h[i] ^= v[i] ^ v[i + 8]; + + block += BLAKE2S_BLOCK_SIZE; + --nblocks; + } +} + +void blake2s_update(struct blake2s_state *state, const uint8_t *in, size_t inlen) +{ + const size_t fill = BLAKE2S_BLOCK_SIZE - state->buflen; + + if (!inlen) + return; + if (inlen > fill) { + memcpy(state->buf + state->buflen, in, fill); + blake2s_compress(state, state->buf, 1, BLAKE2S_BLOCK_SIZE); + state->buflen = 0; + in += fill; + inlen -= fill; + } + if (inlen > BLAKE2S_BLOCK_SIZE) { + const size_t nblocks = DIV_ROUND_UP(inlen, BLAKE2S_BLOCK_SIZE); + /* Hash one less (full) block than strictly possible */ + blake2s_compress(state, in, nblocks - 1, BLAKE2S_BLOCK_SIZE); + in += BLAKE2S_BLOCK_SIZE * (nblocks - 1); + inlen -= BLAKE2S_BLOCK_SIZE * (nblocks - 1); + } + memcpy(state->buf + state->buflen, in, inlen); + state->buflen += inlen; +} + +void blake2s_final(struct blake2s_state *state, uint8_t *out) +{ + blake2s_set_lastblock(state); + memset(state->buf + state->buflen, 0, + BLAKE2S_BLOCK_SIZE - state->buflen); /* Padding */ + blake2s_compress(state, state->buf, 1, state->buflen); + cpu_to_le32_array(state->h, ARRAY_SIZE(state->h)); + memcpy(out, state->h, state->outlen); + explicit_bzero(state, sizeof(*state)); +} +#endif + +#ifdef COMPAT_NEED_CURVE25519 +/* Below here is fiat's implementation of x25519. + * + * Copyright (C) 2015-2016 The fiat-crypto Authors. + * Copyright (C) 2018-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * + * This is a machine-generated formally verified implementation of Curve25519 + * ECDH from: <https://github.com/mit-plv/fiat-crypto>. Though originally + * machine generated, it has been tweaked to be suitable for use in the kernel. + * It is optimized for 32-bit machines and machines that cannot work efficiently + * with 128-bit integer types. + */ + +/* fe means field element. Here the field is \Z/(2^255-19). An element t, + * entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77 + * t[3]+2^102 t[4]+...+2^230 t[9]. + * fe limbs are bounded by 1.125*2^26,1.125*2^25,1.125*2^26,1.125*2^25,etc. + * Multiplication and carrying produce fe from fe_loose. + */ +typedef struct fe { uint32_t v[10]; } fe; + +/* fe_loose limbs are bounded by 3.375*2^26,3.375*2^25,3.375*2^26,3.375*2^25,etc + * Addition and subtraction produce fe_loose from (fe, fe). + */ +typedef struct fe_loose { uint32_t v[10]; } fe_loose; + +static inline void fe_frombytes_impl(uint32_t h[10], const uint8_t *s) +{ + /* Ignores top bit of s. */ + uint32_t a0 = get_unaligned_le32(s); + uint32_t a1 = get_unaligned_le32(s+4); + uint32_t a2 = get_unaligned_le32(s+8); + uint32_t a3 = get_unaligned_le32(s+12); + uint32_t a4 = get_unaligned_le32(s+16); + uint32_t a5 = get_unaligned_le32(s+20); + uint32_t a6 = get_unaligned_le32(s+24); + uint32_t a7 = get_unaligned_le32(s+28); + h[0] = a0&((1<<26)-1); /* 26 used, 32-26 left. 26 */ + h[1] = (a0>>26) | ((a1&((1<<19)-1))<< 6); /* (32-26) + 19 = 6+19 = 25 */ + h[2] = (a1>>19) | ((a2&((1<<13)-1))<<13); /* (32-19) + 13 = 13+13 = 26 */ + h[3] = (a2>>13) | ((a3&((1<< 6)-1))<<19); /* (32-13) + 6 = 19+ 6 = 25 */ + h[4] = (a3>> 6); /* (32- 6) = 26 */ + h[5] = a4&((1<<25)-1); /* 25 */ + h[6] = (a4>>25) | ((a5&((1<<19)-1))<< 7); /* (32-25) + 19 = 7+19 = 26 */ + h[7] = (a5>>19) | ((a6&((1<<12)-1))<<13); /* (32-19) + 12 = 13+12 = 25 */ + h[8] = (a6>>12) | ((a7&((1<< 6)-1))<<20); /* (32-12) + 6 = 20+ 6 = 26 */ + h[9] = (a7>> 6)&((1<<25)-1); /* 25 */ +} + +static inline void fe_frombytes(fe *h, const uint8_t *s) +{ + fe_frombytes_impl(h->v, s); +} + +static inline uint8_t /*bool*/ +addcarryx_u25(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 25 bits of result and 1 bit of carry + * (26 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a + b + c; + *low = x & ((1 << 25) - 1); + return (x >> 25) & 1; +} + +static inline uint8_t /*bool*/ +addcarryx_u26(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 26 bits of result and 1 bit of carry + * (27 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a + b + c; + *low = x & ((1 << 26) - 1); + return (x >> 26) & 1; +} + +static inline uint8_t /*bool*/ +subborrow_u25(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 25 bits of result and 1 bit of borrow + * (26 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a - b - c; + *low = x & ((1 << 25) - 1); + return x >> 31; +} + +static inline uint8_t /*bool*/ +subborrow_u26(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 26 bits of result and 1 bit of borrow + *(27 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a - b - c; + *low = x & ((1 << 26) - 1); + return x >> 31; +} + +static inline uint32_t cmovznz32(uint32_t t, uint32_t z, uint32_t nz) +{ + t = -!!t; /* all set if nonzero, 0 if 0 */ + return (t&nz) | ((~t)&z); +} + +static inline void fe_freeze(uint32_t out[10], const uint32_t in1[10]) +{ + const uint32_t x17 = in1[9]; + const uint32_t x18 = in1[8]; + const uint32_t x16 = in1[7]; + const uint32_t x14 = in1[6]; + const uint32_t x12 = in1[5]; + const uint32_t x10 = in1[4]; + const uint32_t x8 = in1[3]; + const uint32_t x6 = in1[2]; + const uint32_t x4 = in1[1]; + const uint32_t x2 = in1[0]; + uint32_t x20; uint8_t/*bool*/ x21 = subborrow_u26(0x0, x2, 0x3ffffed, &x20); + uint32_t x23; uint8_t/*bool*/ x24 = subborrow_u25(x21, x4, 0x1ffffff, &x23); + uint32_t x26; uint8_t/*bool*/ x27 = subborrow_u26(x24, x6, 0x3ffffff, &x26); + uint32_t x29; uint8_t/*bool*/ x30 = subborrow_u25(x27, x8, 0x1ffffff, &x29); + uint32_t x32; uint8_t/*bool*/ x33 = subborrow_u26(x30, x10, 0x3ffffff, &x32); + uint32_t x35; uint8_t/*bool*/ x36 = subborrow_u25(x33, x12, 0x1ffffff, &x35); + uint32_t x38; uint8_t/*bool*/ x39 = subborrow_u26(x36, x14, 0x3ffffff, &x38); + uint32_t x41; uint8_t/*bool*/ x42 = subborrow_u25(x39, x16, 0x1ffffff, &x41); + uint32_t x44; uint8_t/*bool*/ x45 = subborrow_u26(x42, x18, 0x3ffffff, &x44); + uint32_t x47; uint8_t/*bool*/ x48 = subborrow_u25(x45, x17, 0x1ffffff, &x47); + uint32_t x49 = cmovznz32(x48, 0x0, 0xffffffff); + uint32_t x50 = (x49 & 0x3ffffed); + uint32_t x52; uint8_t/*bool*/ x53 = addcarryx_u26(0x0, x20, x50, &x52); + uint32_t x54 = (x49 & 0x1ffffff); + uint32_t x56; uint8_t/*bool*/ x57 = addcarryx_u25(x53, x23, x54, &x56); + uint32_t x58 = (x49 & 0x3ffffff); + uint32_t x60; uint8_t/*bool*/ x61 = addcarryx_u26(x57, x26, x58, &x60); + uint32_t x62 = (x49 & 0x1ffffff); + uint32_t x64; uint8_t/*bool*/ x65 = addcarryx_u25(x61, x29, x62, &x64); + uint32_t x66 = (x49 & 0x3ffffff); + uint32_t x68; uint8_t/*bool*/ x69 = addcarryx_u26(x65, x32, x66, &x68); + uint32_t x70 = (x49 & 0x1ffffff); + uint32_t x72; uint8_t/*bool*/ x73 = addcarryx_u25(x69, x35, x70, &x72); + uint32_t x74 = (x49 & 0x3ffffff); + uint32_t x76; uint8_t/*bool*/ x77 = addcarryx_u26(x73, x38, x74, &x76); + uint32_t x78 = (x49 & 0x1ffffff); + uint32_t x80; uint8_t/*bool*/ x81 = addcarryx_u25(x77, x41, x78, &x80); + uint32_t x82 = (x49 & 0x3ffffff); + uint32_t x84; uint8_t/*bool*/ x85 = addcarryx_u26(x81, x44, x82, &x84); + uint32_t x86 = (x49 & 0x1ffffff); + uint32_t x88; addcarryx_u25(x85, x47, x86, &x88); + out[0] = x52; + out[1] = x56; + out[2] = x60; + out[3] = x64; + out[4] = x68; + out[5] = x72; + out[6] = x76; + out[7] = x80; + out[8] = x84; + out[9] = x88; +} + +static inline void fe_tobytes(uint8_t s[32], const fe *f) +{ + uint32_t h[10]; + fe_freeze(h, f->v); + s[0] = h[0] >> 0; + s[1] = h[0] >> 8; + s[2] = h[0] >> 16; + s[3] = (h[0] >> 24) | (h[1] << 2); + s[4] = h[1] >> 6; + s[5] = h[1] >> 14; + s[6] = (h[1] >> 22) | (h[2] << 3); + s[7] = h[2] >> 5; + s[8] = h[2] >> 13; + s[9] = (h[2] >> 21) | (h[3] << 5); + s[10] = h[3] >> 3; + s[11] = h[3] >> 11; + s[12] = (h[3] >> 19) | (h[4] << 6); + s[13] = h[4] >> 2; + s[14] = h[4] >> 10; + s[15] = h[4] >> 18; + s[16] = h[5] >> 0; + s[17] = h[5] >> 8; + s[18] = h[5] >> 16; + s[19] = (h[5] >> 24) | (h[6] << 1); + s[20] = h[6] >> 7; + s[21] = h[6] >> 15; + s[22] = (h[6] >> 23) | (h[7] << 3); + s[23] = h[7] >> 5; + s[24] = h[7] >> 13; + s[25] = (h[7] >> 21) | (h[8] << 4); + s[26] = h[8] >> 4; + s[27] = h[8] >> 12; + s[28] = (h[8] >> 20) | (h[9] << 6); + s[29] = h[9] >> 2; + s[30] = h[9] >> 10; + s[31] = h[9] >> 18; +} + +/* h = f */ +static inline void fe_copy(fe *h, const fe *f) +{ + memmove(h, f, sizeof(uint32_t) * 10); +} + +static inline void fe_copy_lt(fe_loose *h, const fe *f) +{ + memmove(h, f, sizeof(uint32_t) * 10); +} + +/* h = 0 */ +static inline void fe_0(fe *h) +{ + memset(h, 0, sizeof(uint32_t) * 10); +} + +/* h = 1 */ +static inline void fe_1(fe *h) +{ + memset(h, 0, sizeof(uint32_t) * 10); + h->v[0] = 1; +} + +static void fe_add_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = in2[9]; + const uint32_t x39 = in2[8]; + const uint32_t x37 = in2[7]; + const uint32_t x35 = in2[6]; + const uint32_t x33 = in2[5]; + const uint32_t x31 = in2[4]; + const uint32_t x29 = in2[3]; + const uint32_t x27 = in2[2]; + const uint32_t x25 = in2[1]; + const uint32_t x23 = in2[0]; + out[0] = (x5 + x23); + out[1] = (x7 + x25); + out[2] = (x9 + x27); + out[3] = (x11 + x29); + out[4] = (x13 + x31); + out[5] = (x15 + x33); + out[6] = (x17 + x35); + out[7] = (x19 + x37); + out[8] = (x21 + x39); + out[9] = (x20 + x38); +} + +/* h = f + g + * Can overlap h with f or g. + */ +static inline void fe_add(fe_loose *h, const fe *f, const fe *g) +{ + fe_add_impl(h->v, f->v, g->v); +} + +static void fe_sub_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = in2[9]; + const uint32_t x39 = in2[8]; + const uint32_t x37 = in2[7]; + const uint32_t x35 = in2[6]; + const uint32_t x33 = in2[5]; + const uint32_t x31 = in2[4]; + const uint32_t x29 = in2[3]; + const uint32_t x27 = in2[2]; + const uint32_t x25 = in2[1]; + const uint32_t x23 = in2[0]; + out[0] = ((0x7ffffda + x5) - x23); + out[1] = ((0x3fffffe + x7) - x25); + out[2] = ((0x7fffffe + x9) - x27); + out[3] = ((0x3fffffe + x11) - x29); + out[4] = ((0x7fffffe + x13) - x31); + out[5] = ((0x3fffffe + x15) - x33); + out[6] = ((0x7fffffe + x17) - x35); + out[7] = ((0x3fffffe + x19) - x37); + out[8] = ((0x7fffffe + x21) - x39); + out[9] = ((0x3fffffe + x20) - x38); +} + +/* h = f - g + * Can overlap h with f or g. + */ +static inline void fe_sub(fe_loose *h, const fe *f, const fe *g) +{ + fe_sub_impl(h->v, f->v, g->v); +} + +static void fe_mul_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = in2[9]; + const uint32_t x39 = in2[8]; + const uint32_t x37 = in2[7]; + const uint32_t x35 = in2[6]; + const uint32_t x33 = in2[5]; + const uint32_t x31 = in2[4]; + const uint32_t x29 = in2[3]; + const uint32_t x27 = in2[2]; + const uint32_t x25 = in2[1]; + const uint32_t x23 = in2[0]; + uint64_t x40 = ((uint64_t)x23 * x5); + uint64_t x41 = (((uint64_t)x23 * x7) + ((uint64_t)x25 * x5)); + uint64_t x42 = ((((uint64_t)(0x2 * x25) * x7) + ((uint64_t)x23 * x9)) + ((uint64_t)x27 * x5)); + uint64_t x43 = (((((uint64_t)x25 * x9) + ((uint64_t)x27 * x7)) + ((uint64_t)x23 * x11)) + ((uint64_t)x29 * x5)); + uint64_t x44 = (((((uint64_t)x27 * x9) + (0x2 * (((uint64_t)x25 * x11) + ((uint64_t)x29 * x7)))) + ((uint64_t)x23 * x13)) + ((uint64_t)x31 * x5)); + uint64_t x45 = (((((((uint64_t)x27 * x11) + ((uint64_t)x29 * x9)) + ((uint64_t)x25 * x13)) + ((uint64_t)x31 * x7)) + ((uint64_t)x23 * x15)) + ((uint64_t)x33 * x5)); + uint64_t x46 = (((((0x2 * ((((uint64_t)x29 * x11) + ((uint64_t)x25 * x15)) + ((uint64_t)x33 * x7))) + ((uint64_t)x27 * x13)) + ((uint64_t)x31 * x9)) + ((uint64_t)x23 * x17)) + ((uint64_t)x35 * x5)); + uint64_t x47 = (((((((((uint64_t)x29 * x13) + ((uint64_t)x31 * x11)) + ((uint64_t)x27 * x15)) + ((uint64_t)x33 * x9)) + ((uint64_t)x25 * x17)) + ((uint64_t)x35 * x7)) + ((uint64_t)x23 * x19)) + ((uint64_t)x37 * x5)); + uint64_t x48 = (((((((uint64_t)x31 * x13) + (0x2 * (((((uint64_t)x29 * x15) + ((uint64_t)x33 * x11)) + ((uint64_t)x25 * x19)) + ((uint64_t)x37 * x7)))) + ((uint64_t)x27 * x17)) + ((uint64_t)x35 * x9)) + ((uint64_t)x23 * x21)) + ((uint64_t)x39 * x5)); + uint64_t x49 = (((((((((((uint64_t)x31 * x15) + ((uint64_t)x33 * x13)) + ((uint64_t)x29 * x17)) + ((uint64_t)x35 * x11)) + ((uint64_t)x27 * x19)) + ((uint64_t)x37 * x9)) + ((uint64_t)x25 * x21)) + ((uint64_t)x39 * x7)) + ((uint64_t)x23 * x20)) + ((uint64_t)x38 * x5)); + uint64_t x50 = (((((0x2 * ((((((uint64_t)x33 * x15) + ((uint64_t)x29 * x19)) + ((uint64_t)x37 * x11)) + ((uint64_t)x25 * x20)) + ((uint64_t)x38 * x7))) + ((uint64_t)x31 * x17)) + ((uint64_t)x35 * x13)) + ((uint64_t)x27 * x21)) + ((uint64_t)x39 * x9)); + uint64_t x51 = (((((((((uint64_t)x33 * x17) + ((uint64_t)x35 * x15)) + ((uint64_t)x31 * x19)) + ((uint64_t)x37 * x13)) + ((uint64_t)x29 * x21)) + ((uint64_t)x39 * x11)) + ((uint64_t)x27 * x20)) + ((uint64_t)x38 * x9)); + uint64_t x52 = (((((uint64_t)x35 * x17) + (0x2 * (((((uint64_t)x33 * x19) + ((uint64_t)x37 * x15)) + ((uint64_t)x29 * x20)) + ((uint64_t)x38 * x11)))) + ((uint64_t)x31 * x21)) + ((uint64_t)x39 * x13)); + uint64_t x53 = (((((((uint64_t)x35 * x19) + ((uint64_t)x37 * x17)) + ((uint64_t)x33 * x21)) + ((uint64_t)x39 * x15)) + ((uint64_t)x31 * x20)) + ((uint64_t)x38 * x13)); + uint64_t x54 = (((0x2 * ((((uint64_t)x37 * x19) + ((uint64_t)x33 * x20)) + ((uint64_t)x38 * x15))) + ((uint64_t)x35 * x21)) + ((uint64_t)x39 * x17)); + uint64_t x55 = (((((uint64_t)x37 * x21) + ((uint64_t)x39 * x19)) + ((uint64_t)x35 * x20)) + ((uint64_t)x38 * x17)); + uint64_t x56 = (((uint64_t)x39 * x21) + (0x2 * (((uint64_t)x37 * x20) + ((uint64_t)x38 * x19)))); + uint64_t x57 = (((uint64_t)x39 * x20) + ((uint64_t)x38 * x21)); + uint64_t x58 = ((uint64_t)(0x2 * x38) * x20); + uint64_t x59 = (x48 + (x58 << 0x4)); + uint64_t x60 = (x59 + (x58 << 0x1)); + uint64_t x61 = (x60 + x58); + uint64_t x62 = (x47 + (x57 << 0x4)); + uint64_t x63 = (x62 + (x57 << 0x1)); + uint64_t x64 = (x63 + x57); + uint64_t x65 = (x46 + (x56 << 0x4)); + uint64_t x66 = (x65 + (x56 << 0x1)); + uint64_t x67 = (x66 + x56); + uint64_t x68 = (x45 + (x55 << 0x4)); + uint64_t x69 = (x68 + (x55 << 0x1)); + uint64_t x70 = (x69 + x55); + uint64_t x71 = (x44 + (x54 << 0x4)); + uint64_t x72 = (x71 + (x54 << 0x1)); + uint64_t x73 = (x72 + x54); + uint64_t x74 = (x43 + (x53 << 0x4)); + uint64_t x75 = (x74 + (x53 << 0x1)); + uint64_t x76 = (x75 + x53); + uint64_t x77 = (x42 + (x52 << 0x4)); + uint64_t x78 = (x77 + (x52 << 0x1)); + uint64_t x79 = (x78 + x52); + uint64_t x80 = (x41 + (x51 << 0x4)); + uint64_t x81 = (x80 + (x51 << 0x1)); + uint64_t x82 = (x81 + x51); + uint64_t x83 = (x40 + (x50 << 0x4)); + uint64_t x84 = (x83 + (x50 << 0x1)); + uint64_t x85 = (x84 + x50); + uint64_t x86 = (x85 >> 0x1a); + uint32_t x87 = ((uint32_t)x85 & 0x3ffffff); + uint64_t x88 = (x86 + x82); + uint64_t x89 = (x88 >> 0x19); + uint32_t x90 = ((uint32_t)x88 & 0x1ffffff); + uint64_t x91 = (x89 + x79); + uint64_t x92 = (x91 >> 0x1a); + uint32_t x93 = ((uint32_t)x91 & 0x3ffffff); + uint64_t x94 = (x92 + x76); + uint64_t x95 = (x94 >> 0x19); + uint32_t x96 = ((uint32_t)x94 & 0x1ffffff); + uint64_t x97 = (x95 + x73); + uint64_t x98 = (x97 >> 0x1a); + uint32_t x99 = ((uint32_t)x97 & 0x3ffffff); + uint64_t x100 = (x98 + x70); + uint64_t x101 = (x100 >> 0x19); + uint32_t x102 = ((uint32_t)x100 & 0x1ffffff); + uint64_t x103 = (x101 + x67); + uint64_t x104 = (x103 >> 0x1a); + uint32_t x105 = ((uint32_t)x103 & 0x3ffffff); + uint64_t x106 = (x104 + x64); + uint64_t x107 = (x106 >> 0x19); + uint32_t x108 = ((uint32_t)x106 & 0x1ffffff); + uint64_t x109 = (x107 + x61); + uint64_t x110 = (x109 >> 0x1a); + uint32_t x111 = ((uint32_t)x109 & 0x3ffffff); + uint64_t x112 = (x110 + x49); + uint64_t x113 = (x112 >> 0x19); + uint32_t x114 = ((uint32_t)x112 & 0x1ffffff); + uint64_t x115 = (x87 + (0x13 * x113)); + uint32_t x116 = (uint32_t) (x115 >> 0x1a); + uint32_t x117 = ((uint32_t)x115 & 0x3ffffff); + uint32_t x118 = (x116 + x90); + uint32_t x119 = (x118 >> 0x19); + uint32_t x120 = (x118 & 0x1ffffff); + out[0] = x117; + out[1] = x120; + out[2] = (x119 + x93); + out[3] = x96; + out[4] = x99; + out[5] = x102; + out[6] = x105; + out[7] = x108; + out[8] = x111; + out[9] = x114; +} + +static inline void fe_mul_ttt(fe *h, const fe *f, const fe *g) +{ + fe_mul_impl(h->v, f->v, g->v); +} + +static inline void fe_mul_tlt(fe *h, const fe_loose *f, const fe *g) +{ + fe_mul_impl(h->v, f->v, g->v); +} + +static inline void +fe_mul_tll(fe *h, const fe_loose *f, const fe_loose *g) +{ + fe_mul_impl(h->v, f->v, g->v); +} + +static void fe_sqr_impl(uint32_t out[10], const uint32_t in1[10]) +{ + const uint32_t x17 = in1[9]; + const uint32_t x18 = in1[8]; + const uint32_t x16 = in1[7]; + const uint32_t x14 = in1[6]; + const uint32_t x12 = in1[5]; + const uint32_t x10 = in1[4]; + const uint32_t x8 = in1[3]; + const uint32_t x6 = in1[2]; + const uint32_t x4 = in1[1]; + const uint32_t x2 = in1[0]; + uint64_t x19 = ((uint64_t)x2 * x2); + uint64_t x20 = ((uint64_t)(0x2 * x2) * x4); + uint64_t x21 = (0x2 * (((uint64_t)x4 * x4) + ((uint64_t)x2 * x6))); + uint64_t x22 = (0x2 * (((uint64_t)x4 * x6) + ((uint64_t)x2 * x8))); + uint64_t x23 = ((((uint64_t)x6 * x6) + ((uint64_t)(0x4 * x4) * x8)) + ((uint64_t)(0x2 * x2) * x10)); + uint64_t x24 = (0x2 * ((((uint64_t)x6 * x8) + ((uint64_t)x4 * x10)) + ((uint64_t)x2 * x12))); + uint64_t x25 = (0x2 * (((((uint64_t)x8 * x8) + ((uint64_t)x6 * x10)) + ((uint64_t)x2 * x14)) + ((uint64_t)(0x2 * x4) * x12))); + uint64_t x26 = (0x2 * (((((uint64_t)x8 * x10) + ((uint64_t)x6 * x12)) + ((uint64_t)x4 * x14)) + ((uint64_t)x2 * x16))); + uint64_t x27 = (((uint64_t)x10 * x10) + (0x2 * ((((uint64_t)x6 * x14) + ((uint64_t)x2 * x18)) + (0x2 * (((uint64_t)x4 * x16) + ((uint64_t)x8 * x12)))))); + uint64_t x28 = (0x2 * ((((((uint64_t)x10 * x12) + ((uint64_t)x8 * x14)) + ((uint64_t)x6 * x16)) + ((uint64_t)x4 * x18)) + ((uint64_t)x2 * x17))); + uint64_t x29 = (0x2 * (((((uint64_t)x12 * x12) + ((uint64_t)x10 * x14)) + ((uint64_t)x6 * x18)) + (0x2 * (((uint64_t)x8 * x16) + ((uint64_t)x4 * x17))))); + uint64_t x30 = (0x2 * (((((uint64_t)x12 * x14) + ((uint64_t)x10 * x16)) + ((uint64_t)x8 * x18)) + ((uint64_t)x6 * x17))); + uint64_t x31 = (((uint64_t)x14 * x14) + (0x2 * (((uint64_t)x10 * x18) + (0x2 * (((uint64_t)x12 * x16) + ((uint64_t)x8 * x17)))))); + uint64_t x32 = (0x2 * ((((uint64_t)x14 * x16) + ((uint64_t)x12 * x18)) + ((uint64_t)x10 * x17))); + uint64_t x33 = (0x2 * ((((uint64_t)x16 * x16) + ((uint64_t)x14 * x18)) + ((uint64_t)(0x2 * x12) * x17))); + uint64_t x34 = (0x2 * (((uint64_t)x16 * x18) + ((uint64_t)x14 * x17))); + uint64_t x35 = (((uint64_t)x18 * x18) + ((uint64_t)(0x4 * x16) * x17)); + uint64_t x36 = ((uint64_t)(0x2 * x18) * x17); + uint64_t x37 = ((uint64_t)(0x2 * x17) * x17); + uint64_t x38 = (x27 + (x37 << 0x4)); + uint64_t x39 = (x38 + (x37 << 0x1)); + uint64_t x40 = (x39 + x37); + uint64_t x41 = (x26 + (x36 << 0x4)); + uint64_t x42 = (x41 + (x36 << 0x1)); + uint64_t x43 = (x42 + x36); + uint64_t x44 = (x25 + (x35 << 0x4)); + uint64_t x45 = (x44 + (x35 << 0x1)); + uint64_t x46 = (x45 + x35); + uint64_t x47 = (x24 + (x34 << 0x4)); + uint64_t x48 = (x47 + (x34 << 0x1)); + uint64_t x49 = (x48 + x34); + uint64_t x50 = (x23 + (x33 << 0x4)); + uint64_t x51 = (x50 + (x33 << 0x1)); + uint64_t x52 = (x51 + x33); + uint64_t x53 = (x22 + (x32 << 0x4)); + uint64_t x54 = (x53 + (x32 << 0x1)); + uint64_t x55 = (x54 + x32); + uint64_t x56 = (x21 + (x31 << 0x4)); + uint64_t x57 = (x56 + (x31 << 0x1)); + uint64_t x58 = (x57 + x31); + uint64_t x59 = (x20 + (x30 << 0x4)); + uint64_t x60 = (x59 + (x30 << 0x1)); + uint64_t x61 = (x60 + x30); + uint64_t x62 = (x19 + (x29 << 0x4)); + uint64_t x63 = (x62 + (x29 << 0x1)); + uint64_t x64 = (x63 + x29); + uint64_t x65 = (x64 >> 0x1a); + uint32_t x66 = ((uint32_t)x64 & 0x3ffffff); + uint64_t x67 = (x65 + x61); + uint64_t x68 = (x67 >> 0x19); + uint32_t x69 = ((uint32_t)x67 & 0x1ffffff); + uint64_t x70 = (x68 + x58); + uint64_t x71 = (x70 >> 0x1a); + uint32_t x72 = ((uint32_t)x70 & 0x3ffffff); + uint64_t x73 = (x71 + x55); + uint64_t x74 = (x73 >> 0x19); + uint32_t x75 = ((uint32_t)x73 & 0x1ffffff); + uint64_t x76 = (x74 + x52); + uint64_t x77 = (x76 >> 0x1a); + uint32_t x78 = ((uint32_t)x76 & 0x3ffffff); + uint64_t x79 = (x77 + x49); + uint64_t x80 = (x79 >> 0x19); + uint32_t x81 = ((uint32_t)x79 & 0x1ffffff); + uint64_t x82 = (x80 + x46); + uint64_t x83 = (x82 >> 0x1a); + uint32_t x84 = ((uint32_t)x82 & 0x3ffffff); + uint64_t x85 = (x83 + x43); + uint64_t x86 = (x85 >> 0x19); + uint32_t x87 = ((uint32_t)x85 & 0x1ffffff); + uint64_t x88 = (x86 + x40); + uint64_t x89 = (x88 >> 0x1a); + uint32_t x90 = ((uint32_t)x88 & 0x3ffffff); + uint64_t x91 = (x89 + x28); + uint64_t x92 = (x91 >> 0x19); + uint32_t x93 = ((uint32_t)x91 & 0x1ffffff); + uint64_t x94 = (x66 + (0x13 * x92)); + uint32_t x95 = (uint32_t) (x94 >> 0x1a); + uint32_t x96 = ((uint32_t)x94 & 0x3ffffff); + uint32_t x97 = (x95 + x69); + uint32_t x98 = (x97 >> 0x19); + uint32_t x99 = (x97 & 0x1ffffff); + out[0] = x96; + out[1] = x99; + out[2] = (x98 + x72); + out[3] = x75; + out[4] = x78; + out[5] = x81; + out[6] = x84; + out[7] = x87; + out[8] = x90; + out[9] = x93; +} + +static inline void fe_sq_tl(fe *h, const fe_loose *f) +{ + fe_sqr_impl(h->v, f->v); +} + +static inline void fe_sq_tt(fe *h, const fe *f) +{ + fe_sqr_impl(h->v, f->v); +} + +static inline void fe_loose_invert(fe *out, const fe_loose *z) +{ + fe t0; + fe t1; + fe t2; + fe t3; + int i; + + fe_sq_tl(&t0, z); + fe_sq_tt(&t1, &t0); + for (i = 1; i < 2; ++i) + fe_sq_tt(&t1, &t1); + fe_mul_tlt(&t1, z, &t1); + fe_mul_ttt(&t0, &t0, &t1); + fe_sq_tt(&t2, &t0); + fe_mul_ttt(&t1, &t1, &t2); + fe_sq_tt(&t2, &t1); + for (i = 1; i < 5; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t1, &t2, &t1); + fe_sq_tt(&t2, &t1); + for (i = 1; i < 10; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t2, &t2, &t1); + fe_sq_tt(&t3, &t2); + for (i = 1; i < 20; ++i) + fe_sq_tt(&t3, &t3); + fe_mul_ttt(&t2, &t3, &t2); + fe_sq_tt(&t2, &t2); + for (i = 1; i < 10; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t1, &t2, &t1); + fe_sq_tt(&t2, &t1); + for (i = 1; i < 50; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t2, &t2, &t1); + fe_sq_tt(&t3, &t2); + for (i = 1; i < 100; ++i) + fe_sq_tt(&t3, &t3); + fe_mul_ttt(&t2, &t3, &t2); + fe_sq_tt(&t2, &t2); + for (i = 1; i < 50; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t1, &t2, &t1); + fe_sq_tt(&t1, &t1); + for (i = 1; i < 5; ++i) + fe_sq_tt(&t1, &t1); + fe_mul_ttt(out, &t1, &t0); +} + +static inline void fe_invert(fe *out, const fe *z) +{ + fe_loose l; + fe_copy_lt(&l, z); + fe_loose_invert(out, &l); +} + +/* Replace (f,g) with (g,f) if b == 1; + * replace (f,g) with (f,g) if b == 0. + * + * Preconditions: b in {0,1} + */ +static inline void fe_cswap(fe *f, fe *g, unsigned int b) +{ + unsigned i; + b = 0 - b; + for (i = 0; i < 10; i++) { + uint32_t x = f->v[i] ^ g->v[i]; + x &= b; + f->v[i] ^= x; + g->v[i] ^= x; + } +} + +/* NOTE: based on fiat-crypto fe_mul, edited for in2=121666, 0, 0.*/ +static inline void fe_mul_121666_impl(uint32_t out[10], const uint32_t in1[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = 0; + const uint32_t x39 = 0; + const uint32_t x37 = 0; + const uint32_t x35 = 0; + const uint32_t x33 = 0; + const uint32_t x31 = 0; + const uint32_t x29 = 0; + const uint32_t x27 = 0; + const uint32_t x25 = 0; + const uint32_t x23 = 121666; + uint64_t x40 = ((uint64_t)x23 * x5); + uint64_t x41 = (((uint64_t)x23 * x7) + ((uint64_t)x25 * x5)); + uint64_t x42 = ((((uint64_t)(0x2 * x25) * x7) + ((uint64_t)x23 * x9)) + ((uint64_t)x27 * x5)); + uint64_t x43 = (((((uint64_t)x25 * x9) + ((uint64_t)x27 * x7)) + ((uint64_t)x23 * x11)) + ((uint64_t)x29 * x5)); + uint64_t x44 = (((((uint64_t)x27 * x9) + (0x2 * (((uint64_t)x25 * x11) + ((uint64_t)x29 * x7)))) + ((uint64_t)x23 * x13)) + ((uint64_t)x31 * x5)); + uint64_t x45 = (((((((uint64_t)x27 * x11) + ((uint64_t)x29 * x9)) + ((uint64_t)x25 * x13)) + ((uint64_t)x31 * x7)) + ((uint64_t)x23 * x15)) + ((uint64_t)x33 * x5)); + uint64_t x46 = (((((0x2 * ((((uint64_t)x29 * x11) + ((uint64_t)x25 * x15)) + ((uint64_t)x33 * x7))) + ((uint64_t)x27 * x13)) + ((uint64_t)x31 * x9)) + ((uint64_t)x23 * x17)) + ((uint64_t)x35 * x5)); + uint64_t x47 = (((((((((uint64_t)x29 * x13) + ((uint64_t)x31 * x11)) + ((uint64_t)x27 * x15)) + ((uint64_t)x33 * x9)) + ((uint64_t)x25 * x17)) + ((uint64_t)x35 * x7)) + ((uint64_t)x23 * x19)) + ((uint64_t)x37 * x5)); + uint64_t x48 = (((((((uint64_t)x31 * x13) + (0x2 * (((((uint64_t)x29 * x15) + ((uint64_t)x33 * x11)) + ((uint64_t)x25 * x19)) + ((uint64_t)x37 * x7)))) + ((uint64_t)x27 * x17)) + ((uint64_t)x35 * x9)) + ((uint64_t)x23 * x21)) + ((uint64_t)x39 * x5)); + uint64_t x49 = (((((((((((uint64_t)x31 * x15) + ((uint64_t)x33 * x13)) + ((uint64_t)x29 * x17)) + ((uint64_t)x35 * x11)) + ((uint64_t)x27 * x19)) + ((uint64_t)x37 * x9)) + ((uint64_t)x25 * x21)) + ((uint64_t)x39 * x7)) + ((uint64_t)x23 * x20)) + ((uint64_t)x38 * x5)); + uint64_t x50 = (((((0x2 * ((((((uint64_t)x33 * x15) + ((uint64_t)x29 * x19)) + ((uint64_t)x37 * x11)) + ((uint64_t)x25 * x20)) + ((uint64_t)x38 * x7))) + ((uint64_t)x31 * x17)) + ((uint64_t)x35 * x13)) + ((uint64_t)x27 * x21)) + ((uint64_t)x39 * x9)); + uint64_t x51 = (((((((((uint64_t)x33 * x17) + ((uint64_t)x35 * x15)) + ((uint64_t)x31 * x19)) + ((uint64_t)x37 * x13)) + ((uint64_t)x29 * x21)) + ((uint64_t)x39 * x11)) + ((uint64_t)x27 * x20)) + ((uint64_t)x38 * x9)); + uint64_t x52 = (((((uint64_t)x35 * x17) + (0x2 * (((((uint64_t)x33 * x19) + ((uint64_t)x37 * x15)) + ((uint64_t)x29 * x20)) + ((uint64_t)x38 * x11)))) + ((uint64_t)x31 * x21)) + ((uint64_t)x39 * x13)); + uint64_t x53 = (((((((uint64_t)x35 * x19) + ((uint64_t)x37 * x17)) + ((uint64_t)x33 * x21)) + ((uint64_t)x39 * x15)) + ((uint64_t)x31 * x20)) + ((uint64_t)x38 * x13)); + uint64_t x54 = (((0x2 * ((((uint64_t)x37 * x19) + ((uint64_t)x33 * x20)) + ((uint64_t)x38 * x15))) + ((uint64_t)x35 * x21)) + ((uint64_t)x39 * x17)); + uint64_t x55 = (((((uint64_t)x37 * x21) + ((uint64_t)x39 * x19)) + ((uint64_t)x35 * x20)) + ((uint64_t)x38 * x17)); + uint64_t x56 = (((uint64_t)x39 * x21) + (0x2 * (((uint64_t)x37 * x20) + ((uint64_t)x38 * x19)))); + uint64_t x57 = (((uint64_t)x39 * x20) + ((uint64_t)x38 * x21)); + uint64_t x58 = ((uint64_t)(0x2 * x38) * x20); + uint64_t x59 = (x48 + (x58 << 0x4)); + uint64_t x60 = (x59 + (x58 << 0x1)); + uint64_t x61 = (x60 + x58); + uint64_t x62 = (x47 + (x57 << 0x4)); + uint64_t x63 = (x62 + (x57 << 0x1)); + uint64_t x64 = (x63 + x57); + uint64_t x65 = (x46 + (x56 << 0x4)); + uint64_t x66 = (x65 + (x56 << 0x1)); + uint64_t x67 = (x66 + x56); + uint64_t x68 = (x45 + (x55 << 0x4)); + uint64_t x69 = (x68 + (x55 << 0x1)); + uint64_t x70 = (x69 + x55); + uint64_t x71 = (x44 + (x54 << 0x4)); + uint64_t x72 = (x71 + (x54 << 0x1)); + uint64_t x73 = (x72 + x54); + uint64_t x74 = (x43 + (x53 << 0x4)); + uint64_t x75 = (x74 + (x53 << 0x1)); + uint64_t x76 = (x75 + x53); + uint64_t x77 = (x42 + (x52 << 0x4)); + uint64_t x78 = (x77 + (x52 << 0x1)); + uint64_t x79 = (x78 + x52); + uint64_t x80 = (x41 + (x51 << 0x4)); + uint64_t x81 = (x80 + (x51 << 0x1)); + uint64_t x82 = (x81 + x51); + uint64_t x83 = (x40 + (x50 << 0x4)); + uint64_t x84 = (x83 + (x50 << 0x1)); + uint64_t x85 = (x84 + x50); + uint64_t x86 = (x85 >> 0x1a); + uint32_t x87 = ((uint32_t)x85 & 0x3ffffff); + uint64_t x88 = (x86 + x82); + uint64_t x89 = (x88 >> 0x19); + uint32_t x90 = ((uint32_t)x88 & 0x1ffffff); + uint64_t x91 = (x89 + x79); + uint64_t x92 = (x91 >> 0x1a); + uint32_t x93 = ((uint32_t)x91 & 0x3ffffff); + uint64_t x94 = (x92 + x76); + uint64_t x95 = (x94 >> 0x19); + uint32_t x96 = ((uint32_t)x94 & 0x1ffffff); + uint64_t x97 = (x95 + x73); + uint64_t x98 = (x97 >> 0x1a); + uint32_t x99 = ((uint32_t)x97 & 0x3ffffff); + uint64_t x100 = (x98 + x70); + uint64_t x101 = (x100 >> 0x19); + uint32_t x102 = ((uint32_t)x100 & 0x1ffffff); + uint64_t x103 = (x101 + x67); + uint64_t x104 = (x103 >> 0x1a); + uint32_t x105 = ((uint32_t)x103 & 0x3ffffff); + uint64_t x106 = (x104 + x64); + uint64_t x107 = (x106 >> 0x19); + uint32_t x108 = ((uint32_t)x106 & 0x1ffffff); + uint64_t x109 = (x107 + x61); + uint64_t x110 = (x109 >> 0x1a); + uint32_t x111 = ((uint32_t)x109 & 0x3ffffff); + uint64_t x112 = (x110 + x49); + uint64_t x113 = (x112 >> 0x19); + uint32_t x114 = ((uint32_t)x112 & 0x1ffffff); + uint64_t x115 = (x87 + (0x13 * x113)); + uint32_t x116 = (uint32_t) (x115 >> 0x1a); + uint32_t x117 = ((uint32_t)x115 & 0x3ffffff); + uint32_t x118 = (x116 + x90); + uint32_t x119 = (x118 >> 0x19); + uint32_t x120 = (x118 & 0x1ffffff); + out[0] = x117; + out[1] = x120; + out[2] = (x119 + x93); + out[3] = x96; + out[4] = x99; + out[5] = x102; + out[6] = x105; + out[7] = x108; + out[8] = x111; + out[9] = x114; +} + +static inline void fe_mul121666(fe *h, const fe_loose *f) +{ + fe_mul_121666_impl(h->v, f->v); +} + +static const uint8_t curve25519_null_point[CURVE25519_KEY_SIZE]; + +bool curve25519(uint8_t out[CURVE25519_KEY_SIZE], + const uint8_t scalar[CURVE25519_KEY_SIZE], + const uint8_t point[CURVE25519_KEY_SIZE]) +{ + fe x1, x2, z2, x3, z3; + fe_loose x2l, z2l, x3l; + unsigned swap = 0; + int pos; + uint8_t e[32]; + + memcpy(e, scalar, 32); + curve25519_clamp_secret(e); + + /* The following implementation was transcribed to Coq and proven to + * correspond to unary scalar multiplication in affine coordinates given + * that x1 != 0 is the x coordinate of some point on the curve. It was + * also checked in Coq that doing a ladderstep with x1 = x3 = 0 gives + * z2' = z3' = 0, and z2 = z3 = 0 gives z2' = z3' = 0. The statement was + * quantified over the underlying field, so it applies to Curve25519 + * itself and the quadratic twist of Curve25519. It was not proven in + * Coq that prime-field arithmetic correctly simulates extension-field + * arithmetic on prime-field values. The decoding of the byte array + * representation of e was not considered. + * + * Specification of Montgomery curves in affine coordinates: + * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Spec/MontgomeryCurve.v#L27> + * + * Proof that these form a group that is isomorphic to a Weierstrass + * curve: + * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/AffineProofs.v#L35> + * + * Coq transcription and correctness proof of the loop + * (where scalarbits=255): + * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZ.v#L118> + * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L278> + * preconditions: 0 <= e < 2^255 (not necessarily e < order), + * fe_invert(0) = 0 + */ + fe_frombytes(&x1, point); + fe_1(&x2); + fe_0(&z2); + fe_copy(&x3, &x1); + fe_1(&z3); + + for (pos = 254; pos >= 0; --pos) { + fe tmp0, tmp1; + fe_loose tmp0l, tmp1l; + /* loop invariant as of right before the test, for the case + * where x1 != 0: + * pos >= -1; if z2 = 0 then x2 is nonzero; if z3 = 0 then x3 + * is nonzero + * let r := e >> (pos+1) in the following equalities of + * projective points: + * to_xz (r*P) === if swap then (x3, z3) else (x2, z2) + * to_xz ((r+1)*P) === if swap then (x2, z2) else (x3, z3) + * x1 is the nonzero x coordinate of the nonzero + * point (r*P-(r+1)*P) + */ + unsigned b = 1 & (e[pos / 8] >> (pos & 7)); + swap ^= b; + fe_cswap(&x2, &x3, swap); + fe_cswap(&z2, &z3, swap); + swap = b; + /* Coq transcription of ladderstep formula (called from + * transcribed loop): + * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZ.v#L89> + * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L131> + * x1 != 0 <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L217> + * x1 = 0 <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L147> + */ + fe_sub(&tmp0l, &x3, &z3); + fe_sub(&tmp1l, &x2, &z2); + fe_add(&x2l, &x2, &z2); + fe_add(&z2l, &x3, &z3); + fe_mul_tll(&z3, &tmp0l, &x2l); + fe_mul_tll(&z2, &z2l, &tmp1l); + fe_sq_tl(&tmp0, &tmp1l); + fe_sq_tl(&tmp1, &x2l); + fe_add(&x3l, &z3, &z2); + fe_sub(&z2l, &z3, &z2); + fe_mul_ttt(&x2, &tmp1, &tmp0); + fe_sub(&tmp1l, &tmp1, &tmp0); + fe_sq_tl(&z2, &z2l); + fe_mul121666(&z3, &tmp1l); + fe_sq_tl(&x3, &x3l); + fe_add(&tmp0l, &tmp0, &z3); + fe_mul_ttt(&z3, &x1, &z2); + fe_mul_tll(&z2, &tmp1l, &tmp0l); + } + /* here pos=-1, so r=e, so to_xz (e*P) === if swap then (x3, z3) + * else (x2, z2) + */ + fe_cswap(&x2, &x3, swap); + fe_cswap(&z2, &z3, swap); + + fe_invert(&z2, &z2); + fe_mul_ttt(&x2, &x2, &z2); + fe_tobytes(out, &x2); + + explicit_bzero(&x1, sizeof(x1)); + explicit_bzero(&x2, sizeof(x2)); + explicit_bzero(&z2, sizeof(z2)); + explicit_bzero(&x3, sizeof(x3)); + explicit_bzero(&z3, sizeof(z3)); + explicit_bzero(&x2l, sizeof(x2l)); + explicit_bzero(&z2l, sizeof(z2l)); + explicit_bzero(&x3l, sizeof(x3l)); + explicit_bzero(&e, sizeof(e)); + + return timingsafe_bcmp(out, curve25519_null_point, CURVE25519_KEY_SIZE) != 0; +} +#endif + +int +crypto_init(void) +{ +#ifndef COMPAT_NEED_CHACHA20POLY1305_MBUF + struct crypto_session_params csp = { + .csp_mode = CSP_MODE_AEAD, + .csp_ivlen = sizeof(uint64_t), + .csp_cipher_alg = CRYPTO_CHACHA20_POLY1305, + .csp_cipher_klen = CHACHA20POLY1305_KEY_SIZE, + .csp_flags = CSP_F_SEPARATE_AAD | CSP_F_SEPARATE_OUTPUT + }; + int ret = crypto_newsession(&chacha20_poly1305_sid, &csp, CRYPTOCAP_F_SOFTWARE); + if (ret != 0) + return (ret); +#endif + return (0); +} + +void +crypto_deinit(void) +{ +#ifndef COMPAT_NEED_CHACHA20POLY1305_MBUF + crypto_freesession(chacha20_poly1305_sid); +#endif +} diff --git a/sys/dev/wg/wg_noise.c b/sys/dev/wg/wg_noise.c new file mode 100644 index 000000000000..756b5c07c10a --- /dev/null +++ b/sys/dev/wg/wg_noise.c @@ -0,0 +1,1410 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net> + * Copyright (c) 2022 The FreeBSD Foundation + */ + +#include <sys/param.h> +#include <sys/systm.h> +#include <sys/ck.h> +#include <sys/endian.h> +#include <sys/epoch.h> +#include <sys/kernel.h> +#include <sys/lock.h> +#include <sys/malloc.h> +#include <sys/mutex.h> +#include <sys/refcount.h> +#include <sys/rwlock.h> +#include <crypto/siphash/siphash.h> + +#include "crypto.h" +#include "wg_noise.h" +#include "support.h" + +/* Protocol string constants */ +#define NOISE_HANDSHAKE_NAME "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" +#define NOISE_IDENTIFIER_NAME "WireGuard v1 zx2c4 Jason@zx2c4.com" + +/* Constants for the counter */ +#define COUNTER_BITS_TOTAL 8192 +#ifdef __LP64__ +#define COUNTER_ORDER 6 +#define COUNTER_BITS 64 +#else +#define COUNTER_ORDER 5 +#define COUNTER_BITS 32 +#endif +#define COUNTER_REDUNDANT_BITS COUNTER_BITS +#define COUNTER_WINDOW_SIZE (COUNTER_BITS_TOTAL - COUNTER_REDUNDANT_BITS) + +/* Constants for the keypair */ +#define REKEY_AFTER_MESSAGES (1ull << 60) +#define REJECT_AFTER_MESSAGES (UINT64_MAX - COUNTER_WINDOW_SIZE - 1) +#define REKEY_AFTER_TIME 120 +#define REKEY_AFTER_TIME_RECV 165 +#define REJECT_INTERVAL (1000000000 / 50) /* fifty times per sec */ +/* 24 = floor(log2(REJECT_INTERVAL)) */ +#define REJECT_INTERVAL_MASK (~((1ull<<24)-1)) +#define TIMER_RESET (SBT_1S * -(REKEY_TIMEOUT+1)) + +#define HT_INDEX_SIZE (1 << 13) +#define HT_INDEX_MASK (HT_INDEX_SIZE - 1) +#define HT_REMOTE_SIZE (1 << 11) +#define HT_REMOTE_MASK (HT_REMOTE_SIZE - 1) +#define MAX_REMOTE_PER_LOCAL (1 << 20) + +struct noise_index { + CK_LIST_ENTRY(noise_index) i_entry; + uint32_t i_local_index; + uint32_t i_remote_index; + int i_is_keypair; +}; + +struct noise_keypair { + struct noise_index kp_index; + u_int kp_refcnt; + bool kp_can_send; + bool kp_is_initiator; + sbintime_t kp_birthdate; /* sbinuptime */ + struct noise_remote *kp_remote; + + uint8_t kp_send[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t kp_recv[NOISE_SYMMETRIC_KEY_LEN]; + + /* Counter elements */ + struct rwlock kp_nonce_lock; + uint64_t kp_nonce_send; + uint64_t kp_nonce_recv; + unsigned long kp_backtrack[COUNTER_BITS_TOTAL / COUNTER_BITS]; + + struct epoch_context kp_smr; +}; + +struct noise_handshake { + uint8_t hs_e[NOISE_PUBLIC_KEY_LEN]; + uint8_t hs_hash[NOISE_HASH_LEN]; + uint8_t hs_ck[NOISE_HASH_LEN]; +}; + +enum noise_handshake_state { + HANDSHAKE_DEAD, + HANDSHAKE_INITIATOR, + HANDSHAKE_RESPONDER, +}; + +struct noise_remote { + struct noise_index r_index; + + CK_LIST_ENTRY(noise_remote) r_entry; + bool r_entry_inserted; + uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; + + struct rwlock r_handshake_lock; + struct noise_handshake r_handshake; + enum noise_handshake_state r_handshake_state; + sbintime_t r_last_sent; /* sbinuptime */ + sbintime_t r_last_init_recv; /* sbinuptime */ + uint8_t r_timestamp[NOISE_TIMESTAMP_LEN]; + uint8_t r_psk[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t r_ss[NOISE_PUBLIC_KEY_LEN]; + + u_int r_refcnt; + struct noise_local *r_local; + void *r_arg; + + struct mtx r_keypair_mtx; + struct noise_keypair *r_next, *r_current, *r_previous; + + struct epoch_context r_smr; + void (*r_cleanup)(struct noise_remote *); +}; + +struct noise_local { + struct rwlock l_identity_lock; + bool l_has_identity; + uint8_t l_public[NOISE_PUBLIC_KEY_LEN]; + uint8_t l_private[NOISE_PUBLIC_KEY_LEN]; + + u_int l_refcnt; + uint8_t l_hash_key[SIPHASH_KEY_LENGTH]; + void *l_arg; + void (*l_cleanup)(struct noise_local *); + + struct mtx l_remote_mtx; + size_t l_remote_num; + CK_LIST_HEAD(,noise_remote) l_remote_hash[HT_REMOTE_SIZE]; + + struct mtx l_index_mtx; + CK_LIST_HEAD(,noise_index) l_index_hash[HT_INDEX_SIZE]; +}; + +static void noise_precompute_ss(struct noise_local *, struct noise_remote *); + +static void noise_remote_index_insert(struct noise_local *, struct noise_remote *); +static struct noise_remote * + noise_remote_index_lookup(struct noise_local *, uint32_t, bool); +static int noise_remote_index_remove(struct noise_local *, struct noise_remote *); +static void noise_remote_expire_current(struct noise_remote *); + +static void noise_add_new_keypair(struct noise_local *, struct noise_remote *, struct noise_keypair *); +static int noise_begin_session(struct noise_remote *); +static void noise_keypair_drop(struct noise_keypair *); + +static void noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *, + size_t, size_t, size_t, size_t, + const uint8_t [NOISE_HASH_LEN]); +static int noise_mix_dh(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static int noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static void noise_mix_hash(uint8_t [NOISE_HASH_LEN], const uint8_t *, size_t); +static void noise_mix_psk(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_SYMMETRIC_KEY_LEN], const uint8_t [NOISE_SYMMETRIC_KEY_LEN]); +static void noise_param_init(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static void noise_msg_encrypt(uint8_t *, const uint8_t *, size_t, + uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]); +static int noise_msg_decrypt(uint8_t *, const uint8_t *, size_t, + uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]); +static void noise_msg_ephemeral(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static void noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]); +static int noise_timer_expired(sbintime_t, uint32_t, uint32_t); +static uint64_t siphash24(const uint8_t [SIPHASH_KEY_LENGTH], const void *, size_t); + +MALLOC_DEFINE(M_NOISE, "NOISE", "wgnoise"); + +/* Local configuration */ +struct noise_local * +noise_local_alloc(void *arg) +{ + struct noise_local *l; + size_t i; + + l = malloc(sizeof(*l), M_NOISE, M_WAITOK | M_ZERO); + + rw_init(&l->l_identity_lock, "noise_identity"); + l->l_has_identity = false; + bzero(l->l_public, NOISE_PUBLIC_KEY_LEN); + bzero(l->l_private, NOISE_PUBLIC_KEY_LEN); + + refcount_init(&l->l_refcnt, 1); + arc4random_buf(l->l_hash_key, sizeof(l->l_hash_key)); + l->l_arg = arg; + l->l_cleanup = NULL; + + mtx_init(&l->l_remote_mtx, "noise_remote", NULL, MTX_DEF); + l->l_remote_num = 0; + for (i = 0; i < HT_REMOTE_SIZE; i++) + CK_LIST_INIT(&l->l_remote_hash[i]); + + mtx_init(&l->l_index_mtx, "noise_index", NULL, MTX_DEF); + for (i = 0; i < HT_INDEX_SIZE; i++) + CK_LIST_INIT(&l->l_index_hash[i]); + + return (l); +} + +struct noise_local * +noise_local_ref(struct noise_local *l) +{ + refcount_acquire(&l->l_refcnt); + return (l); +} + +void +noise_local_put(struct noise_local *l) +{ + if (refcount_release(&l->l_refcnt)) { + if (l->l_cleanup != NULL) + l->l_cleanup(l); + rw_destroy(&l->l_identity_lock); + mtx_destroy(&l->l_remote_mtx); + mtx_destroy(&l->l_index_mtx); + explicit_bzero(l, sizeof(*l)); + free(l, M_NOISE); + } +} + +void +noise_local_free(struct noise_local *l, void (*cleanup)(struct noise_local *)) +{ + l->l_cleanup = cleanup; + noise_local_put(l); +} + +void * +noise_local_arg(struct noise_local *l) +{ + return (l->l_arg); +} + +void +noise_local_private(struct noise_local *l, const uint8_t private[NOISE_PUBLIC_KEY_LEN]) +{ + struct epoch_tracker et; + struct noise_remote *r; + size_t i; + + rw_wlock(&l->l_identity_lock); + memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN); + curve25519_clamp_secret(l->l_private); + l->l_has_identity = curve25519_generate_public(l->l_public, l->l_private); + + NET_EPOCH_ENTER(et); + for (i = 0; i < HT_REMOTE_SIZE; i++) { + CK_LIST_FOREACH(r, &l->l_remote_hash[i], r_entry) { + noise_precompute_ss(l, r); + noise_remote_expire_current(r); + } + } + NET_EPOCH_EXIT(et); + rw_wunlock(&l->l_identity_lock); +} + +int +noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN], + uint8_t private[NOISE_PUBLIC_KEY_LEN]) +{ + int has_identity; + rw_rlock(&l->l_identity_lock); + if ((has_identity = l->l_has_identity)) { + if (public != NULL) + memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN); + if (private != NULL) + memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN); + } + rw_runlock(&l->l_identity_lock); + return (has_identity ? 0 : ENXIO); +} + +static void +noise_precompute_ss(struct noise_local *l, struct noise_remote *r) +{ + rw_wlock(&r->r_handshake_lock); + if (!l->l_has_identity || + !curve25519(r->r_ss, l->l_private, r->r_public)) + bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); + rw_wunlock(&r->r_handshake_lock); +} + +/* Remote configuration */ +struct noise_remote * +noise_remote_alloc(struct noise_local *l, void *arg, + const uint8_t public[NOISE_PUBLIC_KEY_LEN]) +{ + struct noise_remote *r; + + r = malloc(sizeof(*r), M_NOISE, M_WAITOK | M_ZERO); + memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN); + + rw_init(&r->r_handshake_lock, "noise_handshake"); + r->r_handshake_state = HANDSHAKE_DEAD; + r->r_last_sent = TIMER_RESET; + r->r_last_init_recv = TIMER_RESET; + noise_precompute_ss(l, r); + + refcount_init(&r->r_refcnt, 1); + r->r_local = noise_local_ref(l); + r->r_arg = arg; + + mtx_init(&r->r_keypair_mtx, "noise_keypair", NULL, MTX_DEF); + + return (r); +} + +int +noise_remote_enable(struct noise_remote *r) +{ + struct noise_local *l = r->r_local; + uint64_t idx; + int ret = 0; + + /* Insert to hashtable */ + idx = siphash24(l->l_hash_key, r->r_public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK; + + mtx_lock(&l->l_remote_mtx); + if (!r->r_entry_inserted) { + if (l->l_remote_num < MAX_REMOTE_PER_LOCAL) { + r->r_entry_inserted = true; + l->l_remote_num++; + CK_LIST_INSERT_HEAD(&l->l_remote_hash[idx], r, r_entry); + } else { + ret = ENOSPC; + } + } + mtx_unlock(&l->l_remote_mtx); + + return ret; +} + +void +noise_remote_disable(struct noise_remote *r) +{ + struct noise_local *l = r->r_local; + /* remove from hashtable */ + mtx_lock(&l->l_remote_mtx); + if (r->r_entry_inserted) { + r->r_entry_inserted = false; + CK_LIST_REMOVE(r, r_entry); + l->l_remote_num--; + }; + mtx_unlock(&l->l_remote_mtx); +} + +struct noise_remote * +noise_remote_lookup(struct noise_local *l, const uint8_t public[NOISE_PUBLIC_KEY_LEN]) +{ + struct epoch_tracker et; + struct noise_remote *r, *ret = NULL; + uint64_t idx; + + idx = siphash24(l->l_hash_key, public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK; + + NET_EPOCH_ENTER(et); + CK_LIST_FOREACH(r, &l->l_remote_hash[idx], r_entry) { + if (timingsafe_bcmp(r->r_public, public, NOISE_PUBLIC_KEY_LEN) == 0) { + if (refcount_acquire_if_not_zero(&r->r_refcnt)) + ret = r; + break; + } + } + NET_EPOCH_EXIT(et); + return (ret); +} + +static void +noise_remote_index_insert(struct noise_local *l, struct noise_remote *r) +{ + struct noise_index *i, *r_i = &r->r_index; + struct epoch_tracker et; + uint32_t idx; + + noise_remote_index_remove(l, r); + + NET_EPOCH_ENTER(et); +assign_id: + r_i->i_local_index = arc4random(); + idx = r_i->i_local_index & HT_INDEX_MASK; + CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == r_i->i_local_index) + goto assign_id; + } + + mtx_lock(&l->l_index_mtx); + CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == r_i->i_local_index) { + mtx_unlock(&l->l_index_mtx); + goto assign_id; + } + } + CK_LIST_INSERT_HEAD(&l->l_index_hash[idx], r_i, i_entry); + mtx_unlock(&l->l_index_mtx); + + NET_EPOCH_EXIT(et); +} + +static struct noise_remote * +noise_remote_index_lookup(struct noise_local *l, uint32_t idx0, bool lookup_keypair) +{ + struct epoch_tracker et; + struct noise_index *i; + struct noise_keypair *kp; + struct noise_remote *r, *ret = NULL; + uint32_t idx = idx0 & HT_INDEX_MASK; + + NET_EPOCH_ENTER(et); + CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == idx0) { + if (!i->i_is_keypair) { + r = (struct noise_remote *) i; + } else if (lookup_keypair) { + kp = (struct noise_keypair *) i; + r = kp->kp_remote; + } else { + break; + } + if (refcount_acquire_if_not_zero(&r->r_refcnt)) + ret = r; + break; + } + } + NET_EPOCH_EXIT(et); + return (ret); +} + +struct noise_remote * +noise_remote_index(struct noise_local *l, uint32_t idx) +{ + return noise_remote_index_lookup(l, idx, true); +} + +static int +noise_remote_index_remove(struct noise_local *l, struct noise_remote *r) +{ + rw_assert(&r->r_handshake_lock, RA_WLOCKED); + if (r->r_handshake_state != HANDSHAKE_DEAD) { + mtx_lock(&l->l_index_mtx); + r->r_handshake_state = HANDSHAKE_DEAD; + CK_LIST_REMOVE(&r->r_index, i_entry); + mtx_unlock(&l->l_index_mtx); + return (1); + } + return (0); +} + +struct noise_remote * +noise_remote_ref(struct noise_remote *r) +{ + refcount_acquire(&r->r_refcnt); + return (r); +} + +static void +noise_remote_smr_free(struct epoch_context *smr) +{ + struct noise_remote *r; + r = __containerof(smr, struct noise_remote, r_smr); + if (r->r_cleanup != NULL) + r->r_cleanup(r); + noise_local_put(r->r_local); + rw_destroy(&r->r_handshake_lock); + mtx_destroy(&r->r_keypair_mtx); + explicit_bzero(r, sizeof(*r)); + free(r, M_NOISE); +} + +void +noise_remote_put(struct noise_remote *r) +{ + if (refcount_release(&r->r_refcnt)) + NET_EPOCH_CALL(noise_remote_smr_free, &r->r_smr); +} + +void +noise_remote_free(struct noise_remote *r, void (*cleanup)(struct noise_remote *)) +{ + r->r_cleanup = cleanup; + noise_remote_disable(r); + + /* now clear all keypairs and handshakes, then put this reference */ + noise_remote_handshake_clear(r); + noise_remote_keypairs_clear(r); + noise_remote_put(r); +} + +struct noise_local * +noise_remote_local(struct noise_remote *r) +{ + return (noise_local_ref(r->r_local)); +} + +void * +noise_remote_arg(struct noise_remote *r) +{ + return (r->r_arg); +} + +void +noise_remote_set_psk(struct noise_remote *r, + const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) +{ + rw_wlock(&r->r_handshake_lock); + if (psk == NULL) + bzero(r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + else + memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); + rw_wunlock(&r->r_handshake_lock); +} + +int +noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN], + uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) +{ + static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN]; + int ret; + + if (public != NULL) + memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN); + + rw_rlock(&r->r_handshake_lock); + if (psk != NULL) + memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN); + rw_runlock(&r->r_handshake_lock); + + return (ret ? 0 : ENOENT); +} + +int +noise_remote_initiation_expired(struct noise_remote *r) +{ + int expired; + rw_rlock(&r->r_handshake_lock); + expired = noise_timer_expired(r->r_last_sent, REKEY_TIMEOUT, 0); + rw_runlock(&r->r_handshake_lock); + return (expired); +} + +void +noise_remote_handshake_clear(struct noise_remote *r) +{ + rw_wlock(&r->r_handshake_lock); + if (noise_remote_index_remove(r->r_local, r)) + bzero(&r->r_handshake, sizeof(r->r_handshake)); + r->r_last_sent = TIMER_RESET; + rw_wunlock(&r->r_handshake_lock); +} + +void +noise_remote_keypairs_clear(struct noise_remote *r) +{ + struct noise_keypair *kp; + + mtx_lock(&r->r_keypair_mtx); + kp = ck_pr_load_ptr(&r->r_next); + ck_pr_store_ptr(&r->r_next, NULL); + noise_keypair_drop(kp); + + kp = ck_pr_load_ptr(&r->r_current); + ck_pr_store_ptr(&r->r_current, NULL); + noise_keypair_drop(kp); + + kp = ck_pr_load_ptr(&r->r_previous); + ck_pr_store_ptr(&r->r_previous, NULL); + noise_keypair_drop(kp); + mtx_unlock(&r->r_keypair_mtx); +} + +static void +noise_remote_expire_current(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *kp; + + noise_remote_handshake_clear(r); + + NET_EPOCH_ENTER(et); + kp = ck_pr_load_ptr(&r->r_next); + if (kp != NULL) + ck_pr_store_bool(&kp->kp_can_send, false); + kp = ck_pr_load_ptr(&r->r_current); + if (kp != NULL) + ck_pr_store_bool(&kp->kp_can_send, false); + NET_EPOCH_EXIT(et); +} + +/* Keypair functions */ +static void +noise_add_new_keypair(struct noise_local *l, struct noise_remote *r, + struct noise_keypair *kp) +{ + struct noise_keypair *next, *current, *previous; + struct noise_index *r_i = &r->r_index; + + /* Insert into the keypair table */ + mtx_lock(&r->r_keypair_mtx); + next = ck_pr_load_ptr(&r->r_next); + current = ck_pr_load_ptr(&r->r_current); + previous = ck_pr_load_ptr(&r->r_previous); + + if (kp->kp_is_initiator) { + if (next != NULL) { + ck_pr_store_ptr(&r->r_next, NULL); + ck_pr_store_ptr(&r->r_previous, next); + noise_keypair_drop(current); + } else { + ck_pr_store_ptr(&r->r_previous, current); + } + noise_keypair_drop(previous); + ck_pr_store_ptr(&r->r_current, kp); + } else { + ck_pr_store_ptr(&r->r_next, kp); + noise_keypair_drop(next); + ck_pr_store_ptr(&r->r_previous, NULL); + noise_keypair_drop(previous); + + } + mtx_unlock(&r->r_keypair_mtx); + + /* Insert into index table */ + rw_assert(&r->r_handshake_lock, RA_WLOCKED); + + kp->kp_index.i_is_keypair = true; + kp->kp_index.i_local_index = r_i->i_local_index; + kp->kp_index.i_remote_index = r_i->i_remote_index; + + mtx_lock(&l->l_index_mtx); + CK_LIST_INSERT_BEFORE(r_i, &kp->kp_index, i_entry); + r->r_handshake_state = HANDSHAKE_DEAD; + CK_LIST_REMOVE(r_i, i_entry); + mtx_unlock(&l->l_index_mtx); + + explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); +} + +static int +noise_begin_session(struct noise_remote *r) +{ + struct noise_keypair *kp; + + rw_assert(&r->r_handshake_lock, RA_WLOCKED); + + if ((kp = malloc(sizeof(*kp), M_NOISE, M_NOWAIT | M_ZERO)) == NULL) + return (ENOSPC); + + refcount_init(&kp->kp_refcnt, 1); + kp->kp_can_send = true; + kp->kp_is_initiator = r->r_handshake_state == HANDSHAKE_INITIATOR; + kp->kp_birthdate = getsbinuptime(); + kp->kp_remote = noise_remote_ref(r); + + if (kp->kp_is_initiator) + noise_kdf(kp->kp_send, kp->kp_recv, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + r->r_handshake.hs_ck); + else + noise_kdf(kp->kp_recv, kp->kp_send, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + r->r_handshake.hs_ck); + + rw_init(&kp->kp_nonce_lock, "noise_nonce"); + + noise_add_new_keypair(r->r_local, r, kp); + return (0); +} + +struct noise_keypair * +noise_keypair_lookup(struct noise_local *l, uint32_t idx0) +{ + struct epoch_tracker et; + struct noise_index *i; + struct noise_keypair *kp, *ret = NULL; + uint32_t idx = idx0 & HT_INDEX_MASK; + + NET_EPOCH_ENTER(et); + CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == idx0 && i->i_is_keypair) { + kp = (struct noise_keypair *) i; + if (refcount_acquire_if_not_zero(&kp->kp_refcnt)) + ret = kp; + break; + } + } + NET_EPOCH_EXIT(et); + return (ret); +} + +struct noise_keypair * +noise_keypair_current(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *kp, *ret = NULL; + + NET_EPOCH_ENTER(et); + kp = ck_pr_load_ptr(&r->r_current); + if (kp != NULL && ck_pr_load_bool(&kp->kp_can_send)) { + if (noise_timer_expired(kp->kp_birthdate, REJECT_AFTER_TIME, 0)) + ck_pr_store_bool(&kp->kp_can_send, false); + else if (refcount_acquire_if_not_zero(&kp->kp_refcnt)) + ret = kp; + } + NET_EPOCH_EXIT(et); + return (ret); +} + +struct noise_keypair * +noise_keypair_ref(struct noise_keypair *kp) +{ + refcount_acquire(&kp->kp_refcnt); + return (kp); +} + +int +noise_keypair_received_with(struct noise_keypair *kp) +{ + struct noise_keypair *old; + struct noise_remote *r = kp->kp_remote; + + if (kp != ck_pr_load_ptr(&r->r_next)) + return (0); + + mtx_lock(&r->r_keypair_mtx); + if (kp != ck_pr_load_ptr(&r->r_next)) { + mtx_unlock(&r->r_keypair_mtx); + return (0); + } + + old = ck_pr_load_ptr(&r->r_previous); + ck_pr_store_ptr(&r->r_previous, ck_pr_load_ptr(&r->r_current)); + noise_keypair_drop(old); + ck_pr_store_ptr(&r->r_current, kp); + ck_pr_store_ptr(&r->r_next, NULL); + mtx_unlock(&r->r_keypair_mtx); + + return (ECONNRESET); +} + +static void +noise_keypair_smr_free(struct epoch_context *smr) +{ + struct noise_keypair *kp; + kp = __containerof(smr, struct noise_keypair, kp_smr); + noise_remote_put(kp->kp_remote); + rw_destroy(&kp->kp_nonce_lock); + explicit_bzero(kp, sizeof(*kp)); + free(kp, M_NOISE); +} + +void +noise_keypair_put(struct noise_keypair *kp) +{ + if (refcount_release(&kp->kp_refcnt)) + NET_EPOCH_CALL(noise_keypair_smr_free, &kp->kp_smr); +} + +static void +noise_keypair_drop(struct noise_keypair *kp) +{ + struct noise_remote *r; + struct noise_local *l; + + if (kp == NULL) + return; + + r = kp->kp_remote; + l = r->r_local; + + mtx_lock(&l->l_index_mtx); + CK_LIST_REMOVE(&kp->kp_index, i_entry); + mtx_unlock(&l->l_index_mtx); + + noise_keypair_put(kp); +} + +struct noise_remote * +noise_keypair_remote(struct noise_keypair *kp) +{ + return (noise_remote_ref(kp->kp_remote)); +} + +int +noise_keypair_nonce_next(struct noise_keypair *kp, uint64_t *send) +{ + if (!ck_pr_load_bool(&kp->kp_can_send)) + return (EINVAL); + +#ifdef __LP64__ + *send = ck_pr_faa_64(&kp->kp_nonce_send, 1); +#else + rw_wlock(&kp->kp_nonce_lock); + *send = kp->kp_nonce_send++; + rw_wunlock(&kp->kp_nonce_lock); +#endif + if (*send < REJECT_AFTER_MESSAGES) + return (0); + ck_pr_store_bool(&kp->kp_can_send, false); + return (EINVAL); +} + +int +noise_keypair_nonce_check(struct noise_keypair *kp, uint64_t recv) +{ + unsigned long index, index_current, top, i, bit; + int ret = EEXIST; + + rw_wlock(&kp->kp_nonce_lock); + + if (__predict_false(kp->kp_nonce_recv >= REJECT_AFTER_MESSAGES + 1 || + recv >= REJECT_AFTER_MESSAGES)) + goto error; + + ++recv; + + if (__predict_false(recv + COUNTER_WINDOW_SIZE < kp->kp_nonce_recv)) + goto error; + + index = recv >> COUNTER_ORDER; + + if (__predict_true(recv > kp->kp_nonce_recv)) { + index_current = kp->kp_nonce_recv >> COUNTER_ORDER; + top = MIN(index - index_current, COUNTER_BITS_TOTAL / COUNTER_BITS); + for (i = 1; i <= top; i++) + kp->kp_backtrack[ + (i + index_current) & + ((COUNTER_BITS_TOTAL / COUNTER_BITS) - 1)] = 0; +#ifdef __LP64__ + ck_pr_store_64(&kp->kp_nonce_recv, recv); +#else + kp->kp_nonce_recv = recv; +#endif + } + + index &= (COUNTER_BITS_TOTAL / COUNTER_BITS) - 1; + bit = 1ul << (recv & (COUNTER_BITS - 1)); + if (kp->kp_backtrack[index] & bit) + goto error; + + kp->kp_backtrack[index] |= bit; + ret = 0; +error: + rw_wunlock(&kp->kp_nonce_lock); + return (ret); +} + +int +noise_keep_key_fresh_send(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *current; + int keep_key_fresh; + uint64_t nonce; + + NET_EPOCH_ENTER(et); + current = ck_pr_load_ptr(&r->r_current); + keep_key_fresh = current != NULL && ck_pr_load_bool(¤t->kp_can_send); + if (!keep_key_fresh) + goto out; +#ifdef __LP64__ + nonce = ck_pr_load_64(¤t->kp_nonce_send); +#else + rw_rlock(¤t->kp_nonce_lock); + nonce = current->kp_nonce_send; + rw_runlock(¤t->kp_nonce_lock); +#endif + keep_key_fresh = nonce > REKEY_AFTER_MESSAGES; + if (keep_key_fresh) + goto out; + keep_key_fresh = current->kp_is_initiator && noise_timer_expired(current->kp_birthdate, REKEY_AFTER_TIME, 0); + +out: + NET_EPOCH_EXIT(et); + return (keep_key_fresh ? ESTALE : 0); +} + +int +noise_keep_key_fresh_recv(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *current; + int keep_key_fresh; + + NET_EPOCH_ENTER(et); + current = ck_pr_load_ptr(&r->r_current); + keep_key_fresh = current != NULL && ck_pr_load_bool(¤t->kp_can_send) && + current->kp_is_initiator && noise_timer_expired(current->kp_birthdate, + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT, 0); + NET_EPOCH_EXIT(et); + + return (keep_key_fresh ? ESTALE : 0); +} + +int +noise_keypair_encrypt(struct noise_keypair *kp, uint32_t *r_idx, uint64_t nonce, struct mbuf *m) +{ + int ret; + + ret = chacha20poly1305_encrypt_mbuf(m, nonce, kp->kp_send); + if (ret) + return (ret); + + *r_idx = kp->kp_index.i_remote_index; + return (0); +} + +int +noise_keypair_decrypt(struct noise_keypair *kp, uint64_t nonce, struct mbuf *m) +{ + uint64_t cur_nonce; + int ret; + +#ifdef __LP64__ + cur_nonce = ck_pr_load_64(&kp->kp_nonce_recv); +#else + rw_rlock(&kp->kp_nonce_lock); + cur_nonce = kp->kp_nonce_recv; + rw_runlock(&kp->kp_nonce_lock); +#endif + + if (cur_nonce >= REJECT_AFTER_MESSAGES || + noise_timer_expired(kp->kp_birthdate, REJECT_AFTER_TIME, 0)) + return (EINVAL); + + ret = chacha20poly1305_decrypt_mbuf(m, nonce, kp->kp_recv); + if (ret) + return (ret); + + return (0); +} + +/* Handshake functions */ +int +noise_create_initiation(struct noise_remote *r, + uint32_t *s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) +{ + struct noise_handshake *hs = &r->r_handshake; + struct noise_local *l = r->r_local; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + int ret = EINVAL; + + rw_rlock(&l->l_identity_lock); + rw_wlock(&r->r_handshake_lock); + if (!l->l_has_identity) + goto error; + if (!noise_timer_expired(r->r_last_sent, REKEY_TIMEOUT, 0)) + goto error; + noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public); + + /* e */ + curve25519_generate_secret(hs->hs_e); + if (curve25519_generate_public(ue, hs->hs_e) == 0) + goto error; + noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue); + + /* es */ + if (noise_mix_dh(hs->hs_ck, key, hs->hs_e, r->r_public) != 0) + goto error; + + /* s */ + noise_msg_encrypt(es, l->l_public, + NOISE_PUBLIC_KEY_LEN, key, hs->hs_hash); + + /* ss */ + if (noise_mix_ss(hs->hs_ck, key, r->r_ss) != 0) + goto error; + + /* {t} */ + noise_tai64n_now(ets); + noise_msg_encrypt(ets, ets, + NOISE_TIMESTAMP_LEN, key, hs->hs_hash); + + noise_remote_index_insert(l, r); + r->r_handshake_state = HANDSHAKE_INITIATOR; + r->r_last_sent = getsbinuptime(); + *s_idx = r->r_index.i_local_index; + ret = 0; +error: + rw_wunlock(&r->r_handshake_lock); + rw_runlock(&l->l_identity_lock); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + return (ret); +} + +int +noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, + uint32_t s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) +{ + struct noise_remote *r; + struct noise_handshake hs; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; + uint8_t timestamp[NOISE_TIMESTAMP_LEN]; + int ret = EINVAL; + + rw_rlock(&l->l_identity_lock); + if (!l->l_has_identity) + goto error; + noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public); + + /* e */ + noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue); + + /* es */ + if (noise_mix_dh(hs.hs_ck, key, l->l_private, ue) != 0) + goto error; + + /* s */ + if (noise_msg_decrypt(r_public, es, + NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) + goto error; + + /* Lookup the remote we received from */ + if ((r = noise_remote_lookup(l, r_public)) == NULL) + goto error; + + /* ss */ + if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0) + goto error_put; + + /* {t} */ + if (noise_msg_decrypt(timestamp, ets, + NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) + goto error_put; + + memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN); + + /* We have successfully computed the same results, now we ensure that + * this is not an initiation replay, or a flood attack */ + rw_wlock(&r->r_handshake_lock); + + /* Replay */ + if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0) + memcpy(r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN); + else + goto error_set; + /* Flood attack */ + if (noise_timer_expired(r->r_last_init_recv, 0, REJECT_INTERVAL)) + r->r_last_init_recv = getsbinuptime(); + else + goto error_set; + + /* Ok, we're happy to accept this initiation now */ + noise_remote_index_insert(l, r); + r->r_index.i_remote_index = s_idx; + r->r_handshake_state = HANDSHAKE_RESPONDER; + r->r_handshake = hs; + *rp = noise_remote_ref(r); + ret = 0; +error_set: + rw_wunlock(&r->r_handshake_lock); +error_put: + noise_remote_put(r); +error: + rw_runlock(&l->l_identity_lock); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + explicit_bzero(&hs, sizeof(hs)); + return (ret); +} + +int +noise_create_response(struct noise_remote *r, + uint32_t *s_idx, uint32_t *r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]) +{ + struct noise_handshake *hs = &r->r_handshake; + struct noise_local *l = r->r_local; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t e[NOISE_PUBLIC_KEY_LEN]; + int ret = EINVAL; + + rw_rlock(&l->l_identity_lock); + rw_wlock(&r->r_handshake_lock); + + if (r->r_handshake_state != HANDSHAKE_RESPONDER) + goto error; + + /* e */ + curve25519_generate_secret(e); + if (curve25519_generate_public(ue, e) == 0) + goto error; + noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue); + + /* ee */ + if (noise_mix_dh(hs->hs_ck, NULL, e, hs->hs_e) != 0) + goto error; + + /* se */ + if (noise_mix_dh(hs->hs_ck, NULL, e, r->r_public) != 0) + goto error; + + /* psk */ + noise_mix_psk(hs->hs_ck, hs->hs_hash, key, r->r_psk); + + /* {} */ + noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash); + + if ((ret = noise_begin_session(r)) == 0) { + r->r_last_sent = getsbinuptime(); + *s_idx = r->r_index.i_local_index; + *r_idx = r->r_index.i_remote_index; + } +error: + rw_wunlock(&r->r_handshake_lock); + rw_runlock(&l->l_identity_lock); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + explicit_bzero(e, NOISE_PUBLIC_KEY_LEN); + return (ret); +} + +int +noise_consume_response(struct noise_local *l, struct noise_remote **rp, + uint32_t s_idx, uint32_t r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]) +{ + uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + struct noise_handshake hs; + struct noise_remote *r = NULL; + int ret = EINVAL; + + if ((r = noise_remote_index_lookup(l, r_idx, false)) == NULL) + return (ret); + + rw_rlock(&l->l_identity_lock); + if (!l->l_has_identity) + goto error; + + rw_rlock(&r->r_handshake_lock); + if (r->r_handshake_state != HANDSHAKE_INITIATOR) { + rw_runlock(&r->r_handshake_lock); + goto error; + } + memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + hs = r->r_handshake; + rw_runlock(&r->r_handshake_lock); + + /* e */ + noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue); + + /* ee */ + if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0) + goto error_zero; + + /* se */ + if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0) + goto error_zero; + + /* psk */ + noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key); + + /* {} */ + if (noise_msg_decrypt(NULL, en, + 0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) + goto error_zero; + + rw_wlock(&r->r_handshake_lock); + if (r->r_handshake_state == HANDSHAKE_INITIATOR && + r->r_index.i_local_index == r_idx) { + r->r_handshake = hs; + r->r_index.i_remote_index = s_idx; + if ((ret = noise_begin_session(r)) == 0) + *rp = noise_remote_ref(r); + } + rw_wunlock(&r->r_handshake_lock); +error_zero: + explicit_bzero(preshared_key, NOISE_SYMMETRIC_KEY_LEN); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + explicit_bzero(&hs, sizeof(hs)); +error: + rw_runlock(&l->l_identity_lock); + noise_remote_put(r); + return (ret); +} + +static void +hmac(uint8_t *out, const uint8_t *in, const uint8_t *key, const size_t outlen, + const size_t inlen, const size_t keylen) +{ + struct blake2s_state state; + uint8_t x_key[BLAKE2S_BLOCK_SIZE] __aligned(sizeof(uint32_t)) = { 0 }; + uint8_t i_hash[BLAKE2S_HASH_SIZE] __aligned(sizeof(uint32_t)); + int i; + + if (keylen > BLAKE2S_BLOCK_SIZE) { + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, key, keylen); + blake2s_final(&state, x_key); + } else + memcpy(x_key, key, keylen); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, in, inlen); + blake2s_final(&state, i_hash); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x5c ^ 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE); + blake2s_final(&state, i_hash); + + memcpy(out, i_hash, outlen); + explicit_bzero(x_key, BLAKE2S_BLOCK_SIZE); + explicit_bzero(i_hash, BLAKE2S_HASH_SIZE); +} + +/* Handshake helper functions */ +static void +noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x, + size_t a_len, size_t b_len, size_t c_len, size_t x_len, + const uint8_t ck[NOISE_HASH_LEN]) +{ + uint8_t out[BLAKE2S_HASH_SIZE + 1]; + uint8_t sec[BLAKE2S_HASH_SIZE]; + + /* Extract entropy from "x" into sec */ + hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN); + + if (a == NULL || a_len == 0) + goto out; + + /* Expand first key: key = sec, data = 0x1 */ + out[0] = 1; + hmac(out, out, sec, BLAKE2S_HASH_SIZE, 1, BLAKE2S_HASH_SIZE); + memcpy(a, out, a_len); + + if (b == NULL || b_len == 0) + goto out; + + /* Expand second key: key = sec, data = "a" || 0x2 */ + out[BLAKE2S_HASH_SIZE] = 2; + hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); + memcpy(b, out, b_len); + + if (c == NULL || c_len == 0) + goto out; + + /* Expand third key: key = sec, data = "b" || 0x3 */ + out[BLAKE2S_HASH_SIZE] = 3; + hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); + memcpy(c, out, c_len); + +out: + /* Clear sensitive data from stack */ + explicit_bzero(sec, BLAKE2S_HASH_SIZE); + explicit_bzero(out, BLAKE2S_HASH_SIZE + 1); +} + +static int +noise_mix_dh(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN], + const uint8_t private[NOISE_PUBLIC_KEY_LEN], + const uint8_t public[NOISE_PUBLIC_KEY_LEN]) +{ + uint8_t dh[NOISE_PUBLIC_KEY_LEN]; + + if (!curve25519(dh, private, public)) + return (EINVAL); + noise_kdf(ck, key, NULL, dh, + NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck); + explicit_bzero(dh, NOISE_PUBLIC_KEY_LEN); + return (0); +} + +static int +noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN], + const uint8_t ss[NOISE_PUBLIC_KEY_LEN]) +{ + static uint8_t null_point[NOISE_PUBLIC_KEY_LEN]; + if (timingsafe_bcmp(ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0) + return (ENOENT); + noise_kdf(ck, key, NULL, ss, + NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck); + return (0); +} + +static void +noise_mix_hash(uint8_t hash[NOISE_HASH_LEN], const uint8_t *src, + size_t src_len) +{ + struct blake2s_state blake; + + blake2s_init(&blake, NOISE_HASH_LEN); + blake2s_update(&blake, hash, NOISE_HASH_LEN); + blake2s_update(&blake, src, src_len); + blake2s_final(&blake, hash); +} + +static void +noise_mix_psk(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN], + uint8_t key[NOISE_SYMMETRIC_KEY_LEN], + const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) +{ + uint8_t tmp[NOISE_HASH_LEN]; + + noise_kdf(ck, tmp, key, psk, + NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, + NOISE_SYMMETRIC_KEY_LEN, ck); + noise_mix_hash(hash, tmp, NOISE_HASH_LEN); + explicit_bzero(tmp, NOISE_HASH_LEN); +} + +static void +noise_param_init(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN], + const uint8_t s[NOISE_PUBLIC_KEY_LEN]) +{ + struct blake2s_state blake; + + blake2s(ck, (uint8_t *)NOISE_HANDSHAKE_NAME, NULL, + NOISE_HASH_LEN, strlen(NOISE_HANDSHAKE_NAME), 0); + blake2s_init(&blake, NOISE_HASH_LEN); + blake2s_update(&blake, ck, NOISE_HASH_LEN); + blake2s_update(&blake, (uint8_t *)NOISE_IDENTIFIER_NAME, + strlen(NOISE_IDENTIFIER_NAME)); + blake2s_final(&blake, hash); + + noise_mix_hash(hash, s, NOISE_PUBLIC_KEY_LEN); +} + +static void +noise_msg_encrypt(uint8_t *dst, const uint8_t *src, size_t src_len, + uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN]) +{ + /* Nonce always zero for Noise_IK */ + chacha20poly1305_encrypt(dst, src, src_len, + hash, NOISE_HASH_LEN, 0, key); + noise_mix_hash(hash, dst, src_len + NOISE_AUTHTAG_LEN); +} + +static int +noise_msg_decrypt(uint8_t *dst, const uint8_t *src, size_t src_len, + uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN]) +{ + /* Nonce always zero for Noise_IK */ + if (!chacha20poly1305_decrypt(dst, src, src_len, + hash, NOISE_HASH_LEN, 0, key)) + return (EINVAL); + noise_mix_hash(hash, src, src_len); + return (0); +} + +static void +noise_msg_ephemeral(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN], + const uint8_t src[NOISE_PUBLIC_KEY_LEN]) +{ + noise_mix_hash(hash, src, NOISE_PUBLIC_KEY_LEN); + noise_kdf(ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0, + NOISE_PUBLIC_KEY_LEN, ck); +} + +static void +noise_tai64n_now(uint8_t output[NOISE_TIMESTAMP_LEN]) +{ + struct timespec time; + uint64_t sec; + uint32_t nsec; + + getnanotime(&time); + + /* Round down the nsec counter to limit precise timing leak. */ + time.tv_nsec &= REJECT_INTERVAL_MASK; + + /* https://cr.yp.to/libtai/tai64.html */ + sec = htobe64(0x400000000000000aULL + time.tv_sec); + nsec = htobe32(time.tv_nsec); + + /* memcpy to output buffer, assuming output could be unaligned. */ + memcpy(output, &sec, sizeof(sec)); + memcpy(output + sizeof(sec), &nsec, sizeof(nsec)); +} + +static inline int +noise_timer_expired(sbintime_t timer, uint32_t sec, uint32_t nsec) +{ + sbintime_t now = getsbinuptime(); + return (now > (timer + sec * SBT_1S + nstosbt(nsec))) ? ETIMEDOUT : 0; +} + +static uint64_t siphash24(const uint8_t key[SIPHASH_KEY_LENGTH], const void *src, size_t len) +{ + SIPHASH_CTX ctx; + return (SipHashX(&ctx, 2, 4, key, src, len)); +} + +#ifdef SELFTESTS +#include "selftest/counter.c" +#endif /* SELFTESTS */ diff --git a/sys/dev/wg/wg_noise.h b/sys/dev/wg/wg_noise.h new file mode 100644 index 000000000000..27e31e260cf4 --- /dev/null +++ b/sys/dev/wg/wg_noise.h @@ -0,0 +1,131 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net> + */ + +#ifndef __NOISE_H__ +#define __NOISE_H__ + +#include "crypto.h" + +#define NOISE_PUBLIC_KEY_LEN CURVE25519_KEY_SIZE +#define NOISE_SYMMETRIC_KEY_LEN CHACHA20POLY1305_KEY_SIZE +#define NOISE_TIMESTAMP_LEN (sizeof(uint64_t) + sizeof(uint32_t)) +#define NOISE_AUTHTAG_LEN CHACHA20POLY1305_AUTHTAG_SIZE +#define NOISE_HASH_LEN BLAKE2S_HASH_SIZE + +#define REJECT_AFTER_TIME 180 +#define REKEY_TIMEOUT 5 +#define KEEPALIVE_TIMEOUT 10 + +struct noise_local; +struct noise_remote; +struct noise_keypair; + +/* Local configuration */ +struct noise_local * + noise_local_alloc(void *); +struct noise_local * + noise_local_ref(struct noise_local *); +void noise_local_put(struct noise_local *); +void noise_local_free(struct noise_local *, void (*)(struct noise_local *)); +void * noise_local_arg(struct noise_local *); + +void noise_local_private(struct noise_local *, + const uint8_t[NOISE_PUBLIC_KEY_LEN]); +int noise_local_keys(struct noise_local *, + uint8_t[NOISE_PUBLIC_KEY_LEN], + uint8_t[NOISE_PUBLIC_KEY_LEN]); + +/* Remote configuration */ +struct noise_remote * + noise_remote_alloc(struct noise_local *, void *, + const uint8_t[NOISE_PUBLIC_KEY_LEN]); +int noise_remote_enable(struct noise_remote *); +void noise_remote_disable(struct noise_remote *); +struct noise_remote * + noise_remote_lookup(struct noise_local *, const uint8_t[NOISE_PUBLIC_KEY_LEN]); +struct noise_remote * + noise_remote_index(struct noise_local *, uint32_t); +struct noise_remote * + noise_remote_ref(struct noise_remote *); +void noise_remote_put(struct noise_remote *); +void noise_remote_free(struct noise_remote *, void (*)(struct noise_remote *)); +struct noise_local * + noise_remote_local(struct noise_remote *); +void * noise_remote_arg(struct noise_remote *); + +void noise_remote_set_psk(struct noise_remote *, + const uint8_t[NOISE_SYMMETRIC_KEY_LEN]); +int noise_remote_keys(struct noise_remote *, + uint8_t[NOISE_PUBLIC_KEY_LEN], + uint8_t[NOISE_SYMMETRIC_KEY_LEN]); +int noise_remote_initiation_expired(struct noise_remote *); +void noise_remote_handshake_clear(struct noise_remote *); +void noise_remote_keypairs_clear(struct noise_remote *); + +/* Keypair functions */ +struct noise_keypair * + noise_keypair_lookup(struct noise_local *, uint32_t); +struct noise_keypair * + noise_keypair_current(struct noise_remote *); +struct noise_keypair * + noise_keypair_ref(struct noise_keypair *); +int noise_keypair_received_with(struct noise_keypair *); +void noise_keypair_put(struct noise_keypair *); + +struct noise_remote * + noise_keypair_remote(struct noise_keypair *); + +int noise_keypair_nonce_next(struct noise_keypair *, uint64_t *); +int noise_keypair_nonce_check(struct noise_keypair *, uint64_t); + +int noise_keep_key_fresh_send(struct noise_remote *); +int noise_keep_key_fresh_recv(struct noise_remote *); +int noise_keypair_encrypt( + struct noise_keypair *, + uint32_t *r_idx, + uint64_t nonce, + struct mbuf *); +int noise_keypair_decrypt( + struct noise_keypair *, + uint64_t nonce, + struct mbuf *); + +/* Handshake functions */ +int noise_create_initiation( + struct noise_remote *, + uint32_t *s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]); + +int noise_consume_initiation( + struct noise_local *, + struct noise_remote **, + uint32_t s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]); + +int noise_create_response( + struct noise_remote *, + uint32_t *s_idx, + uint32_t *r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]); + +int noise_consume_response( + struct noise_local *, + struct noise_remote **, + uint32_t s_idx, + uint32_t r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]); + +#ifdef SELFTESTS +bool noise_counter_selftest(void); +#endif /* SELFTESTS */ + +#endif /* __NOISE_H__ */ diff --git a/sys/kern/kern_jail.c b/sys/kern/kern_jail.c index 51210c11bf20..8e9cdadd94cd 100644 --- a/sys/kern/kern_jail.c +++ b/sys/kern/kern_jail.c @@ -3758,6 +3758,7 @@ prison_priv_check(struct ucred *cred, int priv) case PRIV_NET_SETIFFIB: case PRIV_NET_OVPN: case PRIV_NET_ME: + case PRIV_NET_WG: /* * 802.11-related privileges. diff --git a/sys/modules/Makefile b/sys/modules/Makefile index 6f718acab38b..00afbffb1baf 100644 --- a/sys/modules/Makefile +++ b/sys/modules/Makefile @@ -164,6 +164,7 @@ SUBDIR= \ if_tuntap \ if_vlan \ if_vxlan \ + ${_if_wg} \ iflib \ ${_igc} \ imgact_binmisc \ @@ -449,6 +450,9 @@ _toecore= toecore _if_enc= if_enc _if_gif= if_gif _if_gre= if_gre +.if ${MK_CRYPT} != "no" || defined(ALL_MODULES) +_if_wg= if_wg +.endif _ipfw_pmod= ipfw_pmod .if ${KERN_OPTS:MIPSEC_SUPPORT} && !${KERN_OPTS:MIPSEC} _ipsec= ipsec diff --git a/sys/modules/if_wg/Makefile b/sys/modules/if_wg/Makefile new file mode 100644 index 000000000000..b47a87472116 --- /dev/null +++ b/sys/modules/if_wg/Makefile @@ -0,0 +1,10 @@ +.PATH: ${SRCTOP}/sys/dev/wg + +KMOD= if_wg + +SRCS= if_wg.c wg_cookie.c wg_crypto.c wg_noise.c +SRCS+= opt_inet.h opt_inet6.h device_if.h bus_if.h + +.include <bsd.kmod.mk> + +CFLAGS+= -include ${SRCTOP}/sys/dev/wg/compat.h diff --git a/sys/net/if_types.h b/sys/net/if_types.h index 419df6aa5647..6794da878587 100644 --- a/sys/net/if_types.h +++ b/sys/net/if_types.h @@ -256,6 +256,7 @@ typedef enum { IFT_ENC = 0xf4, /* Encapsulating interface */ IFT_PFLOG = 0xf6, /* PF packet filter logging */ IFT_PFSYNC = 0xf7, /* PF packet filter synchronization */ + IFT_WIREGUARD = 0xf8, /* WireGuard tunnel */ } ifType; /* diff --git a/sys/netinet6/nd6.c b/sys/netinet6/nd6.c index 84af00eabaac..be881b6291ac 100644 --- a/sys/netinet6/nd6.c +++ b/sys/netinet6/nd6.c @@ -284,8 +284,8 @@ nd6_ifattach(struct ifnet *ifp) * default regardless of the V_ip6_auto_linklocal configuration to * give a reasonable default behavior. */ - if ((V_ip6_auto_linklocal && ifp->if_type != IFT_BRIDGE) || - (ifp->if_flags & IFF_LOOPBACK)) + if ((V_ip6_auto_linklocal && ifp->if_type != IFT_BRIDGE && + ifp->if_type != IFT_WIREGUARD) || (ifp->if_flags & IFF_LOOPBACK)) nd->flags |= ND6_IFF_AUTO_LINKLOCAL; /* * A loopback interface does not need to accept RTADV. diff --git a/sys/sys/priv.h b/sys/sys/priv.h index f07a252295ae..20bfc7312ce3 100644 --- a/sys/sys/priv.h +++ b/sys/sys/priv.h @@ -350,6 +350,7 @@ #define PRIV_NET_SETVLANPCP PRIV_NET_SETLANPCP /* Alias Set VLAN priority */ #define PRIV_NET_OVPN 422 /* Administer OpenVPN DCO. */ #define PRIV_NET_ME 423 /* Administer ME interface. */ +#define PRIV_NET_WG 424 /* Administer WireGuard interface. */ /* * 802.11-related privileges. |