Fail to understand the usage of partial argument in Flax Resnet Official Example

82 Views Asked by At

I have been trying to understand this official example. However, I am very confused about the use of partial in two places.

For example, in line 94, we have the following:

conv = partial(self.conv, use_bias=False, dtype=self.dtype)

I am not sure why it is possible to apply a partial to a class, and where later in the code we fill in the missing argument (if we need to).

Coming to the final definition, I am even more confused. For example,

ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
               block_cls=ResNetBlock)

Where do we apply the argument such as stage_size=[2,2,2,2]?

Thank you

1

There are 1 best solutions below

1
On BEST ANSWER

functools.partial will partially evaluate a function, binding arguments to it for when it is called later. here's an example of it being used with a function:

from functools import partial

def f(x, y, z):
  print(f"{x=} {y=} {z=}")

g = partial(f, 1, z=3)
g(2)
# x=1 y=2 z=3

and here is an example of it being used on a class constructor:

from typing import NamedTuple

class MyClass(NamedTuple):
  a: int
  b: int
  c: int

make_class = partial(MyClass, 1, c=3)
print(make_class(b=2))
# MyClass(a=1, b=2, c=3)

The use in the flax example is conceptually the same: partial(f) returns a function that when called, applies the bound arguments to the original callable, whether it is a function, a method, or a class constructor.

For example, the ResNet18 function created here:

ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
                   block_cls=ResNetBlock)

is a partially-evaluated ResNet constructor, and the function is called in a test here:

  @parameterized.product(
      model=(models.ResNet18, models.ResNet18Local)
  )
  def test_resnet_18_v1_model(self, model):
    """Tests ResNet18 V1 model definition and output (variables)."""
    rng = jax.random.PRNGKey(0)
    model_def = model(num_classes=2, dtype=jnp.float32)
    variables = model_def.init(
        rng, jnp.ones((1, 64, 64, 3), jnp.float32))

    self.assertLen(variables, 2)
    self.assertLen(variables['params'], 11)

model here is the partially evaluated function ResNet18, and when it is called it returns the fully-instantiated ResNet object with the parameters specified in the ResNet18 partial definition.