Talking_Head_Anime_3 / tha3 /app /manual_poser.py
Harry_FBK
Clone original THA3
60094bd
raw
history blame contribute delete
No virus
20.2 kB
import argparse
import logging
import os
import sys
from typing import List
sys.path.append(os.getcwd())
import PIL.Image
import numpy
import torch
import wx
from tha3.poser.modes.load_poser import load_poser
from tha3.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup
from tha3.util import extract_pytorch_image_from_filelike, rgba_to_numpy_image, grid_change_to_numpy_image, \
rgb_to_numpy_image, resize_PIL_image, extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image
class MorphCategoryControlPanel(wx.Panel):
def __init__(self,
parent,
title: str,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.pose_param_category = pose_param_category
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)
self.sizer.Add(title_text, 0, wx.EXPAND)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])
if len(self.param_groups) > 0:
self.choice.SetSelection(0)
self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)
self.sizer.Add(self.choice, 0, wx.EXPAND)
self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.left_slider, 0, wx.EXPAND)
self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.right_slider, 0, wx.EXPAND)
self.checkbox = wx.CheckBox(self, label="Show")
self.checkbox.SetValue(True)
self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)
self.update_ui()
self.sizer.Fit(self)
def update_ui(self):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.left_slider.Enable(False)
self.right_slider.Enable(False)
self.checkbox.Enable(True)
elif param_group.get_arity() == 1:
self.left_slider.Enable(True)
self.right_slider.Enable(False)
self.checkbox.Enable(False)
else:
self.left_slider.Enable(True)
self.right_slider.Enable(True)
self.checkbox.Enable(False)
def on_choice_updated(self, event: wx.Event):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.checkbox.SetValue(True)
self.update_ui()
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
selected_morph_index = self.choice.GetSelection()
param_group = self.param_groups[selected_morph_index]
param_index = param_group.get_parameter_index()
if param_group.is_discrete():
if self.checkbox.GetValue():
for i in range(param_group.get_arity()):
pose[param_index + i] = 1.0
else:
param_range = param_group.get_range()
alpha = (self.left_slider.GetValue() + 1000) / 2000.0
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
if param_group.get_arity() == 2:
alpha = (self.right_slider.GetValue() + 1000) / 2000.0
pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha
class SimpleParamGroupsControlPanel(wx.Panel):
def __init__(self, parent,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
for param_group in self.param_groups:
assert not param_group.is_discrete()
assert param_group.get_arity() == 1
self.sliders = []
for param_group in self.param_groups:
static_text = wx.StaticText(
self,
label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER)
self.sizer.Add(static_text, 0, wx.EXPAND)
range = param_group.get_range()
min_value = int(range[0] * 1000)
max_value = int(range[1] * 1000)
slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)
self.sizer.Add(slider, 0, wx.EXPAND)
self.sliders.append(slider)
self.sizer.Fit(self)
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
for param_group_index in range(len(self.param_groups)):
param_group = self.param_groups[param_group_index]
slider = self.sliders[param_group_index]
param_range = param_group.get_range()
param_index = param_group.get_parameter_index()
alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
def convert_output_image_from_torch_to_numpy(output_image):
if output_image.shape[2] == 2:
h, w, c = output_image.shape
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)
elif output_image.shape[0] == 4:
numpy_image = rgba_to_numpy_image(output_image)
elif output_image.shape[0] == 3:
numpy_image = rgb_to_numpy_image(output_image)
elif output_image.shape[0] == 1:
c, h, w = output_image.shape
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)
numpy_image = rgba_to_numpy_image(alpha_image)
elif output_image.shape[0] == 2:
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)
else:
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0])
numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))
return numpy_image
class MainFrame(wx.Frame):
def __init__(self, poser: Poser, device: torch.device):
super().__init__(None, wx.ID_ANY, "Poser")
self.poser = poser
self.dtype = self.poser.get_dtype()
self.device = device
self.image_size = self.poser.get_image_size()
self.wx_source_image = None
self.torch_source_image = None
self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.SetSizer(self.main_sizer)
self.SetAutoLayout(1)
self.init_left_panel()
self.init_control_panel()
self.init_right_panel()
self.main_sizer.Fit(self)
self.timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_images, self.timer)
save_image_id = wx.NewIdRef()
self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)
accelerator_table = wx.AcceleratorTable([
(wx.ACCEL_CTRL, ord('S'), save_image_id)
])
self.SetAcceleratorTable(accelerator_table)
self.last_pose = None
self.last_output_index = self.output_index_choice.GetSelection()
self.last_output_numpy_image = None
self.wx_source_image = None
self.torch_source_image = None
self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
self.source_image_dirty = True
def init_left_panel(self):
self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1))
self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.left_panel.SetSizer(left_panel_sizer)
self.left_panel.SetAutoLayout(1)
self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size),
style=wx.SIMPLE_BORDER)
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n")
left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND)
self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image)
left_panel_sizer.Fit(self.left_panel)
self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)
def on_erase_background(self, event: wx.Event):
pass
def init_control_panel(self):
self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.control_panel.SetSizer(self.control_panel_sizer)
self.control_panel.SetMinSize(wx.Size(256, 1))
morph_categories = [
PoseParameterCategory.EYEBROW,
PoseParameterCategory.EYE,
PoseParameterCategory.MOUTH,
PoseParameterCategory.IRIS_MORPH
]
morph_category_titles = {
PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ",
PoseParameterCategory.EYE: " ------------ Eye ------------ ",
PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ",
PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ",
}
self.morph_control_panels = {}
for category in morph_categories:
param_groups = self.poser.get_pose_parameter_groups()
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = MorphCategoryControlPanel(
self.control_panel,
morph_category_titles[category],
category,
self.poser.get_pose_parameter_groups())
self.morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.non_morph_control_panels = {}
non_morph_categories = [
PoseParameterCategory.IRIS_ROTATION,
PoseParameterCategory.FACE_ROTATION,
PoseParameterCategory.BODY_ROTATION,
PoseParameterCategory.BREATHING
]
for category in non_morph_categories:
param_groups = self.poser.get_pose_parameter_groups()
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = SimpleParamGroupsControlPanel(
self.control_panel,
category,
self.poser.get_pose_parameter_groups())
self.non_morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.control_panel_sizer.Fit(self.control_panel)
self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)
def init_right_panel(self):
self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
right_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.right_panel.SetSizer(right_panel_sizer)
self.right_panel.SetAutoLayout(1)
self.result_image_panel = wx.Panel(self.right_panel,
size=(self.image_size, self.image_size),
style=wx.SIMPLE_BORDER)
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.output_index_choice = wx.Choice(
self.right_panel,
choices=[str(i) for i in range(self.poser.get_output_length())])
self.output_index_choice.SetSelection(0)
right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)
self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n")
right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)
self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)
right_panel_sizer.Fit(self.right_panel)
self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)
def create_param_category_choice(self, param_category: PoseParameterCategory):
params = []
for param_group in self.poser.get_pose_parameter_groups():
if param_group.get_category() == param_category:
params.append(param_group.get_group_name())
choice = wx.Choice(self.control_panel, choices=params)
if len(params) > 0:
choice.SetSelection(0)
return choice
def load_image(self, event: wx.Event):
dir_name = "data/images"
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN)
if file_dialog.ShowModal() == wx.ID_OK:
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name),
(self.poser.get_image_size(), self.poser.get_image_size()))
w, h = pil_image.size
if pil_image.mode != 'RGBA':
self.source_image_string = "Image must have alpha channel!"
self.wx_source_image = None
self.torch_source_image = None
else:
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\
.to(self.device).to(self.dtype)
self.source_image_dirty = True
self.Refresh()
self.Update()
except:
message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def paint_source_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
def paint_result_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
def draw_nothing_yet_string_to_bitmap(self, bitmap):
dc = wx.MemoryDC()
dc.SelectObject(bitmap)
dc.Clear()
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
dc.SetFont(font)
w, h = dc.GetTextExtent("Nothing yet!")
dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2)
del dc
def get_current_pose(self):
current_pose = [0.0 for i in range(self.poser.get_num_parameters())]
for morph_control_panel in self.morph_control_panels.values():
morph_control_panel.set_param_value(current_pose)
for rotation_control_panel in self.non_morph_control_panels.values():
rotation_control_panel.set_param_value(current_pose)
return current_pose
def update_images(self, event: wx.Event):
current_pose = self.get_current_pose()
if not self.source_image_dirty \
and self.last_pose is not None \
and self.last_pose == current_pose \
and self.last_output_index == self.output_index_choice.GetSelection():
return
self.last_pose = current_pose
self.last_output_index = self.output_index_choice.GetSelection()
if self.torch_source_image is None:
self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)
self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)
self.source_image_dirty = False
self.Refresh()
self.Update()
return
if self.source_image_dirty:
dc = wx.MemoryDC()
dc.SelectObject(self.source_image_bitmap)
dc.Clear()
dc.DrawBitmap(self.wx_source_image, 0, 0)
self.source_image_dirty = False
pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype)
output_index = self.output_index_choice.GetSelection()
with torch.no_grad():
output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()
numpy_image = convert_output_image_from_torch_to_numpy(output_image)
self.last_output_numpy_image = numpy_image
wx_image = wx.ImageFromBuffer(
numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap,
(self.image_size - numpy_image.shape[0]) // 2,
(self.image_size - numpy_image.shape[1]) // 2,
True)
del dc
self.Refresh()
self.Update()
def on_save_image(self, event: wx.Event):
if self.last_output_numpy_image is None:
logging.info("There is no output image to save!!!")
return
dir_name = "data/images"
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE)
if file_dialog.ShowModal() == wx.ID_OK:
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
if os.path.exists(image_file_name):
message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser",
wx.YES_NO | wx.ICON_QUESTION)
result = message_dialog.ShowModal()
if result == wx.ID_YES:
self.save_last_numpy_image(image_file_name)
message_dialog.Destroy()
else:
self.save_last_numpy_image(image_file_name)
except:
message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def save_last_numpy_image(self, image_file_name):
numpy_image = self.last_output_numpy_image
pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')
os.makedirs(os.path.dirname(image_file_name), exist_ok=True)
pil_image.save(image_file_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Manually pose a character image.')
parser.add_argument(
'--model',
type=str,
required=False,
default='standard_float',
choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
help='The model to use.')
args = parser.parse_args()
device = torch.device('cuda')
try:
poser = load_poser(args.model, device)
except RuntimeError as e:
print(e)
sys.exit()
app = wx.App()
main_frame = MainFrame(poser, device)
main_frame.Show(True)
main_frame.timer.Start(30)
app.MainLoop()