StyleTransfer / app.py
John Guerrerio
removed share=True
5f0eb84
raw
history blame
3.83 kB
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import numpy as np
import gradio as gr
IMAGE_SIZE = 244 # VGG image input size - we use VGG 19 as our pretrained CNN
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cnn = models.vgg19(weights=None)
state_dict = torch.load("vgg19-dcbb9e9d.pth")
cnn.load_state_dict(state_dict)
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
self.layers = ['0', '5', '10', '19', '28'] # layers we use as representations
self.model = cnn.features[:29] # we don't care about later layers
def forward(self, x):
features = []
for layer_num, layer in enumerate(self.model):
x = layer(x)
# we don't care about the model output - we care about the output of individual layers
if str(layer_num) in self.layers:
features.append(x)
return features
gradio_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Resize([IMAGE_SIZE, IMAGE_SIZE])
])
def sanitize_inputs(epochs, lr, cl, sl):
if epochs < 1:
return ["Epochs must be positive", None]
if not isinstance(epochs, int):
return ["Epochs must be an integer", None]
if lr < 0:
return ["Learning rate must be positive", None]
if lr > 1:
return ["Learning rate must be less than one", None]
if cl < 0 or cl > 1:
return ["Content loss weight must be between 0 and 1", None]
if sl < 0 or sl > 1:
return ["Style loss weight must be between 0 and 1", None]
return None
def train(Epochs, Learning_Rate, Content_Loss, Style_Loss, Content_Image, Style_Image):
errors = sanitize_inputs(Epochs, Learning_Rate, Content_Loss, Style_Loss)
if errors is not None:
return errors
test = Content_Image
content = gradio_transforms(Content_Image).unsqueeze(0).to(device)
style = gradio_transforms(Style_Image).unsqueeze(0).to(device)
generated = content.clone().requires_grad_(True).to(device)
model = VGG().to(device).eval()
optimizer = optim.Adam([generated], lr=Learning_Rate)
for epoch in range(Epochs):
generatedFeatures = model(generated)
contentFeatures = model(content)
styleFeatures = model(style)
styleLoss = 0
contentLoss = 0
for genFeat, contFeat, styleFeat in zip(generatedFeatures, contentFeatures, styleFeatures):
batch_size, channel, height, width = genFeat.shape
contentLoss += torch.mean((genFeat - contFeat) ** 2)
G = genFeat.view(channel, height * width).mm(genFeat.view(channel, height * width).t())
A = styleFeat.view(channel, height * width).mm(styleFeat.view(channel, height * width).t())
styleLoss += torch.mean((G - A) ** 2)
total_loss = Content_Loss * contentLoss + Style_Loss * styleLoss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
save_image(generated, "generated_gradio.png")
return ["No errors! Enjoy your new image!", "generated_gradio.png"]
demo = gr.Interface(
fn=train,
inputs=["number", "number", "number", "number", "image", "image"],
outputs=[
gr.Label(label="Error Messages"),
gr.Image(label="Generated Image"),
],
title="Neural Style Transfer",
description="Perform neural style transfer on images of your choice! Provide a content image that contains the content you want to transform and a style image that contains the style you want to emulate.\n\nNote: Huggingface requires users to pay to gain access to GPUs, so this model is hosted on a cpu. Training for many epochs will take a VERY long time. Using a larger learning rate (e.g., 0.01) can help reduce the number of epochs you need.",
theme=gr.themes.Soft()
)
demo.launch(debug=True)