By Nathan Burg
Every machine learning engineer has spent days, if not weeks, trying to fix misclassification issues with their models—tuning hyperparameters, augmenting data, and testing every possible approach—only to discover that the root cause was bad data hiding within the dataset. Manually parsing through thousands of images in search of mislabeled or corrupted samples can feel like an endless and frustrating task. But what if there was a more scientific and efficient way to diagnose these issues?
By leveraging embedding space and dimensionality reduction techniques, we can visualize patterns in our data, spot problematic samples, and clean our image classification datasets with precision—all without the guesswork. In this article, I’ll introduce a tool that does just that, helping you identify and exclude bad data in image classification datasets.
In practice, this tool offers a data-driven approach to cleaning image classification datasets. By visualizing the embedding space, you can:
Filtering out bad training data becomes a manageable task, leading to more accurate and reliable models without the need to manually sift through thousands of images.
Bad data in training sets can severely hamper the performance of machine learning models, especially in image classification tasks. Mislabelled images, corrupted files, and outliers can lead to poor model generalization and inaccurate predictions. Identifying and removing these problematic samples is crucial for improving model accuracy and reliability.
To address these challenges, I developed a tool that streamlines the process of cleaning image classification datasets. This tool utilizes embedding space visualization and dimensionality reduction to help you efficiently identify and exclude problematic data from your training sets.
The tool starts by reading an NPZ file that contains:
To generate an NPZ file with this structure, you can use the following code to save the embeddings, labels, and image paths:
python
Copy code
import numpy as np
embeddings = [] # List of embeddings
labels = [] # List of labels
image_paths = [] # List of image paths
np.savez_compressed(
"image_embeddings_labels_paths.npz",
embeddings=np.vstack(embeddings), # Stack embeddings into a 2D array
labels=np.array(labels), # Convert labels to a 1D NumPy array
image_paths=np.array(image_paths) # Convert image paths to a 1D NumPy array
)
This NPZ file format allows the tool to easily manage large image classification datasets for analysis.
To make the high-dimensional image embeddings easier to visualize, the tool uses UMAP (Uniform Manifold Approximation and Projection) for dimensionality reduction. This technique preserves the relative distances between points while compressing them into a 3D space.
python
Copy code
import umap
def reduce_embeddings(embeddings: np.ndarray) -> np.ndarray:
reducer = umap.UMAP(n_components=3)
return reducer.fit_transform(embeddings)
After reducing the embeddings, the tool generates an interactive 3D scatter plot using matplotlib. Each point represents a reduced embedding, color-coded by its label. This visualization helps in spotting clusters, anomalies, and outliers within the dataset.
python
Copy code
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
def graph_embeddings(reduced_embeddings: np.ndarray, labels: np.ndarray, paths: List[str]) -> None:
def random_color() -> Tuple[float, float, float]:
return random.random(), random.random(), random.random()
unique_labels = np.unique(labels)
label_to_color = {label: random_color() for label in unique_labels}
point_colors = [label_to_color[label] for label in labels]
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(
reduced_embeddings[:, 0],
reduced_embeddings[:, 1],
reduced_embeddings[:, 2],
c=point_colors,
s=5,
picker=True
)
handles = [
plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=label_to_color[label], markersize=10)
for label in unique_labels
]
ax.legend(handles, unique_labels, loc='best', title='Classes',
fontsize='small', markerscale=2, frameon=True)
ax.set_title('3D Cluster Map of Image Embeddings')
ax.set_xlabel('UMAP-1')
ax.set_ylabel('UMAP-2')
ax.set_zlabel('UMAP-3')
fig.canvas.mpl_connect('pick_event', lambda event: on_pick(event, labels, paths))
plt.show()
When you click on a point in the 3D plot, the tool retrieves the corresponding image and label, displaying it for inspection. If the image is mislabeled or corrupted, you can exclude it by clicking the "Exclude" button. The tool logs the paths of excluded images to a text file, which can be used to filter out bad training data in future model runs.
python
from PIL import Image
from matplotlib.widgets import Button
from datetime import datetime
EXCLUDED_PATHS = []
def save_to_txt(paths: List[str]) -> None:
unique_paths = list(set(paths))
now = datetime.now()
file_name = f"images_to_exclude_from_training_{now.strftime('%Y_%m_%d')}.txt"
with open(file_name, "w") as file:
file.writelines(f"{path}\n" for path in unique_paths)
def on_pick(event, labels: np.ndarray, paths: List[str]) -> None:
ind = event.ind[0]
image_path = paths[ind]
image = Image.open(image_path)
label = labels[ind]
fig, ax = plt.subplots()
plt.imshow(image)
plt.title(f"Label: {label}")
plt.axis('off')
def exclude(event) -> None:
EXCLUDED_PATHS.append(image_path)
save_to_txt(EXCLUDED_PATHS)
plt.close(fig)
print(f"Excluded: {image_path}")
ax_button = plt.axes([0.8, 0.01, 0.1, 0.075])
btn = Button(ax_button, 'Exclude')
btn.on_clicked(exclude)
plt.show()
print(f"Label: {label}, Image Path: {image_path}")
To explore how this tool can help clean your image classification datasets, head over to the GitHub repository where you’ll find all the code, along with detailed instructions on how to run the script.
This tool provides a simple yet powerful way to identify and exclude problematic data from image classification datasets by leveraging embedding space visualization and dimensionality reduction techniques like UMAP. By filtering out bad training data, you can significantly improve your machine learning model's performance.
If you’re facing similar challenges or need assistance with dataset quality or machine learning projects, feel free to connect with me on LinkedIn
Keywords: image classification datasets, machine learning, data cleaning, filtering out bad training data, embedding space, dimensionality reduction, UMAP, visualizing classification predictions, dataset quality.