Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit a7cdb3a

Browse files
Debojeet Chatterjeefacebook-github-bot
authored andcommitted
Modify Return Signature of TorchScript BERT
Summary: * Return actual answer text instead of spans. (Blank means no answer, no need for exception handling in caller.) * Return answer confidence score. * Return has_answer score. Differential Revision: D17983996 fbshipit-source-id: fc3084681e9d1b0453b8e4a28eeb1293dcb4d541
1 parent 8f93ce1 commit a7cdb3a

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

pytext/models/output_layers/squad_output_layer.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_position_preds(
6464
max_span_length: int,
6565
):
6666
# the following is to enforce end_pos > start_pos. We create a matrix
67-
# of start_positions X end_positions, fill it with the sum logits,
67+
# of start_position X end_position, fill it with the sum logits,
6868
# then mask it to be upper-triangular
6969
# e.g. start_pos_logits = [1, 3, 0, 5, 2]
7070
# end_pos_logits = [2, 4, 6, 3, 5]
@@ -94,14 +94,10 @@ def get_position_preds(
9494
for i in range(logit_sum_matrix.size()[1]):
9595
logit_sum_matrix[:, i, i + max_span_length :] = 0
9696
vals, ids = logit_sum_matrix.max(-1)
97-
_, start_positions = vals.max(-1)
98-
end_positions = ids.gather(-1, start_positions.unsqueeze(-1)).squeeze(-1)
97+
_, start_position = vals.max(-1)
98+
end_position = ids.gather(-1, start_position.unsqueeze(-1)).squeeze(-1)
9999

100-
return (
101-
start_positions,
102-
end_positions,
103-
logit_sum_matrix[0, start_positions, end_positions],
104-
)
100+
return start_position, end_position
105101

106102
def get_pred(
107103
self,
@@ -110,7 +106,7 @@ def get_pred(
110106
contexts: Dict[str, List[Any]],
111107
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
112108
start_pos_logits, end_pos_logits, has_answer_logits = logits
113-
start_pos_preds, end_pos_preds, _ = self.get_position_preds(
109+
start_pos_preds, end_pos_preds = self.get_position_preds(
114110
start_pos_logits, end_pos_logits, self.max_answer_len
115111
)
116112
has_answer_preds = has_answer_logits.argmax(-1)

0 commit comments

Comments
 (0)