Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Improve get_logits() #683

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytext/config/pytext_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,9 @@ class TestConfig(ConfigBase):
test_out_path: str = ""


class LogitsConfig(TestConfig):
# Whether to dump the raw input to output file.
dump_raw_input: bool = False


LATEST_VERSION = 18
5 changes: 5 additions & 0 deletions pytext/data/sources/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import math
import os
from typing import List, Optional

from pytext.data.sources.data_source import (
Expand Down Expand Up @@ -70,6 +71,10 @@ def _split_document(
def process_squad_json(fname, ignore_impossible, max_character_length, min_overlap):
if not fname:
return
if not os.path.exists(fname):
print(f"{fname} does not exist. Not unflattening.")
return

with open(fname) as infile:
dump = json.load(infile)

Expand Down
6 changes: 3 additions & 3 deletions pytext/models/output_layers/squad_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_pred(
targets: torch.Tensor,
contexts: Dict[str, List[Any]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
start_pos_logits, end_pos_logits, has_answer_logits = logits
start_pos_logits, end_pos_logits, has_answer_logits, _, _ = logits
start_pos_preds, end_pos_preds = self.get_position_preds(
start_pos_logits, end_pos_logits, self.max_answer_len
)
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_pred(

def get_loss(
self,
logits: Tuple[torch.Tensor, torch.Tensor],
logits: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
targets: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
contexts: Dict[str, Any] = None,
*args,
Expand All @@ -152,7 +152,7 @@ def get_loss(
torch.Tensor: Model loss.

"""
start_pos_logit, end_pos_logit, has_answer_logit = logits
start_pos_logit, end_pos_logit, has_answer_logit, _, _ = logits
start_pos_target, end_pos_target, has_answer_target = targets

num_answers = start_pos_target.size()[-1]
Expand Down
11 changes: 6 additions & 5 deletions pytext/models/qna/bert_squad_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,18 @@ def arrange_targets(self, tensor_dict):
answer_start_indices,
answer_end_indices,
) = tensor_dict["squad_input"]
# label = True if answer exists
label = tensor_dict["has_answer"]
return answer_start_indices, answer_end_indices, label
# has_answer = True if answer exists
has_answer = tensor_dict["has_answer"]
return answer_start_indices, answer_end_indices, has_answer

def forward(self, *inputs):
tokens, pad_mask, segment_labels, _ = inputs # See arrange_model_inputs()
encoded_layers, cls_embed = self.encoder(inputs)
pos_logits = self.decoder(encoded_layers[-1])
if isinstance(pos_logits, (list, tuple)):
pos_logits = pos_logits[0]

has_ans_logits = (
has_answer_logits = (
torch.zeros((pos_logits.size(0), 2)) # dummy tensor
if self.output_layer.ignore_impossible
else self.has_ans_decoder(cls_embed)
Expand All @@ -122,4 +123,4 @@ def forward(self, *inputs):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

return start_logits, end_logits, has_ans_logits
return start_logits, end_logits, has_answer_logits, pad_mask, segment_labels
2 changes: 2 additions & 0 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def _init_model(cls, model_config, tensorizers, model_state=None):
ComponentType.MODEL, model_config, tensorizers=tensorizers
)
if model_state:
print("Loading model from model state dict...")
model.load_state_dict(model_state)
print("Loaded!")

if cuda.CUDA_ENABLED:
model = model.cuda()
Expand Down
26 changes: 20 additions & 6 deletions pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,13 @@ def get_logits(
output_path: Optional[str] = None,
test_path: Optional[str] = None,
field_names: Optional[List[str]] = None,
dump_raw_input: bool = False,
):
_set_cuda(use_cuda_if_available)
task, train_config, _traing_state = load(snapshot_path)
print(f"Successfully loaded model from {snapshot_path}")
print(f"Model on GPU? {next(task.model.parameters()).is_cuda}")

if isinstance(task, NewTask):
task.model.eval()
data_source = _get_data_source(
Expand All @@ -309,25 +311,37 @@ def get_logits(
batches = task.data.batches(Stage.TEST, data_source=data_source)

with open(output_path, "w", encoding="utf-8") as fout, torch.no_grad():
for (_, tensor_dict) in batches:
for (raw_batch, tensor_dict) in batches:
raw_input_tuple = (
dict_zip(*raw_batch, value_only=True) if dump_raw_input else ()
)
model_inputs = task.model.arrange_model_inputs(tensor_dict)
model_outputs = task.model(*model_inputs)
if isinstance(model_outputs, tuple):
model_outputs_list = [m.tolist() for m in model_outputs]
for row in zip(*model_outputs_list):
# row is a tuple of lists
model_outputs_tuple = tuple(m.tolist() for m in model_outputs)
for row in zip(*raw_input_tuple, *model_outputs_tuple):
dump_row = "\t".join(json.dumps(r) for r in row)
fout.write(f"{dump_row}\n")
elif isinstance(model_outputs, torch.Tensor):
model_outputs_list = model_outputs.tolist()
for row in zip(model_outputs_list):
fout.write(f"{json.dumps(row)}\n")
for row in zip(*raw_input_tuple, model_outputs_list):
dump_row = "\t".join(json.dumps(r) for r in row)
fout.write(f"{dump_row}\n")
else:
raise Exception(
"Expecting tuple or torchTensor types for model_outputs"
)


def dict_zip(*dicts, value_only=False):
dict_keys = dicts[0].keys()
return (
tuple([d[k] for d in dicts] for k in dict_keys)
if value_only
else {k: [d[k] for d in dicts] for k in dict_keys}
)


def batch_predict(model_file: str, examples: List[Dict[str, Any]]):
task, train_config, _training_state = load(model_file)
return task.predict(examples)