PyTorch Model to Detect Handwriting for Beginners

What will we cover?

  • What is PyTorch
  • PyTorch vs Tensorflow
  • Get started with PyTorch
  • Work with image classification

Step 1: What is PyTorch?

PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.

What does that mean?

Well, PyTorch is an open source machine learning library and is used for computer vision and natural language processing. It is primarily developed by Facebook’s AI Research Lab.

Step 2: PyTorch and Tensorflow

Often people worry about which framework to use not to waste time.

You probably do the same – but don’t worry, if you use either PyTorch or Tensorflow, then you are on the right track. They are the most popular Deep Learning frameworks, if you learn one, then you will have an easy time to switch to the other later.

PyTorch was release in 2016 by Facebook’s Research Lab, while Tensorflow was released in 2015 by Google Brain team.

Both are good choices for Deep Learning.

Step 3: PyTorch and prepared datasets

PyTorch comes with a long list of prepared datasets and you can see them all here.

We will look at the MNIST dataset for handwritten digit-recognition.

In the video above we also look at the CIFAR10 data set, which consist of 32×32 images of 10 classes.

You can get a dataset by using torchvision.

from torchvision import datasets

data_path = 'downloads/'
mnist = datasets.MNIST(data_path, train=True, download=True)

Step 4: Getting the data and prepare data

First we need to get the data and prepare them by turning them into tensors and normalize them.

Transforming and Normalizing

  • Images are PIL objects in the MNIST dataset
  • You need to be transformed to tensor (the datatype for Tensorflow)
    • torchvision has transformations transform.ToTensor(), which turns NumPy arrays and PIL to Tensor
  • Then you need to normalize images
    • Need to determine the mean value and the standard deviation
  • Then we can apply nomalization
    • torchvision has transform.Normalize, which takes mean and standard deviation
from torchvision import datasets
from torchvision import transforms
import torch
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt

data_path = 'downloads/'
mnist = datasets.MNIST(data_path, train=True, download=True)
mnist_val = datasets.MNIST(data_path, train=False, download=True)

mnist = datasets.MNIST(data_path, train=True, download=False, transform=transforms.ToTensor())

imgs = torch.stack([img_t for img_t, _ in mnist], dim=3)

print('get mean')
print(imgs.view(1, -1).mean(dim=1))

print('get standard deviation')
print(imgs.view(1, -1).std(dim=1))

Then we can use those values to make the transformation.

mnist = datasets.MNIST(data_path, train=True, download=False, 

mnist_val = datasets.MNIST(data_path, train=False, download=False, 

Step 5: Creating and testing a Model

The model we will use will be as follows.

We can model that as follows.

input_size = 784 # ?? 28*28
hidden_sizes = [128, 64]
output_size = 10

model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                     nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                     nn.Linear(hidden_sizes[1], output_size),

Then we can train the model as follows

train_loader =, batch_size=64,

optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.NLLLoss()

n_epochs = 10
for epoch in range(n_epochs):
    for imgs, labels in train_loader:

        batch_size = imgs.shape[0]
        output = model(imgs.view(batch_size, -1))

        loss = loss_fn(output, labels)


    print("Epoch: %d, Loss: %f" % (epoch, float(loss)))

And finally, test our model.

val_loader =, batch_size=64,

correct = 0
total = 0
with torch.no_grad():
    for imgs, labels in val_loader:
        batch_size = imgs.shape[0]
        outputs = model(imgs.view(batch_size, -1))
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())
print("Accuracy: %f", correct / total)

Reaching an accuracy of 96.44%

Want to learn more?

Want better results? Try using a CNN model.

This is part of a FREE 10h Machine Learning course with Python.

  • 15 video lessons – which explain Machine Learning concepts, demonstrate models on real data, introduce projects and show a solution (YouTube playlist).
  • 30 JuPyter Notebooks – with the full code and explanation from the lectures and projects (GitHub).
  • 15 projects – with step guides to help you structure your solutions and solution explained in the end of video lessons (GitHub).

Leave a Reply Cancel reply

Exit mobile version