Viewing a single comment thread. View all comments

trajo123 t1_iymuivu wrote

To answer you question concretely: in classification you want your model output to reflect a probability distribution over the classes. If you have only 2 classes this can be achieved with 1 output unit producing values ranging from 0 to 1. If you have more than 2 classes then you need 1 unit per class so that each one produces a value in the (0,1) range and also that the sum of all units adds up to 1 to pass as a probability distribution. In case of 1 output unit the sigmoid function ensures that the output is 0,1 and in case of multiple output units softmax ensures the conditions mentioned above. Now, in practice, classification models don't use an explicit activation function after the last layer, instead the loss incorporates the appropriate activation function due to efficiency and numerical stability reasons. So in case of binary classification you have two equivalent options:

  • use 1 output unit with torch.nn.BCEWithLogitsLoss

>This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.

  • use 2 output units with torch.nn.CrossEntropyLoss

>This criterion computes the cross entropy loss between input logits and target

Both of these approaches are mathematically equivalent and should produce the same results up to numerical considerations. If you get wildly different predictions, it means you did something wrong.

On another note, using accuracy when looking at credit card fraud detection is not a good idea because the dataset is most likely highly unbalanced. Probably more than 99% of the data samples are labelled as "not fraud". In this case, having a stupid model always produce "not fraud" regardless of input will already give you 99% accuracy. You may want to look into metrics for unbalanced datasets, e.g. F1 score, false positive rate, false negative rate, etc.

Have fun on your (deep) learning journey!

2