|
|
|
|
|
import os |
|
|
|
inp_text = os.environ.get("inp_text") |
|
inp_wav_dir = os.environ.get("inp_wav_dir") |
|
exp_name = os.environ.get("exp_name") |
|
i_part = os.environ.get("i_part") |
|
all_parts = os.environ.get("all_parts") |
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES") |
|
opt_dir = os.environ.get("opt_dir") |
|
bert_pretrained_dir = os.environ.get("bert_pretrained_dir") |
|
is_half = eval(os.environ.get("is_half", "True")) |
|
import sys, numpy as np, traceback, pdb |
|
import os.path |
|
from glob import glob |
|
from tqdm import tqdm |
|
from text.cleaner import clean_text |
|
import torch |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from time import time as ttime |
|
import shutil |
|
|
|
|
|
def my_save(fea,path): |
|
dir=os.path.dirname(path) |
|
name=os.path.basename(path) |
|
|
|
tmp_path="%s%s.pth"%(ttime(),i_part) |
|
torch.save(fea,tmp_path) |
|
shutil.move(tmp_path,"%s/%s"%(dir,name)) |
|
|
|
|
|
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) |
|
if os.path.exists(txt_path) == False: |
|
bert_dir = "%s/3-bert" % (opt_dir) |
|
os.makedirs(opt_dir, exist_ok=True) |
|
os.makedirs(bert_dir, exist_ok=True) |
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
|
|
|
|
else: |
|
device = "cpu" |
|
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir) |
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) |
|
if is_half == True: |
|
bert_model = bert_model.half().to(device) |
|
else: |
|
bert_model = bert_model.to(device) |
|
|
|
def get_bert_feature(text, word2ph): |
|
with torch.no_grad(): |
|
inputs = tokenizer(text, return_tensors="pt") |
|
for i in inputs: |
|
inputs[i] = inputs[i].to(device) |
|
res = bert_model(**inputs, output_hidden_states=True) |
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] |
|
|
|
assert len(word2ph) == len(text) |
|
phone_level_feature = [] |
|
for i in range(len(word2ph)): |
|
repeat_feature = res[i].repeat(word2ph[i], 1) |
|
phone_level_feature.append(repeat_feature) |
|
|
|
phone_level_feature = torch.cat(phone_level_feature, dim=0) |
|
|
|
return phone_level_feature.T |
|
|
|
def process(data, res): |
|
for name, text, lan in data: |
|
try: |
|
name = os.path.basename(name) |
|
phones, word2ph, norm_text = clean_text( |
|
text.replace("%", "-").replace("¥", ","), lan |
|
) |
|
path_bert = "%s/%s.pt" % (bert_dir, name) |
|
if os.path.exists(path_bert) == False and lan == "zh": |
|
bert_feature = get_bert_feature(norm_text, word2ph) |
|
assert bert_feature.shape[-1] == len(phones) |
|
|
|
my_save(bert_feature, path_bert) |
|
phones = " ".join(phones) |
|
|
|
res.append([name, phones, word2ph, norm_text]) |
|
except: |
|
print(name, text, traceback.format_exc()) |
|
|
|
todo = [] |
|
res = [] |
|
with open(inp_text, "r", encoding="utf8") as f: |
|
lines = f.read().strip("\n").split("\n") |
|
|
|
language_v1_to_language_v2 = { |
|
"ZH": "zh", |
|
"zh": "zh", |
|
"JP": "ja", |
|
"jp": "ja", |
|
"JA": "ja", |
|
"ja": "ja", |
|
"EN": "en", |
|
"en": "en", |
|
"En": "en", |
|
} |
|
for line in lines[int(i_part) :: int(all_parts)]: |
|
try: |
|
wav_name, spk_name, language, text = line.split("|") |
|
|
|
todo.append( |
|
[wav_name, text, language_v1_to_language_v2.get(language, language)] |
|
) |
|
except: |
|
print(line, traceback.format_exc()) |
|
|
|
process(todo, res) |
|
opt = [] |
|
for name, phones, word2ph, norm_text in res: |
|
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text)) |
|
with open(txt_path, "w", encoding="utf8") as f: |
|
f.write("\n".join(opt) + "\n") |
|
|