xref: /linux/tools/testing/selftests/drivers/net/lib/py/env.py (revision 6be87fbb27763c2999e1c69bbec1f3a63cf05422)
1# SPDX-License-Identifier: GPL-2.0
2
3import ipaddress
4import os
5import re
6import time
7from pathlib import Path
8from lib.py import KsftSkipEx, KsftXfailEx
9from lib.py import ksft_setup, wait_file
10from lib.py import cmd, ethtool, ip, CmdExitFailure
11from lib.py import NetNS, NetdevSimDev
12from .remote import Remote
13from . import bpftool
14
15
16class NetDrvEnvBase:
17    """
18    Base class for a NIC / host environments
19
20    Attributes:
21      test_dir: Path to the source directory of the test
22      net_lib_dir: Path to the net/lib directory
23    """
24    def __init__(self, src_path):
25        self.src_path = Path(src_path)
26        self.test_dir = self.src_path.parent.resolve()
27        self.net_lib_dir = (Path(__file__).parent / "../../../../net/lib").resolve()
28
29        self.env = self._load_env_file()
30
31        # Following attrs must be set be inheriting classes
32        self.dev = None
33
34    def _load_env_file(self):
35        env = os.environ.copy()
36
37        src_dir = Path(self.src_path).parent.resolve()
38        if not (src_dir / "net.config").exists():
39            return ksft_setup(env)
40
41        with open((src_dir / "net.config").as_posix(), 'r') as fp:
42            for line in fp.readlines():
43                full_file = line
44                # Strip comments
45                pos = line.find("#")
46                if pos >= 0:
47                    line = line[:pos]
48                line = line.strip()
49                if not line:
50                    continue
51                pair = line.split('=', maxsplit=1)
52                if len(pair) != 2:
53                    raise Exception("Can't parse configuration line:", full_file)
54                env[pair[0]] = pair[1]
55        return ksft_setup(env)
56
57    def __del__(self):
58        pass
59
60    def __enter__(self):
61        ip(f"link set dev {self.dev['ifname']} up")
62        wait_file(f"/sys/class/net/{self.dev['ifname']}/carrier",
63                  lambda x: x.strip() == "1")
64
65        return self
66
67    def __exit__(self, ex_type, ex_value, ex_tb):
68        """
69        __exit__ gets called at the end of a "with" block.
70        """
71        self.__del__()
72
73
74class NetDrvEnv(NetDrvEnvBase):
75    """
76    Class for a single NIC / host env, with no remote end
77    """
78    def __init__(self, src_path, nsim_test=None, **kwargs):
79        super().__init__(src_path)
80
81        self._ns = None
82
83        if 'NETIF' in self.env:
84            if nsim_test is True:
85                raise KsftXfailEx("Test only works on netdevsim")
86
87            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
88        else:
89            if nsim_test is False:
90                raise KsftXfailEx("Test does not work on netdevsim")
91
92            self._ns = NetdevSimDev(**kwargs)
93            self.dev = self._ns.nsims[0].dev
94        self.ifname = self.dev['ifname']
95        self.ifindex = self.dev['ifindex']
96
97    def __del__(self):
98        if self._ns:
99            self._ns.remove()
100            self._ns = None
101
102
103class NetDrvEpEnv(NetDrvEnvBase):
104    """
105    Class for an environment with a local device and "remote endpoint"
106    which can be used to send traffic in.
107
108    For local testing it creates two network namespaces and a pair
109    of netdevsim devices.
110    """
111
112    # Network prefixes used for local tests
113    nsim_v4_pfx = "192.0.2."
114    nsim_v6_pfx = "2001:db8::"
115
116    def __init__(self, src_path, nsim_test=None):
117        super().__init__(src_path)
118
119        self._stats_settle_time = None
120
121        # Things we try to destroy
122        self.remote = None
123        # These are for local testing state
124        self._netns = None
125        self._ns = None
126        self._ns_peer = None
127
128        self.addr_v        = { "4": None, "6": None }
129        self.remote_addr_v = { "4": None, "6": None }
130
131        if "NETIF" in self.env:
132            if nsim_test is True:
133                raise KsftXfailEx("Test only works on netdevsim")
134            self._check_env()
135
136            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
137
138            self.addr_v["4"] = self.env.get("LOCAL_V4")
139            self.addr_v["6"] = self.env.get("LOCAL_V6")
140            self.remote_addr_v["4"] = self.env.get("REMOTE_V4")
141            self.remote_addr_v["6"] = self.env.get("REMOTE_V6")
142            kind = self.env["REMOTE_TYPE"]
143            args = self.env["REMOTE_ARGS"]
144        else:
145            if nsim_test is False:
146                raise KsftXfailEx("Test does not work on netdevsim")
147
148            self.create_local()
149
150            self.dev = self._ns.nsims[0].dev
151
152            self.addr_v["4"] = self.nsim_v4_pfx + "1"
153            self.addr_v["6"] = self.nsim_v6_pfx + "1"
154            self.remote_addr_v["4"] = self.nsim_v4_pfx + "2"
155            self.remote_addr_v["6"] = self.nsim_v6_pfx + "2"
156            kind = "netns"
157            args = self._netns.name
158
159        self.remote = Remote(kind, args, src_path)
160
161        self.addr_ipver = "6" if self.addr_v["6"] else "4"
162        self.addr = self.addr_v[self.addr_ipver]
163        self.remote_addr = self.remote_addr_v[self.addr_ipver]
164
165        # Bracketed addresses, some commands need IPv6 to be inside []
166        self.baddr = f"[{self.addr_v['6']}]" if self.addr_v["6"] else self.addr_v["4"]
167        self.remote_baddr = f"[{self.remote_addr_v['6']}]" if self.remote_addr_v["6"] else self.remote_addr_v["4"]
168
169        self.ifname = self.dev['ifname']
170        self.ifindex = self.dev['ifindex']
171
172        # resolve remote interface name
173        self.remote_ifname = self.resolve_remote_ifc()
174        self.remote_dev = ip("-d link show dev " + self.remote_ifname,
175                             host=self.remote, json=True)[0]
176        self.remote_ifindex = self.remote_dev['ifindex']
177
178        self._required_cmd = {}
179
180    def create_local(self):
181        self._netns = NetNS()
182        self._ns = NetdevSimDev()
183        self._ns_peer = NetdevSimDev(ns=self._netns)
184
185        with open("/proc/self/ns/net") as nsfd0, \
186             open("/var/run/netns/" + self._netns.name) as nsfd1:
187            ifi0 = self._ns.nsims[0].ifindex
188            ifi1 = self._ns_peer.nsims[0].ifindex
189            NetdevSimDev.ctrl_write('link_device',
190                                    f'{nsfd0.fileno()}:{ifi0} {nsfd1.fileno()}:{ifi1}')
191
192        ip(f"   addr add dev {self._ns.nsims[0].ifname} {self.nsim_v4_pfx}1/24")
193        ip(f"-6 addr add dev {self._ns.nsims[0].ifname} {self.nsim_v6_pfx}1/64 nodad")
194        ip(f"   link set dev {self._ns.nsims[0].ifname} up")
195
196        ip(f"   addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v4_pfx}2/24", ns=self._netns)
197        ip(f"-6 addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v6_pfx}2/64 nodad", ns=self._netns)
198        ip(f"   link set dev {self._ns_peer.nsims[0].ifname} up", ns=self._netns)
199
200    def _check_env(self):
201        vars_needed = [
202            ["LOCAL_V4", "LOCAL_V6"],
203            ["REMOTE_V4", "REMOTE_V6"],
204            ["REMOTE_TYPE"],
205            ["REMOTE_ARGS"]
206        ]
207        missing = []
208
209        for choice in vars_needed:
210            for entry in choice:
211                if entry in self.env:
212                    break
213            else:
214                missing.append(choice)
215        # Make sure v4 / v6 configs are symmetric
216        if ("LOCAL_V6" in self.env) != ("REMOTE_V6" in self.env):
217            missing.append(["LOCAL_V6", "REMOTE_V6"])
218        if ("LOCAL_V4" in self.env) != ("REMOTE_V4" in self.env):
219            missing.append(["LOCAL_V4", "REMOTE_V4"])
220        if missing:
221            raise Exception("Invalid environment, missing configuration:", missing,
222                            "Please see tools/testing/selftests/drivers/net/README.rst")
223
224    def resolve_remote_ifc(self):
225        v4 = v6 = None
226        if self.remote_addr_v["4"]:
227            v4 = ip("addr show to " + self.remote_addr_v["4"], json=True, host=self.remote)
228        if self.remote_addr_v["6"]:
229            v6 = ip("addr show to " + self.remote_addr_v["6"], json=True, host=self.remote)
230        if v4 and v6 and v4[0]["ifname"] != v6[0]["ifname"]:
231            raise Exception("Can't resolve remote interface name, v4 and v6 don't match")
232        if (v4 and len(v4) > 1) or (v6 and len(v6) > 1):
233            raise Exception("Can't resolve remote interface name, multiple interfaces match")
234        return v6[0]["ifname"] if v6 else v4[0]["ifname"]
235
236    def __del__(self):
237        if self._ns:
238            self._ns.remove()
239            self._ns = None
240        if self._ns_peer:
241            self._ns_peer.remove()
242            self._ns_peer = None
243        if self._netns:
244            del self._netns
245            self._netns = None
246        if self.remote:
247            del self.remote
248            self.remote = None
249
250    def require_ipver(self, ipver):
251        if not self.addr_v[ipver] or not self.remote_addr_v[ipver]:
252            raise KsftSkipEx(f"Test requires IPv{ipver} connectivity")
253
254    def require_nsim(self, nsim_test=True):
255        """Require or exclude netdevsim for this test"""
256        if nsim_test and self._ns is None:
257            raise KsftXfailEx("Test only works on netdevsim")
258        if nsim_test is False and self._ns is not None:
259            raise KsftXfailEx("Test does not work on netdevsim")
260
261    def _require_cmd(self, comm, key, host=None):
262        cached = self._required_cmd.get(comm, {})
263        if cached.get(key) is None:
264            cached[key] = cmd("command -v -- " + comm, fail=False,
265                              shell=True, host=host).ret == 0
266        self._required_cmd[comm] = cached
267        return cached[key]
268
269    def require_cmd(self, comm, local=True, remote=False):
270        if local:
271            if not self._require_cmd(comm, "local"):
272                raise KsftSkipEx("Test requires command: " + comm)
273        if remote:
274            if not self._require_cmd(comm, "remote", host=self.remote):
275                raise KsftSkipEx("Test requires (remote) command: " + comm)
276
277    def wait_hw_stats_settle(self):
278        """
279        Wait for HW stats to become consistent, some devices DMA HW stats
280        periodically so events won't be reflected until next sync.
281        Good drivers will tell us via ethtool what their sync period is.
282        """
283        if self._stats_settle_time is None:
284            data = {}
285            try:
286                data = ethtool("-c " + self.ifname, json=True)[0]
287            except CmdExitFailure as e:
288                if "Operation not supported" not in e.cmd.stderr:
289                    raise
290
291            self._stats_settle_time = 0.025 + \
292                data.get('stats-block-usecs', 0) / 1000 / 1000
293
294        time.sleep(self._stats_settle_time)
295
296
297class NetDrvContEnv(NetDrvEpEnv):
298    """
299    Class for an environment with a netkit pair setup for forwarding traffic
300    between the physical interface and a network namespace.
301    """
302
303    def __init__(self, src_path, nk_rxqueues=1, **kwargs):
304        super().__init__(src_path, **kwargs)
305
306        self.require_ipver("6")
307        local_prefix = self.env.get("LOCAL_PREFIX_V6")
308        if not local_prefix:
309            raise KsftSkipEx("LOCAL_PREFIX_V6 required")
310
311        local_prefix = local_prefix.rstrip("/64").rstrip("::").rstrip(":")
312        self.ipv6_prefix = f"{local_prefix}::"
313        self.nk_host_ipv6 = f"{local_prefix}::2:1"
314        self.nk_guest_ipv6 = f"{local_prefix}::2:2"
315
316        self.netns = None
317        self._nk_host_ifname = None
318        self._nk_guest_ifname = None
319        self._tc_attached = False
320        self._bpf_prog_pref = None
321        self._bpf_prog_id = None
322
323        ip(f"link add type netkit mode l2 forward peer forward numrxqueues {nk_rxqueues}")
324
325        all_links = ip("-d link show", json=True)
326        netkit_links = [link for link in all_links
327                        if link.get('linkinfo', {}).get('info_kind') == 'netkit'
328                        and 'UP' not in link.get('flags', [])]
329
330        if len(netkit_links) != 2:
331            raise KsftSkipEx("Failed to create netkit pair")
332
333        netkit_links.sort(key=lambda x: x['ifindex'])
334        self._nk_host_ifname = netkit_links[1]['ifname']
335        self._nk_guest_ifname = netkit_links[0]['ifname']
336        self.nk_host_ifindex = netkit_links[1]['ifindex']
337        self.nk_guest_ifindex = netkit_links[0]['ifindex']
338
339        self._setup_ns()
340        self._attach_bpf()
341
342    def __del__(self):
343        if self._tc_attached:
344            cmd(f"tc filter del dev {self.ifname} ingress pref {self._bpf_prog_pref}")
345            self._tc_attached = False
346
347        if self._nk_host_ifname:
348            cmd(f"ip link del dev {self._nk_host_ifname}")
349            self._nk_host_ifname = None
350            self._nk_guest_ifname = None
351
352        if self.netns:
353            del self.netns
354            self.netns = None
355
356        super().__del__()
357
358    def _setup_ns(self):
359        self.netns = NetNS()
360        ip(f"link set dev {self._nk_guest_ifname} netns {self.netns.name}")
361        ip(f"link set dev {self._nk_host_ifname} up")
362        ip(f"-6 addr add fe80::1/64 dev {self._nk_host_ifname} nodad")
363        ip(f"-6 route add {self.nk_guest_ipv6}/128 via fe80::2 dev {self._nk_host_ifname}")
364
365        ip("link set lo up", ns=self.netns)
366        ip(f"link set dev {self._nk_guest_ifname} up", ns=self.netns)
367        ip(f"-6 addr add fe80::2/64 dev {self._nk_guest_ifname}", ns=self.netns)
368        ip(f"-6 addr add {self.nk_guest_ipv6}/64 dev {self._nk_guest_ifname} nodad", ns=self.netns)
369        ip(f"-6 route add default via fe80::1 dev {self._nk_guest_ifname}", ns=self.netns)
370
371    def _attach_bpf(self):
372        bpf_obj = self.test_dir / "nk_forward.bpf.o"
373        if not bpf_obj.exists():
374            raise KsftSkipEx("BPF prog not found")
375
376        cmd(f"tc filter add dev {self.ifname} ingress bpf obj {bpf_obj} sec tc/ingress direct-action")
377        self._tc_attached = True
378
379        tc_info = cmd(f"tc filter show dev {self.ifname} ingress").stdout
380        match = re.search(r'pref (\d+).*nk_forward\.bpf.*id (\d+)', tc_info)
381        if not match:
382            raise Exception("Failed to get BPF prog ID")
383        self._bpf_prog_pref = int(match.group(1))
384        self._bpf_prog_id = int(match.group(2))
385
386        prog_info = bpftool(f"prog show id {self._bpf_prog_id}", json=True)
387        map_ids = prog_info.get("map_ids", [])
388
389        bss_map_id = None
390        for map_id in map_ids:
391            map_info = bpftool(f"map show id {map_id}", json=True)
392            if map_info.get("name").endswith("bss"):
393                bss_map_id = map_id
394
395        if bss_map_id is None:
396            raise Exception("Failed to find .bss map")
397
398        ipv6_addr = ipaddress.IPv6Address(self.ipv6_prefix)
399        ipv6_bytes = ipv6_addr.packed
400        ifindex_bytes = self.nk_host_ifindex.to_bytes(4, byteorder='little')
401        value = ipv6_bytes + ifindex_bytes
402        value_hex = ' '.join(f'{b:02x}' for b in value)
403        bpftool(f"map update id {bss_map_id} key hex 00 00 00 00 value hex {value_hex}")
404