|
import gradio as gr |
|
|
|
import cv2 |
|
import imageio |
|
import math |
|
from math import ceil |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class RelationModuleMultiScale(torch.nn.Module): |
|
|
|
def __init__(self, img_feature_dim, num_bottleneck, num_frames): |
|
super(RelationModuleMultiScale, self).__init__() |
|
self.subsample_num = 3 |
|
self.img_feature_dim = img_feature_dim |
|
self.scales = [i for i in range(num_frames, 1, -1)] |
|
self.relations_scales = [] |
|
self.subsample_scales = [] |
|
for scale in self.scales: |
|
relations_scale = self.return_relationset(num_frames, scale) |
|
self.relations_scales.append(relations_scale) |
|
self.subsample_scales.append(min(self.subsample_num, len(relations_scale))) |
|
self.num_frames = num_frames |
|
self.fc_fusion_scales = nn.ModuleList() |
|
for i in range(len(self.scales)): |
|
scale = self.scales[i] |
|
fc_fusion = nn.Sequential(nn.ReLU(), nn.Linear(scale * self.img_feature_dim, num_bottleneck), nn.ReLU()) |
|
self.fc_fusion_scales += [fc_fusion] |
|
|
|
def forward(self, input): |
|
act_scale_1 = input[:, self.relations_scales[0][0] , :] |
|
act_scale_1 = act_scale_1.view(act_scale_1.size(0), self.scales[0] * self.img_feature_dim) |
|
act_scale_1 = self.fc_fusion_scales[0](act_scale_1) |
|
act_scale_1 = act_scale_1.unsqueeze(1) |
|
act_all = act_scale_1.clone() |
|
for scaleID in range(1, len(self.scales)): |
|
act_relation_all = torch.zeros_like(act_scale_1) |
|
num_total_relations = len(self.relations_scales[scaleID]) |
|
num_select_relations = self.subsample_scales[scaleID] |
|
idx_relations_evensample = [int(ceil(i * num_total_relations / num_select_relations)) for i in range(num_select_relations)] |
|
for idx in idx_relations_evensample: |
|
act_relation = input[:, self.relations_scales[scaleID][idx], :] |
|
act_relation = act_relation.view(act_relation.size(0), self.scales[scaleID] * self.img_feature_dim) |
|
act_relation = self.fc_fusion_scales[scaleID](act_relation) |
|
act_relation = act_relation.unsqueeze(1) |
|
act_relation_all += act_relation |
|
act_all = torch.cat((act_all, act_relation_all), 1) |
|
return act_all |
|
|
|
def return_relationset(self, num_frames, num_frames_relation): |
|
import itertools |
|
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation)) |
|
|
|
|
|
class TransferVAE_Video(nn.Module): |
|
|
|
def __init__(self): |
|
super(TransferVAE_Video, self).__init__() |
|
self.f_dim = 512 |
|
self.z_dim = 512 |
|
self.fc_dim = 1024 |
|
self.channels = 3 |
|
self.frames = 8 |
|
self.batch_size = 128 |
|
self.dropout_rate = 0.5 |
|
self.num_class = 15 |
|
self.prior_sample = 'random' |
|
|
|
import dcgan_64 |
|
self.encoder = dcgan_64.encoder(self.fc_dim, self.channels) |
|
self.decoder = dcgan_64.decoder_woSkip(self.z_dim + self.f_dim, self.channels) |
|
self.fc_output_dim = self.fc_dim |
|
|
|
self.relu = nn.LeakyReLU(0.1) |
|
self.dropout_f = nn.Dropout(p=self.dropout_rate) |
|
self.dropout_v = nn.Dropout(p=self.dropout_rate) |
|
|
|
self.hidden_dim = 512 |
|
self.f_rnn_layers = 1 |
|
|
|
self.z_prior_lstm_ly1 = nn.LSTMCell(self.z_dim, self.hidden_dim) |
|
self.z_prior_lstm_ly2 = nn.LSTMCell(self.hidden_dim, self.hidden_dim) |
|
|
|
self.z_prior_mean = nn.Linear(self.hidden_dim, self.z_dim) |
|
self.z_prior_logvar = nn.Linear(self.hidden_dim, self.z_dim) |
|
|
|
self.z_lstm = nn.LSTM(self.fc_output_dim, self.hidden_dim, self.f_rnn_layers, bidirectional=True, batch_first=True) |
|
self.f_mean = nn.Linear(self.hidden_dim * 2, self.f_dim) |
|
self.f_logvar = nn.Linear(self.hidden_dim * 2, self.f_dim) |
|
|
|
self.z_rnn = nn.RNN(self.hidden_dim * 2, self.hidden_dim, batch_first=True) |
|
self.z_mean = nn.Linear(self.hidden_dim, self.z_dim) |
|
self.z_logvar = nn.Linear(self.hidden_dim, self.z_dim) |
|
|
|
self.fc_feature_domain_frame = nn.Linear(self.z_dim, self.z_dim) |
|
self.fc_classifier_domain_frame = nn.Linear(self.z_dim, 2) |
|
|
|
self.num_bottleneck = 256 |
|
self.TRN = RelationModuleMultiScale(self.z_dim, self.num_bottleneck, self.frames) |
|
self.bn_trn_S = nn.BatchNorm1d(self.num_bottleneck) |
|
self.bn_trn_T = nn.BatchNorm1d(self.num_bottleneck) |
|
self.feat_aggregated_dim = self.num_bottleneck |
|
|
|
self.fc_feature_domain_video = nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim) |
|
self.fc_classifier_domain_video = nn.Linear(self.feat_aggregated_dim, 2) |
|
|
|
self.relation_domain_classifier_all = nn.ModuleList() |
|
for i in range(self.frames-1): |
|
relation_domain_classifier = nn.Sequential( |
|
nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim), |
|
nn.ReLU(), |
|
nn.Linear(self.feat_aggregated_dim, 2) |
|
) |
|
self.relation_domain_classifier_all += [relation_domain_classifier] |
|
|
|
self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class) |
|
self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim) |
|
self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2) |
|
|
|
|
|
def encode_and_sample_post(self, x): |
|
if isinstance(x, list): |
|
conv_x = self.encoder_frame(x[0]) |
|
else: |
|
conv_x = self.encoder_frame(x) |
|
|
|
lstm_out, _ = self.z_lstm(conv_x) |
|
|
|
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim] |
|
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim] |
|
lstm_out_f = torch.cat((frontal, backward), dim=1) |
|
f_mean = self.f_mean(lstm_out_f) |
|
f_logvar = self.f_logvar(lstm_out_f) |
|
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False) |
|
|
|
features, _ = self.z_rnn(lstm_out) |
|
z_mean = self.z_mean(features) |
|
z_logvar = self.z_logvar(features) |
|
z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False) |
|
|
|
if isinstance(x, list): |
|
f_mean_list = [f_mean] |
|
f_post_list = [f_post] |
|
for t in range(1,3,1): |
|
conv_x = self.encoder_frame(x[t]) |
|
lstm_out, _ = self.z_lstm(conv_x) |
|
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim] |
|
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim] |
|
lstm_out_f = torch.cat((frontal, backward), dim=1) |
|
f_mean = self.f_mean(lstm_out_f) |
|
f_logvar = self.f_logvar(lstm_out_f) |
|
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False) |
|
f_mean_list.append(f_mean) |
|
f_post_list.append(f_post) |
|
f_mean = f_mean_list |
|
f_post = f_post_list |
|
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post |
|
|
|
|
|
def decoder_frame(self,zf): |
|
recon_x = self.decoder(zf) |
|
return recon_x |
|
|
|
|
|
def encoder_frame(self, x): |
|
x_shape = x.shape |
|
x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1]) |
|
x_embed = self.encoder(x)[0] |
|
return x_embed.view(x_shape[0], x_shape[1], -1) |
|
|
|
|
|
def reparameterize(self, mean, logvar, random_sampling=True): |
|
if random_sampling is True: |
|
eps = torch.randn_like(logvar) |
|
std = torch.exp(0.5 * logvar) |
|
z = mean + eps * std |
|
return z |
|
else: |
|
return mean |
|
|
|
|
|
def forward(self, x, beta): |
|
_, _, f_post, _, _, z_post = self.encode_and_sample_post(x) |
|
if isinstance(f_post, list): |
|
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim) |
|
else: |
|
f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim) |
|
zf = torch.cat((z_post, f_expand), dim=2) |
|
recon_x = self.decoder_frame(zf) |
|
return f_post, z_post, recon_x |
|
|
|
|
|
|
|
|
|
def name2seq(file_name): |
|
images = [] |
|
|
|
for frame in range(8): |
|
frame_name = '%d' % (frame) |
|
image_filename = file_name + frame_name + '.png' |
|
image = imageio.imread(image_filename) |
|
images.append(image[:, :, :3]) |
|
|
|
images = np.asarray(images, dtype='f') / 256.0 |
|
images = images.transpose((0, 3, 1, 2)) |
|
images = torch.Tensor(images).unsqueeze(dim=0) |
|
return images |
|
|
|
|
|
def concat(file_name): |
|
images = [] |
|
|
|
for frame in range(8): |
|
frame_name = '%d' % (frame) |
|
image_filename = file_name + frame_name + '.png' |
|
image = imageio.imread(image_filename) |
|
images.append(image) |
|
|
|
gif_filename = 'demo.gif' |
|
return imageio.mimsave(gif_filename, images) |
|
|
|
|
|
def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt): |
|
|
|
fig, axs = plt.subplots(2, 4, sharex=True, sharey=True, figsize=(10, 5)) |
|
|
|
axs[0, 0].imshow(src_orig) |
|
axs[0, 0].set_title("\n\n\nOriginal\nInput") |
|
axs[0, 0].axis('off') |
|
|
|
axs[1, 0].imshow(tar_orig) |
|
axs[1, 0].axis('off') |
|
|
|
axs[0, 1].imshow(src_recon) |
|
axs[0, 1].set_title("\n\n\nReconstructed\nOutput") |
|
axs[0, 1].axis('off') |
|
|
|
axs[1, 1].imshow(tar_recon) |
|
axs[1, 1].axis('off') |
|
|
|
axs[0, 2].imshow(src_Zt) |
|
axs[0, 2].set_title("\n\n\nOutput\nw/ Zt") |
|
axs[0, 2].axis('off') |
|
|
|
axs[1, 2].imshow(tar_Zt) |
|
axs[1, 2].axis('off') |
|
|
|
axs[0, 3].imshow(tar_Zf_src_Zt) |
|
axs[0, 3].set_title("\n\n\nExchange\nZt and Zf") |
|
axs[0, 3].axis('off') |
|
|
|
axs[1, 3].imshow(src_Zf_tar_Zt) |
|
axs[1, 3].axis('off') |
|
|
|
plt.subplots_adjust(hspace=0.06, wspace=0.05) |
|
|
|
save_name = 'MyPlot_{}.png'.format(frame_id) |
|
|
|
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0) |
|
|
|
|
|
|
|
model = TransferVAE_Video() |
|
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict']) |
|
model.eval() |
|
|
|
|
|
def run(action_source, hair_source, top_source, bottom_source, action_target, hair_target, top_target, bottom_target): |
|
|
|
|
|
|
|
body_source = '0' |
|
|
|
|
|
if hair_source == "green": hair_source = '0' |
|
elif hair_source == "yellow": hair_source = '2' |
|
elif hair_source == "rose": hair_source = '4' |
|
elif hair_source == "red": hair_source = '7' |
|
elif hair_source == "wine": hair_source = '8' |
|
|
|
|
|
if top_source == "brown": top_source = '0' |
|
elif top_source == "blue": top_source = '1' |
|
elif top_source == "white": top_source = '2' |
|
|
|
|
|
if bottom_source == "white": bottom_source = '0' |
|
elif bottom_source == "golden": bottom_source = '1' |
|
elif bottom_source == "red": bottom_source = '2' |
|
elif bottom_source == "silver": bottom_source = '3' |
|
|
|
file_name_source = './Sprite/frames/domain_1/' + action_source + '/' |
|
file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_' |
|
|
|
|
|
|
|
|
|
body_target = '1' |
|
|
|
|
|
if hair_target == "violet": hair_target = '1' |
|
elif hair_target == "silver": hair_target = '3' |
|
elif hair_target == "purple": hair_target = '5' |
|
elif hair_target == "grey": hair_target = '6' |
|
elif hair_target == "golden": hair_target = '9' |
|
|
|
|
|
if top_target == "grey": top_target = '3' |
|
elif top_target == "khaki": top_target = '4' |
|
elif top_target == "linen": top_target = '5' |
|
elif top_target == "ocre": top_target = '6' |
|
|
|
|
|
if bottom_target == "denim": bottom_target = '4' |
|
elif bottom_target == "olive": bottom_target = '5' |
|
elif bottom_target == "brown": bottom_target = '6' |
|
|
|
file_name_target = './Sprite/frames/domain_2/' + action_target + '/' |
|
file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_' |
|
|
|
|
|
|
|
images_source = name2seq(file_name_source) |
|
images_target = name2seq(file_name_target) |
|
x = torch.cat((images_source, images_target), dim=0) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
f_post, z_post, recon_x = model(x, [0]*3) |
|
|
|
src_orig_sample = x[0, :, :, :, :] |
|
src_recon_sample = recon_x[0, :, :, :, :] |
|
src_f_post = f_post[0, :].unsqueeze(0) |
|
src_z_post = z_post[0, :, :].unsqueeze(0) |
|
|
|
tar_orig_sample = x[1, :, :, :, :] |
|
tar_recon_sample = recon_x[1, :, :, :, :] |
|
tar_f_post = f_post[1, :].unsqueeze(0) |
|
tar_z_post = z_post[1, :, :].unsqueeze(0) |
|
|
|
|
|
|
|
for frame in range(8): |
|
|
|
|
|
src_orig = src_orig_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
tar_orig = tar_orig_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
|
|
|
|
src_recon = src_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
tar_recon = tar_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
|
|
|
|
f_expand_src = 0 * src_f_post.unsqueeze(1).expand(-1, 8, 512) |
|
zf_src = torch.cat((src_z_post, f_expand_src), dim=2) |
|
recon_x_src = model.decoder_frame(zf_src) |
|
src_Zt = recon_x_src.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
|
|
f_expand_tar = 0 * tar_f_post.unsqueeze(1).expand(-1, 8, 512) |
|
zf_tar = torch.cat((tar_z_post, f_expand_tar), dim=2) |
|
recon_x_tar = model.decoder_frame(zf_tar) |
|
tar_Zt = recon_x_tar.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
|
|
|
|
f_expand_src = src_f_post.unsqueeze(1).expand(-1, 8, 512) |
|
zf_srcZf_tarZt = torch.cat((tar_z_post, f_expand_src), dim=2) |
|
recon_x_srcZf_tarZt = model.decoder_frame(zf_srcZf_tarZt) |
|
src_Zf_tar_Zt = recon_x_srcZf_tarZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
|
|
f_expand_tar = tar_f_post.unsqueeze(1).expand(-1, 8, 512) |
|
zf_tarZf_srcZt = torch.cat((src_z_post, f_expand_tar), dim=2) |
|
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt) |
|
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0)) |
|
|
|
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt) |
|
|
|
a = concat('MyPlot_') |
|
|
|
return 'demo.gif' |
|
|
|
|
|
gr.Interface( |
|
fn=run, |
|
inputs=[ |
|
gr.Markdown( |
|
""" |
|
Source Avatar - Human π¦π» |
|
""" |
|
), |
|
gr.Radio(choices=["slash", "spellcard", "walk"], value="slash"), |
|
gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"), |
|
gr.Radio(choices=["brown", "blue", "white"], value="brown"), |
|
gr.Radio(choices=["white", "golden", "red", "silver"], value="white"), |
|
gr.Markdown( |
|
""" |
|
Target Avatar - Alien π½ |
|
""" |
|
), |
|
gr.Radio(choices=["slash", "spellcard", "walk"], value="walk"), |
|
gr.Radio(choices=["violet", "silver", "purple", "grey", "golden"], value="golden"), |
|
gr.Radio(choices=["grey", "khaki", "linen", "ocre"], value="ocre"), |
|
gr.Radio(choices=["denim", "olive", "brown"], value="brown"), |
|
], |
|
outputs=[ |
|
gr.components.Image(type="file", label="Domain Disentanglement"), |
|
], |
|
live=True, |
|
cache_examples=True, |
|
title="TransferVAE for Unsupervised Video Domain Adaptation", |
|
).launch() |
|
|