Skip to content

Commit 2b900f1

Browse files
committed
More fair comparison
stack-info: PR: #146, branch: drisspg/stack/7
1 parent 3ccadfd commit 2b900f1

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

examples/attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ def test(
8787
p = torch.softmax(p.float(), dim=-1).to(dtype)
8888
ref_out = torch.matmul(p, v)
8989

90-
# flex attention version
90+
# flex attention version=
9191
# TODO(jansel): turn the above kernel into a flex attention kernel
92-
flex_out = flex_attention(q, k, v)
92+
flex_compiled = torch.compile(flex_attention, fullgraph=True)
93+
flex_out = flex_compiled(q, k, v)
9394
torch.testing.assert_close(flex_out, ref_out, atol=1e-2, rtol=1e-2)
9495

9596
# sdpa version
@@ -106,7 +107,7 @@ def test(
106107
spda_sec = do_bench(
107108
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)
108109
)
109-
flex_sec = do_bench(lambda: flex_attention(q, k, v))
110+
flex_sec = do_bench(lambda: flex_compiled(q, k, v))
110111
helion_sec = do_bench(lambda: attention(q, k, v))
111112
print(
112113
f"Helion time: {helion_sec:.4f}ms, flex time: {flex_sec:.4f}, torch time: {spda_sec:.4f}"

0 commit comments

Comments
 (0)