aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/Makefile4
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/atf_python/Makefile12
-rw-r--r--tests/atf_python/__init__.py4
-rw-r--r--tests/atf_python/atf_pytest.py218
-rw-r--r--tests/atf_python/sys/Makefile11
-rw-r--r--tests/atf_python/sys/__init__.py0
-rw-r--r--tests/atf_python/sys/net/Makefile10
-rw-r--r--tests/atf_python/sys/net/__init__.py0
-rwxr-xr-xtests/atf_python/sys/net/rtsock.py604
-rw-r--r--tests/atf_python/sys/net/tools.py33
-rw-r--r--tests/atf_python/sys/net/vnet.py203
-rw-r--r--tests/conftest.py121
-rw-r--r--tests/freebsd_test_suite/Makefile13
-rw-r--r--tests/freebsd_test_suite/atf_pytest_wrapper.cpp192
15 files changed, 1424 insertions, 1 deletions
diff --git a/tests/Makefile b/tests/Makefile
index 561a0ec5fcab..cfd065d61539 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -4,12 +4,14 @@ PACKAGE= tests
TESTSDIR= ${TESTSBASE}
-${PACKAGE}FILES+= README
+${PACKAGE}FILES+= README __init__.py conftest.py
KYUAFILE= yes
SUBDIR+= etc
SUBDIR+= sys
+SUBDIR+= atf_python
+SUBDIR+= freebsd_test_suite
SUBDIR_PARALLEL=
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/atf_python/Makefile b/tests/atf_python/Makefile
new file mode 100644
index 000000000000..26d419743257
--- /dev/null
+++ b/tests/atf_python/Makefile
@@ -0,0 +1,12 @@
+.include <src.opts.mk>
+
+.PATH: ${.CURDIR}
+
+FILES= __init__.py atf_pytest.py
+SUBDIR= sys
+
+.include <bsd.own.mk>
+FILESDIR= ${TESTSBASE}/atf_python
+
+
+.include <bsd.prog.mk>
diff --git a/tests/atf_python/__init__.py b/tests/atf_python/__init__.py
new file mode 100644
index 000000000000..6d5ec22ef054
--- /dev/null
+++ b/tests/atf_python/__init__.py
@@ -0,0 +1,4 @@
+import pytest
+
+pytest.register_assert_rewrite("atf_python.sys.net.rtsock")
+pytest.register_assert_rewrite("atf_python.sys.net.vnet")
diff --git a/tests/atf_python/atf_pytest.py b/tests/atf_python/atf_pytest.py
new file mode 100644
index 000000000000..89c0e3a515b9
--- /dev/null
+++ b/tests/atf_python/atf_pytest.py
@@ -0,0 +1,218 @@
+import types
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+
+import pytest
+
+
+class ATFCleanupItem(pytest.Item):
+ def runtest(self):
+ """Runs cleanup procedure for the test instead of the test"""
+ instance = self.parent.cls()
+ instance.cleanup(self.nodeid)
+
+ def setup_method_noop(self, method):
+ """Overrides runtest setup method"""
+ pass
+
+ def teardown_method_noop(self, method):
+ """Overrides runtest teardown method"""
+ pass
+
+
+class ATFTestObj(object):
+ def __init__(self, obj, has_cleanup):
+ # Use nodeid without name to properly name class-derived tests
+ self.ident = obj.nodeid.split("::", 1)[1]
+ self.description = self._get_test_description(obj)
+ self.has_cleanup = has_cleanup
+ self.obj = obj
+
+ def _get_test_description(self, obj):
+ """Returns first non-empty line from func docstring or func name"""
+ docstr = obj.function.__doc__
+ if docstr:
+ for line in docstr.split("\n"):
+ if line:
+ return line
+ return obj.name
+
+ def _convert_marks(self, obj) -> Dict[str, Any]:
+ wj_func = lambda x: " ".join(x) # noqa: E731
+ _map: Dict[str, Dict] = {
+ "require_user": {"name": "require.user"},
+ "require_arch": {"name": "require.arch", "fmt": wj_func},
+ "require_diskspace": {"name": "require.diskspace"},
+ "require_files": {"name": "require.files", "fmt": wj_func},
+ "require_machine": {"name": "require.machine", "fmt": wj_func},
+ "require_memory": {"name": "require.memory"},
+ "require_progs": {"name": "require.progs", "fmt": wj_func},
+ "timeout": {},
+ }
+ ret = {}
+ for mark in obj.iter_markers():
+ if mark.name in _map:
+ name = _map[mark.name].get("name", mark.name)
+ if "fmt" in _map[mark.name]:
+ val = _map[mark.name]["fmt"](mark.args[0])
+ else:
+ val = mark.args[0]
+ ret[name] = val
+ return ret
+
+ def as_lines(self) -> List[str]:
+ """Output test definition in ATF-specific format"""
+ ret = []
+ ret.append("ident: {}".format(self.ident))
+ ret.append("descr: {}".format(self._get_test_description(self.obj)))
+ if self.has_cleanup:
+ ret.append("has.cleanup: true")
+ for key, value in self._convert_marks(self.obj).items():
+ ret.append("{}: {}".format(key, value))
+ return ret
+
+
+class ATFHandler(object):
+ class ReportState(NamedTuple):
+ state: str
+ reason: str
+
+ def __init__(self):
+ self._tests_state_map: Dict[str, ReportStatus] = {}
+
+ def override_runtest(self, obj):
+ # Override basic runtest command
+ obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj)
+ # Override class setup/teardown
+ obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop
+ obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop
+
+ def get_object_cleanup_class(self, obj):
+ if hasattr(obj, "parent") and obj.parent is not None:
+ if hasattr(obj.parent, "cls") and obj.parent.cls is not None:
+ if hasattr(obj.parent.cls, "cleanup"):
+ return obj.parent.cls
+ return None
+
+ def has_object_cleanup(self, obj):
+ return self.get_object_cleanup_class(obj) is not None
+
+ def list_tests(self, tests: List[str]):
+ print('Content-Type: application/X-atf-tp; version="1"')
+ print()
+ for test_obj in tests:
+ has_cleanup = self.has_object_cleanup(test_obj)
+ atf_test = ATFTestObj(test_obj, has_cleanup)
+ for line in atf_test.as_lines():
+ print(line)
+ print()
+
+ def set_report_state(self, test_name: str, state: str, reason: str):
+ self._tests_state_map[test_name] = self.ReportState(state, reason)
+
+ def _extract_report_reason(self, report):
+ data = report.longrepr
+ if data is None:
+ return None
+ if isinstance(data, Tuple):
+ # ('/path/to/test.py', 23, 'Skipped: unable to test')
+ reason = data[2]
+ for prefix in "Skipped: ":
+ if reason.startswith(prefix):
+ reason = reason[len(prefix):]
+ return reason
+ else:
+ # string/ traceback / exception report. Capture the last line
+ return str(data).split("\n")[-1]
+ return None
+
+ def add_report(self, report):
+ # MAP pytest report state to the atf-desired state
+ #
+ # ATF test states:
+ # (1) expected_death, (2) expected_exit, (3) expected_failure
+ # (4) expected_signal, (5) expected_timeout, (6) passed
+ # (7) skipped, (8) failed
+ #
+ # Note that ATF don't have the concept of "soft xfail" - xpass
+ # is a failure. It also calls teardown routine in a separate
+ # process, thus teardown states (pytest-only) are handled as
+ # body continuation.
+
+ # (stage, state, wasxfail)
+
+ # Just a passing test: WANT: passed
+ # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F)
+ #
+ # Failing body test: WHAT: failed
+ # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F)
+ #
+ # pytest.skip test decorator: WANT: skipped
+ # GOT: (setup,skipped, False), (teardown, passed, False)
+ #
+ # pytest.skip call inside test function: WANT: skipped
+ # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F)
+ #
+ # mark.xfail decorator+pytest.xfail: WANT: expected_failure
+ # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F)
+ #
+ # mark.xfail decorator+pass: WANT: failed
+ # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F)
+
+ test_name = report.location[2]
+ stage = report.when
+ state = report.outcome
+ reason = self._extract_report_reason(report)
+
+ # We don't care about strict xfail - it gets translated to False
+
+ if stage == "setup":
+ if state in ("skipped", "failed"):
+ # failed init -> failed test, skipped setup -> xskip
+ # for the whole test
+ self.set_report_state(test_name, state, reason)
+ elif stage == "call":
+ # "call" stage shouldn't matter if setup failed
+ if test_name in self._tests_state_map:
+ if self._tests_state_map[test_name].state == "failed":
+ return
+ if state == "failed":
+ # Record failure & override "skipped" state
+ self.set_report_state(test_name, state, reason)
+ elif state == "skipped":
+ if hasattr(reason, "wasxfail"):
+ # xfail() called in the test body
+ state = "expected_failure"
+ else:
+ # skip inside the body
+ pass
+ self.set_report_state(test_name, state, reason)
+ elif state == "passed":
+ if hasattr(reason, "wasxfail"):
+ # the test was expected to fail but didn't
+ # mark as hard failure
+ state = "failed"
+ self.set_report_state(test_name, state, reason)
+ elif stage == "teardown":
+ if state == "failed":
+ # teardown should be empty, as the cleanup
+ # procedures should be implemented as a separate
+ # function/method, so mark teardown failure as
+ # global failure
+ self.set_report_state(test_name, state, reason)
+
+ def write_report(self, path):
+ if self._tests_state_map:
+ # If we're executing in ATF mode, there has to be just one test
+ # Anyway, deterministically pick the first one
+ first_test_name = next(iter(self._tests_state_map))
+ test = self._tests_state_map[first_test_name]
+ if test.state == "passed":
+ line = test.state
+ else:
+ line = "{}: {}".format(test.state, test.reason)
+ with open(path, mode="w") as f:
+ print(line, file=f)
diff --git a/tests/atf_python/sys/Makefile b/tests/atf_python/sys/Makefile
new file mode 100644
index 000000000000..ff4cf17b85d2
--- /dev/null
+++ b/tests/atf_python/sys/Makefile
@@ -0,0 +1,11 @@
+.include <src.opts.mk>
+
+.PATH: ${.CURDIR}
+
+FILES= __init__.py
+SUBDIR= net
+
+.include <bsd.own.mk>
+FILESDIR= ${TESTSBASE}/atf_python/sys
+
+.include <bsd.prog.mk>
diff --git a/tests/atf_python/sys/__init__.py b/tests/atf_python/sys/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
--- /dev/null
+++ b/tests/atf_python/sys/__init__.py
diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile
new file mode 100644
index 000000000000..05b1d8afe863
--- /dev/null
+++ b/tests/atf_python/sys/net/Makefile
@@ -0,0 +1,10 @@
+.include <src.opts.mk>
+
+.PATH: ${.CURDIR}
+
+FILES= __init__.py rtsock.py tools.py vnet.py
+
+.include <bsd.own.mk>
+FILESDIR= ${TESTSBASE}/atf_python/sys/net
+
+.include <bsd.prog.mk>
diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
--- /dev/null
+++ b/tests/atf_python/sys/net/__init__.py
diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py
new file mode 100755
index 000000000000..788e863f8b28
--- /dev/null
+++ b/tests/atf_python/sys/net/rtsock.py
@@ -0,0 +1,604 @@
+#!/usr/local/bin/python3
+import os
+import socket
+import struct
+import sys
+from ctypes import c_byte
+from ctypes import c_char
+from ctypes import c_int
+from ctypes import c_long
+from ctypes import c_uint32
+from ctypes import c_ulong
+from ctypes import c_ushort
+from ctypes import sizeof
+from ctypes import Structure
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+
+def roundup2(val: int, num: int) -> int:
+ if val % num:
+ return (val | (num - 1)) + 1
+ else:
+ return val
+
+
+class RtSockException(OSError):
+ pass
+
+
+class RtConst:
+ RTM_VERSION = 5
+ ALIGN = sizeof(c_long)
+
+ AF_INET = socket.AF_INET
+ AF_INET6 = socket.AF_INET6
+ AF_LINK = socket.AF_LINK
+
+ RTA_DST = 0x1
+ RTA_GATEWAY = 0x2
+ RTA_NETMASK = 0x4
+ RTA_GENMASK = 0x8
+ RTA_IFP = 0x10
+ RTA_IFA = 0x20
+ RTA_AUTHOR = 0x40
+ RTA_BRD = 0x80
+
+ RTM_ADD = 1
+ RTM_DELETE = 2
+ RTM_CHANGE = 3
+ RTM_GET = 4
+
+ RTF_UP = 0x1
+ RTF_GATEWAY = 0x2
+ RTF_HOST = 0x4
+ RTF_REJECT = 0x8
+ RTF_DYNAMIC = 0x10
+ RTF_MODIFIED = 0x20
+ RTF_DONE = 0x40
+ RTF_XRESOLVE = 0x200
+ RTF_LLINFO = 0x400
+ RTF_LLDATA = 0x400
+ RTF_STATIC = 0x800
+ RTF_BLACKHOLE = 0x1000
+ RTF_PROTO2 = 0x4000
+ RTF_PROTO1 = 0x8000
+ RTF_PROTO3 = 0x40000
+ RTF_FIXEDMTU = 0x80000
+ RTF_PINNED = 0x100000
+ RTF_LOCAL = 0x200000
+ RTF_BROADCAST = 0x400000
+ RTF_MULTICAST = 0x800000
+ RTF_STICKY = 0x10000000
+ RTF_RNH_LOCKED = 0x40000000
+ RTF_GWFLAG_COMPAT = 0x80000000
+
+ RTV_MTU = 0x1
+ RTV_HOPCOUNT = 0x2
+ RTV_EXPIRE = 0x4
+ RTV_RPIPE = 0x8
+ RTV_SPIPE = 0x10
+ RTV_SSTHRESH = 0x20
+ RTV_RTT = 0x40
+ RTV_RTTVAR = 0x80
+ RTV_WEIGHT = 0x100
+
+ @staticmethod
+ def get_props(prefix: str) -> List[str]:
+ return [n for n in dir(RtConst) if n.startswith(prefix)]
+
+ @staticmethod
+ def get_name(prefix: str, value: int) -> str:
+ props = RtConst.get_props(prefix)
+ for prop in props:
+ if getattr(RtConst, prop) == value:
+ return prop
+ return "U:{}:{}".format(prefix, value)
+
+ @staticmethod
+ def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
+ props = RtConst.get_props(prefix)
+ propmap = {getattr(RtConst, prop): prop for prop in props}
+ v = 1
+ ret = {}
+ while value:
+ if v & value:
+ if v in propmap:
+ ret[v] = propmap[v]
+ else:
+ ret[v] = hex(v)
+ value -= v
+ v *= 2
+ return ret
+
+ @staticmethod
+ def get_bitmask_str(prefix: str, value: int) -> str:
+ bmap = RtConst.get_bitmask_map(prefix, value)
+ return ",".join([v for k, v in bmap.items()])
+
+
+class RtMetrics(Structure):
+ _fields_ = [
+ ("rmx_locks", c_ulong),
+ ("rmx_mtu", c_ulong),
+ ("rmx_hopcount", c_ulong),
+ ("rmx_expire", c_ulong),
+ ("rmx_recvpipe", c_ulong),
+ ("rmx_sendpipe", c_ulong),
+ ("rmx_ssthresh", c_ulong),
+ ("rmx_rtt", c_ulong),
+ ("rmx_rttvar", c_ulong),
+ ("rmx_pksent", c_ulong),
+ ("rmx_weight", c_ulong),
+ ("rmx_nhidx", c_ulong),
+ ("rmx_filler", c_ulong * 2),
+ ]
+
+
+class RtMsgHdr(Structure):
+ _fields_ = [
+ ("rtm_msglen", c_ushort),
+ ("rtm_version", c_byte),
+ ("rtm_type", c_byte),
+ ("rtm_index", c_ushort),
+ ("_rtm_spare1", c_ushort),
+ ("rtm_flags", c_int),
+ ("rtm_addrs", c_int),
+ ("rtm_pid", c_int),
+ ("rtm_seq", c_int),
+ ("rtm_errno", c_int),
+ ("rtm_fmask", c_int),
+ ("rtm_inits", c_ulong),
+ ("rtm_rmx", RtMetrics),
+ ]
+
+
+class SockaddrIn(Structure):
+ _fields_ = [
+ ("sin_len", c_byte),
+ ("sin_family", c_byte),
+ ("sin_port", c_ushort),
+ ("sin_addr", c_uint32),
+ ("sin_zero", c_char * 8),
+ ]
+
+
+class SockaddrIn6(Structure):
+ _fields_ = [
+ ("sin6_len", c_byte),
+ ("sin6_family", c_byte),
+ ("sin6_port", c_ushort),
+ ("sin6_flowinfo", c_uint32),
+ ("sin6_addr", c_byte * 16),
+ ("sin6_scope_id", c_uint32),
+ ]
+
+
+class SockaddrDl(Structure):
+ _fields_ = [
+ ("sdl_len", c_byte),
+ ("sdl_family", c_byte),
+ ("sdl_index", c_ushort),
+ ("sdl_type", c_byte),
+ ("sdl_nlen", c_byte),
+ ("sdl_alen", c_byte),
+ ("sdl_slen", c_byte),
+ ("sdl_data", c_byte * 8),
+ ]
+
+
+class SaHelper(object):
+ @staticmethod
+ def is_ipv6(ip: str) -> bool:
+ return ":" in ip
+
+ @staticmethod
+ def ip_sa(ip: str, scopeid: int = 0) -> bytes:
+ if SaHelper.is_ipv6(ip):
+ return SaHelper.ip6_sa(ip, scopeid)
+ else:
+ return SaHelper.ip4_sa(ip)
+
+ @staticmethod
+ def ip4_sa(ip: str) -> bytes:
+ addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
+ sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
+ return bytes(sin)
+
+ @staticmethod
+ def ip6_sa(ip6: str, scopeid: int) -> bytes:
+ addr_bytes = (c_byte * 16)()
+ for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)):
+ addr_bytes[i] = b
+ sin6 = SockaddrIn6(
+ sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid
+ )
+ return bytes(sin6)
+
+ @staticmethod
+ def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes:
+ sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
+ return bytes(sa)
+
+ @staticmethod
+ def pxlen4_sa(pxlen: int) -> bytes:
+ return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen))
+
+ @staticmethod
+ def pxlen_to_ip4(pxlen: int) -> str:
+ if pxlen == 32:
+ return "255.255.255.255"
+ else:
+ addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1)
+ addr_bytes = struct.pack("!I", addr)
+ return socket.inet_ntop(socket.AF_INET, addr_bytes)
+
+ @staticmethod
+ def pxlen6_sa(pxlen: int) -> bytes:
+ return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen))
+
+ @staticmethod
+ def pxlen_to_ip6(pxlen: int) -> str:
+ ip6_b = [0] * 16
+ start = 0
+ while pxlen > 8:
+ ip6_b[start] = 0xFF
+ pxlen -= 8
+ start += 1
+ ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1)
+ return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b))
+
+ @staticmethod
+ def print_sa_inet(sa: bytes):
+ if len(sa) < 8:
+ raise RtSockException("IPv4 sa size too small: {}".format(len(sa)))
+ addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
+ return "{}".format(addr)
+
+ @staticmethod
+ def print_sa_inet6(sa: bytes):
+ if len(sa) < sizeof(SockaddrIn6):
+ raise RtSockException("IPv6 sa size too small: {}".format(len(sa)))
+ addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
+ scopeid = struct.unpack(">I", sa[24:28])[0]
+ return "{} scopeid {}".format(addr, scopeid)
+
+ @staticmethod
+ def print_sa_link(sa: bytes, hd: Optional[bool] = True):
+ if len(sa) < sizeof(SockaddrDl):
+ raise RtSockException("LINK sa size too small: {}".format(len(sa)))
+ sdl = SockaddrDl.from_buffer_copy(sa)
+ if sdl.sdl_index:
+ ifindex = "link#{} ".format(sdl.sdl_index)
+ else:
+ ifindex = ""
+ if sdl.sdl_nlen:
+ iface_offset = 8
+ if sdl.sdl_nlen + iface_offset > len(sa):
+ raise RtSockException(
+ "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))
+ )
+ ifname = "ifname:{} ".format(
+ bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen])
+ )
+ else:
+ ifname = ""
+ return "{}{}".format(ifindex, ifname)
+
+ @staticmethod
+ def print_sa_unknown(sa: bytes):
+ return "unknown_type:{}".format(sa[1])
+
+ @classmethod
+ def print_sa(cls, sa: bytes, hd: Optional[bool] = False):
+ if sa[0] != len(sa):
+ raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
+
+ if len(sa) < 2:
+ raise Exception(
+ "sa type {} too short: {}".format(
+ RtConst.get_name("AF_", sa[1]), len(sa)
+ )
+ )
+
+ if sa[1] == socket.AF_INET:
+ text = cls.print_sa_inet(sa)
+ elif sa[1] == socket.AF_INET6:
+ text = cls.print_sa_inet6(sa)
+ elif sa[1] == socket.AF_LINK:
+ text = cls.print_sa_link(sa)
+ else:
+ text = cls.print_sa_unknown(sa)
+ if hd:
+ dump = " [{!r}]".format(sa)
+ else:
+ dump = ""
+ return "{}{}".format(text, dump)
+
+
+class BaseRtsockMessage(object):
+ def __init__(self, rtm_type):
+ self.rtm_type = rtm_type
+ self.sa = SaHelper()
+
+ @staticmethod
+ def print_rtm_type(rtm_type):
+ return RtConst.get_name("RTM_", rtm_type)
+
+ @property
+ def rtm_type_str(self):
+ return self.print_rtm_type(self.rtm_type)
+
+
+class RtsockRtMessage(BaseRtsockMessage):
+ messages = [
+ RtConst.RTM_ADD,
+ RtConst.RTM_DELETE,
+ RtConst.RTM_CHANGE,
+ RtConst.RTM_GET,
+ ]
+
+ def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
+ super().__init__(rtm_type)
+ self.rtm_flags = 0
+ self.rtm_seq = rtm_seq
+ self._attrs = {}
+ self.rtm_errno = 0
+ self.rtm_pid = 0
+ self.rtm_inits = 0
+ self.rtm_rmx = RtMetrics()
+ self._orig_data = None
+ if dst_sa:
+ self.add_sa_attr(RtConst.RTA_DST, dst_sa)
+ if mask_sa:
+ self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
+
+ def add_sa_attr(self, attr_type, attr_bytes: bytes):
+ self._attrs[attr_type] = attr_bytes
+
+ def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0):
+ if ":" in ip_addr:
+ self.add_ip6_attr(attr_type, ip_addr, scopeid)
+ else:
+ self.add_ip4_attr(attr_type, ip_addr)
+
+ def add_ip4_attr(self, attr_type, ip: str):
+ self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
+
+ def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
+ self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
+
+ def add_link_attr(self, attr_type, ifindex: Optional[int] = 0):
+ self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
+
+ def get_sa(self, attr_type) -> bytes:
+ return self._attrs.get(attr_type)
+
+ def print_message(self):
+ # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
+ if self._orig_data:
+ rtm_len = len(self._orig_data)
+ else:
+ rtm_len = len(bytes(self))
+ print(
+ "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
+ self.rtm_type_str,
+ rtm_len,
+ self.rtm_pid,
+ self.rtm_seq,
+ self.rtm_errno,
+ RtConst.get_bitmask_str("RTF_", self.rtm_flags),
+ )
+ )
+ rtm_addrs = sum(list(self._attrs.keys()))
+ print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
+ for attr in sorted(self._attrs.keys()):
+ sa_data = SaHelper.print_sa(self._attrs[attr])
+ print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
+
+ def print_in_message(self):
+ print("vvvvvvvv IN vvvvvvvv")
+ self.print_message()
+ print()
+
+ def verify_sa_inet(self, sa_data):
+ if len(sa_data) < 8:
+ raise Exception("IPv4 sa size too small: {}".format(sa_data))
+ if sa_data[0] > len(sa_data):
+ raise Exception(
+ "IPv4 sin_len too big: {} vs sa size {}: {}".format(
+ sa_data[0], len(sa_data), sa_data
+ )
+ )
+ sin = SockaddrIn.from_buffer_copy(sa_data)
+ assert sin.sin_port == 0
+ assert sin.sin_zero == [0] * 8
+
+ def compare_sa(self, sa_type, sa_data):
+ if len(sa_data) < 4:
+ sa_type_name = RtConst.get_name("RTA_", sa_type)
+ raise Exception(
+ "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data))
+ )
+ our_sa = self._attrs[sa_type]
+ assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa)
+ assert len(sa_data) == len(our_sa)
+ assert sa_data == our_sa
+
+ def verify(self, rtm_type: int, rtm_sa):
+ assert self.rtm_type_str == self.print_rtm_type(rtm_type)
+ assert self.rtm_errno == 0
+ hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
+ assert hdr._rtm_spare1 == 0
+ for sa_type, sa_data in rtm_sa.items():
+ if sa_type not in self._attrs:
+ sa_type_name = RtConst.get_name("RTA_", sa_type)
+ raise Exception("SA type {} not present".format(sa_type_name))
+ self.compare_sa(sa_type, sa_data)
+
+ @classmethod
+ def from_bytes(cls, data: bytes):
+ if len(data) < sizeof(RtMsgHdr):
+ raise Exception(
+ "messages size {} is less than expected {}".format(
+ len(data), sizeof(RtMsgHdr)
+ )
+ )
+ hdr = RtMsgHdr.from_buffer_copy(data)
+
+ self = cls(hdr.rtm_type)
+ self.rtm_flags = hdr.rtm_flags
+ self.rtm_seq = hdr.rtm_seq
+ self.rtm_errno = hdr.rtm_errno
+ self.rtm_pid = hdr.rtm_pid
+ self.rtm_inits = hdr.rtm_inits
+ self.rtm_rmx = hdr.rtm_rmx
+ self._orig_data = data
+
+ off = sizeof(RtMsgHdr)
+ v = 1
+ addrs_mask = hdr.rtm_addrs
+ while addrs_mask:
+ if addrs_mask & v:
+ addrs_mask -= v
+
+ if off + data[off] > len(data):
+ raise Exception(
+ "SA sizeof for {} > total message length: {}+{} > {}".format(
+ RtConst.get_name("RTA_", v), off, data[off], len(data)
+ )
+ )
+ self._attrs[v] = data[off : off + data[off]]
+ off += roundup2(data[off], RtConst.ALIGN)
+ v *= 2
+ return self
+
+ def __bytes__(self):
+ sz = sizeof(RtMsgHdr)
+ addrs_mask = 0
+ for k, v in self._attrs.items():
+ sz += roundup2(len(v), RtConst.ALIGN)
+ addrs_mask += k
+ hdr = RtMsgHdr(
+ rtm_msglen=sz,
+ rtm_version=RtConst.RTM_VERSION,
+ rtm_type=self.rtm_type,
+ rtm_flags=self.rtm_flags,
+ rtm_seq=self.rtm_seq,
+ rtm_addrs=addrs_mask,
+ rtm_inits=self.rtm_inits,
+ rtm_rmx=self.rtm_rmx,
+ )
+ buf = bytearray(sz)
+ buf[0 : sizeof(RtMsgHdr)] = hdr
+ off = sizeof(RtMsgHdr)
+ for attr in sorted(self._attrs.keys()):
+ v = self._attrs[attr]
+ sa_len = len(v)
+ buf[off : off + sa_len] = v
+ off += roundup2(len(v), RtConst.ALIGN)
+ return bytes(buf)
+
+
+class Rtsock:
+ def __init__(self):
+ self.socket = self._setup_rtsock()
+ self.rtm_seq = 1
+ self.msgmap = self.build_msgmap()
+
+ def build_msgmap(self):
+ classes = [RtsockRtMessage]
+ xmap = {}
+ for cls in classes:
+ for message in cls.messages:
+ xmap[message] = cls
+ return xmap
+
+ def get_seq(self):
+ ret = self.rtm_seq
+ self.rtm_seq += 1
+ return ret
+
+ def get_weight(self, weight) -> int:
+ if weight:
+ return weight
+ else:
+ return 1 # RT_DEFAULT_WEIGHT
+
+ def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]):
+ px = prefix.split("/")
+ addr_sa = SaHelper.ip_sa(px[0])
+ if len(px) > 1:
+ pxlen = int(px[1])
+ if SaHelper.is_ipv6(px[0]):
+ mask_sa = SaHelper.pxlen6_sa(pxlen)
+ else:
+ mask_sa = SaHelper.pxlen4_sa(pxlen)
+ else:
+ mask_sa = None
+ msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa)
+ if isinstance(gw, bytes):
+ msg.add_sa_attr(RtConst.RTA_GATEWAY, gw)
+ else:
+ # String
+ msg.add_ip_attr(RtConst.RTA_GATEWAY, gw)
+ return msg
+
+ def new_rtm_add(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw)
+
+ def new_rtm_del(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw)
+
+ def new_rtm_change(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw)
+
+ def _setup_rtsock(self) -> socket.socket:
+ s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
+ return s
+
+ def print_hd(self, data: bytes):
+ width = 16
+ print("==========================================")
+ for chunk in [data[i : i + width] for i in range(0, len(data), width)]:
+ for b in chunk:
+ print("0x{:02X} ".format(b), end="")
+ print()
+ print()
+
+ def write_message(self, msg):
+ print("vvvvvvvv OUT vvvvvvvv")
+ msg.print_message()
+ print()
+ msg_bytes = bytes(msg)
+ ret = os.write(self.socket.fileno(), msg_bytes)
+ if ret != -1:
+ assert ret == len(msg_bytes)
+
+ def parse_message(self, data: bytes):
+ if len(data) < 4:
+ raise OSError("Short read from rtsock: {} bytes".format(len(data)))
+ rtm_type = data[4]
+ if rtm_type not in self.msgmap:
+ return None
+
+ def write_data(self, data: bytes):
+ self.socket.send(data)
+
+ def read_data(self, seq: Optional[int] = None) -> bytes:
+ while True:
+ data = self.socket.recv(4096)
+ if seq is None:
+ break
+ if len(data) > sizeof(RtMsgHdr):
+ hdr = RtMsgHdr.from_buffer_copy(data)
+ if hdr.rtm_seq == seq:
+ break
+ return data
+
+ def read_message(self) -> bytes:
+ data = self.read_data()
+ return self.parse_message(data)
diff --git a/tests/atf_python/sys/net/tools.py b/tests/atf_python/sys/net/tools.py
new file mode 100644
index 000000000000..9f44872c2c37
--- /dev/null
+++ b/tests/atf_python/sys/net/tools.py
@@ -0,0 +1,33 @@
+#!/usr/local/bin/python3
+import json
+import os
+import socket
+import time
+from ctypes import cdll
+from ctypes import get_errno
+from ctypes.util import find_library
+from typing import List
+from typing import Optional
+
+
+class ToolsHelper(object):
+ NETSTAT_PATH = "/usr/bin/netstat"
+
+ @classmethod
+ def get_output(cls, cmd: str, verbose=False) -> str:
+ if verbose:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+ @classmethod
+ def get_routes(cls, family: str, fibnum: int = 0):
+ family_key = {"inet": "-4", "inet6": "-6"}.get(family)
+ out = cls.get_output(
+ "{} {} -rn -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum)
+ )
+ js = json.loads(out)
+ js = js["statistics"]["route-information"]["route-table"]["rt-family"]
+ if js:
+ return js[0]["rt-entry"]
+ else:
+ return []
diff --git a/tests/atf_python/sys/net/vnet.py b/tests/atf_python/sys/net/vnet.py
new file mode 100644
index 000000000000..0957364f627c
--- /dev/null
+++ b/tests/atf_python/sys/net/vnet.py
@@ -0,0 +1,203 @@
+#!/usr/local/bin/python3
+import os
+import socket
+import time
+from ctypes import cdll
+from ctypes import get_errno
+from ctypes.util import find_library
+from typing import List
+from typing import Optional
+
+
+def run_cmd(cmd: str) -> str:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+
+class VnetInterface(object):
+ INTERFACES_FNAME = "created_interfaces.lst"
+
+ # defines from net/if_types.h
+ IFT_LOOP = 0x18
+ IFT_ETHER = 0x06
+
+ def __init__(self, iface_name: str):
+ self.name = iface_name
+ self.vnet_name = ""
+ self.jailed = False
+ if iface_name.startswith("lo"):
+ self.iftype = self.IFT_LOOP
+ else:
+ self.iftype = self.IFT_ETHER
+
+ @property
+ def ifindex(self):
+ return socket.if_nametoindex(self.name)
+
+ def set_vnet(self, vnet_name: str):
+ self.vnet_name = vnet_name
+
+ def set_jailed(self, jailed: bool):
+ self.jailed = jailed
+
+ def run_cmd(self, cmd):
+ if self.vnet_name and not self.jailed:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ run_cmd(cmd)
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInterface.INTERFACES_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @classmethod
+ def create_iface(cls, iface_name: str):
+ name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip()
+ if not name:
+ raise Exception("Unable to create iface {}".format(iface_name))
+ cls.file_append_line(name)
+ if name.startswith("epair"):
+ cls.file_append_line(name[:-1] + "b")
+ return cls(name)
+
+ @staticmethod
+ def cleanup_ifaces():
+ try:
+ with open(VnetInterface.INTERFACES_FNAME, "r") as f:
+ for line in f:
+ run_cmd("/sbin/ifconfig {} destroy".format(line.strip()))
+ os.unlink(VnetInterface.INTERFACES_FNAME)
+ except Exception:
+ pass
+
+ def setup_addr(self, addr: str):
+ if ":" in addr:
+ family = "inet6"
+ else:
+ family = "inet"
+ cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr)
+ self.run_cmd(cmd)
+
+ def delete_addr(self, addr: str):
+ if ":" in addr:
+ cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr)
+ else:
+ cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr)
+ self.run_cmd(cmd)
+
+ def turn_up(self):
+ cmd = "/sbin/ifconfig {} up".format(self.name)
+ self.run_cmd(cmd)
+
+ def enable_ipv6(self):
+ cmd = "/usr/sbin/ndp -i {} -disabled".format(self.name)
+ self.run_cmd(cmd)
+
+
+class VnetInstance(object):
+ JAILS_FNAME = "created_jails.lst"
+
+ def __init__(self, vnet_name: str, jid: int, ifaces: List[VnetInterface]):
+ self.name = vnet_name
+ self.jid = jid
+ self.ifaces = ifaces
+ for iface in ifaces:
+ iface.set_vnet(vnet_name)
+ iface.set_jailed(True)
+
+ def run_vnet_cmd(self, cmd):
+ if self.vnet_name:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ return run_cmd(cmd)
+
+ @staticmethod
+ def wait_interface(vnet_name: str, iface_name: str):
+ cmd = "jexec {} /sbin/ifconfig -l".format(vnet_name)
+ for i in range(50):
+ ifaces = run_cmd(cmd).strip().split(" ")
+ if iface_name in ifaces:
+ return True
+ time.sleep(0.1)
+ return False
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInstance.JAILS_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @staticmethod
+ def cleanup_vnets():
+ try:
+ with open(VnetInstance.JAILS_FNAME) as f:
+ for line in f:
+ run_cmd("/usr/sbin/jail -r {}".format(line.strip()))
+ os.unlink(VnetInstance.JAILS_FNAME)
+ except Exception:
+ pass
+
+ @classmethod
+ def create_with_interfaces(cls, vnet_name: str, ifaces: List[VnetInterface]):
+ iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces])
+ cmd = "/usr/sbin/jail -i -c name={} persist vnet {}".format(
+ vnet_name, iface_cmds
+ )
+ jid_str = run_cmd(cmd)
+ jid = int(jid_str)
+ if jid <= 0:
+ raise Exception("Jail creation failed, output: {}".format(jid))
+ cls.file_append_line(vnet_name)
+
+ for iface in ifaces:
+ if cls.wait_interface(vnet_name, iface.name):
+ continue
+ raise Exception(
+ "Interface {} has not appeared in vnet {}".format(iface.name, vnet_name)
+ )
+ return cls(vnet_name, jid, ifaces)
+
+ @staticmethod
+ def attach_jid(jid: int):
+ _path: Optional[str] = find_library("c")
+ if _path is None:
+ raise Exception("libc not found")
+ path: str = _path
+ libc = cdll.LoadLibrary(path)
+ if libc.jail_attach(jid) != 0:
+ raise Exception("jail_attach() failed: errno {}".format(get_errno()))
+
+ def attach(self):
+ self.attach_jid(self.jid)
+
+
+class SingleVnetTestTemplate(object):
+ num_epairs = 1
+ IPV6_PREFIXES: List[str] = []
+ IPV4_PREFIXES: List[str] = []
+
+ def setup_method(self, method):
+ test_name = method.__name__
+ vnet_name = "jail_{}".format(test_name)
+ ifaces = []
+ for i in range(self.num_epairs):
+ ifaces.append(VnetInterface.create_iface("epair"))
+ self.vnet = VnetInstance.create_with_interfaces(vnet_name, ifaces)
+ self.vnet.attach()
+ for i, addr in enumerate(self.IPV6_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.enable_ipv6()
+ iface.setup_addr(addr)
+ for i, addr in enumerate(self.IPV4_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.setup_addr(addr)
+
+ def cleanup(self, nodeid: str):
+ print("==== vnet cleanup ===")
+ VnetInstance.cleanup_vnets()
+ VnetInterface.cleanup_ifaces()
+
+ def run_cmd(self, cmd: str) -> str:
+ return os.popen(cmd).read()
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 000000000000..193d2adfb5e0
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,121 @@
+import pytest
+from atf_python.atf_pytest import ATFHandler
+
+
+PLUGIN_ENABLED = False
+DEFAULT_HANDLER = None
+
+
+def get_handler():
+ global DEFAULT_HANDLER
+ if DEFAULT_HANDLER is None:
+ DEFAULT_HANDLER = ATFHandler()
+ return DEFAULT_HANDLER
+
+
+def pytest_addoption(parser):
+ """Add file output"""
+ # Add meta-values
+ group = parser.getgroup("general", "Running and selection options")
+ group.addoption("--atf-var", dest="atf_vars", action="append", default=[])
+ group.addoption(
+ "--atf-source-dir",
+ type=str,
+ dest="atf_source_dir",
+ help="Path to the test source directory",
+ )
+ group.addoption(
+ "--atf-cleanup",
+ default=False,
+ action="store_true",
+ dest="atf_cleanup",
+ help="Call cleanup procedure for a given test",
+ )
+ group = parser.getgroup("terminal reporting", "reporting", after="general")
+ group.addoption(
+ "--atf",
+ default=False,
+ action="store_true",
+ help="Enable test listing/results output in atf format",
+ )
+ group.addoption(
+ "--atf-file",
+ type=str,
+ dest="atf_file",
+ help="Path to the status file provided by atf runtime",
+ )
+
+
+@pytest.mark.trylast
+def pytest_configure(config):
+ if config.option.help:
+ return
+
+ # Register markings anyway to avoid warnings
+ config.addinivalue_line("markers", "require_user(name): user to run the test with")
+ config.addinivalue_line(
+ "markers", "require_arch(names): List[str] of support archs"
+ )
+ # config.addinivalue_line("markers", "require_config(config): List[Tuple[str,Any]] of k=v pairs")
+ config.addinivalue_line(
+ "markers", "require_diskspace(amount): str with required diskspace"
+ )
+ config.addinivalue_line(
+ "markers", "require_files(space): List[str] with file paths"
+ )
+ config.addinivalue_line(
+ "markers", "require_machine(names): List[str] of support machine types"
+ )
+ config.addinivalue_line(
+ "markers", "require_memory(amount): str with required memory"
+ )
+ config.addinivalue_line(
+ "markers", "require_progs(space): List[str] with file paths"
+ )
+ config.addinivalue_line(
+ "markers", "timeout(dur): int/float with max duration in sec"
+ )
+
+ global PLUGIN_ENABLED
+ PLUGIN_ENABLED = config.option.atf
+ if not PLUGIN_ENABLED:
+ return
+ get_handler()
+
+ if config.option.collectonly:
+ # Need to output list of tests to stdout, hence override
+ # standard reporter plugin
+ reporter = config.pluginmanager.getplugin("terminalreporter")
+ if reporter:
+ config.pluginmanager.unregister(reporter)
+
+
+def pytest_collection_modifyitems(session, config, items):
+ """If cleanup is requested, replace collected tests with their cleanups (if any)"""
+ if PLUGIN_ENABLED and config.option.atf_cleanup:
+ new_items = []
+ handler = get_handler()
+ for obj in items:
+ if handler.has_object_cleanup(obj):
+ handler.override_runtest(obj)
+ new_items.append(obj)
+ items.clear()
+ items.extend(new_items)
+
+
+def pytest_collection_finish(session):
+ if PLUGIN_ENABLED and session.config.option.collectonly:
+ handler = get_handler()
+ handler.list_tests(session.items)
+
+
+def pytest_runtest_logreport(report):
+ if PLUGIN_ENABLED:
+ handler = get_handler()
+ handler.add_report(report)
+
+
+def pytest_unconfigure(config):
+ if PLUGIN_ENABLED and config.option.atf_file:
+ handler = get_handler()
+ handler.write_report(config.option.atf_file)
diff --git a/tests/freebsd_test_suite/Makefile b/tests/freebsd_test_suite/Makefile
new file mode 100644
index 000000000000..c929ca2880eb
--- /dev/null
+++ b/tests/freebsd_test_suite/Makefile
@@ -0,0 +1,13 @@
+.include <src.opts.mk>
+
+PACKAGE= tests
+PROG_CXX= atf_pytest_wrapper
+SRCS= atf_pytest_wrapper.cpp
+CXXSTD= c++17
+MAN=
+BINDIR=
+
+.include <bsd.own.mk>
+DESTDIR=${TESTSBASE}
+
+.include <bsd.prog.mk>
diff --git a/tests/freebsd_test_suite/atf_pytest_wrapper.cpp b/tests/freebsd_test_suite/atf_pytest_wrapper.cpp
new file mode 100644
index 000000000000..11fd3c47d507
--- /dev/null
+++ b/tests/freebsd_test_suite/atf_pytest_wrapper.cpp
@@ -0,0 +1,192 @@
+#include <format>
+#include <iostream>
+#include <string>
+#include <vector>
+#include <stdlib.h>
+#include <unistd.h>
+
+class Handler {
+ private:
+ const std::string kPytestName = "pytest";
+ const std::string kCleanupSuffix = ":cleanup";
+ const std::string kPythonPathEnv = "PYTHONPATH";
+ public:
+ // Test listing requested
+ bool flag_list = false;
+ // Output debug data (will break listing)
+ bool flag_debug = false;
+ // Cleanup for the test requested
+ bool flag_cleanup = false;
+ // Test source directory (provided by ATF)
+ std::string src_dir;
+ // Path to write test status to (provided by ATF)
+ std::string dst_file;
+ // Path to add to PYTHONPATH (provided by the schebang args)
+ std::string python_path;
+ // Path to the script (provided by the schebang wrapper)
+ std::string script_path;
+ // Name of the test to run (provided by ATF)
+ std::string test_name;
+ // kv pairs (provided by ATF)
+ std::vector<std::string> kv_list;
+ // our binary name
+ std::string binary_name;
+
+ static std::vector<std::string> ToVector(int argc, char **argv) {
+ std::vector<std::string> ret;
+
+ for (int i = 0; i < argc; i++) {
+ ret.emplace_back(std::string(argv[i]));
+ }
+ return ret;
+ }
+
+ static void PrintVector(std::string prefix, const std::vector<std::string> &vec) {
+ std::cerr << prefix << ": ";
+ for (auto &val: vec) {
+ std::cerr << "'" << val << "' ";
+ }
+ std::cerr << std::endl;
+ }
+
+ void Usage(std::string msg, bool exit_with_error) {
+ std::cerr << binary_name << ": ERROR: " << msg << "." << std::endl;
+ std::cerr << binary_name << ": See atf-test-program(1) for usage details." << std::endl;
+ exit(exit_with_error != 0);
+ }
+
+ // Parse args received from the OS. There can be multiple valid options:
+ // * with schebang args (#!/binary -P/path):
+ // atf_wrap '-P /path' /path/to/script -l
+ // * without schebang args
+ // atf_wrap /path/to/script -l
+ // Running test:
+ // atf_wrap '-P /path' /path/to/script -r /path1 -s /path2 -vk1=v1 testname
+ void Parse(int argc, char **argv) {
+ if (flag_debug) {
+ PrintVector("IN", ToVector(argc, argv));
+ }
+ // getopt() skips the first argument (as it is typically binary name)
+ // it is possible to have either '-P\s*/path' followed by the script name
+ // or just the script name. Parse kernel-provided arg manually and adjust
+ // array to make getopt work
+
+ binary_name = std::string(argv[0]);
+ argc--; argv++;
+ // parse -P\s*path from the kernel.
+ if (argc > 0 && !strncmp(argv[0], "-P", 2)) {
+ char *path = &argv[0][2];
+ while (*path == ' ')
+ path++;
+ python_path = std::string(path);
+ argc--; argv++;
+ }
+
+ // The next argument is a script name. Copy and keep argc/argv the same
+ // Show usage for empty args
+ if (argc == 0) {
+ Usage("Must provide a test case name", true);
+ }
+ script_path = std::string(argv[0]);
+
+ int c;
+ while ((c = getopt(argc, argv, "lr:s:v:")) != -1) {
+ switch (c) {
+ case 'l':
+ flag_list = true;
+ break;
+ case 's':
+ src_dir = std::string(optarg);
+ break;
+ case 'r':
+ dst_file = std::string(optarg);
+ break;
+ case 'v':
+ kv_list.emplace_back(std::string(optarg));
+ break;
+ default:
+ Usage("Unknown option -" + std::string(1, static_cast<char>(c)), true);
+ }
+ }
+ argc -= optind;
+ argv += optind;
+
+ if (flag_list) {
+ return;
+ }
+ // There should be just one argument with the test name
+ if (argc != 1) {
+ Usage("Must provide a test case name", true);
+ }
+ test_name = std::string(argv[0]);
+ if (test_name.size() > kCleanupSuffix.size() &&
+ std::equal(kCleanupSuffix.rbegin(), kCleanupSuffix.rend(), test_name.rbegin())) {
+ test_name = test_name.substr(0, test_name.size() - kCleanupSuffix.size());
+ flag_cleanup = true;
+ }
+ }
+
+ std::vector<std::string> BuildArgs() {
+ std::vector<std::string> args = {"pytest", "-p", "no:cacheprovider", "-s", "--atf"};
+
+ if (flag_list) {
+ args.push_back("--co");
+ args.push_back(script_path);
+ return args;
+ }
+ if (flag_cleanup) {
+ args.push_back("--atf-cleanup");
+ }
+ if (!src_dir.empty()) {
+ args.push_back("--atf-source-dir");
+ args.push_back(src_dir);
+ }
+ if (!dst_file.empty()) {
+ args.push_back("--atf-file");
+ args.push_back(dst_file);
+ }
+ for (auto &pair: kv_list) {
+ args.push_back("--atf-var");
+ args.push_back(pair);
+ }
+ // Create nodeid from the test path &name
+ args.push_back(script_path + "::" + test_name);
+ return args;
+ }
+
+ void SetEnv() {
+ if (!python_path.empty()) {
+ char *env_path = getenv(kPythonPathEnv.c_str());
+ if (env_path != nullptr) {
+ python_path = python_path + ":" + std::string(env_path);
+ }
+ setenv(kPythonPathEnv.c_str(), python_path.c_str(), 1);
+ }
+ }
+
+ int Run(std::string binary, std::vector<std::string> args) {
+ if (flag_debug) {
+ PrintVector("OUT", args);
+ }
+ // allocate array with final NULL
+ char **arr = new char*[args.size() + 1]();
+ for (unsigned long i = 0; i < args.size(); i++) {
+ // work around 'char *const *'
+ arr[i] = strdup(args[i].c_str());
+ }
+ return (execvp(binary.c_str(), arr) != 0);
+ }
+
+ int Process() {
+ SetEnv();
+ return Run(kPytestName, BuildArgs());
+ }
+};
+
+
+int main(int argc, char **argv) {
+ Handler handler;
+
+ handler.Parse(argc, argv);
+ return handler.Process();
+}