Skip to content

_prepare_4d_attention_mask_for_sdpa is not for causal attention but claims... #30095

Closed
@minostauros

Description

@minostauros

... SDPA causal mask generation may be wrong for the mask generation.

if torch.all(mask == 1):
if is_tracing:
pass
elif tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:
return None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)

Will it be safe to just return None for the else: case?

For causal attention, we can just use _prepare_4d_causal_attention_mask_for_sdpa

Related issues:
pytorch/pytorch#108108
Dao-AILab/flash-attention@9e5e8bc
#28802

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions