Training a CNN model on MNIST using PyTorch#
Sample MNIST images:
10 classes
60 thousand training images
10 thousand testing images
Each image is monochrome, 28-by-28 pixels.
#@title Import the required modules
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# Define the "device". If GPU is available, device is set to use it, otherwise CPU will be used.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#@title Download the dataset
train_data = datasets.MNIST(root = './data', train = True,
transform = transforms.ToTensor(), download = True)
test_data = datasets.MNIST(root = './data', train = False,
transform = transforms.ToTensor())
# About the ToTensor() transformation.
# PyTorch networks expect a tensor as input with dimensions N*C*H*W where
# N: batch size
# C: channel size
# H: height
# W: width
# Normally an image is of size H*W*C.
# ToTensor() transformation moves the channel dimension to the beginning as needed by PyTorch.
#@title Define the data loaders
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset = train_data,
batch_size = batch_size,
shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_data ,
batch_size = batch_size,
shuffle = False)
#@title Define a CNN network
class CNN(nn.Module):
#This defines the structure of the NN.
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=28, bias=False)
def forward(self, x):
x = self.conv1(x)
x = x.view(-1, 10)
return x
# Create an instance
net = CNN().to(device)
print(net)
CNN(
(conv1): Conv2d(1, 10, kernel_size=(28, 28), stride=(1, 1), bias=False)
)
#@title Define the loss function and the optimizer
loss_fun = nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD( net.parameters(), lr=1.e-3)
optimizer = torch.optim.SGD( net.parameters(), lr=0.001, momentum=.9)
#@title Train the model
num_epochs = 5
for epoch in range(num_epochs):
for i ,(images,labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
#labels = labels[:,None].float()
optimizer.zero_grad()
output = net(images)
loss = loss_fun(output, labels)
loss.backward()
optimizer.step()
if (i+1) % batch_size == 0:
print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
%(epoch+1, num_epochs, i+1, len(train_data)//batch_size, loss.item()))
Epoch [1/5], Step [100/600], Loss: 1.6611
Epoch [1/5], Step [200/600], Loss: 1.1334
Epoch [1/5], Step [300/600], Loss: 0.9867
Epoch [1/5], Step [400/600], Loss: 0.8299
Epoch [1/5], Step [500/600], Loss: 0.8039
Epoch [1/5], Step [600/600], Loss: 0.6887
Epoch [2/5], Step [100/600], Loss: 0.7048
Epoch [2/5], Step [200/600], Loss: 0.6080
Epoch [2/5], Step [300/600], Loss: 0.5313
Epoch [2/5], Step [400/600], Loss: 0.7372
Epoch [2/5], Step [500/600], Loss: 0.5846
Epoch [2/5], Step [600/600], Loss: 0.4717
Epoch [3/5], Step [100/600], Loss: 0.5912
Epoch [3/5], Step [200/600], Loss: 0.4872
Epoch [3/5], Step [300/600], Loss: 0.5721
Epoch [3/5], Step [400/600], Loss: 0.5382
Epoch [3/5], Step [500/600], Loss: 0.5281
Epoch [3/5], Step [600/600], Loss: 0.5987
Epoch [4/5], Step [100/600], Loss: 0.6145
Epoch [4/5], Step [200/600], Loss: 0.4519
Epoch [4/5], Step [300/600], Loss: 0.4990
Epoch [4/5], Step [400/600], Loss: 0.7298
Epoch [4/5], Step [500/600], Loss: 0.4356
Epoch [4/5], Step [600/600], Loss: 0.5161
Epoch [5/5], Step [100/600], Loss: 0.3879
Epoch [5/5], Step [200/600], Loss: 0.4139
Epoch [5/5], Step [300/600], Loss: 0.4104
Epoch [5/5], Step [400/600], Loss: 0.3494
Epoch [5/5], Step [500/600], Loss: 0.4482
Epoch [5/5], Step [600/600], Loss: 0.5603
#@title Run the trained model on the testing set
correct = 0
total = 0
for images,labels in test_loader:
images = images.to(device)
labels = labels.to(device)
out = net(images)
_, predicted_labels = torch.max(out,1)
correct += (predicted_labels == labels).sum()
total += labels.size(0)
print('Percent correct: %.3f %%' %((100*correct)/(total+1)))
Percent correct: 89.131 %
weights = net.conv1.weight.data.clone().cpu().numpy()# .features[1].weight.data.clone()
from PIL import Image
from matplotlib import pyplot as plt
for i in range(10):
filter = weights[i,:,:,:].reshape(28,28)
plt.figure(i, figsize=(1,1))
plt.imshow(filter, cmap='gray')
im = Image.fromarray(filter, "L")
im.save("file%d.png" % i)