Submitted by ButterscotchLost421 t3_yvmuuc in MachineLearning

Hey,

I am currently training a diffusion model on CIFAR.

The network is very similar to the code in the annotated diffusion model blog post (https://huggingface.co/blog/annotated-diffusion).

Checking Yang Songs code for CIFAR 10 ( https://github.com/yang-song/score_sde ), I see that the DM is trained for a staggering amount of 1 300 000 epochs.

One epoch takes 7 seconds on the machine (NVIDIA A100-SXM4-40GB).

Therefore overall training would take 2500 hours, i.e. a hundred days?

What am I doing wrong? Was the model trained on an even better GPU (what kind of scale)? Or should training an epoch of 50k examples take way below 7 seconds? Or did this really train for a hundred days?

3

Comments

You must log in or register to comment.

yanivbl t1_iwgb683 wrote

They had more GPUs, training in parallel. Not sure about cifar10 but I read the number for ADM with imagenet is ~1000 days for a single V100.

4

ButterscotchLost421 OP t1_iwggwv8 wrote

Thank you! What do you mean by ADM? Adam?

When training in parallel, which technique did they use? Calculate the gradient of a batch of size `N` on each of the devices and then synchronizing all the different devices to get the mean gradient?

1

yanivbl t1_iwgjnht wrote

No, not Adam, I was referring to the model from the diffusion beats Gans paper.

I never trained such model, just read it. But yeah it's most likely what you said (a.k.a data parallelism)

2

samb-t t1_iwguscp wrote

Have you got the 1.3M number from the config file (config.training.n_iters = 1300001), if so that's the number of training steps not epochs! So hopefully more like around 7 hours to train on an A100, thank god!

3

ButterscotchLost421 OP t1_iwgwqq0 wrote

Ah yes, you're right! Thank you so much!

Does 7 secs per epoch sound approximately right to you?

2

samb-t t1_iwgz27t wrote

7 secs sounds very fast but if you're not using a massive model, it's on cifar, and on an A100 it's not implausible, but you might want to double check so you're sure

3

dasayan05 t1_iwi4lir wrote

I have trained DDPM (not SDE) on CIFAR10 using 4 3090s with effective batch size of 1024. Took ~150k iterations (not epochs) and about 1.5 days to reach FID 2.8 (not really SOTA, but works).

1