Submitted by super_deap t3_11tmpc5 in MachineLearning


I did a quick experiment with Pytorch 2.0 Native scaled_dot_product_attention. I was able to a single forward pass within 9GB of memory which is astounding. I think by patching existing Pretrained GPT models and adding more positional encodings, one could easily fine-tune those models to 32k attention on a single A100 80GB. Here is the code I used:


I think it should be possible to replicate even GPT-4 with open source tools something like Bloom + FlashAttention & fine-tune on 32k tokens.

Update: I was successfully able to start the training of GPT-2 (125M) with a context size of 8k and batch size of 1 on a 16GB GPU. Since memory scaled linearly from 4k to 8k. I am expecting, 32k would require ~64GB and should train smoothly on A100 80 GB. Also, I did not do any other optimizations. Maybe 8-bit fine-tuning can further optimize it.

Update 2: I basically picked Karpaty's nanoGPT and patched the pretrained GPT-2 by repeating the embeddings N-times. I was unable to train the model at 8k because generation would cause the crash. So I started the training for a context window of 4k on The Pile: 1 hour in and loss seems to be going down pretty fast. Also Karpaty's generate function is super inefficient, O(n^4) I think so it took forever to generate even 2k tokens. So I generate 1100 tokens just to see if the model is able to go beyond 1k limit. And it seems to be working. Here are some samples at 3k iteration.


Update 3: I have started the training and I am publishing the training script if anyone is interested in replicating or building upon this work. Here is the complete training script:

I will post an update after the weekend once the training has progressed somewhat.

Post-Weekend Update: After ~50k iterations (the model has seen ~200 million tokens, I know this is just too small compared to 10s of billions trained by giga corps), loss only dropped from 4.6 to 4.2 on The Pile:

AFAIR, the loss of GPT-2 on the Pile if trained with 1024 tokens is ~2.8. It seems like the size of the dimension for each token is kind of limiting how much loss can go down since GPT-2 (small) has an embedding dimension of 768. Maybe someone can experiment with GPT-2 medium etc. to see how much we can improve. This is confirmation of the comment by u/lucidraisin below.



You must log in or register to comment.

No-Belt7582 t1_jcjqk6s wrote

Pytorch 2.0 is really impressing everyday


RobbinDeBank t1_jcki2vl wrote

What are its biggest improvements over pytorch 1?


royalemate357 t1_jckqgsr wrote

Pretty sure the main improvement is "torch.compile" which can optimize your code in a nice easy one liner. There's some other nice quality of life improvements like the built in flash attention OP is using, and I think some distributed training stuff. But it's fully backwards compatible, which is great (looking at you tensorflow)


MoistYogurtcloset400 t1_jclv6r1 wrote

Is this torch.compile only compatible with cuda device only?


royalemate357 t1_jclz4t0 wrote

hmm, I am not too sure but their blogpost says this:

>TorchInductor uses a pythonic define-by-run loop level IR to automatically map PyTorch models into generated Triton code on GPUs and C++/OpenMP on CPUs.

so it seems like they support CPU. I also tried it briefly on google colab CPU-only, and it seems to work (i didn't benchmark speed though). I doubt it supports non cuda GPUs but then again support for those even in the general case isnt very good.


mike94025 t1_jcn7ksu wrote

Works for all. You need a compiler backend that can code-gen for your target, and need a frontend for the optimizer that can process the IR.

Alternatively, you need a backend for Triton (or another already supported optimizer) that can codegen for your target architecture.


programmerChilli t1_jcny4qx wrote

We currently officially support Cuda and CPU, although in principle it could be used for other backends too.


Competitive-Rub-1958 t1_jcl97q0 wrote

> Either autograd is disabled (using torch.inference_mode or torch.no_grad) or no tensor argument requires_grad > training is disabled (using .eval())

What's the point of FlashAttention if you can't use it during training? 🤔


mike94025 t1_jcly3mi wrote

Documentation was not updated. Yes, you can use flash attention for training.

The first version included only forward() as we were resolving some issues with backward(). Docstring will be updated.


Competitive-Rub-1958 t1_jcm5ahk wrote

cool! So I just need to enable `flash_sdp`, then ensure I'm basically computing self-attention and have `batch_first=True`. Would that be correct?


mike94025 t1_jcmho8t wrote

Don't call flash_sdp directly. That way you're locked into particular hardware and create non-portable models. You can either use F.scaled_dot_product_attention() , or you use nn.MultiHeadAttention. In either case it will pick the right implementation based on the hardware you have, and the constraints. Ideally, the constraints would be weakened in the future, and/or new kernels might support other operating points in an optimized manner, and then the kernel picker can dispatch to that implementation.

See the kernel-picker logic that dispatches based on input characteristics in the source code, and/or the SDPA tutorial here =>


Competitive-Rub-1958 t1_jcn8bti wrote

cool. I just wanted to make it explicit to make sure I'm running `FlashAttention`. Perhaps there's an easy way to check that?


mike94025 t1_jcv83hu wrote

Yes - use the backend context manager to disable all other backends to see that you're running the one you want. (Otherwise, since all other backends are disabled, you'll get an error.)

