Skip to content

MNIST classification

An MNIST classification using skeletal graphs.

A demo on using SN-graphs as inputs to Graph Neural Networks. The model achieves roughly 90% accuracy.

Load MNIST and show some sn-graphs

import sys

sys.path.append("../src")

from sklearn.datasets import fetch_openml
import sn_graph as sn
import matplotlib.pyplot as plt
# Load MNIST
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

# Reshape to (n_samples, 28, 28)
X = X.reshape(-1, 28, 28)

# Normalize
X = X / 255.0
X.shape
(70000, 28, 28)
NUM_IMAGES = 20
graphs = []
images = []

for img in X[:NUM_IMAGES]:
    images.append(img)
    graphs.append(
        sn.create_sn_graph(
            img,
            minimal_sphere_radius=1,
            edge_sphere_threshold=0.9,
            edge_threshold=0.9,
            return_sdf=True,
        )
    )
for i in range(len(graphs)):
    vertices, edges, sdf_array = graphs[i]
    img = sn.draw_sn_graph(vertices, edges, sdf_array, background_image=images[i])

    plt.imshow(img, cmap="gray")
    plt.title(
        f"""Label: {y[i]}\n  Number of vertices: {len(vertices)} \n Number of edges: {len(edges)}"""
    )
    plt.axis("off")
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Implement dataset, Graph Neural Network, and training algorithm

import torch
import numpy as np
from tqdm.notebook import tqdm

from torch.utils.data import Dataset
from torch_geometric.data import Data


class Mnist_Graph_Dataset(Dataset):
    def __init__(self, precompute_graphs=False):
        print("Initialising dataset...")
        X, targets = self._load_mnist()
        self.X = X
        self.targets = targets
        if precompute_graphs:
            self.graphs = [
                self._getitem(index)
                for index in tqdm(range(len(self.X)), "Precomputing graphs")
            ]
        else:
            self.graphs = None

    def _load_mnist(self):
        # Load MNIST
        X, targets = fetch_openml(
            "mnist_784", version=1, return_X_y=True, as_frame=False
        )

        # Reshape to (n_samples, 28, 28)
        X = X.reshape(-1, 28, 28)

        # Normalize
        X = X / 255.0

        return X, targets

    def __len__(self):
        return len(self.X)

    def _graph_to_tensor(self, vertex_coords, edges, sdf_values, use_pos_features=True):
        vertex_coords = np.array(vertex_coords)
        pos = torch.tensor(vertex_coords, dtype=torch.float)

        coord_to_idx = {tuple(coord): idx for idx, coord in enumerate(vertex_coords)}

        src_indices = []
        dst_indices = []
        edge_lengths = []

        for start_coord, end_coord in edges:
            src_indices.append(coord_to_idx[tuple(start_coord)])
            dst_indices.append(coord_to_idx[tuple(end_coord)])
            edge_lengths.append(
                np.linalg.norm(np.array(end_coord) - np.array(start_coord))
            )

        edge_index = torch.tensor([src_indices, dst_indices], dtype=torch.long)
        edge_attr = torch.tensor(edge_lengths, dtype=torch.float).view(-1, 1)

        x = torch.tensor(
            [sdf_values[tuple(coord)] for coord in vertex_coords], dtype=torch.float
        ).view(-1, 1)
        pos_enc = pos / pos.max(dim=0)[0]

        if use_pos_features:
            # Stack SDF values with positional features
            x = torch.cat([x, pos_enc], dim=1)
            geom_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        else:
            geom_data = Data(
                x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos_enc
            )

        return geom_data

    def _getitem(self, index):
        img = X[index]

        graph = sn.create_sn_graph(
            img,
            minimal_sphere_radius=1,
            edge_sphere_threshold=0.8,
            edge_threshold=0.9,
            return_sdf=True,
        )

        if len(graph[0]) < 2 or len(graph[1]) < 2:
            return self._getitem(index + 1)
        geom_data = self._graph_to_tensor(*graph)
        geom_data.y = torch.tensor([int(self.targets[index])], dtype=torch.long)
        return geom_data

    def __getitem__(self, index):
        return self.graphs[index] if self.graphs is not None else self._getitem(index)
dataset = Mnist_Graph_Dataset(precompute_graphs=True)

print(f"Length of the dataset: {len(dataset)}")
Initialising dataset...

