aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAlexander V. Chernikov <melifaro@FreeBSD.org>2022-06-25 19:01:45 +0000
committerAlexander V. Chernikov <melifaro@FreeBSD.org>2022-06-25 19:25:15 +0000
commit8eb2bee6c0f4957c6c1cea826e59cda4d18a2a64 (patch)
tree8a6481d536e076810de128b0ba49cece8b671554 /tests
parentc38da70c28a886cc31a2f009baa79deb7fceec88 (diff)
downloadsrc-8eb2bee6c0f4957c6c1cea826e59cda4d18a2a64.tar.gz
src-8eb2bee6c0f4957c6c1cea826e59cda4d18a2a64.zip
testing: Add basic atf support to pytest.
Implementation consists of the pytest plugin implementing ATF format and a simple C++ wrapper, which reorders the provided arguments from ATF format to the format understandable by pytest. Each test has this wrapper specified after the shebang. When kyua executes the test, wrapper calls pytest, which loads atf plugin, does the work and returns the result. Additionally, a separate python "package", `/usr/tests/atf_python` has been added to collect code that may be useful across different tests. Current limitations: * Opaque metadata passing via X-Name properties. Require some fixtures to write * `-s srcdir` parameter passed by the runner is ignored. * No `atf-c-api(3)` or similar - relying on pytest framework & existing python libraries * No support for `atf_tc_<get|has>_config_var()` & `atf_tc_set_md_var()`. Can be probably implemented with env variables & autoload fixtures Differential Revision: https://reviews.freebsd.org/D31084 Reviewed by: kp, ngie
Diffstat (limited to 'tests')
-rw-r--r--tests/Makefile4
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/atf_python/Makefile12
-rw-r--r--tests/atf_python/__init__.py4
-rw-r--r--tests/atf_python/atf_pytest.py218
-rw-r--r--tests/atf_python/sys/Makefile11
-rw-r--r--tests/atf_python/sys/__init__.py0
-rw-r--r--tests/atf_python/sys/net/Makefile10
-rw-r--r--tests/atf_python/sys/net/__init__.py0
-rwxr-xr-xtests/atf_python/sys/net/rtsock.py604
-rw-r--r--tests/atf_python/sys/net/tools.py33
-rw-r--r--tests/atf_python/sys/net/vnet.py203
-rw-r--r--tests/conftest.py121
-rw-r--r--tests/freebsd_test_suite/Makefile13
-rw-r--r--tests/freebsd_test_suite/atf_pytest_wrapper.cpp192
15 files changed, 1424 insertions, 1 deletions
diff --git a/tests/Makefile b/tests/Makefile
index 561a0ec5fcab..cfd065d61539 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -4,12 +4,14 @@ PACKAGE= tests
TESTSDIR= ${TESTSBASE}
-${PACKAGE}FILES+= README
+${PACKAGE}FILES+= README __init__.py conftest.py
KYUAFILE= yes
SUBDIR+= etc
SUBDIR+= sys
+SUBDIR+= atf_python
+SUBDIR+= freebsd_test_suite
SUBDIR_PARALLEL=
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/atf_python/Makefile b/tests/atf_python/Makefile
new file mode 100644
index 000000000000..26d419743257
--- /dev/null
+++ b/tests/atf_python/Makefile
@@ -0,0 +1,12 @@
+.include <src.opts.mk>
+
+.PATH: ${.CURDIR}
+
+FILES= __init__.py atf_pytest.py
+SUBDIR= sys
+
+.include <bsd.own.mk>
+FILESDIR= ${TESTSBASE}/atf_python
+
+
+.include <bsd.prog.mk>
diff --git a/tests/atf_python/__init__.py b/tests/atf_python/__init__.py
new file mode 100644
index 000000000000..6d5ec22ef054
--- /dev/null
+++ b/tests/atf_python/__init__.py
@@ -0,0 +1,4 @@
+import pytest
+
+pytest.register_assert_rewrite("atf_python.sys.net.rtsock")
+pytest.register_assert_rewrite("atf_python.sys.net.vnet")
diff --git a/tests/atf_python/atf_pytest.py b/tests/atf_python/atf_pytest.py
new file mode 100644
index 000000000000..89c0e3a515b9
--- /dev/null
+++ b/tests/atf_python/atf_pytest.py
@@ -0,0 +1,218 @@
+import types
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+
+import pytest
+
+
+class ATFCleanupItem(pytest.Item):
+ def runtest(self):
+ """Runs cleanup procedure for the test instead of the test"""
+ instance = self.parent.cls()
+ instance.cleanup(self.nodeid)
+
+ def setup_method_noop(self, method):
+ """Overrides runtest setup method"""
+ pass
+
+ def teardown_method_noop(self, method):
+ """Overrides runtest teardown method"""
+ pass
+
+
+class ATFTestObj(object):
+ def __init__(self, obj, has_cleanup):
+ # Use nodeid without name to properly name class-derived tests
+ self.ident = obj.nodeid.split("::", 1)[1]
+ self.description = self._get_test_description(obj)
+ self.has_cleanup = has_cleanup
+ self.obj = obj
+
+ def _get_test_description(self, obj):
+ """Returns first non-empty line from func docstring or func name"""
+ docstr = obj.function.__doc__
+ if docstr:
+ for line in docstr.split("\n"):
+ if line:
+ return line
+ return obj.name
+
+ def _convert_marks(self, obj) -> Dict[str, Any]:
+ wj_func = lambda x: " ".join(x) # noqa: E731
+ _map: Dict[str, Dict] = {
+ "require_user": {"name": "require.user"},
+ "require_arch": {"name": "require.arch", "fmt": wj_func},
+ "require_diskspace": {"name": "require.diskspace"},
+ "require_files": {"name": "require.files", "fmt": wj_func},
+ "require_machine": {"name": "require.machine", "fmt": wj_func},
+ "require_memory": {"name": "require.memory"},
+ "require_progs": {"name": "require.progs", "fmt": wj_func},
+ "timeout": {},
+ }
+ ret = {}
+ for mark in obj.iter_markers():
+ if mark.name in _map:
+ name = _map[mark.name].get("name", mark.name)
+ if "fmt" in _map[mark.name]:
+ val = _map[mark.name]["fmt"](mark.args[0])
+ else:
+ val = mark.args[0]
+ ret[name] = val
+ return ret
+
+ def as_lines(self) -> List[str]:
+ """Output test definition in ATF-specific format"""
+ ret = []
+ ret.append("ident: {}".format(self.ident))
+ ret.append("descr: {}".format(self._get_test_description(self.obj)))
+ if self.has_cleanup:
+ ret.append("has.cleanup: true")
+ for key, value in self._convert_marks(self.obj).items():
+ ret.append("{}: {}".format(key, value))
+ return ret
+
+
+class ATFHandler(object):
+ class ReportState(NamedTuple):
+ state: str
+ reason: str
+
+ def __init__(self):
+ self._tests_state_map: Dict[str, ReportStatus] = {}
+
+ def override_runtest(self, obj):
+ # Override basic runtest command
+ obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj)
+ # Override class setup/teardown
+ obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop
+ obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop
+
+ def get_object_cleanup_class(self, obj):
+ if hasattr(obj, "parent") and obj.parent is not None:
+ if hasattr(obj.parent, "cls") and obj.parent.cls is not None:
+ if hasattr(obj.parent.cls, "cleanup"):
+ return obj.parent.cls
+ return None
+
+ def has_object_cleanup(self, obj):
+ return self.get_object_cleanup_class(obj) is not None
+
+ def list_tests(self, tests: List[str]):
+ print('Content-Type: application/X-atf-tp; version="1"')
+ print()
+ for test_obj in tests:
+ has_cleanup = self.has_object_cleanup(test_obj)
+ atf_test = ATFTestObj(test_obj, has_cleanup)
+ for line in atf_test.as_lines():
+ print(line)
+ print()
+
+ def set_report_state(self, test_name: str, state: str, reason: str):
+ self._tests_state_map[test_name] = self.ReportState(state, reason)
+
+ def _extract_report_reason(self, report):
+ data = report.longrepr
+ if data is None:
+ return None
+ if isinstance(data, Tuple):
+ # ('/path/to/test.py', 23, 'Skipped: unable to test')
+ reason = data[2]
+ for prefix in "Skipped: ":
+ if reason.startswith(prefix):
+ reason = reason[len(prefix):]
+ return reason
+ else:
+ # string/ traceback / exception report. Capture the last line
+ return str(data).split("\n")[-1]
+ return None
+
+ def add_report(self, report):
+ # MAP pytest report state to the atf-desired state
+ #
+ # ATF test states:
+ # (1) expected_death, (2) expected_exit, (3) expected_failure
+ # (4) expected_signal, (5) expected_timeout, (6) passed
+ # (7) skipped, (8) failed
+ #
+ # Note that ATF don't have the concept of "soft xfail" - xpass
+ # is a failure. It also calls teardown routine in a separate
+ # process, thus teardown states (pytest-only) are handled as
+ # body continuation.
+
+ # (stage, state, wasxfail)
+
+ # Just a passing test: WANT: passed
+ # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F)
+ #
+ # Failing body test: WHAT: failed
+ # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F)
+ #
+ # pytest.skip test decorator: WANT: skipped
+ # GOT: (setup,skipped, False), (teardown, passed, False)
+ #
+ # pytest.skip call inside test function: WANT: skipped
+ # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F)
+ #
+ # mark.xfail decorator+pytest.xfail: WANT: expected_failure
+ # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F)
+ #
+ # mark.xfail decorator+pass: WANT: failed
+ # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F)
+
+ test_name = report.location[2]
+ stage = report.when
+ state = report.outcome
+ reason = self._extract_report_reason(report)
+
+ # We don't care about strict xfail - it gets translated to False
+
+ if stage == "setup":
+ if state in ("skipped", "failed"):
+ # failed init -> failed test, skipped setup -> xskip
+ # for the whole test
+ self.set_report_state(test_name, state, reason)
+ elif stage == "call":
+ # "call" stage shouldn't matter if setup failed
+ if test_name in self._tests_state_map:
+ if self._tests_state_map[test_name].state == "failed":
+ return
+ if state == "failed":
+ # Record failure & override "skipped" state
+ self.set_report_state(test_name, state, reason)
+ elif state == "skipped":
+ if hasattr(reason, "wasxfail"):
+ # xfail() called in the test body
+ state = "expected_failure"
+ else:
+ # skip inside the body
+ pass
+ self.set_report_state(test_name, state, reason)
+ elif state == "passed":
+ if hasattr(reason, "wasxfail"):
+ # the test was expected to fail but didn't
+ # mark as hard failure
+ state = "failed"
+ self.set_report_state(test_name, state, reason)
+ elif stage == "teardown":
+ if state == "failed":
+ # teardown should be empty, as the cleanup
+ # procedures should be implemented as a separate
+ # function/method, so mark teardown failure as
+ # global failure
+ self.set_report_state(test_name, state, reason)
+
+ def write_report(self, path):
+ if self._tests_state_map:
+ # If we're executing in ATF mode, there has to be just one test
+ # Anyway, deterministically pick the first one
+ first_test_name = next(iter(self._tests_state_map))
+ test = self._tests_state_map[first_test_name]
+ if test.state == "passed":
+ line = test.state
+ else:
+ line = "{}: {}".format(test.state, test.reason)
+ with open(path, mode="w") as f:
+ print(line, file=f)
diff --git a/tests/atf_python/sys/Makefile b/tests/atf_python/sys/Makefile
new file mode 100644
index 000000000000..ff4cf17b85d2
--- /dev/null
+++ b/tests/atf_python/sys/Makefile
@@ -0,0 +1,11 @@
+.include <src.opts.mk>
+
+.PATH: ${.CURDIR}
+
+FILES= __init__.py
+SUBDIR= net
+
+.include <bsd.own.mk>
+FILESDIR= ${TESTSBASE}/atf_python/sys
+
+.include <bsd.prog.mk>
diff --git a/tests/atf_python/sys/__init__.py b/tests/atf_python/sys/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
--- /dev/null
+++ b/tests/atf_python/sys/__init__.py
diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile
new file mode 100644
index 000000000000..05b1d8afe863
--- /dev/null
+++ b/tests/atf_python/sys/net/Makefile
@@ -0,0 +1,10 @@
+.include <src.opts.mk>
+
+.PATH: ${.CURDIR}
+
+FILES= __init__.py rtsock.py tools.py vnet.py
+
+.include <bsd.own.mk>
+FILESDIR= ${TESTSBASE}/atf_python/sys/net
+
+.include <bsd.prog.mk>
diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
--- /dev/null
+++ b/tests/atf_python/sys/net/__init__.py
diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py
new file mode 100755
index 000000000000..788e863f8b28
--- /dev/null
+++ b/tests/atf_python/sys/net/rtsock.py
@@ -0,0 +1,604 @@
+#!/usr/local/bin/python3
+import os
+import socket
+import struct
+import sys
+from ctypes import c_byte
+from ctypes import c_char
+from ctypes import c_int
+from ctypes import c_long
+from ctypes import c_uint32
+from ctypes import c_ulong
+from ctypes import c_ushort
+from ctypes import sizeof
+from ctypes import Structure
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+
+def roundup2(val: int, num: int) -> int:
+ if val % num:
+ return (val | (num - 1)) + 1
+ else:
+ return val
+
+
+class RtSockException(OSError):
+ pass
+
+
+class RtConst:
+ RTM_VERSION = 5
+ ALIGN = sizeof(c_long)
+
+ AF_INET = socket.AF_INET
+ AF_INET6 = socket.AF_INET6
+ AF_LINK = socket.AF_LINK
+
+ RTA_DST = 0x1
+ RTA_GATEWAY = 0x2
+ RTA_NETMASK = 0x4
+ RTA_GENMASK = 0x8
+ RTA_IFP = 0x10
+ RTA_IFA = 0x20
+ RTA_AUTHOR = 0x40
+ RTA_BRD = 0x80
+
+ RTM_ADD = 1
+ RTM_DELETE = 2
+ RTM_CHANGE = 3
+ RTM_GET = 4
+
+ RTF_UP = 0x1
+ RTF_GATEWAY = 0x2
+ RTF_HOST = 0x4
+ RTF_REJECT = 0x8
+ RTF_DYNAMIC = 0x10
+ RTF_MODIFIED = 0x20
+ RTF_DONE = 0x40
+ RTF_XRESOLVE = 0x200
+ RTF_LLINFO = 0x400
+ RTF_LLDATA = 0x400
+ RTF_STATIC = 0x800
+ RTF_BLACKHOLE = 0x1000
+ RTF_PROTO2 = 0x4000
+ RTF_PROTO1 = 0x8000
+ RTF_PROTO3 = 0x40000
+ RTF_FIXEDMTU = 0x80000
+ RTF_PINNED = 0x100000
+ RTF_LOCAL = 0x200000
+ RTF_BROADCAST = 0x400000
+ RTF_MULTICAST = 0x800000
+ RTF_STICKY = 0x10000000
+ RTF_RNH_LOCKED = 0x40000000
+ RTF_GWFLAG_COMPAT = 0x80000000
+
+ RTV_MTU = 0x1
+ RTV_HOPCOUNT = 0x2
+ RTV_EXPIRE = 0x4
+ RTV_RPIPE = 0x8
+ RTV_SPIPE = 0x10
+ RTV_SSTHRESH = 0x20
+ RTV_RTT = 0x40
+ RTV_RTTVAR = 0x80
+ RTV_WEIGHT = 0x100
+
+ @staticmethod
+ def get_props(prefix: str) -> List[str]:
+ return [n for n in dir(RtConst) if n.startswith(prefix)]
+
+ @staticmethod
+ def get_name(prefix: str, value: int) -> str:
+ props = RtConst.get_props(prefix)
+ for prop in props:
+ if getattr(RtConst, prop) == value:
+ return prop
+ return "U:{}:{}".format(prefix, value)
+
+ @staticmethod
+ def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
+ props = RtConst.get_props(prefix)
+ propmap = {getattr(RtConst, prop): prop for prop in props}
+ v = 1
+ ret = {}
+ while value:
+ if v & value:
+ if v in propmap:
+ ret[v] = propmap[v]
+ else:
+ ret[v] = hex(v)
+ value -= v
+ v *= 2
+ return ret
+
+ @staticmethod
+ def get_bitmask_str(prefix: str, value: int) -> str:
+ bmap = RtConst.get_bitmask_map(prefix, value)
+ return ",".join([v for k, v in bmap.items()])
+
+
+class RtMetrics(Structure):
+ _fields_ = [
+ ("rmx_locks", c_ulong),
+ ("rmx_mtu", c_ulong),
+ ("rmx_hopcount", c_ulong),
+ ("rmx_expire", c_ulong),
+ ("rmx_recvpipe", c_ulong),
+ ("rmx_sendpipe", c_ulong),
+ ("rmx_ssthresh", c_ulong),
+ ("rmx_rtt", c_ulong),
+ ("rmx_rttvar", c_ulong),
+ ("rmx_pksent", c_ulong),
+ ("rmx_weight", c_ulong),
+ ("rmx_nhidx", c_ulong),
+ ("rmx_filler", c_ulong * 2),
+ ]
+
+
+class RtMsgHdr(Structure):
+ _fields_ = [
+ ("rtm_msglen", c_ushort),
+ ("rtm_version", c_byte),
+ ("rtm_type", c_byte),
+ ("rtm_index", c_ushort),
+ ("_rtm_spare1", c_ushort),
+ ("rtm_flags", c_int),
+ ("rtm_addrs", c_int),
+ ("rtm_pid", c_int),
+ ("rtm_seq", c_int),
+ ("rtm_errno", c_int),
+ ("rtm_fmask", c_int),
+ ("rtm_inits", c_ulong),
+ ("rtm_rmx", RtMetrics),
+ ]
+
+
+class SockaddrIn(Structure):
+ _fields_ = [
+ ("sin_len", c_byte),
+ ("sin_family", c_byte),
+ ("sin_port", c_ushort),
+ ("sin_addr", c_uint32),
+ ("sin_zero", c_char * 8),
+ ]
+
+
+class SockaddrIn6(Structure):
+ _fields_ = [
+ ("sin6_len", c_byte),
+ ("sin6_family", c_byte),
+ ("sin6_port", c_ushort),
+ ("sin6_flowinfo", c_uint32),
+ ("sin6_addr", c_byte * 16),
+ ("sin6_scope_id", c_uint32),
+ ]
+
+
+class SockaddrDl(Structure):
+ _fields_ = [
+ ("sdl_len", c_byte),
+ ("sdl_family", c_byte),
+ ("sdl_index", c_ushort),
+ ("sdl_type", c_byte),
+ ("sdl_nlen", c_byte),
+ ("sdl_alen", c_byte),
+ ("sdl_slen", c_byte),
+ ("sdl_data", c_byte * 8),
+ ]
+
+
+class SaHelper(object):
+ @staticmethod
+ def is_ipv6(ip: str) -> bool:
+ return ":" in ip
+
+ @staticmethod
+ def ip_sa(ip: str, scopeid: int = 0) -> bytes:
+ if SaHelper.is_ipv6(ip):
+ return SaHelper.ip6_sa(ip, scopeid)
+ else:
+ return SaHelper.ip4_sa(ip)
+
+ @staticmethod
+ def ip4_sa(ip: str) -> bytes:
+ addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
+ sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
+ return bytes(sin)
+
+ @staticmethod
+ def ip6_sa(ip6: str, scopeid: int) -> bytes:
+ addr_bytes = (c_byte * 16)()
+ for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)):
+ addr_bytes[i] = b
+ sin6 = SockaddrIn6(
+ sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid
+ )
+ return bytes(sin6)
+
+ @staticmethod
+ def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes:
+ sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
+ return bytes(sa)
+
+ @staticmethod
+ def pxlen4_sa(pxlen: int) -> bytes:
+ return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen))
+
+ @staticmethod
+ def pxlen_to_ip4(pxlen: int) -> str:
+ if pxlen == 32:
+ return "255.255.255.255"
+ else:
+ addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1)
+ addr_bytes = struct.pack("!I", addr)
+ return socket.inet_ntop(socket.AF_INET, addr_bytes)
+
+ @staticmethod
+ def pxlen6_sa(pxlen: int) -> bytes:
+ return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen))
+
+ @staticmethod
+ def pxlen_to_ip6(pxlen: int) -> str:
+ ip6_b = [0] * 16
+ start = 0
+ while pxlen > 8:
+ ip6_b[start] = 0xFF
+ pxlen -= 8
+ start += 1
+ ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1)
+ return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b))
+
+ @staticmethod
+ def print_sa_inet(sa: bytes):
+ if len(sa) < 8:
+ raise RtSockException("IPv4 sa size too small: {}".format(len(sa)))
+ addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
+ return "{}".format(addr)
+
+ @staticmethod
+ def print_sa_inet6(sa: bytes):
+ if len(sa) < sizeof(SockaddrIn6):
+ raise RtSockException("IPv6 sa size too small: {}".format(len(sa)))
+ addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
+ scopeid = struct.unpack(">I", sa[24:28])[0]
+ return "{} scopeid {}".format(addr, scopeid)
+
+ @staticmethod
+ def print_sa_link(sa: bytes, hd: Optional[bool] = True):
+ if len(sa) < sizeof(SockaddrDl):
+ raise RtSockException("LINK sa size too small: {}".format(len(sa)))
+ sdl = SockaddrDl.from_buffer_copy(sa)
+ if sdl.sdl_index:
+ ifindex = "link#{} ".format(sdl.sdl_index)
+ else:
+ ifindex = ""
+ if sdl.sdl_nlen:
+ iface_offset = 8
+ if sdl.sdl_nlen + iface_offset > len(sa):
+ raise RtSockException(
+ "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))
+ )
+ ifname = "ifname:{} ".format(
+ bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen])
+ )
+ else:
+ ifname = ""
+ return "{}{}".format(ifindex, ifname)
+
+ @staticmethod
+ def print_sa_unknown(sa: bytes):
+ return "unknown_type:{}".format(sa[1])
+
+ @classmethod
+ def print_sa(cls, sa: bytes, hd: Optional[bool] = False):
+ if sa[0] != len(sa):
+ raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
+
+ if len(sa) < 2:
+ raise Exception(
+ "sa type {} too short: {}".format(
+ RtConst.get_name("AF_", sa[1]), len(sa)
+ )
+ )
+
+ if sa[1] == socket.AF_INET:
+ text = cls.print_sa_inet(sa)
+ elif sa[1] == socket.AF_INET6:
+ text = cls.print_sa_inet6(sa)
+ elif sa[1] == socket.AF_LINK:
+ text = cls.print_sa_link(sa)
+ else:
+ text = cls.print_sa_unknown(sa)
+ if hd:
+ dump = " [{!r}]".format(sa)
+ else:
+ dump = ""
+ return "{}{}".format(text, dump)
+
+
+class BaseRtsockMessage(object):
+ def __init__(self, rtm_type):
+ self.rtm_type = rtm_type
+ self.sa = SaHelper()
+
+ @staticmethod
+ def print_rtm_type(rtm_type):
+ return RtConst.get_name("RTM_", rtm_type)
+
+ @property
+ def rtm_type_str(self):
+ return self.print_rtm_type(self.rtm_type)
+
+
+class RtsockRtMessage(BaseRtsockMessage):
+ messages = [
+ RtConst.RTM_ADD,
+ RtConst.RTM_DELETE,
+ RtConst.RTM_CHANGE,
+ RtConst.RTM_GET,
+ ]
+
+ def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
+ super().__init__(rtm_type)
+ self.rtm_flags = 0
+ self.rtm_seq = rtm_seq
+ self._attrs = {}
+ self.rtm_errno = 0
+ self.rtm_pid = 0
+ self.rtm_inits = 0
+ self.rtm_rmx = RtMetrics()
+ self._orig_data = None
+ if dst_sa:
+ self.add_sa_attr(RtConst.RTA_DST, dst_sa)
+ if mask_sa:
+ self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
+
+ def add_sa_attr(self, attr_type, attr_bytes: bytes):
+ self._attrs[attr_type] = attr_bytes
+
+ def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0):
+ if ":" in ip_addr:
+ self.add_ip6_attr(attr_type, ip_addr, scopeid)
+ else:
+ self.add_ip4_attr(attr_type, ip_addr)
+
+ def add_ip4_attr(self, attr_type, ip: str):
+ self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
+
+ def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
+ self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
+
+ def add_link_attr(self, attr_type, ifindex: Optional[int] = 0):
+ self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
+
+ def get_sa(self, attr_type) -> bytes:
+ return self._attrs.get(attr_type)
+
+ def print_message(self):
+ # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
+ if self._orig_data:
+ rtm_len = len(self._orig_data)
+ else:
+ rtm_len = len(bytes(self))
+ print(
+ "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
+ self.rtm_type_str,
+ rtm_len,
+ self.rtm_pid,
+ self.rtm_seq,
+ self.rtm_errno,
+ RtConst.get_bitmask_str("RTF_", self.rtm_flags),
+ )
+ )
+ rtm_addrs = sum(list(self._attrs.keys()))
+ print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
+ for attr in sorted(self._attrs.keys()):
+ sa_data = SaHelper.print_sa(self._attrs[attr])
+ print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
+
+ def print_in_message(self):
+ print("vvvvvvvv IN vvvvvvvv")
+ self.print_message()
+ print()
+
+ def verify_sa_inet(self, sa_data):
+ if len(sa_data) < 8:
+ raise Exception("IPv4 sa size too small: {}".format(sa_data))
+ if sa_data[0] > len(sa_data):
+ raise Exception(
+ "IPv4 sin_len too big: {} vs sa size {}: {}".format(
+ sa_data[0], len(sa_data), sa_data
+ )
+ )
+ sin = SockaddrIn.from_buffer_copy(sa_data)
+ assert sin.sin_port == 0
+ assert sin.sin_zero == [0] * 8
+
+ def compare_sa(self, sa_type, sa_data):
+ if len(sa_data) < 4:
+ sa_type_name = RtConst.get_name("RTA_", sa_type)
+ raise Exception(
+ "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data))
+ )
+ our_sa = self._attrs[sa_type]
+ assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa)
+ assert len(sa_data) == len(our_sa)
+ assert sa_data == our_sa
+
+ def verify(self, rtm_type: int, rtm_sa):
+ assert self.rtm_type_str == self.print_rtm_type(rtm_type)
+ assert self.rtm_errno == 0
+ hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
+ assert hdr._rtm_spare1 == 0
+ for sa_type, sa_data in rtm_sa.items():
+ if sa_type not in self._attrs:
+ sa_type_name = RtConst.get_name("RTA_", sa_type)
+ raise Exception("SA type {} not present".format(sa_type_name))
+ self.compare_sa(sa_type, sa_data)
+
+ @classmethod
+ def from_bytes(cls, data: bytes):
+ if len(data) < sizeof(RtMsgHdr):
+ raise Exception(
+ "messages size {} is less than expected {}".format(
+ len(data), sizeof(RtMsgHdr)
+ )
+ )
+ hdr = RtMsgHdr.from_buffer_copy(data)
+
+ self = cls(hdr.rtm_type)
+ self.rtm_flags = hdr.rtm_flags
+ self.rtm_seq = hdr.rtm_seq
+ self.rtm_errno = hdr.rtm_errno
+ self.rtm_pid = hdr.rtm_pid
+ self.rtm_inits = hdr.rtm_inits
+ self.rtm_rmx = hdr.rtm_rmx
+ self._orig_data = data
+
+ off = sizeof(RtMsgHdr)
+ v = 1
+ addrs_mask = hdr.rtm_addrs
+ while addrs_mask:
+ if addrs_mask & v:
+ addrs_mask -= v
+
+ if off + data[off] > len(data):
+ raise Exception(
+ "SA sizeof for {} > total message length: {}+{} > {}".format(
+ RtConst.get_name("RTA_", v), off, data[off], len(data)
+ )
+ )
+ self._attrs[v] = data[off : off + data[off]]
+ off += roundup2(data[off], RtConst.ALIGN)
+ v *= 2
+ return self
+
+ def __bytes__(self):
+ sz = sizeof(RtMsgHdr)
+ addrs_mask = 0
+ for k, v in self._attrs.items():
+ sz += roundup2(len(v), RtConst.ALIGN)
+ addrs_mask += k
+ hdr = RtMsgHdr(
+ rtm_msglen=sz,
+ rtm_version=RtConst.RTM_VERSION,
+ rtm_type=self.rtm_type,
+ rtm_flags=self.rtm_flags,
+ rtm_seq=self.rtm_seq,
+ rtm_addrs=addrs_mask,
+ rtm_inits=self.rtm_inits,
+ rtm_rmx=self.rtm_rmx,
+ )
+ buf = bytearray(sz)
+ buf[0 : sizeof(RtMsgHdr)] = hdr
+ off = sizeof(RtMsgHdr)
+ for attr in sorted(self._attrs.keys()):
+ v = self._attrs[attr]
+ sa_len = len(v)
+ buf[off : off + sa_len] = v
+ off += roundup2(len(v), RtConst.ALIGN)
+ return bytes(buf)
+
+
+class Rtsock:
+ def __init__(self):
+ self.socket = self._setup_rtsock()
+ self.rtm_seq = 1
+ self.msgmap = self.build_msgmap()
+
+ def build_msgmap(self):
+ classes = [RtsockRtMessage]
+ xmap = {}
+ for cls in classes:
+ for message in cls.messages:
+ xmap[message] = cls
+ return xmap
+
+ def get_seq(self):
+ ret = self.rtm_seq
+ self.rtm_seq += 1
+ return ret
+
+ def get_weight(self, weight) -> int:
+ if weight:
+ return weight
+ else:
+ return 1 # RT_DEFAULT_WEIGHT
+
+ def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]):
+ px = prefix.split("/")
+ addr_sa = SaHelper.ip_sa(px[0])
+ if len(px) > 1:
+ pxlen = int(px[1])
+ if SaHelper.is_ipv6(px[0]):
+ mask_sa = SaHelper.pxlen6_sa(pxlen)
+ else:
+ mask_sa = SaHelper.pxlen4_sa(pxlen)
+ else:
+ mask_sa = None
+ msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa)
+ if isinstance(gw, bytes):
+ msg.add_sa_attr(RtConst.RTA_GATEWAY, gw)
+ else:
+ # String
+ msg.add_ip_attr(RtConst.RTA_GATEWAY, gw)
+ return msg
+
+ def new_rtm_add(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw)
+
+ def new_rtm_del(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw)
+
+ def new_rtm_change(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw)
+
+ def _setup_rtsock(self) -> socket.socket:
+ s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
+ return s
+
+ def print_hd(self, data: bytes):
+ width = 16
+ print("==========================================")
+ for chunk in [data[i : i + width] for i in range(0, len(data), width)]:
+ for b in chunk:
+ print("0x{:02X} ".format(b), end="")
+ print()
+ print()
+
+ def write_message(self, msg):
+ print("vvvvvvvv OUT vvvvvvvv")
+ msg.print_message()
+ print()
+ msg_bytes = bytes(msg)
+ ret = os.write(self.socket.fileno(), msg_bytes)
+ if ret != -1:
+ assert ret == len(msg_bytes)
+
+ def parse_message(self, data: bytes):
+ if len(data) < 4:
+ raise OSError("Short read from rtsock: {} bytes".format(len(data)))
+ rtm_type = data[4]
+ if rtm_type not in self.msgmap:
+ return None
+
+ def write_data(self, data: bytes):
+ self.socket.send(data)
+
+ def read_data(self, seq: Optional[int] = None) -> bytes:
+ while True:
+ data = self.socket.recv(4096)
+ if seq is None:
+ break
+ if len(data) > sizeof(RtMsgHdr):
+ hdr = RtMsgHdr.from_buffer_copy(data)
+ if hdr.rtm_seq == seq:
+ break
+ return data
+
+ def read_message(self) -> bytes:
+ data = self.read_data()
+ return self.parse_message(data)
diff --git a/tests/atf_python/sys/net/tools.py b/tests/atf_python/sys/net/tools.py
new file mode 100644
index 000000000000..9f44872c2c37
--- /dev/null
+++ b/tests/atf_python/sys/net/tools.py
@@ -0,0 +1,33 @@
+#!/usr/local/bin/python3
+import json
+import os
+import socket
+import time
+from ctypes import cdll
+from ctypes import get_errno
+from ctypes.util import find_library
+from typing import List
+from typing import Optional
+
+
+class ToolsHelper(object):
+ NETSTAT_PATH = "/usr/bin/netstat"
+
+ @classmethod
+ def get_output(cls, cmd: str, verbose=False) -> str:
+ if verbose:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+ @classmethod
+ def get_routes(cls, family: str, fibnum: int = 0):
+ family_key = {"inet": "-4", "inet6": "-6"}.get(family)
+ out = cls.get_output(
+ "{} {} -rn -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum)
+ )
+ js = json.loads(out)
+ js = js["statistics"]["route-information"]["route-table"]["rt-family"]
+ if js:
+ return js[0]["rt-entry"]
+ else:
+ return []
diff --git a/tests/atf_python/sys/net/vnet.py b/tests/atf_python/sys/net/vnet.py
new file mode 100644
index 000000000000..0957364f627c
--- /dev/null
+++ b/tests/atf_python/sys/net/vnet.py
@@ -0,0 +1,203 @@
+#!/usr/local/bin/python3
+import os
+import socket
+import time
+from ctypes import cdll
+from ctypes import get_errno
+from ctypes.util import find_library
+from typing import List
+from typing import Optional
+
+
+def run_cmd(cmd: str) -> str:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+
+class VnetInterface(object):
+ INTERFACES_FNAME = "created_interfaces.lst"
+
+ # defines from net/if_types.h
+ IFT_LOOP = 0x18
+ IFT_ETHER = 0x06
+
+ def __init__(self, iface_name: str):
+ self.name = iface_name
+ self.vnet_name = ""
+ self.jailed = False
+ if iface_name.startswith("lo"):
+ self.iftype = self.IFT_LOOP
+ else:
+ self.iftype = self.IFT_ETHER
+
+ @property
+ def ifindex(self):
+ return socket.if_nametoindex(self.name)
+
+ def set_vnet(self, vnet_name: str):
+ self.vnet_name = vnet_name
+
+ def set_jailed(self, jailed: bool):
+ self.jailed = jailed
+
+ def run_cmd(self, cmd):
+ if self.vnet_name and not self.jailed:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ run_cmd(cmd)
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInterface.INTERFACES_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @classmethod
+ def create_iface(cls, iface_name: str):
+ name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip()
+ if not name:
+ raise Exception("Unable to create iface {}".format(iface_name))
+ cls.file_append_line(name)
+ if name.startswith("epair"):
+ cls.file_append_line(name[:-1] + "b")
+ return cls(name)
+
+ @staticmethod
+ def cleanup_ifaces():
+ try:
+ with open(VnetInterface.INTERFACES_FNAME, "r") as f:
+ for line in f:
+ run_cmd("/sbin/ifconfig {} destroy".format(line.strip()))
+ os.unlink(VnetInterface.INTERFACES_FNAME)
+ except Exception:
+ pass
+
+ def setup_addr(self, addr: str):
+ if ":" in addr:
+ family = "inet6"
+ else:
+ family = "inet"
+ cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr)
+ self.run_cmd(cmd)
+
+ def delete_addr(self, addr: str):
+ if ":" in addr:
+ cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr)
+ else:
+ cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr)
+ self.run_cmd(cmd)
+
+ def turn_up(self):
+ cmd = "/sbin/ifconfig {} up".format(self.name)
+ self.run_cmd(cmd)
+
+ def enable_ipv6(self):
+ cmd = "/usr/sbin/ndp -i {} -disabled".format(self.name)
+ self.run_cmd(cmd)
+
+
+class VnetInstance(object):
+ JAILS_FNAME = "created_jails.lst"
+
+ def __init__(self, vnet_name: str, jid: int, ifaces: List[VnetInterface]):
+ self.name = vnet_name
+ self.jid = jid
+ self.ifaces = ifaces
+ for iface in ifaces:
+ iface.set_vnet(vnet_name)
+ iface.set_jailed(True)
+
+ def run_vnet_cmd(self, cmd):
+ if self.vnet_name:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ return run_cmd(cmd)
+
+ @staticmethod
+ def wait_interface(vnet_name: str, iface_name: str):
+ cmd = "jexec {} /sbin/ifconfig -l".format(vnet_name)
+ for i in range(50):
+ ifaces = run_cmd(cmd).strip().split(" ")
+ if iface_name in ifaces:
+ return True
+ time.sleep(0.1)
+ return False
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInstance.JAILS_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @staticmethod
+ def cleanup_vnets():
+ try:
+ with open(VnetInstance.JAILS_FNAME) as f:
+ for line in f:
+ run_cmd("/usr/sbin/jail -r {}".format(line.strip()))
+ os.unlink(VnetInstance.JAILS_FNAME)
+ except Exception:
+ pass
+
+ @classmethod
+ def create_with_interfaces(cls, vnet_name: str, ifaces: List[VnetInterface]):
+ iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces])
+ cmd = "/usr/sbin/jail -i -c name={} persist vnet {}".format(
+ vnet_name, iface_cmds
+ )
+ jid_str = run_cmd(cmd)
+ jid = int(jid_str)
+ if jid <= 0:
+ raise Exception("Jail creation failed, output: {}".format(jid))
+ cls.file_append_line(vnet_name)
+
+ for iface in ifaces:
+ if cls.wait_interface(vnet_name, iface.name):
+ continue
+ raise Exception(
+ "Interface {} has not appeared in vnet {}".format(iface.name, vnet_name)
+ )
+ return cls(vnet_name, jid, ifaces)
+
+ @staticmethod
+ def attach_jid(jid: int):
+ _path: Optional[str] = find_library("c")
+ if _path is None:
+ raise Exception("libc not found")
+ path: str = _path
+ libc = cdll.LoadLibrary(path)
+ if libc.jail_attach(jid) != 0:
+ raise Exception("jail_attach() failed: errno {}".format(get_errno()))
+
+ def attach(self):
+ self.attach_jid(self.jid)
+
+
+class SingleVnetTestTemplate(object):
+ num_epairs = 1
+ IPV6_PREFIXES: List[str] = []
+ IPV4_PREFIXES: List[str] = []
+
+ def setup_method(self, method):
+ test_name = method.__name__
+ vnet_name = "jail_{}".format(test_name)
+ ifaces = []
+ for i in range(self.num_epairs):
+ ifaces.append(VnetInterface.create_iface("epair"))
+ self.vnet = VnetInstance.create_with_interfaces(vnet_name, ifaces)
+ self.vnet.attach()
+ for i, addr in enumerate(self.IPV6_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.enable_ipv6()
+ iface.setup_addr(addr)
+ for i, addr in enumerate(self.IPV4_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.setup_addr(addr)
+
+ def cleanup(self, nodeid: str):
+ print("==== vnet cleanup ===")
+ VnetInstance.cleanup_vnets()
+ VnetInterface.cleanup_ifaces()
+
+ def run_cmd(self, cmd: str) -> str:
+ return os.popen(cmd).read()
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 000000000000..193d2adfb5e0
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,121 @@
+import pytest
+from atf_python.atf_pytest import ATFHandler
+
+
+PLUGIN_ENABLED = False
+DEFAULT_HANDLER = None
+
+
+def get_handler():
+ global DEFAULT_HANDLER
+ if DEFAULT_HANDLER is None:
+ DEFAULT_HANDLER = ATFHandler()
+ return DEFAULT_HANDLER
+
+
+def pytest_addoption(parser):
+ """Add file output"""
+ # Add meta-values
+ group = parser.getgroup("general", "Running and selection options")
+ group.addoption("--atf-var", dest="atf_vars", action="append", default=[])
+ group.addoption(
+ "--atf-source-dir",
+ type=str,
+ dest="atf_source_dir",
+ help="Path to the test source directory",
+ )
+ group.addoption(
+ "--atf-cleanup",
+ default=False,
+ action="store_true",
+ dest="atf_cleanup",
+ help="Call cleanup procedure for a given test",
+ )
+ group = parser.getgroup("terminal reporting", "reporting", after="general")
+ group.addoption(
+ "--atf",
+ default=False,
+ action="store_true",
+ help="Enable test listing/results output in atf format",
+ )
+ group.addoption(
+ "--atf-file",
+ type=str,
+ dest="atf_file",
+ help="Path to the status file provided by atf runtime",
+ )
+
+
+@pytest.mark.trylast
+def pytest_configure(config):
+ if config.option.help:
+ return
+
+ # Register markings anyway to avoid warnings
+ config.addinivalue_line("markers", "require_user(name): user to run the test with")
+ config.addinivalue_line(
+ "markers", "require_arch(names): List[str] of support archs"
+ )
+ # config.addinivalue_line("markers", "require_config(config): List[Tuple[str,Any]] of k=v pairs")
+ config.addinivalue_line(
+ "markers", "require_diskspace(amount): str with required diskspace"
+ )
+ config.addinivalue_line(
+ "markers", "require_files(space): List[str] with file paths"
+ )
+ config.addinivalue_line(
+ "markers", "require_machine(names): List[str] of support machine types"
+ )
+ config.addinivalue_line(
+ "markers", "require_memory(amount): str with required memory"
+ )
+ config.addinivalue_line(
+ "markers", "require_progs(space): List[str] with file paths"
+ )
+ config.addinivalue_line(
+ "markers", "timeout(dur): int/float with max duration in sec"
+ )
+
+ global PLUGIN_ENABLED
+ PLUGIN_ENABLED = config.option.atf
+ if not PLUGIN_ENABLED:
+ return
+ get_handler()
+
+ if config.option.collectonly:
+ # Need to output list of tests to stdout, hence override
+ # standard reporter plugin
+ reporter = config.pluginmanager.getplugin("terminalreporter")
+ if reporter:
+ config.pluginmanager.unregister(reporter)
+
+
+def pytest_collection_modifyitems(session, config, items):
+ """If cleanup is requested, replace collected tests with their cleanups (if any)"""
+ if PLUGIN_ENABLED and config.option.atf_cleanup:
+ new_items = []
+ handler = get_handler()
+ for obj in items:
+ if handler.has_object_cleanup(obj):
+ handler.override_runtest(obj)
+ new_items.append(obj)
+ items.clear()
+ items.extend(new_items)
+
+
+def pytest_collection_finish(session):
+ if PLUGIN_ENABLED and session.config.option.collectonly:
+ handler = get_handler()
+ handler.list_tests(session.items)
+
+
+def pytest_runtest_logreport(report):
+ if PLUGIN_ENABLED:
+ handler = get_handler()
+ handler.add_report(report)
+
+
+def pytest_unconfigure(config):
+ if PLUGIN_ENABLED and config.option.atf_file:
+ handler = get_handler()
+ handler.write_report(config.option.atf_file)
diff --git a/tests/freebsd_test_suite/Makefile b/tests/freebsd_test_suite/Makefile
new file mode 100644
index 000000000000..c929ca2880eb
--- /dev/null
+++ b/tests/freebsd_test_suite/Makefile
@@ -0,0 +1,13 @@
+.include <src.opts.mk>
+
+PACKAGE= tests
+PROG_CXX= atf_pytest_wrapper
+SRCS= atf_pytest_wrapper.cpp
+CXXSTD= c++17
+MAN=
+BINDIR=
+
+.include <bsd.own.mk>
+DESTDIR=${TESTSBASE}
+
+.include <bsd.prog.mk>
diff --git a/tests/freebsd_test_suite/atf_pytest_wrapper.cpp b/tests/freebsd_test_suite/atf_pytest_wrapper.cpp
new file mode 100644
index 000000000000..11fd3c47d507
--- /dev/null
+++ b/tests/freebsd_test_suite/atf_pytest_wrapper.cpp
@@ -0,0 +1,192 @@
+#include <format>
+#include <iostream>
+#include <string>
+#include <vector>
+#include <stdlib.h>
+#include <unistd.h>
+
+class Handler {
+ private:
+ const std::string kPytestName = "pytest";
+ const std::string kCleanupSuffix = ":cleanup";
+ const std::string kPythonPathEnv = "PYTHONPATH";
+ public:
+ // Test listing requested
+ bool flag_list = false;
+ // Output debug data (will break listing)
+ bool flag_debug = false;
+ // Cleanup for the test requested
+ bool flag_cleanup = false;
+ // Test source directory (provided by ATF)
+ std::string src_dir;
+ // Path to write test status to (provided by ATF)
+ std::string dst_file;
+ // Path to add to PYTHONPATH (provided by the schebang args)
+ std::string python_path;
+ // Path to the script (provided by the schebang wrapper)
+ std::string script_path;
+ // Name of the test to run (provided by ATF)
+ std::string test_name;
+ // kv pairs (provided by ATF)
+ std::vector<std::string> kv_list;
+ // our binary name
+ std::string binary_name;
+
+ static std::vector<std::string> ToVector(int argc, char **argv) {
+ std::vector<std::string> ret;
+
+ for (int i = 0; i < argc; i++) {
+ ret.emplace_back(std::string(argv[i]));
+ }
+ return ret;
+ }
+
+ static void PrintVector(std::string prefix, const std::vector<std::string> &vec) {
+ std::cerr << prefix << ": ";
+ for (auto &val: vec) {
+ std::cerr << "'" << val << "' ";
+ }
+ std::cerr << std::endl;
+ }
+
+ void Usage(std::string msg, bool exit_with_error) {
+ std::cerr << binary_name << ": ERROR: " << msg << "." << std::endl;
+ std::cerr << binary_name << ": See atf-test-program(1) for usage details." << std::endl;
+ exit(exit_with_error != 0);
+ }
+
+ // Parse args received from the OS. There can be multiple valid options:
+ // * with schebang args (#!/binary -P/path):
+ // atf_wrap '-P /path' /path/to/script -l
+ // * without schebang args
+ // atf_wrap /path/to/script -l
+ // Running test:
+ // atf_wrap '-P /path' /path/to/script -r /path1 -s /path2 -vk1=v1 testname
+ void Parse(int argc, char **argv) {
+ if (flag_debug) {
+ PrintVector("IN", ToVector(argc, argv));
+ }
+ // getopt() skips the first argument (as it is typically binary name)
+ // it is possible to have either '-P\s*/path' followed by the script name
+ // or just the script name. Parse kernel-provided arg manually and adjust
+ // array to make getopt work
+
+ binary_name = std::string(argv[0]);
+ argc--; argv++;
+ // parse -P\s*path from the kernel.
+ if (argc > 0 && !strncmp(argv[0], "-P", 2)) {
+ char *path = &argv[0][2];
+ while (*path == ' ')
+ path++;
+ python_path = std::string(path);
+ argc--; argv++;
+ }
+
+ // The next argument is a script name. Copy and keep argc/argv the same
+ // Show usage for empty args
+ if (argc == 0) {
+ Usage("Must provide a test case name", true);
+ }
+ script_path = std::string(argv[0]);
+
+ int c;
+ while ((c = getopt(argc, argv, "lr:s:v:")) != -1) {
+ switch (c) {
+ case 'l':
+ flag_list = true;
+ break;
+ case 's':
+ src_dir = std::string(optarg);
+ break;
+ case 'r':
+ dst_file = std::string(optarg);
+ break;
+ case 'v':
+ kv_list.emplace_back(std::string(optarg));
+ break;
+ default:
+ Usage("Unknown option -" + std::string(1, static_cast<char>(c)), true);
+ }
+ }
+ argc -= optind;
+ argv += optind;
+
+ if (flag_list) {
+ return;
+ }
+ // There should be just one argument with the test name
+ if (argc != 1) {
+ Usage("Must provide a test case name", true);
+ }
+ test_name = std::string(argv[0]);
+ if (test_name.size() > kCleanupSuffix.size() &&
+ std::equal(kCleanupSuffix.rbegin(), kCleanupSuffix.rend(), test_name.rbegin())) {
+ test_name = test_name.substr(0, test_name.size() - kCleanupSuffix.size());
+ flag_cleanup = true;
+ }
+ }
+
+ std::vector<std::string> BuildArgs() {
+ std::vector<std::string> args = {"pytest", "-p", "no:cacheprovider", "-s", "--atf"};
+
+ if (flag_list) {
+ args.push_back("--co");
+ args.push_back(script_path);
+ return args;
+ }
+ if (flag_cleanup) {
+ args.push_back("--atf-cleanup");
+ }
+ if (!src_dir.empty()) {
+ args.push_back("--atf-source-dir");
+ args.push_back(src_dir);
+ }
+ if (!dst_file.empty()) {
+ args.push_back("--atf-file");
+ args.push_back(dst_file);
+ }
+ for (auto &pair: kv_list) {
+ args.push_back("--atf-var");
+ args.push_back(pair);
+ }
+ // Create nodeid from the test path &name
+ args.push_back(script_path + "::" + test_name);
+ return args;
+ }
+
+ void SetEnv() {
+ if (!python_path.empty()) {
+ char *env_path = getenv(kPythonPathEnv.c_str());
+ if (env_path != nullptr) {
+ python_path = python_path + ":" + std::string(env_path);
+ }
+ setenv(kPythonPathEnv.c_str(), python_path.c_str(), 1);
+ }
+ }
+
+ int Run(std::string binary, std::vector<std::string> args) {
+ if (flag_debug) {
+ PrintVector("OUT", args);
+ }
+ // allocate array with final NULL
+ char **arr = new char*[args.size() + 1]();
+ for (unsigned long i = 0; i < args.size(); i++) {
+ // work around 'char *const *'
+ arr[i] = strdup(args[i].c_str());
+ }
+ return (execvp(binary.c_str(), arr) != 0);
+ }
+
+ int Process() {
+ SetEnv();
+ return Run(kPytestName, BuildArgs());
+ }
+};
+
+
+int main(int argc, char **argv) {
+ Handler handler;
+
+ handler.Parse(argc, argv);
+ return handler.Process();
+}