A data-driven approach to cleaning image classification datasets using embedding space visualization and UMAP dimensionality reduction to filter out bad training data and enhance machine learning model performance.
By Nathan Burg

Every machine learning engineer has spent days, if not weeks, trying to fix misclassification issues with their models—tuning hyperparameters, augmenting data, and testing every possible approach—only to discover that the root cause was bad data hiding within the dataset. Manually parsing through thousands of images in search of mislabeled or corrupted samples can feel like an endless and frustrating task. But what if there was a more scientific and efficient way to diagnose these issues?

By leveraging embedding space and dimensionality reduction techniques, we can visualize patterns in our data, spot problematic samples, and clean our image classification datasets with precision—all without the guesswork. In this article, I’ll introduce a tool that does just that, helping you identify and exclude bad data in image classification datasets.

Practical Applications of the Dataset Cleaning Tool

In practice, this tool offers a data-driven approach to cleaning image classification datasets. By visualizing the embedding space, you can:

  • Spot Anomalies: Identify images that are outliers or don't fit well within their labeled cluster.
  • Detect Misclassifications: Find images that may have been mislabeled.
  • Improve Dataset Quality: Exclude bad data to enhance the performance of your machine learning models.

Filtering out bad training data becomes a manageable task, leading to more accurate and reliable models without the need to manually sift through thousands of images.

The Challenge of Bad Data in Machine Learning

Bad data in training sets can severely hamper the performance of machine learning models, especially in image classification tasks. Mislabelled images, corrupted files, and outliers can lead to poor model generalization and inaccurate predictions. Identifying and removing these problematic samples is crucial for improving model accuracy and reliability.

Introducing a Tool for Efficient Dataset Cleaning

To address these challenges, I developed a tool that streamlines the process of cleaning image classification datasets. This tool utilizes embedding space visualization and dimensionality reduction to help you efficiently identify and exclude problematic data from your training sets.

How the Dataset Cleaning Tool Works

1. Loading and Preparing the Data

The tool starts by reading an NPZ file that contains:

  • Image Embeddings: High-dimensional vectors representing each image.
  • Labels: The corresponding class labels for each embedding.
  • Image Paths: Absolute paths to the original image files.

To generate an NPZ file with this structure, you can use the following code to save the embeddings, labels, and image paths:


import numpy as np

embeddings = []  # List of embeddings
labels = []      # List of labels
image_paths = [] # List of image paths

   embeddings=np.vstack(embeddings),  # Stack embeddings into a 2D array
   labels=np.array(labels),           # Convert labels to a 1D NumPy array
   image_paths=np.array(image_paths)  # Convert image paths to a 1D NumPy array

This NPZ file format allows the tool to easily manage large image classification datasets for analysis.

2. Reducing Embeddings to 3D

To make the high-dimensional image embeddings easier to visualize, the tool uses UMAP (Uniform Manifold Approximation and Projection) for dimensionality reduction. This technique preserves the relative distances between points while compressing them into a 3D space.


import umap

def reduce_embeddings(embeddings: np.ndarray) -> np.ndarray:
   reducer = umap.UMAP(n_components=3)
   return reducer.fit_transform(embeddings)

3. Visualizing Embeddings

After reducing the embeddings, the tool generates an interactive 3D scatter plot using matplotlib. Each point represents a reduced embedding, color-coded by its label. This visualization helps in spotting clusters, anomalies, and outliers within the dataset.


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random

def graph_embeddings(reduced_embeddings: np.ndarray, labels: np.ndarray, paths: List[str]) -> None:
   def random_color() -> Tuple[float, float, float]:
       return random.random(), random.random(), random.random()

   unique_labels = np.unique(labels)
   label_to_color = {label: random_color() for label in unique_labels}
   point_colors = [label_to_color[label] for label in labels]

   fig = plt.figure(figsize=(10, 8))
   ax = fig.add_subplot(111, projection='3d')
   scatter = ax.scatter(
       reduced_embeddings[:, 0],
       reduced_embeddings[:, 1],
       reduced_embeddings[:, 2],

   handles = [
       plt.Line2D([0], [0], marker='o', color='w',
                  markerfacecolor=label_to_color[label], markersize=10)
       for label in unique_labels
   ax.legend(handles, unique_labels, loc='best', title='Classes',
             fontsize='small', markerscale=2, frameon=True)

   ax.set_title('3D Cluster Map of Image Embeddings')

   fig.canvas.mpl_connect('pick_event', lambda event: on_pick(event, labels, paths))

4. Identifying and Excluding Bad Data

When you click on a point in the 3D plot, the tool retrieves the corresponding image and label, displaying it for inspection. If the image is mislabeled or corrupted, you can exclude it by clicking the "Exclude" button. The tool logs the paths of excluded images to a text file, which can be used to filter out bad training data in future model runs.


from PIL import Image
from matplotlib.widgets import Button
from datetime import datetime


def save_to_txt(paths: List[str]) -> None:
   unique_paths = list(set(paths))
   now =
   file_name = f"images_to_exclude_from_training_{now.strftime('%Y_%m_%d')}.txt"
   with open(file_name, "w") as file:
       file.writelines(f"{path}\n" for path in unique_paths)

def on_pick(event, labels: np.ndarray, paths: List[str]) -> None:
   ind = event.ind[0]
   image_path = paths[ind]
   image =
   label = labels[ind]

   fig, ax = plt.subplots()
   plt.title(f"Label: {label}")

   def exclude(event) -> None:
       print(f"Excluded: {image_path}")

   ax_button = plt.axes([0.8, 0.01, 0.1, 0.075])
   btn = Button(ax_button, 'Exclude')
   print(f"Label: {label}, Image Path: {image_path}")

Key Features of the Dataset Cleaning Tool

  • Visualization of Embeddings in 3D Space: Reduces high-dimensional embeddings to 3D using UMAP, allowing for easy visualization and exploration.
  • Interactive Data Exploration: Enables rotation, zooming, and clicking on data points to inspect individual samples.
  • Efficient Identification of Bad Data: Helps spot mislabeled or corrupted images by revealing anomalies and outliers in the embedding space.
  • Simple Exclusion Mechanism: Provides an easy way to exclude bad data from future training by logging the paths of problematic images.

Getting Started with the Dataset Cleaning Tool

To explore how this tool can help clean your image classification datasets, head over to the GitHub repository where you’ll find all the code, along with detailed instructions on how to run the script.

Conclusion: Improving Machine Learning Model Performance

This tool provides a simple yet powerful way to identify and exclude problematic data from image classification datasets by leveraging embedding space visualization and dimensionality reduction techniques like UMAP. By filtering out bad training data, you can significantly improve your machine learning model's performance.

If you’re facing similar challenges or need assistance with dataset quality or machine learning projects, feel free to connect with me on LinkedIn


