RMSNorm implementation's results are not close

188 Views Asked by At

I implemented the RMSNorm algorithm ( ported existing PyTorch code and didn't refer any paper ) and tested it. The test failed as the values are not close enough as per the tolerance.

class TestRMSNorm(tf.test.TestCase):

    def setUp(self):
        super(TestRMSNorm, self).setUp()
        self.batch = tf.random.normal((batch_size, block_size, n_embd)) # 4, 32, 32

    def test_RMSNormTest(self):
        normalized_mat, norm = tf.linalg.normalize(self.batch, axis=(1, 2))
        ff_rms = tf.multiply(norm,
                             tf.pow(tf.cast(tf.size(self.batch[0]), tf.float32), -0.5))
        ffx = tf.Variable(tf.zeros_like(self.batch))
        print(tf.shape(ffx))
        for i in range(self.batch.shape[0]):
            ffx[i, :, : ].assign(tf.divide(self.batch[i] , ff_rms[i]))
        normalized_mat, norm = tf.linalg.normalize(self.batch, axis=(1, 2))
        print(tf.pow(norm,2))
        self.assertAllClose(tf.pow(norm,2),
                            tf.reshape(
                                tf.repeat([tf.constant(1024,tf.float32)], repeats=[4], axis=0),
                                (4,1,1)))

I think the shapes and values are clearly shown here. There is no other error as I ensured the shapes match. Have I missed anything ?

AssertionError: 
Not equal to tolerance rtol=1e-09, atol=0.0001
Mismatched value: a is different from b. 
not close where = (array([0, 1, 2, 3]), array([0, 0, 0, 0]), array([0, 0, 0, 0]))
not close lhs = [1019.3864 1056.9813 1021.6669 1046.128 ]
not close rhs = [1024. 1024. 1024. 1024.]
not close dif = [ 4.6135864 32.981323   2.33313   22.128052 ]
not close tol = [0.00010102 0.00010102 0.00010102 0.00010102]
dtype = float32, shape = (4, 1, 1)
Mismatched elements: 4 / 4 (100%)
Max absolute difference: 32.981323
Max relative difference: 0.03220832

My questions are these.

  1. Is the RMSNorm algorithm correct ? Should I read any material/code to improve it if it is wrong ?
  2. Can I use different tolerance levels to pass the test ? The API _ self.assertAllClose_ takes tolerance levels as parameters.And if I pass 50( for example ) for the upper and lower limit the test passes.

I can also ignore the failure as the values seem to be close.

0

There are 0 best solutions below