xref: /linux/tools/testing/selftests/drivers/net/lib/py/env.py (revision ab771c938d9a57d510bb70c565c9388b10494090)
1# SPDX-License-Identifier: GPL-2.0
2
3import ipaddress
4import os
5import re
6import time
7from pathlib import Path
8from lib.py import KsftSkipEx, KsftXfailEx
9from lib.py import ksft_setup, wait_file
10from lib.py import cmd, ethtool, ip, CmdExitFailure
11from lib.py import NetNS, NetdevSimDev
12from lib.py import NetdevFamily, EthtoolFamily
13from .remote import Remote
14from . import bpftool
15
16
17class NetDrvEnvBase:
18    """
19    Base class for a NIC / host environments
20
21    Attributes:
22      test_dir: Path to the source directory of the test
23      net_lib_dir: Path to the net/lib directory
24    """
25    def __init__(self, src_path):
26        self.src_path = Path(src_path)
27        self.test_dir = self.src_path.parent.resolve()
28        self.net_lib_dir = (Path(__file__).parent / "../../../../net/lib").resolve()
29
30        self.env = self._load_env_file()
31
32        # Following attrs must be set be inheriting classes
33        self.dev = None
34
35    def _load_env_file(self):
36        env = os.environ.copy()
37
38        src_dir = Path(self.src_path).parent.resolve()
39        if not (src_dir / "net.config").exists():
40            return ksft_setup(env)
41
42        with open((src_dir / "net.config").as_posix(), 'r') as fp:
43            for line in fp.readlines():
44                full_file = line
45                # Strip comments
46                pos = line.find("#")
47                if pos >= 0:
48                    line = line[:pos]
49                line = line.strip()
50                if not line:
51                    continue
52                pair = line.split('=', maxsplit=1)
53                if len(pair) != 2:
54                    raise Exception("Can't parse configuration line:", full_file)
55                env[pair[0]] = pair[1]
56        return ksft_setup(env)
57
58    def __del__(self):
59        pass
60
61    def __enter__(self):
62        ip(f"link set dev {self.dev['ifname']} up")
63        wait_file(f"/sys/class/net/{self.dev['ifname']}/carrier",
64                  lambda x: x.strip() == "1")
65
66        return self
67
68    def __exit__(self, ex_type, ex_value, ex_tb):
69        """
70        __exit__ gets called at the end of a "with" block.
71        """
72        self.__del__()
73
74
75class NetDrvEnv(NetDrvEnvBase):
76    """
77    Class for a single NIC / host env, with no remote end
78    """
79    def __init__(self, src_path, nsim_test=None, **kwargs):
80        super().__init__(src_path)
81
82        self._ns = None
83
84        if 'NETIF' in self.env:
85            if nsim_test is True:
86                raise KsftXfailEx("Test only works on netdevsim")
87
88            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
89        else:
90            if nsim_test is False:
91                raise KsftXfailEx("Test does not work on netdevsim")
92
93            self._ns = NetdevSimDev(**kwargs)
94            self.dev = self._ns.nsims[0].dev
95        self.ifname = self.dev['ifname']
96        self.ifindex = self.dev['ifindex']
97
98    def __del__(self):
99        if self._ns:
100            self._ns.remove()
101            self._ns = None
102
103
104class NetDrvEpEnv(NetDrvEnvBase):
105    """
106    Class for an environment with a local device and "remote endpoint"
107    which can be used to send traffic in.
108
109    For local testing it creates two network namespaces and a pair
110    of netdevsim devices.
111    """
112
113    # Network prefixes used for local tests
114    nsim_v4_pfx = "192.0.2."
115    nsim_v6_pfx = "2001:db8::"
116
117    def __init__(self, src_path, nsim_test=None):
118        super().__init__(src_path)
119
120        self._stats_settle_time = None
121
122        # Things we try to destroy
123        self.remote = None
124        # These are for local testing state
125        self._netns = None
126        self._ns = None
127        self._ns_peer = None
128
129        self.addr_v        = { "4": None, "6": None }
130        self.remote_addr_v = { "4": None, "6": None }
131
132        if "NETIF" in self.env:
133            if nsim_test is True:
134                raise KsftXfailEx("Test only works on netdevsim")
135            self._check_env()
136
137            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
138
139            self.addr_v["4"] = self.env.get("LOCAL_V4")
140            self.addr_v["6"] = self.env.get("LOCAL_V6")
141            self.remote_addr_v["4"] = self.env.get("REMOTE_V4")
142            self.remote_addr_v["6"] = self.env.get("REMOTE_V6")
143            kind = self.env["REMOTE_TYPE"]
144            args = self.env["REMOTE_ARGS"]
145        else:
146            if nsim_test is False:
147                raise KsftXfailEx("Test does not work on netdevsim")
148
149            self.create_local()
150
151            self.dev = self._ns.nsims[0].dev
152
153            self.addr_v["4"] = self.nsim_v4_pfx + "1"
154            self.addr_v["6"] = self.nsim_v6_pfx + "1"
155            self.remote_addr_v["4"] = self.nsim_v4_pfx + "2"
156            self.remote_addr_v["6"] = self.nsim_v6_pfx + "2"
157            kind = "netns"
158            args = self._netns.name
159
160        self.remote = Remote(kind, args, src_path)
161
162        self.addr_ipver = "6" if self.addr_v["6"] else "4"
163        self.addr = self.addr_v[self.addr_ipver]
164        self.remote_addr = self.remote_addr_v[self.addr_ipver]
165
166        # Bracketed addresses, some commands need IPv6 to be inside []
167        self.baddr = f"[{self.addr_v['6']}]" if self.addr_v["6"] else self.addr_v["4"]
168        self.remote_baddr = f"[{self.remote_addr_v['6']}]" if self.remote_addr_v["6"] else self.remote_addr_v["4"]
169
170        self.ifname = self.dev['ifname']
171        self.ifindex = self.dev['ifindex']
172
173        # resolve remote interface name
174        self.remote_ifname = self.resolve_remote_ifc()
175        self.remote_dev = ip("-d link show dev " + self.remote_ifname,
176                             host=self.remote, json=True)[0]
177        self.remote_ifindex = self.remote_dev['ifindex']
178
179        self._required_cmd = {}
180
181    def create_local(self):
182        self._netns = NetNS()
183        self._ns = NetdevSimDev()
184        self._ns_peer = NetdevSimDev(ns=self._netns)
185
186        with open("/proc/self/ns/net") as nsfd0, \
187             open("/var/run/netns/" + self._netns.name) as nsfd1:
188            ifi0 = self._ns.nsims[0].ifindex
189            ifi1 = self._ns_peer.nsims[0].ifindex
190            NetdevSimDev.ctrl_write('link_device',
191                                    f'{nsfd0.fileno()}:{ifi0} {nsfd1.fileno()}:{ifi1}')
192
193        ip(f"   addr add dev {self._ns.nsims[0].ifname} {self.nsim_v4_pfx}1/24")
194        ip(f"-6 addr add dev {self._ns.nsims[0].ifname} {self.nsim_v6_pfx}1/64 nodad")
195        ip(f"   link set dev {self._ns.nsims[0].ifname} up")
196
197        ip(f"   addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v4_pfx}2/24", ns=self._netns)
198        ip(f"-6 addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v6_pfx}2/64 nodad", ns=self._netns)
199        ip(f"   link set dev {self._ns_peer.nsims[0].ifname} up", ns=self._netns)
200
201    def _check_env(self):
202        vars_needed = [
203            ["LOCAL_V4", "LOCAL_V6"],
204            ["REMOTE_V4", "REMOTE_V6"],
205            ["REMOTE_TYPE"],
206            ["REMOTE_ARGS"]
207        ]
208        missing = []
209
210        for choice in vars_needed:
211            for entry in choice:
212                if entry in self.env:
213                    break
214            else:
215                missing.append(choice)
216        # Make sure v4 / v6 configs are symmetric
217        if ("LOCAL_V6" in self.env) != ("REMOTE_V6" in self.env):
218            missing.append(["LOCAL_V6", "REMOTE_V6"])
219        if ("LOCAL_V4" in self.env) != ("REMOTE_V4" in self.env):
220            missing.append(["LOCAL_V4", "REMOTE_V4"])
221        if missing:
222            raise Exception("Invalid environment, missing configuration:", missing,
223                            "Please see tools/testing/selftests/drivers/net/README.rst")
224
225    def resolve_remote_ifc(self):
226        v4 = v6 = None
227        if self.remote_addr_v["4"]:
228            v4 = ip("addr show to " + self.remote_addr_v["4"], json=True, host=self.remote)
229        if self.remote_addr_v["6"]:
230            v6 = ip("addr show to " + self.remote_addr_v["6"], json=True, host=self.remote)
231        if v4 and v6 and v4[0]["ifname"] != v6[0]["ifname"]:
232            raise Exception("Can't resolve remote interface name, v4 and v6 don't match")
233        if (v4 and len(v4) > 1) or (v6 and len(v6) > 1):
234            raise Exception("Can't resolve remote interface name, multiple interfaces match")
235        return v6[0]["ifname"] if v6 else v4[0]["ifname"]
236
237    def __del__(self):
238        if self._ns:
239            self._ns.remove()
240            self._ns = None
241        if self._ns_peer:
242            self._ns_peer.remove()
243            self._ns_peer = None
244        if self._netns:
245            del self._netns
246            self._netns = None
247        if self.remote:
248            del self.remote
249            self.remote = None
250
251    def require_ipver(self, ipver):
252        if not self.addr_v[ipver] or not self.remote_addr_v[ipver]:
253            raise KsftSkipEx(f"Test requires IPv{ipver} connectivity")
254
255    def require_nsim(self, nsim_test=True):
256        """Require or exclude netdevsim for this test"""
257        if nsim_test and self._ns is None:
258            raise KsftXfailEx("Test only works on netdevsim")
259        if nsim_test is False and self._ns is not None:
260            raise KsftXfailEx("Test does not work on netdevsim")
261
262    def _require_cmd(self, comm, key, host=None):
263        cached = self._required_cmd.get(comm, {})
264        if cached.get(key) is None:
265            cached[key] = cmd("command -v -- " + comm, fail=False,
266                              shell=True, host=host).ret == 0
267        self._required_cmd[comm] = cached
268        return cached[key]
269
270    def require_cmd(self, comm, local=True, remote=False):
271        if local:
272            if not self._require_cmd(comm, "local"):
273                raise KsftSkipEx("Test requires command: " + comm)
274        if remote:
275            if not self._require_cmd(comm, "remote", host=self.remote):
276                raise KsftSkipEx("Test requires (remote) command: " + comm)
277
278    def wait_hw_stats_settle(self):
279        """
280        Wait for HW stats to become consistent, some devices DMA HW stats
281        periodically so events won't be reflected until next sync.
282        Good drivers will tell us via ethtool what their sync period is.
283        """
284        if self._stats_settle_time is None:
285            data = {}
286            try:
287                data = ethtool("-c " + self.ifname, json=True)[0]
288            except CmdExitFailure as e:
289                if "Operation not supported" not in e.cmd.stderr:
290                    raise
291
292            self._stats_settle_time = 0.025 + \
293                data.get('stats-block-usecs', 0) / 1000 / 1000
294
295        time.sleep(self._stats_settle_time)
296
297
298class NetDrvContEnv(NetDrvEpEnv):
299    """
300    Class for an environment with a netkit pair setup for forwarding traffic
301    between the physical interface and a network namespace.
302    """
303
304    def __init__(self, src_path, lease=False, **kwargs):
305        super().__init__(src_path, **kwargs)
306
307        self.require_ipver("6")
308        local_prefix = self.env.get("LOCAL_PREFIX_V6")
309        if not local_prefix:
310            raise KsftSkipEx("LOCAL_PREFIX_V6 required")
311
312        self.netdevnl = NetdevFamily()
313        self.ethnl = EthtoolFamily()
314
315        local_prefix = local_prefix.rstrip("/64").rstrip("::").rstrip(":")
316        self.ipv6_prefix = f"{local_prefix}::"
317        self.nk_host_ipv6 = f"{local_prefix}::2:1"
318        self.nk_guest_ipv6 = f"{local_prefix}::2:2"
319
320        self.netns = None
321        self._nk_host_ifname = None
322        self._nk_guest_ifname = None
323        self._tc_attached = False
324        self._bpf_prog_pref = None
325        self._bpf_prog_id = None
326        self._leased = False
327
328        nk_rxqueues = 1
329        if lease:
330            nk_rxqueues = 2
331        ip(f"link add type netkit mode l2 forward peer forward numrxqueues {nk_rxqueues}")
332
333        all_links = ip("-d link show", json=True)
334        netkit_links = [link for link in all_links
335                        if link.get('linkinfo', {}).get('info_kind') == 'netkit'
336                        and 'UP' not in link.get('flags', [])]
337
338        if len(netkit_links) != 2:
339            raise KsftSkipEx("Failed to create netkit pair")
340
341        netkit_links.sort(key=lambda x: x['ifindex'])
342        self._nk_host_ifname = netkit_links[1]['ifname']
343        self._nk_guest_ifname = netkit_links[0]['ifname']
344        self.nk_host_ifindex = netkit_links[1]['ifindex']
345        self.nk_guest_ifindex = netkit_links[0]['ifindex']
346
347        if lease:
348            self._lease_queues()
349
350        self._setup_ns()
351        self._attach_bpf()
352
353    def __del__(self):
354        if self._tc_attached:
355            cmd(f"tc filter del dev {self.ifname} ingress pref {self._bpf_prog_pref}")
356            self._tc_attached = False
357
358        if self._nk_host_ifname:
359            cmd(f"ip link del dev {self._nk_host_ifname}")
360            self._nk_host_ifname = None
361            self._nk_guest_ifname = None
362
363        if self.netns:
364            del self.netns
365            self.netns = None
366
367        if self._leased:
368            self.ethnl.rings_set({'header': {'dev-index': self.ifindex},
369                                  'tcp-data-split': 'unknown',
370                                  'hds-thresh': self._hds_thresh,
371                                  'rx': self._rx_rings})
372            self._leased = False
373
374        super().__del__()
375
376    def _lease_queues(self):
377        channels = self.ethnl.channels_get({'header': {'dev-index': self.ifindex}})
378        channels = channels['combined-count']
379        if channels < 2:
380            raise KsftSkipEx('Test requires NETIF with at least 2 combined channels')
381
382        rings = self.ethnl.rings_get({'header': {'dev-index': self.ifindex}})
383        self._rx_rings = rings['rx']
384        self._hds_thresh = rings.get('hds-thresh', 0)
385        self.ethnl.rings_set({'header': {'dev-index': self.ifindex},
386                            'tcp-data-split': 'enabled',
387                            'hds-thresh': 0,
388                            'rx': 64})
389        self.src_queue = channels - 1
390        bind_result = self.netdevnl.queue_create(
391            {
392                "ifindex": self.nk_guest_ifindex,
393                "type": "rx",
394                "lease": {
395                    "ifindex": self.ifindex,
396                    "queue": {"id": self.src_queue, "type": "rx"},
397                },
398            }
399        )
400        self.nk_queue = bind_result['id']
401        self._leased = True
402
403    def _setup_ns(self):
404        self.netns = NetNS()
405        ip(f"link set dev {self._nk_guest_ifname} netns {self.netns.name}")
406        ip(f"link set dev {self._nk_host_ifname} up")
407        ip(f"-6 addr add fe80::1/64 dev {self._nk_host_ifname} nodad")
408        ip(f"-6 route add {self.nk_guest_ipv6}/128 via fe80::2 dev {self._nk_host_ifname}")
409
410        ip("link set lo up", ns=self.netns)
411        ip(f"link set dev {self._nk_guest_ifname} up", ns=self.netns)
412        ip(f"-6 addr add fe80::2/64 dev {self._nk_guest_ifname}", ns=self.netns)
413        ip(f"-6 addr add {self.nk_guest_ipv6}/64 dev {self._nk_guest_ifname} nodad", ns=self.netns)
414        ip(f"-6 route add default via fe80::1 dev {self._nk_guest_ifname}", ns=self.netns)
415
416    def _attach_bpf(self):
417        bpf_obj = self.test_dir / "nk_forward.bpf.o"
418        if not bpf_obj.exists():
419            raise KsftSkipEx("BPF prog not found")
420
421        cmd(f"tc filter add dev {self.ifname} ingress bpf obj {bpf_obj} sec tc/ingress direct-action")
422        self._tc_attached = True
423
424        tc_info = cmd(f"tc filter show dev {self.ifname} ingress").stdout
425        match = re.search(r'pref (\d+).*nk_forward\.bpf.*id (\d+)', tc_info)
426        if not match:
427            raise Exception("Failed to get BPF prog ID")
428        self._bpf_prog_pref = int(match.group(1))
429        self._bpf_prog_id = int(match.group(2))
430
431        prog_info = bpftool(f"prog show id {self._bpf_prog_id}", json=True)
432        map_ids = prog_info.get("map_ids", [])
433
434        bss_map_id = None
435        for map_id in map_ids:
436            map_info = bpftool(f"map show id {map_id}", json=True)
437            if map_info.get("name").endswith("bss"):
438                bss_map_id = map_id
439
440        if bss_map_id is None:
441            raise Exception("Failed to find .bss map")
442
443        ipv6_addr = ipaddress.IPv6Address(self.ipv6_prefix)
444        ipv6_bytes = ipv6_addr.packed
445        ifindex_bytes = self.nk_host_ifindex.to_bytes(4, byteorder='little')
446        value = ipv6_bytes + ifindex_bytes
447        value_hex = ' '.join(f'{b:02x}' for b in value)
448        bpftool(f"map update id {bss_map_id} key hex 00 00 00 00 value hex {value_hex}")
449