Given input tensor of size input(8,10) I got three MLP´s(1,2,3) that all have the input size of 10. Furthermore I have a an index tensor mlp_index(8) which determines the mlp I want to apply onto a certain row in my input. For Example if mlp_index[0] = 2, then the second MLP should be applied onto input[0]. I wrote a minimal example to showcase the problem, and two different ways of dealing with the problem efficiently. However, as you can see, applying just one MLP to the whole input is still significantly faster.
Question: Is there a more efficient way of dealing with that problem?
import torch
import torch.nn as nn
import torch.nn.functional as F
import timeit
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MLP0(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP0, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# arbitrary parameters
input_size = 10
hidden_size = 20
output_size = 5
mlp1 = MLP0(input_size, hidden_size, output_size)
mlp1.to(device)
mlp2 = MLP0(input_size, hidden_size, output_size)
mlp2.to(device)
mlp3 = MLP0(input_size, hidden_size, output_size)
mlp3.to(device)
input_data = torch.rand(size=(8, input_size), device=device)
mlp_index = torch.tensor([0, 1, 0, 1, 0, 2, 0, 2], device=device).unsqueeze(1)
def baseline():
result = mlp1.forward(input_data)
return result
def first_update():
out_1 = mlp1.forward(input_data)
out_2 = mlp2.forward(input_data)
out_3 = mlp3.forward(input_data)
result = torch.where(mlp_index == 0, out_1, torch.where((mlp_index!=0) & (mlp_index!=2), out_2 , out_3))
return result
def second_update():
result = torch.where(mlp_index == 0, mlp1(input_data), torch.where((mlp_index!=0) & (mlp_index!=2), mlp2(input_data) , mlp3(input_data)))
return result
def third_update():
# make batches per model that can be executed at once
input_1 = input_data[(mlp_index == 0)[:,0]]
input_2 = input_data[(mlp_index == 1)[:,0]]
input_3 = input_data[(mlp_index == 2)[:,0]]
# execute batches
out_1 = mlp1.forward(input_1)
out_2 = mlp2.forward(input_2)
out_3 = mlp3.forward(input_3)
out = torch.zeros(size=(8, output_size), device=device)
out[(mlp_index == 0)[:,0]] = out_1
out[(mlp_index == 1)[:,0]] = out_2
out[(mlp_index == 2)[:,0]] = out_3
return out
baseline_time = timeit.timeit(baseline, number=20000)
print(f"Execution time: {baseline_time} seconds")
first_update_time = timeit.timeit(first_update, number=20000)
print(f"Execution time: {first_update_time} seconds")
second_update_time = timeit.timeit(second_update, number=20000)
print(f"Execution time: {second_update_time} seconds")
third_update_time = timeit.timeit(third_update, number=20000)
print(f"Execution time: {third_update_time} seconds")
Ouput:
Execution time: 1.5391472298651934 seconds
Execution time: 2.1761511098593473 seconds
Execution time: 2.233237884938717 seconds
Execution time: 6.252682875841856 seconds
This is a fun question. When playing around with your Code, I found that the execution time of the network itself is minimal compared to any reshaping or data mangling (which explains your strange observations).
Given that your "real" problem has a much larger sample size (and the time reshaping can be diminished compared to execution time), I suggest this third way:
If your network is really that small (and your samples are really so few), indexing your data takes longer than feeding it through the network...
Edit analysis of depth and speed
Note, that I introduced a
depthparameter.result is
Fully connected layer have massive amount of parameters (compared to CNN), but are very fast to calculate. The forward pass through such layers is simply a matrix multiplication. The size of this matrix (at scales that we have considered here) does not influence the duration it takes to multiply two of them. That is the reason why your attempts to increase the computational efforts did not succeed as you only increased the size of the hidden weight matrix. What takes time, however, is to wait for another multiplication to finish. Therefore I introduced a depth parameter. When the depth is increased it becomes computational cheaper to pre-select the important rows by indexing, as you can see in my example.