nn.function.interpolate with bi/trilinear or bicubic is VERY slow when using AMP

Using nn.functional.interpolate is very slow when using mixed precision training. This seems to have already been an issue in the past (#12409) and should be fixed, but I am still seeing a performance regression of up to 5x (yes, that is 5x the runtime of fp32) depending on the network an GPU.

To demonstrate the issue I have created a small script that reproduces it. I am mostly interested in 3D images which is why that example uses 3D convs, but the same issue appears also with 2D inputs (see below, I will add a separate section for that).

What my script does is the following:
It builds a very simple convolutional encoder network (VGG style) that downsamples the input 1x64x64x64 to 512x4x4x4. It then blows up the small representation back to 512x64x64x64 using interpolate. The loss is then simple the MSE between input and upsampled output.
This workflow does not represent a useful application, it is just there to put some load on the GPU and isolate the issue.

I run 100 iterations with this setup in three settings

  1. regular fp32
  2. regular AMP
  3. AMP but with a hack: the encoding (512x4x4x4 tensor) is cast to float, then upsampled and then cast to half.

Expected behavior:
2. should be much faster than 1. due to tensor core acceleration of the convolutions, 3. should be slower than 2. due to additional overhead (casting)

Actual behavior:
2. is slower than 1.. 3. is faster than 2.

Environment:
I am using pytorch 1.7.0 compiled with cuDNN 8.0.5.39. OS is Ubuntu 18.04 (RTX 3090) and centOS (RTX 2080ti, V100).

I am also showing results for 1.7.1 with cuDNN 7.6.5 (installed with pip today) to demonstrate that is not related to my build.

Details and standalone script

from time import time

from torch import nn
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import functional
from torch.optim import SGD


class Network(nn.Module):
    def __init__(self, cast_for_upsample=False):
        super().__init__()
        self.cast_for_upsample = cast_for_upsample

        self.layers = nn.Sequential(
            nn.Conv3d(1, 32, 3, 1, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),
            nn.Conv3d(32, 32, 3, 1, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),

            nn.Conv3d(32, 64, 3, 2, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),
            nn.Conv3d(64, 64, 3, 1, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),

            nn.Conv3d(64, 128, 3, 2, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),
            nn.Conv3d(128, 128, 3, 1, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),

            nn.Conv3d(128, 256, 3, 2, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),
            nn.Conv3d(256, 256, 3, 1, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),

            nn.Conv3d(256, 512, 3, 2, 1, 1, 1, False),
            nn.LeakyReLU(1e-2, True),
            nn.Conv3d(512, 512, 3, 1, 1, 1, 1, False),
        )

    def forward(self, x):
        down = self.layers(x)
        if self.cast_for_upsample:
            up = nn.functional.interpolate(down.float(), x.shape[2:], None, 'trilinear').half()
        else:
            up = nn.functional.interpolate(down, x.shape[2:], None, 'trilinear')
        return up


if __name__ == "__main__":
    inp = torch.rand((2, 1, 64, 64, 64)).cuda()

    net = Network(cast_for_upsample=False).cuda()
    optimizer = SGD(net.parameters(), 0.001)

    torch.cuda.empty_cache()

    # warmup
    for _ in range(10):
        optimizer.zero_grad()
        out = net(inp)
        l = torch.square(inp - out).mean() # just the MSE between input and output as a dummy loss function
        l.backward()
        optimizer.step()

    # fp32 measurement
    st = time()
    for _ in range(100):
        optimizer.zero_grad()
        out = net(inp)
        l = torch.square(inp - out).mean() # just the MSE between input and output as a dummy loss function
        l.backward()
        optimizer.step()
    print('fp32:', time() - st)

    ####################################################
    # now AMP
    net = Network(cast_for_upsample=False).cuda()
    optimizer = SGD(net.parameters(), 0.001)
    scaler = GradScaler()

    torch.cuda.empty_cache()

    # warmup
    for _ in range(10):
        optimizer.zero_grad()

        with autocast():
            out = net(inp)
            l = torch.square(inp - out).mean()  # just the MSE between input and output as a dummy loss function

        scaler.scale(l).backward()
        scaler.step(optimizer)
        scaler.update()

    # amp measurement
    st = time()
    for _ in range(100):
        optimizer.zero_grad()

        with autocast():
            out = net(inp)
            l = torch.square(inp - out).mean()  # just the MSE between input and output as a dummy loss function

        scaler.scale(l).backward()
        scaler.step(optimizer)
        scaler.update()
    print('amp:', time() - st)

    ####################################################
    # now AMP with hacking interpolate so that is runs in fp32
    net = Network(cast_for_upsample=True).cuda()
    optimizer = SGD(net.parameters(), 0.001)
    scaler = GradScaler()

    torch.cuda.empty_cache()

    # warmup
    for _ in range(10):
        optimizer.zero_grad()

        with autocast():
            out = net(inp)
            l = torch.square(inp - out).mean()  # just the MSE between input and output as a dummy loss function

        scaler.scale(l).backward()
        scaler.step(optimizer)
        scaler.update()

    # amp measurement
    st = time()
    for _ in range(100):
        optimizer.zero_grad()

        with autocast():
            out = net(inp)
            l = torch.square(inp - out).mean()  # just the MSE between input and output as a dummy loss function

        scaler.scale(l).backward()
        scaler.step(optimizer)
        scaler.update()
    print('amp cast to float:', time() - st)

I ran this on multiple types of GPUs. Numbers represent time for 100 iterations in seconds.
RTX 3090 (Ampere), 3D:

fp32: 11.094072580337524
amp: 55.94109606742859
amp cast to float: 9.90783166885376

V100 (Volta), 3D:

fp32: 11.654786348342896
amp: 62.763665199279785
amp cast to float: 10.015439987182617

RTX 2080ti (Turing), 3D:

fp32: 13.840788841247559
amp: 65.24087715148926
amp cast to float: 12.579991817474365

As you can see, all architectures run significantly slower with amp. If I enable the casting hack the runtime is reduced substantially, which tells me that the issue is related to upsampling fp16 inputs (or distributing the gradiants through this operation in the backwards pass).

To verify that this is not related to 3D inputs I also ran it in 2d (just replace all Conv3d with Conv2d and use inp = torch.rand((2, 1, 512, 512)).cuda()). Same story (showing RTX3090 only):

RTX 3090 (Ampere), 2D:

fp32: 18.6866774559021
amp: 32.80347037315369
amp cast to float: 18.29295825958252

(same story with ‘bicubic’, data not shown)

To ensure this is related to bi/trilinear and bicubic interpolation I changed the interpolation to ‘nearest’, 2D:

fp32: 4.452304124832153
amp: 3.830399990081787
amp cast to float: 4.455098628997803

‘nearest’ is working fine.

To ensure this is not related to my pytorch build I also ran it with pytorch 1.7.1 (installed today with pip). This is a 2D run on a RTX 2080 ti because 1) RTX 3090 is not supported with cuDNN 7.6.5 so I cannot run it on this GPU and 2) 1.7.1 still comes with cuDNN 7.6.5 which does not support tensor core acceleration for 3D convs on Turing, so 3D will be extra slow.

