I have been trying to understand this official example. However, I am very confused about the use of partial
in two places.
For example, in line 94, we have the following:
conv = partial(self.conv, use_bias=False, dtype=self.dtype)
I am not sure why it is possible to apply a partial to a class, and where later in the code we fill in the missing argument (if we need to).
Coming to the final definition, I am even more confused. For example,
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
block_cls=ResNetBlock)
Where do we apply the argument such as stage_size=[2,2,2,2]
?
Thank you
functools.partial
will partially evaluate a function, binding arguments to it for when it is called later. here's an example of it being used with a function:and here is an example of it being used on a class constructor:
The use in the flax example is conceptually the same:
partial(f)
returns a function that when called, applies the bound arguments to the original callable, whether it is a function, a method, or a class constructor.For example, the
ResNet18
function created here:is a partially-evaluated
ResNet
constructor, and the function is called in a test here:model
here is the partially evaluated functionResNet18
, and when it is called it returns the fully-instantiatedResNet
object with the parameters specified in theResNet18
partial definition.