Deploy an Image Classification Model Using Flask

Deploy an Image Classification Model Using Flask


Overview

  • Get an overview of PyTorch and Flask
  • Learn to build an image classification model in PyTorch
  • Learn how to deploy the model using Flask.

Introduction

Image Classification is a pivotal pillar when it comes to the healthy functioning of Social Media. Classifying content on the basis of certain tags are in lieu of various laws and regulations. It becomes important so as to hide content from a certain set of audiences.

I regularly encounter posts with a “Sensitive Content” on some of the images while scrolling on my Instagram feed. I am sure you must have too. Any image on a humanitarian crisis, terrorism, or violence is generally classified as ‘Sensitive Content’. It always intrigued me how Instagram categorizes an image. This unceasing curiosity pushed me to find answers to understand the procedure of Image Classification.

Most of the images are detected by the image classification models deployed by Instagram. And, there is also a community-based feedback loop. This is one of the most important use cases of the Image Classification. In this article, we will deploy an image classification model to detect the category of the images.

Table of Contents

  1. What is Model Deployment?
  2. Introduction to PyTorch
  3. What is Flask?
  4. Installing Flask and PyTorch on your Machine
  5. Understanding the Problem Statement
  6. Setup the Pre-Trained Image Classification Model
  7. Build an Image Scraper
  8. Create the Webpage
  9. Setup the Flask Project
  10. Working of the Deployed Model

What is Model Deployment?

In a typical machine learning and deep learning project, we usually start with defining the problem statement followed by data collection and preparation, and model building, right?

Once we have successfully built and trained the model, we want it to be available for the end-users. Thus we will have to “deploy”  the model so that the end-users can make use of it. Model Deployment is one of the later stages of any machine learning or deep learning project.

In this article, we will build a classification model in PyTorch and then learn how to deploy the same using Flask. Before we get into the details, let us have a quick introduction to PyTorch.

 

Introduction to PyTorch

PyTorch is a python based library that provides flexibility as a deep learning development platform. The workflow of PyTorch is as close as you can get to python’s scientific computing library – NumPy.

The image will be uploaded soon.

 

PyTorch is being widely used for building deep learning models. Here are some important advantages of PyTorch –

  • Easy to use API –The PyTorch API is as simple as python can be.
  • Python support – PyTorch smoothly integrates with the python data science stack.
  • Dynamic computation graphs –  PyTorch provides a framework for us to build computational graphs as we go, and even change them during the runtime. This is valuable for situations where we don’t know how much memory is going to be required for creating a neural network.

In further sections, we will use a pre-trained model to detect the category of the image using PyTorch. Next, we will be using Flask for model deployment. In the next section, we will briefly discuss Flask.

What is Flask?

Flask is a web application framework written in Python. It has multiple modules that make it easier for a web developer to write applications without having to worry about the details like protocol management, thread management, etc.

Flask gives a variety of choices for developing web applications and it gives us the necessary tools and libraries that allow us to build a web application.

The image will be uploaded soon

 

Installing Flask and PyTorch on your Machine

Installing Flask is simple and straightforward. Here, I am assuming you already have Python 3 and pip installed. To install Flask, you need to run the following command:

sudo apt-get install python3-flask

Next, we need to install the PyTorch. You are not required to have the GPU to run the code provided in this article.

!pip install torch torchvision

That’s it! Now let us take up a problem statement and build a model.

Understanding the Problem Statement

Let us discuss the problem statement, we want to create a web page that will contain a text box like this (shown below). Here users will input URL. And, here the task is to scrape all images from the URL. For each of the images, we will predict the category or class of the image using an image classification model and render the images with categories on the webpage.

The image will be uploaded soon

Here is the workflow for the end-to-end model-

Setting up the Project WorkFlow

  1. Model Building: We will use a pre-trained model Densenet 121 to predict the image class. It is available in the torchvision library in PyTorch. Here, our focus will not be on building a highly accurate classification model from scratch but to see how to deploy that model and make use of it with a web interface.
  2. Create an Image Scraper: We will create a web scraper using requests and the BeautifulSoup library. It will download all the images from a URL and store it so that we can make predictions on it.
  3. Design Webpage Template: Also we will design a user interface where the user can submit a URL and also get the results once calculated.
  4. Classify images and send results: Once we get the query from the user, we will use the model to predict classes of the images and send the results back to the user.

Here is a representation of the steps we just saw:

The image will be uploaded soon

Let’s discuss all the required components of the projects:

Setup the Pre-Trained Image Classification Model

We will use a pre-trained model Densenet 121 to classify the images. If you want to build an Image Classification model I would highly recommend you to go through this article: Build your First Image Classification Model in just 10 Minutes!

You can download the complete code and dataset here.

Let’s start by importing some of the required libraries and get the densenet121 model from the torchvision library. Make sure to add the parameter ‘pretrained’ as True.

# importing the required libraries
import json
import io
import glob
from PIL import Image
from torchvision import models
import torchvision.transforms as transforms
# Pass the parameter "pretrained" as "True" to use the pretrained weights:
model = models.densenet121(pretrained=True)
# switch to model to `eval` mode:
model.eval()

Now, we will define a function to transform the image. It will create a transform pipeline and transform the image as required. This method takes the image data in bytes and applies a series of ‘transform’ functions on it and returns a tensor. This piece of code was taken from the pytorch documentation.

# define the function to pre-process the
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)

Now, the pre-trained model returns the index of the predicted class id. PyTorch has provided mapping for the same so that we can see the name of the predicted class. You can download the mapping here. It has 1000 different categories.

# load the mapping provided by the pytorch
imagenet_class_mapping = json.load(open('imagenet_class_index.json'))

Here is a sample of the mapping:

Next, we will define a function to get the category of the image. For this, we will pass the path of the image as the only parameter. At first, it will open and read the image in the binary form and then transform it. Then it will pass the transformed image to the model to get the predicted class. It will use the mapping and return the class name.

# define the function to get the class predicted of image
# it takes the parameter: image path and provide the output as the predicted class
def get_category(image_path):
# read the image in binary form
with open(image_path, 'rb') as file:
image_bytes = file.read()
# transform the image
transformed_image = transform_image(image_bytes=image_bytes)
# use the model to predict the class
outputs = model.forward(transformed_image)
_, category = outputs.max(1)
# return the value
predicted_idx = str(category.item())
return imagenet_class_mapping[predicted_idx]

Let’s try this function on a few images:

get_category(image_path='static/sample_1.jpeg')
## ['n02089973', 'English_foxhound']

get_category(image_path='static/sample_2.jpeg')
## ['n11939491', 'daisy']

Now, our model is ready to predict the classes of the image. Let’s start with building the image scraper.

 

Build an Image Scraper

In this section, we will build a web scraper that will download the images from the URL provided. We will use the BeautifulSoup library to download the images. You are free to use any other library or an API that will give you the images. 

We will start by importing some of the required libraries. For each of the URLs that we will scrape a new directory will be created to store the images. We will create a function get_path that will return you the path of the folder created for that URL.

# importing required libraries
import requests
from bs4 import BeautifulSoup
import os
import time
def get_path(url):
return "static/URL_" + str(url.replace("/","_"))
headers = {
'User-Agent': "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/77.0.3865.90 Safari/537.36"
}
view rawscrape_images.py hosted with ❤ by GitHub

Now, we will define a function get_images. It will first create the directory using the get_path function and then send a request for the source code. From the source code, we will extract sources by using “img” tag.

After this, We will select only the images with jpeg format. You can also add png format images. I have filtered them out as most of the png format pictures are logos. In the end, start the counter and save images with counter names in the specified directory.

