Spaces:
Sleeping
Sleeping
from ..custom_types import * | |
def get_gm_support(gm, x): | |
dim = x.shape[-1] | |
mu, p, phi, eigen = gm | |
sigma_det = eigen.prod(-1) | |
eigen_inv = 1 / eigen | |
sigma_inverse = torch.matmul(p.transpose(3, 4), p * eigen_inv[:, :, :, :, None]).squeeze(1) | |
phi = torch.softmax(phi, dim=2) | |
const_1 = phi / torch.sqrt((2 * np.pi) ** dim * sigma_det) | |
distance = x[:, :, None, :] - mu | |
mahalanobis_distance = - .5 * torch.einsum('bngd,bgdc,bngc->bng', distance, sigma_inverse, distance) | |
const_2, _ = mahalanobis_distance.max(dim=2) # for numeric stability | |
mahalanobis_distance -= const_2[:, :, None] | |
support = const_1 * torch.exp(mahalanobis_distance) | |
return support, const_2 | |
def gm_log_likelihood_loss(gms: TS, x: T, get_supports: bool = False, | |
mask: Optional[T] = None, reduction: str = "mean") -> Union[T, Tuple[T, TS]]: | |
batch_size, num_points, dim = x.shape | |
support, const = get_gm_support(gms, x) | |
probs = torch.log(support.sum(dim=2)) + const | |
if mask is not None: | |
probs = probs.masked_select(mask=mask.flatten()) | |
if reduction == 'none': | |
likelihood = probs.sum(-1) | |
loss = - likelihood / num_points | |
else: | |
likelihood = probs.sum() | |
loss = - likelihood / (probs.shape[0] * probs.shape[1]) | |
if get_supports: | |
return loss, support | |
return loss | |
def split_mesh_by_gmm(mesh: T_Mesh, gmm) -> Dict[int, T]: | |
faces_split = {} | |
vs, faces = mesh | |
vs_mid_faces = vs[faces].mean(1) | |
_, supports = gm_log_likelihood_loss(gmm, vs_mid_faces.unsqueeze(0), get_supports=True) | |
supports = supports[0] | |
label = supports.argmax(1) | |
for i in range(gmm[1].shape[2]): | |
select = label.eq(i) | |
if select.any(): | |
faces_split[i] = faces[select] | |
else: | |
faces_split[i] = None | |
return faces_split | |
def flatten_gmm(gmm: TS) -> T: | |
b, gp, g, _ = gmm[0].shape | |
mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmm] | |
p = p.reshape(*p.shape[:2], -1) | |
z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2) | |
return z_gmm | |