Skip to content

[Bettertransformer] Transformers 4.41.0 (torch.SDPA-Bert) breaks bettertransformers Bert, but works in Transformers 4.40.2 #1902

Closed
@michaelfeil

Description

@michaelfeil

System Info

python 3.11 
Windows / WSL2
poetry venv

Who can help?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

Installing torch=2.3.1 and transformers=4.41.0 (or transformers=4.40.2 for fix).

from optimum.bettertransformer import BetterTransformer
from transformers import AutoModel, AutoTokenizer
import torch
model_name = "michaelfeil/bge-small-en-v1.5"

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


model = BetterTransformer.transform(model) # COMMENT OUT TO MAKE IT WORK.

model = model.cuda()

for num_sentences in [1, 10, 100]:
    sentences = [f"This is sentence number {i * 10}" for i in range(num_sentences)]
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(model.device)
    outputs = model(**inputs)
    print(outputs.last_hidden_state.shape)

Output:

The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.
torch.Size([1, 7, 384])
torch.Size([10, 7, 384])
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/michael/infinity/bert.py", line 16, in <module>
    outputs = model(**inputs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1137, in forward
    encoder_outputs = self.encoder(
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 690, in forward
    layer_outputs = layer_module(
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py", line 300, in forward
    attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
RuntimeError: shape '[100, 8]' is invalid for input of size 6400

Solution

Installing dependencies from lock file
Package operations: 0 installs, 1 update, 0 removals
  • Downgrading transformers (4.41.2 -> 4.40.2)

Works:

torch.Size([1, 7, 384])
torch.Size([10, 7, 384])
torch.Size([100, 8, 384])

Expected behavior

Bettertransformer is still 1.5x faster than torch.sdpa -> stuck with pinning huggingface transformers with <4.40.2 for now.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions