Skip to content

feat: add flash_attn 2 to bert #27478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class AlignTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = AlignTextSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = AlignTextSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = AlignTextSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = AlignTextSelfOutput(config)
self.pruned_heads = set()

Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class AltRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = AltRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = AltRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = AltRobertaSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = AltRobertaSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -1343,7 +1348,7 @@ def forward(
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
attention_mask=extended_attention_mask if not getattr(self.config, "_flash_attn_2_enabled", False) else attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
Expand Down
149 changes: 147 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
is_flash_attn_2_available
)
from .configuration_bert import BertConfig

if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -103,6 +107,18 @@
# See all BERT models at https://huggingface.co/models?filter=bert
]

# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
Expand Down Expand Up @@ -374,6 +390,129 @@ def forward(
outputs = outputs + (past_key_value,)
return outputs

class BertSelfFlashAttention(BertSelfAttention):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:

mixed_query_layer = self.query(hidden_states)
batch_size, query_length, hidden_size = mixed_query_layer.shape

is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

use_cache = past_key_value is not None
if self.is_decoder:
past_key_value = (key_layer, value_layer)

if attention_mask is None:
attention_scores = flash_attn_func(
query_layer, key_layer, value_layer, self.dropout.p, softmax_scale=None, causal=self.is_decoder
)
else:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_layer, key_layer, value_layer, attention_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=self.dropout.p,
softmax_scale=None,
causal=self.is_decoder,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

attn_output = attn_output.reshape(batch_size, query_length, hidden_size).contiguous()

if output_attentions:
raise NotImplementedError("output_attentions is not implemented")

outputs = (attn_output, None)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)

batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


class BertSelfOutput(nn.Module):
def __init__(self, config):
Expand All @@ -392,7 +531,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class BertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = BertSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = BertSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -740,6 +884,7 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -1012,7 +1157,7 @@ def forward(
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
attention_mask=extended_attention_mask if not getattr(self.config, "_flash_attn_2_enabled", False) else attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,12 @@ def forward(
class BertGenerationAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BertGenerationSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = BertGenerationSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = BertGenerationSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = BertGenerationSelfOutput(config)
self.pruned_heads = set()

Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,12 @@ def forward(
class BridgeTowerAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BridgeTowerSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = BridgeTowerSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = BridgeTowerSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = BridgeTowerSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -1168,7 +1173,7 @@ def forward(
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
attention_mask=extended_attention_mask if not getattr(self.config, "_flash_attn_2_enabled", False) else attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/camembert/modeling_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class CamembertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = CamembertSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = CamembertSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = CamembertSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = CamembertSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -887,7 +892,7 @@ def forward(
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
attention_mask=extended_attention_mask if not getattr(self.config, "_flash_attn_2_enabled", False) else attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class ChineseCLIPTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ChineseCLIPTextSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = ChineseCLIPTextSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = ChineseCLIPTextSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = ChineseCLIPTextSelfOutput(config)
self.pruned_heads = set()

Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/clap/modeling_clap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class ClapTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ClapTextSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = ClapTextSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = ClapTextSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = ClapTextSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -1891,7 +1896,7 @@ def forward(
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
attention_mask=extended_attention_mask if not getattr(self.config, "_flash_attn_2_enabled", False) else attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/data2vec/modeling_data2vec_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class Data2VecTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = Data2VecTextSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = Data2VecTextSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = Data2VecTextSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = Data2VecTextSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -836,7 +841,7 @@ def forward(
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
attention_mask=extended_attention_mask if not getattr(self.config, "_flash_attn_2_enabled", False) else attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class ElectraAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = ElectraSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = ElectraSelfOutput(config)
self.pruned_heads = set()

Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/ernie/modeling_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class ErnieAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ErnieSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = ErnieSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = ErnieSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = ErnieSelfOutput(config)
self.pruned_heads = set()

Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/git/modeling_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,12 @@ class GitAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
else:
if config.position_embedding_type != "absolute":
raise NotImplementedError("flash_attn_2 now only supports absolute position embedding")
self.self = GitSelfFlashAttention(config, position_embedding_type=position_embedding_type)
self.output = GitSelfOutput(config)
self.pruned_heads = set()

Expand Down
Loading