Submitted by super_deap t3_11tmpc5 in MachineLearning
Hi,
I did a quick experiment with Pytorch 2.0 Native scaled_dot_product_attention. I was able to a single forward pass within 9GB of memory which is astounding. I think by patching existing Pretrained GPT models and adding more positional encodings, one could easily fine-tune those models to 32k attention on a single A100 80GB. Here is the code I used:
​
I think it should be possible to replicate even GPT-4 with open source tools something like Bloom + FlashAttention & fine-tune on 32k tokens.
Update: I was successfully able to start the training of GPT-2 (125M) with a context size of 8k and batch size of 1 on a 16GB GPU. Since memory scaled linearly from 4k to 8k. I am expecting, 32k would require ~64GB and should train smoothly on A100 80 GB. Also, I did not do any other optimizations. Maybe 8-bit fine-tuning can further optimize it.
Update 2: I basically picked Karpaty's nanoGPT and patched the pretrained GPT-2 by repeating the embeddings N-times. I was unable to train the model at 8k because generation would cause the crash. So I started the training for a context window of 4k on The Pile: 1 hour in and loss seems to be going down pretty fast. Also Karpaty's generate function is super inefficient, O(n^4) I think so it took forever to generate even 2k tokens. So I generate 1100 tokens just to see if the model is able to go beyond 1k limit. And it seems to be working. Here are some samples at 3k iteration.
​
Update 3: I have started the training and I am publishing the training script if anyone is interested in replicating or building upon this work. Here is the complete training script:
https://gist.github.com/NaxAlpha/1c36eaddd03ed102d24372493264694c
I will post an update after the weekend once the training has progressed somewhat.
Post-Weekend Update: After ~50k iterations (the model has seen ~200 million tokens, I know this is just too small compared to 10s of billions trained by giga corps), loss only dropped from 4.6 to 4.2 on The Pile:
AFAIR, the loss of GPT-2 on the Pile if trained with 1024 tokens is ~2.8. It seems like the size of the dimension for each token is kind of limiting how much loss can go down since GPT-2 (small) has an embedding dimension of 768. Maybe someone can experiment with GPT-2 medium etc. to see how much we can improve. This is confirmation of the comment by u/lucidraisin below.
No-Belt7582 t1_jcjqk6s wrote
Pytorch 2.0 is really impressing everyday