Training a CNN model on MNIST using PyTorch

Training a CNN model on MNIST using PyTorch#

Sample MNIST images:

MNIST examples

  • 10 classes

  • 60 thousand training images

  • 10 thousand testing images

  • Each image is monochrome, 28-by-28 pixels.

#@title Import the required modules
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# Define the "device". If GPU is available, device is set to use it, otherwise CPU will be used.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#@title Download the dataset
train_data = datasets.MNIST(root = './data', train = True,
                        transform = transforms.ToTensor(), download = True)

test_data = datasets.MNIST(root = './data', train = False,
                       transform = transforms.ToTensor())
# About the ToTensor() transformation.

# PyTorch networks expect a tensor as input with dimensions N*C*H*W  where
# N: batch size
# C: channel size
# H: height
# W: width

# Normally an image is of size H*W*C.
# ToTensor() transformation moves the channel dimension to the beginning as needed by PyTorch.
#@title Define the data loaders
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset = train_data,
                                             batch_size = batch_size,
                                             shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset =  test_data ,
                                      batch_size = batch_size,
                                      shuffle = False)
#@title Define a CNN network

class CNN(nn.Module):
    #This defines the structure of the NN.
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=28, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 10)
        return x


# Create an instance
net = CNN().to(device)
print(net)
CNN(
  (conv1): Conv2d(1, 10, kernel_size=(28, 28), stride=(1, 1), bias=False)
)
#@title Define the loss function and the optimizer
loss_fun = nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD( net.parameters(), lr=1.e-3)
optimizer = torch.optim.SGD( net.parameters(), lr=0.001, momentum=.9)
#@title Train the model

num_epochs = 5
for epoch in range(num_epochs):
  for i ,(images,labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)
    #labels = labels[:,None].float()

    optimizer.zero_grad()
    output = net(images)
    loss = loss_fun(output, labels)
    loss.backward()
    optimizer.step()

    if (i+1) % batch_size == 0:
      print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                 %(epoch+1, num_epochs, i+1, len(train_data)//batch_size, loss.item()))
Epoch [1/5], Step [100/600], Loss: 1.6162
Epoch [1/5], Step [200/600], Loss: 1.2004
Epoch [1/5], Step [300/600], Loss: 0.9923
Epoch [1/5], Step [400/600], Loss: 0.8917
Epoch [1/5], Step [500/600], Loss: 0.8996
Epoch [1/5], Step [600/600], Loss: 0.7955
Epoch [2/5], Step [100/600], Loss: 0.8166
Epoch [2/5], Step [200/600], Loss: 0.5380
Epoch [2/5], Step [300/600], Loss: 0.6292
Epoch [2/5], Step [400/600], Loss: 0.6441
Epoch [2/5], Step [500/600], Loss: 0.6008
Epoch [2/5], Step [600/600], Loss: 0.5929
Epoch [3/5], Step [100/600], Loss: 0.5692
Epoch [3/5], Step [200/600], Loss: 0.5071
Epoch [3/5], Step [300/600], Loss: 0.4820
Epoch [3/5], Step [400/600], Loss: 0.5348
Epoch [3/5], Step [500/600], Loss: 0.5006
Epoch [3/5], Step [600/600], Loss: 0.5223
Epoch [4/5], Step [100/600], Loss: 0.4071
Epoch [4/5], Step [200/600], Loss: 0.4680
Epoch [4/5], Step [300/600], Loss: 0.5118
Epoch [4/5], Step [400/600], Loss: 0.4691
Epoch [4/5], Step [500/600], Loss: 0.4009
Epoch [4/5], Step [600/600], Loss: 0.4906
Epoch [5/5], Step [100/600], Loss: 0.4025
Epoch [5/5], Step [200/600], Loss: 0.3732
Epoch [5/5], Step [300/600], Loss: 0.3861
Epoch [5/5], Step [400/600], Loss: 0.6019
Epoch [5/5], Step [500/600], Loss: 0.4827
Epoch [5/5], Step [600/600], Loss: 0.3483
#@title Run the trained model on the testing set

correct = 0
total = 0
for images,labels in test_loader:
  images = images.to(device)
  labels = labels.to(device)

  out = net(images)
  _, predicted_labels = torch.max(out,1)
  correct += (predicted_labels == labels).sum()
  total += labels.size(0)

print('Percent correct: %.3f %%' %((100*correct)/(total+1)))
Percent correct: 89.111 %
weights = net.conv1.weight.data.clone().cpu().numpy()#  .features[1].weight.data.clone()
from PIL import Image
from matplotlib import pyplot as plt
for i in range(10):
  filter = weights[i,:,:,:].reshape(28,28)
  plt.figure(i, figsize=(1,1))
  plt.imshow(filter, cmap='gray')
  im = Image.fromarray(filter, "L")
  im.save("file%d.png" % i)
_images/c5ba94080d7353ccc7fbcada154821026eea8873c1c0757746eecb5d98d3d06a.png _images/b5e56f242300547da6c007a38cd190df3d2aac2872878f621932e2a041fb2ca0.png _images/504296a8ac47153324db4f1d9a785506210e5b87ed947ca47892c6389bf3f3e1.png _images/9ef20027bde1c9d001edd6bafcc10b02a8ddfe2729378a4b3d708b1ce3d31a56.png _images/882b67555e794e3117aacb66343fb098c2e0aa8ff65724f3e630af12c8c8ca9a.png _images/28a3f2faf6793d0eb404b296f2807d2fc7adbd7011ede2e8cc6a347ffed03528.png _images/506f22f221d118fc48fca7594a74b1c4d650db51253944e424e7284000b752e1.png _images/505ba8608af2bd3635c3413f4f96127a34b00ab8e7c078eb07779b2984cb9f2b.png _images/3ff74f403428539e06be660179ef6c3aea80938ee3d14d28f77f2c029756a49a.png _images/f68b4cc290a120d77279587f645093329e88be2d057242e2337a0eb8a6dd8f45.png