Skip to content

Commit 61ec65c

Browse files
committed
sptensor: Add coverage for improved indexing capability
1 parent 5aad34f commit 61ec65c

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

pyttb/sptensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,8 @@ def _set_subtensor(self, key, value):
16471647
newsz.append(self.shape[n])
16481648
else:
16491649
newsz.append(max([self.shape[n], key[n].stop]))
1650+
elif isinstance(key[n], Iterable):
1651+
newsz.append(max([self.shape[n], max(key[n]) + 1]))
16501652
else:
16511653
newsz.append(max([self.shape[n], key[n] + 1]))
16521654

tests/test_sptensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,18 @@ def test_sptensor_setitem_Case1(sample_sptensor):
533533
assert (sptensorInstance.vals == np.vstack((data["vals"], np.array([[7]])))).all()
534534
assert sptensorInstance.shape == data["shape"]
535535

536+
# Case I(b)ii: Set with scalar, iterable index, empty sptensor
537+
someTensor = ttb.sptensor()
538+
someTensor[[0, 1], 0] = 1
539+
assert someTensor[0, 0] == 1
540+
assert someTensor[1, 0] == 1
541+
assert np.all(someTensor[[0, 1], 0].vals == 1)
542+
# Case I(b)ii: Set with scalar, iterable index, non-empty sptensor
543+
someTensor[[0, 1], 1] = 2
544+
assert someTensor[0, 1] == 2
545+
assert someTensor[1, 1] == 2
546+
assert np.all(someTensor[[0, 1], 1].vals == 2)
547+
536548
# Case I: Assign with non-scalar or sptensor
537549
sptensorInstanceLarger = ttb.sptensor.from_tensor_type(sptensorInstance)
538550
with pytest.raises(AssertionError) as excinfo:

0 commit comments

Comments
 (0)