/Users/tomasz/code/sn-graph/src/sn_graph/core.py:282: RuntimeWarning: Image is empty or there are no spheres larger than the minimal_sphere_radius: 1. No vertices will be placed.
  warnings.warn(

Length of the dataset: 70000

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_max_pool


class GraphClassificationNetwork(torch.nn.Module):
    def __init__(self, vertex_features_dim, edge_features_dim, no_classes):
        super().__init__()
        # if binary segmentation, then choose one class as number of classes.
        self.no_classes = no_classes

        self.conv1 = GATConv(
            vertex_features_dim, 8, heads=2, edge_dim=edge_features_dim
        )
        self.conv2 = GATConv(16, 32, heads=2, edge_dim=2)
        self.conv3 = GATConv(64, 128, edge_dim=2)
        self.lin1 = nn.Linear(128, 32)
        self.lin2 = nn.Linear(32, no_classes)

    def forward(self, data):
        # Graph convolutions

        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x, (edge_index, edge_attr) = self.conv1(
            x, edge_index, edge_attr, return_attention_weights=True
        )
        x = F.relu(x)
        x, (edge_index, edge_attr) = self.conv2(
            x, edge_index, edge_attr, return_attention_weights=True
        )
        x = F.relu(x)
        x = self.conv3(x, edge_index, edge_attr)
        x = F.relu(x)

        # Global max pooling - will automatically handle both batched and single samples
        batch = getattr(data, "batch", None)
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        x = global_max_pool(x, batch)

        # FC layers for graph-level prediction
        x = self.lin1(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.lin2(x)
        return x
model = GraphClassificationNetwork(
    vertex_features_dim=3, edge_features_dim=1, no_classes=10
)

output = model(dataset[0])

print(output.shape)
print(output)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
torch.Size([1, 10])
tensor([[ 0.1797, -0.2116,  0.1570, -0.1489, -0.2024, -0.0044, -0.0317, -0.0720,
         -0.1042,  0.0161]], grad_fn=<AddmmBackward0>)
Model parameters: 14954

import torch
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split


from torch.utils.data import Subset

# Assuming dataset and model are defined
indices = list(range(len(dataset)))
print("Creating train-test split")

# Train-test split
train_indices, test_indices = train_test_split(indices, test_size=0.15, random_state=42)


# Create datasets using torch.utils.data.Subset
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
print("Creating train loader")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
print("Creating test loader")
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
Creating train-test split
Creating train loader
Creating test loader
Device: cpu

model = GraphClassificationNetwork(
    vertex_features_dim=3, edge_features_dim=1, no_classes=10
).to(device)
# 0.001 or 0.0005 was the best so far!
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()  # Change to BCELoss for binary classification


def train(loader, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data in tqdm(loader, desc=f"Epoch {epoch}"):
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass
        out = model(data)
        loss = criterion(out, data.y)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        total_loss += loss.item() * data.num_graphs
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.num_graphs

    return total_loss / len(loader), correct / total


def test(loader, epoch):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in tqdm(loader, desc=f"Epoch {epoch}"):
            data = data.to(device)

            # Forward pass only
            out = model(data)
            loss = criterion(out, data.y)

            # Accumulate metrics
            total_loss += loss.item() * data.num_graphs
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs

    return total_loss / len(loader), correct / total


# Training loop
num_epochs = 20
best_acc = 0

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train(train_loader, epoch)
    test_loss, test_acc = test(test_loader, epoch)

    # Print metrics
    print(
        f"Epoch {epoch:03d}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}"
    )

    # Save best model
    if test_acc &gt; best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), "best_model.pth")

print(f"Best test accuracy: {best_acc:.4f}")
Epoch 001, Train loss: 51.0624, Test loss: 39.7174, Train Acc: 0.4303, Test Acc: 0.5825

Epoch 002, Train loss: 36.6956, Test loss: 29.7825, Train Acc: 0.6084, Test Acc: 0.6897

Epoch 003, Train loss: 28.6116, Test loss: 22.8890, Train Acc: 0.6955, Test Acc: 0.7588

Epoch 004, Train loss: 23.5567, Test loss: 19.4011, Train Acc: 0.7556, Test Acc: 0.8025

Epoch 005, Train loss: 21.0890, Test loss: 17.5863, Train Acc: 0.7837, Test Acc: 0.8248

Epoch 006, Train loss: 19.5933, Test loss: 18.7304, Train Acc: 0.7998, Test Acc: 0.8068

Epoch 007, Train loss: 18.4863, Test loss: 16.1587, Train Acc: 0.8125, Test Acc: 0.8398

Epoch 008, Train loss: 17.9499, Test loss: 15.9200, Train Acc: 0.8188, Test Acc: 0.8397

Epoch 009, Train loss: 17.0918, Test loss: 14.1985, Train Acc: 0.8257, Test Acc: 0.8563

Epoch 010, Train loss: 16.6442, Test loss: 15.0479, Train Acc: 0.8309, Test Acc: 0.8487

Epoch 011, Train loss: 16.2155, Test loss: 15.4319, Train Acc: 0.8352, Test Acc: 0.8428

Epoch 012, Train loss: 15.6987, Test loss: 13.4786, Train Acc: 0.8430, Test Acc: 0.8630

Epoch 013, Train loss: 15.0589, Test loss: 13.7935, Train Acc: 0.8497, Test Acc: 0.8573

Epoch 014, Train loss: 14.7263, Test loss: 13.0528, Train Acc: 0.8529, Test Acc: 0.8703

Epoch 015, Train loss: 14.3037, Test loss: 12.1046, Train Acc: 0.8557, Test Acc: 0.8800

Epoch 016, Train loss: 14.1628, Test loss: 11.9766, Train Acc: 0.8584, Test Acc: 0.8798

Epoch 017, Train loss: 13.9132, Test loss: 12.2917, Train Acc: 0.8596, Test Acc: 0.8778

Epoch 018, Train loss: 13.5079, Test loss: 11.2566, Train Acc: 0.8658, Test Acc: 0.8884

Epoch 019, Train loss: 13.3973, Test loss: 11.2064, Train Acc: 0.8662, Test Acc: 0.8862

Epoch 020, Train loss: 13.0579, Test loss: 11.3276, Train Acc: 0.8701, Test Acc: 0.8868
Best test accuracy: 0.8884