how to properly define the tf.Variable if I have a number of blocks

61 Views Asked by At

I just started to transform from pytorch to tensorflow, and have some problems when designing the residual blocks. I have a residual group which contain a number of residual blocks and eack block contains two custom layers. I am troubled with how to define the variables which needs to be used as a part of operation in call() function in each layer.

this is a illustration of my model framework

I tried to define the varible using like self.W = tf.Vaiable(). But in this way, when I initialize the residule group, the self.W will continously be coverd. And when I tried to use self.W to extrace this parameter in call function in each layer, I got None.

In pytorch, I can simply use register_parameters to define the variables in init, and use self.W to extract it in forward function.

Could anyone that are familiar with tensorflow can help me with that? Thanks.

1

There are 1 best solutions below

0
On

You can define the variables by using code below

class M(tf.Module):
  def __call__(self, x):
    self.v = tf.Variable(x)
    return self.v

Thank You.