Training a model¶

This is just an example notetebook, the full code for training models is implemented using hydra and allows to run multiple experiments in parallel.

In [1]:
# Navigate to the root of the project and import necessary libraries
from hydra import initialize, compose
initialize(config_path="../confs", job_name="notebook")

import sys
import os
from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath("../src"))
# We change the working directory to the root of the project
# Run this only once
root_path = Path.cwd().parent
os.chdir(root_path)
sys.path.append(root_path / "src")
/tmp/ipykernel_7494/1263814139.py:3: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="../confs", job_name="notebook")
In [2]:
from experiments.datasets.datasets import GaussianNoiseDataset
from torchvision import datasets, transforms
import os
os.environ["HYDRA_FULL_ERROR"] = "1"


transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to Tensor in [0, 1]
    transforms.Normalize(mean=[0.1307], std=[0.3081]),
    # reshape to 784
    transforms.Lambda(lambda x: x.view(28 * 28))
])

# 2. Load base MNIST dataset with normalization
base_mnist = datasets.MNIST(root="data", train=True, download=True, transform=transform)

# or load fashion-MNIST
# base_mnist = datasets.FashionMNIST(root="data", train=True, download=True, transform=transform)

# 3. Wrap it in GaussianNoiseDataset
noisy_mnist = GaussianNoiseDataset(base_dataset=base_mnist, noise_std=1)

# 4. Use with DataLoader
from torch.utils.data import DataLoader
loader = DataLoader(noisy_mnist, batch_size=512, shuffle=True)
In [3]:
from hydra.utils import instantiate
import os
os.environ["HYDRA_FULL_ERROR"] = "1"

import logging
logging.basicConfig(level=logging.INFO)

cfg = compose(config_name="flow_model/mnist_example")
flow_model = instantiate(cfg.flow_model)

trainer_cfg = compose(config_name="trainer/mnist_trainer")
trainer = instantiate(trainer_cfg.trainer)

# train the model
trainer.fit(flow_model, train_dataloaders=loader)
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
In [4]:
from torchdiffeq import odeint
import torch

# get a batch of data
loader = DataLoader(noisy_mnist, batch_size=256, shuffle=True)
x0, x1, y = next(iter(loader))

y_target = 3 # 0-9

samples = flow_model.sample(x0, y=y_target*torch.ones(x0.shape[0], dtype=torch.long))

# plot the generated samples
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils

def plot_samples(samples, nrow=8):
    # Reshape the samples to 28x28
    samples = samples.view(-1, 1, 28, 28)
    # Unnormalize the samples
    samples = (samples * 0.3081) + 0.1307
    samples.clamp_(0, 1)  # Clamp values to [0, 1]
    # Create a grid of images
    grid = vutils.make_grid(samples, nrow=nrow, padding=2, normalize=True)
    # Convert to numpy array
    grid = grid.permute(1, 2, 0).numpy()
    # Plot the grid
    plt.figure(figsize=(10, 10))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()


# Plot the generated samples
plot_samples(samples.detach().cpu(), nrow=16)