@@ -1139,51 +1139,129 @@ def fit_bpr(CSRMatrix item_features,
1139
1139
user_alpha )
1140
1140
1141
1141
1142
- def predict_lightfm (CSRMatrix item_features ,
1143
- CSRMatrix user_features ,
1144
- int [::1 ] user_ids ,
1145
- int [::1 ] item_ids ,
1146
- double [::1 ] predictions ,
1147
- FastLightFM lightfm ,
1148
- int num_threads ):
1142
+ cdef precompute_unique (CSRMatrix item_features ,
1143
+ CSRMatrix user_features ,
1144
+ int [::1 ] unique_users ,
1145
+ int [::1 ] unique_items ,
1146
+ flt * user_reprs ,
1147
+ flt * it_reprs ,
1148
+ FastLightFM lightfm ,
1149
+ int num_threads ):
1149
1150
"""
1150
- Generate predictions.
1151
+ Precomputes the representations for all the users in unique_users and
1152
+ all the items in unique_items
1151
1153
"""
1152
-
1153
- cdef int i , no_examples
1154
- cdef flt * user_repr
1154
+ cdef int i , j
1155
1155
cdef flt * it_repr
1156
+ cdef flt * user_repr
1157
+ cdef int no_features
1158
+ cdef int no_users
1156
1159
1157
- no_examples = predictions .shape [0 ]
1158
-
1160
+ no_features = unique_items .shape [0 ]
1161
+ no_users = unique_users . shape [ 0 ]
1159
1162
{nogil_block }
1160
-
1161
1163
user_repr = < flt * > malloc (sizeof (flt ) * (lightfm .no_components + 1 ))
1162
1164
it_repr = < flt * > malloc (sizeof (flt ) * (lightfm .no_components + 1 ))
1163
-
1164
- for i in {range_block }(no_examples ):
1165
-
1165
+ # users representations
1166
+ for i in {range_block }(no_users ):
1166
1167
compute_representation (user_features ,
1167
1168
lightfm .user_features ,
1168
1169
lightfm .user_biases ,
1169
1170
lightfm ,
1170
- user_ids [i ],
1171
+ unique_users [i ],
1171
1172
lightfm .user_scale ,
1172
1173
user_repr )
1174
+ for j in {range_block }(lightfm .no_components + 1 ):
1175
+ user_reprs [i * (lightfm .no_components + 1 ) + j ] = user_repr [j ]
1176
+
1177
+ # items representations
1178
+ for i in {range_block }(no_features ):
1173
1179
compute_representation (item_features ,
1174
1180
lightfm .item_features ,
1175
1181
lightfm .item_biases ,
1176
1182
lightfm ,
1177
- item_ids [i ],
1183
+ unique_items [i ],
1178
1184
lightfm .item_scale ,
1179
1185
it_repr )
1186
+ for j in {range_block }(lightfm .no_components + 1 ):
1187
+ it_reprs [i * (lightfm .no_components + 1 ) + j ] = it_repr [j ]
1188
+
1189
+
1190
+ def predict_lightfm (CSRMatrix item_features ,
1191
+ CSRMatrix user_features ,
1192
+ int [::1 ] user_ids ,
1193
+ int [::1 ] item_ids ,
1194
+ double [::1 ] predictions ,
1195
+ FastLightFM lightfm ,
1196
+ int num_threads ,
1197
+ bint precompute ):
1198
+ """
1199
+ Generate predictions.
1200
+ """
1201
+ cdef int i , j , no_examples
1202
+ cdef flt * user_repr
1203
+ cdef flt * it_repr
1204
+ cdef flt * user_reprs
1205
+ cdef flt * it_reprs
1206
+ cdef int [::1 ] unique_users
1207
+ cdef int [::1 ] unique_items
1208
+ cdef long [::1 ] inverse_users
1209
+ cdef long [::1 ] inverse_items
1210
+ cdef int no_features
1211
+ cdef int no_users
1212
+
1213
+ no_examples = predictions .shape [0 ]
1214
+
1215
+ if precompute :
1216
+ unique_users , inverse_users = np .unique (user_ids , return_inverse = True )
1217
+ unique_items , inverse_items = np .unique (item_ids , return_inverse = True )
1218
+ no_features = unique_items .shape [0 ]
1219
+ no_users = unique_users .shape [0 ]
1220
+
1221
+ user_reprs = < flt * > malloc (sizeof (flt ) * no_users * (lightfm .no_components + 1 ))
1222
+ it_reprs = < flt * > malloc (sizeof (flt ) * no_features * (lightfm .no_components + 1 ))
1223
+ precompute_unique (item_features ,
1224
+ user_features ,
1225
+ unique_users ,
1226
+ unique_items ,
1227
+ user_reprs ,
1228
+ it_reprs ,
1229
+ lightfm ,
1230
+ num_threads )
1231
+
1232
+ {nogil_block }
1233
+ user_repr = < flt * > malloc (sizeof (flt ) * (lightfm .no_components + 1 ))
1234
+ it_repr = < flt * > malloc (sizeof (flt ) * (lightfm .no_components + 1 ))
1235
+ for i in {range_block }(no_examples ):
1236
+ if precompute :
1237
+ for j in {range_block }(lightfm .no_components + 1 ):
1238
+ user_repr [j ] = user_reprs [inverse_users [i ] * (lightfm .no_components + 1 ) + j ]
1239
+ it_repr [j ] = it_reprs [inverse_items [i ] * (lightfm .no_components + 1 ) + j ]
1240
+ else :
1241
+ compute_representation (user_features ,
1242
+ lightfm .user_features ,
1243
+ lightfm .user_biases ,
1244
+ lightfm ,
1245
+ user_ids [i ],
1246
+ lightfm .user_scale ,
1247
+ user_repr )
1248
+ compute_representation (item_features ,
1249
+ lightfm .item_features ,
1250
+ lightfm .item_biases ,
1251
+ lightfm ,
1252
+ item_ids [i ],
1253
+ lightfm .item_scale ,
1254
+ it_repr )
1180
1255
1181
1256
predictions [i ] = compute_prediction_from_repr (user_repr ,
1182
- it_repr ,
1183
- lightfm .no_components )
1257
+ it_repr ,
1258
+ lightfm .no_components )
1184
1259
1185
1260
free (user_repr )
1186
1261
free (it_repr )
1262
+ if precompute :
1263
+ free (user_reprs )
1264
+ free (it_reprs )
1187
1265
1188
1266
1189
1267
def predict_ranks (CSRMatrix item_features ,
@@ -1341,3 +1419,4 @@ def __test_in_positives(int row, int col, CSRMatrix mat):
1341
1419
return True
1342
1420
else :
1343
1421
return False
1422
+
0 commit comments