xref: /linux/tools/usb/p9_fwd.py (revision 3ba84ac69b53e6ee07c31d54554e00793d7b144f)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4import argparse
5import errno
6import logging
7import socket
8import struct
9import time
10
11import usb.core
12import usb.util
13
14
15def path_from_usb_dev(dev):
16    """Takes a pyUSB device as argument and returns a string.
17    The string is a Path representation of the position of the USB device on the USB bus tree.
18
19    This path is used to find a USB device on the bus or all devices connected to a HUB.
20    The path is made up of the number of the USB controller followed be the ports of the HUB tree."""
21    if dev.port_numbers:
22        dev_path = ".".join(str(i) for i in dev.port_numbers)
23        return f"{dev.bus}-{dev_path}"
24    return ""
25
26
27HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128
28
29
30class Forwarder:
31    @staticmethod
32    def _log_hexdump(data):
33        if not logging.root.isEnabledFor(logging.TRACE):
34            return
35        L = 16
36        for c in range(0, len(data), L):
37            chars = data[c : c + L]
38            dump = " ".join(f"{x:02x}" for x in chars)
39            printable = "".join(HEXDUMP_FILTER[x] for x in chars)
40            line = f"{c:08x}  {dump:{L*3}s} |{printable:{L}s}|"
41            logging.root.log(logging.TRACE, "%s", line)
42
43    def __init__(self, server, vid, pid, path):
44        self.stats = {
45            "c2s packets": 0,
46            "c2s bytes": 0,
47            "s2c packets": 0,
48            "s2c bytes": 0,
49        }
50        self.stats_logged = time.monotonic()
51
52        def find_filter(dev):
53            dev_path = path_from_usb_dev(dev)
54            if path is not None:
55                return dev_path == path
56            return True
57
58        dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter)
59        if dev is None:
60            raise ValueError("Device not found")
61
62        logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}")
63
64        # dev.set_configuration() is not necessary since g_multi has only one
65        usb9pfs = None
66        # g_multi adds 9pfs as last interface
67        cfg = dev.get_active_configuration()
68        for intf in cfg:
69            # we have to detach the usb-storage driver from multi gadget since
70            # stall option could be set, which will lead to spontaneous port
71            # resets and our transfers will run dead
72            if intf.bInterfaceClass == 0x08:
73                if dev.is_kernel_driver_active(intf.bInterfaceNumber):
74                    dev.detach_kernel_driver(intf.bInterfaceNumber)
75
76            if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09:
77                usb9pfs = intf
78        if usb9pfs is None:
79            raise ValueError("Interface not found")
80
81        logging.info(f"claiming interface:\n{usb9pfs}")
82        usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber)
83        ep_out = usb.util.find_descriptor(
84            usb9pfs,
85            custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT,
86        )
87        assert ep_out is not None
88        ep_in = usb.util.find_descriptor(
89            usb9pfs,
90            custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN,
91        )
92        assert ep_in is not None
93        logging.info("interface claimed")
94
95        self.ep_out = ep_out
96        self.ep_in = ep_in
97        self.dev = dev
98
99        # create and connect socket
100        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
101        self.s.connect(server)
102
103        logging.info("connected to server")
104
105    def c2s(self):
106        """forward a request from the USB client to the TCP server"""
107        data = None
108        while data is None:
109            try:
110                logging.log(logging.TRACE, "c2s: reading")
111                data = self.ep_in.read(self.ep_in.wMaxPacketSize)
112            except usb.core.USBTimeoutError:
113                logging.log(logging.TRACE, "c2s: reading timed out")
114                continue
115            except usb.core.USBError as e:
116                if e.errno == errno.EIO:
117                    logging.debug("c2s: reading failed with %s, retrying", repr(e))
118                    time.sleep(0.5)
119                    continue
120                logging.error("c2s: reading failed with %s, aborting", repr(e))
121                raise
122        size = struct.unpack("<I", data[:4])[0]
123        while len(data) < size:
124            data += self.ep_in.read(size - len(data))
125        logging.log(logging.TRACE, "c2s: writing")
126        self._log_hexdump(data)
127        self.s.send(data)
128        logging.debug("c2s: forwarded %i bytes", size)
129        self.stats["c2s packets"] += 1
130        self.stats["c2s bytes"] += size
131
132    def s2c(self):
133        """forward a response from the TCP server to the USB client"""
134        logging.log(logging.TRACE, "s2c: reading")
135        data = self.s.recv(4)
136        size = struct.unpack("<I", data[:4])[0]
137        while len(data) < size:
138            data += self.s.recv(size - len(data))
139        logging.log(logging.TRACE, "s2c: writing")
140        self._log_hexdump(data)
141        while data:
142            written = self.ep_out.write(data)
143            assert written > 0
144            data = data[written:]
145        if size % self.ep_out.wMaxPacketSize == 0:
146            logging.log(logging.TRACE, "sending zero length packet")
147            self.ep_out.write(b"")
148        logging.debug("s2c: forwarded %i bytes", size)
149        self.stats["s2c packets"] += 1
150        self.stats["s2c bytes"] += size
151
152    def log_stats(self):
153        logging.info("statistics:")
154        for k, v in self.stats.items():
155            logging.info(f"  {k+':':14s} {v}")
156
157    def log_stats_interval(self, interval=5):
158        if (time.monotonic() - self.stats_logged) < interval:
159            return
160
161        self.log_stats()
162        self.stats_logged = time.monotonic()
163
164
165def try_get_usb_str(dev, name):
166    try:
167        with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f:
168            return f.read().strip()
169    except FileNotFoundError:
170        return None
171
172
173def list_usb(args):
174    vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
175
176    print("Bus | Addr | Manufacturer     | Product          | ID        | Path")
177    print("--- | ---- | ---------------- | ---------------- | --------- | ----")
178    for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid):
179        path = path_from_usb_dev(dev) or ""
180        manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown"
181        product = try_get_usb_str(dev, "product") or "unknown"
182        print(
183            f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}"
184        )
185
186
187def connect(args):
188    vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
189
190    f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path)
191
192    try:
193        while True:
194            f.c2s()
195            f.s2c()
196            f.log_stats_interval()
197    finally:
198        f.log_stats()
199
200
201def main():
202    parser = argparse.ArgumentParser(
203        description="Forward 9PFS requests from USB to TCP",
204    )
205
206    parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device")
207    parser.add_argument("--path", type=str, required=False, help="path of target device")
208    parser.add_argument("-v", "--verbose", action="count", default=0)
209
210    subparsers = parser.add_subparsers()
211    subparsers.required = True
212    subparsers.dest = "command"
213
214    parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets")
215    parser_list.set_defaults(func=list_usb)
216
217    parser_connect = subparsers.add_parser(
218        "connect", help="Forward messages between the usb9pfs gadget and the 9p server"
219    )
220    parser_connect.set_defaults(func=connect)
221    connect_group = parser_connect.add_argument_group()
222    connect_group.required = True
223    parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname")
224    parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port")
225
226    args = parser.parse_args()
227
228    logging.TRACE = logging.DEBUG - 5
229    logging.addLevelName(logging.TRACE, "TRACE")
230
231    if args.verbose >= 2:
232        level = logging.TRACE
233    elif args.verbose:
234        level = logging.DEBUG
235    else:
236        level = logging.INFO
237    logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s")
238
239    args.func(args)
240
241
242if __name__ == "__main__":
243    main()
244