SDPA context manager is intended to facilitate debug (for perf or correctness), and is not (and should not be) required for normal operational usage.

Check out the SPDA tutorial at


Competitive-Rub-1958 t1_jd40cwb wrote

would that mean for forcing MHA to use it, I should wrap the ctxmanager around the line where I forward through it?

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
            x = x + self.attn_head(x, x, x, need_weights=False)[0]

because that doesn't really seem to work :(


mike94025 t1_je5mfa8 wrote

This doesn't force it. It says that flash is enabled, and stone others. To force it, you have to disable all other kernels. Then it’s flash or bust.

You can find more in our blog which got published today and the SDPA tutorial. Both are linked here

PS: the context manager can be used anywhere outside the call as well, including around the call to model.forward.


oathbreakerkeeper t1_jd0lu2p wrote

Am I looking in the wrong place? It seems like the torch 2.0 code still requires training==False in order to use FlashAttention:


Dependent_Ad5120 t1_jd3m0ce wrote

try fp16, that doesn't require training=False apparently.


oathbreakerkeeper t1_jd43931 wrote

I'm using amp mixed precision which should be using fp16. It still requires training==false.

But the torch code also disables flash attention if autocast is enabled I'm not sure how to deal with that one.


Dependent_Ad5120 t1_jdec7kx wrote

I don't know. I was using pure fp16, no autocast and it works.


oathbreakerkeeper t1_jdgjte0 wrote

How do you use pure fp16 out of curiosity? I've only ever trained with mixed precision, letting pytorch handle the fp16 stuff from there.

Do you have an example of a github repo that does it?


Dependent_Ad5120 t1_je5qfmp wrote

I don't have a github repo for this, but it is pretty simple:


model = nn.Transformer().cuda().half

input = torch.rand(..).cuda().half

with sdp_kernel(...enable only flash attn):

output = model(input)


These 4 lines should be enough.


mike94025 t1_je5nrdi wrote

You’re looking in the wrong place. What you’re looking at is the BT gen1 fastpath, not the BT gern 2 custom kernels.

You need to look at F.multi_head_attention_forward().

The fastpath still services inference until a full rewrite of for now that will hopefully be refactored in a future release. (There’s always a tension between refactoring and introducing new features under a tone and staffing constrained problem formulation.)


Dependent_Ad5120 t1_jd1d00j wrote

It seems to me that I have to call model.eval() to use the memory_efficient attention. Otherwise, it throws an error of no available kernel.

I tried on both rtx 3090 and A100, in both cases, it seems only have enable_flash=True resulted in the same error of no available kernel, even with model.eval().

So my questions are:

  1. with model.eval(), does it mean drop_out is not enabled during training?
  2. Am I doing something wrong for flash attention? How do I actually enable it?

Thanks a lot!


Dependent_Ad5120 t1_jd3knio wrote

OK, I found out why. To use flash attention, I had to use fp16. It is a bit faster then using memory_efficient attention in my test.


cthorrez t1_jclwi3d wrote

Fast inference is also important to anyone who wants to deploy these things.


mike94025 t1_jcmepd6 wrote

With Better Transformer, ¿Por que no los dos?


cthorrez t1_jcmhg8m wrote

