TalkSHOW / nets /base.py
feifeifeiliu's picture
first version
865fd8a
import torch
import torch.nn as nn
import torch.optim as optim
class TrainWrapperBaseClass():
def __init__(self, args, config) -> None:
self.init_optimizer()
def init_optimizer(self) -> None:
print('using Adam')
self.generator_optimizer = optim.Adam(
self.generator.parameters(),
lr = self.config.Train.learning_rate.generator_learning_rate,
betas=[0.9, 0.999]
)
if self.discriminator is not None:
self.discriminator_optimizer = optim.Adam(
self.discriminator.parameters(),
lr = self.config.Train.learning_rate.discriminator_learning_rate,
betas=[0.9, 0.999]
)
def __call__(self, bat):
raise NotImplementedError
def get_loss(self, **kwargs):
raise NotImplementedError
def state_dict(self):
model_state = {
'generator': self.generator.state_dict(),
'generator_optim': self.generator_optimizer.state_dict(),
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
}
return model_state
def parameters(self):
return self.generator.parameters()
def load_state_dict(self, state_dict):
if 'generator' in state_dict:
self.generator.load_state_dict(state_dict['generator'])
else:
self.generator.load_state_dict(state_dict)
if 'generator_optim' in state_dict and self.generator_optimizer is not None:
self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
if self.discriminator is not None:
self.discriminator.load_state_dict(state_dict['discriminator'])
if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs):
raise NotImplementedError
def init_params(self):
if self.config.Data.pose.convert_to_6d:
scale = 2
else:
scale = 1
global_orient = round(0 * scale)
leye_pose = reye_pose = round(0 * scale)
jaw_pose = round(0 * scale)
body_pose = round((63 - 24) * scale)
left_hand_pose = right_hand_pose = round(45 * scale)
if self.expression:
expression = 100
else:
expression = 0
b_j = 0
jaw_dim = jaw_pose
b_e = b_j + jaw_dim
eye_dim = leye_pose + reye_pose
b_b = b_e + eye_dim
body_dim = global_orient + body_pose
b_h = b_b + body_dim
hand_dim = left_hand_pose + right_hand_pose
b_f = b_h + hand_dim
face_dim = expression
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
self.pose = int(self.full_dim / round(3 * scale))
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]