Skip to content

Commit fbbf554

Browse files
committed
Merge branch 'main' of github.com:sandialabs/pyttb
2 parents bf26ef6 + 598b5bf commit fbbf554

File tree

8 files changed

+264
-62
lines changed

8 files changed

+264
-62
lines changed

pyttb/ktensor.py

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,9 +1360,143 @@ def shape(self):
13601360
"""
13611361
return tuple([f.shape[0] for f in self.factor_matrices])
13621362

1363-
# TODO implement
1364-
def score(self, other, **kwargs):
1365-
assert False, "Not yet implemented" # pragma: no cover
1363+
def score(self, other, weight_penalty=True, threshold=0.99, greedy=True):
1364+
"""
1365+
Checks if two ktensor instances match except for permutation.
1366+
1367+
We define matching as follows. If A (self) and B (other) are single component
1368+
ktensors that have been normalized so that their weights are weights_a and
1369+
weights_b, then the score is defined as
1370+
1371+
score = penalty * (a1.T*b1) * (a2.T*b2) * ... * (aR.T*bR),
1372+
1373+
where the penalty is defined by the weights such that
1374+
1375+
penalty = 1 - abs(weights_a - weights_b) / max(weights_a, weights_b).
1376+
1377+
The score of multi-component ktensors is a normalized sum of the
1378+
scores across the best permutation of the components of A. A can have
1379+
more components than B --- any extra components are ignored in terms of
1380+
the matching score.
1381+
1382+
Parameters
1383+
----------
1384+
other: :class:`pyttb.ktensor`
1385+
`ktensor` to match against
1386+
weight_penalty: bool
1387+
Flag indicating whether or not to consider the weights in the calculations.
1388+
Default: true
1389+
threshold: float
1390+
Threshold specified in the formula above for determining a match.
1391+
Default: 0.99
1392+
greedy: bool
1393+
Flag indicating whether or not to consider all possible matchings
1394+
(exponentially expensive) or just do a greedy matching. Default: true
1395+
1396+
Returns
1397+
-------
1398+
int
1399+
Score (between 0 and 1)
1400+
:class:`pyttb.ktensor`
1401+
Copy of `self`, which has been normalized and permuted to best match `other`
1402+
bool
1403+
Flag indicating a match according to a user-specified threshold
1404+
:class:`Numpy.ndarray`
1405+
Permutation (i.e. array of indices of the modes of self) of the components
1406+
of self that was used to best match other
1407+
1408+
Example
1409+
-------
1410+
Create two `ktensor` instances:
1411+
1412+
>>> A = ttb.ktensor.from_data(np.array([2, 1, 3]), np.ones((3,3)), np.ones((4,3)), np.ones((5,3)))
1413+
>>> B = ttb.ktensor.from_data(np.array([2, 4]), np.ones((3,2)), np.ones((4,2)), np.ones((5,2)))
1414+
1415+
Compute `score` using `ktensor.weights`:
1416+
1417+
>>> score,Aperm,flag,perm = A.score(B)
1418+
>>> print(score)
1419+
0.875
1420+
>>> print(perm)
1421+
[0 2 1]
1422+
1423+
Compute `score` not using `ktensor.weights`:
1424+
1425+
>>> score,Aperm,flag,perm = A.score(B,weight_penalty=False)
1426+
>>> print(score)
1427+
1.0
1428+
>>> print(perm)
1429+
[0 1 2]
1430+
"""
1431+
1432+
if not greedy:
1433+
assert False, "Not yet implemented. Only greedy method is implemented currently."
1434+
1435+
if not isinstance(other, ktensor):
1436+
assert False, "The first input should be a ktensor"
1437+
1438+
if not (self.shape == other.shape):
1439+
assert False, "Size mismatch"
1440+
1441+
# Set-up
1442+
N = self.ndims
1443+
RA = self.ncomponents
1444+
RB = other.ncomponents
1445+
1446+
# We're matching components in A to B
1447+
if (RA < RB):
1448+
assert False, "Tensor A must have at least as many components as tensor B"
1449+
1450+
# Make sure columns of factor matrices are normalized
1451+
A = ttb.ktensor.from_tensor_type(self).normalize()
1452+
B = ttb.ktensor.from_tensor_type(other).normalize()
1453+
1454+
# Compute all possible vector-vector congruences.
1455+
1456+
# Compute every pair for each mode
1457+
Cbig = ttb.tensor.from_function(np.zeros, (RA,RB,N))
1458+
for n in range(N):
1459+
Cbig[:,:,n] = np.abs(A.factor_matrices[n].T @ B.factor_matrices[n])
1460+
1461+
# Collapse across all modes using the product
1462+
C = Cbig.collapse(np.array([2]), np.prod).double()
1463+
1464+
# Calculate penalty based on differences in the Lambda's
1465+
# Note that we are assuming the the lambda value are positive because the
1466+
# ktensor's were previously normalized.
1467+
if weight_penalty:
1468+
P = np.zeros((RA, RB))
1469+
for ra in range(RA):
1470+
la = A.weights[ra]
1471+
for rb in range(RB):
1472+
lb = B.weights[rb]
1473+
if (la == 0) and (lb == 0):
1474+
# if both lambda values are zero (0), they match
1475+
P[ra, rb] = 1
1476+
else:
1477+
P[ra, rb] = 1 - (np.abs(la-lb) / np.max([np.abs(la),np.abs(lb)]))
1478+
C = P * C
1479+
1480+
# Option to do greedy matching
1481+
if greedy:
1482+
best_perm = -1 * np.ones((RA), dtype=np.int)
1483+
best_score = 0
1484+
for r in range(RB):
1485+
idx = np.argmax(C.reshape(np.prod(C.shape),order='F'))
1486+
ij = tt_ind2sub((RA, RB), idx)
1487+
best_score = best_score + C[ij[0], ij[1]]
1488+
C[ij[0], :] = -10
1489+
C[:, ij[1]] = -10
1490+
best_perm[ij[1]] = ij[0]
1491+
best_score = best_score / RB
1492+
flag = 1
1493+
1494+
# Rearrange the components of A according to the best matching
1495+
foo = np.arange(RA)
1496+
tf = np.in1d(foo, best_perm)
1497+
best_perm[RB:RA+1] = foo[~tf]
1498+
A.arrange(permutation=best_perm)
1499+
return best_score, A, flag, best_perm
13661500

13671501
def symmetrize(self):
13681502
"""

