7
7
import logging
8
8
import os
9
9
import struct
10
- from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar
10
+ from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar , Union
11
+ from uuid import UUID
11
12
12
13
import reactivex .operators as op
13
14
import semver
19
20
from tqdm .auto import tqdm
20
21
from tqdm .contrib .logging import logging_redirect_tqdm
21
22
23
+ from usb .control import get_descriptor
24
+ from usb .core import Device as USBDevice
25
+ from usb .core import Endpoint , USBTimeoutError
26
+ from usb .util import ENDPOINT_IN , ENDPOINT_OUT , endpoint_direction , find_descriptor
27
+
22
28
from ..ble .lwp3 .bytecodes import HubKind
23
29
from ..ble .nus import NUS_RX_UUID , NUS_TX_UUID
24
30
from ..ble .pybricks import (
31
+ DEVICE_NAME_UUID ,
25
32
FW_REV_UUID ,
26
33
PNP_ID_UUID ,
27
34
PYBRICKS_COMMAND_EVENT_UUID ,
38
45
from ..compile import compile_file , compile_multi_file
39
46
from ..tools import chunk
40
47
from ..tools .checksum import xor_bytes
48
+ from ..usb import LegoUsbMsg , LegoUsbPid
41
49
from . import ConnectionState
42
50
43
51
logger = logging .getLogger (__name__ )
@@ -138,6 +146,156 @@ def handler(_, data):
138
146
await self ._client .start_notify (NUS_TX_UUID , handler )
139
147
140
148
149
+ class _USBTransport (_Transport ):
150
+ _device : USBDevice
151
+ _disconnected_callback : Callable
152
+ _ep_in : Endpoint
153
+ _ep_out : Endpoint
154
+ _notify_callbacks = {}
155
+ _monitor_task : asyncio .Task
156
+ _response : asyncio .Future
157
+
158
+ def __init__ (self , device : USBDevice ):
159
+ self ._device = device
160
+ self ._notify_callbacks [
161
+ LegoUsbMsg .USB_PYBRICKS_MSG_COMMAND_RESPONSE
162
+ ] = self ._response_handler
163
+
164
+ async def connect (self , disconnected_callback : Callable ) -> None :
165
+ self ._disconnected_callback = disconnected_callback
166
+ self ._device .set_configuration ()
167
+
168
+ # Save input and output endpoints
169
+ cfg = self ._device .get_active_configuration ()
170
+ intf = cfg [(0 , 0 )]
171
+ self ._ep_in = find_descriptor (
172
+ intf ,
173
+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
174
+ == ENDPOINT_IN ,
175
+ )
176
+ self ._ep_out = find_descriptor (
177
+ intf ,
178
+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
179
+ == ENDPOINT_OUT ,
180
+ )
181
+
182
+ # Get length of BOS descriptor
183
+ bos_descriptor = get_descriptor (self ._device , 5 , 0x0F , 0 )
184
+ (ofst , bos_len ) = struct .unpack ("<BxHx" , bos_descriptor )
185
+
186
+ # Get full BOS descriptor
187
+ bos_descriptor = get_descriptor (self ._device , bos_len , 0x0F , 0 )
188
+
189
+ while ofst < bos_len :
190
+ (len , desc_type , cap_type ) = struct .unpack_from (
191
+ "<BBB" , bos_descriptor , offset = ofst
192
+ )
193
+
194
+ if desc_type != 0x10 :
195
+ raise Exception ("Expected Device Capability descriptor" )
196
+
197
+ # Look for platform descriptors
198
+ if cap_type == 0x05 :
199
+ uuid_bytes = bos_descriptor [ofst + 4 : ofst + 4 + 16 ]
200
+ uuid_str = str (UUID (bytes_le = bytes (uuid_bytes )))
201
+
202
+ if uuid_str == DEVICE_NAME_UUID :
203
+ self ._device_name = bytes (
204
+ bos_descriptor [ofst + 20 : ofst + len ]
205
+ ).decode ()
206
+ print ("Connected to hub '" + self ._device_name + "'" )
207
+
208
+ elif uuid_str == FW_REV_UUID :
209
+ fw_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
210
+ self ._fw_version = Version (fw_version .decode ())
211
+
212
+ elif uuid_str == SW_REV_UUID :
213
+ protocol_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
214
+ self ._protocol_version = semver .VersionInfo .parse (
215
+ protocol_version .decode ()
216
+ )
217
+
218
+ elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID :
219
+ caps = bytes (bos_descriptor [ofst + 20 : ofst + len ])
220
+ (
221
+ self ._max_write_size ,
222
+ self ._capability_flags ,
223
+ self ._max_user_program_size ,
224
+ ) = unpack_hub_capabilities (caps )
225
+
226
+ ofst += len
227
+
228
+ self ._monitor_task = asyncio .create_task (self ._monitor_usb ())
229
+
230
+ async def disconnect (self ) -> None :
231
+ # FIXME: Need to make sure this is called when the USB cable is unplugged
232
+ self ._monitor_task .cancel ()
233
+ self ._disconnected_callback ()
234
+
235
+ async def get_firmware_version (self ) -> Version :
236
+ return self ._fw_version
237
+
238
+ async def get_protocol_version (self ) -> Version :
239
+ return self ._protocol_version
240
+
241
+ async def get_hub_type (self ) -> Tuple [HubKind , int ]:
242
+ hub_types = {
243
+ LegoUsbPid .SPIKE_PRIME : (HubKind .TECHNIC_LARGE , 0 ),
244
+ LegoUsbPid .ROBOT_INVENTOR : (HubKind .TECHNIC_LARGE , 1 ),
245
+ LegoUsbPid .SPIKE_ESSENTIAL : (HubKind .TECHNIC_SMALL , 0 ),
246
+ }
247
+
248
+ return hub_types [self ._device .idProduct ]
249
+
250
+ async def get_hub_capabilities (self ) -> Tuple [int , HubCapabilityFlag , int ]:
251
+ return (
252
+ self ._max_write_size ,
253
+ self ._capability_flags ,
254
+ self ._max_user_program_size ,
255
+ )
256
+
257
+ async def send_command (self , command : bytes ) -> None :
258
+ self ._response = asyncio .Future ()
259
+ self ._ep_out .write (
260
+ struct .pack ("B" , LegoUsbMsg .USB_PYBRICKS_MSG_COMMAND ) + command
261
+ )
262
+ try :
263
+ await asyncio .wait_for (self ._response , 1 )
264
+ if self ._response .result () != 0 :
265
+ print (
266
+ f"Received error response for command: { self ._response .result ()} "
267
+ )
268
+ except asyncio .TimeoutError :
269
+ print ("Timed out waiting for a response" )
270
+
271
+ async def set_service_handler (self , callback : Callable ) -> None :
272
+ self ._notify_callbacks [LegoUsbMsg .USB_PYBRICKS_MSG_EVENT ] = callback
273
+
274
+ async def _monitor_usb (self ):
275
+ loop = asyncio .get_running_loop ()
276
+
277
+ while True :
278
+ msg = await loop .run_in_executor (None , self ._read_usb )
279
+
280
+ if msg is None or len (msg ) == 0 :
281
+ continue
282
+
283
+ callback = self ._notify_callbacks .get (msg [0 ])
284
+ if callback is not None :
285
+ callback (bytes (msg [1 :]))
286
+
287
+ def _read_usb (self ):
288
+ try :
289
+ msg = self ._ep_in .read (self ._ep_in .wMaxPacketSize )
290
+ return msg
291
+ except USBTimeoutError :
292
+ return None
293
+
294
+ def _response_handler (self , data : bytes ) -> None :
295
+ (response ,) = struct .unpack ("<I" , data )
296
+ self ._response .set_result (response )
297
+
298
+
141
299
class PybricksHub :
142
300
EOL = b"\r \n " # MicroPython EOL
143
301
@@ -326,11 +484,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
326
484
if self ._enable_line_handler :
327
485
self ._handle_line_data (payload )
328
486
329
- async def connect (self , device : BLEDevice ):
487
+ async def connect (self , device : Union [ BLEDevice , USBDevice ] ):
330
488
"""Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
489
+ or :meth:`usb.core.find`
331
490
332
491
Args:
333
- device: The device to connect to.
492
+ device: The device to connect to (`BLEDevice` or `USBDevice`) .
334
493
335
494
Raises:
336
495
BleakError: if connecting failed (or old firmware without Device
@@ -350,7 +509,12 @@ async def connect(self, device: BLEDevice):
350
509
self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
351
510
)
352
511
353
- self ._transport = _BLETransport (device )
512
+ if isinstance (device , BLEDevice ):
513
+ self ._transport = _BLETransport (device )
514
+ elif isinstance (device , USBDevice ):
515
+ self ._transport = _USBTransport (device )
516
+ else :
517
+ raise TypeError ("Unsupported device type" )
354
518
355
519
def handle_disconnect ():
356
520
logger .info ("Disconnected!" )
0 commit comments