import torch import torch.nn as nn import pytorch_lightning as pl import random from torchvision.datasets import MNIST, FashionMNIST, CelebA import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.utils import save_image from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau import os from typing import Optional class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class Stack(nn.Module): def __init__(self, channels, height, width): super(Stack, self).__init__() self.channels = channels self.height = height self.width = width def forward(self, x): return x.view(x.size(0), self.channels, self.height, self.width) class VAE(pl.LightningModule): def __init__(self, latent_size: int, hidden_size: int, alpha: int, lr: float, batch_size: int, dataset: Optional[str] = None, save_images: Optional[bool] = None, save_path: Optional[str] = None, **kwargs): """Init function for the VAE Args: latent_size (int): Latent Hidden Size alpha (int): Hyperparameter to control the importance of reconstruction loss vs KL-Divergence Loss lr (float): Learning Rate, will not be used if auto_lr_find is used. dataset (Optional[str]): Dataset to used save_images (Optional[bool]): Boolean to decide whether to save images save_path (Optional[str]): Path to save images """ super().__init__() self.latent_size = latent_size self.hidden_size = hidden_size if save_images: self.save_path = f'{save_path}/{kwargs["model_type"]}_images/' self.save_hyperparameters() self.save_images = save_images self.lr = lr self.batch_size = batch_size self.encoder = nn.Sequential( Flatten(), nn.Linear(784, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1), nn.Linear(392, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1), nn.Linear(196, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1), nn.Linear(128, latent_size) ) self.hidden2mu = nn.Linear(latent_size, latent_size) self.hidden2log_var = nn.Linear(latent_size, latent_size) self.alpha = alpha self.decoder = nn.Sequential( nn.Linear(latent_size, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1), nn.Linear(128, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1), nn.Linear(196, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1), nn.Linear(392, 784), Stack(1, 28, 28), nn.Tanh() ) self.height = kwargs.get("height") self.width = kwargs.get("width") self.data_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x:2*x-1.)]) self.dataset = dataset def encode(self, x): hidden = self.encoder(x) mu = self.hidden2mu(hidden) log_var = self.hidden2log_var(hidden) return mu, log_var def decode(self, x): x = self.decoder(x) return x def reparametrize(self, mu, log_var): # Reparametrization Trick to allow gradients to backpropagate from the # stochastic part of the model sigma = torch.exp(0.5*log_var) z = torch.randn_like(sigma) return mu + sigma*z def training_step(self, batch, batch_idx): x, _ = batch mu, log_var, x_out = self.forward(x) kl_loss = (-0.5*(1+log_var - mu**2 - torch.exp(log_var)).sum(dim=1)).mean(dim=0) recon_loss_criterion = nn.MSELoss() recon_loss = recon_loss_criterion(x, x_out) # print(kl_loss.item(),recon_loss.item()) loss = recon_loss*self.alpha + kl_loss self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, _ = batch mu, log_var, x_out = self.forward(x) kl_loss = (-0.5*(1+log_var - mu**2 - torch.exp(log_var)).sum(dim=1)).mean(dim=0) recon_loss_criterion = nn.MSELoss() recon_loss = recon_loss_criterion(x, x_out) # print(kl_loss.item(),recon_loss.item()) loss = recon_loss*self.alpha + kl_loss self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True) self.log('val_recon_loss', recon_loss, on_step=False, on_epoch=True) self.log('val_loss', loss, on_step=False, on_epoch=True) # print(x.mean(),x_out.mean()) return x_out, loss def validation_epoch_end(self, outputs): if not self.save_images: return if not os.path.exists(self.save_path): os.makedirs(self.save_path) choice = random.choice(outputs) output_sample = choice[0] output_sample = output_sample.reshape(-1, 1, self.width, self.height) # output_sample = self.scale_image(output_sample) save_image( output_sample, f"{self.save_path}/epoch_{self.current_epoch+1}.png", # value_range=(-1, 1) ) def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate)) lr_scheduler = ReduceLROnPlateau(optimizer,) return { "optimizer": optimizer, "lr_scheduler": lr_scheduler, "monitor": "val_loss" } def forward(self, x): mu, log_var = self.encode(x) hidden = self.reparametrize(mu, log_var) output = self.decode(hidden) return mu, log_var, output # Functions for dataloading def train_dataloader(self): if self.dataset == "mnist": train_set = MNIST('data/', download=True, train=True, transform=self.data_transform) elif self.dataset == "fashion-mnist": train_set = FashionMNIST( 'data/', download=True, train=True, transform=self.data_transform) elif self.dataset == "celeba": train_set = CelebA('data/', download=False, split="train", transform=self.data_transform) return DataLoader(train_set, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): if self.dataset == "mnist": val_set = MNIST('data/', download=True, train=False, transform=self.data_transform) elif self.dataset == "fashion-mnist": val_set = FashionMNIST( 'data/', download=True, train=False, transform=self.data_transform) elif self.dataset == "celeba": val_set = CelebA('data/', download=False, split="valid", transform=self.data_transform) return DataLoader(val_set, batch_size=self.batch_size) def scale_image(self, img): out = (img + 1) / 2 return out def interpolate(self, x1, x2): assert x1.shape == x2.shape, "Inputs must be of the same shape" if x1.dim() == 3: x1 = x1.unsqueeze(0) if x2.dim() == 3: x2 = x2.unsqueeze(0) if self.training: raise Exception( "This function should not be called when model is still " "in training mode. Use model.eval() before calling the " "function") mu1, lv1 = self.encode(x1) mu2, lv2 = self.encode(x2) z1 = self.reparametrize(mu1, lv1) z2 = self.reparametrize(mu2, lv2) weights = torch.arange(0.1, 0.9, 0.1) intermediate = [self.decode(z1)] for wt in weights: inter = (1.-wt)*z1 + wt*z2 intermediate.append(self.decode(inter)) intermediate.append(self.decode(z2)) out = torch.stack(intermediate, dim=0).squeeze(1) return out, (mu1, lv1), (mu2, lv2)