Skip to content

Commit 5cf8e21

Browse files
committed
Modified predict to compute representations just once; added predict_score with same input as predict_rank, tests
1 parent c90bbf9 commit 5cf8e21

File tree

5 files changed

+2280
-566
lines changed

5 files changed

+2280
-566
lines changed

lightfm/_lightfm_fast.pyx.template

Lines changed: 100 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,51 +1139,129 @@ def fit_bpr(CSRMatrix item_features,
11391139
user_alpha)
11401140

11411141

1142-
def predict_lightfm(CSRMatrix item_features,
1143-
CSRMatrix user_features,
1144-
int[::1] user_ids,
1145-
int[::1] item_ids,
1146-
double[::1] predictions,
1147-
FastLightFM lightfm,
1148-
int num_threads):
1142+
cdef precompute_unique(CSRMatrix item_features,
1143+
CSRMatrix user_features,
1144+
int[::1] unique_users,
1145+
int[::1] unique_items,
1146+
flt *user_reprs,
1147+
flt *it_reprs,
1148+
FastLightFM lightfm,
1149+
int num_threads):
11491150
"""
1150-
Generate predictions.
1151+
Precomputes the representations for all the users in unique_users and
1152+
all the items in unique_items
11511153
"""
1152-
1153-
cdef int i, no_examples
1154-
cdef flt *user_repr
1154+
cdef int i, j
11551155
cdef flt *it_repr
1156+
cdef flt *user_repr
1157+
cdef int no_features
1158+
cdef int no_users
11561159

1157-
no_examples = predictions.shape[0]
1158-
1160+
no_features = unique_items.shape[0]
1161+
no_users = unique_users.shape[0]
11591162
{nogil_block}
1160-
11611163
user_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
11621164
it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
1163-
1164-
for i in {range_block}(no_examples):
1165-
1165+
# users representations
1166+
for i in {range_block}(no_users):
11661167
compute_representation(user_features,
11671168
lightfm.user_features,
11681169
lightfm.user_biases,
11691170
lightfm,
1170-
user_ids[i],
1171+
unique_users[i],
11711172
lightfm.user_scale,
11721173
user_repr)
1174+
for j in {range_block}(lightfm.no_components + 1):
1175+
user_reprs[i * (lightfm.no_components + 1) + j] = user_repr[j]
1176+
1177+
# items representations
1178+
for i in {range_block}(no_features):
11731179
compute_representation(item_features,
11741180
lightfm.item_features,
11751181
lightfm.item_biases,
11761182
lightfm,
1177-
item_ids[i],
1183+
unique_items[i],
11781184
lightfm.item_scale,
11791185
it_repr)
1186+
for j in {range_block}(lightfm.no_components + 1):
1187+
it_reprs[i * (lightfm.no_components + 1) + j] = it_repr[j]
1188+
1189+
1190+
def predict_lightfm(CSRMatrix item_features,
1191+
CSRMatrix user_features,
1192+
int[::1] user_ids,
1193+
int[::1] item_ids,
1194+
double[::1] predictions,
1195+
FastLightFM lightfm,
1196+
int num_threads,
1197+
bint precompute):
1198+
"""
1199+
Generate predictions.
1200+
"""
1201+
cdef int i, j, no_examples
1202+
cdef flt *user_repr
1203+
cdef flt *it_repr
1204+
cdef flt *user_reprs
1205+
cdef flt *it_reprs
1206+
cdef int[::1] unique_users
1207+
cdef int[::1] unique_items
1208+
cdef long[::1] inverse_users
1209+
cdef long[::1] inverse_items
1210+
cdef int no_features
1211+
cdef int no_users
1212+
1213+
no_examples = predictions.shape[0]
1214+
1215+
if precompute:
1216+
unique_users, inverse_users = np.unique(user_ids, return_inverse=True)
1217+
unique_items, inverse_items = np.unique(item_ids, return_inverse=True)
1218+
no_features = unique_items.shape[0]
1219+
no_users = unique_users.shape[0]
1220+
1221+
user_reprs = <flt *>malloc(sizeof(flt) * no_users * (lightfm.no_components + 1))
1222+
it_reprs = <flt *>malloc(sizeof(flt) * no_features *(lightfm.no_components + 1))
1223+
precompute_unique(item_features,
1224+
user_features,
1225+
unique_users,
1226+
unique_items,
1227+
user_reprs,
1228+
it_reprs,
1229+
lightfm,
1230+
num_threads)
1231+
1232+
{nogil_block}
1233+
user_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
1234+
it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
1235+
for i in {range_block}(no_examples):
1236+
if precompute:
1237+
for j in {range_block}(lightfm.no_components + 1):
1238+
user_repr[j] = user_reprs[inverse_users[i] * (lightfm.no_components + 1) + j]
1239+
it_repr[j] = it_reprs[inverse_items[i] * (lightfm.no_components + 1) + j]
1240+
else:
1241+
compute_representation(user_features,
1242+
lightfm.user_features,
1243+
lightfm.user_biases,
1244+
lightfm,
1245+
user_ids[i],
1246+
lightfm.user_scale,
1247+
user_repr)
1248+
compute_representation(item_features,
1249+
lightfm.item_features,
1250+
lightfm.item_biases,
1251+
lightfm,
1252+
item_ids[i],
1253+
lightfm.item_scale,
1254+
it_repr)
11801255

11811256
predictions[i] = compute_prediction_from_repr(user_repr,
1182-
it_repr,
1183-
lightfm.no_components)
1257+
it_repr,
1258+
lightfm.no_components)
11841259

11851260
free(user_repr)
11861261
free(it_repr)
1262+
if precompute:
1263+
free(user_reprs)
1264+
free(it_reprs)
11871265

11881266

11891267
def predict_ranks(CSRMatrix item_features,
@@ -1341,3 +1419,4 @@ def __test_in_positives(int row, int col, CSRMatrix mat):
13411419
return True
13421420
else:
13431421
return False
1422+

0 commit comments

Comments
 (0)