xref: /linux/tools/testing/selftests/drivers/net/psp.py (revision 8f90dc6e417a64144d06405ee0404965139c825a)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4"""Test suite for PSP capable drivers."""
5
6import errno
7import fcntl
8import socket
9import struct
10import termios
11import time
12
13from lib.py import defer
14from lib.py import ksft_run, ksft_exit, ksft_pr
15from lib.py import ksft_true, ksft_eq, ksft_ne, ksft_raises
16from lib.py import KsftSkipEx
17from lib.py import NetDrvEpEnv, PSPFamily, NlError
18from lib.py import bkg, rand_port, wait_port_listen
19
20
21def _get_outq(s):
22    one = b'\0' * 4
23    outq = fcntl.ioctl(s.fileno(), termios.TIOCOUTQ, one)
24    return struct.unpack("I", outq)[0]
25
26
27def _send_with_ack(cfg, msg):
28    cfg.comm_sock.send(msg)
29    response = cfg.comm_sock.recv(4)
30    if response != b'ack\0':
31        raise RuntimeError("Unexpected server response", response)
32
33
34def _remote_read_len(cfg):
35    cfg.comm_sock.send(b'read len\0')
36    return int(cfg.comm_sock.recv(1024)[:-1].decode('utf-8'))
37
38
39def _make_psp_conn(cfg, version=0, ipver=None):
40    _send_with_ack(cfg, b'conn psp\0' + struct.pack('BB', version, version))
41    remote_addr = cfg.remote_addr_v[ipver] if ipver else cfg.remote_addr
42    s = socket.create_connection((remote_addr, cfg.comm_port), )
43    return s
44
45
46def _close_conn(cfg, s):
47    _send_with_ack(cfg, b'data close\0')
48    s.close()
49
50
51def _close_psp_conn(cfg, s):
52    _close_conn(cfg, s)
53
54
55def _spi_xchg(s, rx):
56    s.send(struct.pack('I', rx['spi']) + rx['key'])
57    tx = s.recv(4 + len(rx['key']))
58    return {
59        'spi': struct.unpack('I', tx[:4])[0],
60        'key': tx[4:]
61    }
62
63
64def _send_careful(cfg, s, rounds):
65    data = b'0123456789' * 200
66    for i in range(rounds):
67        n = 0
68        for _ in range(10): # allow 10 retries
69            try:
70                n += s.send(data[n:], socket.MSG_DONTWAIT)
71                if n == len(data):
72                    break
73            except BlockingIOError:
74                time.sleep(0.05)
75        else:
76            rlen = _remote_read_len(cfg)
77            outq = _get_outq(s)
78            report = f'sent: {i * len(data) + n} remote len: {rlen} outq: {outq}'
79            raise RuntimeError(report)
80
81    return len(data) * rounds
82
83
84def _check_data_rx(cfg, exp_len):
85    read_len = -1
86    for _ in range(30):
87        cfg.comm_sock.send(b'read len\0')
88        read_len = int(cfg.comm_sock.recv(1024)[:-1].decode('utf-8'))
89        if read_len == exp_len:
90            break
91        time.sleep(0.01)
92    ksft_eq(read_len, exp_len)
93
94#
95# Test case boiler plate
96#
97
98def _init_psp_dev(cfg):
99    if not hasattr(cfg, 'psp_dev_id'):
100        # Figure out which local device we are testing against
101        for dev in cfg.pspnl.dev_get({}, dump=True):
102            if dev['ifindex'] == cfg.ifindex:
103                cfg.psp_info = dev
104                cfg.psp_dev_id = cfg.psp_info['id']
105                break
106        else:
107            raise KsftSkipEx("No PSP devices found")
108
109    # Enable PSP if necessary
110    cap = cfg.psp_info['psp-versions-cap']
111    ena = cfg.psp_info['psp-versions-ena']
112    if cap != ena:
113        cfg.pspnl.dev_set({'id': cfg.psp_dev_id, 'psp-versions-ena': cap})
114        defer(cfg.pspnl.dev_set, {'id': cfg.psp_dev_id,
115                                  'psp-versions-ena': ena })
116
117#
118# Test cases
119#
120
121def dev_list_devices(cfg):
122    """ Dump all devices """
123    _init_psp_dev(cfg)
124
125    devices = cfg.pspnl.dev_get({}, dump=True)
126
127    found = False
128    for dev in devices:
129        found |= dev['id'] == cfg.psp_dev_id
130    ksft_true(found)
131
132
133def dev_get_device(cfg):
134    """ Get the device we intend to use """
135    _init_psp_dev(cfg)
136
137    dev = cfg.pspnl.dev_get({'id': cfg.psp_dev_id})
138    ksft_eq(dev['id'], cfg.psp_dev_id)
139
140
141def dev_get_device_bad(cfg):
142    """ Test getting device which doesn't exist """
143    raised = False
144    try:
145        cfg.pspnl.dev_get({'id': 1234567})
146    except NlError as e:
147        ksft_eq(e.nl_msg.error, -errno.ENODEV)
148        raised = True
149    ksft_true(raised)
150
151
152def dev_rotate(cfg):
153    """ Test key rotation """
154    _init_psp_dev(cfg)
155
156    rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
157    ksft_eq(rot['id'], cfg.psp_dev_id)
158    rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
159    ksft_eq(rot['id'], cfg.psp_dev_id)
160
161
162def dev_rotate_spi(cfg):
163    """ Test key rotation and SPI check """
164    _init_psp_dev(cfg)
165
166    top_a = top_b = 0
167    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
168        assoc_a = cfg.pspnl.rx_assoc({"version": 0,
169                                     "dev-id": cfg.psp_dev_id,
170                                     "sock-fd": s.fileno()})
171        top_a = assoc_a['rx-key']['spi'] >> 31
172        s.close()
173    rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id})
174    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
175        ksft_eq(rot['id'], cfg.psp_dev_id)
176        assoc_b = cfg.pspnl.rx_assoc({"version": 0,
177                                    "dev-id": cfg.psp_dev_id,
178                                    "sock-fd": s.fileno()})
179        top_b = assoc_b['rx-key']['spi'] >> 31
180        s.close()
181    ksft_ne(top_a, top_b)
182
183
184def _data_basic_send(cfg, version, ipver):
185    """ Test basic data send """
186    _init_psp_dev(cfg)
187
188    # Version 0 is required by spec, don't let it skip
189    if version:
190        name = cfg.pspnl.consts["version"].entries_by_val[version].name
191        if name not in cfg.psp_info['psp-versions-cap']:
192            with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
193                with ksft_raises(NlError) as cm:
194                    cfg.pspnl.rx_assoc({"version": version,
195                                        "dev-id": cfg.psp_dev_id,
196                                        "sock-fd": s.fileno()})
197                ksft_eq(cm.exception.nl_msg.error, -errno.EOPNOTSUPP)
198            raise KsftSkipEx("PSP version not supported", name)
199
200    s = _make_psp_conn(cfg, version, ipver)
201
202    rx_assoc = cfg.pspnl.rx_assoc({"version": version,
203                                   "dev-id": cfg.psp_dev_id,
204                                   "sock-fd": s.fileno()})
205    rx = rx_assoc['rx-key']
206    tx = _spi_xchg(s, rx)
207
208    cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id,
209                        "version": version,
210                        "tx-key": tx,
211                        "sock-fd": s.fileno()})
212
213    data_len = _send_careful(cfg, s, 100)
214    _check_data_rx(cfg, data_len)
215    _close_psp_conn(cfg, s)
216
217
218def psp_ip_ver_test_builder(name, test_func, psp_ver, ipver):
219    """Build test cases for each combo of PSP version and IP version"""
220    def test_case(cfg):
221        cfg.require_ipver(ipver)
222        test_case.__name__ = f"{name}_v{psp_ver}_ip{ipver}"
223        test_func(cfg, psp_ver, ipver)
224    return test_case
225
226
227def main() -> None:
228    """ Ksft boiler plate main """
229
230    with NetDrvEpEnv(__file__) as cfg:
231        cfg.pspnl = PSPFamily()
232
233        # Set up responder and communication sock
234        responder = cfg.remote.deploy("psp_responder")
235
236        cfg.comm_port = rand_port()
237        srv = None
238        try:
239            with bkg(responder + f" -p {cfg.comm_port}", host=cfg.remote,
240                     exit_wait=True) as srv:
241                wait_port_listen(cfg.comm_port, host=cfg.remote)
242
243                cfg.comm_sock = socket.create_connection((cfg.remote_addr,
244                                                          cfg.comm_port),
245                                                         timeout=1)
246
247                cases = [
248                    psp_ip_ver_test_builder(
249                        "data_basic_send", _data_basic_send, version, ipver
250                    )
251                    for version in range(0, 4)
252                    for ipver in ("4", "6")
253                ]
254
255                ksft_run(cases=cases, globs=globals(), case_pfx={"dev_",}, args=(cfg, ))
256
257                cfg.comm_sock.send(b"exit\0")
258                cfg.comm_sock.close()
259        finally:
260            if srv and (srv.stdout or srv.stderr):
261                ksft_pr("")
262                ksft_pr(f"Responder logs ({srv.ret}):")
263            if srv and srv.stdout:
264                ksft_pr("STDOUT:\n#  " + srv.stdout.strip().replace("\n", "\n#  "))
265            if srv and srv.stderr:
266                ksft_pr("STDERR:\n#  " + srv.stderr.strip().replace("\n", "\n#  "))
267    ksft_exit()
268
269
270if __name__ == "__main__":
271    main()
272