import itertools import json import os from collections import Counter import numpy as np import trimesh from matplotlib.path import Path from sklearn.neighbors import KNeighborsClassifier def load_segmentation(path, shape): """ Get a segmentation mask for a given image Arguments: path: path to the segmentation json file shape: shape of the output mask Returns: Returns a segmentation mask """ with open(path) as json_file: dict = json.load(json_file) segmentations = [] for key, val in dict.items(): if not key.startswith("item"): continue # Each item can have multiple polygons. Combine them to one # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation'])) # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int) coordinates = [] for segmentation_coord in val["segmentation"]: # The format before is [x1,y1, x2, y2, ....] x = segmentation_coord[::2] y = segmentation_coord[1::2] xy = np.vstack((x, y)).T coordinates.append(xy) segmentations.append({ "type": val["category_name"], "type_id": val["category_id"], "coordinates": coordinates, }) return segmentations def smpl_to_recon_labels(recon, smpl, k=1): """ Get the bodypart labels for the recon object by using the labels from the corresponding smpl object Arguments: recon: trimesh object (fully clothed model) shape: trimesh object (smpl model) k: number of nearest neighbours to use Returns: Returns a dictionary containing the bodypart and the corresponding indices """ smpl_vert_segmentation = json.load( open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json")) ) n = smpl.vertices.shape[0] y = np.array([None] * n) for key, val in smpl_vert_segmentation.items(): y[val] = key classifier = KNeighborsClassifier(n_neighbors=1) classifier.fit(smpl.vertices, y) y_pred = classifier.predict(recon.vertices) recon_labels = {} for key in smpl_vert_segmentation.keys(): recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int)) return recon_labels def extract_cloth(recon, segmentation, K, R, t, smpl=None): """ Extract a portion of a mesh using 2d segmentation coordinates Arguments: recon: fully clothed mesh seg_coord: segmentation coordinates in 2D (NDC) K: intrinsic matrix of the projection R: rotation matrix of the projection t: translation vector of the projection Returns: Returns a submesh using the segmentation coordinates """ seg_coord = segmentation["coord_normalized"] mesh = trimesh.Trimesh(recon.vertices, recon.faces) extrinsic = np.zeros((3, 4)) extrinsic[:3, :3] = R extrinsic[:, 3] = t P = K[:3, :3] @ extrinsic P_inv = np.linalg.pinv(P) # Each segmentation can contain multiple polygons # We need to check them separately points_so_far = [] faces = recon.faces for polygon in seg_coord: n = len(polygon) coords_h = np.hstack((polygon, np.ones((n, 1)))) # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates XYZ = P_inv @ coords_h[:, :, None] XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1])) XYZ = XYZ[:, :3] / XYZ[:, 3, None] p = Path(XYZ[:, :2]) grid = p.contains_points(recon.vertices[:, :2]) indeces = np.argwhere(grid == True) points_so_far += list(indeces.flatten()) if smpl is not None: num_verts = recon.vertices.shape[0] recon_labels = smpl_to_recon_labels(recon, smpl) body_parts_to_remove = [ "rightHand", "leftToeBase", "leftFoot", "rightFoot", "head", "leftHandIndex1", "rightHandIndex1", "rightToeBase", "leftHand", "rightHand", ] type = segmentation["type_id"] # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso) # https://github.com/switchablenorms/DeepFashion2 # Short sleeve clothes if type == 1 or type == 3 or type == 10: body_parts_to_remove += ["leftForeArm", "rightForeArm"] # No sleeves at all or lower body clothes elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9): body_parts_to_remove += [ "leftForeArm", "rightForeArm", "leftArm", "rightArm", ] # Shorts elif type == 7: body_parts_to_remove += [ "leftLeg", "rightLeg", "leftForeArm", "rightForeArm", "leftArm", "rightArm", ] verts_to_remove = list( itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove]) ) label_mask = np.zeros(num_verts, dtype=bool) label_mask[verts_to_remove] = True seg_mask = np.zeros(num_verts, dtype=bool) seg_mask[points_so_far] = True # Remove points that belong to other bodyparts # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask)) combine_mask = np.zeros(num_verts, dtype=bool) combine_mask[points_so_far] = True combine_mask[extra_verts_to_remove] = False all_indices = np.argwhere(combine_mask == True).flatten() i_x = np.where(np.in1d(faces[:, 0], all_indices))[0] i_y = np.where(np.in1d(faces[:, 1], all_indices))[0] i_z = np.where(np.in1d(faces[:, 2], all_indices))[0] faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z))) mask = np.zeros(len(recon.faces), dtype=bool) if len(faces_to_keep) > 0: mask[faces_to_keep] = True mesh.update_faces(mask) mesh.remove_unreferenced_vertices() # mesh.rezero() return mesh return None