How to efficiently apply different MLP´s on different areas of my input tensor in pytorch

66 Views Asked by At

Given input tensor of size input(8,10) I got three MLP´s(1,2,3) that all have the input size of 10. Furthermore I have a an index tensor mlp_index(8) which determines the mlp I want to apply onto a certain row in my input. For Example if mlp_index[0] = 2, then the second MLP should be applied onto input[0]. I wrote a minimal example to showcase the problem, and two different ways of dealing with the problem efficiently. However, as you can see, applying just one MLP to the whole input is still significantly faster.

Question: Is there a more efficient way of dealing with that problem?

import torch
import torch.nn as nn
import torch.nn.functional as F
import timeit

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MLP0(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP0, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# arbitrary parameters
input_size = 10
hidden_size = 20
output_size = 5

mlp1 = MLP0(input_size, hidden_size, output_size)
mlp1.to(device)
mlp2 = MLP0(input_size, hidden_size, output_size)
mlp2.to(device)
mlp3 = MLP0(input_size, hidden_size, output_size)
mlp3.to(device)

input_data = torch.rand(size=(8, input_size), device=device)
mlp_index = torch.tensor([0, 1, 0, 1, 0, 2, 0, 2], device=device).unsqueeze(1)


def baseline():
    result = mlp1.forward(input_data)
    return result


def first_update():
    out_1 = mlp1.forward(input_data)
    out_2 = mlp2.forward(input_data)
    out_3 = mlp3.forward(input_data)
    result = torch.where(mlp_index == 0, out_1, torch.where((mlp_index!=0) & (mlp_index!=2), out_2 , out_3))
    return result


def second_update():
    result = torch.where(mlp_index == 0, mlp1(input_data), torch.where((mlp_index!=0) & (mlp_index!=2), mlp2(input_data) , mlp3(input_data)))
    return result


def third_update():
    # make batches per model that can be executed at once
    input_1 = input_data[(mlp_index == 0)[:,0]]
    input_2 = input_data[(mlp_index == 1)[:,0]]
    input_3 = input_data[(mlp_index == 2)[:,0]]

    # execute batches
    out_1 = mlp1.forward(input_1)
    out_2 = mlp2.forward(input_2)
    out_3 = mlp3.forward(input_3)

    out = torch.zeros(size=(8, output_size), device=device)
    out[(mlp_index == 0)[:,0]] = out_1
    out[(mlp_index == 1)[:,0]] = out_2
    out[(mlp_index == 2)[:,0]] = out_3
    return out

    

baseline_time = timeit.timeit(baseline, number=20000)
print(f"Execution time: {baseline_time} seconds")

first_update_time = timeit.timeit(first_update, number=20000)
print(f"Execution time: {first_update_time} seconds")

second_update_time = timeit.timeit(second_update, number=20000)
print(f"Execution time: {second_update_time} seconds")

third_update_time = timeit.timeit(third_update, number=20000)
print(f"Execution time: {third_update_time} seconds")

Ouput:

Execution time: 1.5391472298651934 seconds
Execution time: 2.1761511098593473 seconds
Execution time: 2.233237884938717 seconds
Execution time: 6.252682875841856 seconds
2

There are 2 best solutions below

4
Klops On

This is a fun question. When playing around with your Code, I found that the execution time of the network itself is minimal compared to any reshaping or data mangling (which explains your strange observations).

Given that your "real" problem has a much larger sample size (and the time reshaping can be diminished compared to execution time), I suggest this third way:

input_data = torch.rand(size=(8, input_size))
mlp_index = torch.tensor([2, 1, 0, 1, 0, 2, 0, 2])

# make batches per model that can be executed at once
input_1 = input_data[(mlp_index == 0)[:,0]]
input_2 = input_data[(mlp_index == 1)[:,0]]
input_3 = input_data[(mlp_index == 2)[:,0]]

# execute batches
out_1 = mlp1.forward(input_1)
out_2 = mlp2.forward(input_2)
out_3 = mlp3.forward(input_3)

# stitching result
res = torch.empty((8, output_size))
res[(mlp_index == 0)[:,0]] = out_1
res[(mlp_index == 1)[:,0]] = out_2
res[(mlp_index == 2)[:,0]] = out_3

If your network is really that small (and your samples are really so few), indexing your data takes longer than feeding it through the network...

Edit analysis of depth and speed

Note, that I introduced a depth parameter.

import torch
import torch.nn as nn
import torch.nn.functional as F
import timeit

torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MLP0(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, depth):
        super(MLP0, self).__init__()
        self.depth = depth
        self.first = nn.Linear(input_size, hidden_size)
        for i in range(2, depth-1):
            # generate (depth - 2) hidden layer
            setattr(self, f"fc{i}", nn.Linear(hidden_size, hidden_size))
        self.last = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.first(x))
        for i in range(2, depth-2):
            # execute hidden layer
            x = F.relu(getattr(self, f"fc{i}")(x))
        x = self.last(x)
        return x


# arbitrary parameters
input_size = 10
hidden_size = 20
output_size = 5
depth = 200

mlp1 = MLP0(input_size, hidden_size, output_size, depth)
mlp1.to(device)
mlp2 = MLP0(input_size, hidden_size, output_size, depth)
mlp2.to(device)
mlp3 = MLP0(input_size, hidden_size, output_size, depth)
mlp3.to(device)