Because the algorithm seems to only work on inference. Probably due to memory management of the cached activations or something. (Idk the actual technical reasons)


mike94025 t1_jcmlddm wrote

Better Transformer supports both, today. Some optimizations are still inference-only (and in particular support for variable-sequence length Nested Tensor) and the inference fastpath is a bit silo'ed, but nothing that future PyTorch update could not fix.


Spiritual-Reply5896 t1_jcjz3fn wrote

Why is everyone talking about the context length, and not about some kind of memory retrieval? Is the assumption that by increasing context length we can eventually scale it to infite, thus replacing any kind of external memory?


-Rizhiy- t1_jck6j55 wrote

Increasing the context window is a simple albeit costly method of increasing amount of addressable information. Working with external memory is not as straightforward.


felheartx t1_jcli6si wrote

You said working with external memory is not as straightforward. Can you explain that?

I've read this: and even though I'm not super familiar with the details, to my untrained eye it seems like attaching external memory is easier than extending the context size.

Just from reading posts on this subreddit I get the feeling that getting larger and larger context sizes is very difficult. Whereas simply attaching this sort of "dictionary" thing is pretty easy to do.


lmericle t1_jcln487 wrote

You will find that in hype circles such as NLP there's a lot of thought-terminating cliches passed around by people who are not so deep in the weeds. Someone says something with confidence, another person doesn't know how to vet it and so just blindly passes it on, and all of a sudden a hack becomes a rumor becomes dogma. It seems to me to be this way with context vs memory.

Put another way: it's the kind of attitude that says "No, Mr. Ford, what we wanted was faster horses".


KerfuffleV2 t1_jclo0oh wrote

I'm not an ML person, but it seems like that paper is just teaching the LLM to simulate a Turing machine. Actually making it respond normally while doing practical stuff like answering user queries would be a different thing.

Also, suppose the LLM has access to external memory. First, you have to teach it how to interact with that external memory (via special command sequences in its tokens, most likely). Then you have to teach it/take steps to make it appropriately note which things are important or not and store/retrieve them as necessary. All of this requires tokens for input/output so it will increase processing time even when used perfectly, these tokens will also consume the existing context window.

One really big thing with LLMs now is it seems like they don't (and maybe can't) know what they know/don't know. They just predict tokens, they can't really do introspection. Of course, they can be trained to respond that they don't know certain things, but getting the LLM to decide it needs to use the external memory doesn't seem like the simplest thing.

I mean, take humans as an example: Are you effective at taking notes, organizing them in a way that lets you easily recall them in the future, etc? It's not even an easy skill for humans to develop, and we're relatively good at knowing what we don't know.

Another thing is the paper you linked to says it set the temperature to 0, to make the responses very deterministic. Generally this makes them a lot less creative as well. If you turn up temperature, you potentially increase the chances that the LLM generates malformed queries for the external memory or stuff like that.

Anyway, I don't know much about the technical side of increasing the context window but when the context window is bigger the thing can just use it as far as I know. Taking advantage of some sort of external memory system seems like it's a very, very complicated thing to solve reliably.

Again, note this is coming from someone that doesn't really know much about ML, LLMs, etc. I'm just a normal developer, so take all this with a grain of salt.


Art10001 t1_jcnakzz wrote

There is a github project that uses embeddings with GPT-3.5 to create infinite memory, as long as you have infinite disk space. The database grows and grows the more you talk.


KerfuffleV2 t1_jcncad2 wrote

You'd have to link me what you're talking about for me to say anything. I doubt it works as straightforwardly as "infinite memory" though.


Art10001 t1_jcnv4kf wrote

There are other similar projects I found while trying to recover this one, which may also be of interest. You can find them by searching "chatgpt embeddings memory github"


KerfuffleV2 t1_jcp7qcz wrote

I'm not sure I fully understand it, but it seems like it's just basically adding context to the prompt it submits with requests. For obvious reasons, the prompt can only get so big. It also requires making requests to OpenAI's embedding API which isn't free: so it's both pushing in more tokens and making those extra requests.

I can definitely see how that approach could produce better results, but it's also not really unlimited memory. Note: I skimmed the source, but I'm not really a C++ person and I didn't actually set it up to use my OpenAI account via API.


