Why is this stratified bayesian logit so slow in Turing.jl?

77 Views Asked by At

I'm trying to create a Bayesian logistic regression that gives me insights into the number of payments made by a person and the probability of default. I created synthetic data to see if I could fit the model with real data:

using Turing

is_bad_pay(x) = x > 70 ? 1 : 0

function simulate_payment_frequency(Pₙ, N)
    P = rand(DiscreteUniform(1, Pₙ,), N)
    avg_delay = rand(LogNormal(2,2),N)
    payers = [is_bad_pay(x) for x in avg_delay]
    payers, P, avg_delay
end

I have three variables:

  1. A binary variable that is 1 if the person and 0 otherwise (payers)
  2. the number of payments (P)
  3. (avg_delay which is the average days between payments.

the model is the following:

# fit simulated data
@model function freq_pay(prob_pay, number_payments, avg_delay)
## Heading ##
   Num_payments = length(unique(number_payments))
# hierarchical by quantile of number of payments
    αₛ ~ filldist(Normal(60, 10), Num_payments)
    βₛ ~ filldist(Normal(0, 1), Num_payments)
    v = @. logistic(αₛ[number_payments] + βₛ[number_payments]*(avg_delay))  
# logistic regression 
    for i ∈ eachindex(v)
        prob_pay[i] ~ Bernoulli(v[i])
    end
end

I first tried simulating people with only two payments and it works well:

synthetic_payers = simulate_payment_frequency(2, 100)
s1_1 = sample(freq_pay(synthetic_payers[1], synthetic_payers[2], synthetic_payers[3]), NUTS(), 100)

However, when I try more than 3 payments it never stops.

synthetic_payers = simulate_payment_frequency(4, 100)
s1_2 = sample(freq_pay(synthetic_payers[1], synthetic_payers[2], synthetic_payers[3]), NUTS(), 100)

What I'm I doing wrong?

1

There are 1 best solutions below

0
Emiliano Isaza Villamizar On

Ok so ~ Bernoulli is unstable numerically, therefore convergence isn't assured; Turing created a distribution only for logits. This is the answer:

            for i ∈ eachindex(v)
               prob_pay[i] ~ BernoulliLogit(v[i])
            end

Just use stay calm and use BernoulliLogit. (Turing.jl documentation is wrong by the way it uses Bernoulli for the logit example)