The context of the problem is that I have a resnet model in Jax (basically NumPy), and I take the gradient of an image with respect to its class prediction. This gives me a gradient vector, g
, which I then want to normalize. The trouble is, the magnitudes of the components, g[i]
, are such that g[i]**2 == 0
, meaning that just dividing by np.linalg.norm(g)
gives a value of 0
, hence giving me nan
s.
What I've done so far is just checking if the norm is 0 then multiplying by some constant factor, as in (g = np.where(np.linalg.norm(g) < 1e-20, g * 1e20, g)
).
Was thinking maybe I should instead divide by the smallest nonzero element then normalize. Does anyone have ideas on how to properly normalize this vector?