File size: 4,581 Bytes
60094bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, List, Optional

import torch
from torch import Tensor


class PoseParameterCategory(Enum):
    EYEBROW = 1
    EYE = 2
    IRIS_MORPH = 3
    IRIS_ROTATION = 4
    MOUTH = 5
    FACE_ROTATION = 6
    BODY_ROTATION = 7
    BREATHING = 8


class PoseParameterGroup:
    def __init__(self,
                 group_name: str,
                 parameter_index: int,
                 category: PoseParameterCategory,
                 arity: int = 1,
                 discrete: bool = False,
                 default_value: float = 0.0,
                 range: Optional[Tuple[float, float]] = None):
        assert arity == 1 or arity == 2
        if range is None:
            range = (0.0, 1.0)
        if arity == 1:
            parameter_names = [group_name]
        else:
            parameter_names = [group_name + "_left", group_name + "_right"]
        assert len(parameter_names) == arity

        self.parameter_names = parameter_names
        self.range = range
        self.default_value = default_value
        self.discrete = discrete
        self.arity = arity
        self.category = category
        self.parameter_index = parameter_index
        self.group_name = group_name

    def get_arity(self) -> int:
        return self.arity

    def get_group_name(self) -> str:
        return self.group_name

    def get_parameter_names(self) -> List[str]:
        return self.parameter_names

    def is_discrete(self) -> bool:
        return self.discrete

    def get_range(self) -> Tuple[float, float]:
        return self.range

    def get_default_value(self):
        return self.default_value

    def get_parameter_index(self):
        return self.parameter_index

    def get_category(self) -> PoseParameterCategory:
        return self.category


class PoseParameters:
    def __init__(self, pose_parameter_groups: List[PoseParameterGroup]):
        self.pose_parameter_groups = pose_parameter_groups

    def get_parameter_index(self, name: str) -> int:
        index = 0
        for parameter_group in self.pose_parameter_groups:
            for param_name in parameter_group.parameter_names:
                if name == param_name:
                    return index
                index += 1
        raise RuntimeError("Cannot find parameter with name %s" % name)

    def get_parameter_name(self, index: int) -> str:
        assert index >= 0 and index < self.get_parameter_count()

        for group in self.pose_parameter_groups:
            if index < group.get_arity():
                return group.get_parameter_names()[index]
            index -= group.arity

        raise RuntimeError("Something is wrong here!!!")

    def get_pose_parameter_groups(self):
        return self.pose_parameter_groups

    def get_parameter_count(self):
        count = 0
        for group in self.pose_parameter_groups:
            count += group.arity
        return count

    class Builder:
        def __init__(self):
            self.index = 0
            self.pose_parameter_groups = []

        def add_parameter_group(self,
                                group_name: str,
                                category: PoseParameterCategory,
                                arity: int = 1,
                                discrete: bool = False,
                                default_value: float = 0.0,
                                range: Optional[Tuple[float, float]] = None):
            self.pose_parameter_groups.append(
                PoseParameterGroup(
                    group_name,
                    self.index,
                    category,
                    arity,
                    discrete,
                    default_value,
                    range))
            self.index += arity
            return self

        def build(self) -> 'PoseParameters':
            return PoseParameters(self.pose_parameter_groups)


class Poser(ABC):
    @abstractmethod
    def get_image_size(self) -> int:
        pass

    @abstractmethod
    def get_output_length(self) -> int:
        pass

    @abstractmethod
    def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
        pass

    @abstractmethod
    def get_num_parameters(self) -> int:
        pass

    @abstractmethod
    def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor:
        pass

    @abstractmethod
    def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
        pass

    def get_dtype(self) -> torch.dtype:
        return torch.float