Image Classification with PyTorch and Cleanlab#

This 5-minute cleanlab quickstart tutorial demonstrates how to find potential label errors in image classification data. Here we use the MNIST dataset containing 70,000 images of handwritten digits from 0 to 9.

Overview of what we’ll do in this tutorial:

  • Build a simple PyTorch neural net and wrap it with Skorch to make it scikit-learn compatible.

  • Compute the out-of-sample predicted probabilities, psx, via cross-validation.

  • Generate a list of potential label errors with Cleanlab’s get_noise_indices.

1. Install the required dependencies#

Install the following dependencies with pip install:

  1. cleanlab

  2. pandas

  3. matplotlib

  4. torch

  5. torchvision

  6. skorch

2. Fetch and scale the MNIST dataset#

[3]:
from sklearn.datasets import fetch_openml

mnist = fetch_openml("mnist_784")  # Fetch the MNIST dataset

X = mnist.data.astype("float32").to_numpy()  # 2D numpy array of image features
X /= 255.0  # Scale the features to the [0, 1] range

y = mnist.target.astype("int64").to_numpy()  # 1D numpy array of the image labels

Bringing Your Own Data (BYOD)?

Assign your data’s features to variable X and its labels to variable y instead.

3. Define a classification model#

Here, we define a simple neural network with PyTorch.

[4]:
from torch import nn

model = nn.Sequential(
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 10),
    nn.Softmax(dim=-1),
)

4. Ensure your classifier is scikit-learn compatible#

As some of Cleanlab’s features requires scikit-learn compatibility, we will need to adapt the above PyTorch neural net accordingly. Skorch is a convenient package that helps with this. You can also easily wrap an arbitrary model to be scikit-learn compatible as demonstrated here.

[5]:
from skorch import NeuralNetClassifier

model_skorch = NeuralNetClassifier(model)

5. Compute out-of-sample predicted probabilities#

If we’d like Cleanlab to identify potential label errors in the whole dataset and not just the training set, we can consider using the entire dataset when computing the out-of-sample predicted probabilities, psx, via cross-validation.

[6]:
from sklearn.model_selection import cross_val_predict

psx = cross_val_predict(model_skorch, X, y, cv=3, method="predict_proba")
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        1.9845       0.7482        1.5712  0.9986
      2        1.2629       0.8020        0.9596  0.9132
      3        0.8940       0.8338        0.7255  0.9391
      4        0.7364       0.8493        0.6166  0.9282
      5        0.6544       0.8593        0.5528  0.9113
      6        0.5925       0.8677        0.5084  0.9149
      7        0.5556       0.8721        0.4772  0.9111
      8        0.5232       0.8756        0.4536  0.9569
      9        0.5002       0.8800        0.4335  0.9243
     10        0.4787       0.8832        0.4180  0.9266
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        1.9833       0.7598        1.5534  0.9128
      2        1.2569       0.8133        0.9337  0.9286
      3        0.8882       0.8388        0.6993  0.9154
      4        0.7325       0.8561        0.5885  0.9328
      5        0.6482       0.8685        0.5239  0.9243
      6        0.5932       0.8777        0.4802  0.9194
      7        0.5532       0.8834        0.4498  0.9269
      8        0.5223       0.8894        0.4260  0.9397
      9        0.4972       0.8930        0.4061  0.9689
     10        0.4730       0.8952        0.3914  0.9215
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        1.9989       0.7480        1.5723  0.9072
      2        1.2912       0.8127        0.9472  0.9356
      3        0.9270       0.8437        0.7059  0.9256
      4        0.7628       0.8627        0.5909  0.9296
      5        0.6765       0.8738        0.5246  0.9140
      6        0.6224       0.8825        0.4801  0.9168
      7        0.5802       0.8880        0.4477  0.9441
      8        0.5455       0.8922        0.4231  0.9218
      9        0.5216       0.8972        0.4040  0.9237
     10        0.4971       0.9006        0.3878  0.9386

6. Run Cleanlab to find potential label errors#

Cleanlab has a get_noise_indices function to generate a list of potential label errors. Setting sorted_index_method="prob_given_label" returns the indices of all the most likely label errors, sorted by the most suspicious example first.

[7]:
from cleanlab.pruning import get_noise_indices

ordered_label_errors = get_noise_indices(
    s=y, psx=psx, sorted_index_method="prob_given_label"
)

7. Review some of the most likely mislabeled examples#

[8]:
print(f"Cleanlab found {len(ordered_label_errors)} potential label errors.")
print(
    f"Here are the indices of the top 15 most likely ones: \n {ordered_label_errors[:15]}"
)
Cleanlab found 1235 potential label errors.
Here are the indices of the top 15 most likely ones:
 [24798 18598  8729 20820 31134 12679 15942  1352  7010 55739 39457 53216
 20735 11208 26376]

We’ll define a new plot_examples function to display any examples in a subplot conveniently.

Click here to view its code.
[9]:
import matplotlib.pyplot as plt


def plot_examples(id_iter, nrows=1, ncols=1):
    for count, id in enumerate(id_iter):
        plt.subplot(nrows, ncols, count + 1)
        plt.imshow(X[id].reshape(28, 28), cmap="gray")
        plt.title(f"id: {id} \n label: {y[id]}")
        plt.axis("off")

    plt.tight_layout(h_pad=2.0)

Let’s start by having an overview of the top 15 most likely label errors. From here, we can see a few label errors and edge cases. Feel free to change the parameters to display more or fewer examples.

[10]:
plot_examples(ordered_label_errors[range(15)], 3, 5)
../_images/notebooks_Image_Tut_28_0.png

Let’s zoom into specific examples:

Given label is 4 but looks more like a 7

[11]:
plot_examples([59915])
../_images/notebooks_Image_Tut_31_0.png

Given label is 4 but also looks like 9

[12]:
plot_examples([24798])
../_images/notebooks_Image_Tut_33_0.png

Edge case of odd looking 9s

[13]:
plot_examples([18598, 1352, 61247], 1, 3)
../_images/notebooks_Image_Tut_35_0.png

Cleanlab has shortlisted the most likely label errors to speed up your data cleaning process. With this list, you can decide whether to fix label errors, augment edge cases, or remove obscure examples.

What’s next?#

Congratulations on completing this tutorial! Check out our following tutorial on using Cleanlab for text classification, where we found hundreds of potential label errors in one of the most well-known text datasets, the IMBDb movie review dataset!