Diagonal of the Hessian with Tensorflow

796 Views Asked by At

I'm doing some machine learning and I have to deal with a custom loss function. The derivatives and the Hessian of the loss function are difficult to derive and so I've resorted to computing them automatically using Tensorflow.

Here is an example.

import numpy as np
import tensorflow as tf

y_true = np.array([
    [1, 0, 0, 0, 0],
    [0, 1, 0, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1]
], dtype=float)

y_pred = np.array([
    [1, 0, 0, 0, 0],
    [0, 1, 0, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1]
], dtype=float)

weights = np.array([1, 1, 1, 1, 1], dtype=float)

with tf.Session():

    # We first convert the numpy arrays to Tensorflow tensors
    y_true = tf.convert_to_tensor(y_true)
    y_pred = tf.convert_to_tensor(y_pred)
    weights = tf.convert_to_tensor(weights)

    # The following code block is a custom loss 
    ys = tf.reduce_sum(y_true, axis=0)
    y_true = y_true / ys
    ln_p = tf.nn.log_softmax(y_pred)
    wll = tf.reduce_sum(y_true * ln_p, axis=0)
    loss = -tf.tensordot(weights, wll, axes=1)

    grad = tf.gradients(loss, y_pred)[0]

    hess = tf.hessians(loss, y_pred)[0]
    hess = tf.diag_part(hess)

    print(hess.eval())

which prints out

[[0.24090069 0.12669198 0.12669198 0.12669198 0.12669198]
 [0.12669198 0.24090069 0.12669198 0.12669198 0.12669198]
 [0.12669198 0.12669198 0.12669198 0.24090069 0.12669198]
 [0.12669198 0.12669198 0.24090069 0.12669198 0.12669198]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]]

I'm happy with this because it works, the problem is that it doesn't scale. For my use case I only need the diagonal of the Hessian matrix. I've managed to extract it using hess = tf.diag_part(hess) but this will still compute the full Hessian, which is unnecessary. The overhead is so bad that I can't use it for moderate sized datasets (~100k rows).

My question is thus: is there any better way to extract the diagonal of the Hessian? I'm well aware of this post and this one but I don't find the answers good enough.

0

There are 0 best solutions below