input_data = torch.rand(size=(8, input_size), device=device)
mlp_index = torch.tensor([0, 1, 0, 1, 0, 2, 0, 2], device=device).unsqueeze(1)

def baseline():
    """no fair comparison as this only executes model one. Sole purpose as a baseline as this cannot be topped."""
    result = mlp1.forward(input_data)
    return result


def first_update():
    out_1 = mlp1.forward(input_data)
    out_2 = mlp2.forward(input_data)
    out_3 = mlp3.forward(input_data)
    result = torch.where(
        mlp_index == 0,
        out_1,
        torch.where((mlp_index != 0) & (mlp_index != 2), out_2, out_3),
    )
    return result


def second_update():
    result = torch.where(
        mlp_index == 0,
        mlp1(input_data),
        torch.where(
            (mlp_index != 0) & (mlp_index != 2), mlp2(input_data), mlp3(input_data)
        ),
    )
    return result


def third_update():
    # make batches per model that can be executed at once
    # execute batches
    out_1 = mlp1.forward(input_data[(mlp_index == 0)[:, 0]])
    out_2 = mlp2.forward(input_data[(mlp_index == 1)[:, 0]])
    out_3 = mlp3.forward(input_data[(mlp_index == 2)[:, 0]])

    out = torch.zeros(size=(8, output_size), device=device)
    out[(mlp_index == 0)[:, 0]] = out_1
    out[(mlp_index == 1)[:, 0]] = out_2
    out[(mlp_index == 2)[:, 0]] = out_3
    return out


baseline_time = timeit.timeit(baseline, number=200)
print(f"Execution time: {baseline_time} seconds")

first_update_time = timeit.timeit(first_update, number=200)
print(f"Execution time: {first_update_time} seconds")

second_update_time = timeit.timeit(second_update, number=200)
print(f"Execution time: {second_update_time} seconds")

third_update_time = timeit.timeit(third_update, number=200)
print(f"Execution time: {third_update_time} seconds")

result is

Execution time: 1.3142005730001074 seconds
Execution time: 4.0445914549995905 seconds
Execution time: 3.7615038879998792 seconds
Execution time: 3.5695767729998806 seconds

Fully connected layer have massive amount of parameters (compared to CNN), but are very fast to calculate. The forward pass through such layers is simply a matrix multiplication. The size of this matrix (at scales that we have considered here) does not influence the duration it takes to multiply two of them. That is the reason why your attempts to increase the computational efforts did not succeed as you only increased the size of the hidden weight matrix. What takes time, however, is to wait for another multiplication to finish. Therefore I introduced a depth parameter. When the depth is increased it becomes computational cheaper to pre-select the important rows by indexing, as you can see in my example.

3
Karl On

Your input is of shape (8,10). Since you're putting this directly into the MLP, I'm assuming the first dimension is the batch dimension and you won't be adding another batch dimension on top of the first two. If you do, this solution still works, just need some axis munging.

You can pack all the layer weights into a single tensor, index into them, then batch matmul.

import torch
import torch.nn as nn
import torch.nn.functional as F
import timeit
import math

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MultiLinear(nn.Module):
    def __init__(self, in_features, out_features, n_replicas, bias=True, device=None, dtype=None):
        # this is mostly copied from pytorch nn.Linear
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # important note: pytorch nn.Linear weights are of shape (out_features, in_features) 
        # for weird transpose reasons. This variant has in_features first
        self.weight = nn.Parameter(torch.empty((n_replicas, in_features, out_features), **factory_kwargs))
        
        if bias:
            self.bias = nn.Parameter(torch.empty(n_replicas, out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
        
    def reset_parameters(self) -> None:
        # this is also copied from pytorch nn.Linear
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input, index_tensor):
        # input is a 2d float tensor
        # index_tensor is a 1d bool tensor
        
        weight = self.weight[index_tensor]
        bias = self.bias[index_tensor]
        
        output = torch.bmm(input[:,None,:], weight).squeeze() + bias
        return output
    
class MultiMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_replicas):
        super().__init__()
        self.fc1 = MultiLinear(input_size, hidden_size, n_replicas)
        self.fc2 = MultiLinear(hidden_size, output_size, n_replicas)
        
    def forward(self, x, index_tensor):
        x = F.relu(self.fc1(x, index_tensor))
        x = self.fc2(x, index_tensor)
        return x
    
# arbitrary parameters
input_size = 10
hidden_size = 20
output_size = 5
n_replicas = 3

input_data = torch.rand(size=(8, input_size), device=device)
mlp_index = torch.tensor([0, 1, 0, 1, 0, 2, 0, 2], device=device) # note this version skips the extra axis you add

model = MultiMLP(input_size, hidden_size, output_size, n_replicas)
model.to(device)

# note - the first use of the model causes memory allocation to the GPU that is 
# much slower than subsequent calls. this impacts your previous benchmarks.
# calling the model once outside the time test fixes this
result = model(input_data, mlp_index)

def time_test():
    result = model(input_data, mlp_index)
    return result

update_time = timeit.timeit(time_test, number=20000)
print(f"Execution time: {update_time} seconds")

On my GPU, the test times are:

  • Baseline: 1.01s
  • First Update: 4.30s
  • Second Update: 4.46s
  • Third Update: 10.1s
  • My Version: 2.85s

With all that said, you also need some way of making the different weights learn different things. If the outputs of the various models all go into the same loss, there's no reason to use an approach like this over a standard MLP.