Reader small image

You're reading from  Mastering PyTorch

Product typeBook
Published inFeb 2021
Reading LevelIntermediate
PublisherPackt
ISBN-139781789614381
Edition1st Edition
Languages
Tools
Right arrow
Author (1)
Ashish Ranjan Jha
Ashish Ranjan Jha
author image
Ashish Ranjan Jha

Ashish Ranjan Jha received his bachelor's degree in electrical engineering from IIT Roorkee (India), a master's degree in Computer Science from EPFL (Switzerland), and an MBA degree from Quantic School of Business (Washington). He has received a distinction in all 3 of his degrees. He has worked for large technology companies, including Oracle and Sony as well as the more recent tech unicorns such as Revolut, mostly focused on artificial intelligence. He currently works as a machine learning engineer. Ashish has worked on a range of products and projects, from developing an app that uses sensor data to predict the mode of transport to detecting fraud in car damage insurance claims. Besides being an author, machine learning engineer, and data scientist, he also blogs frequently on his personal blog site about the latest research and engineering topics around machine learning.
Read more about Ashish Ranjan Jha

Right arrow

Chapter 13: PyTorch and Explainable AI

Throughout this book, we have built several deep learning models that can perform different kinds of tasks for us. For example, a handwritten digit classifier, an image-caption generator, a sentiment classifier, and more. Although we have mastered how to train and evaluate these models using PyTorch, we do not know what precisely is happening inside these models while they make predictions. Model interpretability or explainability is that field of machine learning where we aim to answer the question, why did the model make that prediction? More elaborately, what did the model see in the input data to make that particular prediction?

In this chapter, we will use the handwritten digit classification model from Chapter 1, Overview of Deep Learning Using PyTorch, to understand its inner workings and thereby explain why the model makes a certain prediction for a given input. We will first dissect the model using only PyTorch code. Then, we will...

Technical requirements

We will be using Jupyter notebooks for all of our exercises. The following is a list of Python libraries that should be installed for this chapter using pip. For example, run pip install torch==1.4.0 on the command line:

jupyter==1.0.0
torch==1.4.0
torchvision==0.5.0 matplotlib==3.1.2
captum==0.2.0

All code files relevant to this chapter are available at https://github.com/PacktPublishing/Mastering-PyTorch/tree/master/Chapter13.

Model interpretability in PyTorch

In this section, we will dissect a trained handwritten digits classification model using PyTorch in the form of an exercise. More precisely, we will be looking at the details of the convolutional layers of the trained handwritten digits classification model to understand what visual features the model is learning from the handwritten digit images. We will look at the convolutional filters/kernels along with the feature maps produced by those filters.

Such details will help us to understand how the model is processing input images and, therefore, making predictions. The full code for the exercise can be found at https://github.com/PacktPublishing/Mastering-PyTorch/blob/master/Chapter13/pytorch_interpretability.ipynb.

Training the handwritten digits classifier – a recap

We will quickly revisit the steps involved in training the handwritten digits classification model, as follows:

  1. First, we import the relevant libraries, and then...

Using Captum to interpret models

Captum (https://captum.ai/) is an open source model interpretability library built by Facebook on top of PyTorch, and it is currently (at the time of writing) under active development. In this section, we will use the handwritten digits classification model that we had trained in the preceding section. We will also use some of the model interpretability tools offered by Captum to explain the predictions made by this model. The full code for the following exercise can be found here: https://github.com/PacktPublishing/Mastering-PyTorch/blob/master/Chapter13/captum_interpretability.ipynb.

Setting up Captum

The model training code is similar to the code shown under the Training the handwritten digits classifier – a recap section. In the following steps, we will use the trained model and a sample image to understand what happens inside the model while making a prediction for the given image:

  1. There are few extra imports related to Captum...

Summary

In this chapter, we have briefly explored how to explain or interpret the decisions made by deep learning models using PyTorch. Using the handwritten digits classification model as an example, we first uncovered the internal workings of a CNN model's convolutional layers. We demonstrated how to visualize the convolutional filters and feature maps produced by convolutional layers.

We then used a dedicated third-party model interpretability library built on PyTorch, called Captum. We used out-of-the-box implementations provided by Captum for feature attribution techniques, such as saliency, integrated gradients, and deeplift. Using these techniques, we demonstrated how the model is using an input to make predictions and which parts of the input are more important for a model to make predictions.

In the next, and final, chapter of this book, we will learn how to rapidly train and test machine learning models on PyTorch – a skill that is useful for quickly iterating...

lock icon
The rest of the chapter is locked
You have been reading a chapter from
Mastering PyTorch
Published in: Feb 2021Publisher: PacktISBN-13: 9781789614381
Register for a free Packt account to unlock a world of extra content!
A free Packt account unlocks extra newsletters, articles, discounted offers, and much more. Start advancing your knowledge today.
undefined
Unlock this book and the full library FREE for 7 days
Get unlimited access to 7000+ expert-authored eBooks and videos courses covering every tech area you can think of
Renews at $15.99/month. Cancel anytime

Author (1)

author image
Ashish Ranjan Jha

Ashish Ranjan Jha received his bachelor's degree in electrical engineering from IIT Roorkee (India), a master's degree in Computer Science from EPFL (Switzerland), and an MBA degree from Quantic School of Business (Washington). He has received a distinction in all 3 of his degrees. He has worked for large technology companies, including Oracle and Sony as well as the more recent tech unicorns such as Revolut, mostly focused on artificial intelligence. He currently works as a machine learning engineer. Ashish has worked on a range of products and projects, from developing an app that uses sensor data to predict the mode of transport to detecting fraud in car damage insurance claims. Besides being an author, machine learning engineer, and data scientist, he also blogs frequently on his personal blog site about the latest research and engineering topics around machine learning.
Read more about Ashish Ranjan Jha