Skip to content

StableHLO Composite Op Example #28891

Answered by dfm
gokul-uf asked this question in Q&A
Discussion options

You must be logged in to vote

StableHLO composites are exposed via jax.lax.composite. The docs include one example, and you can find a few others in the tests:

jax/tests/lax_test.py

Lines 4594 to 4824 in 609fb7f

class CompositeTest(jtu.JaxTestCase):
def test_composite(self):
def my_square_impl(x):
return x ** 2
my_square = lax.composite(my_square_impl, name="my.square")
x = jnp.array(2.0, dtype=jnp.float32)
output = my_square(x)
self.assertEqual(output, jnp.array(4.0, dtype=jnp.float32))
mlir_module = jax.jit(my_square).lower(x).as_text()
self.assertIn(
'stablehlo.composite "my.square" %arg0 {decomposition = @my.square} : '
'(tensor<f32>…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by gokul-uf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants