Passing JAX tracers to Huggingface CLIP transformer for calculating loss

62 Views Asked by At

I'm working on a vision task using JAX, and I'm facing an issue with passing intermediate JAX tracer objects as images to the CLIP model for calculating the loss. The CLIP model expects NumPy arrays as inputs, so the JAX tracer objects are not directly compatible.

Here's a simplified version of the code:

img_txt_clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

def train_step(state, batch, rng):
    """Train Step"""
    inputs, targets = batch

    def clip_loss_fn(params):
        model_fn = lambda x: state.apply_fn({"params": params}, x)
        ray_origins, ray_directions = inputs
        rgb, *_ = perform_volume_rendering(
            model_fn, ray_origins, ray_directions, rng
        )

        # Here's where the issue arises
        clip_input = processor(
            text=["a bulldozer"], images=[rgb], return_tensors="jax", padding=True
        )
        outputs = img_txt_clip(**clip_input)
        logits_per_image = outputs.logits_per_image
        return jnp.mean(logits_per_image)

    train_loss, gradients = jax.value_and_grad(clip_loss_fn)(state.params)
    gradients = lax.pmean(gradients, axis_name="batch")
    new_state = state.apply_gradients(grads=gradients)
    train_loss = jnp.mean(train_loss)
    return train_loss, new_state

I've tried using the FlaxCLIPModel for compatibility with JAX, but passing the JAX tracer objects as images to the CLIP model raises an error. Converting the JAX TRACER objects to NumPy arrays would be inefficient.

I would appreciate any suggestions or solutions to either convert the JAX TRACER objects to NumPy arrays efficiently or make the CLIP model accept the JAX TRACER objects as inputs.

Thank you in advance for your help!

1

There are 1 best solutions below

2
On

This question is answered in JAX's Frequently Asked Questions: JAX FAQ: How can I convert a JAX Tracer to a NumPy array?.

It is impossible to convert a JAX tracer into a NumPy array, because a tracer is not an array, but rather an abstract representation of all possible NumPy arrays of a given shape and dtype.

When this question comes up, though, it usually means the asker is interested in calling host-side non-JAX code at runtime, which manifests as an error about being unable to convert tracers to arrays. If this is actually what you want to do, the best avenue is probably jax.pure_callback, described in External Callbacks in JAX. But be aware that callbacks can lead to slow execution, because the process on the device will pause while the host communication and computation takes place.

A better alternative would be to find a JAX-compatible implementation of the function you're trying to call (i.e. one that supports JAX Tracers as inputs), which could be executed directly on-device rather than requiring a callback to host.