|
17 | 17 | from reactivex.subject import BehaviorSubject, Subject
|
18 | 18 | from tqdm.auto import tqdm
|
19 | 19 | from tqdm.contrib.logging import logging_redirect_tqdm
|
20 |
| -from typing import Callable |
| 20 | +from typing import Callable, Union |
| 21 | +from uuid import UUID |
| 22 | +from usb.core import Device as USBDevice |
| 23 | +from usb.control import get_descriptor |
| 24 | +from usb.util import * |
21 | 25 |
|
22 | 26 | from ..ble.lwp3.bytecodes import HubKind
|
23 | 27 | from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
|
@@ -75,6 +79,79 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
|
75 | 79 | async def start_notify(self, uuid: str, callback: Callable) -> None:
|
76 | 80 | return await self._client.start_notify(uuid, callback)
|
77 | 81 |
|
| 82 | +class PybricksHubUSBClient: |
| 83 | + _device: USBDevice |
| 84 | + _disconnected_callback: Callable = None |
| 85 | + _fw_version: str |
| 86 | + _protocol_version: str |
| 87 | + _hub_capabilities: bytearray |
| 88 | + |
| 89 | + def __init__(self, device: USBDevice, disconnected_callback = None): |
| 90 | + self._device = device |
| 91 | + self._disconnected_callback = disconnected_callback |
| 92 | + |
| 93 | + async def connect(self) -> bool: |
| 94 | + self._device.set_configuration() |
| 95 | + |
| 96 | + # Get length of BOS descriptor |
| 97 | + bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0) |
| 98 | + (ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor) |
| 99 | + |
| 100 | + # Get full BOS descriptor |
| 101 | + bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0) |
| 102 | + |
| 103 | + while ofst < bos_len: |
| 104 | + (len, desc_type, cap_type) = struct.unpack_from("<BBB", bos_descriptor, offset=ofst) |
| 105 | + |
| 106 | + if desc_type != 0x10: |
| 107 | + logger.error("Expected Device Capability descriptor") |
| 108 | + exit(1) |
| 109 | + |
| 110 | + # Look for platform descriptors |
| 111 | + if cap_type == 0x05: |
| 112 | + uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16] |
| 113 | + uuid_str = str(UUID(bytes_le = bytes(uuid_bytes))) |
| 114 | + |
| 115 | + if uuid_str == FW_REV_UUID: |
| 116 | + self._fw_version = bytearray(bos_descriptor[ofst + 20 : ofst + len - 1]) # Remove null-terminator |
| 117 | + elif uuid_str == SW_REV_UUID: |
| 118 | + self._protocol_version = bytearray(bos_descriptor[ofst + 20 : ofst + len - 1]) # Remove null-terminator |
| 119 | + elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID: |
| 120 | + self._hub_capabilities = bytearray(bos_descriptor[ofst + 20 : ofst + len]) |
| 121 | + |
| 122 | + ofst += len |
| 123 | + |
| 124 | + return True |
| 125 | + |
| 126 | + async def disconnect(self) -> bool: |
| 127 | + self._disconnected_callback() |
| 128 | + return True |
| 129 | + |
| 130 | + async def read_gatt_char(self, uuid: str) -> bytearray: |
| 131 | + if uuid == FW_REV_UUID: |
| 132 | + return self._fw_version |
| 133 | + elif uuid == SW_REV_UUID: |
| 134 | + return self._protocol_version |
| 135 | + elif uuid == PYBRICKS_HUB_CAPABILITIES_UUID: |
| 136 | + return self._hub_capabilities |
| 137 | + elif uuid == PNP_ID_UUID: |
| 138 | + return None |
| 139 | + |
| 140 | + async def write_gatt_char(self, uuid: str, data, response: bool) -> None: |
| 141 | + # Get output endpoint |
| 142 | + cfg = self._device.get_active_configuration() |
| 143 | + intf = cfg[(0,0)] |
| 144 | + ep = find_descriptor(intf, custom_match = lambda e: endpoint_direction(e.bEndpointAddress) == ENDPOINT_OUT) |
| 145 | + |
| 146 | + uuid_bytes = UUID(uuid).bytes_le |
| 147 | + ep.write(uuid_bytes + data) |
| 148 | + |
| 149 | + # TODO: Handle response |
| 150 | + |
| 151 | + async def start_notify(self, uuid: str, callback: Callable) -> None: |
| 152 | + # TODO |
| 153 | + return |
| 154 | + |
78 | 155 | class PybricksHub:
|
79 | 156 | EOL = b"\r\n" # MicroPython EOL
|
80 | 157 |
|
@@ -109,7 +186,7 @@ class PybricksHub:
|
109 | 186 | has not been connected yet or the connected hub has Pybricks profile < v1.2.0.
|
110 | 187 | """
|
111 | 188 |
|
112 |
| - def __init__(self, device: BLEDevice): |
| 189 | + def __init__(self, device: Union[BLEDevice, USBDevice]): |
113 | 190 | self.connection_state_observable = BehaviorSubject(ConnectionState.DISCONNECTED)
|
114 | 191 | self.status_observable = BehaviorSubject(StatusFlag(0))
|
115 | 192 | self._stdout_subject = Subject()
|
@@ -155,7 +232,12 @@ def handle_disconnect():
|
155 | 232 | logger.info("Disconnected!")
|
156 | 233 | self.connection_state_observable.on_next(ConnectionState.DISCONNECTED)
|
157 | 234 |
|
158 |
| - self.client = PybricksHubBLEClient(device, disconnected_callback=handle_disconnect) |
| 235 | + if isinstance(device, BLEDevice): |
| 236 | + self.client = PybricksHubBLEClient(device, disconnected_callback=handle_disconnect) |
| 237 | + elif isinstance(device, USBDevice): |
| 238 | + self.client = PybricksHubUSBClient(device, disconnected_callback=handle_disconnect) |
| 239 | + else: |
| 240 | + raise TypeError |
159 | 241 |
|
160 | 242 | @property
|
161 | 243 | def stdout_observable(self) -> Observable[bytes]:
|
@@ -308,7 +390,8 @@ async def connect(self):
|
308 | 390 | )
|
309 | 391 |
|
310 | 392 | pnp_id = await self.client.read_gatt_char(PNP_ID_UUID)
|
311 |
| - _, _, self.hub_kind, self.hub_variant = unpack_pnp_id(pnp_id) |
| 393 | + if pnp_id: |
| 394 | + _, _, self.hub_kind, self.hub_variant = unpack_pnp_id(pnp_id) |
312 | 395 |
|
313 | 396 | if protocol_version >= "1.2.0":
|
314 | 397 | caps = await self.client.read_gatt_char(PYBRICKS_HUB_CAPABILITIES_UUID)
|
|
0 commit comments