Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Fix OSS predict-py API #1320

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytext.common.constants import Stage
from pytext.config import ConfigBase, PyTextConfig
from pytext.config.component import ComponentType, create_component, create_trainer
from pytext.data.data import Data, pad_and_tensorize_batches
from pytext.data.data import Data
from pytext.data.sources.data_source import Schema
from pytext.data.tensorizers import Tensorizer
from pytext.metric_reporters import MetricReporter
Expand All @@ -21,7 +21,7 @@


def create_schema(
tensorizers: Dict[str, Tensorizer], extra_schema: Optional[Dict[str, Type]] = None
tensorizers: Dict[str, Tensorizer], extra_schema: Optional[Dict[str, Type]] = None
) -> Schema:
schema: Dict[str, Type] = {}

Expand Down Expand Up @@ -49,7 +49,7 @@ def add_to_schema(name, type):


def create_tensorizers(
model_inputs: Union[BaseModel.Config.ModelInput, Dict[str, Tensorizer.Config]],
model_inputs: Union[BaseModel.Config.ModelInput, Dict[str, Tensorizer.Config]],
) -> Dict[str, Tensorizer]:
if not isinstance(model_inputs, dict):
model_inputs = model_inputs._asdict()
Expand Down Expand Up @@ -98,13 +98,13 @@ class Config(ConfigBase):

@classmethod
def from_config(
cls,
config: Config,
unused_metadata=None,
model_state=None,
tensorizers=None,
rank=0,
world_size=1,
cls,
config: Config,
unused_metadata=None,
model_state=None,
tensorizers=None,
rank=0,
world_size=1,
):
print(f"Creating task: {cls.__name__}...")
tensorizers, data = cls._init_tensorizers(config, tensorizers, rank, world_size)
Expand Down Expand Up @@ -132,9 +132,9 @@ def _init_tensorizers(cls, config: Config, tensorizers=None, rank=0, world_size=
# Pull extra columns from the metric reporter config to pass into
# the data source schema.
extra_columns = (
getattr(config.metric_reporter, "text_column_names", [])
+ getattr(config.metric_reporter, "additional_column_names", [])
+ getattr(config.metric_reporter, "student_column_names", [])
getattr(config.metric_reporter, "text_column_names", [])
+ getattr(config.metric_reporter, "additional_column_names", [])
+ getattr(config.metric_reporter, "student_column_names", [])
)
extra_schema = {column: Any for column in extra_columns}

Expand Down Expand Up @@ -173,11 +173,11 @@ def _init_model(cls, model_config, tensorizers, model_state=None):
return model

def __init__(
self,
data: Data,
model: BaseModel,
metric_reporter: Optional[MetricReporter] = None,
trainer: Optional[TaskTrainer] = None,
self,
data: Data,
model: BaseModel,
metric_reporter: Optional[MetricReporter] = None,
trainer: Optional[TaskTrainer] = None,
):
self.data = data
self.model = model
Expand All @@ -186,14 +186,17 @@ def __init__(
self.Config.metric_reporter, model
)
self.trainer = trainer or TaskTrainer()

self.input_tensorizers = [t for name, t in self.data.tensorizers.items() if t.is_input]

log_class_usage

def train(
self,
config: PyTextConfig,
rank: int = 0,
world_size: int = 1,
training_state: TrainingState = None,
self,
config: PyTextConfig,
rank: int = 0,
world_size: int = 1,
training_state: TrainingState = None,
):
# next to move dist_init back to prepare_task in pytext/workflow.py
# when processing time between dist_init and first loss.backward() is short
Expand Down Expand Up @@ -234,12 +237,19 @@ def predict(self, examples):
"""
self.model.eval()
results = []
input_tensorizers = {name: tensorizer for name, tensorizer in self.data.tensorizers.items()
if tensorizer.is_input}
for row in examples:
numberized_rows = self.data.numberize_rows([row])
batches = self.data.batcher.batchify(numberized_rows)
_, inputs = next(pad_and_tensorize_batches(self.data.tensorizers, batches))
model_inputs = self.model.arrange_model_inputs(inputs)
model_context = self.model.arrange_model_context(inputs)
numberized_row = {
name: [tensorizer.numberize(row)]
for name, tensorizer in input_tensorizers.items()
}
tensor_dict = {
name: tensorizer.tensorize(batch=numberized_row[name])
for name, tensorizer in input_tensorizers.items()
}
model_inputs = self.model.arrange_model_inputs(tensor_dict)
model_context = self.model.arrange_model_context(tensor_dict)
predictions, scores = self.model.get_pred(
self.model(*model_inputs), context=model_context
)
Expand Down
25 changes: 16 additions & 9 deletions tests/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,26 @@
class TestMain(unittest.TestCase):
def setUp(self):
os.chdir(PYTEXT_HOME)

def run_from_command(self, args, config_filename):
runner = CliRunner()
config_path = os.path.join(tests_module.TEST_CONFIG_DIR, config_filename)
with PathManager.open(config_path, "r") as f:
config_str = f.read()
return runner.invoke(main, args=args, input=config_str)
self.runner = CliRunner()

def test_docnn(self):
# train model
result = self.run_from_command(args=["train"], config_filename="docnn.json")
result = self.runner.invoke(
main, args=["--config-file", "demo/configs/docnn.json", "train"]
)
assert not result.exception, result.exception

# export the trained model
result = self.run_from_command(args=["export"], config_filename="docnn.json")
result = self.runner.invoke(
main, args=["--config-file", "demo/configs/docnn.json", "export"]
)
print(result.output)
assert not result.exception, result.exception

# predict with PyTorch model
result = self.runner.invoke(
main,
args=["predict-py", "--model-file", "/tmp/model.pt"],
input='{"text": "create an alarm for 1:30 pm"}',
)
assert "'prediction':" in result.output, result.exception