I'm using JAX to produce a convolution
def gaussian_kernel(size: int, std: float):
"""Generates a 2D Gaussian kernel."""
x, y = jnp.mgrid[-size:size+1, -size:size+1]
g = jnp.exp(-(x**2 + y**2) / (2 * std**2))
return g / g.sum()
def gaussian_blur(image, kernel_size=5, sigma=1.0):
"""Applies Gaussian blur to a 2D image."""
kernel = gaussian_kernel(kernel_size, sigma)
blurred_image = convolve2d(image, kernel, mode='same')
return blurred_image
Basically, just an ordinary blur.
Mathematically, I don't understand what the DERIVATIVE for a convolution will look like with respect to the input pixels.
As in, what is the effect on output pixel x from changing input pixel y.
How can I define this? How can I extract it from JAX. I dont even know where to start!
Well, I'm hoping to be able to extract the gradients with JAX for the output pixels with respect to the input pixels.
You can compute the jacobian using
jax.jacobian:For an input
imageof shape(M, N),image_outwill also have shape(M, N), andimage_jacwill have shape(M, N, M, N).The value
image_jac[i, j, k, l]tells you the partial derivative ofimage_out[i, j]with respect toimage[k, l].