In [None]:
import torch
MODEL_NAME = "standard_float"
DEVICE_NAME = 'cuda'
device = torch.device(DEVICE_NAME)

def load_poser(model: str, device: torch.device):
    print("Using the %s model." % model)
    if model == "standard_float":
        from tha3.poser.modes.standard_float import create_poser
        return create_poser(device)
    elif model == "standard_half":
        from tha3.poser.modes.standard_half import create_poser
        return create_poser(device)
    elif model == "separable_float":
        from tha3.poser.modes.separable_float import create_poser
        return create_poser(device)
    elif model == "separable_half":
        from tha3.poser.modes.separable_half import create_poser
        return create_poser(device)
    else:
        raise RuntimeError("Invalid model: '%s'" % model)
        
poser = load_poser(MODEL_NAME, DEVICE_NAME)
poser.get_modules();

In [None]:
import PIL.Image
import io
from io import StringIO, BytesIO
import IPython.display
import numpy
import ipywidgets
import time
import threading
import torch
from tha3.util import resize_PIL_image, extract_PIL_image_from_filelike, \
    extract_pytorch_image_from_PIL_image, convert_output_image_from_torch_to_numpy

FRAME_RATE = 30.0

last_torch_input_image = None
torch_input_image = None

def show_pytorch_image(pytorch_image):
    output_image = pytorch_image.detach().cpu()
    numpy_image = numpy.uint8(numpy.rint(convert_output_image_from_torch_to_numpy(output_image) * 255.0))
    pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')
    IPython.display.display(pil_image)

upload_input_image_button = ipywidgets.FileUpload(
    accept='.png',
    multiple=False,
    layout={
        'width': '512px'
    }
)

output_image_widget = ipywidgets.Output(
    layout={
        'border': '1px solid black',
        'width': '512px',
        'height': '512px'
    }
)

eyebrow_dropdown = ipywidgets.Dropdown(
    options=["troubled", "angry", "lowered", "raised", "happy", "serious"],
    value="troubled",
    description="Eyebrow:",    
)
eyebrow_left_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Left:",
    readout=True,
    readout_format=".2f"
)
eyebrow_right_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Right:",
    readout=True,
    readout_format=".2f"
)

eye_dropdown = ipywidgets.Dropdown(
    options=["wink", "happy_wink", "surprised", "relaxed", "unimpressed", "raised_lower_eyelid"],
    value="wink",
    description="Eye:",    
)
eye_left_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Left:",
    readout=True,
    readout_format=".2f"
)
eye_right_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Right:",
    readout=True,
    readout_format=".2f"
)

mouth_dropdown = ipywidgets.Dropdown(
    options=["aaa", "iii", "uuu", "eee", "ooo", "delta", "lowered_corner", "raised_corner", "smirk"],
    value="aaa",
    description="Mouth:",    
)
mouth_left_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Value:",
    readout=True,
    readout_format=".2f"
)
mouth_right_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description=" ",
    readout=True,
    readout_format=".2f",
    disabled=True,
)

def update_mouth_sliders(change):
    if mouth_dropdown.value == "lowered_corner" or mouth_dropdown.value == "raised_corner":
        mouth_left_slider.description = "Left:"
        mouth_right_slider.description = "Right:"
        mouth_right_slider.disabled = False
    else:
        mouth_left_slider.description = "Value:"
        mouth_right_slider.description = " "
        mouth_right_slider.disabled = True

mouth_dropdown.observe(update_mouth_sliders, names='value')

iris_small_left_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Left:",
    readout=True,
    readout_format=".2f"
)
iris_small_right_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Right:",
    readout=True,
    readout_format=".2f",    
)
iris_rotation_x_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="X-axis:",
    readout=True,
    readout_format=".2f"
)
iris_rotation_y_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="Y-axis:",
    readout=True,
    readout_format=".2f",    
)

head_x_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="X-axis:",
    readout=True,
    readout_format=".2f"
)
head_y_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="Y-axis:",
    readout=True,
    readout_format=".2f",    
)
neck_z_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="Z-axis:",
    readout=True,
    readout_format=".2f",    
)
body_y_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="Y-axis rotation:",
    readout=True,
    readout_format=".2f",    
)
body_z_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="Z-axis rotation:",
    readout=True,
    readout_format=".2f",    
)
breathing_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Breathing:",
    readout=True,
    readout_format=".2f",    
)


control_panel = ipywidgets.VBox([
    eyebrow_dropdown,
    eyebrow_left_slider,
    eyebrow_right_slider,
    ipywidgets.HTML(value="<hr>"),
    eye_dropdown,
    eye_left_slider,
    eye_right_slider,
    ipywidgets.HTML(value="<hr>"),
    mouth_dropdown,
    mouth_left_slider,
    mouth_right_slider,
    ipywidgets.HTML(value="<hr>"),
    ipywidgets.HTML(value="<center><b>Iris Shrinkage</b></center>"),
    iris_small_left_slider,
    iris_small_right_slider,
    ipywidgets.HTML(value="<center><b>Iris Rotation</b></center>"),
    iris_rotation_x_slider,
    iris_rotation_y_slider,
    ipywidgets.HTML(value="<hr>"),
    ipywidgets.HTML(value="<center><b>Head Rotation</b></center>"),
    head_x_slider,
    head_y_slider,
    neck_z_slider,
    ipywidgets.HTML(value="<hr>"),
    ipywidgets.HTML(value="<center><b>Body Rotation</b></center>"),
    body_y_slider,
    body_z_slider,
    ipywidgets.HTML(value="<hr>"),
    ipywidgets.HTML(value="<center><b>Breathing</b></center>"),
    breathing_slider,
])

