DveloperY0115's picture
init repo
801501a
from ..custom_types import *
def get_gm_support(gm, x):
dim = x.shape[-1]
mu, p, phi, eigen = gm
sigma_det = eigen.prod(-1)
eigen_inv = 1 / eigen
sigma_inverse = torch.matmul(p.transpose(3, 4), p * eigen_inv[:, :, :, :, None]).squeeze(1)
phi = torch.softmax(phi, dim=2)
const_1 = phi / torch.sqrt((2 * np.pi) ** dim * sigma_det)
distance = x[:, :, None, :] - mu
mahalanobis_distance = - .5 * torch.einsum('bngd,bgdc,bngc->bng', distance, sigma_inverse, distance)
const_2, _ = mahalanobis_distance.max(dim=2) # for numeric stability
mahalanobis_distance -= const_2[:, :, None]
support = const_1 * torch.exp(mahalanobis_distance)
return support, const_2
def gm_log_likelihood_loss(gms: TS, x: T, get_supports: bool = False,
mask: Optional[T] = None, reduction: str = "mean") -> Union[T, Tuple[T, TS]]:
batch_size, num_points, dim = x.shape
support, const = get_gm_support(gms, x)
probs = torch.log(support.sum(dim=2)) + const
if mask is not None:
probs = probs.masked_select(mask=mask.flatten())
if reduction == 'none':
likelihood = probs.sum(-1)
loss = - likelihood / num_points
else:
likelihood = probs.sum()
loss = - likelihood / (probs.shape[0] * probs.shape[1])
if get_supports:
return loss, support
return loss
def split_mesh_by_gmm(mesh: T_Mesh, gmm) -> Dict[int, T]:
faces_split = {}
vs, faces = mesh
vs_mid_faces = vs[faces].mean(1)
_, supports = gm_log_likelihood_loss(gmm, vs_mid_faces.unsqueeze(0), get_supports=True)
supports = supports[0]
label = supports.argmax(1)
for i in range(gmm[1].shape[2]):
select = label.eq(i)
if select.any():
faces_split[i] = faces[select]
else:
faces_split[i] = None
return faces_split
def flatten_gmm(gmm: TS) -> T:
b, gp, g, _ = gmm[0].shape
mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmm]
p = p.reshape(*p.shape[:2], -1)
z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2)
return z_gmm