xref: /linux/tools/testing/selftests/drivers/net/lib/py/env.py (revision 77a6401a8722be20ea8db98ac900c93ccc7068ff)
1# SPDX-License-Identifier: GPL-2.0
2
3import ipaddress
4import os
5import time
6import json
7from pathlib import Path
8from lib.py import KsftSkipEx, KsftXfailEx
9from lib.py import ksft_setup, wait_file
10from lib.py import cmd, ethtool, ip, CmdExitFailure
11from lib.py import NetNS, NetdevSimDev
12from .remote import Remote
13from . import bpftool, RtnlFamily, Netlink
14
15
16class NetDrvEnvBase:
17    """
18    Base class for a NIC / host environments
19
20    Attributes:
21      test_dir: Path to the source directory of the test
22      net_lib_dir: Path to the net/lib directory
23    """
24    def __init__(self, src_path):
25        self.src_path = Path(src_path)
26        self.test_dir = self.src_path.parent.resolve()
27        self.net_lib_dir = (Path(__file__).parent / "../../../../net/lib").resolve()
28
29        self.env = self._load_env_file()
30
31        # Following attrs must be set be inheriting classes
32        self.dev = None
33
34    def _load_env_file(self):
35        env = os.environ.copy()
36
37        src_dir = Path(self.src_path).parent.resolve()
38        if not (src_dir / "net.config").exists():
39            return ksft_setup(env)
40
41        with open((src_dir / "net.config").as_posix(), 'r') as fp:
42            for line in fp.readlines():
43                full_file = line
44                # Strip comments
45                pos = line.find("#")
46                if pos >= 0:
47                    line = line[:pos]
48                line = line.strip()
49                if not line:
50                    continue
51                pair = line.split('=', maxsplit=1)
52                if len(pair) != 2:
53                    raise Exception("Can't parse configuration line:", full_file)
54                env[pair[0]] = pair[1]
55        return ksft_setup(env)
56
57    def __del__(self):
58        pass
59
60    def __enter__(self):
61        ip(f"link set dev {self.dev['ifname']} up")
62        wait_file(f"/sys/class/net/{self.dev['ifname']}/carrier",
63                  lambda x: x.strip() == "1")
64
65        return self
66
67    def __exit__(self, ex_type, ex_value, ex_tb):
68        """
69        __exit__ gets called at the end of a "with" block.
70        """
71        self.__del__()
72
73
74class NetDrvEnv(NetDrvEnvBase):
75    """
76    Class for a single NIC / host env, with no remote end
77    """
78    def __init__(self, src_path, nsim_test=None, **kwargs):
79        super().__init__(src_path)
80
81        self._ns = None
82
83        if 'NETIF' in self.env:
84            if nsim_test is True:
85                raise KsftXfailEx("Test only works on netdevsim")
86
87            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
88        else:
89            if nsim_test is False:
90                raise KsftXfailEx("Test does not work on netdevsim")
91
92            self._ns = NetdevSimDev(**kwargs)
93            self.dev = self._ns.nsims[0].dev
94        self.ifname = self.dev['ifname']
95        self.ifindex = self.dev['ifindex']
96
97    def __del__(self):
98        if self._ns:
99            self._ns.remove()
100            self._ns = None
101
102
103class NetDrvEpEnv(NetDrvEnvBase):
104    """
105    Class for an environment with a local device and "remote endpoint"
106    which can be used to send traffic in.
107
108    For local testing it creates two network namespaces and a pair
109    of netdevsim devices.
110    """
111
112    # Network prefixes used for local tests
113    nsim_v4_pfx = "192.0.2."
114    nsim_v6_pfx = "2001:db8::"
115
116    def __init__(self, src_path, nsim_test=None):
117        super().__init__(src_path)
118
119        self._stats_settle_time = None
120
121        # Things we try to destroy
122        self.remote = None
123        # These are for local testing state
124        self._netns = None
125        self._ns = None
126        self._ns_peer = None
127
128        self.addr_v        = { "4": None, "6": None }
129        self.remote_addr_v = { "4": None, "6": None }
130
131        if "NETIF" in self.env:
132            if nsim_test is True:
133                raise KsftXfailEx("Test only works on netdevsim")
134            self._check_env()
135
136            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
137
138            self.addr_v["4"] = self.env.get("LOCAL_V4")
139            self.addr_v["6"] = self.env.get("LOCAL_V6")
140            self.remote_addr_v["4"] = self.env.get("REMOTE_V4")
141            self.remote_addr_v["6"] = self.env.get("REMOTE_V6")
142            kind = self.env["REMOTE_TYPE"]
143            args = self.env["REMOTE_ARGS"]
144        else:
145            if nsim_test is False:
146                raise KsftXfailEx("Test does not work on netdevsim")
147
148            self.create_local()
149
150            self.dev = self._ns.nsims[0].dev
151
152            self.addr_v["4"] = self.nsim_v4_pfx + "1"
153            self.addr_v["6"] = self.nsim_v6_pfx + "1"
154            self.remote_addr_v["4"] = self.nsim_v4_pfx + "2"
155            self.remote_addr_v["6"] = self.nsim_v6_pfx + "2"
156            kind = "netns"
157            args = self._netns.name
158
159        self.remote = Remote(kind, args, src_path)
160
161        self.addr_ipver = "6" if self.addr_v["6"] else "4"
162        self.addr = self.addr_v[self.addr_ipver]
163        self.remote_addr = self.remote_addr_v[self.addr_ipver]
164
165        # Bracketed addresses, some commands need IPv6 to be inside []
166        self.baddr = f"[{self.addr_v['6']}]" if self.addr_v["6"] else self.addr_v["4"]
167        self.remote_baddr = f"[{self.remote_addr_v['6']}]" if self.remote_addr_v["6"] else self.remote_addr_v["4"]
168
169        self.ifname = self.dev['ifname']
170        self.ifindex = self.dev['ifindex']
171
172        # resolve remote interface name
173        self.remote_ifname = self.resolve_remote_ifc()
174        self.remote_dev = ip("-d link show dev " + self.remote_ifname,
175                             host=self.remote, json=True)[0]
176        self.remote_ifindex = self.remote_dev['ifindex']
177
178        self._required_cmd = {}
179
180    def create_local(self):
181        self._netns = NetNS()
182        self._ns = NetdevSimDev()
183        self._ns_peer = NetdevSimDev(ns=self._netns)
184
185        with open("/proc/self/ns/net") as nsfd0, \
186             open("/var/run/netns/" + self._netns.name) as nsfd1:
187            ifi0 = self._ns.nsims[0].ifindex
188            ifi1 = self._ns_peer.nsims[0].ifindex
189            NetdevSimDev.ctrl_write('link_device',
190                                    f'{nsfd0.fileno()}:{ifi0} {nsfd1.fileno()}:{ifi1}')
191
192        ip(f"   addr add dev {self._ns.nsims[0].ifname} {self.nsim_v4_pfx}1/24")
193        ip(f"-6 addr add dev {self._ns.nsims[0].ifname} {self.nsim_v6_pfx}1/64 nodad")
194        ip(f"   link set dev {self._ns.nsims[0].ifname} up")
195
196        ip(f"   addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v4_pfx}2/24", ns=self._netns)
197        ip(f"-6 addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v6_pfx}2/64 nodad", ns=self._netns)
198        ip(f"   link set dev {self._ns_peer.nsims[0].ifname} up", ns=self._netns)
199
200    def _check_env(self):
201        vars_needed = [
202            ["LOCAL_V4", "LOCAL_V6"],
203            ["REMOTE_V4", "REMOTE_V6"],
204            ["REMOTE_TYPE"],
205            ["REMOTE_ARGS"]
206        ]
207        missing = []
208
209        for choice in vars_needed:
210            for entry in choice:
211                if entry in self.env:
212                    break
213            else:
214                missing.append(choice)
215        # Make sure v4 / v6 configs are symmetric
216        if ("LOCAL_V6" in self.env) != ("REMOTE_V6" in self.env):
217            missing.append(["LOCAL_V6", "REMOTE_V6"])
218        if ("LOCAL_V4" in self.env) != ("REMOTE_V4" in self.env):
219            missing.append(["LOCAL_V4", "REMOTE_V4"])
220        if missing:
221            raise Exception("Invalid environment, missing configuration:", missing,
222                            "Please see tools/testing/selftests/drivers/net/README.rst")
223
224    def resolve_remote_ifc(self):
225        v4 = v6 = None
226        if self.remote_addr_v["4"]:
227            v4 = ip("addr show to " + self.remote_addr_v["4"], json=True, host=self.remote)
228        if self.remote_addr_v["6"]:
229            v6 = ip("addr show to " + self.remote_addr_v["6"], json=True, host=self.remote)
230        if v4 and v6 and v4[0]["ifname"] != v6[0]["ifname"]:
231            raise Exception("Can't resolve remote interface name, v4 and v6 don't match")
232        if (v4 and len(v4) > 1) or (v6 and len(v6) > 1):
233            raise Exception("Can't resolve remote interface name, multiple interfaces match")
234        return v6[0]["ifname"] if v6 else v4[0]["ifname"]
235
236    def __del__(self):
237        if self._ns:
238            self._ns.remove()
239            self._ns = None
240        if self._ns_peer:
241            self._ns_peer.remove()
242            self._ns_peer = None
243        if self._netns:
244            del self._netns
245            self._netns = None
246        if self.remote:
247            del self.remote
248            self.remote = None
249
250    def require_ipver(self, ipver):
251        if not self.addr_v[ipver] or not self.remote_addr_v[ipver]:
252            raise KsftSkipEx(f"Test requires IPv{ipver} connectivity")
253
254    def require_nsim(self, nsim_test=True):
255        """Require or exclude netdevsim for this test"""
256        if nsim_test and self._ns is None:
257            raise KsftXfailEx("Test only works on netdevsim")
258        if nsim_test is False and self._ns is not None:
259            raise KsftXfailEx("Test does not work on netdevsim")
260
261    def _require_cmd(self, comm, key, host=None):
262        cached = self._required_cmd.get(comm, {})
263        if cached.get(key) is None:
264            cached[key] = cmd("command -v -- " + comm, fail=False,
265                              shell=True, host=host).ret == 0
266        self._required_cmd[comm] = cached
267        return cached[key]
268
269    def require_cmd(self, comm, local=True, remote=False):
270        if local:
271            if not self._require_cmd(comm, "local"):
272                raise KsftSkipEx("Test requires command: " + comm)
273        if remote:
274            if not self._require_cmd(comm, "remote", host=self.remote):
275                raise KsftSkipEx("Test requires (remote) command: " + comm)
276
277    def wait_hw_stats_settle(self):
278        """
279        Wait for HW stats to become consistent, some devices DMA HW stats
280        periodically so events won't be reflected until next sync.
281        Good drivers will tell us via ethtool what their sync period is.
282        """
283        if self._stats_settle_time is None:
284            data = {}
285            try:
286                data = ethtool("-c " + self.ifname, json=True)[0]
287            except CmdExitFailure as e:
288                if "Operation not supported" not in e.cmd.stderr:
289                    raise
290
291            self._stats_settle_time = 0.025 + \
292                data.get('stats-block-usecs', 0) / 1000 / 1000
293
294        time.sleep(self._stats_settle_time)
295
296
297class NetDrvContEnv(NetDrvEpEnv):
298    """
299    Class for an environment with a netkit pair setup for forwarding traffic
300    between the physical interface and a network namespace.
301      NETIF           = "eth0"
302      LOCAL_V6        = "2001:db8:1::1"
303      REMOTE_V6       = "2001:db8:1::2"
304      LOCAL_PREFIX_V6 = "2001:db8:2::0/64"
305
306              +-----------------------------+        +------------------------------+
307      dst     | INIT NS                     |        | TEST NS                      |
308      2001:   | +---------------+           |        |                              |
309      db8:2::2| | NETIF         |           |  bpf   |                              |
310          +---|>| 2001:db8:1::1 |           |redirect| +-------------------------+  |
311          |   | |               |-----------|--------|>| Netkit                  |  |
312          |   | +---------------+           | _peer  | | nk_guest                |  |
313          |   | +-------------+ Netkit pair |        | | fe80::2/64              |  |
314          |   | | Netkit      |.............|........|>| 2001:db8:2::2/64        |  |
315          |   | | nk_host     |             |        | +-------------------------+  |
316          |   | | fe80::1/64  |             |        |                              |
317          |   | +-------------+             |        | route:                       |
318          |   |                             |        |   default                    |
319          |   | route:                      |        |     via fe80::1 dev nk_guest |
320          |   |   2001:db8:2::2/128         |        +------------------------------+
321          |   |     via fe80::2 dev nk_host |
322          |   +-----------------------------+
323          |
324          |   +---------------+
325          |   | REMOTE        |
326          +---| 2001:db8:1::2 |
327              +---------------+
328    """
329
330    def __init__(self, src_path, rxqueues=1, **kwargs):
331        self.netns = None
332        self._nk_host_ifname = None
333        self._nk_guest_ifname = None
334        self._tc_clsact_added = False
335        self._tc_attached = False
336        self._bpf_prog_pref = None
337        self._bpf_prog_id = None
338        self._init_ns_attached = False
339        self._old_fwd = None
340        self._old_accept_ra = None
341
342        super().__init__(src_path, **kwargs)
343
344        self.require_ipver("6")
345        local_prefix = self.env.get("LOCAL_PREFIX_V6")
346        if not local_prefix:
347            raise KsftSkipEx("LOCAL_PREFIX_V6 required")
348
349        net = ipaddress.IPv6Network(local_prefix, strict=False)
350        self.ipv6_prefix = str(net.network_address)
351        self.nk_host_ipv6 = f"{self.ipv6_prefix}2:1"
352        self.nk_guest_ipv6 = f"{self.ipv6_prefix}2:2"
353
354        local_v6 = ipaddress.IPv6Address(self.addr_v["6"])
355        if local_v6 in net:
356            raise KsftSkipEx("LOCAL_V6 must not fall within LOCAL_PREFIX_V6")
357
358        rtnl = RtnlFamily()
359        rtnl.newlink(
360            {
361                "linkinfo": {
362                    "kind": "netkit",
363                    "data": {
364                        "mode": "l2",
365                        "policy": "forward",
366                        "peer-policy": "forward",
367                    },
368                },
369                "num-rx-queues": rxqueues,
370            },
371            flags=[Netlink.NLM_F_CREATE, Netlink.NLM_F_EXCL],
372        )
373
374        all_links = ip("-d link show", json=True)
375        netkit_links = [link for link in all_links
376                        if link.get('linkinfo', {}).get('info_kind') == 'netkit'
377                        and 'UP' not in link.get('flags', [])]
378
379        if len(netkit_links) != 2:
380            raise KsftSkipEx("Failed to create netkit pair")
381
382        netkit_links.sort(key=lambda x: x['ifindex'])
383        self._nk_host_ifname = netkit_links[1]['ifname']
384        self._nk_guest_ifname = netkit_links[0]['ifname']
385        self.nk_host_ifindex = netkit_links[1]['ifindex']
386        self.nk_guest_ifindex = netkit_links[0]['ifindex']
387
388        self._setup_ns()
389        self._attach_bpf()
390
391    def __del__(self):
392        if self._tc_attached:
393            cmd(f"tc filter del dev {self.ifname} ingress pref {self._bpf_prog_pref}")
394            self._tc_attached = False
395
396        if self._tc_clsact_added:
397            cmd(f"tc qdisc del dev {self.ifname} clsact")
398            self._tc_clsact_added = False
399
400        if self._nk_host_ifname:
401            cmd(f"ip link del dev {self._nk_host_ifname}")
402            self._nk_host_ifname = None
403            self._nk_guest_ifname = None
404
405        if self._init_ns_attached:
406            cmd("ip netns del init", fail=False)
407            self._init_ns_attached = False
408
409        if self.netns:
410            del self.netns
411            self.netns = None
412
413        if self._old_fwd is not None:
414            with open("/proc/sys/net/ipv6/conf/all/forwarding", "w",
415                      encoding="utf-8") as f:
416                f.write(self._old_fwd)
417            self._old_fwd = None
418        if self._old_accept_ra is not None:
419            with open("/proc/sys/net/ipv6/conf/all/accept_ra", "w",
420                      encoding="utf-8") as f:
421                f.write(self._old_accept_ra)
422            self._old_accept_ra = None
423
424        super().__del__()
425
426    def _setup_ns(self):
427        fwd_path = "/proc/sys/net/ipv6/conf/all/forwarding"
428        ra_path = "/proc/sys/net/ipv6/conf/all/accept_ra"
429        with open(fwd_path, encoding="utf-8") as f:
430            self._old_fwd = f.read().strip()
431        with open(ra_path, encoding="utf-8") as f:
432            self._old_accept_ra = f.read().strip()
433        with open(fwd_path, "w", encoding="utf-8") as f:
434            f.write("1")
435        with open(ra_path, "w", encoding="utf-8") as f:
436            f.write("2")
437
438        self.netns = NetNS()
439        cmd("ip netns attach init 1")
440        self._init_ns_attached = True
441        ip("netns set init 0", ns=self.netns)
442        ip(f"link set dev {self._nk_guest_ifname} netns {self.netns.name}")
443        ip(f"link set dev {self._nk_host_ifname} up")
444        ip(f"-6 addr add fe80::1/64 dev {self._nk_host_ifname} nodad")
445        ip(f"-6 route add {self.nk_guest_ipv6}/128 via fe80::2 dev {self._nk_host_ifname}")
446
447        ip("link set lo up", ns=self.netns)
448        ip(f"link set dev {self._nk_guest_ifname} up", ns=self.netns)
449        ip(f"-6 addr add fe80::2/64 dev {self._nk_guest_ifname}", ns=self.netns)
450        ip(f"-6 addr add {self.nk_guest_ipv6}/64 dev {self._nk_guest_ifname} nodad", ns=self.netns)
451        ip(f"-6 route add default via fe80::1 dev {self._nk_guest_ifname}", ns=self.netns)
452
453    def _tc_ensure_clsact(self):
454        qdisc = json.loads(cmd(f"tc -j qdisc show dev {self.ifname}").stdout)
455        for q in qdisc:
456            if q['kind'] == 'clsact':
457                return
458        cmd(f"tc qdisc add dev {self.ifname} clsact")
459        self._tc_clsact_added = True
460
461    def _get_bpf_prog_ids(self):
462        filters = json.loads(cmd(f"tc -j filter show dev {self.ifname} ingress").stdout)
463        for bpf in filters:
464            if 'options' not in bpf:
465                continue
466            if bpf['options']['bpf_name'].startswith('nk_forward.bpf'):
467                return (bpf['pref'], bpf['options']['prog']['id'])
468        raise Exception("Failed to get BPF prog ID")
469
470    def _attach_bpf(self):
471        bpf_obj = self.test_dir / "nk_forward.bpf.o"
472        if not bpf_obj.exists():
473            raise KsftSkipEx("BPF prog not found")
474
475        self._tc_ensure_clsact()
476        cmd(f"tc filter add dev {self.ifname} ingress bpf obj {bpf_obj}"
477            " sec tc/ingress direct-action")
478        self._tc_attached = True
479
480        (self._bpf_prog_pref, self._bpf_prog_id) = self._get_bpf_prog_ids()
481        prog_info = bpftool(f"prog show id {self._bpf_prog_id}", json=True)
482        map_ids = prog_info.get("map_ids", [])
483
484        bss_map_id = None
485        for map_id in map_ids:
486            map_info = bpftool(f"map show id {map_id}", json=True)
487            if map_info.get("name").endswith("bss"):
488                bss_map_id = map_id
489
490        if bss_map_id is None:
491            raise Exception("Failed to find .bss map")
492
493        ipv6_addr = ipaddress.IPv6Address(self.ipv6_prefix)
494        ipv6_bytes = ipv6_addr.packed
495        ifindex_bytes = self.nk_host_ifindex.to_bytes(4, byteorder='little')
496        value = ipv6_bytes + ifindex_bytes
497        value_hex = ' '.join(f'{b:02x}' for b in value)
498        bpftool(f"map update id {bss_map_id} key hex 00 00 00 00 value hex {value_hex}")
499