Skip to content

Commit 99bae21

Browse files
committed
WIP: Add USB device support
TODO
1 parent 8c85931 commit 99bae21

File tree

2 files changed

+96
-7
lines changed

2 files changed

+96
-7
lines changed

pybricksdev/cli/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,11 @@ def add_parser(self, subparsers: argparse._SubParsersAction):
171171
)
172172

173173
async def run(self, args: argparse.Namespace):
174-
from ..ble import find_device
174+
from ..ble import find_device as find_ble
175175
from ..connections.ev3dev import EV3Connection
176176
from ..connections.lego import REPLHub
177177
from ..connections.pybricks import PybricksHub
178+
from usb.core import find as find_usb
178179

179180
# Pick the right connection
180181
if args.conntype == "ssh":
@@ -188,11 +189,16 @@ async def run(self, args: argparse.Namespace):
188189
elif args.conntype == "ble":
189190
# It is a Pybricks Hub with BLE. Device name or address is given.
190191
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
191-
device_or_address = await find_device(args.name)
192+
device_or_address = await find_ble(args.name)
192193
hub = PybricksHub(device_or_address)
193194

194195
elif args.conntype == "usb":
195-
hub = REPLHub()
196+
device_or_address = find_usb(idVendor=0x0483, idProduct=0x5740)
197+
198+
if device_or_address is not None and device_or_address.product == "Pybricks Hub":
199+
hub = PybricksHub(device_or_address)
200+
else:
201+
hub = REPLHub()
196202
else:
197203
raise ValueError(f"Unknown connection type: {args.conntype}")
198204

pybricksdev/connections/pybricks.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from reactivex.subject import BehaviorSubject, Subject
1818
from tqdm.auto import tqdm
1919
from tqdm.contrib.logging import logging_redirect_tqdm
20-
from typing import Callable
20+
from typing import Callable, Union
21+
from uuid import UUID
22+
from usb.core import Device as USBDevice
23+
from usb.control import get_descriptor
24+
from usb.util import *
2125

2226
from ..ble.lwp3.bytecodes import HubKind
2327
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
@@ -75,6 +79,79 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
7579
async def start_notify(self, uuid: str, callback: Callable) -> None:
7680
return await self._client.start_notify(uuid, callback)
7781

82+
class PybricksHubUSBClient:
83+
_device: USBDevice
84+
_disconnected_callback: Callable = None
85+
_fw_version: str
86+
_protocol_version: str
87+
_hub_capabilities: bytearray
88+
89+
def __init__(self, device: USBDevice, disconnected_callback = None):
90+
self._device = device
91+
self._disconnected_callback = disconnected_callback
92+
93+
async def connect(self) -> bool:
94+
self._device.set_configuration()
95+
96+
# Get length of BOS descriptor
97+
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
98+
(ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor)
99+
100+
# Get full BOS descriptor
101+
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)
102+
103+
while ofst < bos_len:
104+
(len, desc_type, cap_type) = struct.unpack_from("<BBB", bos_descriptor, offset=ofst)
105+
106+
if desc_type != 0x10:
107+
logger.error("Expected Device Capability descriptor")
108+
exit(1)
109+
110+
# Look for platform descriptors
111+
if cap_type == 0x05:
112+
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
113+
uuid_str = str(UUID(bytes_le = bytes(uuid_bytes)))
114+
115+
if uuid_str == FW_REV_UUID:
116+
self._fw_version = bytearray(bos_descriptor[ofst + 20 : ofst + len - 1]) # Remove null-terminator
117+
elif uuid_str == SW_REV_UUID:
118+
self._protocol_version = bytearray(bos_descriptor[ofst + 20 : ofst + len - 1]) # Remove null-terminator
119+
elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
120+
self._hub_capabilities = bytearray(bos_descriptor[ofst + 20 : ofst + len])
121+
122+
ofst += len
123+
124+
return True
125+
126+
async def disconnect(self) -> bool:
127+
self._disconnected_callback()
128+
return True
129+
130+
async def read_gatt_char(self, uuid: str) -> bytearray:
131+
if uuid == FW_REV_UUID:
132+
return self._fw_version
133+
elif uuid == SW_REV_UUID:
134+
return self._protocol_version
135+
elif uuid == PYBRICKS_HUB_CAPABILITIES_UUID:
136+
return self._hub_capabilities
137+
elif uuid == PNP_ID_UUID:
138+
return None
139+
140+
async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
141+
# Get output endpoint
142+
cfg = self._device.get_active_configuration()
143+
intf = cfg[(0,0)]
144+
ep = find_descriptor(intf, custom_match = lambda e: endpoint_direction(e.bEndpointAddress) == ENDPOINT_OUT)
145+
146+
uuid_bytes = UUID(uuid).bytes_le
147+
ep.write(uuid_bytes + data)
148+
149+
# TODO: Handle response
150+
151+
async def start_notify(self, uuid: str, callback: Callable) -> None:
152+
# TODO
153+
return
154+
78155
class PybricksHub:
79156
EOL = b"\r\n" # MicroPython EOL
80157

@@ -109,7 +186,7 @@ class PybricksHub:
109186
has not been connected yet or the connected hub has Pybricks profile < v1.2.0.
110187
"""
111188

112-
def __init__(self, device: BLEDevice):
189+
def __init__(self, device: Union[BLEDevice, USBDevice]):
113190
self.connection_state_observable = BehaviorSubject(ConnectionState.DISCONNECTED)
114191
self.status_observable = BehaviorSubject(StatusFlag(0))
115192
self._stdout_subject = Subject()
@@ -155,7 +232,12 @@ def handle_disconnect():
155232
logger.info("Disconnected!")
156233
self.connection_state_observable.on_next(ConnectionState.DISCONNECTED)
157234

158-
self.client = PybricksHubBLEClient(device, disconnected_callback=handle_disconnect)
235+
if isinstance(device, BLEDevice):
236+
self.client = PybricksHubBLEClient(device, disconnected_callback=handle_disconnect)
237+
elif isinstance(device, USBDevice):
238+
self.client = PybricksHubUSBClient(device, disconnected_callback=handle_disconnect)
239+
else:
240+
raise TypeError
159241

160242
@property
161243
def stdout_observable(self) -> Observable[bytes]:
@@ -308,7 +390,8 @@ async def connect(self):
308390
)
309391

310392
pnp_id = await self.client.read_gatt_char(PNP_ID_UUID)
311-
_, _, self.hub_kind, self.hub_variant = unpack_pnp_id(pnp_id)
393+
if pnp_id:
394+
_, _, self.hub_kind, self.hub_variant = unpack_pnp_id(pnp_id)
312395

313396
if protocol_version >= "1.2.0":
314397
caps = await self.client.read_gatt_char(PYBRICKS_HUB_CAPABILITIES_UUID)

0 commit comments

Comments
 (0)