Mobile Media Apps Assignment: How to Deploy a PyTorch Image Classification Model to Flask For Beginners (Part 1)

Convolutional neural networks? Deep Learning? RESTFul APIs? These buzzwords may sound intimidating, but with the Python Frameworks of today, they are the opposite. In this tutorial, I will walk through how to train a neural network for the image classification task using PyTorch, and in a later tutorial, I will subsequently show how to use our model as part of a REST API via Flask.


For this tutorial, you will need the following:

  • Some basic Python knowledge, syntax and such.

While it is helpful to know some AI concepts, I will try to explain ideas for beginners or point to resources that explain ideas better than me.


I will be showing off a binary classification task using Microsoft’s Cat versus Dog dataset, featured in a famous 2013 Kaggle competition. However, our model will be able to solve any good multi-class image classification dataset, such as CIFAR-10 or ImageNet, given the right modifications.

Some image examples from our Cat vs Dog dataset.

Downloading Data

You can download the dataset using the link above. If you’re using CoLab like me, you can also directly download and unzip the dataset with the following code-block.

!wget!unzip -d data

If you’re unfamiliar with the CoLab block above, here’s a short explanation of what this it does. The ! in front of these two lines signals that these are shell commands in CoLab. wget is a shell command which directly downloads the content at a remote location to the shell environment, and unzip is self explanatory, extracting the contents of the zip, where the -d flag gives the directory to put the extracted contents.

Running this will put our images indata/PetImages . The directory structure is shown below:

