-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mamba + Tensor Parallel Support #1184
Conversation
LGTM, no comments. @haileyschoelkopf -- As a final check, I'd like to verify that TP gives the expected memory benefits. Can you link the wandb here so I can take a look? |
Yep! https://wandb.ai/eleutherai/mamba-neox-tp-memsavings/workspace?nw=nwuserschoelkopf this should be public and a clean comparison, lmk if more is needed or for some reason it's not visible (MBS=16 for both TP=1 and TP=2, corrected for w/ grad accum.) Seeing 29-30GB for TP=1, 19GB for TP = 2. Here's the full wandb of trial runs including initial tests + adding stuff like Mamba's GPT-2 style init, Mamba-160m, and some Pythia-160m baseline curves: https://wandb.ai/eleutherai/mamba-neox?nw=nwuserschoelkopf |
This PR adds Mamba + TP support.
Loss curves comparing TP=2 to TP=1 with + without
mamba_inner_func_fusion
:Versus when allreduce was missing with inner func fusion turned on:
Also tested that PP seems to work.