Repeated Branches in TorchScript Model Export

333 Views Asked by At

I'm trying to export a PyTorch model to TorchScript via scripting and I am stuck. I've created a toy class to showcase the issue:

import torch
from torch import nn


class SadModule(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self, use_skip: bool):
        nn.Module.__init__(self)
        self.use_skip = use_skip
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        if self.use_skip:
            x_input = x
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
        return x

It basically consists of only a linear layer and an optional skip connection. If I try to script the model using

mod1 = SadModule(False)
scripted_mod1 = torch.jit.script(mod)

I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-10-a7ebc7af32c7> in <module>
----> 1 scripted_mod1 = torch.jit.script(mod)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-7-d08ed7ff42ec>", line 12
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-7-d08ed7ff42ec>", line 16
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

So, basically TorchScript isn't able to recognise that for mod1 the True branch of either if statement won't ever be used. Moreover, if we create an instance that actually uses the skip connection,

mod2 = SadModule(True)
scripted_mod2 = torch.jit.script(mod2)

we will get another error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-b5ca61d8aa73> in <module>
----> 1 scripted_mod2 = torch.jit.script(mod2)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-18-ac8b9713c789>", line 17
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-18-ac8b9713c789>", line 21
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

So in this case TorchScript doesn't understand that both ifs will always be true and that in fact x_input is well defined.

To avoid the issue, I could split the class into two subclasses, as in:

class SadModuleNoSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x = self.layer(x)
        return x

class SadModuleSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x_input = x
        x = self.layer(x)
        x = x + x_input
        return x

However, I am working on a huge code base and I would have to repeat the process for many classes, which is time consuming and could introduce bugs. Moreover, often the modules I'm working on are huge convolutional nets and the ifs just control the presence of an additional batch normalization. It seems to me undesirable to have to classes that are identical in 99% of the blocks, save for a single batch norm layer.

Is there a way in which I can help TorchScript with its handling of branches?

Edit: added a minimum viable example.

Update: doesn't work even if I type hint use_skip as constant

from typing import Final

class SadModule(nn.Module):
    use_skip: Final[bool]
    ...
1

There are 1 best solutions below

0
On

I've opened an issue on GitHub. The project maintainers explained that using Final is the way to go. Be careful though, because as of today (May 7, 2021) this feature is still in development (abeit in its final stages, see here for the feature tracker).

Even though it's not yet available in the official releases, it is present in the nightly versions of PyTorch, so you can either install the pytorch-nighly builds as explained in the website (scroll down to Install PyTorch, then choose Preview (Nightly), or wait for the next release.

For anybody reading this answer a few months from now, this feature should be already integrated in the main releases of PyTorch.