127-0-0-1_1 t1_jcqd8se wrote

It's not unlimited memory in a single run, which remains unchanged, but that doesn't seem super relevant to what people want (nothing wrong with multiple runs!). Think about a turing machine, or heck, yourself. A turing machine only has access to a single cell of memory at at time, and in practice, modern CPUs only have access to their registers directly. For long term storage, that goes into RAM, which is accessed on demand.

Similarly, your own memory is not large enough to contain all the information you'd need to complete most complex tasks. That's why you have to write things down and actively try to remember things.

While that uses OpenAI's embedding networks, like the autoregressive LLM itself, it's not like OpenAI has a monopoly on text embeddings by any means (far from it - embeddings have a very straightforward business use and are used in practically any major site you know for things like similarity queries).

While I think OP is overhyping the degree to which this is "infinite memory" yet, in a hypothetical turing machine formulation where the network can more proactively store and restore memory, it would allow for it to be, at least, turing complete.


Spiritual-Reply5896 t1_jcsq4d9 wrote

Exactly, I wanted to find out whether there is some research regarding these embeddings. I really think that by efficient pruning/organization of these "memories" its possible to generate quite advanced memory. Things like embedding consistency then becomes a big player - how much does length affect the embedding, what is the optimal information content vs string size...


hfnuser0000 t1_jcnspad wrote

Hi there! It sounds really interesting! Could you please share the name of the project or provide a link to it? I would love to check it out. Thank you!


CleanThroughMyJorts t1_jck7114 wrote

I don't think the two are mutually exclusive.

The problem with retrieval though (at least current implementations) is the model can't attend to memory globally the way it does with context memory; you're bottlenecked by the retrieval process having to bring things into context through a local search.


super_deap OP t1_jck82rd wrote

Nuance is proportion to context.

Imagine we want to ask the language model to improve a certain module in Linux Kernel.

If I understood them correctly, memory-augmented transformers won't be able to fit together all the pieces to understand what needs to be improved and how because they need to make repeated calls to memory and search/summarize those calls to get a basic understanding and thus miss out on important details.

Compare that to huge context, they have everything they need for the memory in their context and there is no loss of details (in case of full attention).


Spiritual-Reply5896 t1_jckq519 wrote

Lets say Linux kernel manual is embedded as memories. If we can get accurate semantic representation of the question, then we should be able to find relevant context from the memory, and use enough context to answer the question in fewer tokens compared to providing the whole Linux manual as context. If we assume that computing the attention is as fast as vector search, then its a no-brainer that retrieving only relevant context from memory is better approach than using the whole manual. Its of course a trade off between accuracy and speed/scalability, but I argue its a good tradeoff as text isn't often that information dense.

The ability to produce semantically coherent embeddings from text is the grain and salt of LLM, so why would it be any bigger problem to retrieve these memories from external / infinite database than from context window?

Im just hypothesizing with my limited knowledge, please correct me if I make stupid assumptions :)


Screye t1_jcl549n wrote

Context length is also a hard limit on how many logical-hops the model can make.

If each back-n-forth takes 500-ish tokens, then the model can only reason over 16 hops over 8k tokens. With 32k tokens, it can reason over 64 hops. This might allow for emergent behaviors towards tasks that have previously been deemed impossible due to needing at least a minimum number of hops to reason about.

For what it's worth, I think memory retrieval will work just fine for 90% of scenarios and will stay relevant even for 32k tokens. Esp. if the wiki you are retrieving from is millions of lines.


VarietyElderberry t1_jcm4ghk wrote

Could you explain what you mean with a logical-hop and how it is dependent on a certain number of tokens? If you are referring to a paper, a link would be appreciated.


Screye t1_jcmpd5i wrote

This is more derived from extensive personal experience with prompt engineering / fine tuning over the last 2 years.

