I'm attempting to select a single sample from a range of Normal distributions based upon the output of a categorical distribution, however can't seem to come up with quite the right way to do it. Using something along the lines of:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal([0, 1, -10, 30], 1)[..., c]
])
Returns exactly what I want for the single case, however if I want multiple samples at once this breaks (as c becomes a numpy array rather than an integer. Is this possible and if so, how should I go about it?
(I also attempted using OneHotCategorical and multiplying but that didn't work at all!)
You could do this, if you don't want to use
MixtureSameFamily
as Brian suggests:Note I needed to add a
.
to the locs in the gather to avoid a dtype error.Here, what we end up doing is
n
samples from theCategorical
n
Normal
s, whose locs are obtained by indexingn
times into the 4-vector of locsn
-batch ofNormal
s.The previous approach doesn't work because
Distribution
slicing doesn't support this kind of "fancy indexing" It would be cool if we did! TF doesn't support it in general, for reasons.