How to optimize the custom bilinear sampling alternative to grid_sample for TensorRT inference?

2k Views Asked by At

I am trying to covert the model with torch.nn.functional.grid_sample from Pytorch (1.6) to TensorRT (7) through ONNX (opset 11). Opset 11 does not support grid_sample conversion. Custom alternative I found (https://github.com/pytorch/pytorch/issues/27212) is extremely slow while running in Pytorch and have the problem with converting the main loop to TRT.

My own implementation of bilinear sampling (not just grid_sample, but the whole original sampling, based on grid_sample) performs much faster in Pytorch and is converted to TRT successfully. But my custom bilinear sampling in TRT is slower, than the one in Pytorch (5.6 ms vs 2.0 ms). It turns out, that Pytorch image[:, ind, y0, x0] indexing produce Gather layer with running time about 0.97 ms. And there are 4 such layers in the TRT version of such bilinear sampling.

So the questions are:

  • How should I optimize my Pytorch code to get the effective TRT model?
  • What should I do to make Gather layer perform faster?
  • Can the creation of this function as a custom TRT plugin help with making it faster?

Here is the code of bilinear sampling function:

def bilinear_sample_noloop(image, grid):
    """
    :param image: sampling source of shape [N, C, H, W]
    :param grid: integer sampling pixel coordinates of shape [N, grid_H, grid_W, 2]
    :return: sampling result of shape [N, C, grid_H, grid_W]
    """
    Nt, C, H, W = image.shape
    grid_H = grid.shape[1]
    grid_W = grid.shape[2]
    xgrid, ygrid = grid.split([1, 1], dim=-1)
    mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float()
    x0 = torch.floor(xgrid)
    x1 = x0 + 1
    y0 = torch.floor(ygrid)
    y1 = y0 + 1
    wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wb = ((x1 - xgrid) * (ygrid - y0)).permute(3, 0, 1, 2)
    wc = ((xgrid - x0) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wd = ((xgrid - x0) * (ygrid - y0)).permute(3, 0, 1, 2)
    x0 = (x0 * mask).view(Nt, grid_H, grid_W).long()
    y0 = (y0 * mask).view(Nt, grid_H, grid_W).long()
    x1 = (x1 * mask).view(Nt, grid_H, grid_W).long()
    y1 = (y1 * mask).view(Nt, grid_H, grid_W).long()
    ind = torch.arange(Nt, device=image.device) #torch.linspace(0, Nt - 1, Nt, device=image.device)
    ind = ind.view(Nt, 1).expand(-1, grid_H).view(Nt, grid_H, 1).expand(-1, -1, grid_W).long()
    image = image.permute(1, 0, 2, 3)
    output_tensor = (image[:, ind, y0, x0] * wa + image[:, ind, y1, x0] * wb + image[:, ind, y0, x1] * wc + \
                 image[:, ind, y1, x1] * wd).permute(1, 0, 2, 3)
    output_tensor *= mask.permute(0, 3, 1, 2).expand(-1, C, -1, -1)
    image = image.permute(1, 0, 2, 3)
    return output_tensor, mask

Time profiling parameters:

  • Time profiling experiments were performed on laptop Dell G3 15 (Core i7 8750H 2.2 GHz x12, 16 Gb RAM (2666MHz), NVidia GeForce GTX 1050 Ti).
  • Pytorch environment for profiling: Python 3.7 Anaconda 3 environment, Pytorch 1.6. Pytorch time profiling is performed via time.time() with torch.synchronize() before each time stamp.
  • TRT environment for profiling: Docker container http://nvcr.io/nvidia/tensorrt:20.06-py3. Profiling was performed with trtexec, and also with custom C++ and Python code. All three results are close.

A part of TRT model profiling with trtexec:

     Layer   Time (ms)   Avg. Time (ms)   Time %
...
   Mul_146        5.82             0.03      0.5
   Add_147        8.50             0.04      0.7
Gather_148      214.39             0.97     17.3
Gather_174      214.25             0.97     17.3
Gather_201      213.88             0.97     17.3
Gather_228      214.48             0.97     17.3
 Add_237))       25.01             0.11      2.0
   Mul_251        7.84             0.04      0.6
     Total     1238.40             5.60    100.0

Additionally I tried viewing the image as the linear array over all dimensions except C and creating the linear indexes to adress elements in the form image[:, p0]. And for this case Gather becomes even slower (about 1.07 ms). Then I considered C=1 (as it always is in the original model) and address the tensor elements as image[p0]. This time Gather takes about 0.92 ms (still too slow).

1

There are 1 best solutions below

0
On

The Following Code can be used to convert from Pytorch to TensorRT as a bilinear_interpolate of an image

def bilinear_interpolate_torch(im, y, x):
'''
   im : B,C,H,W
   y : 1,numPoints -- pixel location y float
   x : 1,numPOints -- pixel location y float
'''


x0 = torch.floor(x).type(torch.cuda.LongTensor)
x1 = x0 + 1

y0 = torch.floor(y).type(torch.cuda.LongTensor)
y1 = y0 + 1

wa = (x1.type(torch.cuda.FloatTensor) - x) * (y1.type(torch.cuda.FloatTensor) - y)
wb = (x1.type(torch.cuda.FloatTensor) - x) * (y - y0.type(dtype))
wc = (x - x0.type(torch.cuda.FloatTensor)) * (y1.type(torch.cuda.FloatTensor) - y)
wd = (x - x0.type(torch.cuda.FloatTensor)) * (y - y0.type(torch.cuda.FloatTensor))
# Instead of clamp
x1 = x1 - torch.floor(x1 / im.shape[3]).int()
y1 = y1 - torch.floor(y1 / im.shape[2]).int()
Ia = im[:, :, y0, x0]
Ib = im[:, :, y1, x0]
Ic = im[:, :, y0, x1]
Id = im[:, :, y1, x1]

return Ia  * wa + Ib * wb + Ic * wc + Id * wd