xref: /linux/tools/testing/selftests/drivers/net/lib/py/env.py (revision 35c2c39832e569449b9192fa1afbbc4c66227af7)
1# SPDX-License-Identifier: GPL-2.0
2
3import ipaddress
4import os
5import time
6import json
7from pathlib import Path
8from lib.py import KsftSkipEx, KsftXfailEx
9from lib.py import ksft_setup, wait_file
10from lib.py import cmd, ethtool, ip, CmdExitFailure
11from lib.py import NetNS, NetdevSimDev
12from .remote import Remote
13from . import bpftool, RtnlFamily, Netlink
14
15
16class NetDrvEnvBase:
17    """
18    Base class for a NIC / host environments
19
20    Attributes:
21      test_dir: Path to the source directory of the test
22      net_lib_dir: Path to the net/lib directory
23    """
24    def __init__(self, src_path):
25        self.src_path = Path(src_path)
26        self.test_dir = self.src_path.parent.resolve()
27        self.net_lib_dir = (Path(__file__).parent / "../../../../net/lib").resolve()
28
29        self.env = self._load_env_file()
30
31        # Following attrs must be set be inheriting classes
32        self.dev = None
33
34    def _load_env_file(self):
35        env = os.environ.copy()
36
37        src_dir = Path(self.src_path).parent.resolve()
38        if not (src_dir / "net.config").exists():
39            return ksft_setup(env)
40
41        with open((src_dir / "net.config").as_posix(), 'r') as fp:
42            for line in fp.readlines():
43                full_file = line
44                # Strip comments
45                pos = line.find("#")
46                if pos >= 0:
47                    line = line[:pos]
48                line = line.strip()
49                if not line:
50                    continue
51                pair = line.split('=', maxsplit=1)
52                if len(pair) != 2:
53                    raise Exception("Can't parse configuration line:", full_file)
54                env[pair[0]] = pair[1]
55        return ksft_setup(env)
56
57    def __del__(self):
58        pass
59
60    def __enter__(self):
61        ip(f"link set dev {self.dev['ifname']} up")
62        wait_file(f"/sys/class/net/{self.dev['ifname']}/carrier",
63                  lambda x: x.strip() == "1")
64
65        return self
66
67    def __exit__(self, ex_type, ex_value, ex_tb):
68        """
69        __exit__ gets called at the end of a "with" block.
70        """
71        self.__del__()
72
73
74class NetDrvEnv(NetDrvEnvBase):
75    """
76    Class for a single NIC / host env, with no remote end
77    """
78    def __init__(self, src_path, nsim_test=None, **kwargs):
79        super().__init__(src_path)
80
81        self._ns = None
82
83        if 'NETIF' in self.env:
84            if nsim_test is True:
85                raise KsftXfailEx("Test only works on netdevsim")
86
87            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
88        else:
89            if nsim_test is False:
90                raise KsftXfailEx("Test does not work on netdevsim")
91
92            self._ns = NetdevSimDev(**kwargs)
93            self.dev = self._ns.nsims[0].dev
94        self.ifname = self.dev['ifname']
95        self.ifindex = self.dev['ifindex']
96
97    def __del__(self):
98        if self._ns:
99            self._ns.remove()
100            self._ns = None
101
102
103class NetDrvEpEnv(NetDrvEnvBase):
104    """
105    Class for an environment with a local device and "remote endpoint"
106    which can be used to send traffic in.
107
108    For local testing it creates two network namespaces and a pair
109    of netdevsim devices.
110    """
111
112    # Network prefixes used for local tests
113    nsim_v4_pfx = "192.0.2."
114    nsim_v6_pfx = "2001:db8::"
115
116    def __init__(self, src_path, nsim_test=None):
117        super().__init__(src_path)
118
119        self._stats_settle_time = None
120
121        # Things we try to destroy
122        self.remote = None
123        # These are for local testing state
124        self._netns = None
125        self._ns = None
126        self._ns_peer = None
127
128        self.addr_v        = { "4": None, "6": None }
129        self.remote_addr_v = { "4": None, "6": None }
130
131        if "NETIF" in self.env:
132            if nsim_test is True:
133                raise KsftXfailEx("Test only works on netdevsim")
134            self._check_env()
135
136            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
137
138            self.addr_v["4"] = self.env.get("LOCAL_V4")
139            self.addr_v["6"] = self.env.get("LOCAL_V6")
140            self.remote_addr_v["4"] = self.env.get("REMOTE_V4")
141            self.remote_addr_v["6"] = self.env.get("REMOTE_V6")
142            kind = self.env["REMOTE_TYPE"]
143            args = self.env["REMOTE_ARGS"]
144        else:
145            if nsim_test is False:
146                raise KsftXfailEx("Test does not work on netdevsim")
147
148            self.create_local()
149
150            self.dev = self._ns.nsims[0].dev
151
152            self.addr_v["4"] = self.nsim_v4_pfx + "1"
153            self.addr_v["6"] = self.nsim_v6_pfx + "1"
154            self.remote_addr_v["4"] = self.nsim_v4_pfx + "2"
155            self.remote_addr_v["6"] = self.nsim_v6_pfx + "2"
156            kind = "netns"
157            args = self._netns.name
158
159        self.remote = Remote(kind, args, src_path)
160
161        self.addr_ipver = "6" if self.addr_v["6"] else "4"
162        self.addr = self.addr_v[self.addr_ipver]
163        self.remote_addr = self.remote_addr_v[self.addr_ipver]
164
165        # Bracketed addresses, some commands need IPv6 to be inside []
166        self.baddr = f"[{self.addr_v['6']}]" if self.addr_v["6"] else self.addr_v["4"]
167        self.remote_baddr = f"[{self.remote_addr_v['6']}]" if self.remote_addr_v["6"] else self.remote_addr_v["4"]
168
169        self.ifname = self.dev['ifname']
170        self.ifindex = self.dev['ifindex']
171
172        # resolve remote interface name
173        self.remote_ifname = self.resolve_remote_ifc()
174        self.remote_dev = ip("-d link show dev " + self.remote_ifname,
175                             host=self.remote, json=True)[0]
176        self.remote_ifindex = self.remote_dev['ifindex']
177
178        self._required_cmd = {}
179
180    def create_local(self):
181        self._netns = NetNS()
182        self._ns = NetdevSimDev()
183        self._ns_peer = NetdevSimDev(ns=self._netns)
184
185        with open("/proc/self/ns/net") as nsfd0, \
186             open("/var/run/netns/" + self._netns.name) as nsfd1:
187            ifi0 = self._ns.nsims[0].ifindex
188            ifi1 = self._ns_peer.nsims[0].ifindex
189            NetdevSimDev.ctrl_write('link_device',
190                                    f'{nsfd0.fileno()}:{ifi0} {nsfd1.fileno()}:{ifi1}')
191
192        ip(f"   addr add dev {self._ns.nsims[0].ifname} {self.nsim_v4_pfx}1/24")
193        ip(f"-6 addr add dev {self._ns.nsims[0].ifname} {self.nsim_v6_pfx}1/64 nodad")
194        ip(f"   link set dev {self._ns.nsims[0].ifname} up")
195
196        ip(f"   addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v4_pfx}2/24", ns=self._netns)
197        ip(f"-6 addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v6_pfx}2/64 nodad", ns=self._netns)
198        ip(f"   link set dev {self._ns_peer.nsims[0].ifname} up", ns=self._netns)
199
200    def _check_env(self):
201        vars_needed = [
202            ["LOCAL_V4", "LOCAL_V6"],
203            ["REMOTE_V4", "REMOTE_V6"],
204            ["REMOTE_TYPE"],
205            ["REMOTE_ARGS"]
206        ]
207        missing = []
208
209        for choice in vars_needed:
210            for entry in choice:
211                if entry in self.env:
212                    break
213            else:
214                missing.append(choice)
215        # Make sure v4 / v6 configs are symmetric
216        if ("LOCAL_V6" in self.env) != ("REMOTE_V6" in self.env):
217            missing.append(["LOCAL_V6", "REMOTE_V6"])
218        if ("LOCAL_V4" in self.env) != ("REMOTE_V4" in self.env):
219            missing.append(["LOCAL_V4", "REMOTE_V4"])
220        if missing:
221            raise Exception("Invalid environment, missing configuration:", missing,
222                            "Please see tools/testing/selftests/drivers/net/README.rst")
223
224    def resolve_remote_ifc(self):
225        v4 = v6 = None
226        if self.remote_addr_v["4"]:
227            v4 = ip("addr show to " + self.remote_addr_v["4"], json=True, host=self.remote)
228        if self.remote_addr_v["6"]:
229            v6 = ip("addr show to " + self.remote_addr_v["6"], json=True, host=self.remote)
230        if v4 and v6 and v4[0]["ifname"] != v6[0]["ifname"]:
231            raise Exception("Can't resolve remote interface name, v4 and v6 don't match")
232        if (v4 and len(v4) > 1) or (v6 and len(v6) > 1):
233            raise Exception("Can't resolve remote interface name, multiple interfaces match")
234        return v6[0]["ifname"] if v6 else v4[0]["ifname"]
235
236    def __del__(self):
237        if self._ns:
238            self._ns.remove()
239            self._ns = None
240        if self._ns_peer:
241            self._ns_peer.remove()
242            self._ns_peer = None
243        if self._netns:
244            del self._netns
245            self._netns = None
246        if self.remote:
247            del self.remote
248            self.remote = None
249
250    def require_ipver(self, ipver):
251        if not self.addr_v[ipver] or not self.remote_addr_v[ipver]:
252            raise KsftSkipEx(f"Test requires IPv{ipver} connectivity")
253
254    def require_nsim(self, nsim_test=True):
255        """Require or exclude netdevsim for this test"""
256        if nsim_test and self._ns is None:
257            raise KsftXfailEx("Test only works on netdevsim")
258        if nsim_test is False and self._ns is not None:
259            raise KsftXfailEx("Test does not work on netdevsim")
260
261    def get_local_nsim_dev(self):
262        """Returns the local netdevsim device or None.
263           Using this method is discouraged, as it makes tests nsim-specific.
264           Standard interfaces available on all HW should ideally be used.
265           This method is intended for the few cases where nsim-specific
266           assertions need to be verified which cannot be verified otherwise.
267        """
268        return self._ns
269
270    def _require_cmd(self, comm, key, host=None):
271        cached = self._required_cmd.get(comm, {})
272        if cached.get(key) is None:
273            cached[key] = cmd("command -v -- " + comm, fail=False,
274                              shell=True, host=host).ret == 0
275        self._required_cmd[comm] = cached
276        return cached[key]
277
278    def require_cmd(self, comm, local=True, remote=False):
279        if local:
280            if not self._require_cmd(comm, "local"):
281                raise KsftSkipEx("Test requires command: " + comm)
282        if remote:
283            if not self._require_cmd(comm, "remote", host=self.remote):
284                raise KsftSkipEx("Test requires (remote) command: " + comm)
285
286    def wait_hw_stats_settle(self):
287        """
288        Wait for HW stats to become consistent, some devices DMA HW stats
289        periodically so events won't be reflected until next sync.
290        Good drivers will tell us via ethtool what their sync period is.
291        """
292        if self._stats_settle_time is None:
293            data = {}
294            try:
295                data = ethtool("-c " + self.ifname, json=True)[0]
296            except CmdExitFailure as e:
297                if "Operation not supported" not in e.cmd.stderr:
298                    raise
299
300            self._stats_settle_time = \
301                1.25 * data.get('stats-block-usecs', 20000) / 1000 / 1000
302
303        time.sleep(self._stats_settle_time)
304
305
306class NetDrvContEnv(NetDrvEpEnv):
307    """
308    Class for an environment with a netkit pair setup for forwarding traffic
309    between the physical interface and a network namespace.
310      NETIF           = "eth0"
311      LOCAL_V6        = "2001:db8:1::1"
312      REMOTE_V6       = "2001:db8:1::2"
313      LOCAL_PREFIX_V6 = "2001:db8:2::0/64"
314
315              +-----------------------------+        +------------------------------+
316      dst     | INIT NS                     |        | TEST NS                      |
317      2001:   | +---------------+           |        |                              |
318      db8:2::2| | NETIF         |           |  bpf   |                              |
319          +---|>| 2001:db8:1::1 |           |redirect| +-------------------------+  |
320          |   | |               |-----------|--------|>| Netkit                  |  |
321          |   | +---------------+           | _peer  | | nk_guest                |  |
322          |   | +-------------+ Netkit pair |        | | fe80::2/64              |  |
323          |   | | Netkit      |.............|........|>| 2001:db8:2::2/64        |  |
324          |   | | nk_host     |             |        | +-------------------------+  |
325          |   | | fe80::1/64  |             |        |                              |
326          |   | +-------------+             |        | route:                       |
327          |   |                             |        |   default                    |
328          |   | route:                      |        |     via fe80::1 dev nk_guest |
329          |   |   2001:db8:2::2/128         |        +------------------------------+
330          |   |     via fe80::2 dev nk_host |
331          |   +-----------------------------+
332          |
333          |   +---------------+
334          |   | REMOTE        |
335          +---| 2001:db8:1::2 |
336              +---------------+
337    """
338
339    def __init__(self, src_path, rxqueues=1, **kwargs):
340        self.netns = None
341        self._nk_host_ifname = None
342        self._nk_guest_ifname = None
343        self._tc_clsact_added = False
344        self._tc_attached = False
345        self._bpf_prog_pref = None
346        self._bpf_prog_id = None
347        self._init_ns_attached = False
348        self._old_fwd = None
349        self._old_accept_ra = None
350
351        super().__init__(src_path, **kwargs)
352
353        self.require_ipver("6")
354        local_prefix = self.env.get("LOCAL_PREFIX_V6")
355        if not local_prefix:
356            raise KsftSkipEx("LOCAL_PREFIX_V6 required")
357
358        net = ipaddress.IPv6Network(local_prefix, strict=False)
359        self.ipv6_prefix = str(net.network_address)
360        self.nk_host_ipv6 = f"{self.ipv6_prefix}2:1"
361        self.nk_guest_ipv6 = f"{self.ipv6_prefix}2:2"
362
363        local_v6 = ipaddress.IPv6Address(self.addr_v["6"])
364        if local_v6 in net:
365            raise KsftSkipEx("LOCAL_V6 must not fall within LOCAL_PREFIX_V6")
366
367        rtnl = RtnlFamily()
368        rtnl.newlink(
369            {
370                "linkinfo": {
371                    "kind": "netkit",
372                    "data": {
373                        "mode": "l2",
374                        "policy": "forward",
375                        "peer-policy": "forward",
376                    },
377                },
378                "num-rx-queues": rxqueues,
379            },
380            flags=[Netlink.NLM_F_CREATE, Netlink.NLM_F_EXCL],
381        )
382
383        all_links = ip("-d link show", json=True)
384        netkit_links = [link for link in all_links
385                        if link.get('linkinfo', {}).get('info_kind') == 'netkit'
386                        and 'UP' not in link.get('flags', [])]
387
388        if len(netkit_links) != 2:
389            raise KsftSkipEx("Failed to create netkit pair")
390
391        netkit_links.sort(key=lambda x: x['ifindex'])
392        self._nk_host_ifname = netkit_links[1]['ifname']
393        self._nk_guest_ifname = netkit_links[0]['ifname']
394        self.nk_host_ifindex = netkit_links[1]['ifindex']
395        self.nk_guest_ifindex = netkit_links[0]['ifindex']
396
397        self._setup_ns()
398        self._attach_bpf()
399
400    def __del__(self):
401        if self._tc_attached:
402            cmd(f"tc filter del dev {self.ifname} ingress pref {self._bpf_prog_pref}")
403            self._tc_attached = False
404
405        if self._tc_clsact_added:
406            cmd(f"tc qdisc del dev {self.ifname} clsact")
407            self._tc_clsact_added = False
408
409        if self._nk_host_ifname:
410            cmd(f"ip link del dev {self._nk_host_ifname}")
411            self._nk_host_ifname = None
412            self._nk_guest_ifname = None
413
414        if self._init_ns_attached:
415            cmd("ip netns del init", fail=False)
416            self._init_ns_attached = False
417
418        if self.netns:
419            del self.netns
420            self.netns = None
421
422        if self._old_fwd is not None:
423            with open("/proc/sys/net/ipv6/conf/all/forwarding", "w",
424                      encoding="utf-8") as f:
425                f.write(self._old_fwd)
426            self._old_fwd = None
427        if self._old_accept_ra is not None:
428            with open("/proc/sys/net/ipv6/conf/all/accept_ra", "w",
429                      encoding="utf-8") as f:
430                f.write(self._old_accept_ra)
431            self._old_accept_ra = None
432
433        super().__del__()
434
435    def _setup_ns(self):
436        fwd_path = "/proc/sys/net/ipv6/conf/all/forwarding"
437        ra_path = "/proc/sys/net/ipv6/conf/all/accept_ra"
438        with open(fwd_path, encoding="utf-8") as f:
439            self._old_fwd = f.read().strip()
440        with open(ra_path, encoding="utf-8") as f:
441            self._old_accept_ra = f.read().strip()
442        with open(fwd_path, "w", encoding="utf-8") as f:
443            f.write("1")
444        with open(ra_path, "w", encoding="utf-8") as f:
445            f.write("2")
446
447        self.netns = NetNS()
448        cmd("ip netns attach init 1")
449        self._init_ns_attached = True
450        ip("netns set init 0", ns=self.netns)
451        ip(f"link set dev {self._nk_guest_ifname} netns {self.netns.name}")
452        ip(f"link set dev {self._nk_host_ifname} up")
453        ip(f"-6 addr add fe80::1/64 dev {self._nk_host_ifname} nodad")
454        ip(f"-6 route add {self.nk_guest_ipv6}/128 via fe80::2 dev {self._nk_host_ifname}")
455
456        ip("link set lo up", ns=self.netns)
457        ip(f"link set dev {self._nk_guest_ifname} up", ns=self.netns)
458        ip(f"-6 addr add fe80::2/64 dev {self._nk_guest_ifname}", ns=self.netns)
459        ip(f"-6 addr add {self.nk_guest_ipv6}/64 dev {self._nk_guest_ifname} nodad", ns=self.netns)
460        ip(f"-6 route add default via fe80::1 dev {self._nk_guest_ifname}", ns=self.netns)
461
462    def _tc_ensure_clsact(self):
463        qdisc = json.loads(cmd(f"tc -j qdisc show dev {self.ifname}").stdout)
464        for q in qdisc:
465            if q['kind'] == 'clsact':
466                return
467        cmd(f"tc qdisc add dev {self.ifname} clsact")
468        self._tc_clsact_added = True
469
470    def _get_bpf_prog_ids(self):
471        filters = json.loads(cmd(f"tc -j filter show dev {self.ifname} ingress").stdout)
472        for bpf in filters:
473            if 'options' not in bpf:
474                continue
475            if bpf['options']['bpf_name'].startswith('nk_forward.bpf'):
476                return (bpf['pref'], bpf['options']['prog']['id'])
477        raise Exception("Failed to get BPF prog ID")
478
479    def _attach_bpf(self):
480        bpf_obj = self.test_dir / "nk_forward.bpf.o"
481        if not bpf_obj.exists():
482            raise KsftSkipEx("BPF prog not found")
483
484        self._tc_ensure_clsact()
485        cmd(f"tc filter add dev {self.ifname} ingress bpf obj {bpf_obj}"
486            " sec tc/ingress direct-action")
487        self._tc_attached = True
488
489        (self._bpf_prog_pref, self._bpf_prog_id) = self._get_bpf_prog_ids()
490        prog_info = bpftool(f"prog show id {self._bpf_prog_id}", json=True)
491        map_ids = prog_info.get("map_ids", [])
492
493        bss_map_id = None
494        for map_id in map_ids:
495            map_info = bpftool(f"map show id {map_id}", json=True)
496            if map_info.get("name").endswith("bss"):
497                bss_map_id = map_id
498
499        if bss_map_id is None:
500            raise Exception("Failed to find .bss map")
501
502        ipv6_addr = ipaddress.IPv6Address(self.ipv6_prefix)
503        ipv6_bytes = ipv6_addr.packed
504        ifindex_bytes = self.nk_host_ifindex.to_bytes(4, byteorder='little')
505        value = ipv6_bytes + ifindex_bytes
506        value_hex = ' '.join(f'{b:02x}' for b in value)
507        bpftool(f"map update id {bss_map_id} key hex 00 00 00 00 value hex {value_hex}")
508