Skip to content

pnnx fuse wav2vec2 style mha #6004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,35 @@ pnnx.Output output 1 0 out
}
};

class fuse_multiheadattention_pass_12_2 : public fuse_multiheadattention_pass_12
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
18 17
pnnx.Input input_0 0 1 input
nn.Linear op_0 1 1 input 14 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input 15 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input 16 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight
Tensor.view op_3 1 1 14 17 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_4 1 1 15 18 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 16 19 shape=(%batch,%size,%num_heads,%feat_per_head)
torch.transpose op_6 1 1 19 20 dim0=1 dim1=2
torch.transpose op_7 1 1 18 21 dim0=1 dim1=2
torch.transpose op_8 1 1 17 22 dim0=1 dim1=2
Tensor.contiguous op_9 1 1 20 201 memory_format=*
Tensor.contiguous op_10 1 1 21 211 memory_format=*
Tensor.contiguous op_11 1 1 22 221 memory_format=*
F.scaled_dot_product_attention op_12 3 1 221 211 201 23 attn_mask=None dropout_p=0.000000e+00 is_causal=False
torch.transpose op_13 1 1 23 24 dim0=1 dim1=2
Tensor.reshape op_14 1 1 24 25 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 25 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_13 : public fuse_multiheadattention_pass_qkv
{
public:
Expand Down Expand Up @@ -2145,6 +2174,7 @@ void fuse_multiheadattention(Graph& graph)
fuse_multiheadattention_pass_10 j;
fuse_multiheadattention_pass_12 k;
fuse_multiheadattention_pass_12_1 k1;
fuse_multiheadattention_pass_12_2 k2;
fuse_multiheadattention_pass_13 l;
fuse_multiheadattention_pass_14 m;
fuse_multiheadattention_pass_15 n;
Expand Down Expand Up @@ -2186,6 +2216,7 @@ void fuse_multiheadattention(Graph& graph)
pnnx_graph_rewrite(graph, &j, opindex);
pnnx_graph_rewrite(graph, &k, opindex);
pnnx_graph_rewrite(graph, &k1, opindex);
pnnx_graph_rewrite(graph, &k2, opindex);
pnnx_graph_rewrite(graph, &l, opindex);
pnnx_graph_rewrite(graph, &m, opindex);
pnnx_graph_rewrite(graph, &n, opindex);
Expand Down
Loading