Submitted by thomasahle t3_118gie9 in MachineLearning

Cross entropy on logits is a normal simplification that fuses softmax + cross entropy loss to something like:

def label_cross_entropy_on_logits(x, labels):
    return (-x.select(labels) + x.logsumexp(axis=1)).sum(axis=0)

where x.select(labels) = x[range(batch_size), labels].

I was thinking about how the logsumexp term looks like a regularization term, and wondered what would happen if I just replaced it by x.norm(axis=1) instead. It seemed to work just as well as the original, so I thought, why not just enforce unit norm?

I changed my code to

def label_cross_entropy_on_logits(x, labels):
    return -(x.select(labels) / x.norm(axis=1)).sum(axis=0)

and my training sped up dramatically, and my test loss decreased.

I'm sure this is a standard approach to categorical loss, but I haven't seen it before, and would love to get some references.

I found this old post: https://www.reddit.com/r/MachineLearning/comments/k6ff4w/unit_normalization_crossentropy_loss_outperforms/ which references LogitNormalization: https://arxiv.org/pdf/2205.09310.pdf However, it seems those papers all apply layer normalization and then softmax+CE. What seems to work for me is simply replacing softmax+CE by normalization.

8

Comments

You must log in or register to comment.

ChuckSeven t1_j9iyuc2 wrote

hmm not sure, but I think if you don't exponentiate you cannot fit n targets into a d-dimensional space if n > d and you want there to exist a vector v for each target such that the outcome is a one-hot distribution (or 0 loss).

Basically, if you have 10 targets but only a 2-dimensional space you need to have enough non-linearity in the projection to your target space such that there exists a 2d vector which gives 0 loss for each target.

edit: MNIST only has 10 classes so you are probably fine. Furthermore, softmax of the dot product "care exponentially more" about the angle of the prediction vector than the scale. If you use norm, I'd think that you only care about angle which likely leads to different representations. The fact that those may improve performance highly depends how your model may rely on scale to learn certain predictions. Maybe in case of mnist, relying on scale worsens performance (e.g. if you want a wild guess, because it maybe makes "predictions more certain" simply if it has more pixels set to 1).

3

thomasahle OP t1_j9kapw7 wrote

Even with angles you can still have exponentially many vectors that are nearly orthogonal to each other, if that's what you mean...

I agree the representations will be different. Indeed one issue may be that large negative entries will be penalized as much as large positive ones, which is not the case for logsumexp...

But on the other hand more "geometric" representations like this, based on angles, may make the vectors more suitable for stuff like LSH.

1

activatedgeek t1_j9lux6q wrote

You are implying that the NN learns exp(logits) instead of the logits without really constraining the outputs to be positive. It probably won't be a proper scoring rule though might appear to work.

In some ways, this is similar to how you can learn classifiers with the mean squared error by regressing directly to the one-hot vector of class label (here also you don't care about positive output). It works, and also implies a proper scoring rule called the Brier score.

2

thomasahle OP t1_j9nprmt wrote

Great example! With Brier scoring we have

loss = norm(x)**2 - x[label]**2 + (1-x[label])**2
     = norm(x)**2 - 2*x[label] + 1

which is basically equivalent to replacing logsumexp with norm^2 in the first code

def label_cross_entropy_on_logits(x, labels):
    return (-2*x.select(labels) + x.norm(axis=1)**2).sum(axis=0)

This actually works just as good as my original method! The Wikipedia article for proper scoring functions also mention "Spherical score", which seems to be equivalent to my method of dividing by the norm. So maybe that's the explanation?

Note though that I applied Brier Loss directly on the logits, which is probably not how they are meant to be used...

2

cthorrez t1_j9iq35y wrote

> test loss decreased

What function are you using to evaluate test loss? cross entropy or this norm function?

1

thomasahle OP t1_j9iq4rz wrote

Should have said Accuracy.

Only MNist though. Went from 3.8% error on a simple linear model to 1.2%. In average. With 80%-20% train test split. So in no way amazing, just interesting.

Just wondered if other people had experimented more with it, since it's also a bit faster training.

2

cthorrez t1_j9ir0lx wrote

Have you tried it with say an MLP or small convnet on cifar10? I think that would be the next logical step.

1