#!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0 import argparse import errno import logging import socket import struct import time import usb.core import usb.util def path_from_usb_dev(dev): """Takes a pyUSB device as argument and returns a string. The string is a Path representation of the position of the USB device on the USB bus tree. This path is used to find a USB device on the bus or all devices connected to a HUB. The path is made up of the number of the USB controller followed be the ports of the HUB tree.""" if dev.port_numbers: dev_path = ".".join(str(i) for i in dev.port_numbers) return f"{dev.bus}-{dev_path}" return "" HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128 class Forwarder: @staticmethod def _log_hexdump(data): if not logging.root.isEnabledFor(logging.TRACE): return L = 16 for c in range(0, len(data), L): chars = data[c : c + L] dump = " ".join(f"{x:02x}" for x in chars) printable = "".join(HEXDUMP_FILTER[x] for x in chars) line = f"{c:08x} {dump:{L*3}s} |{printable:{L}s}|" logging.root.log(logging.TRACE, "%s", line) def __init__(self, server, vid, pid, path): self.stats = { "c2s packets": 0, "c2s bytes": 0, "s2c packets": 0, "s2c bytes": 0, } self.stats_logged = time.monotonic() def find_filter(dev): dev_path = path_from_usb_dev(dev) if path is not None: return dev_path == path return True dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter) if dev is None: raise ValueError("Device not found") logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}") # dev.set_configuration() is not necessary since g_multi has only one usb9pfs = None # g_multi adds 9pfs as last interface cfg = dev.get_active_configuration() for intf in cfg: # we have to detach the usb-storage driver from multi gadget since # stall option could be set, which will lead to spontaneous port # resets and our transfers will run dead if intf.bInterfaceClass == 0x08: if dev.is_kernel_driver_active(intf.bInterfaceNumber): dev.detach_kernel_driver(intf.bInterfaceNumber) if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09: usb9pfs = intf if usb9pfs is None: raise ValueError("Interface not found") logging.info(f"claiming interface:\n{usb9pfs}") usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber) ep_out = usb.util.find_descriptor( usb9pfs, custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT, ) assert ep_out is not None ep_in = usb.util.find_descriptor( usb9pfs, custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN, ) assert ep_in is not None logging.info("interface claimed") self.ep_out = ep_out self.ep_in = ep_in self.dev = dev # create and connect socket self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.s.connect(server) logging.info("connected to server") def c2s(self): """forward a request from the USB client to the TCP server""" data = None while data is None: try: logging.log(logging.TRACE, "c2s: reading") data = self.ep_in.read(self.ep_in.wMaxPacketSize) except usb.core.USBTimeoutError: logging.log(logging.TRACE, "c2s: reading timed out") continue except usb.core.USBError as e: if e.errno == errno.EIO: logging.debug("c2s: reading failed with %s, retrying", repr(e)) time.sleep(0.5) continue logging.error("c2s: reading failed with %s, aborting", repr(e)) raise size = struct.unpack(" 0 data = data[written:] if size % self.ep_out.wMaxPacketSize == 0: logging.log(logging.TRACE, "sending zero length packet") self.ep_out.write(b"") logging.debug("s2c: forwarded %i bytes", size) self.stats["s2c packets"] += 1 self.stats["s2c bytes"] += size def log_stats(self): logging.info("statistics:") for k, v in self.stats.items(): logging.info(f" {k+':':14s} {v}") def log_stats_interval(self, interval=5): if (time.monotonic() - self.stats_logged) < interval: return self.log_stats() self.stats_logged = time.monotonic() def try_get_usb_str(dev, name): try: with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f: return f.read().strip() except FileNotFoundError: return None def list_usb(args): vid, pid = [int(x, 16) for x in args.id.split(":", 1)] print("Bus | Addr | Manufacturer | Product | ID | Path") print("--- | ---- | ---------------- | ---------------- | --------- | ----") for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid): path = path_from_usb_dev(dev) or "" manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown" product = try_get_usb_str(dev, "product") or "unknown" print( f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}" ) def connect(args): vid, pid = [int(x, 16) for x in args.id.split(":", 1)] f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path) try: while True: f.c2s() f.s2c() f.log_stats_interval() finally: f.log_stats() def main(): parser = argparse.ArgumentParser( description="Forward 9PFS requests from USB to TCP", ) parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device") parser.add_argument("--path", type=str, required=False, help="path of target device") parser.add_argument("-v", "--verbose", action="count", default=0) subparsers = parser.add_subparsers() subparsers.required = True subparsers.dest = "command" parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets") parser_list.set_defaults(func=list_usb) parser_connect = subparsers.add_parser( "connect", help="Forward messages between the usb9pfs gadget and the 9p server" ) parser_connect.set_defaults(func=connect) connect_group = parser_connect.add_argument_group() connect_group.required = True parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname") parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port") args = parser.parse_args() logging.TRACE = logging.DEBUG - 5 logging.addLevelName(logging.TRACE, "TRACE") if args.verbose >= 2: level = logging.TRACE elif args.verbose: level = logging.DEBUG else: level = logging.INFO logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s") args.func(args) if __name__ == "__main__": main()