# define the function to scrape images and store it in a directory
def get_images(url):
# get the directory path
path = get_path(url)
try:
os.mkdir(path)
except:
pass
# request the source code from the URL
response = requests.request("GET", url, headers=headers)
# parse the data through the Beautiful Soup
data = BeautifulSoup(response.text, 'html.parser')
# find the image tag in the source code
images = data.find_all('img', src=True)
# extract the source from all the image tags
image_src = [x['src'] for x in images]
# select only jpeg format images
image_src = [x for x in image_src if x.endswith('.jpeg') ]
image_count = 1
# store the image in the specified directory
for image in image_src:
print(image)
image_file_name = path+'/'+str(image_count)+'.jpeg'
print(image_file_name)
# open the file in write binary form and add the image content to store it
with open(image_file_name, 'wb') as f:
res = requests.get(image)
f.write(res.content)
image_count = image_count+1
view rawscrape_images2.py hosted with ❤ by GitHub

Let’s try out the scraper that we have just created!

get_images('https://medium.com/@allanishac/9-wild-animals-that-would-make-a-much-better-president-than-donald-trump-b41f960bb171')

Now, a new directory is created and see how it looks like. We have all the images downloaded at a single place.

Note: It is advised to use this Image Scraper as per the learning purpose only. Always follow the robots.txt file of the target website which is also known as the robot exclusion protocol. This tells web robots which pages not to crawl.

Create the Webpage

We will create two webpages one is “home.html” and another one is “image_class.html”.

  1. “home.html” is the default one which will have a text box in which a user can type the URL.
  2. image_class.html” will help us to render the images with their categories.

1. home.html

We need to add the form tag in the home.html file to collect the data in the search container. In the form tag, we will pass the method post and name as “search”.

 

By doing this, our backend code would be able to know that we have received some data with the name “search”. At the backend, we need to process that data and send it.

2. image_class.html

While calculating the results another page will get rendered with the results as shown below. This page “image_class.html” will be updated on every query. And you can see that we are showing below information on the web page:

  1. Image category
  2. Image
  3. Frequency count of all available image category

Here, is code to perform this:

# define function to add the image in the html file with the class name
def get_picture_html(path, tag):
image_html = """<p> {tag_name} </p> <picture> <img src= "../{path_name}" height="300" width="400"> </picture>"""
return image_html.format(tag_name=tag, path_name=path)
# define function to add the list element in the html file
def get_count_html(category, count):
count_html = """<li> {category_name} : {count_} </li>"""
return count_html.format(category_name = category, count_ = count)
# function to calculate the value count
def get_value_count(image_class_dict):
count_dic = {}
for category in image_class_dict.values():
if category in count_dic.keys():
count_dic[category] = count_dic[category]+1
else:
count_dic[category] = 1
return count_dic
# function to generate the html file from image_class dictionary
# keys will be the path of the images and values will be the class associated to it.
def generate_html(image_class_dict):
picture_html = ""
count_html = ""
# loop through the keys and add image to the html file
for image in image_class_dict.keys():
picture_html += get_picture_html(path=image, tag= image_class_dict[image])
value_counts = get_value_count(image_class_dict)
# loop through the value_counts and add a count of class to the html file
for value in value_counts.keys():
count_html += get_count_html(value, value_counts[value])
view rawgenerate_html.py hosted with ❤ by GitHub

The next step is to setup the Flask project to combine these individual pieces to solve the challenge.

Setup the Flask Project

We have done the following tasks involved in our project:

  1. Image Classification model that is working fine and able to classify the images.
  2. We have built the Image scraper that will download the images and store them.
  3. We have created the webpage to get and return the results.

And now we need to connect all these files together so that we can have a working project.

Let’s have a look at the directory structure.

Note: Make sure that you save the images in the folder name static and html files in templates. Flask will only look for these names. You will get an error if you change these.

Running a Flask Application

Flask application will first render the home.html file and whenever someone sends a request for the image classification, Flask will detect a post method and call the get_image_class function.

