Talking_Head_Anime_3 / tha3 /poser /general_poser_02.py
Harry_FBK
Clone original THA3
60094bd
raw
history blame contribute delete
No virus
2.93 kB
from typing import List, Optional, Tuple, Dict, Callable
import torch
from torch import Tensor
from torch.nn import Module
from tha3.poser.poser import PoseParameterGroup, Poser
from tha3.compute.cached_computation_func import TensorListCachedComputationFunc
class GeneralPoser02(Poser):
def __init__(self,
module_loaders: Dict[str, Callable[[], Module]],
device: torch.device,
output_length: int,
pose_parameters: List[PoseParameterGroup],
output_list_func: TensorListCachedComputationFunc,
subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None,
default_output_index: int = 0,
image_size: int = 256,
dtype: torch.dtype = torch.float):
self.dtype = dtype
self.image_size = image_size
self.default_output_index = default_output_index
self.output_list_func = output_list_func
self.subrect = subrect
self.pose_parameters = pose_parameters
self.device = device
self.module_loaders = module_loaders
self.modules = None
self.num_parameters = 0
for pose_parameter in self.pose_parameters:
self.num_parameters += pose_parameter.get_arity()
self.output_length = output_length
def get_image_size(self) -> int:
return self.image_size
def get_modules(self):
if self.modules is None:
self.modules = {}
for key in self.module_loaders:
module = self.module_loaders[key]()
self.modules[key] = module
module.to(self.device)
module.train(False)
return self.modules
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
return self.pose_parameters
def get_num_parameters(self) -> int:
return self.num_parameters
def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor:
if output_index is None:
output_index = self.default_output_index
output_list = self.get_posing_outputs(image, pose)
return output_list[output_index]
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
modules = self.get_modules()
if len(image.shape) == 3:
image = image.unsqueeze(0)
if len(pose.shape) == 1:
pose = pose.unsqueeze(0)
if self.subrect is not None:
image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]]
batch = [image, pose]
outputs = {}
return self.output_list_func(modules, batch, outputs)
def get_output_length(self) -> int:
return self.output_length
def free(self):
self.modules = None
def get_dtype(self) -> torch.dtype:
return self.dtype