In many cases, understanding why the model predicted a given outcome is a key detail for model users and a necessary diagnostic to insure your model makes decisions based on the correct features. For example, if you built a convolutional neural network that performed well at predicting damaged products from images, how could you be sure it made its predictions based on the damaged surface of the product and not based on the background of the image that may be correlated to the predictions? A useful tool for this purpose is a saliency map that offers a visualization of the pixels in the image that contribute the most to predictions by the model. In this article, we’ll explore the functionality of saliency maps and walk through an example for a ConvNet used to estimate the age of fish from their scales.
Saliency maps specifically plot the gradient of the predicted outcome from the model with respect to the input, or pixel values.
By calculating the change in predicted class by applying small adjustments to pixel values across the image we can measure the relative importance of each pixel to the ultimate prediction by the model. Figure 1 is an example of a raw saliency map; pixels with a high gradient show up in yellow versus those with a low gradient in blue.
This technique is described in more detail by Simonyan et al. 2013 and can be further investigated at https://arxiv.org/abs/1312.6034
To apply your own, start by loading your ConvNet that has already been trained. In this example, we’ll be using a network I built to predict age from fish scales. Before you proceed, you will need to downgrade Tensorflow to bypass a bug given the newest version of Keras-Vis. I downgraded to Tensorflow 1.7.0.
We’ll need to import a number of modules, namely Keras-Vis that provides the saliency map function. You can learn more about Keras-Vis functions here.
Pick out an image for which you’re interested in the model’s interpretation. In this example, I picked an image that the model successfully predicted its class, but an equally interesting case would be to identify the parts of an image that the model used to incorrectly classify an image.
Locate the name of the last dense layer for the model. In my case, its called ‘dense_12’. We’ll determine the layer index, and change the activation of ‘dense_12’ from softmax to linear because the output gradient will be dependent on all the nodes in the model. Then assign the changes to a new model named ‘model’.
We end up with something like Figure 1, fine-scale measurements for each pixel. To make the interpretation easier, we’ll apply a gaussian filter to smooth out the pixel values and plot it over our original image.
The resulting plot shows the pixels that had the greatest influence on predicted class in yellow (Figure 3).
I’m relieved to see that the model was making decisions based on a part of the image known to be related to fish age (rings on the scale) and not the background or the other scale on the right side of the image. Importantly, this exercise adds credibility to the model’s predictions. The full python script for this visualization can be found here.