@@ -251,6 +251,13 @@ def test_tensor__setitem__(sample_tensor_2way):
251
251
# Subtensor add dimension
252
252
empty_tensor [0 , 0 , 0 ] = 2
253
253
254
+ # Subtensor with lists
255
+ some_tensor = ttb .tenones ((3 , 3 ))
256
+ some_tensor [[0 , 1 ], [0 , 1 ]] = 11
257
+ assert some_tensor [0 , 0 ] == 11
258
+ assert some_tensor [1 , 1 ] == 11
259
+ assert np .all (some_tensor [[0 , 1 ], [0 , 1 ]].data == 11 )
260
+
254
261
# Subscripts with constant
255
262
tensorInstance [np .array ([[1 , 1 ]])] = 13.0
256
263
dataGrowth [1 , 1 ] = 13.0
@@ -293,6 +300,13 @@ def test_tensor__setitem__(sample_tensor_2way):
293
300
dataGrowth [np .unravel_index ([0 , 3 , 4 ], dataGrowth .shape , "F" )] = 13
294
301
assert (tensorInstance .data == dataGrowth ).all ()
295
302
303
+ # Linear index with multiple indicies
304
+ some_tensor = ttb .tenones ((3 , 3 ))
305
+ some_tensor [[0 , 1 ]] = 2
306
+ assert some_tensor [0 ] == 2
307
+ assert some_tensor [1 ] == 2
308
+ assert np .array_equal (some_tensor [[0 , 1 ]], [2 , 2 ])
309
+
296
310
# Test Empty Tensor Set Item, subtensor
297
311
emptyTensor = ttb .tensor .from_data (np .array ([]))
298
312
emptyTensor [0 , 0 , 0 ] = 0
@@ -313,10 +327,17 @@ def test_tensor__setitem__(sample_tensor_2way):
313
327
)
314
328
315
329
# Attempting to set some other way
316
- # TODO either catch this error ourselves or specify more specific exception we expect here
317
- with pytest .raises (Exception ) as excinfo :
330
+ with pytest .raises (ValueError ) as excinfo :
318
331
tensorInstance [0 , "a" , 5 ] = 13.0
319
- # assert "Invalid use of tensor setitem" in str(excinfo)
332
+ assert "must be numeric" in str (excinfo )
333
+
334
+ with pytest .raises (AssertionError ) as excinfo :
335
+
336
+ class BadKey :
337
+ pass
338
+
339
+ tensorInstance [BadKey ] = 13.0
340
+ assert "Invalid use of tensor setitem" in str (excinfo )
320
341
321
342
322
343
@pytest .mark .indevelopment
@@ -346,6 +367,11 @@ def test_tensor__getitem__(sample_tensor_2way):
346
367
tensorInstance [np .array ([[0 , 0 ], [1 , 1 ]]), "extract" ]
347
368
== params ["data" ][([0 , 0 ], [1 , 1 ])]
348
369
).all ()
370
+ # Case 2a: Extract doesn't seem to be needed
371
+ assert tensorInstance [np .array ([0 , 0 ])] == params ["data" ][0 , 0 ]
372
+ assert (
373
+ tensorInstance [np .array ([[0 , 0 ], [1 , 1 ]])] == params ["data" ][([0 , 0 ], [1 , 1 ])]
374
+ ).all ()
349
375
350
376
# Case 2b: Linear Indexing
351
377
assert tensorInstance [np .array ([0 ])] == params ["data" ][0 , 0 ]
0 commit comments