controls = ipywidgets.HBox([
    ipywidgets.VBox([
        output_image_widget, 
        upload_input_image_button
    ]),
    control_panel,
])

from tha3.poser.modes.pose_parameters import get_pose_parameters
pose_parameters = get_pose_parameters()
pose_size = poser.get_num_parameters()
last_pose = torch.zeros(1, pose_size, dtype=poser.get_dtype()).to(device)

iris_small_left_index = pose_parameters.get_parameter_index("iris_small_left")
iris_small_right_index = pose_parameters.get_parameter_index("iris_small_right")
iris_rotation_x_index = pose_parameters.get_parameter_index("iris_rotation_x")
iris_rotation_y_index = pose_parameters.get_parameter_index("iris_rotation_y")
head_x_index = pose_parameters.get_parameter_index("head_x")
head_y_index = pose_parameters.get_parameter_index("head_y")
neck_z_index = pose_parameters.get_parameter_index("neck_z")
body_y_index = pose_parameters.get_parameter_index("body_y")
body_z_index = pose_parameters.get_parameter_index("body_z")
breathing_index = pose_parameters.get_parameter_index("breathing")

def get_pose():
    pose = torch.zeros(1, pose_size, dtype=poser.get_dtype())

    eyebrow_name = f"eyebrow_{eyebrow_dropdown.value}"
    eyebrow_left_index = pose_parameters.get_parameter_index(f"{eyebrow_name}_left")
    eyebrow_right_index = pose_parameters.get_parameter_index(f"{eyebrow_name}_right")
    pose[0, eyebrow_left_index] = eyebrow_left_slider.value
    pose[0, eyebrow_right_index] = eyebrow_right_slider.value

    eye_name = f"eye_{eye_dropdown.value}"
    eye_left_index = pose_parameters.get_parameter_index(f"{eye_name}_left")
    eye_right_index = pose_parameters.get_parameter_index(f"{eye_name}_right")
    pose[0, eye_left_index] = eye_left_slider.value
    pose[0, eye_right_index] = eye_right_slider.value

    mouth_name = f"mouth_{mouth_dropdown.value}"
    if mouth_name == "mouth_lowered_corner" or mouth_name == "mouth_raised_corner":
        mouth_left_index = pose_parameters.get_parameter_index(f"{mouth_name}_left")
        mouth_right_index = pose_parameters.get_parameter_index(f"{mouth_name}_right")
        pose[0, mouth_left_index] = mouth_left_slider.value
        pose[0, mouth_right_index] = mouth_right_slider.value
    else:
        mouth_index = pose_parameters.get_parameter_index(mouth_name)
        pose[0, mouth_index] = mouth_left_slider.value

    pose[0, iris_small_left_index] = iris_small_left_slider.value
    pose[0, iris_small_right_index] = iris_small_right_slider.value
    pose[0, iris_rotation_x_index] = iris_rotation_x_slider.value
    pose[0, iris_rotation_y_index] = iris_rotation_y_slider.value
    pose[0, head_x_index] = head_x_slider.value
    pose[0, head_y_index] = head_y_slider.value
    pose[0, neck_z_index] = neck_z_slider.value
    pose[0, body_y_index] = body_y_slider.value
    pose[0, body_z_index] = body_z_slider.value
    pose[0, breathing_index] = breathing_slider.value

    return pose.to(device)

display(controls)

def update(change):
    global last_pose
    global last_torch_input_image

    if torch_input_image is None:
        return

    needs_update = False
    if last_torch_input_image is None:
        needs_update = True        
    else:
        if (torch_input_image - last_torch_input_image).abs().max().item() > 0:
            needs_update = True         

    pose = get_pose()
    if (pose - last_pose).abs().max().item() > 0:
        needs_update = True

    if not needs_update:
        return

    output_image = poser.pose(torch_input_image, pose)[0]
    with output_image_widget:
        output_image_widget.clear_output(wait=True)
        show_pytorch_image(output_image)  

    last_torch_input_image = torch_input_image
    last_pose = pose

def upload_image(change):
    global torch_input_image
    for name, file_info in upload_input_image_button.value.items():
        content = io.BytesIO(file_info['content'])
    if content is not None:
        pil_image = resize_PIL_image(extract_PIL_image_from_filelike(content), size=(512,512))
        w, h = pil_image.size
        if pil_image.mode != 'RGBA':
            with output_image_widget:
                torch_input_image = None
                output_image_widget.clear_output(wait=True)
                display(ipywidgets.HTML("Image must have an alpha channel!!!"))
        else:
            torch_input_image = extract_pytorch_image_from_PIL_image(pil_image).to(device)
            if poser.get_dtype() == torch.half:
                torch_input_image = torch_input_image.half()
        update(None)

upload_input_image_button.observe(upload_image, names='value')
eyebrow_dropdown.observe(update, 'value')
eyebrow_left_slider.observe(update, 'value')
eyebrow_right_slider.observe(update, 'value')
eye_dropdown.observe(update, 'value')
eye_left_slider.observe(update, 'value')
eye_right_slider.observe(update, 'value')
mouth_dropdown.observe(update, 'value')
mouth_left_slider.observe(update, 'value')
mouth_right_slider.observe(update, 'value')
iris_small_left_slider.observe(update, 'value')
iris_small_right_slider.observe(update, 'value')
iris_rotation_x_slider.observe(update, 'value')
iris_rotation_y_slider.observe(update, 'value')
head_x_slider.observe(update, 'value')
head_y_slider.observe(update, 'value')
neck_z_slider.observe(update, 'value')
body_y_slider.observe(update, 'value')
body_z_slider.observe(update, 'value')
breathing_slider.observe(update, 'value')