Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import constants | |
import vtk | |
import vtk.util.numpy_support as numpy_support | |
from custom_types import * | |
from utils import files_utils, rotation_utils | |
from models import gm_utils | |
from ui import ui_utils, inference_processing, gaussian_status | |
import options | |
def filter_by_inclusion(gaussian: gaussian_status.GaussianStatus) -> bool: | |
return gaussian.included | |
def filter_by_selection(gaussian: gaussian_status.GaussianStatus) -> bool: | |
return gaussian.is_selected | |
class GmmMeshStage: | |
def turn_off_selected(self): | |
if self.selected is not None: | |
# self.arrows.turn_off() | |
self.toggle_selection(self.selected) | |
self.selected = None | |
def turn_gmm_off(self): | |
self.turn_off_selected() | |
for gaussian in self.gmm: | |
gaussian.turn_off() | |
def turn_gmm_on(self): | |
for gaussian in self.gmm: | |
gaussian.turn_on() | |
def event_manger(self, object_id: str): | |
if object_id in self.addresses_dict: | |
return self.toggle_selection(object_id) | |
elif self.arrows.check_event(object_id): | |
transform = self.arrows.get_transform(object_id) | |
self.update_gmm(*transform) | |
return True | |
return False | |
def toggle_selection(self, object_id: str): | |
self.gmm[self.addresses_dict[object_id]].toggle_selection() | |
if self.selected is None: | |
self.selected = object_id | |
elif self.selected == object_id and self.gmm[self.addresses_dict[object_id]].is_not_selected: | |
self.selected = None | |
else: | |
self.gmm[self.addresses_dict[self.selected]].toggle_selection() | |
self.selected = object_id | |
# if self.selected is not None: | |
# self.arrows.update_arrows_transform(self.gmm[self.addresses_dict[self.selected]]) | |
# else: | |
# self.arrows.turn_off() | |
return True | |
def toggle_inclusion_by_id(self, g_id: int, select: Optional[bool] = None) -> Tuple[bool, List[gaussian_status.GaussianStatus]]: | |
toggled = [] | |
self.gmm[g_id].toggle_inclusion(select) | |
toggled.append(self.gmm[g_id]) | |
if self.symmetric_mode: | |
if self.gmm[g_id].twin is not None and self.gmm[g_id].twin.included != self.gmm[g_id].included: | |
self.gmm[g_id].twin.toggle_inclusion(select) | |
toggled.append(self.gmm[g_id].twin) | |
return True, toggled | |
def toggle_inclusion(self, object_id: str) -> Tuple[bool, List[gaussian_status.GaussianStatus]]: | |
if object_id in self.addresses_dict: | |
return self.toggle_inclusion_by_id(self.addresses_dict[object_id]) | |
return False, [] | |
def toggle_all(self): | |
for gaussian in self.gmm: | |
gaussian.toggle_inclusion() | |
def __len__(self): | |
return len(self.gmm) | |
def set_opacity(self, opacity: float): | |
self.view_style.opacity = opacity | |
for gaussian in self.gmm: | |
gaussian.set_color() | |
def update_gmm(self, button: ui_utils.Buttons, key: str) -> bool: | |
if self.selected is not None: | |
g_id = self.addresses_dict[self.selected] | |
self.gmm[g_id].apply_affine(button, key) | |
if self.symmetric_mode: | |
if self.gmm[g_id].twin is not None: | |
self.gmm[g_id].twin.make_symmetric(False) | |
# self.arrows.update_arrows_transform(self.gmm[self.addresses_dict[self.selected]]) | |
return True | |
return False | |
def get_gmm(self) -> Tuple[TS, T]: | |
raw_gmm = [g.get_raw_data() for g in self.gmm if g.included] | |
phi = torch.tensor([g[0] for g in raw_gmm], dtype=torch.float32).view(1, 1, -1) | |
# phi = torch.from_numpy(self.raw_gmm[0]).view(1, 1, -1).float() | |
mu = torch.stack([torch.from_numpy(g[1]).float() for g in raw_gmm], dim=0).view(1, 1, -1, 3) | |
p = torch.stack([torch.from_numpy(g[3]).float() for g in raw_gmm], dim=0).view(1, 1, -1, 3, 3) | |
eigen = torch.stack([torch.from_numpy(g[2]).float() for g in raw_gmm], dim=0).view(1, 1, -1, 3) | |
gmm = mu, p, phi, eigen | |
included = torch.tensor([g.gaussian_id for g in self.gmm if g.included], dtype=torch.int64) | |
return gmm, included | |
def reset(self): | |
for g in self.gmm: | |
g.reset() | |
# self.turn_off_selected() | |
def remove_all(self): | |
self.remove_gaussians(list(self.addresses_dict.keys())) | |
self.addresses_dict = {} | |
self.gmm = [] | |
# def switch_arrows(self, arrow_type: ui_utils.Buttons): | |
# if self.arrows.switch_arrows(arrow_type) and self.selected is not None: | |
# self.arrows.update_arrows_transform(self.gmm[self.addresses_dict[self.selected]]) | |
def toggle_symmetric(self, force_include: bool): | |
self.symmetric_mode = not self.symmetric_mode and False | |
# visited = set() | |
if self.symmetric_mode: | |
for i in range(len(self)): | |
self.gmm[i].make_symmetric(force_include) | |
def remove_gaussians(self, addresses: List[str]): | |
for address in addresses: | |
gaussian_id: int = self.addresses_dict[address] | |
gaussian = self.gmm[gaussian_id] | |
# if gaussian.is_selected: | |
# self.toggle_selection(address) | |
self.gmm[gaussian_id] = None | |
gaussian.delete(self.render) | |
del self.addresses_dict[address] | |
self.gmm = [gaussian for gaussian in self.gmm if gaussian is not None] | |
self.addresses_dict = {self.gmm[i].get_address(): i for i in range(len(self.gmm))} | |
def add_gaussians(self, gaussians: List[gaussian_status.GaussianStatus]) -> List[str]: | |
new_addresses = [] | |
for i, gaussian in enumerate(gaussians): | |
gaussian_copy = gaussian.copy(self.render, self.view_style, is_selected=False) | |
self.gmm.append(gaussian_copy) | |
new_addresses.append(gaussian_copy.get_address()) | |
self.addresses_dict = {self.gmm[i].get_address(): i for i in range(len(self.gmm))} | |
return new_addresses | |
def make_twins(self, address_a: str, address_b: str): | |
if address_a in self.addresses_dict and address_b in self.addresses_dict: | |
gaussian_a, gaussian_b = self.gmm[self.addresses_dict[address_a]], self.gmm[self.addresses_dict[address_b]] | |
gaussian_a.twin = gaussian_b | |
gaussian_b.twin = gaussian_a | |
def split_mesh_by_gmm(self, mesh) -> Dict[int, T]: | |
faces_split = {} | |
mu, p, phi, _ = self.get_gmm()[0] | |
eigen = torch.stack([torch.from_numpy(g.get_view_eigen()).float() for g in self.gmm if g.included], dim=0).view(1, 1, -1, 3) | |
gmm = mu, p, phi, eigen | |
faces_split_ = gm_utils.split_mesh_by_gmm(mesh, gmm) | |
counter = 0 | |
for i in range(len(self.gmm)): | |
if self.gmm[i].disabled: | |
faces_split[i] = None | |
else: | |
faces_split[i] = faces_split_[counter] | |
counter += 1 | |
return faces_split | |
def get_part_face(mesh: V_Mesh, faces_inds: T) -> Tuple[T_Mesh, T]: | |
mesh = mesh[0], torch.from_numpy(mesh[1]).long() | |
mask = faces_inds.ne(0) | |
faces = mesh[1][mask] | |
vs_inds = faces.flatten().unique() | |
vs = mesh[0][vs_inds] | |
mapper = torch.zeros(mesh[0].shape[0], dtype=torch.int64) | |
mapper[vs_inds] = torch.arange(vs.shape[0]) | |
return (vs, mapper[faces]), faces_inds[mask] | |
def save(self, root: str, filter_faces: Callable[[gaussian_status.GaussianStatus], bool] = filter_by_inclusion): | |
if self.faces is not None: | |
if self.gmm_id == -1: | |
name = "mix" | |
else: | |
name = str(self.gmm_id) | |
path = f"{root}/{files_utils.get_time_name(name)}" | |
faces = list(filter(lambda x: x[1] is not None, self.faces.items())) | |
mesh = self.vs, np.concatenate(list(map(lambda x: x[1], faces))) | |
faces_inds = map(lambda x: | |
torch.ones(x[1].shape[0], dtype=torch.int64) | |
if filter_faces(self.gmm[x[0]]) else torch.zeros(x[1].shape[0], dtype=torch.int64), faces) | |
faces_inds = torch.cat(list(faces_inds)) | |
# if name != 'mix': | |
# mesh, faces_inds = self.get_part_face(mesh, faces_inds) | |
files_utils.export_mesh(mesh, path) | |
files_utils.export_list(faces_inds.tolist(), f"{path}_faces") | |
def aggregate_symmetric(self) -> Dict[str, int]: | |
if not self.symmetric_mode: | |
return self.votes | |
out = {} | |
for item in self.votes: | |
actor_id = self.addresses_dict[item] | |
twin = self.gmm[actor_id].twin | |
out[item] = self.votes[item] | |
if twin is not None and twin.get_address() not in self.votes: | |
out[twin.get_address()] = self.votes[item] | |
return out | |
def aggregate_votes(self) -> List[int]: | |
# to_do = self.add_selection if select else self.clear_selection | |
actors_id = [] | |
# votes = self.aggregate_symmetric() | |
for item in self.votes: | |
actor_id = self.addresses_dict[item] | |
actors_id.append(actor_id) | |
self.votes = {} | |
return actors_id | |
def vote(self, *actors: Optional[vtk.vtkActor]): | |
for actor in actors: | |
if actor is not None: | |
address = actor.GetAddressAsString('') | |
if address in self.addresses_dict: | |
if address not in self.votes: | |
self.votes[address] = 0 | |
self.votes[address] += 1 | |
def faces_to_vtk_faces(faces: Union[T, ARRAY]): | |
if type(faces) is T: | |
faces = faces.detach().cpu().numpy() | |
cells_npy = np.column_stack( | |
[np.full(faces.shape[0], 3, dtype=np.int64), faces.astype(np.int64)]).ravel() | |
faces_vtk = vtk.vtkCellArray() | |
faces_vtk.SetCells(faces.shape[0], numpy_support.numpy_to_vtkIdTypeArray(cells_npy)) | |
return faces_vtk | |
def get_mesh_part(self, vs: vtk.vtkPoints, faces: Optional[Union[T, ARRAY]]) -> Optional[vtk.vtkPolyData]: | |
if faces is not None: | |
# actor_mesh = vtk.vtkActor() | |
mesh = vtk.vtkPolyData() | |
# mapper = vtk.vtkPolyDataMapper() | |
mesh.SetPoints(vs) | |
mesh.SetPolys(self.faces_to_vtk_faces(faces)) | |
# mapper.SetInputData(mesh) | |
# actor_mesh.SetMapper(mapper) | |
# actor_mesh.GetProperty().SetOpacity(0.3) | |
# actor_mesh.PickableOff() | |
# if self.to_init: | |
# self.render.AddActor(actor_mesh) | |
return mesh | |
return None | |
def add_gmm(self) -> List[gaussian_status.GaussianStatus]: | |
gmms = [] | |
if len(self.raw_gmm) > 0: | |
phi = self.raw_gmm[0] | |
phi = np.exp(phi) | |
phi = phi / phi.sum() | |
for i, gaussian in enumerate(zip(*self.raw_gmm)): | |
gaussian = gaussian_status.GaussianStatus(gaussian, (self.gmm_id, i), False, self.view_style, | |
self.render, phi[i]) | |
gmms.append(gaussian) | |
return gmms | |
def add_mesh(self, base_mesh: T_Mesh, split_mesh: bool = True, for_slider: bool = True): | |
if base_mesh is not None: | |
vs_vtk = vtk.vtkPoints() | |
self.vs = base_mesh[0] | |
if for_slider: | |
vs_ui = self.init_mesh_pos(base_mesh[0]) | |
else: | |
vs_ui = self.vs | |
vs_vtk.SetData(numpy_support.numpy_to_vtk(vs_ui.numpy())) | |
if split_mesh: | |
self.faces = self.split_mesh_by_gmm(base_mesh) | |
for i in range(len(self.gmm)): | |
part_mesh = self.get_mesh_part(vs_vtk, self.faces[i]) | |
self.gmm[i].replace_part(part_mesh) | |
else: | |
part_mesh = self.get_mesh_part(vs_vtk, base_mesh[1]) | |
self.gmm[0].replace_part(part_mesh) | |
def set_brush(self, is_draw: bool): | |
self.render.set_brush(is_draw) | |
def replace_mesh(self, mesh: Optional[V_Mesh]): | |
mesh = torch.from_numpy(mesh[0]).float(), torch.from_numpy(mesh[1]).long() | |
self.add_mesh(mesh, for_slider=False) | |
# if mesh is None: | |
# return | |
# else: | |
# reduction = 1 - 50000. / mesh[1].shape[0] | |
# source_ = MeshStage.mesh_to_polydata(mesh) | |
# source_ = MeshStage.smooth_mesh(source_, ui_utils.SmoothingMethod.Taubin) | |
# self.decimate_mesh(source_, reduction, out=self.mapper.GetInput()) | |
# self.is_changed = True | |
# if not self.to_init: | |
# self.to_init = True | |
# self.render.AddActor(self.actor) | |
def init_mesh_pos(self, vs: T): | |
vs = vs.clone() | |
r_a = rotation_utils.get_rotation_matrix(150, 1, degree=True) | |
r_b = rotation_utils.get_rotation_matrix(-15, 0, degree=True) | |
r = torch.from_numpy(np.einsum('km,mn->kn', r_b, r_a)).float() | |
vs = torch.einsum('ad,nd->na', r, vs) | |
vs[:, 0] += self.gmm_id * 2 | |
return vs | |
def mesh_to_polydata(mesh: Union[T_Mesh, V_Mesh], source: Optional[vtk.vtkPolyData] = None) -> vtk.vtkPolyData: | |
if source is None: | |
source = vtk.vtkPolyData() | |
vs, faces = mesh | |
if type(vs) is T: | |
vs, faces = vs.detach().cpu().numpy(), faces.detach().cpu().numpy() | |
new_vs_vtk = numpy_support.numpy_to_vtk(vs) | |
cells_npy = np.column_stack( | |
[np.full(faces.shape[0], 3, dtype=np.int64), faces.astype(np.int64)]).ravel() | |
vs_vtk, faces_vtk = vtk.vtkPoints(), vtk.vtkCellArray() | |
vs_vtk.SetData(new_vs_vtk) | |
faces_vtk.SetCells(faces.shape[0], numpy_support.numpy_to_vtkIdTypeArray(cells_npy)) | |
source.SetPoints(vs_vtk) | |
source.SetPolys(faces_vtk) | |
return source | |
def included(self): | |
for g in self.gmm: | |
if g.included: | |
return True | |
return False | |
def move_mesh_to_end(self, cycle: int): | |
self.offset += cycle | |
vs = None | |
for i in range(len(self)): | |
mapper = self.gmm[i].mapper | |
if mapper is not None and mapper.GetInput() is not None: | |
vs_vtk = mapper.GetInput().GetPoints() | |
if vs is None: | |
vs = numpy_support.vtk_to_numpy(vs_vtk.GetData()) | |
vs[:, 0] += cycle * 2 | |
vs_vtk.SetData(numpy_support.numpy_to_vtk(vs)) | |
def pick(self, actor_address: str) -> bool: | |
return actor_address in self.addresses_dict | |
def __init__(self, opt: options.Options, shape_path: List[str], render: ui_utils.CanvasRender, render_number: int, | |
view_style: ui_utils.ViewStyle, to_init=True): | |
self.view_style = view_style | |
self.votes = {} | |
self.shape_id = shape_path[1] | |
self.gmm_id = render_number | |
self.render = render | |
self.symmetric_mode = sum(opt.symmetric) > 0 and False | |
self.selected = None | |
self.offset = render_number | |
# self.arrows = arrows.ArrowManger(render) | |
if self.shape_id != '-1': | |
self.base_mesh = files_utils.load_mesh( ''.join(shape_path)) | |
self.raw_gmm = files_utils.load_gmm(f'{shape_path[0]}/{shape_path[1]}.txt', as_np=True)[:-1] | |
else: | |
self.base_mesh = None | |
self.raw_gmm = [] | |
self.to_init = to_init | |
self.is_changed = False | |
self.gmm: List[gaussian_status.GaussianStatus] = self.add_gmm() | |
self.vs = self.faces = None | |
self.add_mesh(self.base_mesh) | |
self.addresses_dict: Dict[str, int] = {self.gmm[i].get_address(): i for i in range(len(self.gmm))} | |
if self.symmetric_mode: | |
for i in range(len(self) // 2): | |
self.make_twins(self.gmm[i].get_address(), self.gmm[i + len(self) // 2].get_address()) | |
self.toggle_all() | |
# if self.raw_gmm: | |
# gmms = self.get_gmm()[0] | |
# files_utils.export_gmm(gmms, 0, f"./{render_number}") | |
class GmmStatuses: | |
def __len__(self): | |
return len(self.gmms) | |
def switch_arrows(self, arrow_type: ui_utils.Buttons): | |
self.main_gmm.switch_arrows(arrow_type) | |
def turn_gmm_off(self): | |
self.main_gmm.turn_gmm_off() | |
def turn_gmm_on(self): | |
self.main_gmm.turn_gmm_on() | |
def update_gmm(self, button: ui_utils.Buttons, key: str) -> bool: | |
return self.main_gmm.update_gmm(button, key) | |
def toggle_symmetric(self, force_include: bool = False): | |
for gmm in self.gmms: | |
gmm.toggle_symmetric(force_include) | |
def event_manger(self, object_id: str): | |
for gmm in self.gmms: | |
if gmm.event_manger(object_id): | |
return True | |
return False | |
def toggle_inclusion(self, object_id: str): | |
for gmm in self.gmms: | |
if gmm.toggle_inclusion(object_id)[0]: | |
return True | |
return False | |
def main_gmm(self) -> GmmMeshStage: | |
return self.gmms[0] | |
def reset(self): | |
for gmm in self.gmms: | |
gmm.reset() | |
def set_brush(self, is_draw: bool): | |
for gmm in self.gmms: | |
gmm.set_brush(is_draw) | |
def move_mesh_to_end(self, ptr: int): | |
self.gmms[ptr].move_mesh_to_end(len(self)) | |
def pick(self, actor_address: str) -> Optional[GmmMeshStage]: | |
for gmm in self.gmms: | |
if gmm.pick(actor_address): | |
return gmm | |
return None | |
def __init__(self, opt: options.Options, shape_paths: List[List[str]], render, view_styles: List[ui_utils.ViewStyle]): | |
self.gmms = [GmmMeshStage(opt, shape_path, render, i, view_style) for i, (shape_path, view_style) in | |
enumerate(zip(shape_paths, view_styles))] | |
def to_local(func): | |
def inner(self: MeshGmmStatuses.TransitionController, mouse_pos: Optional[Tuple[int, int]], *args, **kwargs): | |
if mouse_pos is not None: | |
size = self.render.GetRenderWindow().GetScreenSize() | |
aspect = self.render.GetAspect() | |
mouse_pos = float(mouse_pos[0]) / size[0] - .5, float(mouse_pos[1]) / size[1] - .5 | |
mouse_pos = torch.tensor([mouse_pos[0] / aspect[1], mouse_pos[1] / aspect[0]]) | |
return func(self, mouse_pos, *args, **kwargs) | |
return inner | |
class MeshGmmStatuses(GmmStatuses): | |
def aggregate_votes(self, select: bool): | |
if self.cur_canvas < len(self.gmms): | |
stage = self.gmms[self.cur_canvas] | |
changed = stage.aggregate_votes() | |
changed = list(filter(lambda x: not stage.gmm[x].disabled and stage.gmm[x].is_selected != select, changed)) | |
for item in changed: | |
stage.gmm[item].toggle_selection() | |
return len(changed) > 0 | |
def vote(self, *actors: Optional[vtk.vtkActor]): | |
self.gmms[self.cur_canvas].vote(*actors) | |
def init_draw(self, side: int): | |
self.cur_canvas = side | |
def sort_gmms(self, gmms, included): | |
order = torch.arange(gmms[0].shape[2]).tolist() | |
order = sorted(order, key=lambda x: included[x][0] * 100 + included[x][1]) | |
gmms = [[item[:, :, order[i]] for item in gmms] for i in range(gmms[0].shape[2])] | |
gmms = [torch.stack([gmms[j][i] for j in range(len(gmms))], dim=2) for i in range(len(gmms[0]))] | |
return gmms | |
def save_light(self, root, gmms): | |
gmms = self.sort_gmms(*gmms) | |
save_dict = {'ids': { | |
gmm.shape_id: [gaussian.gaussian_id[1] for gaussian in gmm.gmm if gaussian.included] | |
for gmm in self.gmms if gmm.included}, | |
'gmm': gmms} | |
path = f"{root}/{files_utils.get_time_name('light')}" | |
files_utils.save_pickle(save_dict, path) | |
def save(self, root: str, gmms): | |
# for gmm in self.gmms: | |
# if gmm.included: | |
# gmm.save(root) | |
if len(gmms[0]) > 0: | |
self.save_light(root, gmms) | |
def set_brush(self, is_draw: bool): | |
super(MeshGmmStatuses, self).set_brush(is_draw) | |
self.main_gmm.render.set_brush(is_draw) | |
def update_mesh(self, res=128): | |
if self.model_process is not None: | |
self.model_process.get_mesh(res) | |
return True | |
return False | |
# self.all_info[side] = gaussian_inds | |
def request_gmm(self) -> Tuple[TS, T]: | |
gmm, included = self.main_gmm.get_gmm() | |
return gmm, included | |
def replace_mesh(self): | |
if self.model_process is not None: | |
self.model_process.replace_mesh() | |
def exit(self): | |
if self.model_process is not None: | |
self.model_process.exit() | |
def main_stage(self) -> GmmMeshStage: | |
return self.gmms[0] | |
def stages(self): | |
return self.gmms | |
class TransitionController: | |
def moving_axis(self) -> int: | |
return {ui_utils.EditDirection.X_Axis: 0, | |
ui_utils.EditDirection.Y_Axis: 2, | |
ui_utils.EditDirection.Z_Axis: 1}[self.edit_direction] | |
def get_delta_translation(self, mouse_pos: T) -> ARRAY: | |
delta_3d = np.zeros(3) | |
axis = self.moving_axis | |
vec = mouse_pos - self.origin_mouse | |
delta = torch.einsum('d,d', vec, self.dir_2d[:, axis]) | |
delta_3d[axis] = delta | |
return delta_3d | |
def get_delta_rotation(self, mouse_pos: T) -> ARRAY: | |
projections = [] | |
for pos in (self.origin_mouse, mouse_pos): | |
vec = pos - self.transition_origin_2d | |
projection = torch.einsum('d,da->a', vec, self.dir_2d) | |
projection[self.moving_axis] = 0 | |
projection = nnf.normalize(projection, p=2, dim=0) | |
projections.append(projection) | |
sign = (projections[0][(self.moving_axis + 2) % 3] * projections[1][(self.moving_axis + 1) % 3] | |
- projections[0][(self.moving_axis + 1) % 3] * projections[1][(self.moving_axis + 2) % 3] ).sign() | |
angle = (torch.acos(torch.einsum('d,d', *projections)) * sign).item() | |
return ui_utils.get_rotation_matrix(angle, self.moving_axis) | |
def get_delta_scaling(self, mouse_pos: T) -> ARRAY: | |
raise NotImplementedError | |
def toggle_edit_direction(self, direction: ui_utils.EditDirection): | |
self.edit_direction = direction | |
def get_transition(self, mouse_pos: Optional[T]) -> ui_utils.Transition: | |
transition = ui_utils.Transition(self.transition_origin.numpy(), self.transition_type) | |
if mouse_pos is not None: | |
if self.transition_type is ui_utils.EditType.Translating: | |
transition.translation = self.get_delta_translation(mouse_pos) | |
elif self.transition_type is ui_utils.EditType.Rotating: | |
transition.rotation = self.get_delta_rotation(mouse_pos) | |
elif self.transition_type is ui_utils.EditType.Scaling: | |
transition.rotation = self.get_delta_scaling(mouse_pos) | |
return transition | |
def init_transition(self, mouse_pos: Tuple[int, int], transition_origin: T, transition_type: ui_utils.EditType): | |
transform_mat_vtk = self.camera.GetViewTransformMatrix() | |
dir_2d = torch.zeros(3, 4) | |
for i in range(3): | |
for j in range(4): | |
dir_2d[i, j] = transform_mat_vtk.GetElement(i, j) | |
self.transition_origin = transition_origin | |
transition_origin = torch.tensor(transition_origin.tolist() + [1]) | |
transition_origin_2d = torch.einsum('ab,b->a', dir_2d, transition_origin) | |
self.transition_origin_2d = transition_origin_2d[:2] / transition_origin_2d[-1].abs() | |
# print(f"<{self.transition_origin[0]}, {self.transition_origin[1]}>") | |
# print(mouse_pos) | |
self.origin_mouse, self.dir_2d = mouse_pos, nnf.normalize(dir_2d[:2, :3], p=2, dim=1) | |
self.transition_type = transition_type | |
def camera(self): | |
return self.render.GetActiveCamera() | |
def __init__(self, render: ui_utils.CanvasRender): | |
self.render = render | |
self.transition_origin = torch.zeros(3) | |
self.transition_origin_2d = torch.zeros(2) | |
self.origin_mouse, self.dir_2d = torch.zeros(2), torch.zeros(2, 3) | |
self.edit_direction = ui_utils.EditDirection.X_Axis | |
self.transition_type = ui_utils.EditType.Translating | |
def selected_gaussians(self) -> Iterable[gaussian_status.GaussianStatus]: | |
return filter(lambda x: x.is_selected, self.main_stage.gmm) | |
def temporary_transition(self, mouse_pos: Optional[Tuple[int, int]] = None, end=False) -> bool: | |
transition = self.transition_controller.get_transition(mouse_pos) | |
is_change = False | |
for gaussian in self.selected_gaussians: | |
if end: | |
is_change = gaussian.end_transition(transition) or is_change | |
else: | |
is_change = gaussian.temporary_transition(transition) or is_change | |
return is_change | |
def end_transition(self, mouse_pos: Optional[Tuple[int, int]]) -> bool: | |
return self.temporary_transition(mouse_pos, True) | |
def init_transition(self, mouse_pos, transition_type: ui_utils.EditType): | |
center = list(map(lambda x: x.mu_baked, self.selected_gaussians)) | |
if len(center) == 0: | |
return | |
# center = torch.from_numpy(np.stack(center, axis=0).mean(0)) | |
center = torch.zeros(3) | |
self.transition_controller.init_transition(mouse_pos, center, transition_type) | |
def toggle_edit_direction(self, direction: ui_utils.EditDirection): | |
self.transition_controller.toggle_edit_direction(direction) | |
def clear_selection(self) -> bool: | |
is_changed = False | |
for gaussian in self.selected_gaussians: | |
gaussian.toggle_selection() | |
is_changed = True | |
return is_changed | |
def __init__(self, opt: options.Options, shape_paths: List[List[str]], render, view_styles: List[ui_utils.ViewStyle], | |
with_model: bool): | |
super(MeshGmmStatuses, self).__init__(opt, shape_paths, render, view_styles) | |
if with_model: | |
self.model_process = inference_processing.InferenceProcess(opt, self.main_stage.replace_mesh, | |
self.request_gmm, | |
shape_paths) | |
else: | |
self.model_process = None | |
self.counter = 0 | |
self.cur_canvas = 0 | |
self.transition_controller = MeshGmmStatuses.TransitionController(self.main_stage.render) | |
class MeshGmmUnited(MeshGmmStatuses): | |
def save(self, root: str): | |
super(MeshGmmUnited, self).save(root) | |
self.main_gmm.save(root, filter_by_selection) | |
def aggregate_votes(self, select: bool): | |
if self.cur_canvas < len(self.gmms): | |
stage = self.gmms[self.cur_canvas] | |
changed = stage.aggregate_votes() | |
changed = list(filter(lambda x: not stage.gmm[x].disabled and stage.gmm[x].included != select, changed)) | |
for item in changed: | |
is_toggled, toggled = stage.toggle_inclusion_by_id(item, select) | |
if is_toggled: | |
if toggled[0].included: | |
new_addresses = self.main_gmm.add_gaussians(toggled) | |
for gaussian, new_address in zip(toggled, new_addresses): | |
self.stage_mapper[gaussian.get_address()] = new_address | |
self.make_twins(toggled, new_addresses) | |
else: | |
addresses = [gaussian.get_address() for gaussian in toggled] | |
addresses = list(filter(lambda x: x in self.stage_mapper, addresses)) | |
self.main_gmm.remove_gaussians([self.stage_mapper[address] for address in addresses]) | |
for address in addresses: | |
del self.stage_mapper[address] | |
return len(changed) > 0 | |
else: | |
return self.update_selection(select) | |
def update_selection(self, select: bool): | |
changed = self.main_stage.aggregate_votes() | |
changed = filter(lambda x: self.main_stage.gmm[x].is_selected != select, changed) | |
for item in changed: | |
self.main_stage.gmm[item].toggle_selection() | |
return False | |
def vote(self, *actors: Optional[vtk.vtkActor]): | |
if self.cur_canvas < len(self.gmms): | |
self.gmms[self.cur_canvas].vote(*actors) | |
else: | |
self.main_gmm.vote(*actors) | |
def reset(self): | |
super(MeshGmmUnited, self).reset() | |
self.main_gmm.remove_all() | |
for gmm in self.gmms: | |
gmm.toggle_all() | |
self.stage_mapper = {} | |
def event_manger(self, object_id: str): | |
return self.toggle_inclusion(object_id) or self.main_gmm.event_manger(object_id) | |
def make_twins(self, toggled: List[gaussian_status.GaussianStatus], new_addresses : List[str]): | |
if len(new_addresses) == 2: | |
self.main_gmm.make_twins(*new_addresses) | |
else: | |
if toggled[0].twin is not None and toggled[0].twin.get_address() in self.stage_mapper: | |
self.main_gmm.make_twins(new_addresses[0], self.stage_mapper[toggled[0].twin.get_address()]) | |
def toggle_symmetric(self, force_include: bool = False): | |
super(MeshGmmUnited, self).toggle_symmetric(force_include) | |
self.main_gmm.toggle_symmetric(force_include) | |
def main_gmm(self) -> GmmMeshStage: | |
return self.main_gmm_ | |
def main_stage(self) -> GmmMeshStage: | |
return self.main_gmm_ | |
def __init__(self, opt: options.Options, gmm_paths: List[int], renders_right, view_styles: List[ui_utils.ViewStyle], | |
main_render: ui_utils.CanvasRender, with_model: bool): | |
self.main_gmm_ = GmmMeshStage(opt, -1, main_render, len(gmm_paths), view_styles[-1], to_init=False) | |
super(MeshGmmUnited, self).__init__(opt, gmm_paths, renders_right, view_styles[:-1], with_model) | |
self.main_render = main_render | |
self.reset() | |
self.stage_mapper: Dict[str, str] = {} | |
def main(): | |
opt = options.Options(tag="chairs_sym_hard").load() | |
model = train_utils.model_lc(opt)[0] | |
model = model.to(CPU) | |
colors = torch.rand(opt.num_gaussians, 3) | |
shape_nums = 1103, 1637, 2954, 3631, 4814 | |
for shape_num in shape_nums: | |
mesh = files_utils.load_mesh(f"{opt.cp_folder}/occ/samples_{shape_num}") | |
gmm = files_utils.load_gmm(f"{opt.cp_folder}/gmms/samples_{shape_num}") | |
vs, faces = mesh | |
phi, mu, eigen, p, _ = [item.unsqueeze(0).unsqueeze(0) for item in gmm] | |
gmm = mu, p, phi, eigen | |
attention = model.get_attention(vs.unsqueeze(0), torch.tensor([shape_num], dtype=torch.int64))[-4:] | |
# _, supports = gm_utils.hierarchical_gm_log_likelihood_loss([gmm], vs.unsqueeze(0), get_supports=True) | |
# supports = supports[0][0] | |
supports = torch.cat(attention, dim=0) | |
supports = supports.mean(-1).mean(0) | |
label = supports.argmax(1) | |
colors_ = colors[label] | |
files_utils.export_mesh((vs, faces), f"{constants.OUT_ROOT}/{opt.tag}_{shape_num}b", colors=colors_) | |
return 0 | |
if __name__ == '__main__': | |
from utils import train_utils | |
exit(main()) | |