Model Parallelism for Bert Models


I’m trying to implement Model parallelism for BERT models by splitting and assigning layers across GPUs. I took DeBERTa as an example for this.
For DeBERTa, I’m able to split entire model into ’embedding’, ‘encoder’, ‘pooler’, ‘classifier’ and ‘dropout’ layers as shown in below pic.


With this approach, I trained on IMDB classification task by assigning ‘encoder’ to second GPU and others to first ‘GPU’. At the end of the training, second GPU consumed lot of memory when compared to first GPU and this resulted in 20-80 split of the entire model.

So, I tried splitting encoder layers also as shown below but getting this error – “TypeError: forward() takes 1 positional argument but 2 were given”

embed ='cuda:0')

f6e = dberta.deberta.encoder.layer[:6].to('cuda:0')

l6e = dberta.deberta.encoder.layer[6:].to('cuda:1')

pooler ='cuda:0')

classifier ='cuda:0')

dropout ='cuda:0')

test = "this is to test deberta"

inp_ids = tok_dberta(test, return_tensors='pt').input_ids
att_mask = tok_dberta(test, return_tensors='pt').attention_mask

emb_out = embed('cuda:0'))

first_6_enc_lay_out = f6e(emb_out)
TypeError                                 Traceback (most recent call last)
<ipython-input-15-379d948e5ba5> in <module>
----> 1 first_6_enc_lay_out = f6e(emb_out)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

TypeError: forward() takes 1 positional argument but 2 were given

Plz suggest how to proceed further..

1 possible answer(s) on “Model Parallelism for Bert Models

  1. Yay, so glad to hear you found a solution, @saichandrapandraju!

    Thank you for updating the notebook too!

    If the issue has been fully resolved for you please don’t hesitate to close this Issue.

    If some new problem occurs, please open a new dedicated issue. Thank you.