Skip to content

Commit 13639bc

Browse files
committed
Add support for USB connections
Adds a new subclass of PybricksHub that manages USB connections.
1 parent a4db08e commit 13639bc

File tree

2 files changed

+84
-4
lines changed

2 files changed

+84
-4
lines changed

pybricksdev/cli/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,11 @@ def add_parser(self, subparsers: argparse._SubParsersAction):
171171
)
172172

173173
async def run(self, args: argparse.Namespace):
174-
from ..ble import find_device
174+
from ..ble import find_device as find_ble
175175
from ..connections.ev3dev import EV3Connection
176176
from ..connections.lego import REPLHub
177-
from ..connections.pybricks import PybricksHubBLE
177+
from ..connections.pybricks import PybricksHubBLE, PybricksHubUSB
178+
from usb.core import find as find_usb
178179

179180
# Pick the right connection
180181
if args.conntype == "ssh":
@@ -185,14 +186,20 @@ async def run(self, args: argparse.Namespace):
185186

186187
device_or_address = socket.gethostbyname(args.name)
187188
hub = EV3Connection(device_or_address)
189+
188190
elif args.conntype == "ble":
189191
# It is a Pybricks Hub with BLE. Device name or address is given.
190192
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
191-
device_or_address = await find_device(args.name)
193+
device_or_address = await find_ble(args.name)
192194
hub = PybricksHubBLE(device_or_address)
193195

194196
elif args.conntype == "usb":
195-
hub = REPLHub()
197+
device_or_address = find_usb(idVendor=0x0483, idProduct=0x5740)
198+
199+
if device_or_address is not None and device_or_address.product == "Pybricks Hub":
200+
hub = PybricksHubUSB(device_or_address)
201+
else:
202+
hub = REPLHub()
196203
else:
197204
raise ValueError(f"Unknown connection type: {args.conntype}")
198205

pybricksdev/connections/pybricks.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from tqdm.auto import tqdm
1919
from tqdm.contrib.logging import logging_redirect_tqdm
2020
from typing import Callable
21+
from uuid import UUID
22+
from usb.core import Device as USBDevice, Endpoint
23+
from usb.control import get_descriptor
24+
from usb.util import *
2125

2226
from ..ble.lwp3.bytecodes import HubKind
2327
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
@@ -703,3 +707,72 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
703707

704708
async def start_notify(self, uuid: str, callback: Callable) -> None:
705709
return await self._client.start_notify(uuid, callback)
710+
711+
class PybricksHubUSB(PybricksHub):
712+
_device: USBDevice
713+
_ep_in: Endpoint
714+
_ep_out: Endpoint
715+
_notify_callbacks = {}
716+
717+
def __init__(self, device: USBDevice):
718+
super().__init__()
719+
self._device = device
720+
721+
async def _client_connect(self) -> bool:
722+
self._device.set_configuration()
723+
724+
# Save input and output endpoints
725+
cfg = self._device.get_active_configuration()
726+
intf = cfg[(0,0)]
727+
self._ep_in = find_descriptor(intf, custom_match = lambda e: endpoint_direction(e.bEndpointAddress) == ENDPOINT_IN)
728+
self._ep_out = find_descriptor(intf, custom_match = lambda e: endpoint_direction(e.bEndpointAddress) == ENDPOINT_OUT)
729+
730+
# Set write size to endpoint packet size minus length of UUID
731+
self._max_write_size = self._ep_out.wMaxPacketSize - 16
732+
733+
# Get length of BOS descriptor
734+
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
735+
(ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor)
736+
737+
# Get full BOS descriptor
738+
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)
739+
740+
while ofst < bos_len:
741+
(len, desc_type, cap_type) = struct.unpack_from("<BBB", bos_descriptor, offset=ofst)
742+
743+
if desc_type != 0x10:
744+
logger.error("Expected Device Capability descriptor")
745+
exit(1)
746+
747+
# Look for platform descriptors
748+
if cap_type == 0x05:
749+
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
750+
uuid_str = str(UUID(bytes_le = bytes(uuid_bytes)))
751+
752+
if uuid_str == FW_REV_UUID:
753+
fw_version = bytearray(bos_descriptor[ofst + 20 : ofst + len - 1]) # Remove null-terminator
754+
self.fw_version = Version(fw_version.decode())
755+
756+
elif uuid_str == SW_REV_UUID:
757+
self._protocol_version = bytearray(bos_descriptor[ofst + 20 : ofst + len - 1]) # Remove null-terminator
758+
759+
elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
760+
caps = bytearray(bos_descriptor[ofst + 20 : ofst + len])
761+
(_, self._capability_flags, self._max_user_program_size) = unpack_hub_capabilities(caps)
762+
763+
ofst += len
764+
765+
return True
766+
767+
async def _client_disconnect(self) -> bool:
768+
self._handle_disconnect()
769+
770+
async def read_gatt_char(self, uuid: str) -> bytearray:
771+
return None
772+
773+
async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
774+
self._ep_out.write(UUID(uuid).bytes_le + data)
775+
# TODO: Handle response
776+
777+
async def start_notify(self, uuid: str, callback: Callable) -> None:
778+
self._notify_callbacks[uuid] = callback

0 commit comments

Comments
 (0)