Skip to content

Commit aeef706

Browse files
Add start/end token padding to GPT2Preprocessor (#704)
* Add start end token padding to GPT2 preprocessor * change pad token to the magic one
1 parent 8581352 commit aeef706

File tree

5 files changed

+94
-13
lines changed

5 files changed

+94
-13
lines changed

keras_nlp/models/gpt2/gpt2_causal_lm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ class GPT2CausalLM(Task):
161161
162162
"""
163163

164-
def __init__(self, backbone, preprocessor=None, **kwargs):
164+
def __init__(
165+
self,
166+
backbone,
167+
preprocessor=None,
168+
**kwargs,
169+
):
165170
inputs = backbone.input
166171
x = backbone(inputs)
167172
# Use token embedding weights to project from the token representation

keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
class GPT2CausalLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
3030
def setUp(self):
31-
vocab = {
31+
self.vocab = {
3232
"<|endoftext|>": 0,
3333
"!": 1,
3434
"air": 2,
@@ -47,11 +47,12 @@ def setUp(self):
4747
merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"]
4848
merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"]
4949
merges += ["Ġai r", "Ġa i", "pla ne"]
50+
self.merges = merges
5051

5152
self.preprocessor = GPT2CausalLMPreprocessor(
5253
tokenizer=GPT2Tokenizer(
53-
vocabulary=vocab,
54-
merges=merges,
54+
vocabulary=self.vocab,
55+
merges=self.merges,
5556
),
5657
sequence_length=8,
5758
)
@@ -77,6 +78,27 @@ def test_tokenize_list_of_strings(self):
7778
self.assertAllEqual(y, [[4, 5, 3, 6, 0, 0, 0]] * 4)
7879
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0]] * 4)
7980

81+
def test_pad_start_end_token(self):
82+
input_data = ["airplane at airport"] * 4
83+
84+
preprocessor = GPT2CausalLMPreprocessor(
85+
tokenizer=GPT2Tokenizer(
86+
vocabulary=self.vocab,
87+
merges=self.merges,
88+
),
89+
sequence_length=8,
90+
add_start_token=True,
91+
add_end_token=True,
92+
)
93+
x, y, sw = preprocessor(input_data)
94+
self.assertAllEqual(
95+
x["token_ids"],
96+
[[0, 2, 4, 5, 3, 6, 0]] * 4,
97+
)
98+
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1]] * 4)
99+
self.assertAllEqual(y, [[2, 4, 5, 3, 6, 0, 0]] * 4)
100+
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0]] * 4)
101+
80102
def test_tokenize_labeled_batch(self):
81103
x = tf.constant(["airplane at airport"] * 4)
82104
y = tf.constant([1] * 4)

keras_nlp/models/gpt2/gpt2_preprocessor.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,27 @@ class GPT2Preprocessor(Preprocessor):
5858
Args:
5959
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
6060
sequence_length: The length of the packed inputs.
61+
add_start_token: If true, the preprocessor will append the tokenizer
62+
start token to each input sequence.
63+
add_end_token: If true, the preprocessor will append the tokenizer
64+
end token to each input sequence.
6165
6266
Examples:
6367
```python
6468
# Load the preprocessor from a preset.
6569
preprocessor = keras_nlp.models.GPT2Preprocessor.from_preset("gpt2_base_en")
6670
6771
# Tokenize and pack a single sentence.
68-
sentence = tf.constant("league of legends")
72+
sentence = tf.constant("League of legends")
6973
preprocessor(sentence)
7074
# Same output.
71-
preprocessor("league of legends")
75+
preprocessor("League of legends")
7276
7377
# Tokenize a batch of sentences.
74-
sentences = tf.constant(["taco tuesday", "gi gi gi gi"])
78+
sentences = tf.constant(["Taco tuesday", "Fish taco!"])
7579
preprocessor(sentences)
7680
# Same output.
77-
preprocessor(["taco tuesday", "gi gi gi gi"])
81+
preprocessor(["Taco tuesday", "Fish taco!"])
7882
7983
# Map a dataset to preprocess a single sentence.
8084
features = tf.constant(
@@ -119,12 +123,16 @@ def __init__(
119123
self,
120124
tokenizer,
121125
sequence_length,
126+
add_start_token=False,
127+
add_end_token=False,
122128
**kwargs,
123129
):
124130
super().__init__(**kwargs)
125131

126132
self.tokenizer = tokenizer
127133
self.sequence_length = sequence_length
134+
self.add_start_token = add_start_token
135+
self.add_end_token = add_end_token
128136

129137
def get_config(self):
130138
config = super().get_config()
@@ -148,9 +156,28 @@ def call(self, x, y=None, sample_weight=None):
148156
input_is_1d = len(token_ids.shape) == 1
149157
if input_is_1d:
150158
token_ids = tf.RaggedTensor.from_tensor([token_ids])
159+
if self.add_start_token:
160+
start_tokens = tf.fill(
161+
[tf.shape(token_ids)[0], 1],
162+
self.tokenizer.start_token_id,
163+
)
164+
token_ids = tf.concat([start_tokens, token_ids], axis=1)
165+
if self.add_end_token:
166+
end_tokens = tf.fill(
167+
[tf.shape(token_ids)[0], 1],
168+
self.tokenizer.end_token_id,
169+
)
170+
token_ids = tf.concat([token_ids, end_tokens], axis=1)
151171
mask = tf.ones_like(token_ids, dtype=tf.bool)
152-
mask = mask.to_tensor(shape=(None, self.sequence_length))
153-
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
172+
shape_after_padding = tf.stack(
173+
[tf.constant(-1), self.sequence_length],
174+
axis=0,
175+
)
176+
mask = mask.to_tensor(shape=shape_after_padding)
177+
token_ids = token_ids.to_tensor(
178+
shape=shape_after_padding,
179+
default_value=self.tokenizer.pad_token_id,
180+
)
154181
if input_is_1d:
155182
# If the input is a single string, we let the output be a 1D tensor.
156183
token_ids = tf.squeeze(token_ids, axis=0)

keras_nlp/models/gpt2/gpt2_preprocessor_test.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
class GPT2PreprocessorTest(tf.test.TestCase, parameterized.TestCase):
2828
def setUp(self):
29-
vocab = {
29+
self.vocab = {
3030
"<|endoftext|>": 0,
3131
"!": 1,
3232
"air": 2,
@@ -45,11 +45,12 @@ def setUp(self):
4545
merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"]
4646
merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"]
4747
merges += ["Ġai r", "Ġa i", "pla ne"]
48+
self.merges = merges
4849

4950
self.preprocessor = GPT2Preprocessor(
5051
tokenizer=GPT2Tokenizer(
51-
vocabulary=vocab,
52-
merges=merges,
52+
vocabulary=self.vocab,
53+
merges=self.merges,
5354
),
5455
sequence_length=8,
5556
)
@@ -74,6 +75,28 @@ def test_tokenize_list_of_strings(self):
7475
output["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4
7576
)
7677

78+
def test_pad_start_end_token(self):
79+
input_data = ["airplane at airport"] * 4
80+
81+
preprocessor = GPT2Preprocessor(
82+
tokenizer=GPT2Tokenizer(
83+
vocabulary=self.vocab,
84+
merges=self.merges,
85+
),
86+
sequence_length=8,
87+
add_start_token=True,
88+
add_end_token=True,
89+
)
90+
output = preprocessor(input_data)
91+
self.assertAllEqual(
92+
output["token_ids"],
93+
[[0, 2, 4, 5, 3, 6, 0, 0]] * 4,
94+
)
95+
96+
self.assertAllEqual(
97+
output["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
98+
)
99+
77100
def test_tokenize_labeled_batch(self):
78101
x = tf.constant(["airplane at airport"] * 4)
79102
y = tf.constant([1] * 4)

keras_nlp/models/gpt2/gpt2_tokenizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def __init__(
112112
)
113113

114114
self.end_token_id = self.token_to_id(end_token)
115+
# GPT2 uses the same start and pad token as end token, i.e.,
116+
# "<|endoftext|>".
117+
self.start_token_id = self.end_token_id
118+
self.pad_token_id = self.end_token_id
115119

116120
@classproperty
117121
def presets(cls):

0 commit comments

Comments
 (0)