Simply put:

  • The model learns what it sees. Or, throw enough data of a certain type and emergent properties relating to that data will shop given enough data & compute.
  • If it has never seen data past 8k tokens in the past (due to context window limitations), the model won't need to learn to reason over more than 8k tokens.
  • The source data (humans) have limitations on the complexity of thoughts that can be captured within 8k tokens vs 32k tokens
  • That's not say that the model doesn't reason over longer windows using latent knowledge, which makes its implicit 'reasoning window' much larger than just 8k tokens. But, that is fundamentally different than explicitly reasoning over a 32k window.
  • The model today can only assemble a chain-of-thought prompt of 8k tokens. If there is never any human feedback or loss-landscape-optimization for when it fails to reason past 8k tokens, then any ability the model gains there will be purely incidental.
  • On the other hand, when you have chain-of-thought prompt chains that are 32k tokens long, we can naturally expect it to contain more axioms, postulates and relationships between those postulates/axioms.
  • Those completions will get evaluated against human feedback & just self-supervised-scenarios, which should explicitly optimize the loss landscape to reason over far more complex logical statements.

Idk if that makes sense. Our field keeps moving away from math, and as embarrassing as it is to antromorphize the model, it does make it easier to get the point across.


HateRedditCantQuitit t1_jcmdot7 wrote

I think of context as a end-to-end connected version of retrieval. You can backprop from loss to retrieved info, but you also want to backprop from loss to the non-retrieved info, which would basically be equivalent to having it all in context (in a handwavy way). Which is to say that just having more context is a simple solution.

I think everyone knows increasing context length is not 100% sufficient, but it sure is a simple convenient solution.


hfnuser0000 t1_jcl5akd wrote

How many ways are there to implement memory retrieval? Can someone explain the intuition behind it? Thank you in advance!


lucidraisin t1_jcl0y16 wrote

it is important for everyone to know that there may be a capacity limit to the context length, as explored by this paper. gpt4 may not have this limit, but smaller variants like llama may. it also depends on the task you are trying to solve. you cannot just get 'infinite context', as some would sell you that their network can do. more research needed... hopefully pytorch 2.0 leads to that


super_deap OP t1_jcl1omd wrote

Thanks for that paper; I came across it a while ago but have not read it yet. Is the limit due to number of model parameters or size of embedding. I suspect size of embedding to be the biggest factor in limiting how big the context can be.


lucidraisin t1_jcl2rkh wrote

yea, literature is scant and all over the place in the efficient attention field. in this paper, i believe they claim it is query-key dimension (d_dot), but i think it should depend on the number of heads too. i don't know of any other papers that explore this topic. i just don't want people to be surprised if they fine tune to greater context lengths and things don't work as well as gpt4


super_deap OP t1_jcl3whl wrote

That is understandable. I am working with that assumption as well. (I have failed too many of such experiments to have a blind faith 🙈)


lucidraisin t1_jcl6ecd wrote

no worries, thanks for running the experiments and sharing your results 🙏


antonb90 t1_jczajd1 wrote

Things are improving fast.

>COLT5 is better at any speed. For 16k input length, COLT5 matches or exceeds LONGT5 quality for Large and XL with 35-75% training speedup and 50-100% inference speedup on top of the order-of-magnitude inference speedup from MQA. Encoder speedups are even greater (Appendix D). COLT5-XL also achieves SOTA performance on the SCROLLS benchmark


>COLT5 achieves both stronger performance and faster inference speed at all input lengths and is able to effectively make use of extremely long inputs. We note that COLT5 achieves large quality gains by going from 32k to 64k tokens even while keeping the number of routed tokens constant, providing more evidence for our hypothesis.

Google's new COLT5 64k,


lucidraisin t1_jczarq8 wrote

that isn't for decoders. encoder only, and still needs to be verified. the majority of research paper never work out on closer examination. just trust me, stick with flash attention for now until further notice and save yourself a lot of headache


Unlucky_Excitement_2 t1_jczk2lm wrote

Since you're the OG with this. Can I pick your brain? You don't see value in hyena hierachrcy. Inference with 64k context window but 100x more efficient than flash attention. I notice on github, you plan on implementing flash attention on all your transformer based models? HH perplexity actually scales with parameter count scaling. Thoughts?


lucidraisin t1_jcznnvh wrote

actually, i'm keeping an eye on Hyena! there are however a number of issues i still have with the paper (i'm not going to play reviewer 2, as it is not my place nor is reddit a good forum for that), but i intend to reserve judgement and try it out on few difficult problems like genomics and EEG later this year. proof is in the pudding.


