Python instability due to exponentials

157 Views Asked by At

Setup: I have the following function in python, where x can get very large:

import numpy as np

def function(x, pi):
  d = len(pi)
  output = 0
  for r in range(d):
    output += pi[r] * np.exp(-x)
  return output

Input description: x can be very large causing np.exp(-x) to evaluate to zero which results in the entire function being zero, and pi is just a vector of probabilities (e.g., [0.5, 0.5]).

Question: Is there a more stable way to implement this function such that it wouldn't lead to the output being zero? Thanks.

Edit: I have decided to give more details since it was asked in the comments. The entire function is

def entire_function(x_array, pi, r):
  d = len(pi)
  numerator = np.exp(-x_array[r])
  denominator = 0
  for r_prime in range(d):
    denominator += pi[r_prime] * np.exp(-x_array[r_prime])
  return numerator / denominator 

Even trying to use np.log doesn't really help. For example:

a = np.array([np.exp(-900), np.exp(-800)])
print(np.log(a[0]+a[1]))

This gives me -Inf. The summation in the denominator is the nasty part that is giving me trouble since it is preventing me from accessing the exponents (to make the computation more numerically stable). I guess this issue is similar to the logsumexp examples in machine learning with the extra pi[r] factors in front.

1

There are 1 best solutions below

0
hbwales On BEST ANSWER

Note that in general we have:

pex = elog(p) ex = elog(p) + x

Using this we can apply the log-sum-exp trick you linked

import numpy as np

xs = np.array([700, 900])
ps = np.array([0.6, 0.4])

def original(xs, ps, r):
    ex = np.exp(-xs) 
    return ex[r] / (ps*ex).sum()

def log_sum_exp(x):
    c = x.max()
    return c + np.log(np.sum(np.exp(x - c)))

def adjusted(xs, ps, r):
    return np.exp(-xs[r] - log_sum_exp(-xs + np.log(ps)))

Which we can check with

def check(xs, ps, r):
    # calculate to 1,000 decinal places to check result against
    from decimal import Decimal, getcontext
    getcontext().prec = 1000
    ex = [Decimal.exp(-Decimal(float(xi))) for xi in xs]
    return ex[r] / sum(Decimal(float(pi))*ei for pi,ei in zip(ps, ex))

print([adjusted(xs, ps, i) for i in range(2)])       # [1.6666666666666516, 2.306494211227875e-87]
print([float(check(xs, ps, i)) for i in range(2)])   # [1.6666666666666667, 2.306494211227896e-87]