What will we cover?
- What is PyTorch
- PyTorch vs Tensorflow
- Get started with PyTorch
- Work with image classification
Step 1: What is PyTorch?
PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
What does that mean?
Well, PyTorch is an open source machine learning library and is used for computer vision and natural language processing. It is primarily developed by Facebook’s AI Research Lab.
Step 2: PyTorch and Tensorflow
Often people worry about which framework to use not to waste time.
You probably do the same – but don’t worry, if you use either PyTorch or Tensorflow, then you are on the right track. They are the most popular Deep Learning frameworks, if you learn one, then you will have an easy time to switch to the other later.
PyTorch was release in 2016 by Facebook’s Research Lab, while Tensorflow was released in 2015 by Google Brain team.
Both are good choices for Deep Learning.
Step 3: PyTorch and prepared datasets
PyTorch comes with a long list of prepared datasets and you can see them all here.
We will look at the MNIST dataset for handwritten digit-recognition.
In the video above we also look at the CIFAR10 data set, which consist of 32×32 images of 10 classes.
You can get a dataset by using torchvision.
from torchvision import datasets data_path = 'downloads/' mnist = datasets.MNIST(data_path, train=True, download=True)
Step 4: Getting the data and prepare data
First we need to get the data and prepare them by turning them into tensors and normalize them.
- Images are PIL objects in the MNIST dataset
- You need to be transformed to tensor (the datatype for Tensorflow)
- torchvision has transformations transform.ToTensor(), which turns NumPy arrays and PIL to Tensor
- Then you need to normalize images
- Need to determine the mean value and the standard deviation
- Then we can apply nomalization
- torchvision has transform.Normalize, which takes mean and standard deviation
from torchvision import datasets from torchvision import transforms import torch import torch.nn as nn from torch import optim import matplotlib.pyplot as plt data_path = 'downloads/' mnist = datasets.MNIST(data_path, train=True, download=True) mnist_val = datasets.MNIST(data_path, train=False, download=True) mnist = datasets.MNIST(data_path, train=True, download=False, transform=transforms.ToTensor()) imgs = torch.stack([img_t for img_t, _ in mnist], dim=3) print('get mean') print(imgs.view(1, -1).mean(dim=1)) print('get standard deviation') print(imgs.view(1, -1).std(dim=1))
Then we can use those values to make the transformation.
mnist = datasets.MNIST(data_path, train=True, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))])) mnist_val = datasets.MNIST(data_path, train=False, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))]))
Step 5: Creating and testing a Model
The model we will use will be as follows.
We can model that as follows.
input_size = 784 # ?? 28*28 hidden_sizes = [128, 64] output_size = 10 model = nn.Sequential(nn.Linear(input_size, hidden_sizes), nn.ReLU(), nn.Linear(hidden_sizes, hidden_sizes), nn.ReLU(), nn.Linear(hidden_sizes, output_size), nn.LogSoftmax(dim=1))
Then we can train the model as follows
train_loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True) optimizer = optim.SGD(model.parameters(), lr=0.01) loss_fn = nn.NLLLoss() n_epochs = 10 for epoch in range(n_epochs): for imgs, labels in train_loader: optimizer.zero_grad() batch_size = imgs.shape output = model(imgs.view(batch_size, -1)) loss = loss_fn(output, labels) loss.backward() optimizer.step() print("Epoch: %d, Loss: %f" % (epoch, float(loss)))
And finally, test our model.
val_loader = torch.utils.data.DataLoader(mnist_val, batch_size=64, shuffle=True) correct = 0 total = 0 with torch.no_grad(): for imgs, labels in val_loader: batch_size = imgs.shape outputs = model(imgs.view(batch_size, -1)) _, predicted = torch.max(outputs, dim=1) total += labels.shape correct += int((predicted == labels).sum()) print("Accuracy: %f", correct / total)
Reaching an accuracy of 96.44%
Want to learn more?
Want better results? Try using a CNN model.
This is part of a FREE 10h Machine Learning course with Python.
- 15 video lessons – which explain Machine Learning concepts, demonstrate models on real data, introduce projects and show a solution (YouTube playlist).
- 30 JuPyter Notebooks – with the full code and explanation from the lectures and projects (GitHub).
- 15 projects – with step guides to help you structure your solutions and solution explained in the end of video lessons (GitHub).