I was trying to understand how to use tfp.bijectors.Affine properly, and testing if the output generated by this function is as my expected, so I wrote a piece of code as this:
import tensorflow_probability as tfp
distribution = tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(1), scale_diag=tf.ones(1))
sample1 = distribution.sample(5)
mu = tf.Variable([1], dtype=tf.float32)
sigma = tf.Variable([2], dtype=tf.float32)
bijector = tfp.bijectors.Affine(shift=mu, scale_diag=sigma)
raw_action = bijector.forward(sample1)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run(sample1))
print(sess.run(raw_action))
Based on what I have learnd, any elements of raw_action(let's say y) should be one more than twice the number of their counterparts(let's say x) at sample1(cause y = 2x + 1). But when I ran the code above, this was what I had:
[[-0.90748686]
[ 0.26501548]
[ 1.6397986 ]
[ 0.00422014]
[ 1.4650348 ]]
[[-4.671192 ]
[ 3.8869004 ]
[ 4.4971347 ]
[ 0.02692574]
[-0.38139176]]
This output does not satisfy y = 2x + 1 obviously.
I tried to fix the value of sample1, and tested again:
import tensorflow_probability as tfp
distribution = tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(1), scale_diag=tf.ones(1))
sample1 = tf.Variable([[-0.90748686],[ 0.26501548],[ 1.6397986 ],[ 0.00422014],[ 1.4650348 ]], dtype=tf.float32)
mu = tf.Variable([1], dtype=tf.float32)
sigma = tf.Variable([2], dtype=tf.float32)
bijector = tfp.bijectors.Affine(shift=mu, scale_diag=sigma)
raw_action = bijector.forward(sample1)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run(sample1))
print(sess.run(raw_action))
And this time I got what I expected, it was just an application of y = 2x+1:
[[-0.90748686]
[ 0.26501548]
[ 1.6397986 ]
[ 0.00422014]
[ 1.4650348 ]]
[[-0.8149737]
[ 1.530031 ]
[ 4.2795973]
[ 1.0084403]
[ 3.9300697]]
I have no idea why these two blocks perform different behaviours, any clues and answers would be appreciated!
python version: 3.10.9
tensorflow verion: 1.13.1
tensorflow-probability version: 0.6.0
platform: Jupyter notebook