Unlucky_Excitement_2 t1_jczo8wf wrote

Those are actually super compelling problems. I'll keep an eye out. Again thank you, you contribute so much.


lucidraisin t1_jczoelv wrote

yea no problem, happy to chat more if you are doing research in this space. you can always reach out to me through email


kittenkrazy t1_jcjxr0b wrote

I wonder if this can effectively be used in LLaMA. 32K context would be a game changer


Nhabls t1_jck9a4c wrote

Yeah just need enough training time and data to be able to train those 32k context layers effectively........................


mrpogiface t1_jckmi7d wrote

Definitely, but you'd need to further fine-tune the model to "teach" it to make use of the additional context


kreuzguy t1_jcliuwx wrote

Someone should definitely look into this!


harharveryfunny t1_jckltrp wrote

> I think it should be possible to replicate even GPT-4 with open source tools something like Bloom + FlashAttention & fine-tune on 32k tokens.

So you mean build a model with a 32K attention window, but somehow initialize it with weights from BLOOM (2K window) then finetune ? Are you aware of any attempts to do this sort of thing ?


super_deap OP t1_jckpoey wrote

I think one just needs to duplicate positional embeddings and we are good to go. Of course, there needs to be more comprehensive empirical analysis on this and I have not come across any of such attempts. I did a basic experiment and it seems to work but will have to wait and see.


Mindless-Ad8595 t1_jckrmcz wrote

Please keep posting your progress, this is very interesting!


tripple13 t1_jck9593 wrote

Does anyone know why they didn't add the flashattention directly into the MultiheadAttention-modules? Seems to be integrated, awesome!


programmerChilli t1_jcnydmw wrote

I think it is used in Pytorch’s nn.transformerencoder but a lot of people like implementing their own.


mike94025 t1_jcv94un wrote

SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)

Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)

Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.


mike94025 t1_je5ojaw wrote

It is. Follow the call tree into F.multi_head_attention_forward


tripple13 t1_je5seed wrote

Is that right? I some how end up here when trying to assess what the F.multi_head_attention call does in the Class definition.

But I trust you're right, it would only make sense, I just couldn't identify the calls myself.


Sad-Comedian-711 t1_jcpvahm wrote

So there is flash attention and then there is block sparse flash attention.

Flash attention by itself only got them to 16k on an A100 for their model, to go further they needed to use windowed attention... You could have already gone to 16k with windowed attention before this paper without much issue.

The special thing about this windowed attention is that it is in blocks that can fit into SRAM. From what I can tell Python's implementation of Flash Attention doesn't look like it supports block sparse flash attention.

While Triton's looks like it does:

I think windowing must be done in blocks that align with the SRAM grid so it kinda has to be part of the Flash Attention implementation. You might be able to throw normal Big Bird block sparse attention on top...

You also may be able to call out to triton's implementation:


crude2refined t1_jcwdk7j wrote

In Google colab, I'm not able to reproduce the benefits in pytorch 2 vs 1 with scaled_dot_product_attention. Is there anything I'm missing? Please see attached image:


mike94025 t1_jcx5xvg wrote

Data type?

SDPA currently has 3 kernels implemented by a kernel picker.

  • sdpa_math
  • sdpa_flash
  • sdpa_mem_eff

A kernel picker picks the best given your constraints

  • Math is the trusted kernel from the equation in the paper.
  • Flash only works for FP16 and BF16, and on SM80 (e.g., A100).
  • mem_efficient kernel works on older architecture levels, and supports FP32, but the upside is limited due to lack of compute capacity for FP32. FP16 or BF16 should help. Also, there are requirements on alignment, dropout values etc to qualify for the high-perf SDPA implementations. Dropout required to be 0 @ PT2.0

Also, different kernels parallelize across different dimensions, so B=1 will not work with all of those kernels.

In a nutshell, performance comes at the price of generality, and GPUs are finnecky to get the performance, so our inputs must adhere to those, and parallelization strategies matter for different combinations of dimensions.


tysam_and_co t1_jco9im0 wrote

I have a feeling that's because the heuristic is likely not giving you flash attention at all, but instead this kernel, which appropriately fits your usecase and is in the list of possibly-automatically-selected kernels: