Viewing a single comment thread. View all comments

MustachedSpud t1_j8sacz8 wrote

They might be thinking in a different direction than me, but the majority of Memory use during training is not from the model weights or optimizer state in most cases. It comes from tracking all the activations of the training batch. If you think about a cnn, each filter gets used across the whole image so you will have many more activations than filters. So optimizer memory savings has very limited benefits

3

ChuckSeven t1_j8svm1b wrote

those are way less. for every vector of activations you usually have that squared in weights time 2 or 3 depending of how many momentum values you keep.

1

MustachedSpud t1_j8t25bb wrote

Not true, in any case with convolution, attention, or recurrence, which are most modern applications. In all of these cases the activation count grows with how often weights are reused as well as with batch size. Those dominate optimizer memory usage unless you used a tiny batch size.

That's why checkpointing can be useful. This paper does a solid job covering memory usage: https://scholar.google.com/scholar?q=low+memory+neural+network+training+checkpoint&hl=en&as_sdt=0&as_vis=1&oi=scholart#d=gs_qabs&t=1676575377350&u=%23p%3DOLSwmmdygaoJ

2

ChuckSeven t1_j8t5r5m wrote

yea it depends. Even just batch-size makes a difference. But for really big models, I'd assume that the number of weights far outweighs the number of activations.

3

MustachedSpud t1_j8t65fh wrote

Yeah very configuration dependent, but larger batch sizes usually learn faster so there's a tendency to lean into that

1