Calculating the Jacobian of a JAX Convolution

40 Views Asked by At

I'm using JAX to produce a convolution

def gaussian_kernel(size: int, std: float):
    """Generates a 2D Gaussian kernel."""
    x, y = jnp.mgrid[-size:size+1, -size:size+1]
    g = jnp.exp(-(x**2 + y**2) / (2 * std**2))
    return g / g.sum()
    
def gaussian_blur(image, kernel_size=5, sigma=1.0):
    """Applies Gaussian blur to a 2D image."""
    kernel = gaussian_kernel(kernel_size, sigma)
    blurred_image = convolve2d(image, kernel, mode='same')
    return blurred_image 

Basically, just an ordinary blur.

Mathematically, I don't understand what the DERIVATIVE for a convolution will look like with respect to the input pixels.

As in, what is the effect on output pixel x from changing input pixel y.

How can I define this? How can I extract it from JAX. I dont even know where to start!

Well, I'm hoping to be able to extract the gradients with JAX for the output pixels with respect to the input pixels.

1

There are 1 best solutions below

0
jakevdp On

You can compute the jacobian using jax.jacobian:

image_out = gaussian_blur(image)
image_jac = jax.jacobian(gaussian_blur)(image)

For an input image of shape (M, N), image_out will also have shape (M, N), and image_jac will have shape (M, N, M, N).

The value image_jac[i, j, k, l] tells you the partial derivative of image_out[i, j] with respect to image[k, l].