Submitted by Scared_Employer6992 t3_11dd59q in MachineLearning

I have to train an UNet-like architecture for semantic segmentation with 200 outcome classes. When outcoming a final map of 4x200x500x500, batch size of 4 and 200 channels (no. of semantic classes). It blows up my GPU memory (40GB).

My first thought is only to create a broad category to reduce the number of classes. Does someone have a suggestion or tricks to accomplish this semantic segmentation task in a savvier way?

15

Comments

You must log in or register to comment.

badabummbadabing t1_ja7yb9y wrote

The problem might be the number of output channels at high resolution. Instead of computing the final layer's activations and gradients in parallel for each channel, you should be able to sequentially compute each channel's loss and add their gradients in the end. This is easy, because the loss decomposes as a sum over the channels (and thus, also the channels' gradients).

In pytorch, this whole thing should then be as simple as running the forward and backward passes for the channels of the final layer sequentially (before calling optimizer.step() and optimizer.zero_grad() once). You will probably also need to retain_graph=True on every backward call, otherwise the activations in the preceding layers will be deleted before you get to the next channel.

16

kweu t1_ja8ur2q wrote

I work with data of similar size and I use random crops during training and a sliding window for prediction. For example you could train to segment 128x128-sized crops of the input images, then put the predictions together to segment the image at full resolution and keep your 200 classes probably. But tbh 200 sounds a bit excessive anyway

5

QuadmasterXLII t1_ja7wog6 wrote

... does it fit with batch size 1?

2

QuadmasterXLII t1_ja7yo0f wrote

Your problem is the U-Net backbone, not the loss function. Assuming that you're married to a batch size of 4, the final convolution to get to 4 x 200 x 500 x 500, crossentropy, and the backpropagation should only take maybe 10 GB, so cram your architecture into the remaining 30GB

import torch
x = torch.randn([4, 128, 500, 500]).cuda()
z = torch.nn.Conv2d(128, 200, 3)
z.cuda()
q = torch.randint(0, 200, (4, 498, 498)).cuda()
torch.nn.CrossEntropyLoss()(z(x), q).backward()

for example, takes 7.5 GB.

2

Scared_Employer6992 OP t1_ja7xpjt wrote

I haven't tried with bs=1, but I also don't want to use bs=1 as I usually get bad results with it and my net has a lot of BN layers.

0

badabummbadabing t1_ja7yxbg wrote

Don't use batch normalization. Lots of U-Nets use e.g. instance normalisation. A batch size of 1 should be completely fine (but you will need to play with the learning rate upon changing this). Check the 'no new U-Net' (aka NN-Unet) paper by Fabian Isensee for the definitive resource on what matters in U-Nets.

10

prettyyyyprettyygood t1_jabcw8d wrote

Perhaps you can train at a lower resolution, then later train a separate model to "super-resolution" the segmentation predictions.

1