Is there a way to define the gradient of a 3D FFT using tensorflow's custom_gradient decorator

495 Views Asked by At

Context & problem

I am using the Hamiltonian Monte Carlo (HMC) method of the tensorflow-probability module to explore the most probable states of a self-written probability function. Amongst the parameters I am trying to fit are Fourier modes of a real three dimensional field.

For the HMC to run, each computing block needs to have its gradient implemented. However, the implementation of the inverse real FFT tf.signal.irfft3d does not have a gradient method associated by default.

Question

Is there a way to implement the gradient of the function irfft3d? I already have a running, self-implemented irfft3d with more basic blocks of tensorflow on which automatic differentiation seems to work, but I would like to wrap the actual optimized and stable implementation of tf.signal.irfft3d using the decorator @tf.custom_gradient to make the automatic differentiation work.

My guess

The Fourier transform is linear, this problem is then theoretically trivial. However, writing the Jacobian of the Fourier transform on a grid is numerically unfeasible (as its dimensions would be huge). Luckily, tensorflow demands only for a functional that evaluates the Jacobian on a input vector. I believe this can be efficiently done thanks to the FFT algorithm. Unfortunately, it seems to me that tensorflow demands a functional that computes the invert of the Jacobian applied to the "upstream gradient", which I do not understand:

https://www.tensorflow.org/api_docs/python/tf/custom_gradient?version=nightly

function f(*x) that returns a tuple (y, grad_fn) where:

  • x is a sequence of (nested structures of) Tensor inputs to the function.
  • y is a (nested structure of) Tensor outputs of applying TensorFlow operations in f to x.
  • grad_fn is a function with the signature g(*grad_ys) which returns a list of Tensors the same size as (flattened) x - the derivatives of Tensors in y with respect to the Tensors in x. grad_ys is a sequence of Tensors the same size as (flattened) y holding the initial value gradients for each Tensor in y.

In a pure mathematical sense, a vector-argument vector-valued function f's derivatives should be its Jacobian matrix J. Here we are expressing the Jacobian J as a function grad_fn which defines how J will transform a vector grad_ys when left-multiplied with it (grad_ys * J, the vector-Jacobian product, or VJP). This functional representation of a matrix is convenient to use for chain-rule calculation (in e.g. the back-propagation algorithm).

Complying with the dimensions and the formats given in the doc, I cannot imagine any other solution that:

#!/usr/bin/env python3

# set up
import tensorflow as tf

n = 64

noise = tf.random.normal((n, n, n))
modes = tf.signal.rfft3d(noise)

# the function 
@tf.custom_gradient
def irfft3d(x):
    def grad_fn(dy):
        return (tf.signal.rfft3d(dy))

    return (tf.signal.irfft3d(x), grad_fn)

# test 
with tf.GradientTape() as gt:
    gt.watch(modes)
    rec_noise = irfft3d(modes)

dn_dm = gt.gradient(rec_noise, modes)

print(dn_dm)

Which does run and return:

tf.Tensor(
[[[262144.+0.j      0.+0.j      0.+0.j ...      0.+0.j      0.+0.j
        0.+0.j]
  [     0.+0.j      0.+0.j      0.+0.j ...      0.+0.j      0.+0.j
        0.+0.j]
  [     0.+0.j      0.+0.j      0.+0.j ...      0.+0.j      0.+0.j
        0.+0.j]
  ...
  [     0.+0.j      0.+0.j      0.+0.j ...      0.+0.j      0.+0.j
        0.+0.j]
  [     0.+0.j      0.+0.j      0.+0.j ...      0.+0.j      0.+0.j
        0.+0.j]
  [     0.+0.j      0.+0.j      0.+0.j ...      0.+0.j      0.+0.j
        0.+0.j]]], shape=(64, 64, 33), dtype=complex64)

I cannot really wrap up my mind around it. First, this would be such a simple solution that I would not understand why it has not been natively implemented. But more importantly, I am simply lost in what tensorflow expects from this self-written gradient function and I am not able to express its result in a mathematical way that makes sense to me.

Is there anybody out there that understands the way tensorflow handles differentiation and could help or correct me?

0

There are 0 best solutions below