How can I test if a jitted Jax function creates new tensor or a view?

156 Views Asked by At

I have a basic code like this:

@jit
def concat_permute(indices, in1, in2):
    tensor = jnp.concatenate([jnp.atleast_1d(in1), jnp.atleast_1d(in2)])
    return tensor[indices]

Here is my test tensors:

key = jax.random.PRNGKey(758493)
in1 = tens = jax.random.uniform(key, shape=(15,5,3))
in2 = tens = jax.random.uniform(key, shape=(10,5,3))
indices = jax.random.choice(key, 25, (25,), replace=False)

And here is the Jaxpr of the function:

{ lambda ; a:i32[25] b:f32[15,5,3] c:f32[10,5,3]. let
    d:f32[25,5,3] = xla_call[
      call_jaxpr={ lambda ; e:i32[25] f:f32[15,5,3] g:f32[10,5,3]. let
          h:f32[15,5,3] = xla_call[
            call_jaxpr={ lambda ; i:f32[15,5,3]. let  in (i,) }
            name=atleast_1d
          ] f
          j:f32[10,5,3] = xla_call[
            call_jaxpr={ lambda ; k:f32[10,5,3]. let  in (k,) }
            name=atleast_1d
          ] g
          l:f32[25,5,3] = concatenate[dimension=0] h j
          m:bool[25] = lt e 0
          n:i32[25] = add e 25
          o:i32[25] = select_n m e n
          p:i32[25,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(25, 1)
          ] o
          q:f32[25,5,3] = gather[
            dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
            fill_value=None
            indices_are_sorted=False
            mode=GatherScatterMode.PROMISE_IN_BOUNDS
            slice_sizes=(1, 5, 3)
            unique_indices=False
          ] l p
        in (q,) }
      name=concat_permute
    ] a b c
  in (d,) }

It seems it creates a new tensor using my permutation array but I'm not sure. Is there a more clear way to see if this opeeration is made by creating new tensor or not?

I tried "jax.make_jaxpr" and see the results but not sure about the problem.

1

There are 1 best solutions below

2
jakevdp On BEST ANSWER

The short answer is, no the output of your function will not share memory with the array allocated for tensor.

In XLA, an array is represented by a uniformly-strided buffer, and when you select random values from an array, the result cannot in general be constructed via uniform-striding over a view of the input buffer.