I'm trying to export a PyTorch model to TorchScript via scripting and I am stuck. I've created a toy class to showcase the issue:
import torch
from torch import nn
class SadModule(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self, use_skip: bool):
nn.Module.__init__(self)
self.use_skip = use_skip
self.layer = nn.Linear(2, 2)
def forward(self, x):
if self.use_skip:
x_input = x
x = self.layer(x)
if self.use_skip:
x = x + x_input
return x
It basically consists of only a linear layer and an optional skip connection. If I try to script the model using
mod1 = SadModule(False)
scripted_mod1 = torch.jit.script(mod)
I get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-10-a7ebc7af32c7> in <module>
----> 1 scripted_mod1 = torch.jit.script(mod)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
895
896 if isinstance(obj, torch.nn.Module):
--> 897 return torch.jit._recursive.create_script_module(
898 obj, torch.jit._recursive.infer_methods_to_compile
899 )
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
350 check_module_initialized(nn_module)
351 concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
353
354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
408 # Compile methods if necessary
409 if concrete_type not in concrete_type_store.methods_compiled:
--> 410 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
411 torch._C._run_emit_module_hook(cpp_module)
412 concrete_type_store.methods_compiled.add(concrete_type)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
302 property_rcbs = [p.resolution_callback for p in property_stubs]
303
--> 304 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
305
306
RuntimeError:
x_input is not defined in the false branch:
File "<ipython-input-7-d08ed7ff42ec>", line 12
def forward(self, x):
if self.use_skip:
~~~~~~~~~~~~~~~~~
x_input = x
~~~~~~~~~~~ <--- HERE
x = self.layer(x)
if self.use_skip:
and was used here:
File "<ipython-input-7-d08ed7ff42ec>", line 16
x = self.layer(x)
if self.use_skip:
x = x + x_input
~~~~~~~ <--- HERE
return x
So, basically TorchScript isn't able to recognise that for mod1
the True
branch of either if
statement won't ever be used. Moreover, if we create an instance that actually uses the skip connection,
mod2 = SadModule(True)
scripted_mod2 = torch.jit.script(mod2)
we will get another error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-21-b5ca61d8aa73> in <module>
----> 1 scripted_mod2 = torch.jit.script(mod2)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
895
896 if isinstance(obj, torch.nn.Module):
--> 897 return torch.jit._recursive.create_script_module(
898 obj, torch.jit._recursive.infer_methods_to_compile
899 )
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
350 check_module_initialized(nn_module)
351 concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
353
354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
408 # Compile methods if necessary
409 if concrete_type not in concrete_type_store.methods_compiled:
--> 410 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
411 torch._C._run_emit_module_hook(cpp_module)
412 concrete_type_store.methods_compiled.add(concrete_type)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
302 property_rcbs = [p.resolution_callback for p in property_stubs]
303
--> 304 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
305
306
RuntimeError:
x_input is not defined in the false branch:
File "<ipython-input-18-ac8b9713c789>", line 17
def forward(self, x):
if self.use_skip:
~~~~~~~~~~~~~~~~~
x_input = x
~~~~~~~~~~~ <--- HERE
x = self.layer(x)
if self.use_skip:
and was used here:
File "<ipython-input-18-ac8b9713c789>", line 21
x = self.layer(x)
if self.use_skip:
x = x + x_input
~~~~~~~ <--- HERE
return x
So in this case TorchScript doesn't understand that both if
s will always be true and that in fact x_input
is well defined.
To avoid the issue, I could split the class into two subclasses, as in:
class SadModuleNoSkip(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self):
nn.Module.__init__(self)
self.layer = nn.Linear(2, 2)
def forward(self, x):
x = self.layer(x)
return x
class SadModuleSkip(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self):
nn.Module.__init__(self)
self.layer = nn.Linear(2, 2)
def forward(self, x):
x_input = x
x = self.layer(x)
x = x + x_input
return x
However, I am working on a huge code base and I would have to repeat the process for many classes, which is time consuming and could introduce bugs. Moreover, often the modules I'm working on are huge convolutional nets and the if
s just control the presence of an additional batch normalization. It seems to me undesirable to have to classes that are identical in 99% of the blocks, save for a single batch norm layer.
Is there a way in which I can help TorchScript with its handling of branches?
Edit: added a minimum viable example.
Update: doesn't work even if I type hint use_skip
as constant
from typing import Final
class SadModule(nn.Module):
use_skip: Final[bool]
...
I've opened an issue on GitHub. The project maintainers explained that using
Final
is the way to go. Be careful though, because as of today (May 7, 2021) this feature is still in development (abeit in its final stages, see here for the feature tracker).Even though it's not yet available in the official releases, it is present in the nightly versions of PyTorch, so you can either install the
pytorch-nighly
builds as explained in the website (scroll down to Install PyTorch, then choose Preview (Nightly), or wait for the next release.For anybody reading this answer a few months from now, this feature should be already integrated in the main releases of PyTorch.