1 possible answer(s) on “How to use torch.utils.checkpoint and DistributedDataParallel together

  1. Hey @devilztt if you are manually synchronizing gradients, then you don’t need DDP anymore.

    init_process_group(...)
    model = MyModel(...)
    model(inputs).sum().backward()
    works = []
    for p in model.parameters():
        # to speed it up, you can also organize grads to larger buckets to make allreduce more efficient
        works.append(dist.all_reduce(p.grad, async_op=True))
    for work in works:
        work.wait()
    ...