# Training a CNN model on MNIST using PyTorch

Sample MNIST images:

![MNIST examples](https://www.researchgate.net/profile/Stefan_Elfwing/publication/266205382/figure/fig5/AS:267913563209738@1440886979379/Example-images-of-the-ten-handwritten-digits-in-the-MNIST-training-set.png)

- 10 classes
- 60 thousand training images
- 10 thousand testing images
- Each image is monochrome, 28-by-28 pixels.

In [None]:
#@title Import the required modules
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [None]:
# Define the "device". If GPU is available, device is set to use it, otherwise CPU will be used.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
#@title Download the dataset
train_data = datasets.MNIST(root = './data', train = True,
                        transform = transforms.ToTensor(), download = True)

test_data = datasets.MNIST(root = './data', train = False,
                       transform = transforms.ToTensor())

In [None]:
# About the ToTensor() transformation.

# PyTorch networks expect a tensor as input with dimensions N*C*H*W  where
# N: batch size
# C: channel size
# H: height
# W: width

# Normally an image is of size H*W*C.
# ToTensor() transformation moves the channel dimension to the beginning as needed by PyTorch.

In [None]:
#@title Define the data loaders
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset = train_data,
                                             batch_size = batch_size,
                                             shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset =  test_data ,
                                      batch_size = batch_size,
                                      shuffle = False)

In [None]:
#@title Define a CNN network

class CNN(nn.Module):
    #This defines the structure of the NN.
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=28, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 10)
        return x


# Create an instance
net = CNN().to(device)

In [None]:
print(net)

In [None]:
#@title Define the loss function and the optimizer
loss_fun = nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD( net.parameters(), lr=1.e-3)
optimizer = torch.optim.SGD( net.parameters(), lr=0.001, momentum=.9)

In [None]:
#@title Train the model

num_epochs = 5
for epoch in range(num_epochs):
  for i ,(images,labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)
    #labels = labels[:,None].float()

    optimizer.zero_grad()
    output = net(images)
    loss = loss_fun(output, labels)
    loss.backward()
    optimizer.step()

    if (i+1) % batch_size == 0:
      print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                 %(epoch+1, num_epochs, i+1, len(train_data)//batch_size, loss.item()))

In [None]:
#@title Run the trained model on the testing set

correct = 0
total = 0
for images,labels in test_loader:
  images = images.to(device)
  labels = labels.to(device)

  out = net(images)
  _, predicted_labels = torch.max(out,1)
  correct += (predicted_labels == labels).sum()
  total += labels.size(0)

print('Percent correct: %.3f %%' %((100*correct)/(total+1)))

In [None]:
weights = net.conv1.weight.data.clone().cpu().numpy()#  .features[1].weight.data.clone()


In [None]:
from PIL import Image
from matplotlib import pyplot as plt
for i in range(10):
  filter = weights[i,:,:,:].reshape(28,28)
  plt.figure(i, figsize=(1,1))
  plt.imshow(filter, cmap='gray')
  im = Image.fromarray(filter, "L")
  im.save("file%d.png" % i)