I have a basic code like this:
@jit
def concat_permute(indices, in1, in2):
tensor = jnp.concatenate([jnp.atleast_1d(in1), jnp.atleast_1d(in2)])
return tensor[indices]
Here is my test tensors:
key = jax.random.PRNGKey(758493)
in1 = tens = jax.random.uniform(key, shape=(15,5,3))
in2 = tens = jax.random.uniform(key, shape=(10,5,3))
indices = jax.random.choice(key, 25, (25,), replace=False)
And here is the Jaxpr of the function:
{ lambda ; a:i32[25] b:f32[15,5,3] c:f32[10,5,3]. let
d:f32[25,5,3] = xla_call[
call_jaxpr={ lambda ; e:i32[25] f:f32[15,5,3] g:f32[10,5,3]. let
h:f32[15,5,3] = xla_call[
call_jaxpr={ lambda ; i:f32[15,5,3]. let in (i,) }
name=atleast_1d
] f
j:f32[10,5,3] = xla_call[
call_jaxpr={ lambda ; k:f32[10,5,3]. let in (k,) }
name=atleast_1d
] g
l:f32[25,5,3] = concatenate[dimension=0] h j
m:bool[25] = lt e 0
n:i32[25] = add e 25
o:i32[25] = select_n m e n
p:i32[25,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(25, 1)
] o
q:f32[25,5,3] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 5, 3)
unique_indices=False
] l p
in (q,) }
name=concat_permute
] a b c
in (d,) }
It seems it creates a new tensor using my permutation array but I'm not sure. Is there a more clear way to see if this opeeration is made by creating new tensor or not?
I tried "jax.make_jaxpr" and see the results but not sure about the problem.
The short answer is, no the output of your function will not share memory with the array allocated for
tensor.In XLA, an array is represented by a uniformly-strided buffer, and when you select random values from an array, the result cannot in general be constructed via uniform-striding over a view of the input buffer.