pyttb/pyttb_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,9 @@ def tt_ind2sub(shape, idx):
514514
:class:`numpy.ndarray`
515515
"""
516516
if idx.size == 0:
517-
return np.array([])
517+
return np.empty(shape=(0,len(shape)), dtype=int)
518518

519-
return np.array(np.unravel_index(idx, shape)).transpose()
519+
return np.array(np.unravel_index(idx, shape, order='F')).transpose()
520520

521521

522522
def tt_subsubsref(obj, s):
@@ -575,7 +575,7 @@ def tt_sub2ind(shape, subs):
575575
"""
576576
if subs.size == 0:
577577
return np.array([])
578-
idx = np.ravel_multi_index(tuple(subs.transpose()), shape)
578+
idx = np.ravel_multi_index(tuple(subs.transpose()), shape, order='F')
579579
return idx
580580

581581

pyttb/sptensor.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def extract(self, searchsubs):
426426
assert False, 'Invalid subscripts'
427427

428428
# Set the default answer to zero
429-
a = np.zeros(shape=(p, 1))
429+
a = np.zeros(shape=(p, 1), dtype=self.vals.dtype)
430430

431431
# Find which indices already exist and their locations
432432
loc = ttb.tt_ismember_rows(searchsubs, self.subs)
@@ -1112,18 +1112,20 @@ def __getitem__(self, item):
11121112
11131113
Examples
11141114
--------
1115-
>>> X = sptensor(np.array([[4,4,4],[2,2,1],[2,3,2]]),np.array([[3],[5],[1]]),(4,4,4))
1116-
>>> X[1,2,1] #<-- returns zero
1117-
>>> X[4,4,4] #<-- returns 3
1118-
>>> X[3:4,:,:] #<-- returns 1 x 4 x 4 sptensor
1115+
>>> X = sptensor(np.array([[3,3,3],[1,1,0],[1,2,1]]),np.array([3,5,1]),(4,4,4))
1116+
>>> X[0,1,0] #<-- returns zero
1117+
>>> X[3,3,3] #<-- returns 3
1118+
>>> X[2:3,:,:] #<-- returns 1 x 4 x 4 sptensor
11191119
X = sptensor([6;16;26],[1;1;1],30);
11201120
X([1:6]') <-- extracts a subtensor
1121-
X([1:6]','extract') %<-- extracts a vector of 6 elements
11221121
"""
1122+
# This does not work like MATLAB TTB; you must call sptensor.extract to get this functionality
1123+
# X([1:6]','extract') %<-- extracts a vector of 6 elements
1124+
11231125
#TODO IndexError for value outside of indices
11241126
# TODO Key error if item not in container
11251127
# *** CASE 1: Rectangular Subtensor ***
1126-
if isinstance(item, tuple) and len(item) == self.ndims and item[len(item)-1] != 'extract':
1128+
if isinstance(item, tuple) and len(item) == self.ndims:
11271129
# Extract the subdimensions to be extracted from self
11281130
region = item
11291131

@@ -1160,7 +1162,7 @@ def __getitem__(self, item):
11601162
# Return a single double value for a zero-order sub-tensor
11611163
if newsiz.size == 0:
11621164
if vals.size == 0:
1163-
a = 0
1165+
a = np.array([[0]])
11641166
else:
11651167
a = vals
11661168
return a
@@ -1177,21 +1179,22 @@ def __getitem__(self, item):
11771179
# Case 2: EXTRACT
11781180

11791181
# *** CASE 2a: Subscript indexing ***
1180-
if len(item) > 1 and isinstance(item[-1], str) and item[-1] == 'extract':
1181-
# extract array of subscripts
1182-
srchsubs = np.array(item[0])
1183-
item = item[0]
1182+
if isinstance(item, np.ndarray) and len(item.shape) == 2 and item.shape[1] == self.ndims:
1183+
srchsubs = np.array(item)
11841184

11851185
# *** CASE 2b: Linear indexing ***
11861186
else:
11871187
# Error checking
1188-
if not isinstance(item, list) and not isinstance(item, np.ndarray):
1188+
if isinstance(item, list):
1189+
idx = np.array(item)
1190+
elif isinstance(item, np.ndarray):
1191+
idx = item
1192+
else:
11891193
assert False, 'Invalid indexing'
11901194

1191-
idx = item
11921195
if len(idx.shape) != 1:
11931196
assert False, 'Expecting a row index'
1194-
#idx=np.expand_dims(idx, axis=1)
1197+
11951198
# extract linear indices and convert to subscripts
11961199
srchsubs = tt_ind2sub(self.shape, idx)
11971200

pyttb/tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ def find(self):
298298
299299
:return:
300300
"""
301-
idx = np.where(self.data > 0)
302-
subs = np.array(idx).transpose()
303-
vals = self.data[idx]
304-
return subs, vals[:, None]
301+
idx = np.nonzero(np.ravel(self.data,order='F'))[0]
302+
subs = ttb.tt_ind2sub(self.shape,idx)
303+
vals = self.data[tuple(subs.T)][:,None]
304+
return subs, vals
305305

306306
def full(self):
307307
"""
@@ -1623,7 +1623,7 @@ def __repr__(self):
16231623
s += str(self.data)
16241624
s += '\n'
16251625
return s
1626-
for i, j in enumerate(range(0, np.prod(self.shape), self.shape[-1]*self.shape[-2])):
1626+
for i in np.arange(np.prod(self.shape[:-2])):
16271627
s += 'data'
16281628
if self.ndims == 2:
16291629
s += '[:, :]'

tests/test_ktensor.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def test_ktensor_issymetric(sample_ktensor_2way, sample_ktensor_symmetric):
375375
def test_ktensor_mask(sample_ktensor_2way):
376376
(data, K) = sample_ktensor_2way
377377
W = ttb.tensor.from_data(np.array([[0, 1], [1, 0]]))
378-
assert (K.mask(W) == np.array([[39], [63]])).all()
378+
assert (K.mask(W) == np.array([[63], [39]])).all()
379379

380380
# Mask too large
381381
with pytest.raises(AssertionError) as excinfo:
@@ -614,7 +614,43 @@ def test_ktensor_redistribute(sample_ktensor_2way):
614614
assert (np.array([[5, 6], [7, 8]]) == K[1]).all()
615615
assert (np.array([1, 1]) == K.weights).all()
616616

617-
@pytest.mark.indevelopment
617+
pytest.mark.indevelopment
618+
def test_ktensor_score():
619+
A = ttb.ktensor.from_data(np.array([2, 1, 3]), np.ones((3,3)), np.ones((4,3)), np.ones((5,3)))
620+
B = ttb.ktensor.from_data(np.array([2, 4]), np.ones((3,2)), np.ones((4,2)), np.ones((5,2)))
621+
622+
# defaults
623+
score, Aperm, flag, best_perm = A.score(B)
624+
assert score == 0.875
625+
assert np.allclose(Aperm.weights, np.array([15.49193338,23.23790008,7.74596669]))
626+
assert flag == 1
627+
assert (best_perm == np.array([0,2,1])).all()
628+
629+
# compare just factor matrices (i.e., do not use weights)
630+
score, Aperm, flag, best_perm = A.score(B, weight_penalty=False)
631+
assert score == 1.0
632+
assert np.allclose(Aperm.weights, np.array([15.49193338,7.74596669,23.23790008]))
633+
assert flag == 1
634+
assert (best_perm == np.array([0,1,2])).all()
635+
636+
# compute score using exhaustive search
637+
with pytest.raises(AssertionError) as excinfo:
638+
score, Aperm, flag, best_perm = A.score(B, greedy=False)
639+
assert "Not yet implemented. Only greedy method is implemented currently." in str(excinfo)
640+
641+
# try to compute score with tensor type other than ktensor
642+
with pytest.raises(AssertionError) as excinfo:
643+
score, Aperm, flag, best_perm = A.score(ttb.tensor.from_tensor_type(B))
644+
assert "The first input should be a ktensor" in str(excinfo)
645+
646+
# try to compute score when ktensor dimensions do not match
647+
with pytest.raises(AssertionError) as excinfo:
648+
# A is 3x4x5; B is 3x4x4
649+
B = ttb.ktensor.from_data(np.array([2, 4]), np.ones((3,2)), np.ones((4,2)), np.ones((4,2)))
650+
score, Aperm, flag, best_perm = A.score(B)
651+
assert "Size mismatch" in str(excinfo)
652+
653+
pytest.mark.indevelopment
618654
def test_ktensor_shape(sample_ktensor_2way, sample_ktensor_3way):
619655
(data, K0) = sample_ktensor_2way
620656
assert K0.shape == (2, 2)

tests/test_pyttb_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def test_sptensor_to_sparse_matrix():
4444
subs = np.array([[1, 1, 1], [1, 1, 3], [2, 2, 2], [3, 3, 3]])
4545
vals = np.array([[0.5], [1.5], [2.5], [3.5]])
4646
shape = (4, 4, 4)
47-
mode0 = sparse.coo_matrix(([0.5, 1.5, 2.5, 3.5], ([5, 7, 10, 15], [1, 1, 2, 3])))
48-
mode1 = sparse.coo_matrix(([0.5, 1.5, 2.5, 3.5], ([5, 7, 10, 15], [1, 1, 2, 3])))
47+
mode0 = sparse.coo_matrix(([0.5, 1.5, 2.5, 3.5], ([5, 13, 10, 15], [1, 1, 2, 3])))
48+
mode1 = sparse.coo_matrix(([0.5, 1.5, 2.5, 3.5], ([5, 13, 10, 15], [1, 1, 2, 3])))
4949
mode2 = sparse.coo_matrix(([0.5, 1.5, 2.5, 3.5], ([5, 5, 10, 15], [1, 3, 2, 3])))
5050
Ynt = [mode0, mode1, mode2]
5151
sptensorInstance = ttb.sptensor().from_data(subs, vals, shape)
@@ -330,15 +330,17 @@ def test_tt_ind2sub_valid():
330330
subs = np.array([[0, 0, 0], [1, 1, 1], [3, 3, 3]])
331331
idx = np.array([0, 21, 63])
332332
shape = (4, 4, 4)
333+
print(f'\nttb.tt_ind2sub(shape, idx): {ttb.tt_ind2sub(shape, idx)}')
333334
assert (ttb.tt_ind2sub(shape, idx) == subs).all()
334335

335-
subs = np.array([[0, 1], [1, 0]])
336+
subs = np.array([[1, 0], [0, 1]])
336337
idx = np.array([1, 2])
337338
shape = (2, 2)
339+
print(f'\nttb.tt_ind2sub(shape, idx): {ttb.tt_ind2sub(shape, idx)}')
338340
assert (ttb.tt_ind2sub(shape, idx) == subs).all()
339341

340342
empty = np.array([])
341-
assert (ttb.tt_ind2sub(shape, empty) == empty).all()
343+
assert (ttb.tt_ind2sub(shape, empty) == np.empty(shape=(0,len(shape)), dtype=int)).all()
342344

343345
@pytest.mark.indevelopment
344346
def test_tt_subsubsref_valid():

0 commit comments

Comments
 (0)