This function will work in the following steps:

  1. First, it will send a request to download the images and store them.
  2. Next, It will send the directory path to the get_prediction.py file which will calculate and return the results in the form of a dictionary.
  3. Finally, It will send this dictionary to the generate_html.py, file generating the output file which will be sent back to the user.

# importing the required libaries
from flask import Flask, render_template, request, redirect, url_for
from get_images import get_images, get_path, get_directory
from get_prediction import get_prediction
from generate_html import generate_html
from torchvision import models
import json
app = Flask(__name__)
# mapping
imagenet_class_mapping = json.load(open('imagenet_class_index.json'))
# use the pre-trained model
model = models.densenet121(pretrained=True)
model.eval()
# define the function to get the images from the url and predicted the class
def get_image_class(path):
# get images from the URL and store it in a given path
get_images(path)
# predict the image class of the images with provided directory
path = get_path(path)
images_with_tags = get_prediction(model, imagenet_class_mapping, path)
# generate html file to render once we predict the classes
generate_html(images_with_tags)
view rawsetup_flask.py hosted with ❤ by GitHub

Once the above steps are done, we are ready to serve the user with the results. We will call the success function which will then render the image_class.html file.

# by deafult render the "home.html"
@app.route('/')
def home():
return render_template('home.html')
@app.route('/', methods=['POST', 'GET'])
def get_data():
if request.method == 'POST':
user = request.form['search']
# if search button hit, call the function get_image_class
get_image_class(user)
#render the image_class.html
return redirect(url_for('success', name=get_directory(user)))
@app.route('/success/<name>')
def success(name):
return render_template('image_class.html')
if __name__ == '__main__' :
app.run(debug=True)
view rawsetup_flask2.py hosted with ❤ by GitHub

Get Prediction for all images of Source URL

Till now, we have taken prediction for each image individually. Now, we will solve this by modifying get_category function with new parameters. We will pass the directory path which will contain multiple image files.

Now, we will define another function get_prediction which will use the get_category function and will return the dictionary where the keys will be the image path and the values will be the image class.

Later, we will send this dictionary to the generate_html.py file which will create the HTML file for us.

# get class of all the images present in the directory
def get_category(model, imagenet_class_mapping, image_path):
with open(image_path, 'rb') as file:
image_bytes = file.read()
transformed_image = transform_image(image_bytes=image_bytes)
outputs = model.forward(transformed_image)
_, category = outputs.max(1)
predicted_idx = str(category.item())
return imagenet_class_mapping[predicted_idx]
# It will create a dictionary of the image path and the predicted class
# we will use that dictionary to generate the html file.
def get_prediction(model, imagenet_class_mapping, path_to_directory):
files = glob.glob(path_to_directory+'/*')
image_with_tags = {}
for image_file in files:
image_with_tags[image_file] = get_category(model, imagenet_class_mapping, image_path=image_file)[1]
return image_with_tags

Now, all the code files are ready and we just need to connect these with the master file.

Firstly, create an object of the Flask class that will take the name of the current module __name__ as an argument. The route function will tell the Flask application which URL to render next on the webpage.

Working of the Deployed Model

You can download the complete code and dataset here. 

Now, we will run the file get_class.py and the flask server will get started on localhost:5000.

Open the web browser and go to localhost:5000 and you will see that the default home page is rendered there. Now, type any URL in the text box and press the search button. It might take 20-30 seconds depending upon the number of images in that URL and the Internet speed.

Let’s check out the working of the deployed model.

Video Player
00:00
01:05

End Notes

In this article, I explained, in brief, the concepts of model deployment, Pytorch, and Flask. Then we dived into understanding various steps involved in the process of creating an image classification model using PyTorch and deploying it with Flask. I hope this helps you in building and deploying your image classification model.

Also, the model was deployed on the localhost. We can also deploy it on Cloud Services like Google Cloud, Amazon, github.io etc. We will cover this also in the upcoming article.

No comments:

Post a Comment