Mobile Media Apps Assignment: How to Deploy a PyTorch Image Classification Model to Flask For Beginners (Part 1)

Requirements

Data

Some image examples from our Cat vs Dog dataset.
!wget https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip!unzip kagglecatsanddogs_3367a.zip -d data

If you’re unfamiliar with the CoLab block above, here’s a short explanation of what this it does. The ! in front of these two lines signals that these are shell commands in CoLab. wget is a shell command which directly downloads the content at a remote location to the shell environment, and unzip is self explanatory, extracting the contents of the zip, where the -d flag gives the directory to put the extracted contents.

data
|
|__PetImages
| |
| |__Cat (directory of cat images
| |
| |__Dog (directory of dog images)
|
|__MSR-LA - 3467.docx
|
|__readme[1].txt
import os
from os import path
import PIL
rm_count, count = 0, 0for dir in ("Cat", "Dog"):
dest = path.join("data/PetImages", dir)
files = os.listdir(dest)
for fname in files:
full = path.join(dest, fname)
if fname.find("jpg") == -1:
os.remove(full)
rm_count += 1
continue
try:
PIL.Image.open(full).convert("RGB")
except PIL.UnidentifiedImageError:
os.remove(full)
rm_count += 1
continue
count += 1
print("Kept %d files, removed %d files" % (count, rm_count))

The code isn’t too important to understand, but what it does is that it keeps JPEG files that can be opened without error by the image library we will be using, PIL. Run this once and your data is good to go.

from torch.utils.data import Dataset
from torchvision import transforms
import PIL
import os
class CatDogDataset(Dataset): def __init__(self, ...):
pass

def __len__(self):
pass

def __getitem__(self, idx):
pass

Our Dataset Class Skeleton

def __init__(self, datapath, custom_transforms = []):
# count images in examples
self.datapath = datapath
self.cat_fnames = os.listdir( os.path.join(datapath, "Cat") )
self.dog_fnames = os.listdir( os.path.join(datapath, "Dog") )
self.transforms = transforms.Compose([
transforms.Resize(96),
transforms.CenterCrop(96),
*custom_transforms,
transforms.ToTensor()
])

Constructor implementation.

def __len__(self):
return len(self.cat_fnames) + len(self.dog_fnames)

Length method implementation

def __getitem__(self, idx):
# lower indices get cats while higher indices get dogs
impath = None
is_cat = idx < len(self.cat_fnames)
if is_cat:
fname = self.cat_fnames[idx]
impath = os.path.join(self.datapath, "Cat", fname)
else:
fname = self.dog_fnames[idx - len(self.cat_fnames)]
impath = os.path.join(self.datapath, "Dog", fname)
# open image and apply transforms
image = PIL.Image.open(impath).convert("RGB")
as_tensor = self.transforms(image)
# return image as tensor and binary label for cat (0) / dog (1)
return as_tensor, int(not is_cat)

Get Item method implementation

from torch.utils.data import DataLoaderdataset = CatDogDataset("data/PetImages")
data = DataLoader(dataset, batch_size=128, shuffle=True)

The DataLoader constructor accepts more parameters, but this should be enough for us.

from torch.utils.data import DataLoader, random_splitdef load_data(datapath):
dataset = CatDogDataset(datapath)
# We pass in our dataset and the split distribution to get our
# subsets in a tuple.
train_set, test_set = random_split(dataset, (19998, 5000))
train_data, test_data = (
DataLoader(train_set, batch_size=128, shuffle=True),
DataLoader(test_set, batch_size=128, shuffle=True)
)

This function creates our training / testing split from a data path. We’ll use this later.

The Model

import torch
import torchvision
model = torchvision.models.resnet18(pretrained=False)

You can pass in True to the pretrained parameter to download model parameters trained on ImageNet. If you train the model on our Cat / Dog dataset with these pre-trained weights in place, you could potentially get better performance, though I’ve opted not to do this for the tutorial.

print(net.fc)
> Linear(in_features=512, out_features=1000, bias=True)
model.fc = torch.nn.Linear(512, 2)

Training

train_data, test_data = load_data("data/PetImages")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
loss_func = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-3).to(device)
training_accuracies = []
testing_accuracies = []
losses = []
for epoch in range(20):
# set model in training mode
model.train()
for images, labels in train_data:
# get data on GPU and make prediction
images, labels = images.to(device), labels.to(device)
preds = model(images)
# calculate loss and training accuracy from prediction
loss = loss_func(preds, labels)
acc = (preds.argmax(1) == labels).float().mean()
# calculate gradient and update model parameters
optim.zero_grad() # reset gradients in optimizer
loss.backward() # calculate new gradients for parameters
optim.step() # add gradients to parameters
# record metrics
train_accuracies.append(acc.item())
losses.append(loss.item())
# set model in test mode
model.eval()
for images, labels in test_data:
# get data on GPU and make prediction
images, labels = images.to(device), labels.to(device)
preds = model(images)
# calculate and record accuracy
acc = (preds.argmax(1) == labels).float().mean()
test_accuracies.append(acc.item())

TIP: CoLab disconnects you and your code after an idle period of time, which is around 1 hour on a free plan. You can circumvent this by playing a long youtube video in the background. This is especially useful when training with large datasets.

import matplotlib.pyplot as pltfig, (ax1, ax2, ax3) = plt.subplots(1, 3)
fig.suptitle("Visualization")
ax1.set_title("Training Loss")
ax1.set_ylim([0, 1])
ax1.plot(range(len(losses)), losses)
ax2.set_title("Training Accuracy")
ax2.set_ylim([0, 1])
ax2.plot(range(len(train_accuracies)), train_accuracies)
ax3.set_title("Test Accuracy")
ax3.set_ylim([0, 1])
ax3.plot(range(len(test_accuracies)), test_accuracies)
plt.show()
Your graph will look a bit different than mine.

Conclusion

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store