Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
from torchvision.utils import save_image | |
import numpy as np | |
import gradio as gr | |
IMAGE_SIZE = 244 # VGG image input size - we use VGG 19 as our pretrained CNN | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
cnn = models.vgg19(weights=None) | |
state_dict = torch.load("vgg19-dcbb9e9d.pth") | |
cnn.load_state_dict(state_dict) | |
class VGG(nn.Module): | |
def __init__(self): | |
super(VGG, self).__init__() | |
self.layers = ['0', '5', '10', '19', '28'] # layers we use as representations | |
self.model = cnn.features[:29] # we don't care about later layers | |
def forward(self, x): | |
features = [] | |
for layer_num, layer in enumerate(self.model): | |
x = layer(x) | |
# we don't care about the model output - we care about the output of individual layers | |
if str(layer_num) in self.layers: | |
features.append(x) | |
return features | |
gradio_transforms = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]) | |
]) | |
def sanitize_inputs(epochs, lr, cl, sl): | |
if epochs < 1: | |
return ["Epochs must be positive", None] | |
if not isinstance(epochs, int): | |
return ["Epochs must be an integer", None] | |
if lr < 0: | |
return ["Learning rate must be positive", None] | |
if lr > 1: | |
return ["Learning rate must be less than one", None] | |
if cl < 0 or cl > 1: | |
return ["Content loss weight must be between 0 and 1", None] | |
if sl < 0 or sl > 1: | |
return ["Style loss weight must be between 0 and 1", None] | |
return None | |
def train(Epochs, Learning_Rate, Content_Loss, Style_Loss, Content_Image, Style_Image): | |
errors = sanitize_inputs(Epochs, Learning_Rate, Content_Loss, Style_Loss) | |
if errors is not None: | |
return errors | |
test = Content_Image | |
content = gradio_transforms(Content_Image).unsqueeze(0).to(device) | |
style = gradio_transforms(Style_Image).unsqueeze(0).to(device) | |
generated = content.clone().requires_grad_(True).to(device) | |
model = VGG().to(device).eval() | |
optimizer = optim.Adam([generated], lr=Learning_Rate) | |
for epoch in range(Epochs): | |
generatedFeatures = model(generated) | |
contentFeatures = model(content) | |
styleFeatures = model(style) | |
styleLoss = 0 | |
contentLoss = 0 | |
for genFeat, contFeat, styleFeat in zip(generatedFeatures, contentFeatures, styleFeatures): | |
batch_size, channel, height, width = genFeat.shape | |
contentLoss += torch.mean((genFeat - contFeat) ** 2) | |
G = genFeat.view(channel, height * width).mm(genFeat.view(channel, height * width).t()) | |
A = styleFeat.view(channel, height * width).mm(styleFeat.view(channel, height * width).t()) | |
styleLoss += torch.mean((G - A) ** 2) | |
total_loss = Content_Loss * contentLoss + Style_Loss * styleLoss | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() | |
save_image(generated, "generated_gradio.png") | |
return ["No errors! Enjoy your new image!", "generated_gradio.png"] | |
demo = gr.Interface( | |
fn=train, | |
inputs=["number", "number", "number", "number", "image", "image"], | |
outputs=[ | |
gr.Label(label="Error Messages"), | |
gr.Image(label="Generated Image"), | |
], | |
title="Neural Style Transfer", | |
description="Perform neural style transfer on images of your choice! Provide a content image that contains the content you want to transform and a style image that contains the style you want to emulate.\n\nNote: Huggingface requires users to pay to gain access to GPUs, so this model is hosted on a cpu. Training for many epochs will take a VERY long time. Using a larger learning rate (e.g., 0.01) can help reduce the number of epochs you need.", | |
theme=gr.themes.Soft() | |
) | |
demo.launch(debug=True) |