8
8
import os
9
9
import struct
10
10
from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar
11
+ from uuid import UUID
11
12
12
13
import reactivex .operators as op
13
14
import semver
19
20
from tqdm .auto import tqdm
20
21
from tqdm .contrib .logging import logging_redirect_tqdm
21
22
23
+ from usb .control import get_descriptor
24
+ from usb .core import Device as USBDevice
25
+ from usb .core import Endpoint , USBTimeoutError
26
+ from usb .util import ENDPOINT_IN , ENDPOINT_OUT , endpoint_direction , find_descriptor
27
+
22
28
from ..ble .lwp3 .bytecodes import HubKind
23
29
from ..ble .nus import NUS_RX_UUID , NUS_TX_UUID
24
30
from ..ble .pybricks import (
31
+ DEVICE_NAME_UUID ,
25
32
FW_REV_UUID ,
26
33
PNP_ID_UUID ,
27
34
PYBRICKS_COMMAND_EVENT_UUID ,
38
45
from ..compile import compile_file , compile_multi_file
39
46
from ..tools import chunk
40
47
from ..tools .checksum import xor_bytes
48
+ from ..usb import LegoUsbPid
41
49
from . import ConnectionState
42
50
43
51
logger = logging .getLogger (__name__ )
@@ -138,6 +146,140 @@ def handler(_, data):
138
146
await self ._client .start_notify (NUS_TX_UUID , handler )
139
147
140
148
149
+ class _USBTransport (_Transport ):
150
+ _device : USBDevice
151
+ _disconnected_callback : Callable
152
+ _ep_in : Endpoint
153
+ _ep_out : Endpoint
154
+ _notify_callbacks = {}
155
+ _monitor_task : asyncio .Task
156
+
157
+ def __init__ (self , device : USBDevice ):
158
+ self ._device = device
159
+
160
+ async def connect (self , disconnected_callback : Callable ) -> None :
161
+ self ._disconnected_callback = disconnected_callback
162
+ self ._device .set_configuration ()
163
+
164
+ # Save input and output endpoints
165
+ cfg = self ._device .get_active_configuration ()
166
+ intf = cfg [(0 , 0 )]
167
+ self ._ep_in = find_descriptor (
168
+ intf ,
169
+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
170
+ == ENDPOINT_IN ,
171
+ )
172
+ self ._ep_out = find_descriptor (
173
+ intf ,
174
+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
175
+ == ENDPOINT_OUT ,
176
+ )
177
+
178
+ # Get length of BOS descriptor
179
+ bos_descriptor = get_descriptor (self ._device , 5 , 0x0F , 0 )
180
+ (ofst , bos_len ) = struct .unpack ("<BxHx" , bos_descriptor )
181
+
182
+ # Get full BOS descriptor
183
+ bos_descriptor = get_descriptor (self ._device , bos_len , 0x0F , 0 )
184
+
185
+ while ofst < bos_len :
186
+ (len , desc_type , cap_type ) = struct .unpack_from (
187
+ "<BBB" , bos_descriptor , offset = ofst
188
+ )
189
+
190
+ if desc_type != 0x10 :
191
+ raise Exception ("Expected Device Capability descriptor" )
192
+
193
+ # Look for platform descriptors
194
+ if cap_type == 0x05 :
195
+ uuid_bytes = bos_descriptor [ofst + 4 : ofst + 4 + 16 ]
196
+ uuid_str = str (UUID (bytes_le = bytes (uuid_bytes )))
197
+
198
+ if uuid_str == DEVICE_NAME_UUID :
199
+ self ._device_name = bytes (
200
+ bos_descriptor [ofst + 20 : ofst + len ]
201
+ ).decode ()
202
+ print ("Connected to hub '" + self ._device_name + "'" )
203
+
204
+ elif uuid_str == FW_REV_UUID :
205
+ fw_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
206
+ self ._fw_version = Version (fw_version .decode ())
207
+
208
+ elif uuid_str == SW_REV_UUID :
209
+ protocol_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
210
+ self ._protocol_version = semver .VersionInfo .parse (
211
+ protocol_version .decode ()
212
+ )
213
+
214
+ elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID :
215
+ caps = bytes (bos_descriptor [ofst + 20 : ofst + len ])
216
+ (
217
+ self ._max_write_size ,
218
+ self ._capability_flags ,
219
+ self ._max_user_program_size ,
220
+ ) = unpack_hub_capabilities (caps )
221
+
222
+ ofst += len
223
+
224
+ self ._monitor_task = asyncio .create_task (self ._monitor_usb ())
225
+
226
+ async def disconnect (self ) -> None :
227
+ # FIXME: Need to make sure this is called when the USB cable is unplugged
228
+ self ._monitor_task .cancel ()
229
+ self ._disconnected_callback ()
230
+
231
+ async def get_firmware_version (self ) -> Version :
232
+ return self ._fw_version
233
+
234
+ async def get_protocol_version (self ) -> Version :
235
+ return self ._protocol_version
236
+
237
+ async def get_hub_type (self ) -> Tuple [HubKind , int ]:
238
+ hub_types = {
239
+ LegoUsbPid .SPIKE_PRIME : (HubKind .TECHNIC_LARGE , 0 ),
240
+ LegoUsbPid .ROBOT_INVENTOR : (HubKind .TECHNIC_LARGE , 1 ),
241
+ LegoUsbPid .SPIKE_ESSENTIAL : (HubKind .TECHNIC_SMALL , 0 ),
242
+ }
243
+
244
+ return hub_types [self ._device .idProduct ]
245
+
246
+ async def get_hub_capabilities (self ) -> Tuple [int , HubCapabilityFlag , int ]:
247
+ return (
248
+ self ._max_write_size ,
249
+ self ._capability_flags ,
250
+ self ._max_user_program_size ,
251
+ )
252
+
253
+ async def send_command (self , command : bytes ) -> None :
254
+ self ._ep_out .write (UUID (PYBRICKS_COMMAND_EVENT_UUID ).bytes_le + command )
255
+
256
+ async def set_service_handler (self , callback : Callable ) -> None :
257
+ self ._notify_callbacks [PYBRICKS_COMMAND_EVENT_UUID ] = callback
258
+
259
+ async def _monitor_usb (self ):
260
+ loop = asyncio .get_running_loop ()
261
+
262
+ while True :
263
+ msg = await loop .run_in_executor (None , self ._read_usb )
264
+
265
+ if msg is None :
266
+ continue
267
+
268
+ if len (msg ) > 16 :
269
+ uuid = str (UUID (bytes_le = bytes (msg [:16 ])))
270
+ if uuid in self ._notify_callbacks :
271
+ callback = self ._notify_callbacks [uuid ]
272
+ if callback :
273
+ callback (bytes (msg [16 :]))
274
+
275
+ def _read_usb (self ):
276
+ try :
277
+ msg = self ._ep_in .read (self ._ep_in .wMaxPacketSize )
278
+ return msg
279
+ except USBTimeoutError :
280
+ return None
281
+
282
+
141
283
class PybricksHub :
142
284
EOL = b"\r \n " # MicroPython EOL
143
285
@@ -326,11 +468,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
326
468
if self ._enable_line_handler :
327
469
self ._handle_line_data (payload )
328
470
329
- async def connect (self , device : BLEDevice ):
471
+ async def connect (self , device ):
330
472
"""Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
473
+ or :meth:`usb.core.find`
331
474
332
475
Args:
333
- device: The device to connect to.
476
+ device: The device to connect to (`BLEDevice` or `USBDevice`) .
334
477
335
478
Raises:
336
479
BleakError: if connecting failed (or old firmware without Device
@@ -350,7 +493,12 @@ async def connect(self, device: BLEDevice):
350
493
self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
351
494
)
352
495
353
- self ._transport = _BLETransport (device )
496
+ if isinstance (device , BLEDevice ):
497
+ self ._transport = _BLETransport (device )
498
+ elif isinstance (device , USBDevice ):
499
+ self ._transport = _USBTransport (device )
500
+ else :
501
+ raise TypeError ("Unsupported device type" )
354
502
355
503
def handle_disconnect ():
356
504
logger .info ("Disconnected!" )
0 commit comments