Training a CNN model on MNIST using PyTorch

Training a CNN model on MNIST using PyTorch#

Sample MNIST images:

MNIST examples

  • 10 classes

  • 60 thousand training images

  • 10 thousand testing images

  • Each image is monochrome, 28-by-28 pixels.

#@title Import the required modules
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# Define the "device". If GPU is available, device is set to use it, otherwise CPU will be used.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#@title Download the dataset
train_data = datasets.MNIST(root = './data', train = True,
                        transform = transforms.ToTensor(), download = True)

test_data = datasets.MNIST(root = './data', train = False,
                       transform = transforms.ToTensor())
# About the ToTensor() transformation.

# PyTorch networks expect a tensor as input with dimensions N*C*H*W  where
# N: batch size
# C: channel size
# H: height
# W: width

# Normally an image is of size H*W*C.
# ToTensor() transformation moves the channel dimension to the beginning as needed by PyTorch.
#@title Define the data loaders
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset = train_data,
                                             batch_size = batch_size,
                                             shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset =  test_data ,
                                      batch_size = batch_size,
                                      shuffle = False)
#@title Define a CNN network

class CNN(nn.Module):
    #This defines the structure of the NN.
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=28, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 10)
        return x


# Create an instance
net = CNN().to(device)
print(net)
CNN(
  (conv1): Conv2d(1, 10, kernel_size=(28, 28), stride=(1, 1), bias=False)
)
#@title Define the loss function and the optimizer
loss_fun = nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD( net.parameters(), lr=1.e-3)
optimizer = torch.optim.SGD( net.parameters(), lr=0.001, momentum=.9)
#@title Train the model

num_epochs = 5
for epoch in range(num_epochs):
  for i ,(images,labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)
    #labels = labels[:,None].float()

    optimizer.zero_grad()
    output = net(images)
    loss = loss_fun(output, labels)
    loss.backward()
    optimizer.step()

    if (i+1) % batch_size == 0:
      print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                 %(epoch+1, num_epochs, i+1, len(train_data)//batch_size, loss.item()))
Epoch [1/5], Step [100/600], Loss: 1.6611
Epoch [1/5], Step [200/600], Loss: 1.1334
Epoch [1/5], Step [300/600], Loss: 0.9867
Epoch [1/5], Step [400/600], Loss: 0.8299
Epoch [1/5], Step [500/600], Loss: 0.8039
Epoch [1/5], Step [600/600], Loss: 0.6887
Epoch [2/5], Step [100/600], Loss: 0.7048
Epoch [2/5], Step [200/600], Loss: 0.6080
Epoch [2/5], Step [300/600], Loss: 0.5313
Epoch [2/5], Step [400/600], Loss: 0.7372
Epoch [2/5], Step [500/600], Loss: 0.5846
Epoch [2/5], Step [600/600], Loss: 0.4717
Epoch [3/5], Step [100/600], Loss: 0.5912
Epoch [3/5], Step [200/600], Loss: 0.4872
Epoch [3/5], Step [300/600], Loss: 0.5721
Epoch [3/5], Step [400/600], Loss: 0.5382
Epoch [3/5], Step [500/600], Loss: 0.5281
Epoch [3/5], Step [600/600], Loss: 0.5987
Epoch [4/5], Step [100/600], Loss: 0.6145
Epoch [4/5], Step [200/600], Loss: 0.4519
Epoch [4/5], Step [300/600], Loss: 0.4990
Epoch [4/5], Step [400/600], Loss: 0.7298
Epoch [4/5], Step [500/600], Loss: 0.4356
Epoch [4/5], Step [600/600], Loss: 0.5161
Epoch [5/5], Step [100/600], Loss: 0.3879
Epoch [5/5], Step [200/600], Loss: 0.4139
Epoch [5/5], Step [300/600], Loss: 0.4104
Epoch [5/5], Step [400/600], Loss: 0.3494
Epoch [5/5], Step [500/600], Loss: 0.4482
Epoch [5/5], Step [600/600], Loss: 0.5603
#@title Run the trained model on the testing set

correct = 0
total = 0
for images,labels in test_loader:
  images = images.to(device)
  labels = labels.to(device)

  out = net(images)
  _, predicted_labels = torch.max(out,1)
  correct += (predicted_labels == labels).sum()
  total += labels.size(0)

print('Percent correct: %.3f %%' %((100*correct)/(total+1)))
Percent correct: 89.131 %
weights = net.conv1.weight.data.clone().cpu().numpy()#  .features[1].weight.data.clone()
from PIL import Image
from matplotlib import pyplot as plt
for i in range(10):
  filter = weights[i,:,:,:].reshape(28,28)
  plt.figure(i, figsize=(1,1))
  plt.imshow(filter, cmap='gray')
  im = Image.fromarray(filter, "L")
  im.save("file%d.png" % i)
_images/e0208cdd7330304f3ed4430f72744a6afcef9650df32d958480272629eacf743.png _images/ac31a99fc73c32a9e8595012423228f81a5cc2c54c0cc4b192a37eb70aa454e2.png _images/c377446996d6895d9064e39538eaa7d42a641bb42bdd46b44e4fa6b2df2de7f4.png _images/e3c79a15376bf24d5e1638c436f6b9f8fdc424e052ce0b0156af672d317a89aa.png _images/40d3b5ed8edde795945ce1394f0859855c4ef6d3fbd32449e412d237c62600d6.png _images/f6486b0e25197c6d5ece00e3b0e8a1c0f71ae6ce29066cd5972ae1e9014891ae.png _images/665135ca012922294a6bed20ebdc175f8e8ecdf6b4a46c63158aa00b8d5f0315.png _images/36c3dc65668ba6a9d49060de3ef0ac4ef0a7dd5e2e6eb5034623e73d9e306796.png _images/4dbe89670cf7dadd465128dcc877ea56d2cc67c22bf31abd6500e9b5f48b4b7e.png _images/7c8868ca6c73eed3cf1137ddfe67205ae190f14e2080f5a59df7d9059faba273.png