Introduction to Saliency Maps—The Vanilla Gradient
Get an overview of saliency maps and implement the vanilla gradient saliency algorithm.
We'll cover the following
Saliency maps
Given an image
A saliency map can therefore be used as a visual explanation to verify the correctness of a model. For example, in the figure above, if the red pixels in the saliency map are concentrated around the “grass,” we can infer that the model is biased toward irrelevant artifacts, such as the background, to make the prediction. As a result, such a model shouldn’t be trusted and should only be deployed after careful examination.
Note: Saliency maps can also be referred to as pixel-attribution, attribution, or sensitivity maps.
Vanilla gradient saliency
We’ll now learn to implement our first saliency map algorithm: the vanilla gradient saliency.
Let’s assume that
where
The vanilla gradients
The unnormalized saliency map
Implementing vanilla gradient saliency
The code below implements and visualizes the vanilla gradient saliency of an image with respect to the prediction made by the MobileNet-V2 network trained on the ImageNet-1K dataset. It outputs the original image, its vanilla gradient saliency, and the network prediction. The red pixels in the saliency map denote the bright pixels important for prediction, while the black pixels are unnecessary.
import torchimport torchvisionimport torchvision.transforms as Tfrom torchvision.models import mobilenet_v2from PIL import Imageimport jsonimport matplotlib.pyplot as pltclass_idx = json.load(open("imagenet_class_index.json", 'r')) # ImageNet Id to Label Mappingidx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]image = Image.open("dog.jpg").resize((224,224)) # original imagetransform = T.Compose([T.Resize((224,224)),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])image = transform(image) # transform and normalize imageimage.requires_grad = Truetorchvision.utils.save_image(image, "./output/image.png", normalize=True)model = mobilenet_v2() # load MobileNetV2 modelckpt = torch.load("./mobilenet_v2-b0353104.pth", map_location="cpu")model.load_state_dict(ckpt)model.eval()logits = model(image.unsqueeze(0))[0] # output logits, forward passprediction = torch.argmax(logits) # network predictionprint("Model predicted : ", idx2label[prediction.item()])grad = torch.autograd.grad(logits[prediction.item()], [image])[0] # backward passgrad, _ = torch.max(grad.relu(), 0) # take positive gradients onlyplt.imshow(grad, cmap=plt.cm.hot) # plot saliencyplt.axis("off")plt.tight_layout()plt.savefig("./output/saliency.png", bbox_inches="tight")
Lines 1–7: We import
torch
for automatic differentiation viatorch.autograd.grad()
,torchvision
for loading MobileNetV2,PIL
for image manipulations,json
for JSON-related utilities, andmatplotlib
for plotting graphs/images.Lines 9–10: We load the ImageNet-1K class indexes and create a map
idx2label
, which maps a numeric label to its corresponding class name in the ImageNet-1K dataset.Lines 12–17: We load a
PIL Python Image Library image
and define atransform
function to resize the input image to 224 ✕ 224 resolution and then normalize it.Lines 19–20: We transform the
image
and enable itsrequires_grad
attribute for computing gradients.Lines 23–26: We load the MobileNet-V2
model
and set its parameters toeval()
mode.model.eval()
is a switch for specific layers like andbatch normalization Batch normalization scales layers' outputs to have a mean of 0 and a variance of 1. The outputs are scaled this way to train the network faster. in the model that behave differently during training and inference (evaluating) time.dropout The dropout layer is a mask that nullifies the contribution of some neurons toward the next layer and leaves all others unmodified. Lines 28–29: We calculate the class
logits
and the networkprediction
.Line 32: We calculate the vanilla gradient of the network prediction with respect to the input image.
Line 33: We take the maximum value of the positive gradient along the channel dimension.
Lines 35–38: We plot and visualize the normalized saliency map.
As we can see, the bright or red pixels are concentrated around the face of the dog, suggesting that the facial features are essential for classifying an object into a dog.
The vanilla gradient algorithm is fast because it involves only one forward-backward pass and generates decent explanations for most inputs. However, when the input becomes noisy, the vanilla gradients also become noisy, giving importance to some random/irrelevant pixels in the saliency map. This makes this algorithm vulnerable to