import gradio as gr import numpy as np import torch from torch import nn import imageio import cv2 class Generator(nn.Module): # Refer to the link below for explanations about nc, nz, and ngf # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs def __init__(self, nc=4, nz=100, ngf=64): super(Generator, self).__init__() self.network = nn.Sequential( nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.Tanh(), ) def forward(self, input): output = self.network(input) return output def display_gif(file_name, save_name): images = [] for frame in range(8): frame_name = '%d' % (frame) image_filename = file_name + frame_name + '.png' images.append(imageio.imread(image_filename)) gif_filename = 'avatar_source.gif' return imageio.mimsave(gif_filename, images) def display_gif_pad(file_name, save_name): images = [] for frame in range(8): frame_name = '%d' % (frame) image_filename = file_name + frame_name + '.png' image = imageio.imread(image_filename) image = image[:, :, :3] image_pad = cv2.copyMakeBorder(image, 0, 0, 125, 125, cv2.BORDER_CONSTANT, value=0) images.append(image_pad) return imageio.mimsave(save_name, images) def display_image(file_name): image_filename = file_name + '0' + '.png' print(image_filename) image = imageio.imread(image_filename) imageio.imwrite('image.png', image) def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target): # == Source Avatar == # body body_source = '0' # hair if hair_source == "green": hair_source = '0' elif hair_source == "yellow": hair_source = '2' elif hair_source == "rose": hair_source = '4' elif hair_source == "red": hair_source = '7' elif hair_source == "wine": hair_source = '8' # top if top_source == "brown": top_source = '0' elif top_source == "blue": top_source = '1' elif top_source == "white": top_source = '2' # bottom if bottom_source == "white": bottom_source = '0' elif bottom_source == "golden": bottom_source = '1' elif bottom_source == "red": bottom_source = '2' elif bottom_source == "silver": bottom_source = '3' file_name_source = './Sprite/frames/domain_1/' + action_source + '/' file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_' gif = display_gif_pad(file_name_source, 'avatar_source.gif') # == Target Avatar == # body body_target = '1' # hair if hair_target == "violet": hair_target = '1' elif hair_target == "silver": hair_target = '3' elif hair_target == "purple": hair_target = '5' elif hair_target == "grey": hair_target = '6' elif hair_target == "golden": hair_target = '9' # top if top_target == "grey": top_target = '3' elif top_target == "khaki": top_target = '4' elif top_target == "linen": top_target = '5' elif top_target == "ocre": top_target = '6' # bottom if bottom_target == "denim": bottom_target = '4' elif bottom_target == "olive": bottom_target = '5' elif bottom_target == "brown": bottom_target = '6' file_name_target = './Sprite/frames/domain_2/' + action_target + '/' file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_' gif_target = display_gif_pad(file_name_target, 'avatar_target.gif') return 'demo.gif' gr.Interface( run, inputs=[ gr.Textbox(value="Source Avatar - Human", interactive=False), gr.Radio(choices=["slash", "spellcard", "walk"], value="slash"), gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"), gr.Radio(choices=["brown", "blue", "white"], value="brown"), gr.Radio(choices=["white", "golden", "red", "silver"], value="white"), gr.Textbox(value="Target Avatar - Alien", interactive=False), gr.Radio(choices=["slash", "spellcard", "walk"], value="walk"), gr.Radio(choices=["violet", "silver", "purple", "grey", "golden"], value="golden"), gr.Radio(choices=["grey", "khaki", "linen", "ocre"], value="ocre"), gr.Radio(choices=["denim", "olive", "brown"], value="brown"), ], outputs=[ gr.components.Image(type="file", label="Domain Disentanglement"), ], live=True, title="TransferVAE for Unsupervised Video Domain Adaptation", ).launch()