Correct way to combine Normal distributions in Julia with Distributions.jl

77 Views Asked by At

I have two multivariate normal distributions like such:

using Distributions, LinearAlgebra
g1 = MvNormal([1,2], [2 1; 1 2])
A = [3 1; 1 3]
B = [[A [0;0]]; transpose([0,0,1])]
g2 = MvNormal([1,2,3], B)

I would like to combine them into a new distribution of the concatenation of both variables, assuming they are independent. The function product_distribution seems like it should do the trick:

g3 = product_distribution(g1, g2)

But that results in an error:

ERROR: all distributions must be of the same size

Which I really don't understand, and makes me think this function's purpose is not what I thought it was, but I can't find any other that would be more appropriate.

To be clear, the desired output should be equivalent to:

m3 = vcat(mean(g1), mean(g2))
s3 = hvcat( (2,2), cov(g1), zeros(2,3), zeros(3,2), cov(g2))
g3 = MvNormal(m3, s3)

(Although perhaps a sparse matrix or some other optimised diagonal block matrix type would be more appropriate, but I really don't care in this case.)

1

There are 1 best solutions below

3
Dan Getz On BEST ANSWER

Couldn't find a clean answer, but the issue has come up before. The best I can suggest so far:

using BlockDiagonals, Distributions

concat(ds::Union{MvNormal, Normal}...) =
  foldl(ds; init = MvNormal(Float64[],zeros((0,0)))) do x, D
      d = D isa Normal ? MvNormal([mean(D)], [var(D);;]) : D
      m1, c1 = mean(x), cov(x)
      m2, c2 = mean(d), cov(d)
      return MvNormal(vcat(m1, m2), BlockDiagonal([c1, c2]))
  end

giving:

julia> g1 = MvNormal([1,2], [2 1; 1 2]);

julia> A = [3 1; 1 3];

julia> B = BlockDiagonal([A,[1;;]]);

julia> g2 = MvNormal([1,2,3], B);

julia> g12 = concat(g1,g2)
MvNormal{Float64, PDMats.PDMat{Float64, BlockDiagonal{Float64, Matrix{Float64}}}, Vector{Float64}}(
dim: 5
μ: [1.0, 2.0, 1.0, 2.0, 3.0]
Σ: [2.0 1.0 … 0.0 0.0; 1.0 2.0 … 0.0 0.0; … ; 0.0 0.0 … 3.0 0.0; 0.0 0.0 … 0.0 1.0]
)

julia> rand(g12)
5-element Vector{Float64}:
  1.0362522596298223
  2.0469956784329866
 -0.21925320262748982
 -1.2334419613775114
  3.2146164519549814