JAX grad function: why am I getting a list of zeros instead of the gradients?

656 Views Asked by At

I'm trying to find global maximum of a Python function with many variables (500+). For this purpose I'm trying to use JAX grad() to compute the gradient function of this MyFunction.

But I'm obviously doing something wrong - because each time I try to get the derivatives of MyFunction I just get a list of zeros which doesn't make any sense. (Note: MyFunction works as expected)

Any idea?

from jax import grad
import jax.numpy as jnp
import numpy as np
import json

# Example data - in real case I have 150,000+ rows
data = jnp.array([[ 1.  ,  1.06,  9.77,  5.  ,  3.  ,  2.  ,  6.  , 12.  ,  4.  ,
       10.  ,  1.  ,  7.  ,  1.  , 12.  , 10.  , 12.  ,  4.  , 10.  ,
        8.  ,  7.  , 11.  ,  5.  ,  9.  ,  3.  ,  6.  , 12.  ,  6.  ,
        5.  ,  3.  ,  5.  ,  9.  ,  8.  ,  9.  , 10.  , 11.  , 12.  ,
        1.  ,  2.  ,  3.  ,  4.  ,  5.  ,  6.  ,  7.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ], 
        [ 1.  ,  1.33,  3.33,  5.  ,  3.  ,  2.  ,  6.  , 12.  ,  4.  ,
       10.  ,  1.  ,  7.  ,  1.  , 12.  , 10.  , 12.  ,  4.  , 10.  ,
        8.  ,  7.  , 11.  ,  5.  ,  9.  ,  3.  ,  6.  , 12.  ,  6.  ,
        5.  ,  3.  ,  5.  ,  9.  ,  8.  ,  9.  , 10.  , 11.  , 12.  ,
        1.  ,  2.  ,  3.  ,  4.  ,  5.  ,  6.  ,  7.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ], 
        [ 2.  ,  1.65,  2.07,  5.  ,  3.  ,  2.  ,  6.  , 12.  ,  4.  ,
       10.  ,  1.  ,  7.  ,  1.  , 12.  , 10.  , 12.  ,  4.  ,  8.  ,
        6.  ,  5.  ,  9.  ,  3.  ,  7.  ,  1.  ,  4.  , 10.  ,  4.  ,
        3.  ,  1.  ,  3.  ,  7.  , 10.  , 11.  , 12.  ,  1.  ,  2.  ,
        3.  ,  4.  ,  5.  ,  6.  ,  7.  ,  8.  ,  9.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ]])

        
def MyFunction(coefs, data):
    
    balance = float(len(data)*-1000)
    
    for row in data:
        result = row[0]
        fOdds = row[1]
        dOdds = row[2]

        h1P = 0.0
        h2P = 0.0
        h3P = 0.0
        h4P = 0.0
        h5P = 0.0
        h6P = 0.0
        h7P = 0.0
        h8P = 0.0
        h9P = 0.0
        h10P = 0.0
        h11P = 0.0
        h12P = 0.0
        
        for p in range (0, 14):
            s = int(row[3+p]-1)
            h = int(row[17+p])
            r = int(row[43+p])
                
            bCoef = coefs[p]
            sCoef = coefs[14 + (p * 12) + s]
            hCoef = coefs[182 + (p * 12) + h]
            if r == 1:
                rCoef = coefs[350 + p]
            else:
                rCoef = 1.0
            
            pStrength = bCoef * sCoef * hCoef * rCoef
            
            if h == 0:
                h1P += pStrength
            if h == 1:
                h2P += pStrength
            if h == 2:
                h3P += pStrength
            if h == 3:
                h4P += pStrength
            if h == 4:
                h5P += pStrength
            if h == 5:
                h6P += pStrength
            if h == 6:
                h7P += pStrength
            if h == 7:
                h8P += pStrength
            if h == 8:
                h9P += pStrength
            if h == 9:
                h10P += pStrength
            if h == 10:
                h11P += pStrength
            if h == 11:
                h12P += pStrength
        
        for h in range (0, 12):
            hSign = int(row[31+h]-1)
            if h == 0:
                h1P *= coefs [364 + (h*12) + hSign]
            if h == 1:
                h2P *= coefs [364 + (h*12) + hSign]
            if h == 2:
                h3P *= coefs [364 + (h*12) + hSign]
            if h == 3:
                h4P *= coefs [364 + (h*12) + hSign]
            if h == 4:
                h5P *= coefs [364 + (h*12) + hSign]
            if h == 5:
                h6P *= coefs [364 + (h*12) + hSign]
            if h == 6:
                h7P *= coefs [364 + (h*12) + hSign]
            if h == 7:
                h8P *= coefs [364 + (h*12) + hSign]
            if h == 8:
                h9P *= coefs [364 + (h*12) + hSign]
            if h == 9:
                h10P *= coefs [364 + (h*12) + hSign]
            if h == 10:
                h11P *= coefs [364 + (h*12) + hSign]
            if h == 11:
                h12P *= coefs [364 + (h*12) + hSign]
                    
        fPoints = 0.0
        dPoints = 0.0
                                    
        fPoints += h1P * coefs[508]
        fPoints += h2P * coefs[509]
        fPoints += h3P * coefs[510]
        fPoints += h4P * coefs[511]
        fPoints += h5P * coefs[512]
        fPoints += h6P * coefs[513]
        fPoints += h7P * coefs[514]
        fPoints += h8P * coefs[515]
        fPoints += h9P * coefs[516]
        fPoints += h10P * coefs[517]
        fPoints += h11P * coefs[518]
        fPoints += h12P * coefs[519]
        
        dPoints += h1P * coefs[520]
        dPoints += h2P * coefs[521]
        dPoints += h3P * coefs[522]             
        dPoints += h4P * coefs[523]     
        dPoints += h5P * coefs[524]
        dPoints += h6P * coefs[525]
        dPoints += h7P * coefs[526]     
        dPoints += h8P * coefs[527]
        dPoints += h9P * coefs[528]
        dPoints += h10P * coefs[529]
        dPoints += h11P * coefs[530]
        dPoints += h12P * coefs[531]
        
        if result == 1:
            if fPoints >= dPoints:
                balance += fOdds*1000
                    
        elif result == 2:
            if dPoints > fPoints:
                balance += dOdds*1000
        
    return balance


derivFunction = grad (MyFunction)
coefs = np.random.sample(532)
# here I just get a list of 532 zeros instead of the derivatives...
print (derivFunction(coefs, data))
coefs = np.random.sample(532)
print (derivFunction(coefs, data))
1

There are 1 best solutions below

0
On

It appears that you're getting a zero gradient because this is the correct result: your function has a local gradient of zero at the input values. One way to see this is by perturbing the coefficients and observing that it does not change the output value:

print(MyFunction(coefs, data))
# -3000.0
print(MyFunction(coefs + 0.1, data))
# -3000.0
print(MyFunction(coefs - 0.1, data))
# -3000.0