8
8
import os
9
9
import struct
10
10
from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar
11
+ from uuid import UUID
11
12
12
13
import reactivex .operators as op
13
14
import semver
19
20
from tqdm .auto import tqdm
20
21
from tqdm .contrib .logging import logging_redirect_tqdm
21
22
23
+ from usb .control import get_descriptor
24
+ from usb .core import Device as USBDevice
25
+ from usb .core import Endpoint , USBTimeoutError
26
+ from usb .util import ENDPOINT_IN , ENDPOINT_OUT , endpoint_direction , find_descriptor
27
+
22
28
from ..ble .lwp3 .bytecodes import HubKind
23
29
from ..ble .nus import NUS_RX_UUID , NUS_TX_UUID
24
30
from ..ble .pybricks import (
31
+ DEVICE_NAME_UUID ,
25
32
FW_REV_UUID ,
26
33
PNP_ID_UUID ,
27
34
PYBRICKS_COMMAND_EVENT_UUID ,
38
45
from ..compile import compile_file , compile_multi_file
39
46
from ..tools import chunk
40
47
from ..tools .checksum import xor_bytes
48
+ from ..usb import LegoUsbPid
41
49
from . import ConnectionState
42
50
43
51
logger = logging .getLogger (__name__ )
@@ -138,6 +146,141 @@ def handler(_, data):
138
146
await self ._client .start_notify (NUS_TX_UUID , handler )
139
147
140
148
149
+ class _USBTransport (_Transport ):
150
+ _device : USBDevice
151
+ _disconnected_callback : Callable
152
+ _ep_in : Endpoint
153
+ _ep_out : Endpoint
154
+ _notify_callbacks = {}
155
+ _monitor_task : asyncio .Task
156
+
157
+ _USB_PYBRICKS_MSG_COMMAND = b"\x00 "
158
+ _USB_PYBRICKS_MSG_COMMAND_RESPONSE = b"\x01 "
159
+ _USB_PYBRICKS_MSG_EVENT = b"\x02 "
160
+
161
+ def __init__ (self , device : USBDevice ):
162
+ self ._device = device
163
+
164
+ async def connect (self , disconnected_callback : Callable ) -> None :
165
+ self ._disconnected_callback = disconnected_callback
166
+ self ._device .set_configuration ()
167
+
168
+ # Save input and output endpoints
169
+ cfg = self ._device .get_active_configuration ()
170
+ intf = cfg [(0 , 0 )]
171
+ self ._ep_in = find_descriptor (
172
+ intf ,
173
+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
174
+ == ENDPOINT_IN ,
175
+ )
176
+ self ._ep_out = find_descriptor (
177
+ intf ,
178
+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
179
+ == ENDPOINT_OUT ,
180
+ )
181
+
182
+ # Get length of BOS descriptor
183
+ bos_descriptor = get_descriptor (self ._device , 5 , 0x0F , 0 )
184
+ (ofst , bos_len ) = struct .unpack ("<BxHx" , bos_descriptor )
185
+
186
+ # Get full BOS descriptor
187
+ bos_descriptor = get_descriptor (self ._device , bos_len , 0x0F , 0 )
188
+
189
+ while ofst < bos_len :
190
+ (len , desc_type , cap_type ) = struct .unpack_from (
191
+ "<BBB" , bos_descriptor , offset = ofst
192
+ )
193
+
194
+ if desc_type != 0x10 :
195
+ raise Exception ("Expected Device Capability descriptor" )
196
+
197
+ # Look for platform descriptors
198
+ if cap_type == 0x05 :
199
+ uuid_bytes = bos_descriptor [ofst + 4 : ofst + 4 + 16 ]
200
+ uuid_str = str (UUID (bytes_le = bytes (uuid_bytes )))
201
+
202
+ if uuid_str == DEVICE_NAME_UUID :
203
+ self ._device_name = bytes (
204
+ bos_descriptor [ofst + 20 : ofst + len ]
205
+ ).decode ()
206
+ print ("Connected to hub '" + self ._device_name + "'" )
207
+
208
+ elif uuid_str == FW_REV_UUID :
209
+ fw_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
210
+ self ._fw_version = Version (fw_version .decode ())
211
+
212
+ elif uuid_str == SW_REV_UUID :
213
+ protocol_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
214
+ self ._protocol_version = semver .VersionInfo .parse (
215
+ protocol_version .decode ()
216
+ )
217
+
218
+ elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID :
219
+ caps = bytes (bos_descriptor [ofst + 20 : ofst + len ])
220
+ (
221
+ self ._max_write_size ,
222
+ self ._capability_flags ,
223
+ self ._max_user_program_size ,
224
+ ) = unpack_hub_capabilities (caps )
225
+
226
+ ofst += len
227
+
228
+ self ._monitor_task = asyncio .create_task (self ._monitor_usb ())
229
+
230
+ async def disconnect (self ) -> None :
231
+ # FIXME: Need to make sure this is called when the USB cable is unplugged
232
+ self ._monitor_task .cancel ()
233
+ self ._disconnected_callback ()
234
+
235
+ async def get_firmware_version (self ) -> Version :
236
+ return self ._fw_version
237
+
238
+ async def get_protocol_version (self ) -> Version :
239
+ return self ._protocol_version
240
+
241
+ async def get_hub_type (self ) -> Tuple [HubKind , int ]:
242
+ hub_types = {
243
+ LegoUsbPid .SPIKE_PRIME : (HubKind .TECHNIC_LARGE , 0 ),
244
+ LegoUsbPid .ROBOT_INVENTOR : (HubKind .TECHNIC_LARGE , 1 ),
245
+ LegoUsbPid .SPIKE_ESSENTIAL : (HubKind .TECHNIC_SMALL , 0 ),
246
+ }
247
+
248
+ return hub_types [self ._device .idProduct ]
249
+
250
+ async def get_hub_capabilities (self ) -> Tuple [int , HubCapabilityFlag , int ]:
251
+ return (
252
+ self ._max_write_size ,
253
+ self ._capability_flags ,
254
+ self ._max_user_program_size ,
255
+ )
256
+
257
+ async def send_command (self , command : bytes ) -> None :
258
+ self ._ep_out .write (self ._USB_PYBRICKS_MSG_COMMAND + command )
259
+
260
+ async def set_service_handler (self , callback : Callable ) -> None :
261
+ self ._notify_callbacks [self ._USB_PYBRICKS_MSG_EVENT [0 ]] = callback
262
+
263
+ async def _monitor_usb (self ):
264
+ loop = asyncio .get_running_loop ()
265
+
266
+ while True :
267
+ msg = await loop .run_in_executor (None , self ._read_usb )
268
+
269
+ if msg is None or len (msg ) == 0 :
270
+ continue
271
+
272
+ if msg [0 ] in self ._notify_callbacks :
273
+ callback = self ._notify_callbacks [msg [0 ]]
274
+ callback (bytes (msg [1 :]))
275
+
276
+ def _read_usb (self ):
277
+ try :
278
+ msg = self ._ep_in .read (self ._ep_in .wMaxPacketSize )
279
+ return msg
280
+ except USBTimeoutError :
281
+ return None
282
+
283
+
141
284
class PybricksHub :
142
285
EOL = b"\r \n " # MicroPython EOL
143
286
@@ -326,11 +469,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
326
469
if self ._enable_line_handler :
327
470
self ._handle_line_data (payload )
328
471
329
- async def connect (self , device : BLEDevice ):
472
+ async def connect (self , device ):
330
473
"""Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
474
+ or :meth:`usb.core.find`
331
475
332
476
Args:
333
- device: The device to connect to.
477
+ device: The device to connect to (`BLEDevice` or `USBDevice`) .
334
478
335
479
Raises:
336
480
BleakError: if connecting failed (or old firmware without Device
@@ -350,7 +494,12 @@ async def connect(self, device: BLEDevice):
350
494
self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
351
495
)
352
496
353
- self ._transport = _BLETransport (device )
497
+ if isinstance (device , BLEDevice ):
498
+ self ._transport = _BLETransport (device )
499
+ elif isinstance (device , USBDevice ):
500
+ self ._transport = _USBTransport (device )
501
+ else :
502
+ raise TypeError ("Unsupported device type" )
354
503
355
504
def handle_disconnect ():
356
505
logger .info ("Disconnected!" )
0 commit comments