MNIST Handwritten Digit Classification
Handwritten Digit Recognition
This notebook covers:
- intro & imports
- loading the dataset
- normalizing the data
- one-hot encoding (for demonstration)
- previewing the data (
display_sample(num)) - defining a neural network (MLP)
- setting up loss & optimizer
- training loop
- evaluating & plotting accuracy/loss
- visualizing predictions
- ideas for improving model performance (CNN, augmentation, etc.)
Intro
Objective: Classify 28×28 grayscale digits 0–9. Approach: Small CNN from scratch → compare with LeNet‑5; label smoothing + augmentation (random affine, elastic). Data: 60k train / 10k test, balanced Result: TEST_ACC%; robust to small rotations (+/‑ 15°). Stack: PyTorch/TensorFlow, torchvision/keras, matplotlib.
Imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as npprint(torch.backends.mps.is_available())# set device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
deviceLoad the MNIST dataset
data_dir = Path('./data')
# basic transform: to tensor (scales to [0,1])
base_transform = transforms.ToTensor()
train_dataset = datasets.MNIST(
root=data_dir,
train=True,
download=True,
transform=base_transform,
)
test_dataset = datasets.MNIST(
root=data_dir,
train=False,
download=True,
transform=base_transform,
)
len(train_dataset), len(test_dataset)Normalize the data
PyTorch's ToTensor() already scales to [0,1]. If you want to normalize to MNIST's mean/std, do this:
mnist_mean = 0.1307
mnist_std = 0.3081
norm_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mnist_mean,), (mnist_std,)),
])
# apply to datasets
train_dataset.transform = norm_transform
test_dataset.transform = norm_transformOne-hot encode (for demonstration)
PyTorch's nn.CrossEntropyLoss expects integer class indices, not one-hot vectors. But we'll keep this section to illustrate one-hot encoding.
def to_one_hot(labels: torch.Tensor, num_classes: int = 10):
return torch.eye(num_classes)[labels]
# demo
demo_labels = torch.tensor([0, 1, 5])
to_one_hot(demo_labels, 10)Preview the data
def display_sample(idx: int):
img, label = train_dataset[idx]
plt.imshow(img.squeeze(0), cmap='gray')
plt.title(f'Label: {label}')
plt.axis('off')
plt.show()
# try a few
display_sample(0)
display_sample(123)MLP
Define a simple MLP model
class MNISTMLP(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.net = nn.Sequential(
nn.Linear(28*28, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10), # logits
)
def forward(self, x):
x = self.flatten(x)
return self.net(x)
model = MNISTMLP().to(device)
modelLoss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)DataLoaders and training loop
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
def train_one_epoch(model, dataloader, loss_fn, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# forward
preds = model(X)
loss = loss_fn(preds, y)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# stats
running_loss += loss.item() * X.size(0)
_, predicted = torch.max(preds, 1)
correct += (predicted == y).sum().item()
total += y.size(0)
epoch_loss = running_loss / total
epoch_acc = correct / total
return epoch_loss, epoch_acc
def evaluate(model, dataloader, loss_fn, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
preds = model(X)
loss = loss_fn(preds, y)
running_loss += loss.item() * X.size(0)
_, predicted = torch.max(preds, 1)
correct += (predicted == y).sum().item()
total += y.size(0)
epoch_loss = running_loss / total
epoch_acc = correct / total
return epoch_loss, epoch_accRun DataLoader
num_epochs = 5
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
for epoch in range(num_epochs):
train_loss, train_acc = train_one_epoch(model, train_loader, loss_fn, optimizer, device)
val_loss, val_acc = evaluate(model, test_loader, loss_fn, device)
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
print(f"Epoch {epoch+1}/{num_epochs} | "
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")Plot loss and accuracy
plt.figure()
plt.plot(history['train_loss'], label='train_loss')
plt.plot(history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss over epochs')
plt.show()
plt.figure()
plt.plot(history['train_acc'], label='train_acc')
plt.plot(history['val_acc'], label='val_acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy over epochs')
plt.show()Visualize predictions
def show_predictions(model, dataset, n: int = 10):
model.eval()
plt.figure(figsize=(12, 3))
for i in range(n):
img, label = dataset[i]
with torch.no_grad():
logits = model(img.unsqueeze(0).to(device))
pred_label = logits.argmax(dim=1).item()
plt.subplot(1, n, i+1)
plt.imshow(img.squeeze(0), cmap='gray')
color = 'green' if pred_label == label else 'red'
plt.title(f'T:{label}\nP:{pred_label}', color=color)
plt.axis('off')
plt.tight_layout()
plt.show()
show_predictions(model, test_dataset, n=10)Improving performance
Here are PyTorch-friendly ways to improve accuracy:
- Use a CNN (recommended for MNIST)
- Data augmentation (random rotations/shifts)
- Learning rate scheduling
- More epochs / AdamW / weight decay
mlp_history = history # keep MLP history
mlp_model = model # keep MLP modelCNN
Example
class MNISTCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 14x14
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 7x7
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x# define CNN (we already had the class in the notebook)
cnn_model = MNISTCNN().to(device)
cnn_loss_fn = nn.CrossEntropyLoss()
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=1e-3)cnn_history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
num_epochs_cnn = 5
for epoch in range(num_epochs_cnn):
train_loss, train_acc = train_one_epoch(
cnn_model, train_loader, cnn_loss_fn, cnn_optimizer, device
)
val_loss, val_acc = evaluate(
cnn_model, test_loader, cnn_loss_fn, device
)
cnn_history["train_loss"].append(train_loss)
cnn_history["train_acc"].append(train_acc)
cnn_history["val_loss"].append(val_loss)
cnn_history["val_acc"].append(val_acc)
print(
f"[CNN] Epoch {epoch+1}/{num_epochs_cnn} | "
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
)Show Predictions
show_predictions(cnn_model, test_dataset, n=10)Plot Model Comparisons
# compare loss
plt.figure()
plt.plot(mlp_history["val_loss"], label="MLP val loss")
plt.plot(cnn_history["val_loss"], label="CNN val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Validation Loss: MLP vs CNN")
plt.legend()
plt.ylim(bottom=0)
plt.show()
# compare accuracy
plt.figure()
plt.plot(mlp_history["val_acc"], label="MLP val acc")
plt.plot(cnn_history["val_acc"], label="CNN val acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy: MLP vs CNN")
plt.legend()
plt.ylim(bottom=0.9)
plt.show()CNN With optimal LR
cnn_model = MNISTCNN().to(device)
cnn_loss_fn = nn.CrossEntropyLoss()
# start with a reasonable LR
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=1e-3)
cnn_comparison_history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
num_epochs_cnn = 8
best_val_loss = float("inf")
best_state_dict = None
best_epoch = -1for epoch in range(num_epochs_cnn):
train_loss, train_acc = train_one_epoch(
cnn_model, train_loader, cnn_loss_fn, cnn_optimizer, device
)
val_loss, val_acc = evaluate(
cnn_model, test_loader, cnn_loss_fn, device
)
cnn_comparison_history["train_loss"].append(train_loss)
cnn_comparison_history["train_acc"].append(train_acc)
cnn_comparison_history["val_loss"].append(val_loss)
cnn_comparison_history["val_acc"].append(val_acc)
# 👇 save best
if val_loss < best_val_loss:
best_val_loss = val_loss
best_state_dict = cnn_model.state_dict()
best_epoch = epoch
print(
f"[CNN] Epoch {epoch+1}/{num_epochs_cnn} | "
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
)
print(f"Best val loss: {best_val_loss:.4f} at epoch {best_epoch+1}")
Train more epochs with best LR
extra_epochs = 3
for epoch in range(extra_epochs):
train_loss, train_acc = train_one_epoch(
cnn_model, train_loader, cnn_loss_fn, cnn_optimizer, device
)
val_loss, val_acc = evaluate(
cnn_model, test_loader, cnn_loss_fn, device
)
print(
f"[CONT] Epoch {epoch+1}/{extra_epochs} | "
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
)#
# "figure out" best LR
#
def try_lrs(model_cls, lrs, train_loader, test_loader, device):
results = []
for lr in lrs:
model = model_cls().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
train_loss, train_acc = train_one_epoch(model, train_loader, loss_fn, optimizer, device)
val_loss, val_acc = evaluate(model, test_loader, loss_fn, device)
results.append({
"lr": lr,
"train_loss": train_loss,
"val_loss": val_loss,
"val_acc": val_acc,
})
print(f"LR={lr:.5f} -> val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
return results
lrs_to_test = [1e-4, 3e-4, 1e-3, 0.0015, 0.002, 0.0025, 3e-3, 0.004, 0.007,0.0095, 1e-2]
lr_results = try_lrs(MNISTCNN, lrs_to_test, train_loader, test_loader, device)Pick Best LR
best_lr_entry = min(lr_results, key=lambda x: x["val_loss"])
best_lr = best_lr_entry["lr"]
print("Best LR found:", best_lr)Visualize LR vs Loss
plt.figure()
plt.plot([r["lr"] for r in lr_results], [r["val_loss"] for r in lr_results], marker="o")
plt.xscale("log")
plt.xlabel("learning rate")
plt.ylabel("val loss")
plt.title("LR range test (lower is better)")
plt.show()Retrain With Optimal LR
best_lr_model = MNISTCNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(best_lr_model.parameters(), lr=best_lr)
best_lr_history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
epochs = 6
for epoch in range(epochs):
train_loss, train_acc = train_one_epoch(best_lr_model, train_loader, loss_fn, optimizer, device)
val_loss, val_acc = evaluate(best_lr_model, test_loader, loss_fn, device)
best_lr_history["train_loss"].append(train_loss)
best_lr_history["train_acc"].append(train_acc)
best_lr_history["val_loss"].append(val_loss)
best_lr_history["val_acc"].append(val_acc)
print(
f"[Best-LR CNN] Epoch {epoch+1}/{epochs} | "
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
)CNN-Comparison Viz
plt.figure()
plt.plot(cnn_history["val_loss"], label="CNN (1e-3)")
plt.plot(best_lr_history["val_loss"], label=f"CNN (best lr={best_lr})")
plt.xlabel("Epoch")
plt.ylabel("Val loss")
plt.ylim(bottom=0)
plt.legend()
plt.title("CNN vs CNN (best LR)")
plt.show()
🧩 Optimizer Tuning & Learning Rate Insights
By experimenting with the learning rate and number of epochs, we can see how sensitive deep learning models are to optimization settings.
-
The learning rate determines how big each step in the gradient descent process is.
- Too high → unstable training or oscillating loss.
- Too low → very slow convergence.
- Just right → smooth and steady improvement.
-
The best loss value represents the model parameters that minimize the validation loss.
- Saving and restoring these weights (
best_state_dict) helps us retain the model at its optimal state before overfitting begins.
- Saving and restoring these weights (
-
Running additional epochs after reaching the best loss can reveal whether the model is still improving or starting to overfit.
-
A quick learning-rate range test helps identify where loss improves fastest.
- Training again with that “best” LR often yields faster convergence and better accuracy.
🧠 Takeaway:
Optimizing a neural network isn’t just about architecture — the optimizer settings (learning rate, epochs, batch size) can make or break model performance.
Even small tuning changes can turn a 97% model into a 99% model!