from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# Validation
# Evaluating the trained model on training data
y_pred = model.predict(X_test)
# Computing the confusion matrix
cm = confusion_matrix(y_test, y_pred)
# Plotting the confusion matrix using Matplotlib
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=iris.target_names, yticklabels=iris.target_names,
title='Confusion Matrix',
ylabel='True label',
xlabel='Predicted label')
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
fmt = 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
#saving figure
plt.savefig('output/graph.png')