Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

refactor ScriptTensorizor to support both text and tokens input #1096

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pytext/torchscript/tensorizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from .bert import ScriptBERTTensorizer
from .bert import ScriptBERTTensorizer, ScriptBERTTokenTensorizer
from .normalizer import VectorNormalizer
from .roberta import ScriptRoBERTaTensorizer
from .roberta import ScriptRoBERTaTensorizer, ScriptRoBERTaTokenTensorizer


__all__ = ["ScriptBERTTensorizer", "ScriptRoBERTaTensorizer", "VectorNormalizer"]
__all__ = [
"ScriptBERTTensorizer",
"ScriptBERTTokenTensorizer",
"ScriptRoBERTaTensorizer",
"ScriptRoBERTaTokenTensorizer",
"VectorNormalizer",
]
65 changes: 59 additions & 6 deletions pytext/torchscript/tensorizer/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .tensorizer import ScriptTensorizer, VocabLookup


class ScriptBERTTensorizer(ScriptTensorizer):
class ScriptBERTTensorizerBase(ScriptTensorizer):
def __init__(
self,
tokenizer: torch.jit.ScriptModule,
Expand All @@ -29,28 +29,31 @@ def __init__(

@torch.jit.script_method
def numberize(self, row: List[str]) -> Tuple[List[int], List[int], int]:
"""Convert row into token ids by doing vocab look-up. It will also
append bos & eos index into token_ids if needed.
"""Convert raw inputs into token ids by doing vocab look-up. It will also
append bos & eos index into token ids if needed.

Args:
row: a list of input texts, in most case it is a
row: 1) a list of raw inputs, in most case it is a
single text or a pair of texts.
2) a list of preprocced tokens, we could still
apply other operations (for example: bpe) on it.

Returns:
a list of token ids after doing vocab lookup and segment labels.
"""
token_ids: List[int] = []
segment_labels: List[int] = []
seq_len: int = 0
per_sentence_tokens: List[List[Tuple[str, int, int]]] = self.tokenize(row)

for idx, text in enumerate(row):
for idx, tokens in enumerate(per_sentence_tokens):
if idx == 0 and self.add_bos_token:
bos_idx: Optional[int] = self.vocab.bos_idx
else:
bos_idx: Optional[int] = None

lookup_ids: List[int] = self.vocab_lookup(
self.tokenizer.tokenize(text),
tokens,
bos_idx=bos_idx,
eos_idx=self.vocab.eos_idx,
use_eos_token_for_bos=self.use_eos_token_for_bos,
Expand All @@ -66,6 +69,18 @@ def numberize(self, row: List[str]) -> Tuple[List[int], List[int], int]:
def tensorize(
self, rows: List[List[str]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert multiple rows of raw inputs into model input tensors.

Args:
row: 1) each row is a list of raw inputs, in most case it is a
single text or a pair of texts.
2) each row is a list of preprocced tokens, we could still
apply other operations (for example: bpe) on it.

Returns:
model input tensors.
"""

tokens_2d: List[List[int]] = []
segment_labels_2d: List[List[int]] = []
seq_len_2d: List[int] = []
Expand All @@ -79,3 +94,41 @@ def tensorize(
tokens, pad_mask = pad_2d_mask(tokens_2d, pad_value=self.vocab.pad_idx)
segment_labels, _ = pad_2d_mask(segment_labels_2d, pad_value=self.vocab.pad_idx)
return tokens, pad_mask, segment_labels


class ScriptBERTTensorizer(ScriptBERTTensorizerBase):
@torch.jit.script_method
def tokenize(self, row: List[str]) -> List[List[Tuple[str, int, int]]]:
"""Convert raw inputs into tokens.

Args:
row: a list of raw inputs, in most case it is a
single text or a pair of texts.

Returns:
a per sentence list of tokens which include token index.
"""

per_sentence_tokens: List[List[Tuple[str, int, int]]] = []
for text in row:
per_sentence_tokens.append(self.tokenizer.tokenize(text))
return per_sentence_tokens


class ScriptBERTTokenTensorizer(ScriptBERTTensorizerBase):
@torch.jit.script_method
def tokenize(self, row: List[str]) -> List[List[Tuple[str, int, int]]]:
"""Convert raw inputs into tokens.

Args:
row: a list of preprocced tokens, we could still
apply other operations (for example: bpe) on it.

Returns:
a per sentence list of tokens which include token index.
"""

per_sentence_tokens: List[Tuple[str, int, int]] = []
for raw_token in row:
per_sentence_tokens.extend(self.tokenizer.tokenize(raw_token))
return [per_sentence_tokens]
71 changes: 61 additions & 10 deletions pytext/torchscript/tensorizer/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,76 @@

import torch

from .bert import ScriptBERTTensorizer
from .bert import ScriptBERTTensorizerBase


class ScriptRoBERTaTensorizer(ScriptBERTTensorizer):
class ScriptRoBERTaTensorizerBase(ScriptBERTTensorizerBase):
@torch.jit.script_method
def numberize(self, row: List[str]) -> Tuple[List[int], List[int], int]:
tokens: List[int] = []
"""Convert raw inputs into token ids by doing vocab look-up. It will also
append bos & eos index into token ids if needed.

Args:
row: 1) a list of raw inputs, in most case it is a
single text or a pair of texts.
2) a list of preprocced tokens, we could still
apply other operations (for example: bpe) on it.

Returns:
a list of token ids after doing vocab lookup and segment labels.
"""
token_ids: List[int] = []
segment_labels: List[int] = []
seq_len: int = 0
per_sentence_tokens: List[List[Tuple[str, int, int]]] = self.tokenize(row)

for idx, text in enumerate(row):
token_ids: List[int] = self.vocab_lookup(
self.tokenizer.tokenize(text),
for idx, tokens in enumerate(per_sentence_tokens):
lookup_ids: List[int] = self.vocab_lookup(
tokens,
bos_idx=self.vocab.bos_idx,
eos_idx=self.vocab.eos_idx,
max_seq_len=self.max_seq_len,
)[0]
tokens.extend(token_ids)
segment_labels.extend([idx] * len(token_ids))
seq_len = len(tokens)
token_ids.extend(lookup_ids)
segment_labels.extend([idx] * len(lookup_ids))
seq_len = len(token_ids)

return token_ids, segment_labels, seq_len


class ScriptRoBERTaTensorizer(ScriptRoBERTaTensorizerBase):
@torch.jit.script_method
def tokenize(self, row: List[str]) -> List[List[Tuple[str, int, int]]]:
"""Convert raw inputs into tokens.

Args:
row: a list of raw inputs, in most case it is a
single text or a pair of texts.

Returns:
a per sentence list of tokens which include token index.
"""

per_sentence_tokens: List[List[Tuple[str, int, int]]] = []
for text in row:
per_sentence_tokens.append(self.tokenizer.tokenize(text))
return per_sentence_tokens


class ScriptRoBERTaTokenTensorizer(ScriptRoBERTaTensorizerBase):
@torch.jit.script_method
def tokenize(self, row: List[str]) -> List[List[Tuple[str, int, int]]]:
"""Convert raw inputs into tokens.

Args:
row: a list of raw inputs, in most case it is a
single text or a pair of texts.

Returns:
a per sentence list of tokens which include token index.
"""

return tokens, segment_labels, seq_len
per_sentence_tokens: List[Tuple[str, int, int]] = []
for raw_token in row:
per_sentence_tokens.extend(self.tokenizer.tokenize(raw_token))
return [per_sentence_tokens]
8 changes: 8 additions & 0 deletions pytext/torchscript/tensorizer/tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@


class ScriptTensorizer(torch.jit.ScriptModule):
@torch.jit.script_method
def tokenize(self, row):
"""
Tokenize raw inputs into tokens, for example gpt-2 bpe,
sentence piece and yoda tokenizer.
"""
raise NotImplementedError

@torch.jit.script_method
def numberize(self, row):
"""
Expand Down