Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,11 @@ import imageio
|
|
5 |
import math
|
6 |
from math import ceil
|
7 |
import matplotlib.pyplot as plt
|
8 |
-
import matplotlib.animation as animation
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
import torch
|
12 |
import torch.nn as nn
|
13 |
import torch.nn.functional as F
|
14 |
-
from torch.autograd import Function
|
15 |
|
16 |
|
17 |
class RelationModuleMultiScale(torch.nn.Module):
|
@@ -301,7 +299,7 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
|
|
301 |
|
302 |
|
303 |
# == Load Model ==
|
304 |
-
model = TransferVAE_Video(
|
305 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
306 |
model.eval()
|
307 |
|
@@ -393,24 +391,24 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
|
|
393 |
tar_recon = tar_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
394 |
|
395 |
# Zt
|
396 |
-
f_expand_src = 0 * src_f_post.unsqueeze(1).expand(-1, 8,
|
397 |
zf_src = torch.cat((src_z_post, f_expand_src), dim=2)
|
398 |
recon_x_src = model.decoder_frame(zf_src)
|
399 |
src_Zt = recon_x_src.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
400 |
|
401 |
-
f_expand_tar = 0 * tar_f_post.unsqueeze(1).expand(-1, 8,
|
402 |
-
zf_tar = torch.cat((tar_z_post, f_expand_tar), dim=2)
|
403 |
recon_x_tar = model.decoder_frame(zf_tar)
|
404 |
tar_Zt = recon_x_tar.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
405 |
|
406 |
# Zf_Zt
|
407 |
-
f_expand_src = src_f_post.unsqueeze(1).expand(-1, 8,
|
408 |
-
zf_srcZf_tarZt = torch.cat((tar_z_post, f_expand_src), dim=2)
|
409 |
recon_x_srcZf_tarZt = model.decoder_frame(zf_srcZf_tarZt)
|
410 |
src_Zf_tar_Zt = recon_x_srcZf_tarZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
411 |
|
412 |
-
f_expand_tar = tar_f_post.unsqueeze(1).expand(-1, 8,
|
413 |
-
zf_tarZf_srcZt = torch.cat((src_z_post, f_expand_tar), dim=2)
|
414 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
415 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
416 |
|
|
|
5 |
import math
|
6 |
from math import ceil
|
7 |
import matplotlib.pyplot as plt
|
|
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
|
|
13 |
|
14 |
|
15 |
class RelationModuleMultiScale(torch.nn.Module):
|
|
|
299 |
|
300 |
|
301 |
# == Load Model ==
|
302 |
+
model = TransferVAE_Video()
|
303 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
304 |
model.eval()
|
305 |
|
|
|
391 |
tar_recon = tar_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
392 |
|
393 |
# Zt
|
394 |
+
f_expand_src = 0 * src_f_post.unsqueeze(1).expand(-1, 8, 512)
|
395 |
zf_src = torch.cat((src_z_post, f_expand_src), dim=2)
|
396 |
recon_x_src = model.decoder_frame(zf_src)
|
397 |
src_Zt = recon_x_src.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
398 |
|
399 |
+
f_expand_tar = 0 * tar_f_post.unsqueeze(1).expand(-1, 8, 512)
|
400 |
+
zf_tar = torch.cat((tar_z_post, f_expand_tar), dim=2)
|
401 |
recon_x_tar = model.decoder_frame(zf_tar)
|
402 |
tar_Zt = recon_x_tar.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
403 |
|
404 |
# Zf_Zt
|
405 |
+
f_expand_src = src_f_post.unsqueeze(1).expand(-1, 8, 512)
|
406 |
+
zf_srcZf_tarZt = torch.cat((tar_z_post, f_expand_src), dim=2)
|
407 |
recon_x_srcZf_tarZt = model.decoder_frame(zf_srcZf_tarZt)
|
408 |
src_Zf_tar_Zt = recon_x_srcZf_tarZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
409 |
|
410 |
+
f_expand_tar = tar_f_post.unsqueeze(1).expand(-1, 8, 512)
|
411 |
+
zf_tarZf_srcZt = torch.cat((src_z_post, f_expand_tar), dim=2)
|
412 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
413 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
414 |
|