Using torch.distributed.barrier() makes the whole code hang

🐛 Bug

Adding torch.distributed.barrier(), makes the training process hang indefinitely.

To Reproduce

Steps to reproduce the behavior:

  1. Run training in multiple GPUs (tested in 2 and 8 32GB Tesla V100)
  2. Run the validation step on just one GPU, and use torch.distributed.barrier() to make the other processes wait until validation is done.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import torch.multiprocessing as mp
import torch.distributed as dist

import numpy as np

from torchvision import datasets
import torchvision.transforms as transforms


class img_classifier(nn.Module):
    def __init__(self):
        super(img_classifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 500)
        self.fc2 = nn.Linear(500, 10)
        self.dropout = nn.Dropout(0.2)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))        

        return x

def train(gpu, args):
    
    rank = args['nr'] * args['gpus'] + gpu                          
    dist.init_process_group(                                   
        backend='nccl',                                         
        init_method='env://',                                   
        world_size=args['world_size'],                              
        rank=rank                                               
    )
    
    torch.cuda.set_device(gpu)
    
    ## Dataset
    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_data = datasets.CIFAR10('CIFAR10_data/', train=True, download=True, transform=transform)
    valid_data = datasets.CIFAR10('CIFAR10_data/', train=False, download=True, transform=transform)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
    	train_data,
    	num_replicas=args['world_size'],
    	rank=rank,
    )
    
    batch_size = len(train_data)//args['world_size']
    
    if gpu == 0:
        print(batch_size)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
    
    ##Model
    
    model = img_classifier().cuda(gpu)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.02)
    
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    
    n_epochs = 20
    valid_loss_min = np.Inf

    for epoch in range(1, n_epochs+1):

        train_loss = 0.0
        valid_loss = 0.0

        model.train()
               
        for data, target in train_loader:
            if torch.cuda.is_available:
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()*data.size(0)


        # validate the model
        if gpu==0 :
            model.eval()
            for data, target in valid_loader:
                if torch.cuda.is_available:
                    data, target = data.cuda(), target.cuda()
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()*data.size(0)

            # calculate average losses
            train_loss = train_loss/len(train_loader.dataset)
            valid_loss = valid_loss/len(valid_loader.dataset)

            # print training/validation statistics 
            print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
                epoch, train_loss, valid_loss))

            # save model if validation loss has decreased
            if valid_loss <= valid_loss_min:
                print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_min,
                valid_loss))
                torch.save(model.state_dict(), 'model_cifar.pt')
                valid_loss_min = valid_loss

        dist.barrier()
        
def main():
    args = {
        'gpus' : 2,
        'nodes' : 1,
        'nr': 0
    }
    
    args['world_size'] = args['gpus'] * args['nodes']
    mp.spawn(train, nprocs=args['gpus'], args=(args,))
    
    
if __name__ == '__main__':
    main()

Expected behavior

I would expect the training to continue as normal once validation is done.

Environment

PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Clang version: Could not collect
CMake version: version 3.13.4

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB
GPU 4: Tesla V100-SXM2-32GB
GPU 5: Tesla V100-SXM2-32GB
GPU 6: Tesla V100-SXM2-32GB
GPU 7: Tesla V100-SXM2-32GB

Nvidia driver version: 450.80.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.7.0
[pip3] torchfile==0.1.0
[pip3] torchvision==0.8.1
[conda] blas 1.0 mkl anaconda
[conda] cudatoolkit 11.0.221 h6bb024c_0 anaconda
[conda] mkl 2020.2 256 anaconda
[conda] mkl-service 2.3.0 py38he904b0f_0
[conda] mkl_fft 1.2.0 py38h23d657b_0
[conda] mkl_random 1.1.1 py38h0573a6f_0 anaconda
[conda] numpy 1.18.5 pypi_0 pypi
[conda] pytorch 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0 pytorch
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchvision 0.8.1 py38_cu110 pytorch

Additional context

By putting print statements around torch.distributed.barrier(), I’ve found that the barrier statement is executed but the process still hangs after.

I have tested other configurations too:

  • Removing the distributed sampler but leaving torch.distributed.barrier(). It hangs
  • Leaving the distributed sampler but removing torch.distributed.barrier(). Does not hang

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu

1 possible answer(s) on “Using torch.distributed.barrier() makes the whole code hang

  1. Fix for no_grad compatibility is in #54159. After that PR, no_grad should also work:

            # validate the model
            if gpu==0 :
                with torch.no_grad():
                    model.eval()
                    for data, target in valid_loader:
                        if torch.cuda.is_available:
                            data, target = data.cuda(), target.cuda()
                        output = model(data)
                        loss = criterion(output, target)
                        valid_loss += loss.item()*data.size(0)

    The root cause of the original hang is because when running evaluation on just one of the ranks, that rank would still try to evaluation whether it should rebuild DDP communication buckets or not. As a result, bucket rebuild states might diverge on different ranks and some ranks might expect configuration sync collective communication while others would not participate.

    The solution could be either using ddp.module to run evalution or using no_grad() mode to run evaluation.

    cc @zhaojuanmao