MNIST classification
An MNIST classification using skeletal graphs.
A demo on using SN-graphs as inputs to Graph Neural Networks. The model achieves roughly 90% accuracy.
Load MNIST and show some sn-graphs
import sys
sys.path.append("../src")
from sklearn.datasets import fetch_openml
import sn_graph as sn
import matplotlib.pyplot as plt
# Load MNIST
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
# Reshape to (n_samples, 28, 28)
X = X.reshape(-1, 28, 28)
# Normalize
X = X / 255.0
X.shape
NUM_IMAGES = 20
graphs = []
images = []
for img in X[:NUM_IMAGES]:
images.append(img)
graphs.append(
sn.create_sn_graph(
img,
minimal_sphere_radius=1,
edge_sphere_threshold=0.9,
edge_threshold=0.9,
return_sdf=True,
)
)
for i in range(len(graphs)):
vertices, edges, sdf_array = graphs[i]
img = sn.draw_sn_graph(vertices, edges, sdf_array, background_image=images[i])
plt.imshow(img, cmap="gray")
plt.title(
f"""Label: {y[i]}\n Number of vertices: {len(vertices)} \n Number of edges: {len(edges)}"""
)
plt.axis("off")
plt.show()
Implement dataset, Graph Neural Network, and training algorithm
import torch
import numpy as np
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from torch_geometric.data import Data
class Mnist_Graph_Dataset(Dataset):
def __init__(self, precompute_graphs=False):
print("Initialising dataset...")
X, targets = self._load_mnist()
self.X = X
self.targets = targets
if precompute_graphs:
self.graphs = [
self._getitem(index)
for index in tqdm(range(len(self.X)), "Precomputing graphs")
]
else:
self.graphs = None
def _load_mnist(self):
# Load MNIST
X, targets = fetch_openml(
"mnist_784", version=1, return_X_y=True, as_frame=False
)
# Reshape to (n_samples, 28, 28)
X = X.reshape(-1, 28, 28)
# Normalize
X = X / 255.0
return X, targets
def __len__(self):
return len(self.X)
def _graph_to_tensor(self, vertex_coords, edges, sdf_values, use_pos_features=True):
vertex_coords = np.array(vertex_coords)
pos = torch.tensor(vertex_coords, dtype=torch.float)
coord_to_idx = {tuple(coord): idx for idx, coord in enumerate(vertex_coords)}
src_indices = []
dst_indices = []
edge_lengths = []
for start_coord, end_coord in edges:
src_indices.append(coord_to_idx[tuple(start_coord)])
dst_indices.append(coord_to_idx[tuple(end_coord)])
edge_lengths.append(
np.linalg.norm(np.array(end_coord) - np.array(start_coord))
)
edge_index = torch.tensor([src_indices, dst_indices], dtype=torch.long)
edge_attr = torch.tensor(edge_lengths, dtype=torch.float).view(-1, 1)
x = torch.tensor(
[sdf_values[tuple(coord)] for coord in vertex_coords], dtype=torch.float
).view(-1, 1)
pos_enc = pos / pos.max(dim=0)[0]
if use_pos_features:
# Stack SDF values with positional features
x = torch.cat([x, pos_enc], dim=1)
geom_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
else:
geom_data = Data(
x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos_enc
)
return geom_data
def _getitem(self, index):
img = X[index]
graph = sn.create_sn_graph(
img,
minimal_sphere_radius=1,
edge_sphere_threshold=0.8,
edge_threshold=0.9,
return_sdf=True,
)
if len(graph[0]) < 2 or len(graph[1]) < 2:
return self._getitem(index + 1)
geom_data = self._graph_to_tensor(*graph)
geom_data.y = torch.tensor([int(self.targets[index])], dtype=torch.long)
return geom_data
def __getitem__(self, index):
return self.graphs[index] if self.graphs is not None else self._getitem(index)
dataset = Mnist_Graph_Dataset(precompute_graphs=True)
print(f"Length of the dataset: {len(dataset)}")
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_max_pool
class GraphClassificationNetwork(torch.nn.Module):
def __init__(self, vertex_features_dim, edge_features_dim, no_classes):
super().__init__()
# if binary segmentation, then choose one class as number of classes.
self.no_classes = no_classes
self.conv1 = GATConv(
vertex_features_dim, 8, heads=2, edge_dim=edge_features_dim
)
self.conv2 = GATConv(16, 32, heads=2, edge_dim=2)
self.conv3 = GATConv(64, 128, edge_dim=2)
self.lin1 = nn.Linear(128, 32)
self.lin2 = nn.Linear(32, no_classes)
def forward(self, data):
# Graph convolutions
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
x, (edge_index, edge_attr) = self.conv1(
x, edge_index, edge_attr, return_attention_weights=True
)
x = F.relu(x)
x, (edge_index, edge_attr) = self.conv2(
x, edge_index, edge_attr, return_attention_weights=True
)
x = F.relu(x)
x = self.conv3(x, edge_index, edge_attr)
x = F.relu(x)
# Global max pooling - will automatically handle both batched and single samples
batch = getattr(data, "batch", None)
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
x = global_max_pool(x, batch)
# FC layers for graph-level prediction
x = self.lin1(x)
x = F.relu(x)
x = F.dropout(x, p=0.1, training=self.training)
x = self.lin2(x)
return x
model = GraphClassificationNetwork(
vertex_features_dim=3, edge_features_dim=1, no_classes=10
)
output = model(dataset[0])
print(output.shape)
print(output)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
import torch
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
# Assuming dataset and model are defined
indices = list(range(len(dataset)))
print("Creating train-test split")
# Train-test split
train_indices, test_indices = train_test_split(indices, test_size=0.15, random_state=42)
# Create datasets using torch.utils.data.Subset
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
print("Creating train loader")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
print("Creating test loader")
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model = GraphClassificationNetwork(
vertex_features_dim=3, edge_features_dim=1, no_classes=10
).to(device)
# 0.001 or 0.0005 was the best so far!
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() # Change to BCELoss for binary classification
def train(loader, epoch):
model.train()
total_loss = 0
correct = 0
total = 0
for data in tqdm(loader, desc=f"Epoch {epoch}"):
data = data.to(device)
optimizer.zero_grad()
# Forward pass
out = model(data)
loss = criterion(out, data.y)
# Backward pass
loss.backward()
optimizer.step()
# Accumulate metrics
total_loss += loss.item() * data.num_graphs
pred = out.argmax(dim=1)
correct += (pred == data.y).sum().item()
total += data.num_graphs
return total_loss / len(loader), correct / total
def test(loader, epoch):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data in tqdm(loader, desc=f"Epoch {epoch}"):
data = data.to(device)
# Forward pass only
out = model(data)
loss = criterion(out, data.y)
# Accumulate metrics
total_loss += loss.item() * data.num_graphs
pred = out.argmax(dim=1)
correct += (pred == data.y).sum().item()
total += data.num_graphs
return total_loss / len(loader), correct / total
# Training loop
num_epochs = 20
best_acc = 0
for epoch in range(1, num_epochs + 1):
train_loss, train_acc = train(train_loader, epoch)
test_loss, test_acc = test(test_loader, epoch)
# Print metrics
print(
f"Epoch {epoch:03d}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}"
)
# Save best model
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), "best_model.pth")
print(f"Best test accuracy: {best_acc:.4f}")