Spaces:
Build error
Build error
import os | |
import sys | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
sys.path.append(os.getcwd()) | |
from tqdm import tqdm | |
from transformers import Wav2Vec2Processor | |
from evaluation.metrics import LVD | |
import numpy as np | |
import smplx as smpl | |
from data_utils.lower_body import part2full, poses2pred, c_index_3d | |
from nets import * | |
from nets.utils import get_path, get_dpath | |
from trainer.options import parse_args | |
from data_utils import torch_data | |
from trainer.config import load_JsonConfig | |
import torch | |
from torch.utils import data | |
from data_utils.get_j import to3d, get_joints | |
from scripts.test_body import init_model, init_dataloader | |
def test(test_loader, generator, config): | |
print('start testing') | |
loss_dict = {} | |
B = 1 | |
with torch.no_grad(): | |
count = 0 | |
for bat in tqdm(test_loader, desc="Testing......"): | |
count = count + 1 | |
aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \ | |
bat['expression'].to('cuda').to(torch.float32) | |
id = bat['speaker'].to('cuda') - 20 | |
betas = bat['betas'][0].to('cuda').to(torch.float64) | |
poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze() | |
poses = to3d(poses, config).unsqueeze(dim=0).transpose(1, 2) | |
# poses = poses[:, c_index_3d, :] | |
cur_wav_file = bat['aud_file'][0] | |
pred = generator.infer_on_audio(cur_wav_file, | |
initial_pose=poses, | |
id=id, | |
fps=30, | |
B=B | |
) | |
pred = torch.tensor(pred, device='cuda') | |
bat_loss_dict = {'capacity': (poses[:, c_index_3d, :pred.shape[0]].transpose(1,2) - pred).abs().sum(-1).mean()} | |
if loss_dict: # 非空 | |
for key in list(bat_loss_dict.keys()): | |
loss_dict[key] += bat_loss_dict[key] | |
else: | |
for key in list(bat_loss_dict.keys()): | |
loss_dict[key] = bat_loss_dict[key] | |
for key in loss_dict.keys(): | |
loss_dict[key] = loss_dict[key] / count | |
print(key + '=' + str(loss_dict[key].item())) | |
def main(): | |
parser = parse_args() | |
args = parser.parse_args() | |
device = torch.device(args.gpu) | |
torch.cuda.set_device(device) | |
config = load_JsonConfig(args.config_file) | |
os.environ['smplx_npz_path'] = config.smplx_npz_path | |
os.environ['extra_joint_path'] = config.extra_joint_path | |
os.environ['j14_regressor_path'] = config.j14_regressor_path | |
print('init dataloader...') | |
test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config) | |
print('init model...') | |
model_name = 's2g_body_vq' | |
model_type = 'n_com_8192' | |
model_path = get_path(model_name, model_type) | |
generator = init_model(model_name, model_path, args, config) | |
test(test_loader, generator, config) | |
if __name__ == '__main__': | |
main() | |