|
import gradio as gr |
|
import torch |
|
from torch import nn |
|
import imageio |
|
import cv2 |
|
|
|
|
|
class Generator(nn.Module): |
|
|
|
|
|
def __init__(self, nc=4, nz=100, ngf=64): |
|
super(Generator, self).__init__() |
|
self.network = nn.Sequential( |
|
nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False), |
|
nn.BatchNorm2d(ngf), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), |
|
nn.Tanh(), |
|
) |
|
|
|
def forward(self, input): |
|
output = self.network(input) |
|
return output |
|
|
|
|
|
def display_gif(file_name, save_name): |
|
images = [] |
|
|
|
for frame in range(8): |
|
frame_name = '%d' % (frame) |
|
image_filename = file_name + frame_name + '.png' |
|
images.append(imageio.imread(image_filename)) |
|
|
|
gif_filename = 'avatar_source.gif' |
|
return imageio.mimsave(gif_filename, images) |
|
|
|
|
|
def display_gif_pad(file_name, save_name): |
|
images = [] |
|
|
|
for frame in range(8): |
|
frame_name = '%d' % (frame) |
|
image_filename = file_name + frame_name + '.png' |
|
image = imageio.imread(image_filename) |
|
image = image[:, :, :3] |
|
image_pad = cv2.copyMakeBorder(image, 0, 0, 100, 100, cv2.BORDER_CONSTANT, value=0) |
|
images.append(image_pad) |
|
|
|
gif_filename = 'avatar_source.gif' |
|
return imageio.mimsave(gif_filename, images) |
|
|
|
|
|
def display_image(file_name): |
|
|
|
image_filename = file_name + '0' + '.png' |
|
print(image_filename) |
|
image = imageio.imread(image_filename) |
|
imageio.imwrite('image.png', image) |
|
|
|
|
|
def run(action, hair, top, bottom): |
|
|
|
|
|
|
|
|
|
|
|
body = '0' |
|
|
|
|
|
if hair == "green": hair = '0' |
|
elif hair == "yellow": hair = '2' |
|
elif hair == "rose": hair = '4' |
|
elif hair == "red": hair = '7' |
|
elif hair == "wine": hair = '8' |
|
|
|
|
|
if top == "brown": top = '0' |
|
elif top == "blue": top = '1' |
|
elif top == "white": top = '2' |
|
|
|
|
|
if bottom == "white": bottom = '0' |
|
elif bottom == "golden": bottom = '1' |
|
elif bottom == "red": bottom = '2' |
|
elif bottom == "silver": bottom = '3' |
|
|
|
file_name_source = './Sprite/frames/domain_1/' + action + '/' |
|
file_name_source = file_name_source + 'front' + '_' + str(body) + str(bottom) + str(top) + str(hair) + '_' |
|
|
|
gif = display_gif_pad(file_name_source, 'avatar_source.gif') |
|
|
|
|
|
body_target = '1' |
|
hair_target = np.random.choice('1', '3', '5', '6', '9') |
|
top_target = np.random.choice('3', '4', '5', '6') |
|
bottom_target = np.random.choice('4', '5', '6') |
|
|
|
file_name_target = './Sprite/frames/domain_2/' + action + '/' |
|
file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_' |
|
|
|
gif = display_gif_pad(file_name_target, 'avatar_target.gif') |
|
|
|
return 'avatar_source.gif', 'avatar_target.gif' |
|
|
|
|
|
gr.Interface( |
|
run, |
|
inputs=[ |
|
gr.Radio(choices=["shoot", "slash", "spellcard", "thrust", "walk"], value="shoot"), |
|
gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"), |
|
gr.Radio(choices=["brown", "blue", "white"], value="brown"), |
|
gr.Radio(choices=["white", "golden", "red", "silver"], value="white"), |
|
], |
|
outputs=[ |
|
gr.components.Image(type="file", label="Source Avatar (Costumed by You)"), |
|
gr.components.Image(type="file", label="Target Avatar (Randomly Chosen)"), |
|
], |
|
live=True, |
|
title="TransferVAE for Unsupervised Video Domain Adaptation", |
|
).launch() |
|
|