JAX metal produces jaxlib.xla_extension.XlaRuntimeError error on MacBook M3 for jnp.linalg.qr(A)

60 Views Asked by At

I have a new MacBook M3-pro and I have followed the instructions to install JAX 0.0.5 with jax and jaxlib 4.20.

I can run some commands using the JAX library, e.g. print(jax.random.PRNGKey(10)) returns the expected result.

But some commands, produce errors.

Running Q, R = jnp.linalg.qr(A), produce an error as follows:

Traceback (most recent call last):
  File "/Users/username/Documents/Code/ppx/test_jax.py", line 7, in <module>
    Q, R = jnp.linalg.qr(A)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/username/Documents/Code/ppx/test_jax.py:7:0: error: failed to legalize operation 'mhlo.custom_call'
/Users/username/Documents/Code/ppx/test_jax.py:7:0: note: see current operation: %6:2 = "mhlo.custom_call"(%5) {api_version = 1 : i32, backend_config = "", call_target_name = "Qr", called_computations = [], has_side_effect = false} : (tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2xf32>)

Note A is defined as follows: A = jnp.array([[1, 2], [3, 4]])

I expected it to return the QR decomposition.

1

There are 1 best solutions below

0
jakevdp On

The jax-metal plugin is experimental and incomplete (you should see a prominent warning to that effect each time you load it). You can find the current set of open issues under the Apple GPU Metal plugin label within the JAX issue tracker.

Given this, problems like the one you're seeing are not uncommon, nor are they unexpected. Your best course of action would be to switch to one of the standard, non-experimental hardware backends supported by the JAX team.