Skip to content

Commit 8e1f4d0

Browse files
committed
SPTENSOR:
* Fix argument mismatch for ttm (modes s.b. dims) * Fix ttm for rectangular matrices * Make error message consitent with tensor TENSOR: * Fix error message
1 parent d3577f2 commit 8e1f4d0

File tree

4 files changed

+33
-25
lines changed

4 files changed

+33
-25
lines changed

pyttb/sptensor.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,15 +2173,14 @@ def __repr__(self): # pragma: no cover
21732173

21742174
__str__ = __repr__
21752175

2176-
def ttm(self, matrices, mode, dims=None, transpose=False):
2176+
def ttm(self, matrices, dims=None, transpose=False):
21772177
"""
21782178
Sparse tensor times matrix.
21792179
21802180
Parameters
21812181
----------
21822182
matrices: A matrix or list of matrices
2183-
mode:
2184-
dims:
2183+
dims: :class:`Numpy.ndarray`, int
21852184
transpose: Transpose matrices to be multiplied
21862185
21872186
Returns
@@ -2190,10 +2189,15 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
21902189
"""
21912190
if dims is None:
21922191
dims = np.arange(self.ndims)
2192+
elif isinstance(dims, list):
2193+
dims = np.array(dims)
2194+
elif np.isscalar(dims) or isinstance(dims, list):
2195+
dims = np.array([dims])
2196+
21932197
# Handle list of matrices
21942198
if isinstance(matrices, list):
21952199
# Check dimensions are valid
2196-
[dims, vidx] = tt_dimscheck(mode, self.ndims, len(matrices))
2200+
[dims, vidx] = tt_dimscheck(dims, self.ndims, len(matrices))
21972201
# Calculate individual products
21982202
Y = self.ttm(matrices[vidx[0]], dims[0], transpose=transpose)
21992203
for i in range(1, dims.size):
@@ -2208,33 +2212,34 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
22082212
if transpose:
22092213
matrices = matrices.transpose()
22102214

2211-
# Check mode
2212-
if not np.isscalar(mode) or mode < 0 or mode > self.ndims-1:
2213-
assert False, "Mode must be in [0, ndims)"
2215+
# Ensure this is the terminal single dimension case
2216+
if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
2217+
assert False, "dims must contain values in [0,self.dims)"
2218+
dims = dims[0]
22142219

22152220
# Compute the product
22162221

22172222
# Check that sizes match
2218-
if self.shape[mode] != matrices.shape[1]:
2223+
if self.shape[dims] != matrices.shape[1]:
22192224
assert False, "Matrix shape doesn't match tensor shape"
22202225

22212226
# Compute the new size
22222227
siz = np.array(self.shape)
2223-
siz[mode] = matrices.shape[0]
2228+
siz[dims] = matrices.shape[0]
22242229

22252230
# Compute self[mode]'
2226-
Xnt = ttb.tt_to_sparse_matrix(self, mode, True)
2231+
Xnt = ttb.tt_to_sparse_matrix(self, dims, True)
22272232

22282233
# Reshape puts the reshaped things after the unchanged modes, transpose then puts it in front
22292234
idx = 0
22302235

22312236
# Convert to sparse matrix and do multiplication; generally result is sparse
22322237
Z = Xnt.dot(matrices.transpose())
22332238

2234-
# Rearrange back into sparse tensor of original shape
2235-
Ynt = ttb.tt_from_sparse_matrix(Z, self.shape, mode, idx)
2239+
# Rearrange back into sparse tensor of correct shape
2240+
Ynt = ttb.tt_from_sparse_matrix(Z, siz, dims, idx)
22362241

2237-
if Z.nnz <= 0.5 * np.prod(siz):
2242+
if not isinstance(Z, np.ndarray) and Z.nnz <= 0.5 * np.prod(siz):
22382243
return Ynt
22392244
else:
22402245
# TODO evaluate performance loss by casting into sptensor then tensor. I assume minimal since we are already

pyttb/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ def ttm(self, matrix, dims=None, transpose=False):
921921
assert False, "matrix must be of type numpy.ndarray"
922922

923923
if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
924-
assert False, "dims must contain values in [0,self.dims]"
924+
assert False, "dims must contain values in [0,self.dims)"
925925

926926
# old version (ver=0)
927927
shape = np.array(self.shape)

tests/test_sptensor.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,25 +1365,25 @@ def test_sptensor_ttm(sample_sptensor):
13651365
result[:, 3, 3] = 3.5
13661366
result = ttb.tensor.from_data(result)
13671367
result = ttb.sptensor.from_tensor_type(result)
1368-
assert sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), mode=0).isequal(result)
1369-
assert sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), mode=0, transpose=True).isequal(result)
1368+
assert sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), dims=0).isequal(result)
1369+
assert sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), dims=0, transpose=True).isequal(result)
13701370

13711371
# This is a multiway multiplication yielding a sparse tensor, yielding a dense tensor relies on tensor.ttm
13721372
matrix = sparse.coo_matrix(np.eye(4))
13731373
list_of_matrices = [matrix, matrix, matrix]
1374-
assert sptensorInstance.ttm(list_of_matrices, mode=np.array([0, 1, 2])).isequal(sptensorInstance)
1374+
assert sptensorInstance.ttm(list_of_matrices, dims=np.array([0, 1, 2])).isequal(sptensorInstance)
13751375

13761376
with pytest.raises(AssertionError) as excinfo:
1377-
sptensorInstance.ttm(sparse.coo_matrix(np.ones((5, 5))), mode=0)
1377+
sptensorInstance.ttm(sparse.coo_matrix(np.ones((5, 5))), dims=0)
13781378
assert "Matrix shape doesn't match tensor shape" in str(excinfo)
13791379

13801380
with pytest.raises(AssertionError) as excinfo:
1381-
sptensorInstance.ttm(np.array([1, 2, 3, 4]), mode=0)
1381+
sptensorInstance.ttm(np.array([1, 2, 3, 4]), dims=0)
13821382
assert "Sptensor.ttm: second argument must be a matrix" in str(excinfo)
13831383

13841384
with pytest.raises(AssertionError) as excinfo:
1385-
sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), mode=4)
1386-
assert "Mode must be in [0, ndims)" in str(excinfo)
1385+
sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), dims=4)
1386+
assert "dims must contain values in [0,self.dims)" in str(excinfo)
13871387

13881388
sptensorInstance[0, :, :] = 1
13891389
sptensorInstance[3, :, :] = 1
@@ -1397,17 +1397,20 @@ def test_sptensor_ttm(sample_sptensor):
13971397
# TODO: Ensure mode mappings are consistent between matlab and numpy
13981398
# MATLAB is opposite orientation so the mapping from matlab to numpy is
13991399
# {3:0, 2:2, 1:1}
1400-
assert (sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), mode=1).isequal(ttb.tensor.from_data(result)))
1400+
assert (sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), dims=1).isequal(ttb.tensor.from_data(result)))
14011401

14021402
result = 2*np.ones((4, 4, 4))
14031403
result[:, 1, 1] = 2.5
14041404
result[:, 1, 3] = 3.5
14051405
result[:, 2, 2] = 4.5
1406-
assert (sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), mode=0).isequal(ttb.tensor.from_data(result)))
1406+
assert (sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), dims=0).isequal(ttb.tensor.from_data(result)))
14071407

14081408
result = np.zeros((4, 4, 4))
14091409
result[0, :, :] = 4.0
14101410
result[3, :, :] = 4.0
14111411
result[1, 1, :] = 2
14121412
result[2, 2, :] = 2.5
1413-
assert (sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), mode=2).isequal(ttb.tensor.from_data(result)))
1413+
assert (sptensorInstance.ttm(sparse.coo_matrix(np.ones((4, 4))), dims=2).isequal(ttb.tensor.from_data(result)))
1414+
1415+
# Confirm reshape for non-square matrix
1416+
assert sptensorInstance.ttm(sparse.coo_matrix(np.ones((1, 4))), dims=2).shape == (4,4,1)

tests/test_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ def test_tensor_ttm(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
10571057
# 3-way, dims must be in range [0,self.ndims]
10581058
with pytest.raises(AssertionError) as excinfo:
10591059
tensorInstance3.ttm(M2, tensorInstance3.ndims + 1)
1060-
assert "dims must contain values in [0,self.dims]" in str(excinfo)
1060+
assert "dims must contain values in [0,self.dims)" in str(excinfo)
10611061

10621062
@pytest.mark.indevelopment
10631063
def test_tensor_ttt(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):

0 commit comments

Comments
 (0)