1#!/usr/bin/env python3 2# SPDX-License-Identifier: GPL-2.0 3 4import argparse 5import errno 6import logging 7import socket 8import struct 9import time 10 11import usb.core 12import usb.util 13 14 15def path_from_usb_dev(dev): 16 """Takes a pyUSB device as argument and returns a string. 17 The string is a Path representation of the position of the USB device on the USB bus tree. 18 19 This path is used to find a USB device on the bus or all devices connected to a HUB. 20 The path is made up of the number of the USB controller followed be the ports of the HUB tree.""" 21 if dev.port_numbers: 22 dev_path = ".".join(str(i) for i in dev.port_numbers) 23 return f"{dev.bus}-{dev_path}" 24 return "" 25 26 27HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128 28 29 30class Forwarder: 31 @staticmethod 32 def _log_hexdump(data): 33 if not logging.root.isEnabledFor(logging.TRACE): 34 return 35 L = 16 36 for c in range(0, len(data), L): 37 chars = data[c : c + L] 38 dump = " ".join(f"{x:02x}" for x in chars) 39 printable = "".join(HEXDUMP_FILTER[x] for x in chars) 40 line = f"{c:08x} {dump:{L*3}s} |{printable:{L}s}|" 41 logging.root.log(logging.TRACE, "%s", line) 42 43 def __init__(self, server, vid, pid, path): 44 self.stats = { 45 "c2s packets": 0, 46 "c2s bytes": 0, 47 "s2c packets": 0, 48 "s2c bytes": 0, 49 } 50 self.stats_logged = time.monotonic() 51 52 def find_filter(dev): 53 dev_path = path_from_usb_dev(dev) 54 if path is not None: 55 return dev_path == path 56 return True 57 58 dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter) 59 if dev is None: 60 raise ValueError("Device not found") 61 62 logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}") 63 64 # dev.set_configuration() is not necessary since g_multi has only one 65 usb9pfs = None 66 # g_multi adds 9pfs as last interface 67 cfg = dev.get_active_configuration() 68 for intf in cfg: 69 # we have to detach the usb-storage driver from multi gadget since 70 # stall option could be set, which will lead to spontaneous port 71 # resets and our transfers will run dead 72 if intf.bInterfaceClass == 0x08: 73 if dev.is_kernel_driver_active(intf.bInterfaceNumber): 74 dev.detach_kernel_driver(intf.bInterfaceNumber) 75 76 if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09: 77 usb9pfs = intf 78 if usb9pfs is None: 79 raise ValueError("Interface not found") 80 81 logging.info(f"claiming interface:\n{usb9pfs}") 82 usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber) 83 ep_out = usb.util.find_descriptor( 84 usb9pfs, 85 custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT, 86 ) 87 assert ep_out is not None 88 ep_in = usb.util.find_descriptor( 89 usb9pfs, 90 custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN, 91 ) 92 assert ep_in is not None 93 logging.info("interface claimed") 94 95 self.ep_out = ep_out 96 self.ep_in = ep_in 97 self.dev = dev 98 99 # create and connect socket 100 self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 101 self.s.connect(server) 102 103 logging.info("connected to server") 104 105 def c2s(self): 106 """forward a request from the USB client to the TCP server""" 107 data = None 108 while data is None: 109 try: 110 logging.log(logging.TRACE, "c2s: reading") 111 data = self.ep_in.read(self.ep_in.wMaxPacketSize) 112 except usb.core.USBTimeoutError: 113 logging.log(logging.TRACE, "c2s: reading timed out") 114 continue 115 except usb.core.USBError as e: 116 if e.errno == errno.EIO: 117 logging.debug("c2s: reading failed with %s, retrying", repr(e)) 118 time.sleep(0.5) 119 continue 120 logging.error("c2s: reading failed with %s, aborting", repr(e)) 121 raise 122 size = struct.unpack("<I", data[:4])[0] 123 while len(data) < size: 124 data += self.ep_in.read(size - len(data)) 125 logging.log(logging.TRACE, "c2s: writing") 126 self._log_hexdump(data) 127 self.s.send(data) 128 logging.debug("c2s: forwarded %i bytes", size) 129 self.stats["c2s packets"] += 1 130 self.stats["c2s bytes"] += size 131 132 def s2c(self): 133 """forward a response from the TCP server to the USB client""" 134 logging.log(logging.TRACE, "s2c: reading") 135 data = self.s.recv(4) 136 size = struct.unpack("<I", data[:4])[0] 137 while len(data) < size: 138 data += self.s.recv(size - len(data)) 139 logging.log(logging.TRACE, "s2c: writing") 140 self._log_hexdump(data) 141 while data: 142 written = self.ep_out.write(data) 143 assert written > 0 144 data = data[written:] 145 if size % self.ep_out.wMaxPacketSize == 0: 146 logging.log(logging.TRACE, "sending zero length packet") 147 self.ep_out.write(b"") 148 logging.debug("s2c: forwarded %i bytes", size) 149 self.stats["s2c packets"] += 1 150 self.stats["s2c bytes"] += size 151 152 def log_stats(self): 153 logging.info("statistics:") 154 for k, v in self.stats.items(): 155 logging.info(f" {k+':':14s} {v}") 156 157 def log_stats_interval(self, interval=5): 158 if (time.monotonic() - self.stats_logged) < interval: 159 return 160 161 self.log_stats() 162 self.stats_logged = time.monotonic() 163 164 165def try_get_usb_str(dev, name): 166 try: 167 with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f: 168 return f.read().strip() 169 except FileNotFoundError: 170 return None 171 172 173def list_usb(args): 174 vid, pid = [int(x, 16) for x in args.id.split(":", 1)] 175 176 print("Bus | Addr | Manufacturer | Product | ID | Path") 177 print("--- | ---- | ---------------- | ---------------- | --------- | ----") 178 for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid): 179 path = path_from_usb_dev(dev) or "" 180 manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown" 181 product = try_get_usb_str(dev, "product") or "unknown" 182 print( 183 f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}" 184 ) 185 186 187def connect(args): 188 vid, pid = [int(x, 16) for x in args.id.split(":", 1)] 189 190 f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path) 191 192 try: 193 while True: 194 f.c2s() 195 f.s2c() 196 f.log_stats_interval() 197 finally: 198 f.log_stats() 199 200 201def main(): 202 parser = argparse.ArgumentParser( 203 description="Forward 9PFS requests from USB to TCP", 204 ) 205 206 parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device") 207 parser.add_argument("--path", type=str, required=False, help="path of target device") 208 parser.add_argument("-v", "--verbose", action="count", default=0) 209 210 subparsers = parser.add_subparsers() 211 subparsers.required = True 212 subparsers.dest = "command" 213 214 parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets") 215 parser_list.set_defaults(func=list_usb) 216 217 parser_connect = subparsers.add_parser( 218 "connect", help="Forward messages between the usb9pfs gadget and the 9p server" 219 ) 220 parser_connect.set_defaults(func=connect) 221 connect_group = parser_connect.add_argument_group() 222 connect_group.required = True 223 parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname") 224 parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port") 225 226 args = parser.parse_args() 227 228 logging.TRACE = logging.DEBUG - 5 229 logging.addLevelName(logging.TRACE, "TRACE") 230 231 if args.verbose >= 2: 232 level = logging.TRACE 233 elif args.verbose: 234 level = logging.DEBUG 235 else: 236 level = logging.INFO 237 logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s") 238 239 args.func(args) 240 241 242if __name__ == "__main__": 243 main() 244