Performance Metrics for Classification Models in Machine Learning: Part II

Evaluate Multi-Class Classification Models

Bassant Gamal
5 min readFeb 27, 2021

In part I, we discussed how to evaluate binary-class classification models using Recall, Precession, Accuracy, and F1-Score. Here, we will see how we can apply those metrics to a multi-class classification model.

As seen in part I, we can build the confusion matrix for a multi-class model as well as the binary-class model. But as it may become more complex when there are too many classes, we can separate each class in a single confusion matrix to make calculations and visualizations easier.

Note: Of course we will not do that manually for each classification problem we work with, but this is just to make it clear and understand how it works.

Consider this toy example:

The model here tries to classify an input picture as a picture of a cat, dog, tiger, or wolf. The true labels are represented by ‘Y_true’, and the predicted labels are represented by ‘y_pred’.

Now, let’s print the evaluation metrics for this model:

The metrics are calculated for each class individually, except for the accuracy; it’s calculated across all the classes to give the classification accuracy for this model.

Note: The confusion matrix is trnsported i.e. the rows represents the actual values and the columns represents the predicted values. So we need to transpose it to get the one we are fammiliar with.

The Confusion Matrix for the four classes.

Now Let’s pick one class and try to understand the numbers.

Consider the cat class, the confusion matrix for this class is:

The Confusion Matrix for the ‘Cat’ class.

Recall = TP/(TP+FN) = 6/(6+4) = 0.6 → Out of 10 actual cats, the model captured 6 cats correctly.

Precision = TP/(TP+FP) = 6/(6+6) = 0.5 → Out of 12 captured cats, there are 6 actual cats.

The F1-Score = 2 * (Precision * Recall)/(Precision + Recall) = 0.545.

So, if we for some reason are interested in capturing as many cats as we can, and don’t care more about other classes, this means we need to get a high recall for the ‘Cat’ class. But if we want to keep good results for other classes, in this case, we would try to increase the F1-Score for the ‘Cat’ class.

Assume we have made some changes to the model to capture more cats correctly and got the following results:

The recall and precision for the ‘Cat’ class improved, which means that the model is now able to capture more cats correctly (it captures 8 cats out of 10). The performance changed a little bit for the other classes too.

If we forced the model to classify each picture as a ‘Cat’, we will get the following results:

The precision, recall, and f1-score for all other classes are zeros. This means that the model failed to capture those classes. It classifies any picture as a ‘Cat’, so it can capture all the ‘Cats’ in our dataset, that’s why the recall is 1, but as the dataset have other pictures that are not ‘Cats’, the precision for the ‘Cat’ class is very low (0.27).

Similarly, we can build the confusion matrix for the ‘Dog’ class:

The Confusion Matrix for the ‘Dog’ class.

Usually, we need our model to successfully capture all classes, and even if we are interested in one class more than the others, we still need to find a single metric to represent the performance of the model. Two of the most commonly used ones are:

1. Macro Average

Simply, calculate the average value for each metric across the total number of classes.

2. Weighted Average

Give each class a weight that is equal to the number of samples of that class in the dataset, then calculate the average for each metric across the total number of samples.

In our example (consider the first case):

Weighted Avg. for Precision = (10*0.5 + 10*0.364 + 9*0.429 + 8*0.286)/37 = 0.399.

Weighted Avg. for Recall = (10*0.6 + 10*0.4 +9*0.333 + 8*0.25)/37 = 0.405.

Weighted Avg. for F1-Score = (10*0.545 + 10*0.381 + 9*0.375 + 8*0.267)/37 = 0.399.

When to use Which?

In the case you want to give more importance to a specific class based on its proportion, there you use the weighted average.

But there might be a case when your dataset is unbalanced, there you use the macro-average.

Consider the following example:

The dataset here is highly unbalanced, there are 520 samples for the cat class vs a total of 13 samples for other classes, this may make the model predicts everything as a cat while it’s not.

Note that the weighted avg. for all metrics is very high and also the accuracy is high (96.3%), but we still cannot say that it is a perfect model because the dataset is unbalanced.

There is one more thing we need to talk about, how to draw the ROC curve for a multi-class model?

The AUC-ROC curve is only for binary classification problems. But we can extend it to multi-class classification problems by using the One vs All technique.

So, if we have four classes, the ROC curve will be generated for each class individually, for example, classifying ‘Cat’ and ‘Not Cat’ then classifying ‘Dog’ and ‘Not Dog’ and so on.

Let’s look at the following simple example:

I hope that it was useful. If you have any comments or if there is something unclear, please let me know.

--

--

Bassant Gamal
Bassant Gamal

Responses (2)