Viewing a single comment thread. View all comments

trajo123 t1_j3busy6 wrote

First of all, the dataset size is way too small to train a model from scratch to give meaningful results on this relatively complex task (more complex than MNIST for example, which has a training set of 60000 images). Second, your model is way too small/simple for this task even if you would have 100 times more data. I strongly suggest "Transfer Learning" - fine-tuning a pre-trained model by replacing the classification head, freezing the rest of the model in place and training on your dataset.

Something along these lines:

from torchvision import transforms, models

# ...

model = models.swin_b(weights=models.Swin_B_Weights.IMAGENET1K_V1)
model.heads[0] = nn.Linear(model.heads[0].in_features, 1, bias=True)
# ...
)

In the pre-trained model documentation you will see what training recippe was used and what transforms were applied to the image. Typically:

transforms.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
            )
            
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)

See more at <https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights>. You can also find pre-trained models HuggingFace / VisionModels.

Hope this helps, good luck!

3

trajo123 t1_j3c38rx wrote

Several things I noticed in your code:

  • your model doesn't use any transfer function
  • the combination of final activation function and loss function is incorrect
  • for CNN you should be using BatchNorm2D layers

The code should look something like this:

    def __init__(self, input_size, num_classes):
        super(CNNClassifier, self).__init__()
        self.input_size = input_size
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1) # increase the number of channels
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=128, kernel_size=3, stride=1, padding=1) # increase the number of channels
        self.bn2 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(128, 256)  # note the smaller numbers
        self.fc2 = nn.Linear(256, num_classes)
        self.bn1 = nn.BatchNorm2d(32),
        self.final_pool = nn.AdaptiveAvgPool2d(1)  # before flatten, you should use AdaptiveMaxPool2d, or AdaptiveAvgPool2d to get rid of the spatial dimensions, essentially treat each filter as one feature
        # self.softmax = nn.Softmax(dim=1) - not needed, see below. Also Softmax is not correct for use with NLLLoss, he correct one would be LogSoftmax(dim=1)
        self.f = nn.ReLU()
        
    def forward(self, x):     
        x = self.conv1(x)
        x = self.pool(x)
        x = self.f(x)  # apply the transfer function
        x = self.bn1(x) # apply batch norm (this can also be placed before the transfer function)

        x = self.conv2(x)   
        x = self.pool(x)
        x = self.f(x)  # apply the transfer function        
        x = self.bn2(x) # apply batch norm (this can also be placed before the transfer function)

        # since you are now using batchnorm, you could add a few more blocks like the one above, vanishing gradients are less of a concern now

        x = self.final_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.f(x)  # apply the transfer function, here you could try tanh as well
        x = self.fc2(x)
        # x = self.softmax(x)  # no need for a function here because it is incorporated into the loss function for numerical/computational efficiency reasons
        return x

Also, the loss should be

# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()  # the more natural choice of loss function for classification, actually for binary classification the more natural choice would be BCEWithLogitsLoss, but then you need to set the number of number of output units to 1.
1

trajo123 t1_j3c3cvf wrote

...let me know if it works any better!

1

AKavun OP t1_j3l51kx wrote

Thank you sir, I posted a general update to this thread and I will be further updating you about everything.

1