from abc import ABC, abstractmethod from enum import Enum from typing import Tuple, List, Optional import torch from torch import Tensor class PoseParameterCategory(Enum): EYEBROW = 1 EYE = 2 IRIS_MORPH = 3 IRIS_ROTATION = 4 MOUTH = 5 FACE_ROTATION = 6 BODY_ROTATION = 7 BREATHING = 8 class PoseParameterGroup: def __init__(self, group_name: str, parameter_index: int, category: PoseParameterCategory, arity: int = 1, discrete: bool = False, default_value: float = 0.0, range: Optional[Tuple[float, float]] = None): assert arity == 1 or arity == 2 if range is None: range = (0.0, 1.0) if arity == 1: parameter_names = [group_name] else: parameter_names = [group_name + "_left", group_name + "_right"] assert len(parameter_names) == arity self.parameter_names = parameter_names self.range = range self.default_value = default_value self.discrete = discrete self.arity = arity self.category = category self.parameter_index = parameter_index self.group_name = group_name def get_arity(self) -> int: return self.arity def get_group_name(self) -> str: return self.group_name def get_parameter_names(self) -> List[str]: return self.parameter_names def is_discrete(self) -> bool: return self.discrete def get_range(self) -> Tuple[float, float]: return self.range def get_default_value(self): return self.default_value def get_parameter_index(self): return self.parameter_index def get_category(self) -> PoseParameterCategory: return self.category class PoseParameters: def __init__(self, pose_parameter_groups: List[PoseParameterGroup]): self.pose_parameter_groups = pose_parameter_groups def get_parameter_index(self, name: str) -> int: index = 0 for parameter_group in self.pose_parameter_groups: for param_name in parameter_group.parameter_names: if name == param_name: return index index += 1 raise RuntimeError("Cannot find parameter with name %s" % name) def get_parameter_name(self, index: int) -> str: assert index >= 0 and index < self.get_parameter_count() for group in self.pose_parameter_groups: if index < group.get_arity(): return group.get_parameter_names()[index] index -= group.arity raise RuntimeError("Something is wrong here!!!") def get_pose_parameter_groups(self): return self.pose_parameter_groups def get_parameter_count(self): count = 0 for group in self.pose_parameter_groups: count += group.arity return count class Builder: def __init__(self): self.index = 0 self.pose_parameter_groups = [] def add_parameter_group(self, group_name: str, category: PoseParameterCategory, arity: int = 1, discrete: bool = False, default_value: float = 0.0, range: Optional[Tuple[float, float]] = None): self.pose_parameter_groups.append( PoseParameterGroup( group_name, self.index, category, arity, discrete, default_value, range)) self.index += arity return self def build(self) -> 'PoseParameters': return PoseParameters(self.pose_parameter_groups) class Poser(ABC): @abstractmethod def get_image_size(self) -> int: pass @abstractmethod def get_output_length(self) -> int: pass @abstractmethod def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: pass @abstractmethod def get_num_parameters(self) -> int: pass @abstractmethod def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor: pass @abstractmethod def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: pass def get_dtype(self) -> torch.dtype: return torch.float