Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Add predict function for NewTask #936

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytext.common.constants import Stage
from pytext.config import ConfigBase, PyTextConfig
from pytext.config.component import ComponentType, create_component, create_trainer
from pytext.data.data import Data
from pytext.data.data import Data, pad_and_tensorize_batches
from pytext.data.sources.data_source import Schema
from pytext.data.tensorizers import Tensorizer
from pytext.metric_reporters import MetricReporter
Expand Down Expand Up @@ -206,6 +206,27 @@ def test(self, data_source):
self.metric_reporter,
)

def predict(self, examples):
"""
Generates predictions using PyTorch model. The difference with `test()` is
that this should be used when the the examples do not have any true
label/target.

Args:
examples: json format examples, input names should match the names specified
in this task's features config
"""
results = []
for row in examples:
self.model.eval()
numberized_rows = self.data.numberize_rows([row])
batches = self.data.batcher.batchify(numberized_rows)
_, inputs = next(pad_and_tensorize_batches(self.data.tensorizers, batches))
model_inputs = self.model.arrange_model_inputs(inputs)
predictions, scores = self.model.get_pred(self.model(*model_inputs))
results.append({"prediction": predictions, "score": scores})
return results

def export(self, model, export_path, metric_channels=None, export_onnx_path=None):
# Make sure to put the model on CPU and disable CUDA before exporting to
# ONNX to disable any data_parallel pieces
Expand Down