Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Fix caffe2 predict #1103

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pytext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ def _predict(workspace_id, predict_net, model, tensorizers, input):
}
model_inputs = model.arrange_model_inputs(tensor_dict)
model_input_names = model.get_export_input_names(tensorizers)
vocab_to_export = model.vocab_to_export(tensorizers)
for blob_name, model_input in zip(model_input_names, model_inputs):
converted_blob_name = convert_caffe2_blob_name(blob_name)
workspace.blobs[converted_blob_name] = np.array([model_input], dtype=str)
converted_blob_name = blob_name
dtype = np.float32
if blob_name in vocab_to_export:
converted_blob_name = convert_caffe2_blob_name(blob_name)
dtype = str

workspace.blobs[converted_blob_name] = np.array([model_input], dtype=dtype)
workspace.RunNet(predict_net)
return {
str(blob): workspace.blobs[blob][0] for blob in predict_net.external_outputs
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_data_tiny.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
alarm/set_alarm 11:17:datetime reactivate weekly alarm
alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays
alarm/time_left_on_alarm When will alarm go off
reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer
alarm/set_alarm 11:17:datetime reactivate weekly alarm [1.0]
alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays [1.0]
alarm/time_left_on_alarm When will alarm go off [1.0]
reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer [1.0]
20 changes: 10 additions & 10 deletions tests/data/train_data_tiny.tsv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier
alarm/set_alarm Turn on all my alarms
alarm/set_alarm 12:27:datetime sound alarm every 8 minutes
alarm/set_alarm 7:17:datetime repeat yesterdays alarm
alarm/snooze_alarm continue my alarm
alarm/time_left_on_alarm Do I have anymore time on the alarm
reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday
reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders
weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow
weather/find 13:17:weather/attribute When will it snow
alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier [1.0]
alarm/set_alarm Turn on all my alarms [1.0]
alarm/set_alarm 12:27:datetime sound alarm every 8 minutes [1.0]
alarm/set_alarm 7:17:datetime repeat yesterdays alarm [1.0]
alarm/snooze_alarm continue my alarm [1.0]
alarm/time_left_on_alarm Do I have anymore time on the alarm [1.0]
reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday [1.0]
reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders [1.0]
weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow [1.0]
weather/find 13:17:weather/attribute When will it snow [1.0]
27 changes: 25 additions & 2 deletions tests/predictor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import tempfile
import unittest

import numpy as np
from pytext import batch_predict_caffe2_model
from pytext.config import LATEST_VERSION, PyTextConfig
from pytext.data import Data
from pytext.data.sources import TSVDataSource
from pytext.data.tensorizers import (
FloatListTensorizer,
LabelTensorizer,
TokenTensorizer,
)
from pytext.models.doc_model import DocModel
from pytext.task import create_task
from pytext.task.serialize import save
from pytext.task.tasks import DocumentClassificationTask
Expand All @@ -25,14 +32,23 @@ def test_batch_predict_caffe2_model(self):
eval_data = tests_module.test_file("test_data_tiny.tsv")
config = PyTextConfig(
task=DocumentClassificationTask.Config(
model=DocModel.Config(
inputs=DocModel.Config.ModelInput(
tokens=TokenTensorizer.Config(),
dense=FloatListTensorizer.Config(
column="dense", dim=1, error_check=True
),
labels=LabelTensorizer.Config(),
)
),
data=Data.Config(
source=TSVDataSource.Config(
train_filename=train_data,
eval_filename=eval_data,
test_filename=eval_data,
field_names=["label", "slots", "text"],
field_names=["label", "slots", "text", "dense"],
)
)
),
),
version=LATEST_VERSION,
save_snapshot_path=snapshot_file.name,
Expand All @@ -47,3 +63,10 @@ def test_batch_predict_caffe2_model(self):
snapshot_file.name, caffe2_model_file.name
)
self.assertEqual(4, len(results))

pt_results = task.predict(task.data.data_source.test)

for pt_res, res in zip(pt_results, results):
np.testing.assert_array_almost_equal(
pt_res["score"].tolist()[0], [score[0] for score in res.values()]
)