Viewing a single comment thread. View all comments

badabummbadabing t1_ja7yb9y wrote

The problem might be the number of output channels at high resolution. Instead of computing the final layer's activations and gradients in parallel for each channel, you should be able to sequentially compute each channel's loss and add their gradients in the end. This is easy, because the loss decomposes as a sum over the channels (and thus, also the channels' gradients).

In pytorch, this whole thing should then be as simple as running the forward and backward passes for the channels of the final layer sequentially (before calling optimizer.step() and optimizer.zero_grad() once). You will probably also need to retain_graph=True on every backward call, otherwise the activations in the preceding layers will be deleted before you get to the next channel.

16