How to build a forward pass of a network in PyTorch based on a string expression?

55 Views Asked by At

I have a string expression: self.w0torch.sin(x)+self.w1torch.exp(x). How can I use this expression as the forward pass of a model in PyTorch? The class for instantiating a model is as follows:

class MyModule(nn.Module):
    def __init__(self,vector):
        super().__init__()
        self.s='self.w0*torch.sin(x)+self.w1*torch.exp(x)'

        w0=0.01*torch.rand(1,dtype=torch.float,requires_grad=True)
        self.w0 = nn.Parameter(w0)

        w1=0.01*torch.rand(1,dtype=torch.float,requires_grad=True)
        self.w1 = nn.Parameter(w1)

    def forward(self,x):
        return ????

For this self.w0torch.sin(x)+self.w1torch.exp(x) string expression, the architecture of the model is as follows:

The architecture of the model

I have tried the following method as the forward pass:

def forward(self,x):
    return eval(self.s)

Is this the best way to do the forward pass? Note that the string expression could be varying and I don't want to define a constant forward pass like:

 def forward(self,x):
    return self.w0*torch.sin(x)+self.w1*torch.exp(x)
1

There are 1 best solutions below

0
inverted_index On BEST ANSWER

I do not recommend using eval directly due to the following reasons:

  • Security: eval can execute any arbitrary code, which is a potential security risk, especially with untrusted input.
  • Performance: eval can be slower as it needs to parse and interpret the string each time it is called.
  • Debugging and Maintenance: Code that uses eval is often harder to understand, debug, and maintain.

However, if the requirement is to have a dynamic expression for the forward pass where the expression can change, you can use a safer alternative to eval. One such alternative is using torch's built-in operations and dynamically constructing the computation graph. This can be done using Python's built-in functions like getattr and setattr. Here's an example of how you might implement this:

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self, vector):
        super().__init__()
        self.s = 'self.w0*torch.sin(x)+self.w1*torch.exp(x)'

        w0 = 0.01 * torch.rand(1, dtype=torch.float, requires_grad=True)
        self.w0 = nn.Parameter(w0)

        w1 = 0.01 * torch.rand(1, dtype=torch.float, requires_grad=True)
        self.w1 = nn.Parameter(w1)

    def parse_expression(self, x, expression):
        terms = expression.split('+')
        result = 0.0
        for term in terms:
            parts = term.split('*')
            weight = getattr(self, parts[0].strip())
            operation = parts[1].split('(')[0].strip()
            operand = x
            operation_func = getattr(torch, operation)
            result += weight * operation_func(operand)
        return result

    def forward(self, x):
        return self.parse_expression(x, self.s)