RTX 2080 ti, torch 1.7.1, cuDNN 7.6.5, 2D:

fp32: 21.80005693435669
amp: 35.320844650268555
amp cast to float: 22.71259832382202

still there 🙂

Would be great if someone could look into this! If there is anything else you need please let me know.

Best,

Fabian

cc @mcarilli @ptrblck

1 possible answer(s) on “nn.function.interpolate with bi/trilinear or bicubic is VERY slow when using AMP

  1. @FabianIsensee thank you for the updates and the simplified script.
    @xwang233 in both original and simplified scripts the backward for interpolation is run (it is at the end of the network, not the beginning), and it should be run, so in that sense there’s no problem with amp.
    For the sizes in the simplified script (and most likely in the original, but I didn’t double-check) the backward interpolate kernel time is much slower in master than it was in 1.7, both for fp32 and fp16.
    Simplified script, 3D: 1,32,7,7,7 -> 1,32,64,64,64, P100, interpolation backward kernel time according to nvprof
    fp32 master ~32 ms
    fp16 master ~39 ms
    fp32 1.7. ~4.7 ms
    fp16 1.7 ~8.3 ms

    @FabianIsensee is such a large interpolation (7->64 or 4->64) a typical usecase? I suspect the kernels are optimized for interpolation by a factor of 2 give or take, so 7->14, or 4->8, and if this big blow-up case is common, a different approach might be needed.