|
import gradio as gr |
|
import os, subprocess, torchaudio |
|
import torch |
|
from PIL import Image |
|
import gradio as gr |
|
import soundfile |
|
from gtts import gTTS |
|
import tempfile |
|
from pydub.generators import Sine |
|
from pydub import AudioSegment |
|
import dlib |
|
import cv2 |
|
import imageio |
|
import os |
|
import ffmpeg |
|
from io import BytesIO |
|
import requests |
|
import sys |
|
|
|
python_path = sys.executable |
|
|
|
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub |
|
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface |
|
|
|
block = gr.Blocks() |
|
|
|
def compute_aspect_preserved_bbox(bbox, increase_area, h, w): |
|
left, top, right, bot = bbox |
|
width = right - left |
|
height = bot - top |
|
|
|
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) |
|
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) |
|
|
|
left_t = int(left - width_increase * width) |
|
top_t = int(top - height_increase * height) |
|
right_t = int(right + width_increase * width) |
|
bot_t = int(bot + height_increase * height) |
|
|
|
left_oob = -min(0, left_t) |
|
right_oob = right - min(right_t, w) |
|
top_oob = -min(0, top_t) |
|
bot_oob = bot - min(bot_t, h) |
|
|
|
if max(left_oob, right_oob, top_oob, bot_oob) > 0: |
|
max_w = max(left_oob, right_oob) |
|
max_h = max(top_oob, bot_oob) |
|
if max_w > max_h: |
|
return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w |
|
else: |
|
return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h |
|
|
|
else: |
|
return (left_t, top_t, right_t, bot_t) |
|
|
|
def crop_src_image(src_img, detector=None): |
|
if detector is None: |
|
detector = dlib.get_frontal_face_detector() |
|
save_img='/content/image_pre.png' |
|
img = cv2.imread(src_img) |
|
faces = detector(img, 0) |
|
h, width, _ = img.shape |
|
if len(faces) > 0: |
|
bbox = [faces[0].left(), faces[0].top(),faces[0].right(), faces[0].bottom()] |
|
l = bbox[3]-bbox[1] |
|
bbox[1]= bbox[1]-l*0.1 |
|
bbox[3]= bbox[3]-l*0.1 |
|
bbox[1] = max(0,bbox[1]) |
|
bbox[3] = min(h,bbox[3]) |
|
bbox = compute_aspect_preserved_bbox(tuple(bbox), 0.5, img.shape[0], img.shape[1]) |
|
img = img[bbox[1] :bbox[3] , bbox[0]:bbox[2]] |
|
img = cv2.resize(img, (256, 256)) |
|
cv2.imwrite(save_img,img) |
|
else: |
|
img = cv2.resize(img,(256,256)) |
|
cv2.imwrite(save_img, img) |
|
return save_img |
|
|
|
def pad_image(image): |
|
w, h = image.size |
|
if w == h: |
|
return image |
|
elif w > h: |
|
new_image = Image.new(image.mode, (w, w), (0, 0, 0)) |
|
new_image.paste(image, (0, (w - h) // 2)) |
|
return new_image |
|
else: |
|
new_image = Image.new(image.mode, (h, h), (0, 0, 0)) |
|
new_image.paste(image, ((h - w) // 2, 0)) |
|
return new_image |
|
|
|
def calculate(image_in, audio_in): |
|
waveform, sample_rate = torchaudio.load(audio_in) |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
torchaudio.save("/content/audio.wav", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) |
|
image_in = image_in.replace("results/", "") |
|
print("****"*100) |
|
print(f" *#*#*# original image => {image_in} *#*#*# ") |
|
if os.path.exists(image_in): |
|
print(f"image exists => {image_in}") |
|
image = Image.open(image_in) |
|
else: |
|
print("image not exists reading web image") |
|
image_url = "http://labelme.csail.mit.edu/Release3.0/Images/users/DNguyen91/face/m_unsexy_gr.jpg" |
|
response = requests.get(image_url) |
|
image = Image.open(BytesIO(response.content)) |
|
print("****"*100) |
|
image = pad_image(image) |
|
|
|
image.save("image.png") |
|
|
|
pocketsphinx_run = subprocess.run(['pocketsphinx', '-phone_align', 'yes', 'single', '/content/audio.wav'], check=True, capture_output=True) |
|
jq_run = subprocess.run(['jq', '[.w[]|{word: (.t | ascii_upcase | sub("<S>"; "sil") | sub("<SIL>"; "sil") | sub("\\\(2\\\)"; "") | sub("\\\(3\\\)"; "") | sub("\\\(4\\\)"; "") | sub("\\\[SPEECH\\\]"; "SIL") | sub("\\\[NOISE\\\]"; "SIL")), phones: [.w[]|{ph: .t | sub("\\\+SPN\\\+"; "SIL") | sub("\\\+NSN\\\+"; "SIL"), bg: (.b*100)|floor, ed: (.b*100+.d*100)|floor}]}]'], input=pocketsphinx_run.stdout, capture_output=True) |
|
with open("test.json", "w") as f: |
|
f.write(jq_run.stdout.decode('utf-8').strip()) |
|
|
|
|
|
os.system(f"cd /content/one-shot-talking-face && {python_path} -B test_script.py --img_path /content/image.png --audio_path /content/audio.wav --phoneme_path /content/test.json --save_dir /content/train") |
|
return "/content/train/image_audio.mp4" |
|
|
|
def merge_frames(): |
|
|
|
|
|
path = '/content/video_results/restored_imgs' |
|
|
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
image_folder = os.fsencode(path) |
|
print(image_folder) |
|
filenames = [] |
|
|
|
for file in os.listdir(image_folder): |
|
filename = os.fsdecode(file) |
|
if filename.endswith( ('.jpg', '.png', '.gif') ): |
|
filenames.append(filename) |
|
|
|
filenames.sort() |
|
print(filenames) |
|
images = list(map(lambda filename: imageio.imread("/content/video_results/restored_imgs/"+filename), filenames)) |
|
|
|
imageio.mimsave('/content/video_output.mp4', images, fps=25.0) |
|
return "/content/video_output.mp4" |
|
|
|
def audio_video(): |
|
|
|
input_video = ffmpeg.input('/content/video_output.mp4') |
|
|
|
input_audio = ffmpeg.input('/content/audio.wav') |
|
os.system(f"rm -rf /content/final_output.mp4") |
|
ffmpeg.concat(input_video, input_audio, v=1, a=1).output('/content/final_output.mp4').run() |
|
|
|
return "/content/final_output.mp4" |
|
|
|
def one_shot_talking(image_in,audio_in): |
|
|
|
|
|
|
|
crop_img=crop_src_image(image_in) |
|
|
|
if os.path.exists("/content/results/restored_imgs/image_pre.png"): |
|
os.system(f"rm -rf /content/results/restored_imgs/image_pre.png") |
|
|
|
if not os.path.exists( "/content/results" ): |
|
os.makedirs("/content/results") |
|
|
|
|
|
os.system(f"{python_path} /content/GFPGAN/inference_gfpgan.py --upscale 2 -i /content/image_pre.png -o /content/results --bg_upsampler realesrgan") |
|
|
|
|
|
image_in_one_shot='/content/results/image_pre.png' |
|
|
|
|
|
calculate(image_in_one_shot,audio_in) |
|
|
|
|
|
os.system(f"rm -rf /content/extracted_frames/image_audio_frames") |
|
|
|
os.system(f"{python_path} /content/PyVideoFramesExtractor/extract.py --video=/content/train/image_audio.mp4") |
|
|
|
|
|
|
|
os.system(f"rm -rf /content/video_results/") |
|
os.system(f"{python_path} /content/GFPGAN/inference_gfpgan.py --upscale 2 -i /content/extracted_frames/image_audio_frames -o /content/video_results --bg_upsampler realesrgan") |
|
|
|
|
|
merge_frames() |
|
return audio_video() |
|
|
|
|
|
def one_shot(image_in,input_text,gender): |
|
if gender == "Female": |
|
tts = gTTS(input_text) |
|
with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as f: |
|
tts.write_to_fp(f) |
|
f.seek(0) |
|
sound = AudioSegment.from_file(f.name, format="mp3") |
|
os.system(f"rm -rf /content/audio.wav") |
|
sound.export("/content/audio.wav", format="wav") |
|
audio_in="/content/audio.wav" |
|
return one_shot_talking(image_in,audio_in) |
|
elif gender == 'Male': |
|
|
|
models, cfg, task = load_model_ensemble_and_task_from_hf_hub( |
|
"Voicemod/fastspeech2-en-male1", |
|
arg_overrides={"vocoder": "hifigan", "fp16": False} |
|
) |
|
|
|
model = models[0] |
|
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg) |
|
generator = task.build_generator([model], cfg) |
|
|
|
|
|
sample = TTSHubInterface.get_model_input(task, input_text) |
|
sample["net_input"]["src_tokens"] = sample["net_input"]["src_tokens"] |
|
sample["net_input"]["src_lengths"] = sample["net_input"]["src_lengths"] |
|
sample["speaker"] = sample["speaker"] |
|
|
|
wav, rate = TTSHubInterface.get_prediction(task, model, generator, sample) |
|
|
|
os.system(f"rm -rf /content/audio_before.wav") |
|
soundfile.write("/content/audio_before.wav", wav.cpu().clone().numpy(), rate) |
|
os.system(f"rm -rf /content/audio.wav") |
|
cmd='ffmpeg -i /content/audio_before.wav -filter:a "atempo=0.7" -vn /content/audio.wav' |
|
os.system(cmd) |
|
audio_in="/content/audio.wav" |
|
|
|
return one_shot_talking(image_in,audio_in) |
|
|
|
|
|
def run(): |
|
with gr.Blocks(css=".gradio-container {background-color: lightgray} #radio_div {background-color: #FFD8B4; font-size: 40px;}") as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>"+ "One Shot Talking Face from Text" + "</h1><br/><br/>") |
|
with gr.Group(): |
|
|
|
with gr.Row(): |
|
|
|
image_in = gr.Image(show_label=True, type="filepath",label="Input Image") |
|
input_text = gr.Textbox(show_label=True,label="Input Text") |
|
gender = gr.Radio(["Female","Male"],value="Female",label="Gender") |
|
video_out = gr.Video(show_label=True,label="Output") |
|
with gr.Row(): |
|
|
|
btn = gr.Button("Generate") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
btn.click(one_shot, inputs=[image_in,input_text,gender], outputs=[video_out]) |
|
demo.queue() |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
if __name__ == "__main__": |
|
run() |