-
Notifications
You must be signed in to change notification settings - Fork 285
Add ALBERT Backbone #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ALBERT Backbone #622
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! One main comment.
)(x) | ||
|
||
albert_group_layers = [ | ||
AlbertGroupLayer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like more indirection than we need here... Couldn't you just do something like nested for loops here?
for i in range(num_hidden_groups):
for j in range(num_layers_pre_group):
# setup transformer layer.
# call transformer layer
Is there something I am missing here? AlbertGroupLayer looks like it is just calling an encoder block successively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative here is to have a nested list. Something like this:
[[group1_layer1, group1_layer2], [group2_layer1, group2_layer2], ...].
Initially, I'd wanted to do this, but thought that a separate layer might be better than a nested list. I can change it to a nested list.
group_transformer_layers = []
for i in range(num_hidden_groups):
transformer_layers = [
TransformerEncoder(..., name="group_i_transformer_j") for j in range(num_layers_per_group)
]
group_layers.append(transformer_layers)
We can't do the forward pass here...we need to have a separate loop for that.
Edit: Put it in the same nested loop.
for i in range(num_layers): | ||
# Index of the hidden group | ||
group_idx = int(i / (num_layers / num_hidden_groups)) | ||
x = albert_group_layers[group_idx](x, padding_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would generally kwarg padding_mask=padding_mask
here rather than rely on arg ordering.
from keras_nlp.layers.transformer_encoder import TransformerEncoder | ||
|
||
|
||
class AlbertGroupLayer(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me like we could remove this (and definitely preference to do so if we can). But I may be missing something here.
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_nlp") | ||
class AlbertBackbone(keras.Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on timing of this landing vs #621, you may want to update this to follow the new base class.
)(x) | ||
|
||
layer_idx = 0 | ||
for i in range(num_groups): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop is pretty inscrutable. Can we at least offer a comment? I particularly don't understand the while
loop logic and wonder if there's a more readable approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking at the original implementation, I think we have over parameterized. We just want to take in the number of groups, and split out groups evenly so that any remainder layers are distributed evenly amount the first groups.
I tried to write this out as readably as possible in a nested loop.
for group_idx in range(num_groups):
# Create a single encoder block with shared weights for the group.
shared_encoder = TransformerEncoder(...)
# Split our total layers evenly among groups.
layers_per_group = num_layers // num_groups
# Any remainder layers go into the earlier groups.
if group_idx < num_layers % num_groups:
layers_per_group += 1
# Apply our shared encoder block once per layer in the group.
for _ in range(layers_per_group):
x = shared_encoder(x, padding_mask=padding_mask)
@abheesht17 would this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, is this what's happening? 🤔
They have an argument called "inner_group_num" (which I have renamed as num_layers_per_group
because inner_group_num
isn't exactly easy to understand). Check this notebook out; the graph is similar to what I have done. And the number of parameters is the same for both.
https://colab.research.google.com/drive/17m3ozVYBTuodsSRxb8ycn9g37i6q48gc?usp=sharing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jbischof, yeah. The alternative is to have a separate loop for this. Let me know which one you guys prefer. Something like:
group_layers = []
for i in range(num_groups):
transformer_layers = [
TransformerEncoder(
num_heads=num_heads,
intermediate_dim=intermediate_dim,
activation=lambda x: keras.activations.gelu(
x, approximate=True
),
dropout=dropout,
kernel_initializer=albert_kernel_initializer(),
name=f"group_{i}_transformer_layer_{j}",
)
for j in range(num_layers_per_group)
]
group_layers.append(transformer_layers)
for layer_idx in range(num_layers):
group_idx = int(layer_idx / (num_layers / num_groups))
transformer_layers = group_layers[group_idx]
for transformer_layer in transformer_layers:
x = transformer_layer(x, padding_mask=padding_mask)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I missed the "inner_group_num"
, but that seems fairly separate? A single "layer" according to the num_layers
argument is just a stack of our encoders right? I think my solution could still stand just slightly expanded...
def get_layer():
# A "layer" in Albert terminology is actually any number of repeated attention and
# feed-forward blocks, controller by the `inner_group_num` parameter.
layer = keras.Sequential()
for _ in range(inner_group_num):
layer.add(TransformerEncoder(...))
return layer
for group_idx in range(num_groups):
# Create a single encoder block with shared weights for the group.
shared_encoder = get_layer()
# Split our total layers evenly among groups.
layers_per_group = num_layers // num_groups
# Any remainder layers go into the earlier groups.
if group_idx < num_layers % num_groups:
layers_per_group += 1
# Apply our shared encoder block once per layer in the group.
for _ in range(layers_per_group):
x = shared_encoder(x, padding_mask=padding_mask)
Maybe I am still missing something though, let's sync tomorrow!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, not quite. There is a difference between the two:
>>> for group_idx in range(5):
... layers_per_group = 12 // 5
... if group_idx < 12 % 5:
... layers_per_group += 1
... for _ in range(layers_per_group):
... print(group_idx, end=", ")
...
0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4,
>>> for layer_idx in range(12):
... group_idx = int (layer_idx / (12 / 5))
... print(group_idx, end=", ")
...
0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This might be it?
layer_idx = 0
for group_idx in range(num_groups):
# Create a single block with shared weights for the group.
shared_block = get_shared_block()
# Apply the shared block until the fraction of total layer exceeds
# the fractional boundary of the next group.
while layer_idx / num_layers < (group_idx + 1) / num_groups:
x = shared_block(x, padding_mask=padding_mask)
layer_idx += 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this should work!
Also,
while layer_idx / num_layers < (group_idx + 1) / num_groups:
x = shared_block(x, padding_mask=padding_mask)
layer_idx += 1
is same as what is on the PR currently, if we just shift num_groups
to the other side:
while int(layer_idx / (num_layers / num_groups)) == group_idx
.
The hidden size must be divisible by the number of attention heads. | ||
num_hidden_groups: int. Number of groups, with each group having a | ||
certain number of Transformer layers. | ||
num_layers_per_group: int. Number of Transformer layers per group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this? The original implementation does not seem to have it, and it seems like we are overpameterized here. Let's just take in num_layers
and num_groups
.
https://github.com/google-research/albert/blob/master/modeling.py#L47
I think this could also lead to a simplification of the loop as well, will comment below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as #622 (comment).
)(x) | ||
|
||
layer_idx = 0 | ||
for i in range(num_groups): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking at the original implementation, I think we have over parameterized. We just want to take in the number of groups, and split out groups evenly so that any remainder layers are distributed evenly amount the first groups.
I tried to write this out as readably as possible in a nested loop.
for group_idx in range(num_groups):
# Create a single encoder block with shared weights for the group.
shared_encoder = TransformerEncoder(...)
# Split our total layers evenly among groups.
layers_per_group = num_layers // num_groups
# Any remainder layers go into the earlier groups.
if group_idx < num_layers % num_groups:
layers_per_group += 1
# Apply our shared encoder block once per layer in the group.
for _ in range(layers_per_group):
x = shared_encoder(x, padding_mask=padding_mask)
@abheesht17 would this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, a few more comments! We're getting super close.
name="embedding_projection", | ||
)(x) | ||
|
||
def get_group_layer(group_idx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be more readable if we extract as a keras.Layer
at the top of this file? Otherwise we have a funky closure in the middle of the forward pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this may be my bad, I recommended this (it's a style used by cv and keras applications somewhat frequently). Ideally, we could just use a keras.Sequential
here, which would be super readable. But there is no way to pass a padding mask to a sequential.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jbischof, yeah, that's what I'd initially done in the first few commits. But we want to avoid making extra layers whenever possible, discussed this with @mattdangerw. And secondly, we will have to pass a bunch of args to the layer in the for loop, which is something we can avoid.
This snippet, which is succinct and small:
for group_idx in range(num_groups):
# Define the group. A group in ALBERT terminology is any number of
# repeated attention and FFN blocks.
group_layer = get_group_layer(group_idx)
# Assume num_layers = 8, num_groups = 5. Then, the order of group
# calls will be 0, 0, 1, 1, 2, 3, 3, 4.
while int(layer_idx / num_calls_per_group) == group_idx:
x = group_layer(x, padding_mask=padding_mask)
layer_idx += 1
, will become this:
for group_idx in range(num_groups):
# Define the group. A group in ALBERT terminology is any number of
# repeated attention and FFN blocks.
group_layer = GroupLayer(
group_idx,
num_heads=num_heads,
intermediate_dim=intermediate_dim,
dropout=dropout,
...
)
# Assume num_layers = 8, num_groups = 5. Then, the order of group
# calls will be 0, 0, 1, 1, 2, 3, 3, 4.
while int(layer_idx / num_calls_per_group) == group_idx:
x = group_layer(x, padding_mask=padding_mask)
layer_idx += 1
, or something similar, anyway.
On second thoughts, doesn't look too bad, but yeah, the first point still stands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so be it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is this, which has the advantage of showing up in a summary/plot_model call as a single entity.
def get_group_layer(group_idx):
"""Stack `num_inner_repetitions` encoder block as a single layer."""
outputs = inputs = keras.Input(shape=(None, hidden_dim))
for _ in range(num_inner_repetitions):
outputs = TransformerEncoder(
num_heads=num_heads,
intermediate_dim=intermediate_dim,
activation=lambda x: keras.activations.gelu(
x, approximate=True
),
dropout=dropout,
kernel_initializer=albert_kernel_initializer(),
)(outputs, padding_mask=padding_mask)
return keras.Model(
(inputs, padding_mask), (outputs), name=f"group_{group_idx}",
)
It needs to be called as a tuple group_layer((x, padding_mask))
.
I'm also totally happy to just quit and remove num_groups
and num_inner_repetitions
entirely if we want to 😈
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can do the functional approach! Currently, this is how model.summary()
looks. I feel this isn't too bad, especially because the arguments are sort of confusing. This helps clarify which arg does what, especially num_groups
and num_inner_repetitions
.
Model: "albert_backbone"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
token_ids (InputLayer) [(None, None)] 0 []
token_embedding (Embedding) (None, None, 128) 3840000 ['token_ids[0][0]']
segment_ids (InputLayer) [(None, None)] 0 []
position_embedding (PositionEm (None, None, 128) 65536 ['token_embedding[0][0]']
bedding)
segment_embedding (Embedding) (None, None, 128) 256 ['segment_ids[0][0]']
add (Add) (None, None, 128) 0 ['token_embedding[0][0]',
'position_embedding[0][0]',
'segment_embedding[0][0]']
embeddings_layer_norm (LayerNo (None, None, 128) 256 ['add[0][0]']
rmalization)
embeddings_dropout (Dropout) (None, None, 128) 0 ['embeddings_layer_norm[0][0]']
embedding_projection (Dense) (None, None, 768) 99072 ['embeddings_dropout[0][0]']
padding_mask (InputLayer) [(None, None)] 0 []
group_0_transformer_layer_0 (T (None, None, 768) 7087872 ['embedding_projection[0][0]',
ransformerEncoder) 'padding_mask[0][0]',
'group_0_transformer_layer_0[0][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[1][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[2][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[3][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[4][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[5][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[6][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[7][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[8][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[9][
0]',
'padding_mask[0][0]',
'group_0_transformer_layer_0[10]
[0]',
'padding_mask[0][0]']
tf.__operators__.getitem (Slic (None, 768) 0 ['group_0_transformer_layer_0[11]
ingOpLambda) [0]']
pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'
]
==================================================================================================
Total params: 11,683,584
Trainable params: 11,683,584
Non-trainable params: 0
__________________________________________________________________________________________________
# repeated attention and FFN blocks. | ||
group_layer = get_group_layer(group_idx) | ||
|
||
# Assume num_layers = 8, num_groups = 5. Then, the order of group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int(8 / 5) = 1
so I don't see how your example works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not really testing this logic in the test suite so we have to be super careful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALBERT has weird logic...they don't int-ify "num_calls_per_group". It's just 8 / 5 = 1.6. That's why I have this comment: https://github.com/keras-team/keras-nlp/pull/622/files#diff-df72ed88a450bf9293e3b81087b9ee43aa547a07372043c589c126f8a0141aaeR198-R199.
num_layers = 8
num_groups = 5
num_calls_per_group = num_layers / num_groups
layer_idx = 0
for group_idx in range(num_groups):
# Define the group. A group in ALBERT terminology is any number of
# repeated attention and FFN blocks.
# group_layer = get_group_layer(group_idx)
# Assume num_layers = 8, num_groups = 5. Then, the order of group
# calls will be 0, 0, 1, 1, 2, 3, 3, 4.
while int(layer_idx / num_calls_per_group) == group_idx:
# x = group_layer(x, padding_mask=padding_mask)
print(group_idx)
layer_idx += 1
gives
0
0
1
1
2
3
3
4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok LGTM! I really appreciate the hard work @abheesht17 and sorry I didn't understand all the nuance. I'll let you and @mattdangerw make the final call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! Just a few last comments
if num_layers % num_groups != 0: | ||
raise ValueError( | ||
"`num_layers` must be divisible by `num_groups`. Received " | ||
f"`num_layers` = {num_layers}` and `num_groups` = {num_groups}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nits
"`num_layers` must be divisible by `num_groups`. Received: "
f"`num_layers={num_layers}` and `num_groups={num_groups}`."
vocabulary_size, | ||
num_layers, | ||
num_heads, | ||
num_groups, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's let's give num_groups and num_inner_repetitions defaults of 1. That will be a good way to indicate to users how this is used in practice (this means we should also drop them below the dim arguments.
), | ||
dropout=dropout, | ||
kernel_initializer=albert_kernel_initializer(), | ||
name=f"group_{group_idx}_transformer_layer_{inner_repetition_idx}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's name this:
group_{group_idx}_inner_layer_{inner_repetition_idx}
We want the link between num_inner_repetitions
and the "inner" index here clear when viewing a model summary.
Resolves #615
This PR implements the ALBERT model graph.
Checkpoint Conversion Colab Notebook: https://colab.research.google.com/drive/1JkGOSQh5cg7u7Y2K503Zmv2EjsYN9UF3?usp=sharing
Checkpoint Source: https://huggingface.co/albert-base-v2
The original repository uses TF 1.x, and I am not entirely familiar with TF 1.x (checkpoint loading, etc. might prove to be painful, considering we are using TF 2.x for KerasNLP). Hence, I've used Hugging Face's model.
Notes: