Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from enum import Enum | |
from typing import Tuple, List, Optional | |
import torch | |
from torch import Tensor | |
class PoseParameterCategory(Enum): | |
EYEBROW = 1 | |
EYE = 2 | |
IRIS_MORPH = 3 | |
IRIS_ROTATION = 4 | |
MOUTH = 5 | |
FACE_ROTATION = 6 | |
BODY_ROTATION = 7 | |
BREATHING = 8 | |
class PoseParameterGroup: | |
def __init__(self, | |
group_name: str, | |
parameter_index: int, | |
category: PoseParameterCategory, | |
arity: int = 1, | |
discrete: bool = False, | |
default_value: float = 0.0, | |
range: Optional[Tuple[float, float]] = None): | |
assert arity == 1 or arity == 2 | |
if range is None: | |
range = (0.0, 1.0) | |
if arity == 1: | |
parameter_names = [group_name] | |
else: | |
parameter_names = [group_name + "_left", group_name + "_right"] | |
assert len(parameter_names) == arity | |
self.parameter_names = parameter_names | |
self.range = range | |
self.default_value = default_value | |
self.discrete = discrete | |
self.arity = arity | |
self.category = category | |
self.parameter_index = parameter_index | |
self.group_name = group_name | |
def get_arity(self) -> int: | |
return self.arity | |
def get_group_name(self) -> str: | |
return self.group_name | |
def get_parameter_names(self) -> List[str]: | |
return self.parameter_names | |
def is_discrete(self) -> bool: | |
return self.discrete | |
def get_range(self) -> Tuple[float, float]: | |
return self.range | |
def get_default_value(self): | |
return self.default_value | |
def get_parameter_index(self): | |
return self.parameter_index | |
def get_category(self) -> PoseParameterCategory: | |
return self.category | |
class PoseParameters: | |
def __init__(self, pose_parameter_groups: List[PoseParameterGroup]): | |
self.pose_parameter_groups = pose_parameter_groups | |
def get_parameter_index(self, name: str) -> int: | |
index = 0 | |
for parameter_group in self.pose_parameter_groups: | |
for param_name in parameter_group.parameter_names: | |
if name == param_name: | |
return index | |
index += 1 | |
raise RuntimeError("Cannot find parameter with name %s" % name) | |
def get_parameter_name(self, index: int) -> str: | |
assert index >= 0 and index < self.get_parameter_count() | |
for group in self.pose_parameter_groups: | |
if index < group.get_arity(): | |
return group.get_parameter_names()[index] | |
index -= group.arity | |
raise RuntimeError("Something is wrong here!!!") | |
def get_pose_parameter_groups(self): | |
return self.pose_parameter_groups | |
def get_parameter_count(self): | |
count = 0 | |
for group in self.pose_parameter_groups: | |
count += group.arity | |
return count | |
class Builder: | |
def __init__(self): | |
self.index = 0 | |
self.pose_parameter_groups = [] | |
def add_parameter_group(self, | |
group_name: str, | |
category: PoseParameterCategory, | |
arity: int = 1, | |
discrete: bool = False, | |
default_value: float = 0.0, | |
range: Optional[Tuple[float, float]] = None): | |
self.pose_parameter_groups.append( | |
PoseParameterGroup( | |
group_name, | |
self.index, | |
category, | |
arity, | |
discrete, | |
default_value, | |
range)) | |
self.index += arity | |
return self | |
def build(self) -> 'PoseParameters': | |
return PoseParameters(self.pose_parameter_groups) | |
class Poser(ABC): | |
def get_image_size(self) -> int: | |
pass | |
def get_output_length(self) -> int: | |
pass | |
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: | |
pass | |
def get_num_parameters(self) -> int: | |
pass | |
def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor: | |
pass | |
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: | |
pass | |
def get_dtype(self) -> torch.dtype: | |
return torch.float | |