We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3ccadfd commit 2b900f1Copy full SHA for 2b900f1
examples/attention.py
@@ -87,9 +87,10 @@ def test(
87
p = torch.softmax(p.float(), dim=-1).to(dtype)
88
ref_out = torch.matmul(p, v)
89
90
- # flex attention version
+ # flex attention version=
91
# TODO(jansel): turn the above kernel into a flex attention kernel
92
- flex_out = flex_attention(q, k, v)
+ flex_compiled = torch.compile(flex_attention, fullgraph=True)
93
+ flex_out = flex_compiled(q, k, v)
94
torch.testing.assert_close(flex_out, ref_out, atol=1e-2, rtol=1e-2)
95
96
# sdpa version
@@ -106,7 +107,7 @@ def test(
106
107
spda_sec = do_bench(
108
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)
109
)
- flex_sec = do_bench(lambda: flex_attention(q, k, v))
110
+ flex_sec = do_bench(lambda: flex_compiled(q, k, v))
111
helion_sec = do_bench(lambda: attention(q, k, v))
112
print(
113
f"Helion time: {helion_sec:.4f}ms, flex time: {flex_sec:.4f}, torch time: {spda_sec:.4f}"
0 commit comments