Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit dbad191

Browse files
chenyangyu1988facebook-github-bot
authored andcommitted
Migrate pytext/utils/torch.py logic into pytext/torchscript/ for long term maintainability (#1082)
Summary: Pull Request resolved: #1082 Migrate pytext/utils/torch.py logic into pytext/torchscript/ for long term maintainability Differential Revision: D18207798 fbshipit-source-id: c5680edf99d20b4e46fa887d4865dce138da885a
1 parent 0ce889d commit dbad191

File tree

11 files changed

+770
-6
lines changed

11 files changed

+770
-6
lines changed
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

4-
from pytext.torchscript.tensorizer.bert import ScriptBERTTensorizer
5-
from pytext.torchscript.tensorizer.roberta import ScriptRoBERTaTensorizer
4+
from .bert import ScriptBERTTensorizer
5+
from .normalizer import VectorNormalizer
6+
from .roberta import ScriptRoBERTaTensorizer
67

78

8-
__all__ = ["ScriptBERTTensorizer", "ScriptRoBERTaTensorizer"]
9+
__all__ = ["ScriptBERTTensorizer", "ScriptRoBERTaTensorizer", "VectorNormalizer"]

pytext/torchscript/tensorizer/bert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from typing import List, Optional, Tuple
55

66
import torch
7-
from pytext.utils.torch import Vocabulary as ScriptVocabulary, pad_2d_mask
7+
from pytext.torchscript.utils import pad_2d_mask
8+
from pytext.torchscript.vocab import ScriptVocabulary
89

910
from .tensorizer import ScriptTensorizer, VocabLookup
1011

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
from typing import List
5+
6+
import torch
7+
8+
9+
class VectorNormalizer(torch.nn.Module):
10+
"""Performs in-place normalization over all features of a dense feature
11+
vector by doing (x - mean)/stddev for each x in the feature vector.
12+
13+
This is a ScriptModule so that the normalize function can be called at
14+
training time in the tensorizer, as well as at inference time by using it in
15+
your torchscript forward function. To use this in your tensorizer
16+
update_meta_data must be called once per row in your initialize function,
17+
and then calculate_feature_stats must be called upon the last time it runs.
18+
See usage in FloatListTensorizer for an example.
19+
20+
Setting do_normalization=False will make the normalize function an identity
21+
function.
22+
"""
23+
24+
def __init__(self, dim: int, do_normalization: bool = True):
25+
super().__init__()
26+
self.num_rows = 0
27+
self.feature_sums = [0] * dim
28+
self.feature_squared_sums = [0] * dim
29+
self.do_normalization = do_normalization
30+
self.feature_avgs = [0.0] * dim
31+
self.feature_stddevs = [1.0] * dim
32+
33+
def __getstate__(self):
34+
return {
35+
"num_rows": self.num_rows,
36+
"feature_sums": self.feature_sums,
37+
"feature_squared_sums": self.feature_squared_sums,
38+
"do_normalization": self.do_normalization,
39+
"feature_avgs": self.feature_avgs,
40+
"feature_stddevs": self.feature_stddevs,
41+
}
42+
43+
def __setstate__(self, state):
44+
self.num_rows = state["num_rows"]
45+
self.feature_sums = state["feature_sums"]
46+
self.feature_squared_sums = state["feature_squared_sums"]
47+
self.do_normalization = state["do_normalization"]
48+
self.feature_avgs = state["feature_avgs"]
49+
self.feature_stddevs = state["feature_stddevs"]
50+
51+
# TODO: this is only to satisfy the TorchScript compiler.
52+
# Can remove when D17551196 lands
53+
def forward(self):
54+
pass
55+
56+
def update_meta_data(self, vec):
57+
if self.do_normalization:
58+
self.num_rows += 1
59+
for i in range(len(vec)):
60+
self.feature_sums[i] += vec[i]
61+
self.feature_squared_sums[i] += vec[i] ** 2
62+
63+
def calculate_feature_stats(self):
64+
if self.do_normalization:
65+
self.feature_avgs = [x / self.num_rows for x in self.feature_sums]
66+
self.feature_stddevs = [
67+
(
68+
(self.feature_squared_sums[i] / self.num_rows)
69+
- (self.feature_avgs[i] ** 2)
70+
)
71+
** 0.5
72+
for i in range(len(self.feature_squared_sums))
73+
]
74+
75+
def normalize(self, vec: List[List[float]]):
76+
if self.do_normalization:
77+
for i in range(len(vec)):
78+
for j in range(len(vec[i])):
79+
vec[i][j] -= self.feature_avgs[j]
80+
vec[i][j] /= (
81+
self.feature_stddevs[j] if self.feature_stddevs[j] != 0 else 1.0
82+
)
83+
return vec

pytext/torchscript/tensorizer/tensorizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import List, Optional, Tuple
55

66
import torch
7-
from pytext.utils.torch import Vocabulary as ScriptVocabulary
7+
from pytext.torchscript.vocab import ScriptVocabulary
88

99

1010
class ScriptTensorizer(torch.jit.ScriptModule):
@@ -29,6 +29,10 @@ def tensorize(self, rows):
2929

3030

3131
class VocabLookup(torch.jit.ScriptModule):
32+
"""
33+
TorchScript implementation of lookup_tokens() in pytext/data/tensorizers.py
34+
"""
35+
3236
def __init__(self, vocab: ScriptVocabulary):
3337
super().__init__()
3438
self.vocab = vocab

pytext/torchscript/tests/test_tensorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from pytext.torchscript.tensorizer import ScriptBERTTensorizer, ScriptRoBERTaTensorizer
1010
from pytext.torchscript.tensorizer.tensorizer import VocabLookup
11-
from pytext.utils.torch import Vocabulary as ScriptVocabulary
11+
from pytext.torchscript.vocab import ScriptVocabulary
1212

1313

1414
class TensorizerTest(unittest.TestCase):
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
import io
5+
import pickle
6+
import unittest
7+
8+
import torch
9+
from pytext.torchscript.tokenizer import ScriptBPE
10+
from pytext.torchscript.utils import make_byte_inputs, utf8_chars
11+
12+
13+
BPE_VOCAB_FILE = io.StringIO(
14+
"""
15+
hello_EOW 20
16+
world_EOW 18
17+
th 17
18+
is_EOW 16
19+
bpe_EOW 15
20+
! 14
21+
h 13
22+
t 6
23+
s_EOW 2
24+
i -1
25+
今_EOW -2
26+
"""
27+
)
28+
29+
30+
class BPETest(unittest.TestCase):
31+
def test_utf8_chars(self):
32+
words = ["hello", "💩", \\_(ツ)_/¯", "今日"]
33+
for word in words:
34+
self.assertEqual(list(word), utf8_chars(word))
35+
36+
def test_simple_bpe(self):
37+
BPE_VOCAB_FILE.seek(0)
38+
bpe = ScriptBPE.from_vocab_file(BPE_VOCAB_FILE)
39+
tokenized = bpe.tokenize(["hello", "world", "this", "is", "bpe", "今日"])
40+
self.assertEqual(
41+
["hello_EOW", "world_EOW", "th", "is_EOW", "is_EOW", "bpe_EOW", "今_EOW"],
42+
tokenized,
43+
)
44+
45+
def test_pickle_bpe(self):
46+
BPE_VOCAB_FILE.seek(0)
47+
original_bpe = ScriptBPE.from_vocab_file(BPE_VOCAB_FILE)
48+
bpe = pickle.loads(pickle.dumps(original_bpe))
49+
tokenized = bpe.tokenize(["hello", "world", "this", "is", "bpe", "今日"])
50+
self.assertEqual(
51+
["hello_EOW", "world_EOW", "th", "is_EOW", "is_EOW", "bpe_EOW", "今_EOW"],
52+
tokenized,
53+
)
54+
55+
def test_make_bytes_input(self):
56+
s1 = "I want some coffee today"
57+
s2 = "Turn it up"
58+
max_char_length = 5
59+
60+
batch = [s1.split(), s2.split()]
61+
bytes, seq_lens = make_byte_inputs(batch, max_char_length)
62+
63+
def to_bytes(word, pad_to):
64+
return list(word.encode()) + [0] * (pad_to - len(word))
65+
66+
expected_bytes = [
67+
[
68+
to_bytes("I", 5),
69+
to_bytes("want", 5),
70+
to_bytes("some", 5),
71+
to_bytes("coffe", 5),
72+
to_bytes("today", 5),
73+
],
74+
[
75+
to_bytes("Turn", 5),
76+
to_bytes("it", 5),
77+
to_bytes("up", 5),
78+
to_bytes("", 5),
79+
to_bytes("", 5),
80+
],
81+
]
82+
expected_seq_lens = [5, 3]
83+
84+
self.assertIsInstance(bytes, torch.LongTensor)
85+
self.assertIsInstance(seq_lens, torch.LongTensor)
86+
self.assertEqual(bytes.tolist(), expected_bytes)
87+
self.assertEqual(seq_lens.tolist(), expected_seq_lens)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
import unittest
5+
6+
import torch
7+
from pytext.torchscript.vocab import ScriptVocabulary
8+
from torch import jit
9+
10+
11+
class VocabTest(unittest.TestCase):
12+
def setUp(self):
13+
vocab_list = ["UNK", "a", "b", "c", "d"]
14+
self.vocab = ScriptVocabulary(vocab_list)
15+
16+
def test_vocab_lookup(self):
17+
# There are bugs with just making this a script, eventually these can be simpler
18+
class LookupWord(jit.ScriptModule):
19+
def __init__(self, vocab):
20+
super().__init__()
21+
self.vocab = vocab
22+
23+
@jit.script_method
24+
def forward(self, word: str):
25+
return self.vocab.idx[word]
26+
27+
lookup_word = LookupWord(self.vocab)
28+
29+
self.assertEqual(1, lookup_word("a"))
30+
self.assertEqual(3, lookup_word("c"))
31+
with self.assertRaises(Exception):
32+
lookup_word("notaword")
33+
34+
def test_vocab_idx_lookup(self):
35+
# There are bugs with just making this a script, eventually these can be simpler
36+
class LookupIndex(jit.ScriptModule):
37+
def __init__(self, vocab):
38+
super().__init__()
39+
self.vocab = vocab
40+
41+
@jit.script_method
42+
def forward(self, i: int):
43+
return self.vocab.vocab[i]
44+
45+
lookup_idx = LookupIndex(self.vocab)
46+
47+
self.assertEqual("UNK", lookup_idx(0))
48+
self.assertEqual("b", lookup_idx(2))
49+
with self.assertRaises(Exception):
50+
lookup_idx(20)
51+
52+
def test_lookup_1d(self):
53+
self.assertEqual(
54+
[1, 0, 3, 4], self.vocab.lookup_indices_1d(["a", "e", "c", "d"])
55+
)
56+
self.assertEqual([], self.vocab.lookup_indices_1d([]))
57+
58+
def test_lookup_2d(self):
59+
self.assertEqual(
60+
[[1, 0, 3, 4], [], [2]],
61+
self.vocab.lookup_indices_2d([["a", "e", "c", "d"], [], ["b"]]),
62+
)
63+
self.assertEqual([], self.vocab.lookup_indices_2d([]))
64+
65+
def test_custom_unk(self):
66+
vocab_list = ["a", "UNK", "b", "c", "d"]
67+
vocab = ScriptVocabulary(vocab_list, unk_idx=1)
68+
self.assertEqual([0, 1, 3, 4], vocab.lookup_indices_1d(["a", "e", "c", "d"]))
69+
70+
def test_lookup_words_1d_cycle_heuristic(self):
71+
self.assertEqual(
72+
self.vocab.lookup_words_1d_cycle_heuristic(
73+
torch.tensor([1, 0, 0]), [], ["y", "z"]
74+
),
75+
["a", "y", "z"],
76+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
from .bpe import ScriptBPE
5+
6+
7+
__all__ = ["ScriptBPE"]

0 commit comments

Comments
 (0)