TranSVAE / app.py
ldkong's picture
Update app.py
11e4216
raw
history blame
8.92 kB
import gradio as gr
import numpy as np
import torch
from torch import nn
import imageio
import cv2
class RelationModuleMultiScale(torch.nn.Module):
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
super(RelationModuleMultiScale, self).__init__()
self.subsample_num = 3
self.img_feature_dim = img_feature_dim
self.scales = [i for i in range(num_frames, 1, -1)]
self.relations_scales = []
self.subsample_scales = []
for scale in self.scales:
relations_scale = self.return_relationset(num_frames, scale)
self.relations_scales.append(relations_scale)
self.subsample_scales.append(min(self.subsample_num, len(relations_scale)))
self.num_frames = num_frames
self.fc_fusion_scales = nn.ModuleList() # high-tech modulelist
for i in range(len(self.scales)):
scale = self.scales[i]
fc_fusion = nn.Sequential(nn.ReLU(), nn.Linear(scale * self.img_feature_dim, num_bottleneck), nn.ReLU())
self.fc_fusion_scales += [fc_fusion]
def forward(self, input):
act_scale_1 = input[:, self.relations_scales[0][0] , :]
act_scale_1 = act_scale_1.view(act_scale_1.size(0), self.scales[0] * self.img_feature_dim)
act_scale_1 = self.fc_fusion_scales[0](act_scale_1)
act_scale_1 = act_scale_1.unsqueeze(1)
act_all = act_scale_1.clone()
for scaleID in range(1, len(self.scales)):
act_relation_all = torch.zeros_like(act_scale_1)
num_total_relations = len(self.relations_scales[scaleID])
num_select_relations = self.subsample_scales[scaleID]
idx_relations_evensample = [int(ceil(i * num_total_relations / num_select_relations)) for i in range(num_select_relations)]
for idx in idx_relations_evensample:
act_relation = input[:, self.relations_scales[scaleID][idx], :]
act_relation = act_relation.view(act_relation.size(0), self.scales[scaleID] * self.img_feature_dim)
act_relation = self.fc_fusion_scales[scaleID](act_relation)
act_relation = act_relation.unsqueeze(1)
act_relation_all += act_relation
act_all = torch.cat((act_all, act_relation_all), 1)
return act_all
def return_relationset(self, num_frames, num_frames_relation):
import itertools
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='Sprite', help='datasets')
parser.add_argument('--data_root', default='dataset', help='root directory for data')
parser.add_argument('--num_class', type=int, default=15, help='the number of class for jester dataset')
parser.add_argument('--input_type', default='image', choices=['feature', 'image'], help='the type of input')
parser.add_argument('--src', default='domain_1', help='source domain')
parser.add_argument('--tar', default='domain_2', help='target domain')
parser.add_argument('--num_segments', type=int, default=8, help='the number of frame segment')
parser.add_argument('--backbone', type=str, default="dcgan", choices=['dcgan', 'resnet101', 'I3Dpretrain','I3Dfinetune'], help='backbone')
parser.add_argument('--channels', default=3, type=int, help='input channels for image inputs')
parser.add_argument('--add_fc', default=1, type=int, metavar='M', help='number of additional fc layers (excluding the last fc layer) (e.g. 0, 1, 2)')
parser.add_argument('--fc_dim', type=int, default=1024, help='dimension of added fc')
parser.add_argument('--frame_aggregation', type=str, default='trn', choices=[ 'rnn', 'trn'], help='aggregation of frame features (none if baseline_type is not video)')
parser.add_argument('--dropout_rate', default=0.5, type=float, help='dropout ratio for frame-level feature (default: 0.5)')
parser.add_argument('--f_dim', type=int, default=512, help='dim of f')
parser.add_argument('--z_dim', type=int, default=512, help='dimensionality of z_t')
parser.add_argument('--f_rnn_layers', type=int, default=1, help='number of layers (content lstm)')
parser.add_argument('--use_bn', type=str, default='none', choices=['none', 'AdaBN', 'AutoDIAL'], help='normalization-based methods')
parser.add_argument('--prior_sample', type=str, default='random', choices=['random', 'post'], help='how to sample prior')
parser.add_argument('--batch_size', default=128, type=int, help='-batch size')
parser.add_argument('--use_attn', type=str, default='TransAttn', choices=['none', 'TransAttn', 'general'], help='attention-mechanism')
parser.add_argument('--data_threads', type=int, default=5, help='number of data loading threads')
opt = parser.parse_args(args=[])
def display_gif(file_name, save_name):
images = []
for frame in range(8):
frame_name = '%d' % (frame)
image_filename = file_name + frame_name + '.png'
images.append(imageio.imread(image_filename))
gif_filename = 'avatar_source.gif'
return imageio.mimsave(gif_filename, images)
def display_gif_pad(file_name, save_name):
images = []
for frame in range(8):
frame_name = '%d' % (frame)
image_filename = file_name + frame_name + '.png'
image = imageio.imread(image_filename)
image = image[:, :, :3]
image_pad = cv2.copyMakeBorder(image, 0, 0, 125, 125, cv2.BORDER_CONSTANT, value=0)
images.append(image_pad)
return imageio.mimsave(save_name, images)
def display_image(file_name):
image_filename = file_name + '0' + '.png'
print(image_filename)
image = imageio.imread(image_filename)
imageio.imwrite('image.png', image)
def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target):
# == Source Avatar ==
# body
body_source = '0'
# hair
if hair_source == "green": hair_source = '0'
elif hair_source == "yellow": hair_source = '2'
elif hair_source == "rose": hair_source = '4'
elif hair_source == "red": hair_source = '7'
elif hair_source == "wine": hair_source = '8'
# top
if top_source == "brown": top_source = '0'
elif top_source == "blue": top_source = '1'
elif top_source == "white": top_source = '2'
# bottom
if bottom_source == "white": bottom_source = '0'
elif bottom_source == "golden": bottom_source = '1'
elif bottom_source == "red": bottom_source = '2'
elif bottom_source == "silver": bottom_source = '3'
file_name_source = './Sprite/frames/domain_1/' + action_source + '/'
file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_'
gif = display_gif_pad(file_name_source, 'avatar_source.gif')
# == Target Avatar ==
# body
body_target = '1'
# hair
if hair_target == "violet": hair_target = '1'
elif hair_target == "silver": hair_target = '3'
elif hair_target == "purple": hair_target = '5'
elif hair_target == "grey": hair_target = '6'
elif hair_target == "golden": hair_target = '9'
# top
if top_target == "grey": top_target = '3'
elif top_target == "khaki": top_target = '4'
elif top_target == "linen": top_target = '5'
elif top_target == "ocre": top_target = '6'
# bottom
if bottom_target == "denim": bottom_target = '4'
elif bottom_target == "olive": bottom_target = '5'
elif bottom_target == "brown": bottom_target = '6'
file_name_target = './Sprite/frames/domain_2/' + action_target + '/'
file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_'
gif_target = display_gif_pad(file_name_target, 'avatar_target.gif')
return 'demo.gif'
gr.Interface(
run,
inputs=[
gr.Textbox(value="Source Avatar - Human", interactive=False),
gr.Radio(choices=["slash", "spellcard", "walk"], value="slash"),
gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"),
gr.Radio(choices=["brown", "blue", "white"], value="brown"),
gr.Radio(choices=["white", "golden", "red", "silver"], value="white"),
gr.Textbox(value="Target Avatar - Alien", interactive=False),
gr.Radio(choices=["slash", "spellcard", "walk"], value="walk"),
gr.Radio(choices=["violet", "silver", "purple", "grey", "golden"], value="golden"),
gr.Radio(choices=["grey", "khaki", "linen", "ocre"], value="ocre"),
gr.Radio(choices=["denim", "olive", "brown"], value="brown"),
],
outputs=[
gr.components.Image(type="file", label="Domain Disentanglement"),
],
live=True,
title="TransferVAE for Unsupervised Video Domain Adaptation",
).launch()