aboutsummaryrefslogtreecommitdiff
path: root/sys/dev/if_wg/if_wg.c
diff options
context:
space:
mode:
authorKyle Evans <kevans@FreeBSD.org>2021-03-15 02:25:40 +0000
committerKyle Evans <kevans@FreeBSD.org>2021-03-15 04:52:04 +0000
commit74ae3f3e33b810248da19004c58b3581cd367843 (patch)
treeb17ce98b77a3a1a86e8255dad7861d9c160222a9 /sys/dev/if_wg/if_wg.c
parent3e5e9939cda3b24df37c37da5f195415a894d9fd (diff)
downloadsrc-74ae3f3e33b810248da19004c58b3581cd367843.tar.gz
src-74ae3f3e33b810248da19004c58b3581cd367843.zip
if_wg: import latest fixup work from the wireguard-freebsd project
This is the culmination of about a week of work from three developers to fix a number of functional and security issues. This patch consists of work done by the following folks: - Jason A. Donenfeld <Jason@zx2c4.com> - Matt Dunwoodie <ncon@noconroy.net> - Kyle Evans <kevans@FreeBSD.org> Notable changes include: - Packets are now correctly staged for processing once the handshake has completed, resulting in less packet loss in the interim. - Various race conditions have been resolved, particularly w.r.t. socket and packet lifetime (panics) - Various tests have been added to assure correct functionality and tooling conformance - Many security issues have been addressed - if_wg now maintains jail-friendly semantics: sockets are created in the interface's home vnet so that it can act as the sole network connection for a jail - if_wg no longer fails to remove peer allowed-ips of 0.0.0.0/0 - if_wg now exports via ioctl a format that is future proof and complete. It is additionally supported by the upstream wireguard-tools (which we plan to merge in to base soon) - if_wg now conforms to the WireGuard protocol and is more closely aligned with security auditing guidelines Note that the driver has been rebased away from using iflib. iflib poses a number of challenges for a cloned device trying to operate in a vnet that are non-trivial to solve and adds complexity to the implementation for little gain. The crypto implementation that was previously added to the tree was a super complex integration of what previously appeared in an old out of tree Linux module, which has been reduced to crypto.c containing simple boring reference implementations. This is part of a near-to-mid term goal to work with FreeBSD kernel crypto folks and take advantage of or improve accelerated crypto already offered elsewhere. There's additional test suite effort underway out-of-tree taking advantage of the aforementioned jail-friendly semantics to test a number of real-world topologies, based on netns.sh. Also note that this is still a work in progress; work going further will be much smaller in nature. MFC after: 1 month (maybe)
Diffstat (limited to 'sys/dev/if_wg/if_wg.c')
-rw-r--r--sys/dev/if_wg/if_wg.c3454
1 files changed, 3454 insertions, 0 deletions
diff --git a/sys/dev/if_wg/if_wg.c b/sys/dev/if_wg/if_wg.c
new file mode 100644
index 000000000000..ba2eb3221fac
--- /dev/null
+++ b/sys/dev/if_wg/if_wg.c
@@ -0,0 +1,3454 @@
+/*
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
+ * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
+ * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+/* TODO audit imports */
+#include "opt_inet.h"
+#include "opt_inet6.h"
+
+#include <sys/cdefs.h>
+__FBSDID("$FreeBSD$");
+
+#include <sys/param.h>
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <vm/uma.h>
+
+#include <sys/mbuf.h>
+#include <sys/socket.h>
+#include <sys/kernel.h>
+
+#include <sys/sockio.h>
+#include <sys/socketvar.h>
+#include <sys/errno.h>
+#include <sys/jail.h>
+#include <sys/priv.h>
+#include <sys/proc.h>
+#include <sys/lock.h>
+#include <sys/rwlock.h>
+#include <sys/rmlock.h>
+#include <sys/protosw.h>
+#include <sys/module.h>
+#include <sys/endian.h>
+#include <sys/kdb.h>
+#include <sys/sx.h>
+#include <sys/sysctl.h>
+#include <sys/gtaskqueue.h>
+#include <sys/smp.h>
+#include <sys/nv.h>
+
+#include <net/bpf.h>
+
+#include <sys/syslog.h>
+
+#include <net/if.h>
+#include <net/if_var.h>
+#include <net/if_clone.h>
+#include <net/if_types.h>
+#include <net/ethernet.h>
+#include <net/radix.h>
+
+#include <netinet/in.h>
+#include <netinet/in_var.h>
+#include <netinet/ip.h>
+#include <netinet/ip_var.h>
+#include <netinet/ip6.h>
+#include <netinet6/ip6_var.h>
+#include <netinet6/scope6_var.h>
+#include <netinet/udp.h>
+#include <netinet/ip_icmp.h>
+#include <netinet/icmp6.h>
+#include <netinet/in_pcb.h>
+#include <netinet6/in6_pcb.h>
+#include <netinet/udp_var.h>
+
+#include <machine/in_cksum.h>
+
+#include "support.h"
+#include "wg_noise.h"
+#include "wg_cookie.h"
+#include "if_wg.h"
+
+/* It'd be nice to use IF_MAXMTU, but that means more complicated mbuf allocations,
+ * so instead just do the biggest mbuf we can easily allocate minus the usual maximum
+ * IPv6 overhead of 80 bytes. If somebody wants bigger frames, we can revisit this. */
+#define MAX_MTU (MJUM16BYTES - 80)
+
+#define DEFAULT_MTU 1420
+
+#define MAX_STAGED_PKT 128
+#define MAX_QUEUED_PKT 1024
+#define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1)
+
+#define MAX_QUEUED_HANDSHAKES 4096
+
+#define HASHTABLE_PEER_SIZE (1 << 11)
+#define HASHTABLE_INDEX_SIZE (1 << 13)
+#define MAX_PEERS_PER_IFACE (1 << 20)
+
+#define REKEY_TIMEOUT 5
+#define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */
+#define KEEPALIVE_TIMEOUT 10
+#define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT)
+#define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
+#define UNDERLOAD_TIMEOUT 1
+
+#define DPRINTF(sc, ...) if (wireguard_debug) if_printf(sc->sc_ifp, ##__VA_ARGS__)
+
+/* First byte indicating packet type on the wire */
+#define WG_PKT_INITIATION htole32(1)
+#define WG_PKT_RESPONSE htole32(2)
+#define WG_PKT_COOKIE htole32(3)
+#define WG_PKT_DATA htole32(4)
+
+#define WG_PKT_WITH_PADDING(n) (((n) + (16-1)) & (~(16-1)))
+#define WG_KEY_SIZE 32
+
+struct wg_pkt_initiation {
+ uint32_t t;
+ uint32_t s_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
+ struct cookie_macs m;
+};
+
+struct wg_pkt_response {
+ uint32_t t;
+ uint32_t s_idx;
+ uint32_t r_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t en[0 + NOISE_AUTHTAG_LEN];
+ struct cookie_macs m;
+};
+
+struct wg_pkt_cookie {
+ uint32_t t;
+ uint32_t r_idx;
+ uint8_t nonce[COOKIE_NONCE_SIZE];
+ uint8_t ec[COOKIE_ENCRYPTED_SIZE];
+};
+
+struct wg_pkt_data {
+ uint32_t t;
+ uint32_t r_idx;
+ uint8_t nonce[sizeof(uint64_t)];
+ uint8_t buf[];
+};
+
+struct wg_endpoint {
+ union {
+ struct sockaddr r_sa;
+ struct sockaddr_in r_sin;
+#ifdef INET6
+ struct sockaddr_in6 r_sin6;
+#endif
+ } e_remote;
+ union {
+ struct in_addr l_in;
+#ifdef INET6
+ struct in6_pktinfo l_pktinfo6;
+#define l_in6 l_pktinfo6.ipi6_addr
+#endif
+ } e_local;
+};
+
+struct wg_tag {
+ struct m_tag t_tag;
+ struct wg_endpoint t_endpoint;
+ struct wg_peer *t_peer;
+ struct mbuf *t_mbuf;
+ int t_done;
+ int t_mtu;
+};
+
+struct wg_index {
+ LIST_ENTRY(wg_index) i_entry;
+ SLIST_ENTRY(wg_index) i_unused_entry;
+ uint32_t i_key;
+ struct noise_remote *i_value;
+};
+
+struct wg_timers {
+ /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */
+ struct rwlock t_lock;
+
+ int t_disabled;
+ int t_need_another_keepalive;
+ uint16_t t_persistent_keepalive_interval;
+ struct callout t_new_handshake;
+ struct callout t_send_keepalive;
+ struct callout t_retry_handshake;
+ struct callout t_zero_key_material;
+ struct callout t_persistent_keepalive;
+
+ struct mtx t_handshake_mtx;
+ struct timespec t_handshake_last_sent;
+ struct timespec t_handshake_complete;
+ volatile int t_handshake_retries;
+};
+
+struct wg_aip {
+ struct radix_node r_nodes[2];
+ CK_LIST_ENTRY(wg_aip) r_entry;
+ struct sockaddr_storage r_addr;
+ struct sockaddr_storage r_mask;
+ struct wg_peer *r_peer;
+};
+
+struct wg_queue {
+ struct mtx q_mtx;
+ struct mbufq q;
+};
+
+struct wg_peer {
+ CK_LIST_ENTRY(wg_peer) p_hash_entry;
+ CK_LIST_ENTRY(wg_peer) p_entry;
+ uint64_t p_id;
+ struct wg_softc *p_sc;
+
+ struct noise_remote p_remote;
+ struct cookie_maker p_cookie;
+ struct wg_timers p_timers;
+
+ struct rwlock p_endpoint_lock;
+ struct wg_endpoint p_endpoint;
+
+ SLIST_HEAD(,wg_index) p_unused_index;
+ struct wg_index p_index[3];
+
+ struct wg_queue p_stage_queue;
+ struct wg_queue p_encap_queue;
+ struct wg_queue p_decap_queue;
+
+ struct grouptask p_clear_secrets;
+ struct grouptask p_send_initiation;
+ struct grouptask p_send_keepalive;
+ struct grouptask p_send;
+ struct grouptask p_recv;
+
+ counter_u64_t p_tx_bytes;
+ counter_u64_t p_rx_bytes;
+
+ CK_LIST_HEAD(, wg_aip) p_aips;
+ struct mtx p_lock;
+ struct epoch_context p_ctx;
+};
+
+enum route_direction {
+ /* TODO OpenBSD doesn't use IN/OUT, instead passes the address buffer
+ * directly to route_lookup. */
+ IN,
+ OUT,
+};
+
+struct wg_aip_table {
+ size_t t_count;
+ struct radix_node_head *t_ip;
+ struct radix_node_head *t_ip6;
+};
+
+struct wg_allowedip {
+ uint16_t family;
+ union {
+ struct in_addr ip4;
+ struct in6_addr ip6;
+ };
+ uint8_t cidr;
+};
+
+struct wg_hashtable {
+ struct mtx h_mtx;
+ SIPHASH_KEY h_secret;
+ CK_LIST_HEAD(, wg_peer) h_peers_list;
+ CK_LIST_HEAD(, wg_peer) *h_peers;
+ u_long h_peers_mask;
+ size_t h_num_peers;
+};
+
+struct wg_socket {
+ struct mtx so_mtx;
+ struct socket *so_so4;
+ struct socket *so_so6;
+ uint32_t so_user_cookie;
+ in_port_t so_port;
+};
+
+struct wg_softc {
+ LIST_ENTRY(wg_softc) sc_entry;
+ struct ifnet *sc_ifp;
+ int sc_flags;
+
+ struct ucred *sc_ucred;
+ struct wg_socket sc_socket;
+ struct wg_hashtable sc_hashtable;
+ struct wg_aip_table sc_aips;
+
+ struct mbufq sc_handshake_queue;
+ struct grouptask sc_handshake;
+
+ struct noise_local sc_local;
+ struct cookie_checker sc_cookie;
+
+ struct buf_ring *sc_encap_ring;
+ struct buf_ring *sc_decap_ring;
+
+ struct grouptask *sc_encrypt;
+ struct grouptask *sc_decrypt;
+
+ struct rwlock sc_index_lock;
+ LIST_HEAD(,wg_index) *sc_index;
+ u_long sc_index_mask;
+
+ struct sx sc_lock;
+ volatile u_int sc_peer_count;
+};
+
+#define WGF_DYING 0x0001
+
+/* TODO the following defines are freebsd specific, we should see what is
+ * necessary and cleanup from there (i suspect a lot can be junked). */
+
+#ifndef ENOKEY
+#define ENOKEY ENOTCAPABLE
+#endif
+
+#if __FreeBSD_version > 1300000
+typedef void timeout_t (void *);
+#endif
+
+#define GROUPTASK_DRAIN(gtask) \
+ gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
+
+#define MTAG_WIREGUARD 0xBEAD
+#define M_ENQUEUED M_PROTO1
+
+static int clone_count;
+static uma_zone_t ratelimit_zone;
+static int wireguard_debug;
+static volatile unsigned long peer_counter = 0;
+static const char wgname[] = "wg";
+static unsigned wg_osd_jail_slot;
+
+static struct sx wg_sx;
+SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
+
+static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
+
+SYSCTL_NODE(_net, OID_AUTO, wg, CTLFLAG_RW, 0, "WireGuard");
+SYSCTL_INT(_net_wg, OID_AUTO, debug, CTLFLAG_RWTUN, &wireguard_debug, 0,
+ "enable debug logging");
+
+TASKQGROUP_DECLARE(if_io_tqg);
+
+MALLOC_DEFINE(M_WG, "WG", "wireguard");
+VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
+
+
+#define V_wg_cloner VNET(wg_cloner)
+#define WG_CAPS IFCAP_LINKSTATE
+#define ph_family PH_loc.eight[5]
+
+struct wg_timespec64 {
+ uint64_t tv_sec;
+ uint64_t tv_nsec;
+};
+
+struct wg_peer_export {
+ struct sockaddr_storage endpoint;
+ struct timespec last_handshake;
+ uint8_t public_key[WG_KEY_SIZE];
+ uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN];
+ size_t endpoint_sz;
+ struct wg_allowedip *aip;
+ uint64_t rx_bytes;
+ uint64_t tx_bytes;
+ int aip_count;
+ uint16_t persistent_keepalive;
+};
+
+static struct wg_tag *wg_tag_get(struct mbuf *);
+static struct wg_endpoint *wg_mbuf_endpoint_get(struct mbuf *);
+static int wg_socket_init(struct wg_softc *, in_port_t);
+static int wg_socket_bind(struct socket *, struct socket *, in_port_t *);
+static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
+static void wg_socket_uninit(struct wg_softc *);
+static void wg_socket_set_cookie(struct wg_softc *, uint32_t);
+static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
+static void wg_timers_event_data_sent(struct wg_timers *);
+static void wg_timers_event_data_received(struct wg_timers *);
+static void wg_timers_event_any_authenticated_packet_sent(struct wg_timers *);
+static void wg_timers_event_any_authenticated_packet_received(struct wg_timers *);
+static void wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *);
+static void wg_timers_event_handshake_initiated(struct wg_timers *);
+static void wg_timers_event_handshake_responded(struct wg_timers *);
+static void wg_timers_event_handshake_complete(struct wg_timers *);
+static void wg_timers_event_session_derived(struct wg_timers *);
+static void wg_timers_event_want_initiation(struct wg_timers *);
+static void wg_timers_event_reset_handshake_last_sent(struct wg_timers *);
+static void wg_timers_run_send_initiation(struct wg_timers *, int);
+static void wg_timers_run_retry_handshake(struct wg_timers *);
+static void wg_timers_run_send_keepalive(struct wg_timers *);
+static void wg_timers_run_new_handshake(struct wg_timers *);
+static void wg_timers_run_zero_key_material(struct wg_timers *);
+static void wg_timers_run_persistent_keepalive(struct wg_timers *);
+static void wg_timers_init(struct wg_timers *);
+static void wg_timers_enable(struct wg_timers *);
+static void wg_timers_disable(struct wg_timers *);
+static void wg_timers_set_persistent_keepalive(struct wg_timers *, uint16_t);
+static void wg_timers_get_last_handshake(struct wg_timers *, struct timespec *);
+static int wg_timers_expired_handshake_last_sent(struct wg_timers *);
+static int wg_timers_check_handshake_last_sent(struct wg_timers *);
+static void wg_queue_init(struct wg_queue *, const char *);
+static void wg_queue_deinit(struct wg_queue *);
+static void wg_queue_purge(struct wg_queue *);
+static struct mbuf *wg_queue_dequeue(struct wg_queue *, struct wg_tag **);
+static int wg_queue_len(struct wg_queue *);
+static int wg_queue_in(struct wg_peer *, struct mbuf *);
+static void wg_queue_out(struct wg_peer *);
+static void wg_queue_stage(struct wg_peer *, struct mbuf *);
+static int wg_aip_init(struct wg_aip_table *);
+static void wg_aip_destroy(struct wg_aip_table *);
+static void wg_aip_populate_aip4(struct wg_aip *, const struct in_addr *, uint8_t);
+static void wg_aip_populate_aip6(struct wg_aip *, const struct in6_addr *, uint8_t);
+static int wg_aip_add(struct wg_aip_table *, struct wg_peer *, const struct wg_allowedip *);
+static int wg_peer_remove(struct radix_node *, void *);
+static void wg_peer_remove_all(struct wg_softc *);
+static int wg_aip_delete(struct wg_aip_table *, struct wg_peer *);
+static struct wg_peer *wg_aip_lookup(struct wg_aip_table *, struct mbuf *, enum route_direction);
+static void wg_hashtable_init(struct wg_hashtable *);
+static void wg_hashtable_destroy(struct wg_hashtable *);
+static void wg_hashtable_peer_insert(struct wg_hashtable *, struct wg_peer *);
+static struct wg_peer *wg_peer_lookup(struct wg_softc *, const uint8_t [32]);
+static void wg_hashtable_peer_remove(struct wg_hashtable *, struct wg_peer *);
+static int wg_cookie_validate_packet(struct cookie_checker *, struct mbuf *, int);
+static struct wg_peer *wg_peer_alloc(struct wg_softc *);
+static void wg_peer_free_deferred(epoch_context_t);
+static void wg_peer_destroy(struct wg_peer *);
+static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
+static void wg_send_initiation(struct wg_peer *);
+static void wg_send_response(struct wg_peer *);
+static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct mbuf *);
+static void wg_peer_set_endpoint_from_tag(struct wg_peer *, struct wg_tag *);
+static void wg_peer_clear_src(struct wg_peer *);
+static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
+static void wg_deliver_out(struct wg_peer *);
+static void wg_deliver_in(struct wg_peer *);
+static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
+static void wg_send_keepalive(struct wg_peer *);
+static void wg_handshake(struct wg_softc *, struct mbuf *);
+static void wg_encap(struct wg_softc *, struct mbuf *);
+static void wg_decap(struct wg_softc *, struct mbuf *);
+static void wg_softc_handshake_receive(struct wg_softc *);
+static void wg_softc_decrypt(struct wg_softc *);
+static void wg_softc_encrypt(struct wg_softc *);
+static struct noise_remote *wg_remote_get(struct wg_softc *, uint8_t [NOISE_PUBLIC_KEY_LEN]);
+static uint32_t wg_index_set(struct wg_softc *, struct noise_remote *);
+static struct noise_remote *wg_index_get(struct wg_softc *, uint32_t);
+static void wg_index_drop(struct wg_softc *, uint32_t);
+static int wg_update_endpoint_addrs(struct wg_endpoint *, const struct sockaddr *, struct ifnet *);
+static void wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
+static void wg_encrypt_dispatch(struct wg_softc *);
+static void wg_decrypt_dispatch(struct wg_softc *);
+static void crypto_taskq_setup(struct wg_softc *);
+static void crypto_taskq_destroy(struct wg_softc *);
+static int wg_clone_create(struct if_clone *, int, caddr_t);
+static void wg_qflush(struct ifnet *);
+static int wg_transmit(struct ifnet *, struct mbuf *);
+static int wg_output(struct ifnet *, struct mbuf *, const struct sockaddr *, struct route *);
+static void wg_clone_destroy(struct ifnet *);
+static int wg_peer_to_export(struct wg_peer *, struct wg_peer_export *);
+static bool wgc_privileged(struct wg_softc *);
+static int wgc_get(struct wg_softc *, struct wg_data_io *);
+static int wgc_set(struct wg_softc *, struct wg_data_io *);
+static int wg_up(struct wg_softc *);
+static void wg_down(struct wg_softc *);
+static void wg_reassign(struct ifnet *, struct vnet *, char *unused);
+static void wg_init(void *);
+static int wg_ioctl(struct ifnet *, u_long, caddr_t);
+static void vnet_wg_init(const void *);
+static void vnet_wg_uninit(const void *);
+static void wg_module_init(void);
+static void wg_module_deinit(void);
+
+/* TODO Peer */
+static struct wg_peer *
+wg_peer_alloc(struct wg_softc *sc)
+{
+ struct wg_peer *peer;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ peer = malloc(sizeof(*peer), M_WG, M_WAITOK|M_ZERO);
+ peer->p_sc = sc;
+ peer->p_id = peer_counter++;
+ CK_LIST_INIT(&peer->p_aips);
+
+ rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
+ wg_queue_init(&peer->p_stage_queue, "stageq");
+ wg_queue_init(&peer->p_encap_queue, "txq");
+ wg_queue_init(&peer->p_decap_queue, "rxq");
+
+ GROUPTASK_INIT(&peer->p_send_initiation, 0, (gtask_fn_t *)wg_send_initiation, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send_initiation, peer, NULL, NULL, "wg initiation");
+ GROUPTASK_INIT(&peer->p_send_keepalive, 0, (gtask_fn_t *)wg_send_keepalive, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send_keepalive, peer, NULL, NULL, "wg keepalive");
+ GROUPTASK_INIT(&peer->p_clear_secrets, 0, (gtask_fn_t *)noise_remote_clear, &peer->p_remote);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_clear_secrets,
+ &peer->p_remote, NULL, NULL, "wg clear secrets");
+
+ GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
+ GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
+
+ wg_timers_init(&peer->p_timers);
+
+ peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
+ peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
+
+ SLIST_INIT(&peer->p_unused_index);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0],
+ i_unused_entry);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1],
+ i_unused_entry);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2],
+ i_unused_entry);
+
+ return (peer);
+}
+
+#define WG_HASHTABLE_PEER_FOREACH(peer, i, ht) \
+ for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \
+ LIST_FOREACH(peer, &(ht)->h_peers[i], p_hash_entry)
+#define WG_HASHTABLE_PEER_FOREACH_SAFE(peer, i, ht, tpeer) \
+ for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \
+ CK_LIST_FOREACH_SAFE(peer, &(ht)->h_peers[i], p_hash_entry, tpeer)
+static void
+wg_hashtable_init(struct wg_hashtable *ht)
+{
+ mtx_init(&ht->h_mtx, "hash lock", NULL, MTX_DEF);
+ arc4random_buf(&ht->h_secret, sizeof(ht->h_secret));
+ ht->h_num_peers = 0;
+ ht->h_peers = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF,
+ &ht->h_peers_mask);
+}
+
+static void
+wg_hashtable_destroy(struct wg_hashtable *ht)
+{
+ MPASS(ht->h_num_peers == 0);
+ mtx_destroy(&ht->h_mtx);
+ hashdestroy(ht->h_peers, M_DEVBUF, ht->h_peers_mask);
+}
+
+static void
+wg_hashtable_peer_insert(struct wg_hashtable *ht, struct wg_peer *peer)
+{
+ uint64_t key;
+
+ key = siphash24(&ht->h_secret, peer->p_remote.r_public,
+ sizeof(peer->p_remote.r_public));
+
+ mtx_lock(&ht->h_mtx);
+ ht->h_num_peers++;
+ CK_LIST_INSERT_HEAD(&ht->h_peers[key & ht->h_peers_mask], peer, p_hash_entry);
+ CK_LIST_INSERT_HEAD(&ht->h_peers_list, peer, p_entry);
+ mtx_unlock(&ht->h_mtx);
+}
+
+static struct wg_peer *
+wg_peer_lookup(struct wg_softc *sc,
+ const uint8_t pubkey[WG_KEY_SIZE])
+{
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ uint64_t key;
+ struct wg_peer *i = NULL;
+
+ key = siphash24(&ht->h_secret, pubkey, WG_KEY_SIZE);
+
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(i, &ht->h_peers[key & ht->h_peers_mask], p_hash_entry) {
+ if (timingsafe_bcmp(i->p_remote.r_public, pubkey,
+ WG_KEY_SIZE) == 0)
+ break;
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ return i;
+}
+
+static void
+wg_hashtable_peer_remove(struct wg_hashtable *ht, struct wg_peer *peer)
+{
+ mtx_lock(&ht->h_mtx);
+ ht->h_num_peers--;
+ CK_LIST_REMOVE(peer, p_hash_entry);
+ CK_LIST_REMOVE(peer, p_entry);
+ mtx_unlock(&ht->h_mtx);
+}
+
+static void
+wg_peer_free_deferred(epoch_context_t ctx)
+{
+ struct wg_peer *peer = __containerof(ctx, struct wg_peer, p_ctx);
+ counter_u64_free(peer->p_tx_bytes);
+ counter_u64_free(peer->p_rx_bytes);
+ rw_destroy(&peer->p_timers.t_lock);
+ rw_destroy(&peer->p_endpoint_lock);
+ free(peer, M_WG);
+}
+
+static void
+wg_peer_destroy(struct wg_peer *peer)
+{
+ /* Callers should already have called:
+ * wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
+ */
+ wg_aip_delete(&peer->p_sc->sc_aips, peer);
+ MPASS(CK_LIST_EMPTY(&peer->p_aips));
+
+ /* We disable all timers, so we can't call the following tasks. */
+ wg_timers_disable(&peer->p_timers);
+
+ /* Ensure the tasks have finished running */
+ GROUPTASK_DRAIN(&peer->p_clear_secrets);
+ GROUPTASK_DRAIN(&peer->p_send_initiation);
+ GROUPTASK_DRAIN(&peer->p_send_keepalive);
+ GROUPTASK_DRAIN(&peer->p_recv);
+ GROUPTASK_DRAIN(&peer->p_send);
+
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_clear_secrets);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send_initiation);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send_keepalive);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_recv);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send);
+
+ wg_queue_deinit(&peer->p_decap_queue);
+ wg_queue_deinit(&peer->p_encap_queue);
+ wg_queue_deinit(&peer->p_stage_queue);
+
+ /* Final cleanup */
+ --peer->p_sc->sc_peer_count;
+ noise_remote_clear(&peer->p_remote);
+ DPRINTF(peer->p_sc, "Peer %llu destroyed\n", (unsigned long long)peer->p_id);
+ NET_EPOCH_CALL(wg_peer_free_deferred, &peer->p_ctx);
+}
+
+static void
+wg_peer_set_endpoint_from_tag(struct wg_peer *peer, struct wg_tag *t)
+{
+ struct wg_endpoint *e = &t->t_endpoint;
+
+ MPASS(e->e_remote.r_sa.sa_family != 0);
+ if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
+ return;
+
+ peer->p_endpoint = *e;
+}
+
+static void
+wg_peer_clear_src(struct wg_peer *peer)
+{
+ rw_rlock(&peer->p_endpoint_lock);
+ bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
+ rw_runlock(&peer->p_endpoint_lock);
+}
+
+static void
+wg_peer_get_endpoint(struct wg_peer *p, struct wg_endpoint *e)
+{
+ memcpy(e, &p->p_endpoint, sizeof(*e));
+}
+
+/* Allowed IP */
+static int
+wg_aip_init(struct wg_aip_table *tbl)
+{
+ int rc;
+
+ tbl->t_count = 0;
+ rc = rn_inithead((void **)&tbl->t_ip,
+ offsetof(struct sockaddr_in, sin_addr) * NBBY);
+
+ if (rc == 0)
+ return (ENOMEM);
+ RADIX_NODE_HEAD_LOCK_INIT(tbl->t_ip);
+#ifdef INET6
+ rc = rn_inithead((void **)&tbl->t_ip6,
+ offsetof(struct sockaddr_in6, sin6_addr) * NBBY);
+ if (rc == 0) {
+ free(tbl->t_ip, M_RTABLE);
+ return (ENOMEM);
+ }
+ RADIX_NODE_HEAD_LOCK_INIT(tbl->t_ip6);
+#endif
+ return (0);
+}
+
+static void
+wg_aip_destroy(struct wg_aip_table *tbl)
+{
+ RADIX_NODE_HEAD_DESTROY(tbl->t_ip);
+ free(tbl->t_ip, M_RTABLE);
+#ifdef INET6
+ RADIX_NODE_HEAD_DESTROY(tbl->t_ip6);
+ free(tbl->t_ip6, M_RTABLE);
+#endif
+}
+
+static void
+wg_aip_populate_aip4(struct wg_aip *aip, const struct in_addr *addr,
+ uint8_t mask)
+{
+ struct sockaddr_in *raddr, *rmask;
+ uint8_t *p;
+ unsigned int i;
+
+ raddr = (struct sockaddr_in *)&aip->r_addr;
+ rmask = (struct sockaddr_in *)&aip->r_mask;
+
+ raddr->sin_len = sizeof(*raddr);
+ raddr->sin_family = AF_INET;
+ raddr->sin_addr = *addr;
+
+ rmask->sin_len = sizeof(*rmask);
+ p = (uint8_t *)&rmask->sin_addr.s_addr;
+ for (i = 0; i < mask / NBBY; i++)
+ p[i] = 0xff;
+ if ((mask % NBBY) != 0)
+ p[i] = (0xff00 >> (mask % NBBY)) & 0xff;
+ raddr->sin_addr.s_addr &= rmask->sin_addr.s_addr;
+}
+
+static void
+wg_aip_populate_aip6(struct wg_aip *aip, const struct in6_addr *addr,
+ uint8_t mask)
+{
+ struct sockaddr_in6 *raddr, *rmask;
+
+ raddr = (struct sockaddr_in6 *)&aip->r_addr;
+ rmask = (struct sockaddr_in6 *)&aip->r_mask;
+
+ raddr->sin6_len = sizeof(*raddr);
+ raddr->sin6_family = AF_INET6;
+ raddr->sin6_addr = *addr;
+
+ rmask->sin6_len = sizeof(*rmask);
+ in6_prefixlen2mask(&rmask->sin6_addr, mask);
+ for (int i = 0; i < 4; ++i)
+ raddr->sin6_addr.__u6_addr.__u6_addr32[i] &= rmask->sin6_addr.__u6_addr.__u6_addr32[i];
+}
+
+/* wg_aip_take assumes that the caller guarantees the allowed-ip exists. */
+static void
+wg_aip_take(struct radix_node_head *root, struct wg_peer *peer,
+ struct wg_aip *route)
+{
+ struct radix_node *node;
+ struct wg_peer *ppeer;
+
+ RADIX_NODE_HEAD_LOCK_ASSERT(root);
+
+ node = root->rnh_lookup(&route->r_addr, &route->r_mask,
+ &root->rh);
+ MPASS(node != NULL);
+
+ route = (struct wg_aip *)node;
+ ppeer = route->r_peer;
+ if (ppeer != peer) {
+ route->r_peer = peer;
+
+ CK_LIST_REMOVE(route, r_entry);
+ CK_LIST_INSERT_HEAD(&peer->p_aips, route, r_entry);
+ }
+}
+
+static int
+wg_aip_add(struct wg_aip_table *tbl, struct wg_peer *peer,
+ const struct wg_allowedip *aip)
+{
+ struct radix_node *node;
+ struct radix_node_head *root;
+ struct wg_aip *route;
+ sa_family_t family;
+ bool needfree = false;
+
+ family = aip->family;
+ if (family != AF_INET && family != AF_INET6) {
+ return (EINVAL);
+ }
+
+ route = malloc(sizeof(*route), M_WG, M_WAITOK|M_ZERO);
+ switch (family) {
+ case AF_INET:
+ root = tbl->t_ip;
+
+ wg_aip_populate_aip4(route, &aip->ip4, aip->cidr);
+ break;
+ case AF_INET6:
+ root = tbl->t_ip6;
+
+ wg_aip_populate_aip6(route, &aip->ip6, aip->cidr);
+ break;
+ }
+
+ route->r_peer = peer;
+
+ RADIX_NODE_HEAD_LOCK(root);
+ node = root->rnh_addaddr(&route->r_addr, &route->r_mask, &root->rh,
+ route->r_nodes);
+ if (node == route->r_nodes) {
+ tbl->t_count++;
+ CK_LIST_INSERT_HEAD(&peer->p_aips, route, r_entry);
+ } else {
+ needfree = true;
+ wg_aip_take(root, peer, route);
+ }
+ RADIX_NODE_HEAD_UNLOCK(root);
+ if (needfree) {
+ free(route, M_WG);
+ }
+ return (0);
+}
+
+static struct wg_peer *
+wg_aip_lookup(struct wg_aip_table *tbl, struct mbuf *m,
+ enum route_direction dir)
+{
+ RADIX_NODE_HEAD_RLOCK_TRACKER;
+ struct ip *iphdr;
+ struct ip6_hdr *ip6hdr;
+ struct radix_node_head *root;
+ struct radix_node *node;
+ struct wg_peer *peer = NULL;
+ struct sockaddr_in sin;
+ struct sockaddr_in6 sin6;
+ void *addr;
+ int version;
+
+ NET_EPOCH_ASSERT();
+ iphdr = mtod(m, struct ip *);
+ version = iphdr->ip_v;
+
+ if (__predict_false(dir != IN && dir != OUT))
+ return NULL;
+
+ if (version == 4) {
+ root = tbl->t_ip;
+ memset(&sin, 0, sizeof(sin));
+ sin.sin_len = sizeof(struct sockaddr_in);
+ if (dir == IN)
+ sin.sin_addr = iphdr->ip_src;
+ else
+ sin.sin_addr = iphdr->ip_dst;
+ addr = &sin;
+ } else if (version == 6) {
+ ip6hdr = mtod(m, struct ip6_hdr *);
+ memset(&sin6, 0, sizeof(sin6));
+ sin6.sin6_len = sizeof(struct sockaddr_in6);
+
+ root = tbl->t_ip6;
+ if (dir == IN)
+ addr = &ip6hdr->ip6_src;
+ else
+ addr = &ip6hdr->ip6_dst;
+ memcpy(&sin6.sin6_addr, addr, sizeof(sin6.sin6_addr));
+ addr = &sin6;
+ } else {
+ return (NULL);
+ }
+ RADIX_NODE_HEAD_RLOCK(root);
+ if ((node = root->rnh_matchaddr(addr, &root->rh)) != NULL) {
+ peer = ((struct wg_aip *) node)->r_peer;
+ }
+ RADIX_NODE_HEAD_RUNLOCK(root);
+ return (peer);
+}
+
+struct peer_del_arg {
+ struct radix_node_head * pda_head;
+ struct wg_peer *pda_peer;
+ struct wg_aip_table *pda_tbl;
+};
+
+static int
+wg_peer_remove(struct radix_node *rn, void *arg)
+{
+ struct peer_del_arg *pda = arg;
+ struct wg_peer *peer = pda->pda_peer;
+ struct radix_node_head * rnh = pda->pda_head;
+ struct wg_aip_table *tbl = pda->pda_tbl;
+ struct wg_aip *route = (struct wg_aip *)rn;
+ struct radix_node *x;
+
+ if (route->r_peer != peer)
+ return (0);
+ x = (struct radix_node *)rnh->rnh_deladdr(&route->r_addr,
+ &route->r_mask, &rnh->rh);
+ if (x != NULL) {
+ tbl->t_count--;
+ CK_LIST_REMOVE(route, r_entry);
+ free(route, M_WG);
+ }
+ return (0);
+}
+
+static void
+wg_peer_remove_all(struct wg_softc *sc)
+{
+ struct wg_peer *peer, *tpeer;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ CK_LIST_FOREACH_SAFE(peer, &sc->sc_hashtable.h_peers_list,
+ p_entry, tpeer) {
+ wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
+ wg_peer_destroy(peer);
+ }
+}
+
+static int
+wg_aip_delete(struct wg_aip_table *tbl, struct wg_peer *peer)
+{
+ struct peer_del_arg pda;
+
+ pda.pda_peer = peer;
+ pda.pda_tbl = tbl;
+ RADIX_NODE_HEAD_LOCK(tbl->t_ip);
+ pda.pda_head = tbl->t_ip;
+ rn_walktree(&tbl->t_ip->rh, wg_peer_remove, &pda);
+ RADIX_NODE_HEAD_UNLOCK(tbl->t_ip);
+
+ RADIX_NODE_HEAD_LOCK(tbl->t_ip6);
+ pda.pda_head = tbl->t_ip6;
+ rn_walktree(&tbl->t_ip6->rh, wg_peer_remove, &pda);
+ RADIX_NODE_HEAD_UNLOCK(tbl->t_ip6);
+ return (0);
+}
+
+static int
+wg_socket_init(struct wg_softc *sc, in_port_t port)
+{
+ struct thread *td;
+ struct ucred *cred;
+ struct socket *so4, *so6;
+ int rc;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ td = curthread;
+ if (sc->sc_ucred == NULL)
+ return (EBUSY);
+ cred = crhold(sc->sc_ucred);
+
+ /*
+ * For socket creation, we use the creds of the thread that created the
+ * tunnel rather than the current thread to maintain the semantics that
+ * WireGuard has on Linux with network namespaces -- that the sockets
+ * are created in their home vnet so that they can be configured and
+ * functionally attached to a foreign vnet as the jail's only interface
+ * to the network.
+ */
+ rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, td);
+ if (rc)
+ goto out;
+
+ rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
+ /*
+ * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
+ * This should never happen with a new socket.
+ */
+ MPASS(rc == 0);
+
+ rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, td);
+ if (rc) {
+ SOCK_LOCK(so4);
+ sofree(so4);
+ goto out;
+ }
+ rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
+ MPASS(rc == 0);
+
+ so4->so_user_cookie = so6->so_user_cookie = sc->sc_socket.so_user_cookie;
+
+ rc = wg_socket_bind(so4, so6, &port);
+ if (rc == 0) {
+ sc->sc_socket.so_port = port;
+ wg_socket_set(sc, so4, so6);
+ }
+out:
+ crfree(cred);
+ return (rc);
+}
+
+static void wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
+{
+ struct wg_socket *so = &sc->sc_socket;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ so->so_user_cookie = user_cookie;
+ if (so->so_so4)
+ so->so_so4->so_user_cookie = user_cookie;
+ if (so->so_so6)
+ so->so_so6->so_user_cookie = user_cookie;
+}
+
+static void
+wg_socket_uninit(struct wg_softc *sc)
+{
+ wg_socket_set(sc, NULL, NULL);
+}
+
+static void
+wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
+{
+ struct wg_socket *so = &sc->sc_socket;
+ struct socket *so4, *so6;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ so4 = atomic_load_ptr(&so->so_so4);
+ so6 = atomic_load_ptr(&so->so_so6);
+ atomic_store_ptr(&so->so_so4, new_so4);
+ atomic_store_ptr(&so->so_so6, new_so6);
+
+ if (!so4 && !so6)
+ return;
+ NET_EPOCH_WAIT();
+ if (so4)
+ soclose(so4);
+ if (so6)
+ soclose(so6);
+}
+
+union wg_sockaddr {
+ struct sockaddr sa;
+ struct sockaddr_in in4;
+ struct sockaddr_in6 in6;
+};
+
+static int
+wg_socket_bind(struct socket *so4, struct socket *so6, in_port_t *requested_port)
+{
+ int rc;
+ struct thread *td;
+ union wg_sockaddr laddr;
+ struct sockaddr_in *sin;
+ struct sockaddr_in6 *sin6;
+ in_port_t port = *requested_port;
+
+ td = curthread;
+ bzero(&laddr, sizeof(laddr));
+ sin = &laddr.in4;
+ sin->sin_len = sizeof(laddr.in4);
+ sin->sin_family = AF_INET;
+ sin->sin_port = htons(port);
+ sin->sin_addr = (struct in_addr) { 0 };
+
+ if ((rc = sobind(so4, &laddr.sa, td)) != 0)
+ return (rc);
+
+ if (port == 0) {
+ rc = sogetsockaddr(so4, (struct sockaddr **)&sin);
+ if (rc != 0)
+ return (rc);
+ port = ntohs(sin->sin_port);
+ free(sin, M_SONAME);
+ }
+
+ sin6 = &laddr.in6;
+ sin6->sin6_len = sizeof(laddr.in6);
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_port = htons(port);
+ sin6->sin6_addr = (struct in6_addr) { .s6_addr = { 0 } };
+ rc = sobind(so6, &laddr.sa, td);
+ if (rc != 0)
+ return (rc);
+ *requested_port = port;
+ return (0);
+}
+
+static int
+wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
+{
+ struct epoch_tracker et;
+ struct sockaddr *sa;
+ struct wg_socket *so = &sc->sc_socket;
+ struct socket *so4, *so6;
+ struct mbuf *control = NULL;
+ int ret = 0;
+ size_t len = m->m_pkthdr.len;
+
+ /* Get local control address before locking */
+ if (e->e_remote.r_sa.sa_family == AF_INET) {
+ if (e->e_local.l_in.s_addr != INADDR_ANY)
+ control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
+ sizeof(struct in_addr), IP_SENDSRCADDR,
+ IPPROTO_IP);
+ } else if (e->e_remote.r_sa.sa_family == AF_INET6) {
+ if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
+ control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
+ sizeof(struct in6_pktinfo), IPV6_PKTINFO,
+ IPPROTO_IPV6);
+ } else {
+ m_freem(m);
+ return (EAFNOSUPPORT);
+ }
+
+ /* Get remote address */
+ sa = &e->e_remote.r_sa;
+
+ NET_EPOCH_ENTER(et);
+ so4 = atomic_load_ptr(&so->so_so4);
+ so6 = atomic_load_ptr(&so->so_so6);
+ if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
+ ret = sosend(so4, sa, NULL, m, control, 0, curthread);
+ else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
+ ret = sosend(so6, sa, NULL, m, control, 0, curthread);
+ else {
+ ret = ENOTCONN;
+ m_freem(control);
+ m_freem(m);
+ }
+ NET_EPOCH_EXIT(et);
+ if (ret == 0) {
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
+ }
+ return (ret);
+}
+
+static void
+wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf,
+ size_t len)
+{
+ struct mbuf *m;
+ int ret = 0;
+
+retry:
+ m = m_gethdr(M_WAITOK, MT_DATA);
+ m->m_len = 0;
+ m_copyback(m, 0, len, buf);
+
+ if (ret == 0) {
+ ret = wg_send(sc, e, m);
+ /* Retry if we couldn't bind to e->e_local */
+ if (ret == EADDRNOTAVAIL) {
+ bzero(&e->e_local, sizeof(e->e_local));
+ goto retry;
+ }
+ } else {
+ ret = wg_send(sc, e, m);
+ }
+ if (ret)
+ DPRINTF(sc, "Unable to send packet: %d\n", ret);
+}
+
+/* TODO Tag */
+static struct wg_tag *
+wg_tag_get(struct mbuf *m)
+{
+ struct m_tag *tag;
+
+ tag = m_tag_find(m, MTAG_WIREGUARD, NULL);
+ if (tag == NULL) {
+ tag = m_tag_get(MTAG_WIREGUARD, sizeof(struct wg_tag), M_NOWAIT|M_ZERO);
+ m_tag_prepend(m, tag);
+ MPASS(!SLIST_EMPTY(&m->m_pkthdr.tags));
+ MPASS(m_tag_locate(m, MTAG_ABI_COMPAT, MTAG_WIREGUARD, NULL) == tag);
+ }
+ return (struct wg_tag *)tag;
+}
+
+static struct wg_endpoint *
+wg_mbuf_endpoint_get(struct mbuf *m)
+{
+ struct wg_tag *hdr;
+
+ if ((hdr = wg_tag_get(m)) == NULL)
+ return (NULL);
+
+ return (&hdr->t_endpoint);
+}
+
+/* Timers */
+static void
+wg_timers_init(struct wg_timers *t)
+{
+ bzero(t, sizeof(*t));
+
+ t->t_disabled = 1;
+ rw_init(&t->t_lock, "wg peer timers");
+ callout_init(&t->t_retry_handshake, true);
+ callout_init(&t->t_send_keepalive, true);
+ callout_init(&t->t_new_handshake, true);
+ callout_init(&t->t_zero_key_material, true);
+ callout_init(&t->t_persistent_keepalive, true);
+}
+
+static void
+wg_timers_enable(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ t->t_disabled = 0;
+ rw_wunlock(&t->t_lock);
+ wg_timers_run_persistent_keepalive(t);
+}
+
+static void
+wg_timers_disable(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ t->t_disabled = 1;
+ t->t_need_another_keepalive = 0;
+ rw_wunlock(&t->t_lock);
+
+ callout_stop(&t->t_retry_handshake);
+ callout_stop(&t->t_send_keepalive);
+ callout_stop(&t->t_new_handshake);
+ callout_stop(&t->t_zero_key_material);
+ callout_stop(&t->t_persistent_keepalive);
+}
+
+static void
+wg_timers_set_persistent_keepalive(struct wg_timers *t, uint16_t interval)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled) {
+ t->t_persistent_keepalive_interval = interval;
+ wg_timers_run_persistent_keepalive(t);
+ }
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_get_last_handshake(struct wg_timers *t, struct timespec *time)
+{
+ rw_rlock(&t->t_lock);
+ time->tv_sec = t->t_handshake_complete.tv_sec;
+ time->tv_nsec = t->t_handshake_complete.tv_nsec;
+ rw_runlock(&t->t_lock);
+}
+
+static int
+wg_timers_expired_handshake_last_sent(struct wg_timers *t)
+{
+ struct timespec uptime;
+ struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 };
+
+ getnanouptime(&uptime);
+ timespecadd(&t->t_handshake_last_sent, &expire, &expire);
+ return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
+}
+
+static int
+wg_timers_check_handshake_last_sent(struct wg_timers *t)
+{
+ int ret;
+
+ rw_wlock(&t->t_lock);
+ if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT)
+ getnanouptime(&t->t_handshake_last_sent);
+ rw_wunlock(&t->t_lock);
+ return (ret);
+}
+
+/* Should be called after an authenticated data packet is sent. */
+static void
+wg_timers_event_data_sent(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled && !callout_pending(&t->t_new_handshake))
+ callout_reset(&t->t_new_handshake, MSEC_2_TICKS(
+ NEW_HANDSHAKE_TIMEOUT * 1000 +
+ arc4random_uniform(REKEY_TIMEOUT_JITTER)),
+ (timeout_t *)wg_timers_run_new_handshake, t);
+ rw_runlock(&t->t_lock);
+}
+
+/* Should be called after an authenticated data packet is received. */
+static void
+wg_timers_event_data_received(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled) {
+ if (!callout_pending(&t->t_send_keepalive)) {
+ callout_reset(&t->t_send_keepalive,
+ MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
+ (timeout_t *)wg_timers_run_send_keepalive, t);
+ } else {
+ t->t_need_another_keepalive = 1;
+ }
+ }
+ rw_runlock(&t->t_lock);
+}
+
+/*
+ * Should be called after any type of authenticated packet is sent, whether
+ * keepalive, data, or handshake.
+ */
+static void
+wg_timers_event_any_authenticated_packet_sent(struct wg_timers *t)
+{
+ callout_stop(&t->t_send_keepalive);
+}
+
+/*
+ * Should be called after any type of authenticated packet is received, whether
+ * keepalive, data, or handshake.
+ */
+static void
+wg_timers_event_any_authenticated_packet_received(struct wg_timers *t)
+{
+ callout_stop(&t->t_new_handshake);
+}
+
+/*
+ * Should be called before a packet with authentication, whether
+ * keepalive, data, or handshake is sent, or after one is received.
+ */
+static void
+wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled && t->t_persistent_keepalive_interval > 0)
+ callout_reset(&t->t_persistent_keepalive,
+ MSEC_2_TICKS(t->t_persistent_keepalive_interval * 1000),
+ (timeout_t *)wg_timers_run_persistent_keepalive, t);
+ rw_runlock(&t->t_lock);
+}
+
+/* Should be called after a handshake initiation message is sent. */
+static void
+wg_timers_event_handshake_initiated(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled)
+ callout_reset(&t->t_retry_handshake, MSEC_2_TICKS(
+ REKEY_TIMEOUT * 1000 +
+ arc4random_uniform(REKEY_TIMEOUT_JITTER)),
+ (timeout_t *)wg_timers_run_retry_handshake, t);
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_event_handshake_responded(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ getnanouptime(&t->t_handshake_last_sent);
+ rw_wunlock(&t->t_lock);
+}
+
+/*
+ * Should be called after a handshake response message is received and processed
+ * or when getting key confirmation via the first data message.
+ */
+static void
+wg_timers_event_handshake_complete(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ if (!t->t_disabled) {
+ callout_stop(&t->t_retry_handshake);
+ t->t_handshake_retries = 0;
+ getnanotime(&t->t_handshake_complete);
+ wg_timers_run_send_keepalive(t);
+ }
+ rw_wunlock(&t->t_lock);
+}
+
+/*
+ * Should be called after an ephemeral key is created, which is before sending a
+ * handshake response or after receiving a handshake response.
+ */
+static void
+wg_timers_event_session_derived(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled) {
+ callout_reset(&t->t_zero_key_material,
+ MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
+ (timeout_t *)wg_timers_run_zero_key_material, t);
+ }
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_event_want_initiation(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled)
+ wg_timers_run_send_initiation(t, 0);
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_event_reset_handshake_last_sent(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1);
+ rw_wunlock(&t->t_lock);
+}
+
+static void
+wg_timers_run_send_initiation(struct wg_timers *t, int is_retry)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+ if (!is_retry)
+ t->t_handshake_retries = 0;
+ if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT)
+ GROUPTASK_ENQUEUE(&peer->p_send_initiation);
+}
+
+static void
+wg_timers_run_retry_handshake(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ rw_wlock(&t->t_lock);
+ if (t->t_handshake_retries <= MAX_TIMER_HANDSHAKES) {
+ t->t_handshake_retries++;
+ rw_wunlock(&t->t_lock);
+
+ DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete "
+ "after %d seconds, retrying (try %d)\n",
+ (unsigned long long)peer->p_id,
+ REKEY_TIMEOUT, t->t_handshake_retries + 1);
+ wg_peer_clear_src(peer);
+ wg_timers_run_send_initiation(t, 1);
+ } else {
+ rw_wunlock(&t->t_lock);
+
+ DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete "
+ "after %d retries, giving up\n",
+ (unsigned long long) peer->p_id, MAX_TIMER_HANDSHAKES + 2);
+
+ callout_stop(&t->t_send_keepalive);
+ wg_queue_purge(&peer->p_stage_queue);
+ if (!callout_pending(&t->t_zero_key_material))
+ callout_reset(&t->t_zero_key_material,
+ MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
+ (timeout_t *)wg_timers_run_zero_key_material, t);
+ }
+}
+
+static void
+wg_timers_run_send_keepalive(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ GROUPTASK_ENQUEUE(&peer->p_send_keepalive);
+ if (t->t_need_another_keepalive) {
+ t->t_need_another_keepalive = 0;
+ callout_reset(&t->t_send_keepalive,
+ MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
+ (timeout_t *)wg_timers_run_send_keepalive, t);
+ }
+}
+
+static void
+wg_timers_run_new_handshake(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ DPRINTF(peer->p_sc, "Retrying handshake with peer %llu because we "
+ "stopped hearing back after %d seconds\n",
+ (unsigned long long)peer->p_id, NEW_HANDSHAKE_TIMEOUT);
+ wg_peer_clear_src(peer);
+
+ wg_timers_run_send_initiation(t, 0);
+}
+
+static void
+wg_timers_run_zero_key_material(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ DPRINTF(peer->p_sc, "Zeroing out all keys for peer %llu, since we "
+ "haven't received a new one in %d seconds\n",
+ (unsigned long long)peer->p_id, REJECT_AFTER_TIME * 3);
+ GROUPTASK_ENQUEUE(&peer->p_clear_secrets);
+}
+
+static void
+wg_timers_run_persistent_keepalive(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ if (t->t_persistent_keepalive_interval != 0)
+ GROUPTASK_ENQUEUE(&peer->p_send_keepalive);
+}
+
+/* TODO Handshake */
+static void
+wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
+{
+ struct wg_endpoint endpoint;
+
+ counter_u64_add(peer->p_tx_bytes, len);
+ wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers);
+ wg_timers_event_any_authenticated_packet_sent(&peer->p_timers);
+ wg_peer_get_endpoint(peer, &endpoint);
+ wg_send_buf(peer->p_sc, &endpoint, buf, len);
+}
+
+static void
+wg_send_initiation(struct wg_peer *peer)
+{
+ struct wg_pkt_initiation pkt;
+ struct epoch_tracker et;
+
+ if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT)
+ return;
+ DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n",
+ (unsigned long long)peer->p_id);
+
+ NET_EPOCH_ENTER(et);
+ if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue,
+ pkt.es, pkt.ets) != 0)
+ goto out;
+ pkt.t = WG_PKT_INITIATION;
+ cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
+ sizeof(pkt)-sizeof(pkt.m));
+ wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
+ wg_timers_event_handshake_initiated(&peer->p_timers);
+out:
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_send_response(struct wg_peer *peer)
+{
+ struct wg_pkt_response pkt;
+ struct epoch_tracker et;
+
+ NET_EPOCH_ENTER(et);
+
+ DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n",
+ (unsigned long long)peer->p_id);
+
+ if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx,
+ pkt.ue, pkt.en) != 0)
+ goto out;
+ if (noise_remote_begin_session(&peer->p_remote) != 0)
+ goto out;
+
+ wg_timers_event_session_derived(&peer->p_timers);
+ pkt.t = WG_PKT_RESPONSE;
+ cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
+ sizeof(pkt)-sizeof(pkt.m));
+ wg_timers_event_handshake_responded(&peer->p_timers);
+ wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
+out:
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
+ struct mbuf *m)
+{
+ struct wg_pkt_cookie pkt;
+ struct wg_endpoint *e;
+
+ DPRINTF(sc, "Sending cookie response for denied handshake message\n");
+
+ pkt.t = WG_PKT_COOKIE;
+ pkt.r_idx = idx;
+
+ e = wg_mbuf_endpoint_get(m);
+ cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
+ pkt.ec, &e->e_remote.r_sa);
+ wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
+}
+
+static void
+wg_send_keepalive(struct wg_peer *peer)
+{
+ struct mbuf *m = NULL;
+ struct wg_tag *t;
+ struct epoch_tracker et;
+
+ if (wg_queue_len(&peer->p_stage_queue) != 0) {
+ NET_EPOCH_ENTER(et);
+ goto send;
+ }
+ if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
+ return;
+ if ((t = wg_tag_get(m)) == NULL) {
+ m_freem(m);
+ return;
+ }
+ t->t_peer = peer;
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+ t->t_mtu = 0; /* MTU == 0 OK for keepalive */
+
+ NET_EPOCH_ENTER(et);
+ wg_queue_stage(peer, m);
+send:
+ wg_queue_out(peer);
+ NET_EPOCH_EXIT(et);
+}
+
+static int
+wg_cookie_validate_packet(struct cookie_checker *checker, struct mbuf *m,
+ int under_load)
+{
+ struct wg_pkt_initiation *init;
+ struct wg_pkt_response *resp;
+ struct cookie_macs *macs;
+ struct wg_endpoint *e;
+ int type, size;
+ void *data;
+
+ type = *mtod(m, uint32_t *);
+ data = m->m_data;
+ e = wg_mbuf_endpoint_get(m);
+ if (type == WG_PKT_INITIATION) {
+ init = mtod(m, struct wg_pkt_initiation *);
+ macs = &init->m;
+ size = sizeof(*init) - sizeof(*macs);
+ } else if (type == WG_PKT_RESPONSE) {
+ resp = mtod(m, struct wg_pkt_response *);
+ macs = &resp->m;
+ size = sizeof(*resp) - sizeof(*macs);
+ } else
+ return 0;
+
+ return (cookie_checker_validate_macs(checker, macs, data, size,
+ under_load, &e->e_remote.r_sa));
+}
+
+
+static void
+wg_handshake(struct wg_softc *sc, struct mbuf *m)
+{
+ struct wg_pkt_initiation *init;
+ struct wg_pkt_response *resp;
+ struct noise_remote *remote;
+ struct wg_pkt_cookie *cook;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+
+ /* This is global, so that our load calculation applies to the whole
+ * system. We don't care about races with it at all.
+ */
+ static struct timeval wg_last_underload;
+ static const struct timeval underload_interval = { UNDERLOAD_TIMEOUT, 0 };
+ bool packet_needs_cookie = false;
+ int underload, res;
+
+ underload = mbufq_len(&sc->sc_handshake_queue) >=
+ MAX_QUEUED_HANDSHAKES / 8;
+ if (underload)
+ getmicrouptime(&wg_last_underload);
+ else if (wg_last_underload.tv_sec != 0) {
+ if (!ratecheck(&wg_last_underload, &underload_interval))
+ underload = 1;
+ else
+ bzero(&wg_last_underload, sizeof(wg_last_underload));
+ }
+
+ res = wg_cookie_validate_packet(&sc->sc_cookie, m, underload);
+
+ if (res && res != EAGAIN) {
+ printf("validate_packet got %d\n", res);
+ goto free;
+ }
+ if (res == EINVAL) {
+ DPRINTF(sc, "Invalid initiation MAC\n");
+ goto free;
+ } else if (res == ECONNREFUSED) {
+ DPRINTF(sc, "Handshake ratelimited\n");
+ goto free;
+ } else if (res == EAGAIN) {
+ packet_needs_cookie = true;
+ } else if (res != 0) {
+ DPRINTF(sc, "Unexpected handshake ratelimit response: %d\n", res);
+ goto free;
+ }
+
+ t = wg_tag_get(m);
+ switch (*mtod(m, uint32_t *)) {
+ case WG_PKT_INITIATION:
+ init = mtod(m, struct wg_pkt_initiation *);
+
+ if (packet_needs_cookie) {
+ wg_send_cookie(sc, &init->m, init->s_idx, m);
+ goto free;
+ }
+ if (noise_consume_initiation(&sc->sc_local, &remote,
+ init->s_idx, init->ue, init->es, init->ets) != 0) {
+ DPRINTF(sc, "Invalid handshake initiation");
+ goto free;
+ }
+
+ peer = __containerof(remote, struct wg_peer, p_remote);
+ DPRINTF(sc, "Receiving handshake initiation from peer %llu\n",
+ (unsigned long long)peer->p_id);
+ counter_u64_add(peer->p_rx_bytes, sizeof(*init));
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*init));
+ wg_peer_set_endpoint_from_tag(peer, t);
+ wg_send_response(peer);
+ break;
+ case WG_PKT_RESPONSE:
+ resp = mtod(m, struct wg_pkt_response *);
+
+ if (packet_needs_cookie) {
+ wg_send_cookie(sc, &resp->m, resp->s_idx, m);
+ goto free;
+ }
+
+ if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) {
+ DPRINTF(sc, "Unknown handshake response\n");
+ goto free;
+ }
+ peer = __containerof(remote, struct wg_peer, p_remote);
+ if (noise_consume_response(remote, resp->s_idx, resp->r_idx,
+ resp->ue, resp->en) != 0) {
+ DPRINTF(sc, "Invalid handshake response\n");
+ goto free;
+ }
+
+ DPRINTF(sc, "Receiving handshake response from peer %llu\n",
+ (unsigned long long)peer->p_id);
+ counter_u64_add(peer->p_rx_bytes, sizeof(*resp));
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*resp));
+ wg_peer_set_endpoint_from_tag(peer, t);
+ if (noise_remote_begin_session(&peer->p_remote) == 0) {
+ wg_timers_event_session_derived(&peer->p_timers);
+ wg_timers_event_handshake_complete(&peer->p_timers);
+ }
+ break;
+ case WG_PKT_COOKIE:
+ cook = mtod(m, struct wg_pkt_cookie *);
+
+ if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) {
+ DPRINTF(sc, "Unknown cookie index\n");
+ goto free;
+ }
+
+ peer = __containerof(remote, struct wg_peer, p_remote);
+
+ if (cookie_maker_consume_payload(&peer->p_cookie,
+ cook->nonce, cook->ec) != 0) {
+ DPRINTF(sc, "Could not decrypt cookie response\n");
+ goto free;
+ }
+
+ DPRINTF(sc, "Receiving cookie response\n");
+ goto free;
+ default:
+ goto free;
+ }
+ MPASS(peer != NULL);
+ wg_timers_event_any_authenticated_packet_received(&peer->p_timers);
+ wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers);
+
+free:
+ m_freem(m);
+}
+
+static void
+wg_softc_handshake_receive(struct wg_softc *sc)
+{
+ struct mbuf *m;
+
+ while ((m = mbufq_dequeue(&sc->sc_handshake_queue)) != NULL)
+ wg_handshake(sc, m);
+}
+
+/* TODO Encrypt */
+static void
+wg_encap(struct wg_softc *sc, struct mbuf *m)
+{
+ struct wg_pkt_data *data;
+ size_t padding_len, plaintext_len, out_len;
+ struct mbuf *mc;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+ uint64_t nonce;
+ int res, allocation_order;
+
+ NET_EPOCH_ASSERT();
+ t = wg_tag_get(m);
+ peer = t->t_peer;
+
+ plaintext_len = MIN(WG_PKT_WITH_PADDING(m->m_pkthdr.len), t->t_mtu);
+ padding_len = plaintext_len - m->m_pkthdr.len;
+ out_len = sizeof(struct wg_pkt_data) + plaintext_len + NOISE_AUTHTAG_LEN;
+
+ if (out_len <= MCLBYTES)
+ allocation_order = MCLBYTES;
+ else if (out_len <= MJUMPAGESIZE)
+ allocation_order = MJUMPAGESIZE;
+ else if (out_len <= MJUM9BYTES)
+ allocation_order = MJUM9BYTES;
+ else if (out_len <= MJUM16BYTES)
+ allocation_order = MJUM16BYTES;
+ else
+ goto error;
+
+ if ((mc = m_getjcl(M_NOWAIT, MT_DATA, M_PKTHDR, allocation_order)) == NULL)
+ goto error;
+
+ data = mtod(mc, struct wg_pkt_data *);
+ m_copydata(m, 0, m->m_pkthdr.len, data->buf);
+ bzero(data->buf + m->m_pkthdr.len, padding_len);
+
+ data->t = WG_PKT_DATA;
+
+ res = noise_remote_encrypt(&peer->p_remote, &data->r_idx, &nonce,
+ data->buf, plaintext_len);
+ nonce = htole64(nonce); /* Wire format is little endian. */
+ memcpy(data->nonce, &nonce, sizeof(data->nonce));
+
+ if (__predict_false(res)) {
+ if (res == EINVAL) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ m_freem(mc);
+ goto error;
+ } else if (res == ESTALE) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ } else {
+ m_freem(mc);
+ goto error;
+ }
+ }
+
+ /* A packet with length 0 is a keepalive packet */
+ if (m->m_pkthdr.len == 0)
+ DPRINTF(sc, "Sending keepalive packet to peer %llu\n",
+ (unsigned long long)peer->p_id);
+ /*
+ * Set the correct output value here since it will be copied
+ * when we move the pkthdr in send.
+ */
+ mc->m_len = mc->m_pkthdr.len = out_len;
+ mc->m_flags &= ~(M_MCAST | M_BCAST);
+
+ t->t_mbuf = mc;
+ error:
+ /* XXX membar ? */
+ t->t_done = 1;
+ GROUPTASK_ENQUEUE(&peer->p_send);
+}
+
+static void
+wg_decap(struct wg_softc *sc, struct mbuf *m)
+{
+ struct wg_pkt_data *data;
+ struct wg_peer *peer, *routed_peer;
+ struct wg_tag *t;
+ size_t plaintext_len;
+ uint8_t version;
+ uint64_t nonce;
+ int res;
+
+ NET_EPOCH_ASSERT();
+ data = mtod(m, struct wg_pkt_data *);
+ plaintext_len = m->m_pkthdr.len - sizeof(struct wg_pkt_data);
+
+ t = wg_tag_get(m);
+ peer = t->t_peer;
+
+ memcpy(&nonce, data->nonce, sizeof(nonce));
+ nonce = le64toh(nonce); /* Wire format is little endian. */
+
+ res = noise_remote_decrypt(&peer->p_remote, data->r_idx, nonce,
+ data->buf, plaintext_len);
+
+ if (__predict_false(res)) {
+ if (res == EINVAL) {
+ goto error;
+ } else if (res == ECONNRESET) {
+ wg_timers_event_handshake_complete(&peer->p_timers);
+ } else if (res == ESTALE) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ } else {
+ panic("unexpected response: %d\n", res);
+ }
+ }
+ wg_peer_set_endpoint_from_tag(peer, t);
+
+ /* Remove the data header, and crypto mac tail from the packet */
+ m_adj(m, sizeof(struct wg_pkt_data));
+ m_adj(m, -NOISE_AUTHTAG_LEN);
+
+ /* A packet with length 0 is a keepalive packet */
+ if (m->m_pkthdr.len == 0) {
+ DPRINTF(peer->p_sc, "Receiving keepalive packet from peer "
+ "%llu\n", (unsigned long long)peer->p_id);
+ goto done;
+ }
+
+ version = mtod(m, struct ip *)->ip_v;
+ if (!((version == 4 && m->m_pkthdr.len >= sizeof(struct ip)) ||
+ (version == 6 && m->m_pkthdr.len >= sizeof(struct ip6_hdr)))) {
+ DPRINTF(peer->p_sc, "Packet is neither ipv4 nor ipv6 from peer "
+ "%llu\n", (unsigned long long)peer->p_id);
+ goto error;
+ }
+
+ routed_peer = wg_aip_lookup(&peer->p_sc->sc_aips, m, IN);
+ if (routed_peer != peer) {
+ DPRINTF(peer->p_sc, "Packet has unallowed src IP from peer "
+ "%llu\n", (unsigned long long)peer->p_id);
+ goto error;
+ }
+
+done:
+ t->t_mbuf = m;
+error:
+ t->t_done = 1;
+ GROUPTASK_ENQUEUE(&peer->p_recv);
+}
+
+static void
+wg_softc_decrypt(struct wg_softc *sc)
+{
+ struct epoch_tracker et;
+ struct mbuf *m;
+
+ NET_EPOCH_ENTER(et);
+ while ((m = buf_ring_dequeue_mc(sc->sc_decap_ring)) != NULL)
+ wg_decap(sc, m);
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_softc_encrypt(struct wg_softc *sc)
+{
+ struct mbuf *m;
+ struct epoch_tracker et;
+
+ NET_EPOCH_ENTER(et);
+ while ((m = buf_ring_dequeue_mc(sc->sc_encap_ring)) != NULL)
+ wg_encap(sc, m);
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_encrypt_dispatch(struct wg_softc *sc)
+{
+ for (int i = 0; i < mp_ncpus; i++) {
+ if (sc->sc_encrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
+ continue;
+ GROUPTASK_ENQUEUE(&sc->sc_encrypt[i]);
+ }
+}
+
+static void
+wg_decrypt_dispatch(struct wg_softc *sc)
+{
+ for (int i = 0; i < mp_ncpus; i++) {
+ if (sc->sc_decrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
+ continue;
+ GROUPTASK_ENQUEUE(&sc->sc_decrypt[i]);
+ }
+}
+
+static void
+wg_deliver_out(struct wg_peer *peer)
+{
+ struct epoch_tracker et;
+ struct wg_tag *t;
+ struct mbuf *m;
+ struct wg_endpoint endpoint;
+ size_t len;
+ int ret;
+
+ NET_EPOCH_ENTER(et);
+ if (peer->p_sc->sc_ifp->if_link_state == LINK_STATE_DOWN)
+ goto done;
+
+ wg_peer_get_endpoint(peer, &endpoint);
+
+ while ((m = wg_queue_dequeue(&peer->p_encap_queue, &t)) != NULL) {
+ /* t_mbuf will contain the encrypted packet */
+ if (t->t_mbuf == NULL) {
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OERRORS, 1);
+ m_freem(m);
+ continue;
+ }
+ len = t->t_mbuf->m_pkthdr.len;
+ ret = wg_send(peer->p_sc, &endpoint, t->t_mbuf);
+
+ if (ret == 0) {
+ wg_timers_event_any_authenticated_packet_traversal(
+ &peer->p_timers);
+ wg_timers_event_any_authenticated_packet_sent(
+ &peer->p_timers);
+
+ if (m->m_pkthdr.len != 0)
+ wg_timers_event_data_sent(&peer->p_timers);
+ counter_u64_add(peer->p_tx_bytes, len);
+ } else if (ret == EADDRNOTAVAIL) {
+ wg_peer_clear_src(peer);
+ wg_peer_get_endpoint(peer, &endpoint);
+ }
+ m_freem(m);
+ }
+done:
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_deliver_in(struct wg_peer *peer)
+{
+ struct mbuf *m;
+ struct ifnet *ifp;
+ struct wg_softc *sc;
+ struct epoch_tracker et;
+ struct wg_tag *t;
+ uint32_t af;
+ int version;
+
+ NET_EPOCH_ENTER(et);
+ sc = peer->p_sc;
+ ifp = sc->sc_ifp;
+
+ while ((m = wg_queue_dequeue(&peer->p_decap_queue, &t)) != NULL) {
+ /* t_mbuf will contain the encrypted packet */
+ if (t->t_mbuf == NULL) {
+ if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
+ m_freem(m);
+ continue;
+ }
+ MPASS(m == t->t_mbuf);
+
+ wg_timers_event_any_authenticated_packet_received(
+ &peer->p_timers);
+ wg_timers_event_any_authenticated_packet_traversal(
+ &peer->p_timers);
+
+ counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
+
+ if (m->m_pkthdr.len == 0) {
+ m_freem(m);
+ continue;
+ }
+
+ m->m_flags &= ~(M_MCAST | M_BCAST);
+ m->m_pkthdr.rcvif = ifp;
+ version = mtod(m, struct ip *)->ip_v;
+ if (version == IPVERSION) {
+ af = AF_INET;
+ BPF_MTAP2(ifp, &af, sizeof(af), m);
+ CURVNET_SET(ifp->if_vnet);
+ ip_input(m);
+ CURVNET_RESTORE();
+ } else if (version == 6) {
+ af = AF_INET6;
+ BPF_MTAP2(ifp, &af, sizeof(af), m);
+ CURVNET_SET(ifp->if_vnet);
+ ip6_input(m);
+ CURVNET_RESTORE();
+ } else
+ m_freem(m);
+
+ wg_timers_event_data_received(&peer->p_timers);
+ }
+ NET_EPOCH_EXIT(et);
+}
+
+static int
+wg_queue_in(struct wg_peer *peer, struct mbuf *m)
+{
+ struct buf_ring *parallel = peer->p_sc->sc_decap_ring;
+ struct wg_queue *serial = &peer->p_decap_queue;
+ struct wg_tag *t;
+ int rc;
+
+ MPASS(wg_tag_get(m) != NULL);
+
+ mtx_lock(&serial->q_mtx);
+ if ((rc = mbufq_enqueue(&serial->q, m)) == ENOBUFS) {
+ m_freem(m);
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
+ } else {
+ m->m_flags |= M_ENQUEUED;
+ rc = buf_ring_enqueue(parallel, m);
+ if (rc == ENOBUFS) {
+ t = wg_tag_get(m);
+ t->t_done = 1;
+ }
+ }
+ mtx_unlock(&serial->q_mtx);
+ return (rc);
+}
+
+static void
+wg_queue_stage(struct wg_peer *peer, struct mbuf *m)
+{
+ struct wg_queue *q = &peer->p_stage_queue;
+ mtx_lock(&q->q_mtx);
+ STAILQ_INSERT_TAIL(&q->q.mq_head, m, m_stailqpkt);
+ q->q.mq_len++;
+ while (mbufq_full(&q->q)) {
+ m = mbufq_dequeue(&q->q);
+ if (m) {
+ m_freem(m);
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
+ }
+ }
+ mtx_unlock(&q->q_mtx);
+}
+
+static void
+wg_queue_out(struct wg_peer *peer)
+{
+ struct buf_ring *parallel = peer->p_sc->sc_encap_ring;
+ struct wg_queue *serial = &peer->p_encap_queue;
+ struct wg_tag *t;
+ struct mbufq staged;
+ struct mbuf *m;
+
+ if (noise_remote_ready(&peer->p_remote) != 0) {
+ if (wg_queue_len(&peer->p_stage_queue))
+ wg_timers_event_want_initiation(&peer->p_timers);
+ return;
+ }
+
+ /* We first "steal" the staged queue to a local queue, so that we can do these
+ * remaining operations without having to hold the staged queue mutex. */
+ STAILQ_INIT(&staged.mq_head);
+ mtx_lock(&peer->p_stage_queue.q_mtx);
+ STAILQ_SWAP(&staged.mq_head, &peer->p_stage_queue.q.mq_head, mbuf);
+ staged.mq_len = peer->p_stage_queue.q.mq_len;
+ peer->p_stage_queue.q.mq_len = 0;
+ staged.mq_maxlen = peer->p_stage_queue.q.mq_maxlen;
+ mtx_unlock(&peer->p_stage_queue.q_mtx);
+
+ while ((m = mbufq_dequeue(&staged)) != NULL) {
+ if ((t = wg_tag_get(m)) == NULL) {
+ m_freem(m);
+ continue;
+ }
+ t->t_peer = peer;
+ mtx_lock(&serial->q_mtx);
+ if (mbufq_enqueue(&serial->q, m) != 0) {
+ m_freem(m);
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
+ } else {
+ m->m_flags |= M_ENQUEUED;
+ if (buf_ring_enqueue(parallel, m)) {
+ t = wg_tag_get(m);
+ t->t_done = 1;
+ }
+ }
+ mtx_unlock(&serial->q_mtx);
+ }
+ wg_encrypt_dispatch(peer->p_sc);
+}
+
+static struct mbuf *
+wg_queue_dequeue(struct wg_queue *q, struct wg_tag **t)
+{
+ struct mbuf *m_, *m;
+
+ m = NULL;
+ mtx_lock(&q->q_mtx);
+ m_ = mbufq_first(&q->q);
+ if (m_ != NULL && (*t = wg_tag_get(m_))->t_done) {
+ m = mbufq_dequeue(&q->q);
+ m->m_flags &= ~M_ENQUEUED;
+ }
+ mtx_unlock(&q->q_mtx);
+ return (m);
+}
+
+static int
+wg_queue_len(struct wg_queue *q)
+{
+ /* This access races. We might consider adding locking here. */
+ return (mbufq_len(&q->q));
+}
+
+static void
+wg_queue_init(struct wg_queue *q, const char *name)
+{
+ mtx_init(&q->q_mtx, name, NULL, MTX_DEF);
+ mbufq_init(&q->q, MAX_QUEUED_PKT);
+}
+
+static void
+wg_queue_deinit(struct wg_queue *q)
+{
+ wg_queue_purge(q);
+ mtx_destroy(&q->q_mtx);
+}
+
+static void
+wg_queue_purge(struct wg_queue *q)
+{
+ mtx_lock(&q->q_mtx);
+ mbufq_drain(&q->q);
+ mtx_unlock(&q->q_mtx);
+}
+
+/* TODO Indexes */
+static struct noise_remote *
+wg_remote_get(struct wg_softc *sc, uint8_t public[NOISE_PUBLIC_KEY_LEN])
+{
+ struct wg_peer *peer;
+
+ if ((peer = wg_peer_lookup(sc, public)) == NULL)
+ return (NULL);
+ return (&peer->p_remote);
+}
+
+static uint32_t
+wg_index_set(struct wg_softc *sc, struct noise_remote *remote)
+{
+ struct wg_index *index, *iter;
+ struct wg_peer *peer;
+ uint32_t key;
+
+ /* We can modify this without a lock as wg_index_set, wg_index_drop are
+ * guaranteed to be serialised (per remote). */
+ peer = __containerof(remote, struct wg_peer, p_remote);
+ index = SLIST_FIRST(&peer->p_unused_index);
+ MPASS(index != NULL);
+ SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry);
+
+ index->i_value = remote;
+
+ rw_wlock(&sc->sc_index_lock);
+assign_id:
+ key = index->i_key = arc4random();
+ key &= sc->sc_index_mask;
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == index->i_key)
+ goto assign_id;
+
+ LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry);
+
+ rw_wunlock(&sc->sc_index_lock);
+
+ /* Likewise, no need to lock for index here. */
+ return index->i_key;
+}
+
+static struct noise_remote *
+wg_index_get(struct wg_softc *sc, uint32_t key0)
+{
+ struct wg_index *iter;
+ struct noise_remote *remote = NULL;
+ uint32_t key = key0 & sc->sc_index_mask;
+
+ rw_enter_read(&sc->sc_index_lock);
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == key0) {
+ remote = iter->i_value;
+ break;
+ }
+ rw_exit_read(&sc->sc_index_lock);
+ return remote;
+}
+
+static void
+wg_index_drop(struct wg_softc *sc, uint32_t key0)
+{
+ struct wg_index *iter;
+ struct wg_peer *peer = NULL;
+ uint32_t key = key0 & sc->sc_index_mask;
+
+ rw_enter_write(&sc->sc_index_lock);
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == key0) {
+ LIST_REMOVE(iter, i_entry);
+ break;
+ }
+ rw_exit_write(&sc->sc_index_lock);
+
+ if (iter == NULL)
+ return;
+
+ /* We expect a peer */
+ peer = __containerof(iter->i_value, struct wg_peer, p_remote);
+ MPASS(peer != NULL);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry);
+}
+
+static int
+wg_update_endpoint_addrs(struct wg_endpoint *e, const struct sockaddr *srcsa,
+ struct ifnet *rcvif)
+{
+ const struct sockaddr_in *sa4;
+ const struct sockaddr_in6 *sa6;
+ int ret = 0;
+
+ /*
+ * UDP passes a 2-element sockaddr array: first element is the
+ * source addr/port, second the destination addr/port.
+ */
+ if (srcsa->sa_family == AF_INET) {
+ sa4 = (const struct sockaddr_in *)srcsa;
+ e->e_remote.r_sin = sa4[0];
+ e->e_local.l_in = sa4[1].sin_addr;
+ } else if (srcsa->sa_family == AF_INET6) {
+ sa6 = (const struct sockaddr_in6 *)srcsa;
+ e->e_remote.r_sin6 = sa6[0];
+ e->e_local.l_in6 = sa6[1].sin6_addr;
+ } else {
+ ret = EAFNOSUPPORT;
+ }
+
+ return (ret);
+}
+
+static void
+wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb,
+ const struct sockaddr *srcsa, void *_sc)
+{
+ struct wg_pkt_data *pkt_data;
+ struct wg_endpoint *e;
+ struct wg_softc *sc = _sc;
+ struct mbuf *m;
+ int pktlen, pkttype;
+ struct noise_remote *remote;
+ struct wg_tag *t;
+ void *data;
+
+ /* Caller provided us with srcsa, no need for this header. */
+ m_adj(m0, offset + sizeof(struct udphdr));
+
+ /*
+ * Ensure mbuf has at least enough contiguous data to peel off our
+ * headers at the beginning.
+ */
+ if ((m = m_defrag(m0, M_NOWAIT)) == NULL) {
+ m_freem(m0);
+ return;
+ }
+ data = mtod(m, void *);
+ pkttype = *(uint32_t*)data;
+ t = wg_tag_get(m);
+ if (t == NULL) {
+ goto free;
+ }
+ e = wg_mbuf_endpoint_get(m);
+
+ if (wg_update_endpoint_addrs(e, srcsa, m->m_pkthdr.rcvif)) {
+ goto free;
+ }
+
+ pktlen = m->m_pkthdr.len;
+
+ if ((pktlen == sizeof(struct wg_pkt_initiation) &&
+ pkttype == WG_PKT_INITIATION) ||
+ (pktlen == sizeof(struct wg_pkt_response) &&
+ pkttype == WG_PKT_RESPONSE) ||
+ (pktlen == sizeof(struct wg_pkt_cookie) &&
+ pkttype == WG_PKT_COOKIE)) {
+ if (mbufq_enqueue(&sc->sc_handshake_queue, m) == 0) {
+ GROUPTASK_ENQUEUE(&sc->sc_handshake);
+ } else {
+ DPRINTF(sc, "Dropping handshake packet\n");
+ m_freem(m);
+ }
+ } else if (pktlen >= sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN
+ && pkttype == WG_PKT_DATA) {
+
+ pkt_data = data;
+ remote = wg_index_get(sc, pkt_data->r_idx);
+ if (remote == NULL) {
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
+ m_freem(m);
+ } else if (buf_ring_count(sc->sc_decap_ring) > MAX_QUEUED_PKT) {
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
+ m_freem(m);
+ } else {
+ t->t_peer = __containerof(remote, struct wg_peer,
+ p_remote);
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+
+ wg_queue_in(t->t_peer, m);
+ wg_decrypt_dispatch(sc);
+ }
+ } else {
+free:
+ m_freem(m);
+ }
+}
+
+static int
+wg_transmit(struct ifnet *ifp, struct mbuf *m)
+{
+ struct wg_softc *sc;
+ sa_family_t family;
+ struct epoch_tracker et;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+ uint32_t af;
+ int rc;
+
+ /*
+ * Work around lifetime issue in the ipv6 mld code.
+ */
+ if (__predict_false(ifp->if_flags & IFF_DYING))
+ return (ENXIO);
+
+ rc = 0;
+ sc = ifp->if_softc;
+ if ((t = wg_tag_get(m)) == NULL) {
+ rc = ENOBUFS;
+ goto early_out;
+ }
+ af = m->m_pkthdr.ph_family;
+ BPF_MTAP2(ifp, &af, sizeof(af), m);
+
+ NET_EPOCH_ENTER(et);
+ peer = wg_aip_lookup(&sc->sc_aips, m, OUT);
+ if (__predict_false(peer == NULL)) {
+ rc = ENOKEY;
+ goto err;
+ }
+
+ family = peer->p_endpoint.e_remote.r_sa.sa_family;
+ if (__predict_false(family != AF_INET && family != AF_INET6)) {
+ DPRINTF(sc, "No valid endpoint has been configured or "
+ "discovered for peer %llu\n", (unsigned long long)peer->p_id);
+
+ rc = EHOSTUNREACH;
+ goto err;
+ }
+ t->t_peer = peer;
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+ t->t_mtu = ifp->if_mtu;
+
+ wg_queue_stage(peer, m);
+ wg_queue_out(peer);
+ NET_EPOCH_EXIT(et);
+ return (rc);
+err:
+ NET_EPOCH_EXIT(et);
+early_out:
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
+ /* TODO: send ICMP unreachable */
+ m_free(m);
+ return (rc);
+}
+
+static int
+wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *rt)
+{
+ m->m_pkthdr.ph_family = sa->sa_family;
+ return (wg_transmit(ifp, m));
+}
+
+static int
+wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
+{
+ uint8_t public[WG_KEY_SIZE];
+ const void *pub_key;
+ const struct sockaddr *endpoint;
+ int err;
+ size_t size;
+ struct wg_peer *peer = NULL;
+ bool need_insert = false;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ if (!nvlist_exists_binary(nvl, "public-key")) {
+ return (EINVAL);
+ }
+ pub_key = nvlist_get_binary(nvl, "public-key", &size);
+ if (size != WG_KEY_SIZE) {
+ return (EINVAL);
+ }
+ if (noise_local_keys(&sc->sc_local, public, NULL) == 0 &&
+ bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
+ return (0); // Silently ignored; not actually a failure.
+ }
+ peer = wg_peer_lookup(sc, pub_key);
+ if (nvlist_exists_bool(nvl, "remove") &&
+ nvlist_get_bool(nvl, "remove")) {
+ if (peer != NULL) {
+ wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
+ wg_peer_destroy(peer);
+ }
+ return (0);
+ }
+ if (nvlist_exists_bool(nvl, "replace-allowedips") &&
+ nvlist_get_bool(nvl, "replace-allowedips") &&
+ peer != NULL) {
+
+ wg_aip_delete(&peer->p_sc->sc_aips, peer);
+ }
+ if (peer == NULL) {
+ if (sc->sc_peer_count >= MAX_PEERS_PER_IFACE)
+ return (E2BIG);
+ sc->sc_peer_count++;
+
+ need_insert = true;
+ peer = wg_peer_alloc(sc);
+ MPASS(peer != NULL);
+ noise_remote_init(&peer->p_remote, pub_key, &sc->sc_local);
+ cookie_maker_init(&peer->p_cookie, pub_key);
+ }
+ if (nvlist_exists_binary(nvl, "endpoint")) {
+ endpoint = nvlist_get_binary(nvl, "endpoint", &size);
+ if (size > sizeof(peer->p_endpoint.e_remote)) {
+ err = EINVAL;
+ goto out;
+ }
+ memcpy(&peer->p_endpoint.e_remote, endpoint, size);
+ }
+ if (nvlist_exists_binary(nvl, "preshared-key")) {
+ const void *key;
+
+ key = nvlist_get_binary(nvl, "preshared-key", &size);
+ if (size != WG_KEY_SIZE) {
+ err = EINVAL;
+ goto out;
+ }
+ noise_remote_set_psk(&peer->p_remote, key);
+ }
+ if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
+ uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
+ if (pki > UINT16_MAX) {
+ err = EINVAL;
+ goto out;
+ }
+ wg_timers_set_persistent_keepalive(&peer->p_timers, pki);
+ }
+ if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
+ const void *binary;
+ uint64_t cidr;
+ const nvlist_t * const * aipl;
+ struct wg_allowedip aip;
+ size_t allowedip_count;
+
+ aipl = nvlist_get_nvlist_array(nvl, "allowed-ips",
+ &allowedip_count);
+ for (size_t idx = 0; idx < allowedip_count; idx++) {
+ if (!nvlist_exists_number(aipl[idx], "cidr"))
+ continue;
+ cidr = nvlist_get_number(aipl[idx], "cidr");
+ if (nvlist_exists_binary(aipl[idx], "ipv4")) {
+ binary = nvlist_get_binary(aipl[idx], "ipv4", &size);
+ if (binary == NULL || cidr > 32 || size != sizeof(aip.ip4)) {
+ err = EINVAL;
+ goto out;
+ }
+ aip.family = AF_INET;
+ memcpy(&aip.ip4, binary, sizeof(aip.ip4));
+ } else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
+ binary = nvlist_get_binary(aipl[idx], "ipv6", &size);
+ if (binary == NULL || cidr > 128 || size != sizeof(aip.ip6)) {
+ err = EINVAL;
+ goto out;
+ }
+ aip.family = AF_INET6;
+ memcpy(&aip.ip6, binary, sizeof(aip.ip6));
+ } else {
+ continue;
+ }
+ aip.cidr = cidr;
+
+ if ((err = wg_aip_add(&sc->sc_aips, peer, &aip)) != 0) {
+ goto out;
+ }
+ }
+ }
+ if (need_insert) {
+ wg_hashtable_peer_insert(&sc->sc_hashtable, peer);
+ if (sc->sc_ifp->if_link_state == LINK_STATE_UP)
+ wg_timers_enable(&peer->p_timers);
+ }
+ return (0);
+
+out:
+ if (need_insert) /* If we fail, only destroy if it was new. */
+ wg_peer_destroy(peer);
+ return (err);
+}
+
+static int
+wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
+{
+ uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
+ struct ifnet *ifp;
+ void *nvlpacked;
+ nvlist_t *nvl;
+ ssize_t size;
+ int err;
+
+ ifp = sc->sc_ifp;
+ if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
+ return (EFAULT);
+
+ sx_xlock(&sc->sc_lock);
+
+ nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK);
+ err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
+ if (err)
+ goto out;
+ nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
+ if (nvl == NULL) {
+ err = EBADMSG;
+ goto out;
+ }
+ if (nvlist_exists_bool(nvl, "replace-peers") &&
+ nvlist_get_bool(nvl, "replace-peers"))
+ wg_peer_remove_all(sc);
+ if (nvlist_exists_number(nvl, "listen-port")) {
+ uint64_t new_port = nvlist_get_number(nvl, "listen-port");
+ if (new_port > UINT16_MAX) {
+ err = EINVAL;
+ goto out;
+ }
+ if (new_port != sc->sc_socket.so_port) {
+ if ((ifp->if_drv_flags & IFF_DRV_RUNNING) != 0) {
+ if ((err = wg_socket_init(sc, new_port)) != 0)
+ goto out;
+ } else
+ sc->sc_socket.so_port = new_port;
+ }
+ }
+ if (nvlist_exists_binary(nvl, "private-key")) {
+ const void *key = nvlist_get_binary(nvl, "private-key", &size);
+ if (size != WG_KEY_SIZE) {
+ err = EINVAL;
+ goto out;
+ }
+
+ if (noise_local_keys(&sc->sc_local, NULL, private) != 0 ||
+ timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
+ struct noise_local *local;
+ struct wg_peer *peer;
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ bool has_identity;
+
+ if (curve25519_generate_public(public, key)) {
+ /* Peer conflict: remove conflicting peer. */
+ if ((peer = wg_peer_lookup(sc, public)) !=
+ NULL) {
+ wg_hashtable_peer_remove(ht, peer);
+ wg_peer_destroy(peer);
+ }
+ }
+
+ /*
+ * Set the private key and invalidate all existing
+ * handshakes.
+ */
+ local = &sc->sc_local;
+ noise_local_lock_identity(local);
+ /* Note: we might be removing the private key. */
+ has_identity = noise_local_set_private(local, key) == 0;
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ noise_remote_precompute(&peer->p_remote);
+ wg_timers_event_reset_handshake_last_sent(
+ &peer->p_timers);
+ noise_remote_expire_current(&peer->p_remote);
+ }
+ mtx_unlock(&ht->h_mtx);
+ cookie_checker_update(&sc->sc_cookie,
+ has_identity ? public : NULL);
+ noise_local_unlock_identity(local);
+ }
+ }
+ if (nvlist_exists_number(nvl, "user-cookie")) {
+ uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
+ if (user_cookie > UINT32_MAX) {
+ err = EINVAL;
+ goto out;
+ }
+ wg_socket_set_cookie(sc, user_cookie);
+ }
+ if (nvlist_exists_nvlist_array(nvl, "peers")) {
+ size_t peercount;
+ const nvlist_t * const*nvl_peers;
+
+ nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
+ for (int i = 0; i < peercount; i++) {
+ err = wg_peer_add(sc, nvl_peers[i]);
+ if (err != 0)
+ goto out;
+ }
+ }
+
+ nvlist_destroy(nvl);
+out:
+ free(nvlpacked, M_TEMP);
+ sx_xunlock(&sc->sc_lock);
+ return (err);
+}
+
+static unsigned int
+in_mask2len(struct in_addr *mask)
+{
+ unsigned int x, y;
+ uint8_t *p;
+
+ p = (uint8_t *)mask;
+ for (x = 0; x < sizeof(*mask); x++) {
+ if (p[x] != 0xff)
+ break;
+ }
+ y = 0;
+ if (x < sizeof(*mask)) {
+ for (y = 0; y < NBBY; y++) {
+ if ((p[x] & (0x80 >> y)) == 0)
+ break;
+ }
+ }
+ return x * NBBY + y;
+}
+
+static int
+wg_peer_to_export(struct wg_peer *peer, struct wg_peer_export *exp)
+{
+ struct wg_endpoint *ep;
+ struct wg_aip *rt;
+ struct noise_remote *remote;
+ int i;
+
+ /* Non-sleepable context. */
+ NET_EPOCH_ASSERT();
+
+ bzero(&exp->endpoint, sizeof(exp->endpoint));
+ remote = &peer->p_remote;
+ ep = &peer->p_endpoint;
+ if (ep->e_remote.r_sa.sa_family != 0) {
+ exp->endpoint_sz = (ep->e_remote.r_sa.sa_family == AF_INET) ?
+ sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
+
+ memcpy(&exp->endpoint, &ep->e_remote, exp->endpoint_sz);
+ }
+
+ /* We always export it. */
+ (void)noise_remote_keys(remote, exp->public_key, exp->preshared_key);
+ exp->persistent_keepalive =
+ peer->p_timers.t_persistent_keepalive_interval;
+ wg_timers_get_last_handshake(&peer->p_timers, &exp->last_handshake);
+ exp->rx_bytes = counter_u64_fetch(peer->p_rx_bytes);
+ exp->tx_bytes = counter_u64_fetch(peer->p_tx_bytes);
+
+ exp->aip_count = 0;
+ CK_LIST_FOREACH(rt, &peer->p_aips, r_entry) {
+ exp->aip_count++;
+ }
+
+ /* Early success; no allowed-ips to copy out. */
+ if (exp->aip_count == 0)
+ return (0);
+
+ exp->aip = malloc(exp->aip_count * sizeof(*exp->aip), M_TEMP, M_NOWAIT);
+ if (exp->aip == NULL)
+ return (ENOMEM);
+
+ i = 0;
+ CK_LIST_FOREACH(rt, &peer->p_aips, r_entry) {
+ exp->aip[i].family = rt->r_addr.ss_family;
+ if (exp->aip[i].family == AF_INET) {
+ struct sockaddr_in *sin =
+ (struct sockaddr_in *)&rt->r_addr;
+
+ exp->aip[i].ip4 = sin->sin_addr;
+
+ sin = (struct sockaddr_in *)&rt->r_mask;
+ exp->aip[i].cidr = in_mask2len(&sin->sin_addr);
+ } else if (exp->aip[i].family == AF_INET6) {
+ struct sockaddr_in6 *sin6 =
+ (struct sockaddr_in6 *)&rt->r_addr;
+
+ exp->aip[i].ip6 = sin6->sin6_addr;
+
+ sin6 = (struct sockaddr_in6 *)&rt->r_mask;
+ exp->aip[i].cidr = in6_mask2len(&sin6->sin6_addr, NULL);
+ }
+ i++;
+ if (i == exp->aip_count)
+ break;
+ }
+
+ /* Again, AllowedIPs might have shrank; update it. */
+ exp->aip_count = i;
+
+ return (0);
+}
+
+static nvlist_t *
+wg_peer_export_to_nvl(struct wg_softc *sc, struct wg_peer_export *exp)
+{
+ struct wg_timespec64 ts64;
+ nvlist_t *nvl, **nvl_aips;
+ size_t i;
+ uint16_t family;
+
+ nvl_aips = NULL;
+ if ((nvl = nvlist_create(0)) == NULL)
+ return (NULL);
+
+ nvlist_add_binary(nvl, "public-key", exp->public_key,
+ sizeof(exp->public_key));
+ if (wgc_privileged(sc))
+ nvlist_add_binary(nvl, "preshared-key", exp->preshared_key,
+ sizeof(exp->preshared_key));
+ if (exp->endpoint_sz != 0)
+ nvlist_add_binary(nvl, "endpoint", &exp->endpoint,
+ exp->endpoint_sz);
+
+ if (exp->aip_count != 0) {
+ nvl_aips = mallocarray(exp->aip_count, sizeof(*nvl_aips),
+ M_WG, M_WAITOK | M_ZERO);
+ }
+
+ for (i = 0; i < exp->aip_count; i++) {
+ nvl_aips[i] = nvlist_create(0);
+ if (nvl_aips[i] == NULL)
+ goto err;
+ family = exp->aip[i].family;
+ nvlist_add_number(nvl_aips[i], "cidr", exp->aip[i].cidr);
+ if (family == AF_INET)
+ nvlist_add_binary(nvl_aips[i], "ipv4",
+ &exp->aip[i].ip4, sizeof(exp->aip[i].ip4));
+ else if (family == AF_INET6)
+ nvlist_add_binary(nvl_aips[i], "ipv6",
+ &exp->aip[i].ip6, sizeof(exp->aip[i].ip6));
+ }
+
+ if (i != 0) {
+ nvlist_add_nvlist_array(nvl, "allowed-ips",
+ (const nvlist_t *const *)nvl_aips, i);
+ }
+
+ for (i = 0; i < exp->aip_count; ++i)
+ nvlist_destroy(nvl_aips[i]);
+
+ free(nvl_aips, M_WG);
+ nvl_aips = NULL;
+
+ ts64.tv_sec = exp->last_handshake.tv_sec;
+ ts64.tv_nsec = exp->last_handshake.tv_nsec;
+ nvlist_add_binary(nvl, "last-handshake-time", &ts64, sizeof(ts64));
+
+ if (exp->persistent_keepalive != 0)
+ nvlist_add_number(nvl, "persistent-keepalive-interval",
+ exp->persistent_keepalive);
+
+ if (exp->rx_bytes != 0)
+ nvlist_add_number(nvl, "rx-bytes", exp->rx_bytes);
+ if (exp->tx_bytes != 0)
+ nvlist_add_number(nvl, "tx-bytes", exp->tx_bytes);
+
+ return (nvl);
+err:
+ for (i = 0; i < exp->aip_count && nvl_aips[i] != NULL; i++) {
+ nvlist_destroy(nvl_aips[i]);
+ }
+
+ free(nvl_aips, M_WG);
+ nvlist_destroy(nvl);
+ return (NULL);
+}
+
+static int
+wg_marshal_peers(struct wg_softc *sc, nvlist_t **nvlp, nvlist_t ***nvl_arrayp, int *peer_countp)
+{
+ struct wg_peer *peer;
+ int err, i, peer_count;
+ nvlist_t *nvl, **nvl_array;
+ struct epoch_tracker et;
+ struct wg_peer_export *wpe;
+
+ nvl = NULL;
+ nvl_array = NULL;
+ if (nvl_arrayp)
+ *nvl_arrayp = NULL;
+ if (nvlp)
+ *nvlp = NULL;
+ if (peer_countp)
+ *peer_countp = 0;
+ peer_count = sc->sc_hashtable.h_num_peers;
+ if (peer_count == 0) {
+ return (ENOENT);
+ }
+
+ if (nvlp && (nvl = nvlist_create(0)) == NULL)
+ return (ENOMEM);
+
+ err = i = 0;
+ nvl_array = malloc(peer_count*sizeof(void*), M_TEMP, M_WAITOK | M_ZERO);
+ wpe = malloc(peer_count*sizeof(*wpe), M_TEMP, M_WAITOK | M_ZERO);
+
+ NET_EPOCH_ENTER(et);
+ CK_LIST_FOREACH(peer, &sc->sc_hashtable.h_peers_list, p_entry) {
+ if ((err = wg_peer_to_export(peer, &wpe[i])) != 0) {
+ break;
+ }
+
+ i++;
+ if (i == peer_count)
+ break;
+ }
+ NET_EPOCH_EXIT(et);
+
+ if (err != 0)
+ goto out;
+
+ /* Update the peer count, in case we found fewer entries. */
+ *peer_countp = peer_count = i;
+ if (peer_count == 0) {
+ err = ENOENT;
+ goto out;
+ }
+
+ for (i = 0; i < peer_count; i++) {
+ int idx;
+
+ /*
+ * Peers are added to the list in reverse order, effectively,
+ * because it's simpler/quicker to add at the head every time.
+ *
+ * Export them in reverse order. No worries if we fail mid-way
+ * through, the cleanup below will DTRT.
+ */
+ idx = peer_count - i - 1;
+ nvl_array[idx] = wg_peer_export_to_nvl(sc, &wpe[i]);
+ if (nvl_array[idx] == NULL) {
+ break;
+ }
+ }
+
+ if (i < peer_count) {
+ /* Error! */
+ *peer_countp = 0;
+ err = ENOMEM;
+ } else if (nvl) {
+ nvlist_add_nvlist_array(nvl, "peers",
+ (const nvlist_t * const *)nvl_array, peer_count);
+ if ((err = nvlist_error(nvl))) {
+ goto out;
+ }
+ *nvlp = nvl;
+ }
+ *nvl_arrayp = nvl_array;
+ out:
+ if (err != 0) {
+ /* Note that nvl_array is populated in reverse order. */
+ for (i = 0; i < peer_count; i++) {
+ nvlist_destroy(nvl_array[i]);
+ }
+
+ free(nvl_array, M_TEMP);
+ if (nvl != NULL)
+ nvlist_destroy(nvl);
+ }
+
+ for (i = 0; i < peer_count; i++)
+ free(wpe[i].aip, M_TEMP);
+ free(wpe, M_TEMP);
+ return (err);
+}
+
+static int
+wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
+{
+ nvlist_t *nvl, **nvl_array;
+ void *packed;
+ size_t size;
+ int peer_count, err;
+
+ nvl = nvlist_create(0);
+ if (nvl == NULL)
+ return (ENOMEM);
+
+ sx_slock(&sc->sc_lock);
+
+ err = 0;
+ packed = NULL;
+ if (sc->sc_socket.so_port != 0)
+ nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
+ if (sc->sc_socket.so_user_cookie != 0)
+ nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
+ if (sc->sc_local.l_has_identity) {
+ nvlist_add_binary(nvl, "public-key", sc->sc_local.l_public, WG_KEY_SIZE);
+ if (wgc_privileged(sc))
+ nvlist_add_binary(nvl, "private-key", sc->sc_local.l_private, WG_KEY_SIZE);
+ }
+ if (sc->sc_hashtable.h_num_peers > 0) {
+ err = wg_marshal_peers(sc, NULL, &nvl_array, &peer_count);
+ if (err)
+ goto out_nvl;
+ nvlist_add_nvlist_array(nvl, "peers",
+ (const nvlist_t * const *)nvl_array, peer_count);
+ }
+ packed = nvlist_pack(nvl, &size);
+ if (packed == NULL) {
+ err = ENOMEM;
+ goto out_nvl;
+ }
+ if (wgd->wgd_size == 0) {
+ wgd->wgd_size = size;
+ goto out_packed;
+ }
+ if (wgd->wgd_size < size) {
+ err = ENOSPC;
+ goto out_packed;
+ }
+ if (wgd->wgd_data == NULL) {
+ err = EFAULT;
+ goto out_packed;
+ }
+ err = copyout(packed, wgd->wgd_data, size);
+ wgd->wgd_size = size;
+
+out_packed:
+ free(packed, M_NVLIST);
+out_nvl:
+ nvlist_destroy(nvl);
+ sx_sunlock(&sc->sc_lock);
+ return (err);
+}
+
+static int
+wg_ioctl(struct ifnet *ifp, u_long cmd, caddr_t data)
+{
+ struct wg_data_io *wgd = (struct wg_data_io *)data;
+ struct ifreq *ifr = (struct ifreq *)data;
+ struct wg_softc *sc = ifp->if_softc;
+ int ret = 0;
+
+ switch (cmd) {
+ case SIOCSWG:
+ ret = priv_check(curthread, PRIV_NET_WG);
+ if (ret == 0)
+ ret = wgc_set(sc, wgd);
+ break;
+ case SIOCGWG:
+ ret = wgc_get(sc, wgd);
+ break;
+ /* Interface IOCTLs */
+ case SIOCSIFADDR:
+ /*
+ * This differs from *BSD norms, but is more uniform with how
+ * WireGuard behaves elsewhere.
+ */
+ break;
+ case SIOCSIFFLAGS:
+ if ((ifp->if_flags & IFF_UP) != 0)
+ ret = wg_up(sc);
+ else
+ wg_down(sc);
+ break;
+ case SIOCSIFMTU:
+ if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
+ ret = EINVAL;
+ else
+ ifp->if_mtu = ifr->ifr_mtu;
+ break;
+ case SIOCADDMULTI:
+ case SIOCDELMULTI:
+ break;
+ default:
+ ret = ENOTTY;
+ }
+
+ return ret;
+}
+
+static int
+wg_up(struct wg_softc *sc)
+{
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ struct ifnet *ifp = sc->sc_ifp;
+ struct wg_peer *peer;
+ int rc = EBUSY;
+
+ sx_xlock(&sc->sc_lock);
+ /* Jail's being removed, no more wg_up(). */
+ if ((sc->sc_flags & WGF_DYING) != 0)
+ goto out;
+
+ /* Silent success if we're already running. */
+ rc = 0;
+ if (ifp->if_drv_flags & IFF_DRV_RUNNING)
+ goto out;
+ ifp->if_drv_flags |= IFF_DRV_RUNNING;
+
+ rc = wg_socket_init(sc, sc->sc_socket.so_port);
+ if (rc == 0) {
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ wg_timers_enable(&peer->p_timers);
+ wg_queue_out(peer);
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
+ } else {
+ ifp->if_drv_flags &= ~IFF_DRV_RUNNING;
+ }
+out:
+ sx_xunlock(&sc->sc_lock);
+ return (rc);
+}
+
+static void
+wg_down(struct wg_softc *sc)
+{
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ struct ifnet *ifp = sc->sc_ifp;
+ struct wg_peer *peer;
+
+ sx_xlock(&sc->sc_lock);
+ if (!(ifp->if_drv_flags & IFF_DRV_RUNNING)) {
+ sx_xunlock(&sc->sc_lock);
+ return;
+ }
+ ifp->if_drv_flags &= ~IFF_DRV_RUNNING;
+
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ wg_queue_purge(&peer->p_stage_queue);
+ wg_timers_disable(&peer->p_timers);
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ mbufq_drain(&sc->sc_handshake_queue);
+
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ noise_remote_clear(&peer->p_remote);
+ wg_timers_event_reset_handshake_last_sent(&peer->p_timers);
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
+ wg_socket_uninit(sc);
+
+ sx_xunlock(&sc->sc_lock);
+}
+
+static void
+crypto_taskq_setup(struct wg_softc *sc)
+{
+
+ sc->sc_encrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
+ sc->sc_decrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
+
+ for (int i = 0; i < mp_ncpus; i++) {
+ GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
+ (gtask_fn_t *)wg_softc_encrypt, sc);
+ taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
+ GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
+ (gtask_fn_t *)wg_softc_decrypt, sc);
+ taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
+ }
+}
+
+static void
+crypto_taskq_destroy(struct wg_softc *sc)
+{
+ for (int i = 0; i < mp_ncpus; i++) {
+ taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_encrypt[i]);
+ taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_decrypt[i]);
+ }
+ free(sc->sc_encrypt, M_WG);
+ free(sc->sc_decrypt, M_WG);
+}
+
+static int
+wg_clone_create(struct if_clone *ifc, int unit, caddr_t params)
+{
+ struct wg_softc *sc;
+ struct ifnet *ifp;
+ struct noise_upcall noise_upcall;
+
+ sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
+ sc->sc_ucred = crhold(curthread->td_ucred);
+ ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
+ ifp->if_softc = sc;
+ if_initname(ifp, wgname, unit);
+
+ noise_upcall.u_arg = sc;
+ noise_upcall.u_remote_get =
+ (struct noise_remote *(*)(void *, uint8_t *))wg_remote_get;
+ noise_upcall.u_index_set =
+ (uint32_t (*)(void *, struct noise_remote *))wg_index_set;
+ noise_upcall.u_index_drop =
+ (void (*)(void *, uint32_t))wg_index_drop;
+ noise_local_init(&sc->sc_local, &noise_upcall);
+ cookie_checker_init(&sc->sc_cookie, ratelimit_zone);
+
+ sc->sc_socket.so_port = 0;
+
+ atomic_add_int(&clone_count, 1);
+ ifp->if_capabilities = ifp->if_capenable = WG_CAPS;
+
+ mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES);
+ sx_init(&sc->sc_lock, "wg softc lock");
+ rw_init(&sc->sc_index_lock, "wg index lock");
+ sc->sc_peer_count = 0;
+ sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL);
+ sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL);
+ GROUPTASK_INIT(&sc->sc_handshake, 0,
+ (gtask_fn_t *)wg_softc_handshake_receive, sc);
+ taskqgroup_attach(qgroup_if_io_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
+ crypto_taskq_setup(sc);
+
+ wg_hashtable_init(&sc->sc_hashtable);
+ sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, &sc->sc_index_mask);
+ wg_aip_init(&sc->sc_aips);
+
+ if_setmtu(ifp, ETHERMTU - 80);
+ ifp->if_flags = IFF_BROADCAST | IFF_MULTICAST | IFF_NOARP;
+ ifp->if_init = wg_init;
+ ifp->if_reassign = wg_reassign;
+ ifp->if_qflush = wg_qflush;
+ ifp->if_transmit = wg_transmit;
+ ifp->if_output = wg_output;
+ ifp->if_ioctl = wg_ioctl;
+
+ if_attach(ifp);
+ bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
+
+ sx_xlock(&wg_sx);
+ LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
+ sx_xunlock(&wg_sx);
+
+ return 0;
+}
+
+static void
+wg_clone_destroy(struct ifnet *ifp)
+{
+ struct wg_softc *sc = ifp->if_softc;
+ struct ucred *cred;
+
+ sx_xlock(&wg_sx);
+ sx_xlock(&sc->sc_lock);
+ sc->sc_flags |= WGF_DYING;
+ cred = sc->sc_ucred;
+ sc->sc_ucred = NULL;
+ sx_xunlock(&sc->sc_lock);
+ LIST_REMOVE(sc, sc_entry);
+ sx_xunlock(&wg_sx);
+
+ if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
+
+ sx_xlock(&sc->sc_lock);
+ wg_socket_uninit(sc);
+ sx_xunlock(&sc->sc_lock);
+
+ /*
+ * No guarantees that all traffic have passed until the epoch has
+ * elapsed with the socket closed.
+ */
+ NET_EPOCH_WAIT();
+
+ taskqgroup_drain_all(qgroup_if_io_tqg);
+ sx_xlock(&sc->sc_lock);
+ wg_peer_remove_all(sc);
+ epoch_drain_callbacks(net_epoch_preempt);
+ sx_xunlock(&sc->sc_lock);
+ sx_destroy(&sc->sc_lock);
+ rw_destroy(&sc->sc_index_lock);
+ taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_handshake);
+ crypto_taskq_destroy(sc);
+ buf_ring_free(sc->sc_encap_ring, M_WG);
+ buf_ring_free(sc->sc_decap_ring, M_WG);
+
+ wg_aip_destroy(&sc->sc_aips);
+ wg_hashtable_destroy(&sc->sc_hashtable);
+
+ if (cred != NULL)
+ crfree(cred);
+ if_detach(sc->sc_ifp);
+ if_free(sc->sc_ifp);
+ /* Ensure any local/private keys are cleaned up */
+ explicit_bzero(sc, sizeof(*sc));
+ free(sc, M_WG);
+
+ atomic_add_int(&clone_count, -1);
+}
+
+static void
+wg_qflush(struct ifnet *ifp __unused)
+{
+}
+
+/*
+ * Privileged information (private-key, preshared-key) are only exported for
+ * root and jailed root by default.
+ */
+static bool
+wgc_privileged(struct wg_softc *sc)
+{
+ struct thread *td;
+
+ td = curthread;
+ return (priv_check(td, PRIV_NET_WG) == 0);
+}
+
+static void
+wg_reassign(struct ifnet *ifp, struct vnet *new_vnet __unused,
+ char *unused __unused)
+{
+ struct wg_softc *sc;
+
+ sc = ifp->if_softc;
+ wg_down(sc);
+}
+
+static void
+wg_init(void *xsc)
+{
+ struct wg_softc *sc;
+
+ sc = xsc;
+ wg_up(sc);
+}
+
+static void
+vnet_wg_init(const void *unused __unused)
+{
+
+ V_wg_cloner = if_clone_simple(wgname, wg_clone_create, wg_clone_destroy,
+ 0);
+}
+VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
+ vnet_wg_init, NULL);
+
+static void
+vnet_wg_uninit(const void *unused __unused)
+{
+
+ if_clone_detach(V_wg_cloner);
+}
+VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
+ vnet_wg_uninit, NULL);
+
+static int
+wg_prison_remove(void *obj, void *data __unused)
+{
+ const struct prison *pr = obj;
+ struct wg_softc *sc;
+ struct ucred *cred;
+ bool dying;
+
+ /*
+ * Do a pass through all if_wg interfaces and release creds on any from
+ * the jail that are supposed to be going away. This will, in turn, let
+ * the jail die so that we don't end up with Schrödinger's jail.
+ */
+ sx_slock(&wg_sx);
+ LIST_FOREACH(sc, &wg_list, sc_entry) {
+ cred = NULL;
+
+ sx_xlock(&sc->sc_lock);
+ dying = (sc->sc_flags & WGF_DYING) != 0;
+ if (!dying && sc->sc_ucred != NULL &&
+ sc->sc_ucred->cr_prison == pr) {
+ /* Home jail is going away. */
+ cred = sc->sc_ucred;
+ sc->sc_ucred = NULL;
+
+ sc->sc_flags |= WGF_DYING;
+ }
+
+ /*
+ * If this is our foreign vnet going away, we'll also down the
+ * link and kill the socket because traffic needs to stop. Any
+ * address will be revoked in the rehoming process.
+ */
+ if (cred != NULL || (!dying &&
+ sc->sc_ifp->if_vnet == pr->pr_vnet)) {
+ if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
+ /* Have to kill the sockets, as they also hold refs. */
+ wg_socket_uninit(sc);
+ }
+
+ sx_xunlock(&sc->sc_lock);
+
+ if (cred != NULL) {
+ CURVNET_SET(sc->sc_ifp->if_vnet);
+ if_purgeaddrs(sc->sc_ifp);
+ CURVNET_RESTORE();
+ crfree(cred);
+ }
+ }
+ sx_sunlock(&wg_sx);
+
+ return (0);
+}
+
+static void
+wg_module_init(void)
+{
+ osd_method_t methods[PR_MAXMETHOD] = {
+ [PR_METHOD_REMOVE] = wg_prison_remove,
+ };
+
+ ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit),
+ NULL, NULL, NULL, NULL, 0, 0);
+ wg_osd_jail_slot = osd_jail_register(NULL, methods);
+}
+
+static void
+wg_module_deinit(void)
+{
+
+ uma_zdestroy(ratelimit_zone);
+ osd_jail_deregister(wg_osd_jail_slot);
+
+ MPASS(LIST_EMPTY(&wg_list));
+}
+
+static int
+wg_module_event_handler(module_t mod, int what, void *arg)
+{
+
+ switch (what) {
+ case MOD_LOAD:
+ wg_module_init();
+ break;
+ case MOD_UNLOAD:
+ if (atomic_load_int(&clone_count) == 0)
+ wg_module_deinit();
+ else
+ return (EBUSY);
+ break;
+ default:
+ return (EOPNOTSUPP);
+ }
+ return (0);
+}
+
+static moduledata_t wg_moduledata = {
+ "wg",
+ wg_module_event_handler,
+ NULL
+};
+
+DECLARE_MODULE(wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
+MODULE_VERSION(wg, 1);
+MODULE_DEPEND(wg, crypto, 1, 1, 1);