xref: /linux/tools/testing/selftests/drivers/net/psp.py (revision 50d3bdfb84c88408934f75430d0e3d2baa4f5d7a)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4"""Test suite for PSP capable drivers."""
5
6import errno
7import fcntl
8import os
9import socket
10import struct
11import termios
12import time
13
14from lib.py import defer
15from lib.py import ksft_run, ksft_exit, ksft_pr
16from lib.py import ksft_true, ksft_eq, ksft_ne, ksft_gt, ksft_raises
17from lib.py import ksft_not_none
18from lib.py import ksft_variants, KsftNamedVariant
19from lib.py import KsftSkipEx, KsftFailEx
20from lib.py import NetDrvEpEnv, NetDrvContEnv
21from lib.py import Netlink, NlError, PSPFamily, RtnlFamily
22from lib.py import NetNSEnter
23from lib.py import bkg, rand_port, wait_port_listen
24from lib.py import ip
25
26
27def _get_outq(s):
28    one = b'\0' * 4
29    outq = fcntl.ioctl(s.fileno(), termios.TIOCOUTQ, one)
30    return struct.unpack("I", outq)[0]
31
32
33def _send_with_ack(cfg, msg):
34    cfg.comm_sock.send(msg)
35    response = cfg.comm_sock.recv(4)
36    if response != b'ack\0':
37        raise RuntimeError("Unexpected server response", response)
38
39
40def _remote_read_len(cfg):
41    cfg.comm_sock.send(b'read len\0')
42    return int(cfg.comm_sock.recv(1024)[:-1].decode('utf-8'))
43
44
45def _make_clr_conn(cfg, ipver=None):
46    _send_with_ack(cfg, b'conn clr\0')
47    remote_addr = cfg.remote_addr_v[ipver] if ipver else cfg.remote_addr
48    s = socket.create_connection((remote_addr, cfg.comm_port), )
49    return s
50
51
52def _make_psp_conn(cfg, version=0, ipver=None):
53    _send_with_ack(cfg, b'conn psp\0' + struct.pack('BB', version, version))
54    remote_addr = cfg.remote_addr_v[ipver] if ipver else cfg.remote_addr
55    s = socket.create_connection((remote_addr, cfg.comm_port), )
56    return s
57
58
59def _close_conn(cfg, s):
60    _send_with_ack(cfg, b'data close\0')
61    s.close()
62
63
64def _close_psp_conn(cfg, s):
65    _close_conn(cfg, s)
66
67
68def _spi_xchg(s, rx):
69    s.send(struct.pack('I', rx['spi']) + rx['key'])
70    tx = s.recv(4 + len(rx['key']))
71    return {
72        'spi': struct.unpack('I', tx[:4])[0],
73        'key': tx[4:]
74    }
75
76
77def _send_careful(cfg, s, rounds):
78    data = b'0123456789' * 200
79    for i in range(rounds):
80        n = 0
81        for _ in range(10): # allow 10 retries
82            try:
83                n += s.send(data[n:], socket.MSG_DONTWAIT)
84                if n == len(data):
85                    break
86            except BlockingIOError:
87                time.sleep(0.05)
88        else:
89            rlen = _remote_read_len(cfg)
90            outq = _get_outq(s)
91            report = f'sent: {i * len(data) + n} remote len: {rlen} outq: {outq}'
92            raise RuntimeError(report)
93
94    return len(data) * rounds
95
96
97def _check_data_rx(cfg, exp_len):
98    read_len = -1
99    for _ in range(30):
100        cfg.comm_sock.send(b'read len\0')
101        read_len = int(cfg.comm_sock.recv(1024)[:-1].decode('utf-8'))
102        if read_len == exp_len:
103            break
104        time.sleep(0.01)
105    ksft_eq(read_len, exp_len)
106
107
108def _check_data_outq(s, exp_len, force_wait=False):
109    outq = 0
110    for _ in range(10):
111        outq = _get_outq(s)
112        if not force_wait and outq == exp_len:
113            break
114        time.sleep(0.01)
115    ksft_eq(outq, exp_len)
116
117
118def _get_stat(cfg, key):
119    return cfg.pspnl.get_stats({'dev-id': cfg.psp_dev_id})[key]
120
121#
122# Test case boiler plate
123#
124
125def _init_psp_dev(cfg, use_psp_ifindex=False):
126    if not hasattr(cfg, 'psp_dev_id'):
127        # Figure out which local device we are testing against
128        # For NetDrvContEnv: use psp_ifindex instead of ifindex
129        target_ifindex = cfg.psp_ifindex if use_psp_ifindex else cfg.ifindex
130        for dev in cfg.pspnl.dev_get({}, dump=True):
131            if dev['ifindex'] == target_ifindex:
132                cfg.psp_info = dev
133                cfg.psp_dev_id = cfg.psp_info['id']
134                break
135        else:
136            raise KsftSkipEx("No PSP devices found")
137
138    # Enable PSP if necessary
139    cap = cfg.psp_info['psp-versions-cap']
140    ena = cfg.psp_info['psp-versions-ena']
141    if cap != ena:
142        cfg.pspnl.dev_set({'id': cfg.psp_dev_id, 'psp-versions-ena': cap})
143        defer(cfg.pspnl.dev_set, {'id': cfg.psp_dev_id,
144                                  'psp-versions-ena': ena })
145
146#
147# Test cases
148#
149
150def dev_list_devices(cfg):
151    """ Dump all devices """
152    _init_psp_dev(cfg)
153
154    devices = cfg.pspnl.dev_get({}, dump=True)
155
156    found = False
157    for dev in devices:
158        found |= dev['id'] == cfg.psp_dev_id
159    ksft_true(found)
160
161
162def dev_get_device(cfg):
163    """ Get the device we intend to use """
164    _init_psp_dev(cfg)
165
166    dev = cfg.pspnl.dev_get({'id': cfg.psp_dev_id})
167    ksft_eq(dev['id'], cfg.psp_dev_id)
168
169
170def dev_get_device_bad(cfg):
171    """ Test getting device which doesn't exist """
172    raised = False
173    try:
174        cfg.pspnl.dev_get({'id': 1234567})
175    except NlError as e:
176        ksft_eq(e.nl_msg.error, -errno.ENODEV)
177        raised = True
178    ksft_true(raised)
179
180
181def dev_rotate(cfg):
182    """ Test key rotation """
183    _init_psp_dev(cfg)
184
185    prev_rotations = _get_stat(cfg, 'key-rotations')
186
187    rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
188    ksft_eq(rot['id'], cfg.psp_dev_id)
189    rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
190    ksft_eq(rot['id'], cfg.psp_dev_id)
191
192    cur_rotations = _get_stat(cfg, 'key-rotations')
193    ksft_eq(cur_rotations, prev_rotations + 2)
194
195
196def dev_rotate_spi(cfg):
197    """ Test key rotation and SPI check """
198    _init_psp_dev(cfg)
199
200    top_a = top_b = 0
201    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
202        assoc_a = cfg.pspnl.rx_assoc({"version": 0,
203                                     "dev-id": cfg.psp_dev_id,
204                                     "sock-fd": s.fileno()})
205        top_a = assoc_a['rx-key']['spi'] >> 31
206        s.close()
207    rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
208    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
209        ksft_eq(rot['id'], cfg.psp_dev_id)
210        assoc_b = cfg.pspnl.rx_assoc({"version": 0,
211                                    "dev-id": cfg.psp_dev_id,
212                                    "sock-fd": s.fileno()})
213        top_b = assoc_b['rx-key']['spi'] >> 31
214        s.close()
215    ksft_ne(top_a, top_b)
216
217
218def assoc_basic(cfg):
219    """ Test creating associations """
220    _init_psp_dev(cfg)
221
222    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
223        assoc = cfg.pspnl.rx_assoc({"version": 0,
224                                  "dev-id": cfg.psp_dev_id,
225                                  "sock-fd": s.fileno()})
226        ksft_eq(assoc['dev-id'], cfg.psp_dev_id)
227        ksft_gt(assoc['rx-key']['spi'], 0)
228        ksft_eq(len(assoc['rx-key']['key']), 16)
229
230        assoc = cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
231                                  "version": 0,
232                                  "tx-key": assoc['rx-key'],
233                                  "sock-fd": s.fileno()})
234        ksft_eq(len(assoc), 0)
235        s.close()
236
237
238def assoc_bad_dev(cfg):
239    """ Test creating associations with bad device ID """
240    _init_psp_dev(cfg)
241
242    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
243        with ksft_raises(NlError) as cm:
244            cfg.pspnl.rx_assoc({"version": 0,
245                              "dev-id": cfg.psp_dev_id + 1234567,
246                              "sock-fd": s.fileno()})
247        ksft_eq(cm.exception.nl_msg.error, -errno.ENODEV)
248
249
250def assoc_sk_only_conn(cfg):
251    """ Test creating associations based on socket """
252    _init_psp_dev(cfg)
253
254    with _make_clr_conn(cfg) as s:
255        assoc = cfg.pspnl.rx_assoc({"version": 0,
256                                  "sock-fd": s.fileno()})
257        ksft_eq(assoc['dev-id'], cfg.psp_dev_id)
258        cfg.pspnl.tx_assoc({"version": 0,
259                          "tx-key": assoc['rx-key'],
260                          "sock-fd": s.fileno()})
261        _close_conn(cfg, s)
262
263
264def assoc_sk_only_mismatch(cfg):
265    """ Test creating associations based on socket (dev mismatch) """
266    _init_psp_dev(cfg)
267
268    with _make_clr_conn(cfg) as s:
269        with ksft_raises(NlError) as cm:
270            cfg.pspnl.rx_assoc({"version": 0,
271                              "dev-id": cfg.psp_dev_id + 1234567,
272                              "sock-fd": s.fileno()})
273        the_exception = cm.exception
274        ksft_eq(the_exception.nl_msg.extack['bad-attr'], ".dev-id")
275        ksft_eq(the_exception.nl_msg.error, -errno.EINVAL)
276        _close_conn(cfg, s)
277
278
279def assoc_sk_only_mismatch_tx(cfg):
280    """ Test creating associations based on socket (dev mismatch) """
281    _init_psp_dev(cfg)
282
283    with _make_clr_conn(cfg) as s:
284        with ksft_raises(NlError) as cm:
285            assoc = cfg.pspnl.rx_assoc({"version": 0,
286                                      "sock-fd": s.fileno()})
287            cfg.pspnl.tx_assoc({"version": 0,
288                              "tx-key": assoc['rx-key'],
289                              "dev-id": cfg.psp_dev_id + 1234567,
290                              "sock-fd": s.fileno()})
291        the_exception = cm.exception
292        ksft_eq(the_exception.nl_msg.extack['bad-attr'], ".dev-id")
293        ksft_eq(the_exception.nl_msg.error, -errno.EINVAL)
294        _close_conn(cfg, s)
295
296
297def assoc_sk_only_unconn(cfg):
298    """ Test creating associations based on socket (unconnected, should fail) """
299    _init_psp_dev(cfg)
300
301    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
302        with ksft_raises(NlError) as cm:
303            cfg.pspnl.rx_assoc({"version": 0,
304                              "sock-fd": s.fileno()})
305        the_exception = cm.exception
306        ksft_eq(the_exception.nl_msg.extack['miss-type'], "dev-id")
307        ksft_eq(the_exception.nl_msg.error, -errno.EINVAL)
308
309
310def assoc_version_mismatch(cfg):
311    """ Test creating associations where Rx and Tx PSP versions do not match """
312    _init_psp_dev(cfg)
313
314    versions = list(cfg.psp_info['psp-versions-cap'])
315    if len(versions) < 2:
316        raise KsftSkipEx("Not enough PSP versions supported by the device for the test")
317
318    # Translate versions to integers
319    versions = [cfg.pspnl.consts["version"].entries[v].value for v in versions]
320
321    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
322        rx = cfg.pspnl.rx_assoc({"version": versions[0],
323                                 "dev-id": cfg.psp_dev_id,
324                                 "sock-fd": s.fileno()})
325
326        for version in versions[1:]:
327            with ksft_raises(NlError) as cm:
328                cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
329                                    "version": version,
330                                    "tx-key": rx['rx-key'],
331                                    "sock-fd": s.fileno()})
332            the_exception = cm.exception
333            ksft_eq(the_exception.nl_msg.error, -errno.EINVAL)
334
335
336def assoc_twice(cfg):
337    """ Test reusing Tx assoc for two sockets """
338    _init_psp_dev(cfg)
339
340    def rx_assoc_check(s):
341        assoc = cfg.pspnl.rx_assoc({"version": 0,
342                                  "dev-id": cfg.psp_dev_id,
343                                  "sock-fd": s.fileno()})
344        ksft_eq(assoc['dev-id'], cfg.psp_dev_id)
345        ksft_gt(assoc['rx-key']['spi'], 0)
346        ksft_eq(len(assoc['rx-key']['key']), 16)
347
348        return assoc
349
350    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
351        assoc = rx_assoc_check(s)
352        tx = cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
353                               "version": 0,
354                               "tx-key": assoc['rx-key'],
355                               "sock-fd": s.fileno()})
356        ksft_eq(len(tx), 0)
357
358        # Use the same Tx assoc second time
359        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s2:
360            rx_assoc_check(s2)
361            tx = cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
362                                   "version": 0,
363                                   "tx-key": assoc['rx-key'],
364                                   "sock-fd": s2.fileno()})
365            ksft_eq(len(tx), 0)
366
367        s.close()
368
369
370def _data_basic_send(cfg, version, ipver):
371    """ Test basic data send """
372    _init_psp_dev(cfg)
373
374    # Version 0 is required by spec, don't let it skip
375    if version:
376        name = cfg.pspnl.consts["version"].entries_by_val[version].name
377        if name not in cfg.psp_info['psp-versions-cap']:
378            with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
379                with ksft_raises(NlError) as cm:
380                    cfg.pspnl.rx_assoc({"version": version,
381                                        "dev-id": cfg.psp_dev_id,
382                                        "sock-fd": s.fileno()})
383                ksft_eq(cm.exception.nl_msg.error, -errno.EOPNOTSUPP)
384            raise KsftSkipEx("PSP version not supported", name)
385
386    s = _make_psp_conn(cfg, version, ipver)
387
388    rx_assoc = cfg.pspnl.rx_assoc({"version": version,
389                                   "dev-id": cfg.psp_dev_id,
390                                   "sock-fd": s.fileno()})
391    rx = rx_assoc['rx-key']
392    tx = _spi_xchg(s, rx)
393
394    cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
395                        "version": version,
396                        "tx-key": tx,
397                        "sock-fd": s.fileno()})
398
399    data_len = _send_careful(cfg, s, 100)
400    _check_data_rx(cfg, data_len)
401    _close_psp_conn(cfg, s)
402
403
404def __bad_xfer_do(cfg, s, tx, version='hdr0-aes-gcm-128'):
405    # Make sure we accept the ACK for the SPI before we seal with the bad assoc
406    _check_data_outq(s, 0)
407
408    cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
409                        "version": version,
410                        "tx-key": tx,
411                        "sock-fd": s.fileno()})
412
413    data_len = _send_careful(cfg, s, 20)
414    _check_data_outq(s, data_len, force_wait=True)
415    _check_data_rx(cfg, 0)
416    _close_psp_conn(cfg, s)
417
418
419def data_send_bad_key(cfg):
420    """ Test send data with bad key """
421    _init_psp_dev(cfg)
422
423    s = _make_psp_conn(cfg)
424
425    rx_assoc = cfg.pspnl.rx_assoc({"version": 0,
426                                   "dev-id": cfg.psp_dev_id,
427                                   "sock-fd": s.fileno()})
428    rx = rx_assoc['rx-key']
429    tx = _spi_xchg(s, rx)
430    tx['key'] = (tx['key'][0] ^ 0xff).to_bytes(1, 'little') + tx['key'][1:]
431    __bad_xfer_do(cfg, s, tx)
432
433
434def data_send_disconnect(cfg):
435    """ Test socket close after sending data """
436    _init_psp_dev(cfg)
437
438    with _make_psp_conn(cfg) as s:
439        assoc = cfg.pspnl.rx_assoc({"version": 0,
440                                  "sock-fd": s.fileno()})
441        tx = _spi_xchg(s, assoc['rx-key'])
442        cfg.pspnl.tx_assoc({"version": 0,
443                          "tx-key": tx,
444                          "sock-fd": s.fileno()})
445
446        data_len = _send_careful(cfg, s, 100)
447        _check_data_rx(cfg, data_len)
448
449        s.shutdown(socket.SHUT_RDWR)
450        s.close()
451
452
453def _data_mss_adjust(cfg, ipver):
454    _init_psp_dev(cfg)
455
456    # First figure out what the MSS would be without any adjustments
457    s = _make_clr_conn(cfg, ipver)
458    s.send(b"0123456789abcdef" * 1024)
459    _check_data_rx(cfg, 16 * 1024)
460    mss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG)
461    _close_conn(cfg, s)
462
463    s = _make_psp_conn(cfg, 0, ipver)
464    try:
465        rx_assoc = cfg.pspnl.rx_assoc({"version": 0,
466                                     "dev-id": cfg.psp_dev_id,
467                                     "sock-fd": s.fileno()})
468        rx = rx_assoc['rx-key']
469        tx = _spi_xchg(s, rx)
470
471        rxmss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG)
472        ksft_eq(mss, rxmss)
473
474        cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
475                          "version": 0,
476                          "tx-key": tx,
477                          "sock-fd": s.fileno()})
478
479        txmss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG)
480        ksft_eq(mss, txmss + 40)
481
482        data_len = _send_careful(cfg, s, 100)
483        _check_data_rx(cfg, data_len)
484        _check_data_outq(s, 0)
485
486        txmss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG)
487        ksft_eq(mss, txmss + 40)
488    finally:
489        _close_psp_conn(cfg, s)
490
491
492def data_stale_key(cfg):
493    """ Test send on a double-rotated key """
494    _init_psp_dev(cfg)
495
496    prev_stale = _get_stat(cfg, 'stale-events')
497    s = _make_psp_conn(cfg)
498    try:
499        rx_assoc = cfg.pspnl.rx_assoc({"version": 0,
500                                     "dev-id": cfg.psp_dev_id,
501                                     "sock-fd": s.fileno()})
502        rx = rx_assoc['rx-key']
503        tx = _spi_xchg(s, rx)
504
505        cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
506                          "version": 0,
507                          "tx-key": tx,
508                          "sock-fd": s.fileno()})
509
510        data_len = _send_careful(cfg, s, 100)
511        _check_data_rx(cfg, data_len)
512        _check_data_outq(s, 0)
513
514        cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
515        cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
516
517        cur_stale = _get_stat(cfg, 'stale-events')
518        ksft_gt(cur_stale, prev_stale)
519
520        s.send(b'0123456789' * 200)
521        _check_data_outq(s, 2000, force_wait=True)
522    finally:
523        _close_psp_conn(cfg, s)
524
525
526def __nsim_psp_rereg(cfg):
527    # The PSP dev ID will change, remember what was there before
528    before = set([x['id'] for x in cfg.pspnl.dev_get({}, dump=True)])
529
530    cfg._ns.nsims[0].dfs_write('psp_rereg', '1')
531
532    after = set([x['id'] for x in cfg.pspnl.dev_get({}, dump=True)])
533
534    new_devs = list(after - before)
535    ksft_eq(len(new_devs), 1)
536    cfg.psp_dev_id = list(after - before)[0]
537
538
539def removal_device_rx(cfg):
540    """ Test removing a netdev / PSD with active Rx assoc """
541
542    # We could technically devlink reload real devices, too
543    # but that kills the control socket. So test this on
544    # netdevsim only for now
545    cfg.require_nsim()
546
547    s = _make_clr_conn(cfg)
548    try:
549        rx_assoc = cfg.pspnl.rx_assoc({"version": 0,
550                                       "dev-id": cfg.psp_dev_id,
551                                       "sock-fd": s.fileno()})
552        ksft_not_none(rx_assoc)
553
554        __nsim_psp_rereg(cfg)
555    finally:
556        _close_conn(cfg, s)
557
558
559def removal_device_bi(cfg):
560    """ Test removing a netdev / PSD with active Rx/Tx assoc """
561
562    # We could technically devlink reload real devices, too
563    # but that kills the control socket. So test this on
564    # netdevsim only for now
565    cfg.require_nsim()
566
567    s = _make_clr_conn(cfg)
568    try:
569        rx_assoc = cfg.pspnl.rx_assoc({"version": 0,
570                                       "dev-id": cfg.psp_dev_id,
571                                       "sock-fd": s.fileno()})
572        cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
573                            "version": 0,
574                            "tx-key": rx_assoc['rx-key'],
575                            "sock-fd": s.fileno()})
576        __nsim_psp_rereg(cfg)
577    finally:
578        _close_conn(cfg, s)
579
580
581def _get_psp_ver_ip_variants():
582    for ver in range(4):
583        for ipv in ("4", "6"):
584            yield KsftNamedVariant(f"v{ver}_ip{ipv}", ver, ipv)
585
586
587def _get_ip_variants():
588    for ipv in ("4", "6"):
589        yield KsftNamedVariant(f"ip{ipv}", ipv)
590
591
592@ksft_variants(_get_psp_ver_ip_variants())
593def data_basic_send(cfg, version, ipver):
594    """Test basic PSP data send."""
595    cfg.require_ipver(ipver)
596    _data_basic_send(cfg, version, ipver)
597
598
599@ksft_variants(_get_ip_variants())
600def data_mss_adjust(cfg, ipver):
601    """Test MSS adjustment with PSP."""
602    cfg.require_ipver(ipver)
603    _data_mss_adjust(cfg, ipver)
604
605
606def _check_assoc_list(cfg, psp_dev_id, ifindex, nsid=None):
607    """Verify assoc-list contains device with given ifindex, no duplicates."""
608    dev_info = cfg.pspnl.dev_get({'id': psp_dev_id})
609
610    ksft_true('assoc-list' in dev_info,
611              "No assoc-list in dev_get() response after association")
612    found = False
613    for assoc in dev_info['assoc-list']:
614        if assoc['ifindex'] != ifindex:
615            continue
616        if nsid is not None and assoc['nsid'] != nsid:
617            continue
618        ksft_eq(found, False, "Duplicate assoc entry found")
619        found = True
620    ksft_eq(found, True,
621            "Associated device not found in dev_get() response")
622
623
624def _data_basic_send_netkit_psp_assoc(cfg, version, ipver):
625    """
626    Test basic data send with netkit interface associated with PSP dev.
627    """
628    _assoc_nk_guest(cfg)
629
630    # Enter guest namespace (netns) to run PSP test
631    with NetNSEnter(cfg.netns.name):
632        cfg.pspnl = PSPFamily()
633
634        sock = _make_psp_conn(cfg, version, ipver)
635
636        rx_assoc = cfg.pspnl.rx_assoc({"version": version,
637                                       "dev-id": cfg.psp_dev_id,
638                                       "sock-fd": sock.fileno()})
639        rx_key = rx_assoc['rx-key']
640        tx_key = _spi_xchg(sock, rx_key)
641
642        cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
643                            "version": version,
644                            "tx-key": tx_key,
645                            "sock-fd": sock.fileno()})
646
647        data_len = _send_careful(cfg, sock, 100)
648        _check_data_rx(cfg, data_len)
649        _close_psp_conn(cfg, sock)
650
651
652def _assoc_check_list(cfg):
653    """Test that assoc-list is correctly populated after dev-assoc."""
654    _assoc_nk_guest(cfg)
655    _check_assoc_list(cfg, cfg.psp_dev_id, cfg.nk_guest_ifindex,
656                      cfg.psp_dev_peer_nsid)
657
658
659def _get_psp_ver_ip6_variants():
660    for ver in range(4):
661        yield KsftNamedVariant(f"v{ver}_ip6", ver, "6")
662
663
664@ksft_variants(_get_psp_ver_ip6_variants())
665def data_basic_send_netkit_psp_assoc(cfg, version, ipver):
666    """Test PSP data send via netkit with dev-assoc."""
667    cfg.require_ipver(ipver)
668    _data_basic_send_netkit_psp_assoc(cfg, version, ipver)
669
670
671def _key_rotation_notify_multi_ns_netkit(cfg):
672    """ Test key rotation notifications across multiple namespaces using netkit """
673    _assoc_nk_guest(cfg)
674
675    # Create listener in guest namespace; socket stays bound to that ns
676    with NetNSEnter(cfg.netns.name):
677        peer_pspnl = PSPFamily()
678        peer_pspnl.ntf_subscribe('use')
679
680    # Create listener in main namespace
681    main_pspnl = PSPFamily()
682    main_pspnl.ntf_subscribe('use')
683
684    # Trigger key rotation on the PSP device
685    cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
686
687    # Poll both sockets from main thread
688    for pspnl, label in [(main_pspnl, "main"), (peer_pspnl, "guest")]:
689        for ntf in pspnl.poll_ntf(duration=10):
690            if ntf['msg'].get('id') == cfg.psp_dev_id:
691                break
692        else:
693            raise KsftFailEx(
694                f"No key rotation notification received"
695                f" in {label} namespace")
696
697
698def _dev_change_notify_multi_ns_netkit(cfg):
699    """ Test dev_change notifications across multiple namespaces using netkit """
700    _assoc_nk_guest(cfg)
701
702    # Create listener in guest namespace; socket stays bound to that ns
703    with NetNSEnter(cfg.netns.name):
704        peer_pspnl = PSPFamily()
705        peer_pspnl.ntf_subscribe('mgmt')
706
707    # Create listener in main namespace
708    main_pspnl = PSPFamily()
709    main_pspnl.ntf_subscribe('mgmt')
710
711    # Trigger dev_change by calling dev_set (notification is always sent)
712    cfg.pspnl.dev_set({'id': cfg.psp_dev_id,
713                       'psp-versions-ena': cfg.psp_info['psp-versions-cap']})
714
715    # Poll both sockets from main thread
716    for pspnl, label in [(main_pspnl, "main"), (peer_pspnl, "guest")]:
717        for ntf in pspnl.poll_ntf(duration=10):
718            if ntf['msg'].get('id') == cfg.psp_dev_id:
719                break
720        else:
721            raise KsftFailEx(
722                f"No dev_change notification received"
723                f" in {label} namespace")
724
725
726def _psp_dev_get_check_netkit_psp_assoc(cfg):
727    """ Check psp dev-get output with netkit interface associated with PSP dev """
728    _assoc_nk_guest(cfg)
729
730    # Check 1: In default netns, verify dev-get has correct ifindex and assoc-list
731    dev_info = cfg.pspnl.dev_get({'id': cfg.psp_dev_id})
732    ksft_eq(dev_info['ifindex'], cfg.psp_ifindex)
733    _check_assoc_list(cfg, cfg.psp_dev_id, cfg.nk_guest_ifindex,
734                      cfg.psp_dev_peer_nsid)
735
736    # Check 2: In guest netns, verify dev-get has assoc-list with nk_guest device
737    with NetNSEnter(cfg.netns.name):
738        peer_pspnl = PSPFamily()
739
740        # Dump all devices in the guest namespace
741        peer_devices = peer_pspnl.dev_get({}, dump=True)
742
743        # Find the device with by-association flag
744        peer_dev = None
745        for dev in peer_devices:
746            if dev.get('by-association'):
747                peer_dev = dev
748                break
749
750        ksft_not_none(peer_dev, "No PSP device found with by-association flag in guest netns")
751
752        # Verify assoc-list contains the nk_guest device
753        ksft_true('assoc-list' in peer_dev and len(peer_dev['assoc-list']) > 0,
754                  "Guest device should have assoc-list with local devices")
755
756        # Verify the assoc-list contains nk_guest ifindex with nsid=-1 (same namespace)
757        found = False
758        for assoc in peer_dev['assoc-list']:
759            if assoc['ifindex'] == cfg.nk_guest_ifindex:
760                ksft_eq(assoc['nsid'], -1,
761                        "nsid should be -1 (NETNSA_NSID_NOT_ASSIGNED) for same-namespace device")
762                found = True
763                break
764        ksft_true(found, "nk_guest ifindex not found in assoc-list")
765
766
767def _dev_assoc_no_nsid(cfg):
768    """ Test dev-assoc and dev-disassoc without nsid attribute """
769    _init_psp_dev(cfg, True)
770
771    # Associate without nsid - should look up ifindex in caller's netns
772    cfg.pspnl.dev_assoc({'id': cfg.psp_dev_id,
773                         'ifindex': cfg.nk_host_ifindex})
774    defer(_try_disassoc, cfg,
775          cfg.psp_dev_id, cfg.nk_host_ifindex)
776    defer(delattr, cfg, 'psp_dev_id')
777    defer(delattr, cfg, 'psp_info')
778
779    # Verify assoc-list contains the device (match by ifindex only)
780    _check_assoc_list(cfg, cfg.psp_dev_id, cfg.nk_host_ifindex)
781
782    # Disassociate without nsid - should also use caller's netns
783    cfg.pspnl.dev_disassoc({'id': cfg.psp_dev_id,
784                            'ifindex': cfg.nk_host_ifindex})
785
786    # Verify assoc-list no longer contains the device
787    dev_info = cfg.pspnl.dev_get({'id': cfg.psp_dev_id})
788    found = False
789    if 'assoc-list' in dev_info:
790        for assoc in dev_info['assoc-list']:
791            if assoc['ifindex'] == cfg.nk_host_ifindex:
792                found = True
793                break
794    ksft_true(not found, "Device should not be in assoc-list after disassociation")
795
796
797def _psp_dev_assoc_cleanup_on_netkit_del(cfg):
798    """Test that assoc-list is cleared when associated netkit is deleted.
799
800    Creates a disposable netkit pair for this test to avoid destroying
801    the shared environment.
802    """
803    _init_psp_dev(cfg, True)
804    defer(delattr, cfg, 'psp_dev_id')
805    defer(delattr, cfg, 'psp_info')
806
807    existing = {cfg.nk_host_ifindex, cfg.nk_guest_ifindex}
808
809    # Create a temporary netkit pair
810    tmp_host_name = "tmp_nk_host"
811    tmp_guest_name = "tmp_nk_guest"
812    rtnl = RtnlFamily()
813    rtnl.newlink(
814        {
815            "ifname": tmp_host_name,
816            "linkinfo": {
817                "kind": "netkit",
818                "data": {
819                    "mode": "l2",
820                    "policy": "forward",
821                    "peer-policy": "forward",
822                },
823            },
824        },
825        flags=[Netlink.NLM_F_CREATE, Netlink.NLM_F_EXCL],
826    )
827    cleanup_netkit = defer(ip, f"link del {tmp_host_name}")
828
829    # Find the peer by diffing against existing netkit ifindexes
830    all_links = ip("-d link show", json=True)
831    tmp_peer = [link for link in all_links
832                if link.get('linkinfo', {}).get('info_kind') == 'netkit'
833                and link['ifindex'] not in existing
834                and link['ifname'] != tmp_host_name]
835    ksft_eq(len(tmp_peer), 1,
836            "Failed to find temporary netkit peer")
837    guest_name = tmp_peer[0]['ifname']
838
839    # Rename and move guest end into the test namespace
840    ip(f"link set dev {guest_name} name {tmp_guest_name}")
841    ip(f"link set dev {tmp_guest_name} netns {cfg.netns.name}")
842    tmp_guest_dev = ip(f"link show dev {tmp_guest_name}",
843                       json=True, ns=cfg.netns)[0]
844    tmp_guest_ifindex = tmp_guest_dev['ifindex']
845    ip(f"link set dev {tmp_guest_name} up", ns=cfg.netns)
846
847    # Associate PSP device with the temporary guest interface
848    cfg.pspnl.dev_assoc({'id': cfg.psp_dev_id,
849                         'ifindex': tmp_guest_ifindex,
850                         'nsid': cfg.psp_dev_peer_nsid})
851
852    # Verify assoc-list contains the temporary device
853    _check_assoc_list(cfg, cfg.psp_dev_id, tmp_guest_ifindex,
854                      cfg.psp_dev_peer_nsid)
855
856    # Delete the temporary netkit pair (deleting one end removes both)
857    ip(f"link del {tmp_host_name}")
858    cleanup_netkit.cancel()
859
860    # Verify assoc-list is cleared after netkit deletion
861    dev_info = cfg.pspnl.dev_get({'id': cfg.psp_dev_id})
862    ksft_true('assoc-list' not in dev_info
863              or len(dev_info['assoc-list']) == 0,
864              "assoc-list should be empty after netkit deletion")
865
866
867def _try_disassoc(cfg, psp_dev_id, ifindex, nsid=None):
868    """Best-effort disassociate, ignoring errors if already removed."""
869    try:
870        params = {'id': psp_dev_id, 'ifindex': ifindex}
871        if nsid is not None:
872            params['nsid'] = nsid
873        cfg.pspnl.dev_disassoc(params)
874    except NlError:
875        pass
876
877
878def _assoc_nk_guest(cfg):
879    """Associate nk_guest with PSP device and register cleanup via defer()."""
880    _init_psp_dev(cfg, True)
881
882    cfg.pspnl.dev_assoc({'id': cfg.psp_dev_id,
883                         'ifindex': cfg.nk_guest_ifindex,
884                         'nsid': cfg.psp_dev_peer_nsid})
885    defer(_disassoc_nk_guest, cfg,
886          cfg.psp_dev_id, cfg.nk_guest_ifindex)
887
888
889def _disassoc_nk_guest(cfg, psp_dev_id, nk_guest_ifindex):
890    """Disassociate nk_guest and reset cfg PSP state."""
891    pspnl = PSPFamily()
892    pspnl.dev_disassoc({'id': psp_dev_id, 'ifindex': nk_guest_ifindex,
893                        'nsid': cfg.psp_dev_peer_nsid})
894    cfg.pspnl = pspnl
895    del cfg.psp_dev_id
896    del cfg.psp_info
897
898
899def _get_nsid(ns_name):
900    """Get the nsid for a namespace."""
901    for entry in ip("netns list-id", json=True):
902        if entry.get("name") == str(ns_name):
903            return entry["nsid"]
904    raise KsftSkipEx(f"nsid not found for namespace {ns_name}")
905
906
907def _setup_psp_attributes(cfg):
908    # pylint: disable=protected-access
909    """
910    Set up PSP-specific attributes on the environment.
911
912    This sets attributes needed for PSP tests based on whether we're using
913    netdevsim or a real NIC.
914    """
915    if cfg._ns is not None:
916        # netdevsim case: PSP device is the local dev (in host namespace)
917        cfg.psp_dev = cfg._ns.nsims[0].dev
918        cfg.psp_ifname = cfg.psp_dev['ifname']
919        cfg.psp_ifindex = cfg.psp_dev['ifindex']
920
921        # PSP peer device is the remote dev (in _netns, where psp_responder runs)
922        cfg.psp_dev_peer = cfg._ns_peer.nsims[0].dev
923        cfg.psp_dev_peer_ifname = cfg.psp_dev_peer['ifname']
924        cfg.psp_dev_peer_ifindex = cfg.psp_dev_peer['ifindex']
925    else:
926        # Real NIC case: PSP device is the local interface
927        cfg.psp_dev = cfg.dev
928        cfg.psp_ifname = cfg.ifname
929        cfg.psp_ifindex = cfg.ifindex
930
931        # PSP peer device is the remote interface
932        cfg.psp_dev_peer = cfg.remote_dev
933        cfg.psp_dev_peer_ifname = cfg.remote_ifname
934        cfg.psp_dev_peer_ifindex = cfg.remote_ifindex
935
936    # Get nsid for the guest namespace (netns) where nk_guest is
937    cfg.psp_dev_peer_nsid = _get_nsid(cfg.netns.name)
938
939
940
941def main() -> None:
942    """ Ksft boiler plate main """
943
944    # Make sure LOCAL_PREFIX_V6 is set
945    if "LOCAL_PREFIX_V6" not in os.environ:
946        os.environ["LOCAL_PREFIX_V6"] = "2001:db8:2::"
947
948    try:
949        env = NetDrvContEnv(__file__, primary_rx_redirect=True)
950        has_cont = True
951    except KsftSkipEx:
952        env = NetDrvEpEnv(__file__)
953        has_cont = False
954
955    with env as cfg:
956        cfg.pspnl = PSPFamily()
957
958        if has_cont:
959            _setup_psp_attributes(cfg)
960
961        # Set up responder and communication sock
962        # psp_responder runs in _netns (remote namespace with psp_dev_peer)
963        responder = cfg.remote.deploy("psp_responder")
964
965        cfg.comm_port = rand_port()
966        srv = None
967        try:
968            with bkg(responder + f" -p {cfg.comm_port} -i {cfg.remote_ifindex}",
969                     host=cfg.remote, exit_wait=True) as srv:
970                wait_port_listen(cfg.comm_port, host=cfg.remote)
971
972                cfg.comm_sock = socket.create_connection((cfg.remote_addr,
973                                                          cfg.comm_port),
974                                                         timeout=1)
975
976                cases = [data_basic_send, data_mss_adjust]
977
978                if has_cont:
979                    cases += [
980                        _assoc_check_list,
981                        data_basic_send_netkit_psp_assoc,
982                        _key_rotation_notify_multi_ns_netkit,
983                        _dev_change_notify_multi_ns_netkit,
984                        _psp_dev_get_check_netkit_psp_assoc,
985                        _dev_assoc_no_nsid,
986                        _psp_dev_assoc_cleanup_on_netkit_del,
987                    ]
988
989                ksft_run(cases=cases, globs=globals(),
990                         case_pfx={"dev_", "data_", "assoc_", "removal_"},
991                         args=(cfg, ))
992
993                cfg.comm_sock.send(b"exit\0")
994                cfg.comm_sock.close()
995        finally:
996            if srv and (srv.stdout or srv.stderr):
997                ksft_pr("")
998                ksft_pr(f"Responder logs ({srv.ret}):")
999            if srv and srv.stdout:
1000                ksft_pr("STDOUT:\n#  " + srv.stdout.strip().replace("\n", "\n#  "))
1001            if srv and srv.stderr:
1002                ksft_pr("STDERR:\n#  " + srv.stderr.strip().replace("\n", "\n#  "))
1003    ksft_exit()
1004
1005
1006if __name__ == "__main__":
1007    main()
1008