| |
| |__Cat (directory of cat images
| |
| |__Dog (directory of dog images)
|__MSR-LA - 3467.docx

Cleaning the Data

Unfortunately, our dataset will require some cleaning. Within the Cat and Dog directories, we have a paws.db file which we won’t need. If you download this locally, the .db file can manually be deleted. However, the dataset also contains some corrupt images, which are going to be harder to find. To delete both corrupt images and .db files, I’ve created a code block which can be run either locally or in a CoLab block.

import os
from os import path
import PIL
rm_count, count = 0, 0for dir in ("Cat", "Dog"):
dest = path.join("data/PetImages", dir)
files = os.listdir(dest)
for fname in files:
full = path.join(dest, fname)
if fname.find("jpg") == -1:
rm_count += 1
except PIL.UnidentifiedImageError:
rm_count += 1
count += 1
print("Kept %d files, removed %d files" % (count, rm_count))

The code isn’t too important to understand, but what it does is that it keeps JPEG files that can be opened without error by the image library we will be using, PIL. Run this once and your data is good to go.

PyTorch Datasets and DataLoaders

Our dataset, in the end, will contain 24,998 total images. During training, we will be running our model on a substantial subset of these images. If we want to train relatively quickly, we will have to batch our images during training, and if we want to increase performance, we’ll also have to shuffle our training images on each train iteration. On top of this, we’ll also have to adjust each image to be smaller, so training won’t take too long, and convert image files to a tensor datatype (think an n-dimensional vector in PyTorch) that our model can read.

This sounds painful to do by ourselves. Fortunately, PyTorch gives us two convenience classes that do the hard work for us: the Dataset and the DataLoader classes from the module.

A torch Dataset is a data structure which has a fixed length that you can index into to retrieve a training example and its corresponding label attached. To create your own customDataset, you need to initialize a Python subclass of Dataset and implement the following methods: __len__, which returns the amount of training examples, and __getitem__, which returns a training example given an index.

Now lets walk through implementing our Dataset . First, we’ll start off with our imports and our class definition

from import Dataset
from torchvision import transforms
import PIL
import os
class CatDogDataset(Dataset): def __init__(self, ...):

def __len__(self):

def __getitem__(self, idx):

Our Dataset Class Skeleton

Our first four lines consists of necessary imports for our Dataset. These will come into play later. Then, we have our CatDogDataset class, extending from PyTorch’s Dataset. Here, I’ve also laid out the two required methods as well as a required constructor. I’ll go into my implementation individually for each method.

Lets look at our constructor.

def __init__(self, datapath, custom_transforms = []):
# count images in examples
self.datapath = datapath
self.cat_fnames = os.listdir( os.path.join(datapath, "Cat") )
self.dog_fnames = os.listdir( os.path.join(datapath, "Dog") )
self.transforms = transforms.Compose([

Constructor implementation.

Our constructor accepts a path to our data folder: data/PetImages in our case. We’ll save this as a class property.

Two other properties, self.cat_fnames and self.dog_fnames, contain arrays of image filenames within each folder.

Our final class property, self.transforms, is assigned to PyTorch transform. This transform is a callable object, requiring initialization, which resizes and crops a given image to be 96 x 96 resolution and converts the image to a tensor. The optional custom_transforms parameter inserted within the transform is for image augmentations, such as RandomHorizontalFlip or ColorJitter, which can be useful for squeezing more performance. I won’t be using any for this tutorial, but you are free to experiment. (NOTE: If you use augmentations, make sure to apply them only to the training set)

Below is my implementation of the __len__ function.

def __len__(self):
return len(self.cat_fnames) + len(self.dog_fnames)

Length method implementation

Very simple. The amount of images in our dataset is equal to the number of cat image files + the number of dog image files. Here, we’re making use of our class properties to calculate that.

Finally, our __getitem__ method.

def __getitem__(self, idx):
# lower indices get cats while higher indices get dogs
impath = None
is_cat = idx < len(self.cat_fnames)
if is_cat:
fname = self.cat_fnames[idx]
impath = os.path.join(self.datapath, "Cat", fname)
fname = self.dog_fnames[idx - len(self.cat_fnames)]
impath = os.path.join(self.datapath, "Dog", fname)
# open image and apply transforms
image ="RGB")
as_tensor = self.transforms(image)
# return image as tensor and binary label for cat (0) / dog (1)
return as_tensor, int(not is_cat)

Get Item method implementation

The method accepts an index to retrieve data from. We will return a tuple with a tensor representing a 96 x 96 image and a single integer label containing a 0 for a cat and a 1 for a dog.

Starting off, we need to convert our index to an actual file to retrieve. Here, I will index into our cat images if the index is within the lower half of the length. Otherwise, I will retrieve a dog. The if statement retrieves the path to the image of a cat / dog that we want to return. We then use PIL to open our image, convert it to an RGB format (not all images in our set have three channels, which creates problems during batching), apply our stored transform to do the resize and conversion, and return our resulting tensor plus an integer indicating cat / dog status.

Now that we have full Dataset implemented, we can explore the second class: the DataLoader. Unlike the Dataset, the DataLoader requires no implementation. Instead, it wraps an existing dataset, augmenting it to return shuffled batches of data. Creating a DataLoader for our CatDogDataset requires only a few lines.

from import DataLoaderdataset = CatDogDataset("data/PetImages")
data = DataLoader(dataset, batch_size=128, shuffle=True)

The DataLoader constructor accepts more parameters, but this should be enough for us.

To accurately train our model, we will want to create what is called a train/test split. This means that we will use a subset of our dataset, the train split, to feed data and adjust our model parameters on, and this also means that we will test the performance of our model on a dataset it doesn’t learn from, the test split, to evaluate our model’s real world performance on unseen data.

To create this partition, PyTorch gives us another convenience function: random_split .

from import DataLoader, random_splitdef load_data(datapath):
dataset = CatDogDataset(datapath)
# We pass in our dataset and the split distribution to get our
# subsets in a tuple.
train_set, test_set = random_split(dataset, (19998, 5000))
train_data, test_data = (
DataLoader(train_set, batch_size=128, shuffle=True),
DataLoader(test_set, batch_size=128, shuffle=True)

This function creates our training / testing split from a data path. We’ll use this later.

The Model

Entire papers are dedicated to figuring out optimal model architecture for image classification tasks. Here, I will use the work of smarter people to solve my simple task.

PyTorch provides plenty of pre-built models for different tasks. Here, I will use ResNet, a revolutionary Convolutional Neural Network (CNN)for the multi-class image classification task detailed in this paper. While it’s not important to understand the fine specifics of the model, it is always good to understand the how a model’s architecture allows it to make a decision. I will let this article do the explaining for me, as it does a far better job than I could ever do.

To use ResNet in our code, import the torchvision module and use the models class. If you want to try and build your own CNN for this task, follow this guide here.

import torch
import torchvision
model = torchvision.models.resnet18(pretrained=False)

You can pass in True to the pretrained parameter to download model parameters trained on ImageNet. If you train the model on our Cat / Dog dataset with these pre-trained weights in place, you could potentially get better performance, though I’ve opted not to do this for the tutorial.

Here, I’ve initialized ResNet18, the ResNet variation with 18 layers. I chose ResNet18 due to its good performance and small footprint. You can view the other models you can initialize here.

Just like that, we have a model we are almost ready to use. We just have to make one change. Currently, the model is set up to classify 1000 classes, looking at the output layer.

> Linear(in_features=512, out_features=1000, bias=True)

Because we’re doing a binary classification task, we have to change the output features to two. To do so, all we need to do is replace the linear layer of our ResNet with a new one.

model.fc = torch.nn.Linear(512, 2)

With this, our model will output raw scores for two classes (cat, dog). Now, we’re ready to train.


First, let’s load our training and testing data.

train_data, test_data = load_data("data/PetImages")

If you have a GPU or are on CoLab, let’s get our model onto the GPU. Training will be painfully long without this line.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model =

Now, we need to implement the code that does the gradient descent training algorithm. Gradient descent is an algorithm that iterates through our dataset multiple times, feeding our training examples into our model to get predictions, measuring how far the predictions are from our actual labels through a loss value, and updating our model parameters by adding a calculated gradient from our loss so we hopefully do better on subsequent iterations.

To run gradient descent, we will need a loss function to calculate loss and an optimizer to update our model parameters. These exist in PyTorch as functions.

loss_func = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-3).to(device)

For our loss, I’ve opted to calculate a cross entropy loss. PyTorch has multiple loss functions, but cross entropy is perfect for a classification task where the model outputs raw scores. A regression task (calculating a value on a spectrum) will require a different loss function.

For our optimizer, I’m using the Adam class. It accepts one required parameter, the model parameters to optimize, which you can obtain via the model.parameters() method. I also pass in a learning rate to the optimizer — this controls the magnitude of the change of our model parameters per each training iterations. Adjusting this can possibly net more performance, though I’ve found 0.001 to be good for my experience. Like the loss functions, PyTorch has a wide amount of optimizers. I chose Adam because of the optimizer’s variable learning rate, requiring less overall tuning of parameters.

If we want to visualize our training loss and model performance on each dataset, we need to record our metrics for later use.

training_accuracies = []
testing_accuracies = []
losses = []

Now, onto the training loop:

for epoch in range(20):
# set model in training mode
for images, labels in train_data:
# get data on GPU and make prediction
images, labels =,
preds = model(images)
# calculate loss and training accuracy from prediction
loss = loss_func(preds, labels)
acc = (preds.argmax(1) == labels).float().mean()
# calculate gradient and update model parameters
optim.zero_grad() # reset gradients in optimizer
loss.backward() # calculate new gradients for parameters
optim.step() # add gradients to parameters
# record metrics
# set model in test mode
for images, labels in test_data:
# get data on GPU and make prediction
images, labels =,
preds = model(images)
# calculate and record accuracy
acc = (preds.argmax(1) == labels).float().mean()

This is a lot to go through, so lets break this code block down.

Our training will iterate over 20 epochs (iterations through a dataset). For smaller datasets, you may need more epochs while for larger datasets, you may need fewer epochs. This can be determined based on a visualization we will later implement.

Within the for-loop, we first set our model in training mode. This is required because some parts of our model function differently during training to get better performance during testing.

Then, we iterate through our entire training set, in batches because of the data loader. We feed our images from the dataset into our model to get predictions. We use the loss function we initialized earlier to calculate a loss from the predictions and training labels. We also calculate an accuracy from the values using tensor broadcasting: here, the line (preds.argmax(1) == labels).float().mean() counts how many of our predictions matched the correct label, returning a float percentage representing an accuracy. Afterwards, we have the three lines, which abstract gradient calculation and model updating for us. Never forget these three lines when writing future training loops. Finally in the train data loop, we record our metrics to display later.

Then, after setting our model into testing mode, we iterate over our entire test set. We make the same predictions, but importantly, we don’t calculate a loss update our model parameters because we want to use this test set as an estimation of real world performance. We then calculate and record our accuracy for later use.

Training will take a little bit on CoLab’s free plan, so feel free to do some chores while your model trains.

TIP: CoLab disconnects you and your code after an idle period of time, which is around 1 hour on a free plan. You can circumvent this by playing a long youtube video in the background. This is especially useful when training with large datasets.

To visualize our metrics, we will need to use matplotlib.pyplot . The following block creates a plt subplot that displays our three metrics side by side.

import matplotlib.pyplot as pltfig, (ax1, ax2, ax3) = plt.subplots(1, 3)
ax1.set_title("Training Loss")
ax1.set_ylim([0, 1])
ax1.plot(range(len(losses)), losses)
ax2.set_title("Training Accuracy")
ax2.set_ylim([0, 1])
ax2.plot(range(len(train_accuracies)), train_accuracies)
ax3.set_title("Test Accuracy")
ax3.set_ylim([0, 1])
ax3.plot(range(len(test_accuracies)), test_accuracies)

The graph will look a little like this.

Your graph will look a bit different than mine.

Your graphs will look more concave than mine, since I used image augmentations for my training. I was able to get around 90% accuracy, where you may get around 85% accuracy without augmentations.


This was a very fast overview of using PyTorch to do classification. The ultimate takeaway from this article should be the fact that Torch does a lot of heavy lifting for you. With a little bit of code for data and some boilerplate code for training, you can get a model up and running within an hour.

With that said, there are plenty of ways you can get better performance than what I covered. Here are a few off the top of my head:

  • Data augmentations in the training data: RandomHorizontalFlip, RandomCrop, ColorJitter to name a few

View my Colab notebook as a reference. Some of my code is different, but overall, I get to the same end result.

UT Austin CS Major - Spring 2022