Skip to content

Commit 82ccb9a

Browse files
committed
Add support for USB connections
Adds a new subclass of PybricksHub that manages USB connections. Signed-off-by: Nate Karstens <[email protected]>
1 parent eb28049 commit 82ccb9a

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

pybricksdev/cli/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,12 @@ def add_parser(self, subparsers: argparse._SubParsersAction):
171171
)
172172

173173
async def run(self, args: argparse.Namespace):
174-
from ..ble import find_device
174+
from usb.core import find as find_usb
175+
176+
from ..ble import find_device as find_ble
175177
from ..connections.ev3dev import EV3Connection
176178
from ..connections.lego import REPLHub
177-
from ..connections.pybricks import PybricksHubBLE
179+
from ..connections.pybricks import PybricksHubBLE, PybricksHubUSB
178180

179181
# Pick the right connection
180182
if args.conntype == "ssh":
@@ -185,14 +187,23 @@ async def run(self, args: argparse.Namespace):
185187

186188
device_or_address = socket.gethostbyname(args.name)
187189
hub = EV3Connection(device_or_address)
190+
188191
elif args.conntype == "ble":
189192
# It is a Pybricks Hub with BLE. Device name or address is given.
190193
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
191-
device_or_address = await find_device(args.name)
194+
device_or_address = await find_ble(args.name)
192195
hub = PybricksHubBLE(device_or_address)
193196

194197
elif args.conntype == "usb":
195-
hub = REPLHub()
198+
device_or_address = find_usb(idVendor=0x0483, idProduct=0x5740)
199+
200+
if (
201+
device_or_address is not None
202+
and device_or_address.product == "Pybricks Hub"
203+
):
204+
hub = PybricksHubUSB(device_or_address)
205+
else:
206+
hub = REPLHub()
196207
else:
197208
raise ValueError(f"Unknown connection type: {args.conntype}")
198209

pybricksdev/connections/pybricks.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import struct
99
from typing import Awaitable, Callable, List, Optional, TypeVar
10+
from uuid import UUID
1011

1112
import reactivex.operators as op
1213
import semver
@@ -17,6 +18,10 @@
1718
from reactivex.subject import BehaviorSubject, Subject
1819
from tqdm.auto import tqdm
1920
from tqdm.contrib.logging import logging_redirect_tqdm
21+
from usb.control import get_descriptor
22+
from usb.core import Device as USBDevice
23+
from usb.core import Endpoint
24+
from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor
2025

2126
from ..ble.lwp3.bytecodes import HubKind
2227
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
@@ -705,3 +710,91 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
705710

706711
async def start_notify(self, uuid: str, callback: Callable) -> None:
707712
return await self._client.start_notify(uuid, callback)
713+
714+
715+
class PybricksHubUSB(PybricksHub):
716+
_device: USBDevice
717+
_ep_in: Endpoint
718+
_ep_out: Endpoint
719+
_notify_callbacks = {}
720+
721+
def __init__(self, device: USBDevice):
722+
super().__init__()
723+
self._device = device
724+
725+
async def _client_connect(self) -> bool:
726+
self._device.set_configuration()
727+
728+
# Save input and output endpoints
729+
cfg = self._device.get_active_configuration()
730+
intf = cfg[(0, 0)]
731+
self._ep_in = find_descriptor(
732+
intf,
733+
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
734+
== ENDPOINT_IN,
735+
)
736+
self._ep_out = find_descriptor(
737+
intf,
738+
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
739+
== ENDPOINT_OUT,
740+
)
741+
742+
# Set write size to endpoint packet size minus length of UUID
743+
self._max_write_size = self._ep_out.wMaxPacketSize - 16
744+
745+
# Get length of BOS descriptor
746+
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
747+
(ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor)
748+
749+
# Get full BOS descriptor
750+
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)
751+
752+
while ofst < bos_len:
753+
(len, desc_type, cap_type) = struct.unpack_from(
754+
"<BBB", bos_descriptor, offset=ofst
755+
)
756+
757+
if desc_type != 0x10:
758+
logger.error("Expected Device Capability descriptor")
759+
exit(1)
760+
761+
# Look for platform descriptors
762+
if cap_type == 0x05:
763+
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
764+
uuid_str = str(UUID(bytes_le=bytes(uuid_bytes)))
765+
766+
if uuid_str == FW_REV_UUID:
767+
fw_version = bytearray(
768+
bos_descriptor[ofst + 20 : ofst + len - 1]
769+
) # Remove null-terminator
770+
self.fw_version = Version(fw_version.decode())
771+
772+
elif uuid_str == SW_REV_UUID:
773+
self._protocol_version = bytearray(
774+
bos_descriptor[ofst + 20 : ofst + len - 1]
775+
) # Remove null-terminator
776+
777+
elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
778+
caps = bytearray(bos_descriptor[ofst + 20 : ofst + len])
779+
(
780+
_,
781+
self._capability_flags,
782+
self._max_user_program_size,
783+
) = unpack_hub_capabilities(caps)
784+
785+
ofst += len
786+
787+
return True
788+
789+
async def _client_disconnect(self) -> bool:
790+
self._handle_disconnect()
791+
792+
async def read_gatt_char(self, uuid: str) -> bytearray:
793+
return None
794+
795+
async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
796+
self._ep_out.write(UUID(uuid).bytes_le + data)
797+
# TODO: Handle response
798+
799+
async def start_notify(self, uuid: str, callback: Callable) -> None:
800+
self._notify_callbacks[uuid] = callback

0 commit comments

Comments
 (0)