+++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py @@ -14,6 +14,8 @@ # # Contact: ps-license@tuebingen.mpg.de +from pathlib import Path + import torch import trimesh import torch.nn as nn @@ -22,6 +24,17 @@ from .utils.mesh import winding_numbers + +def load_pkl(path): + with open(path, "rb") as fin: + return pickle.load(fin) + + +def save_pkl(obj, path): + with open(path, "wb") as fout: + pickle.dump(obj, fout) + + class BodySegment(nn.Module): def __init__(self, name, @@ -63,9 +76,17 @@ self.register_buffer('segment_faces', segment_faces) # create vector to select vertices form faces - tri_vidx = [] - for ii in range(faces.max().item()+1): - tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] + segments_folder = Path(segments_folder) + tri_vidx_path = segments_folder / "tri_vidx.pkl" + if not tri_vidx_path.is_file(): + tri_vidx = [] + for ii in range(faces.max().item()+1): + tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] + + save_pkl(tri_vidx, tri_vidx_path) + else: + tri_vidx = load_pkl(tri_vidx_path) + self.register_buffer('tri_vidx', torch.tensor(tri_vidx)) def create_band_faces(self): @@ -149,7 +170,7 @@ self.segmentation = {} for idx, name in enumerate(names): self.segmentation[name] = BodySegment(name, faces, segments_folder, - model_type).to('cuda') + model_type).to(device) def batch_has_self_isec_verts(self, vertices): """ +++ venv/lib/python3.10/site-packages/selfcontact/selfcontact.py @@ -41,6 +41,7 @@ test_segments=True, compute_hd=False, buffer_geodists=False, + device="cuda", ): super().__init__() @@ -95,7 +96,7 @@ if self.test_segments: sxseg = pickle.load(open(segments_bounds_path, 'rb')) self.segments = BatchBodySegment( - [x for x in sxseg.keys()], faces, segments_folder, self.model_type + [x for x in sxseg.keys()], faces, segments_folder, self.model_type, device=device, ) # load regressor to get high density mesh @@ -106,7 +107,7 @@ torch.tensor(hd_operator['values']), torch.Size(hd_operator['size'])) self.register_buffer('hd_operator', - torch.tensor(hd_operator).float()) + hd_operator.clone().detach().float()) with open(point_vert_corres_path, 'rb') as f: hd_geovec = pickle.load(f)['faces_vert_is_sampled_from'] @@ -135,9 +136,13 @@ # split because of memory into two chunks exterior = torch.zeros((bs, nv), device=vertices.device, dtype=torch.bool) - exterior[:, :5000] = winding_numbers(vertices[:,:5000,:], + exterior[:, :3000] = winding_numbers(vertices[:,:3000,:], triangles).le(0.99) - exterior[:, 5000:] = winding_numbers(vertices[:,5000:,:], + exterior[:, 3000:6000] = winding_numbers(vertices[:,3000:6000,:], + triangles).le(0.99) + exterior[:, 6000:9000] = winding_numbers(vertices[:,6000:9000,:], + triangles).le(0.99) + exterior[:, 9000:] = winding_numbers(vertices[:,9000:,:], triangles).le(0.99) # check if intersections happen within segments @@ -173,9 +178,13 @@ # split because of memory into two chunks exterior = torch.zeros((bs, np), device=points.device, dtype=torch.bool) - exterior[:, :6000] = winding_numbers(points[:,:6000,:], + exterior[:, :3000] = winding_numbers(points[:,:3000,:], + triangles).le(0.99) + exterior[:, 3000:6000] = winding_numbers(points[:,3000:6000,:], triangles).le(0.99) - exterior[:, 6000:] = winding_numbers(points[:,6000:,:], + exterior[:, 6000:9000] = winding_numbers(points[:,6000:9000,:], + triangles).le(0.99) + exterior[:, 9000:] = winding_numbers(points[:,9000:,:], triangles).le(0.99) return exterior @@ -371,6 +380,23 @@ return hd_v2v_mins, hd_exteriors, hd_points, hd_faces_in_contacts + def verts_in_contact(self, vertices, return_idx=False): + + # get pairwise distances of vertices + v2v = self.get_pairwise_dists(vertices, vertices, squared=True) + + # mask v2v with eucledean and geodesic dsitance + euclmask = v2v < self.euclthres**2 + mask = euclmask * self.geomask + + # find closes vertex in contact + in_contact = mask.sum(1) > 0 + + if return_idx: + in_contact = torch.where(in_contact) + + return in_contact + class SelfContactSmall(nn.Module): +++ venv/lib/python3.10/site-packages/selfcontact/utils/mesh.py @@ -82,7 +82,7 @@ if valid_vals > 0: loss = (mask * dists).sum() / valid_vals else: - loss = torch.Tensor([0]).cuda() + loss = mask.new_tensor([0]) return loss def batch_index_select(inp, dim, index): @@ -103,6 +103,7 @@ xx = torch.bmm(x, x.transpose(2, 1)) yy = torch.bmm(y, y.transpose(2, 1)) zz = torch.bmm(x, y.transpose(2, 1)) + use_cuda = x.device.type == "cuda" if use_cuda: dtype = torch.cuda.LongTensor else: