salad-demo / salad /spaghetti /ui /ui_controllers.py
DveloperY0115's picture
init repo
801501a
from __future__ import annotations
import constants
import vtk
import vtk.util.numpy_support as numpy_support
from custom_types import *
from utils import files_utils, rotation_utils
from models import gm_utils
from ui import ui_utils, inference_processing, gaussian_status
import options
def filter_by_inclusion(gaussian: gaussian_status.GaussianStatus) -> bool:
return gaussian.included
def filter_by_selection(gaussian: gaussian_status.GaussianStatus) -> bool:
return gaussian.is_selected
class GmmMeshStage:
def turn_off_selected(self):
if self.selected is not None:
# self.arrows.turn_off()
self.toggle_selection(self.selected)
self.selected = None
def turn_gmm_off(self):
self.turn_off_selected()
for gaussian in self.gmm:
gaussian.turn_off()
def turn_gmm_on(self):
for gaussian in self.gmm:
gaussian.turn_on()
def event_manger(self, object_id: str):
if object_id in self.addresses_dict:
return self.toggle_selection(object_id)
elif self.arrows.check_event(object_id):
transform = self.arrows.get_transform(object_id)
self.update_gmm(*transform)
return True
return False
def toggle_selection(self, object_id: str):
self.gmm[self.addresses_dict[object_id]].toggle_selection()
if self.selected is None:
self.selected = object_id
elif self.selected == object_id and self.gmm[self.addresses_dict[object_id]].is_not_selected:
self.selected = None
else:
self.gmm[self.addresses_dict[self.selected]].toggle_selection()
self.selected = object_id
# if self.selected is not None:
# self.arrows.update_arrows_transform(self.gmm[self.addresses_dict[self.selected]])
# else:
# self.arrows.turn_off()
return True
def toggle_inclusion_by_id(self, g_id: int, select: Optional[bool] = None) -> Tuple[bool, List[gaussian_status.GaussianStatus]]:
toggled = []
self.gmm[g_id].toggle_inclusion(select)
toggled.append(self.gmm[g_id])
if self.symmetric_mode:
if self.gmm[g_id].twin is not None and self.gmm[g_id].twin.included != self.gmm[g_id].included:
self.gmm[g_id].twin.toggle_inclusion(select)
toggled.append(self.gmm[g_id].twin)
return True, toggled
def toggle_inclusion(self, object_id: str) -> Tuple[bool, List[gaussian_status.GaussianStatus]]:
if object_id in self.addresses_dict:
return self.toggle_inclusion_by_id(self.addresses_dict[object_id])
return False, []
def toggle_all(self):
for gaussian in self.gmm:
gaussian.toggle_inclusion()
def __len__(self):
return len(self.gmm)
def set_opacity(self, opacity: float):
self.view_style.opacity = opacity
for gaussian in self.gmm:
gaussian.set_color()
def update_gmm(self, button: ui_utils.Buttons, key: str) -> bool:
if self.selected is not None:
g_id = self.addresses_dict[self.selected]
self.gmm[g_id].apply_affine(button, key)
if self.symmetric_mode:
if self.gmm[g_id].twin is not None:
self.gmm[g_id].twin.make_symmetric(False)
# self.arrows.update_arrows_transform(self.gmm[self.addresses_dict[self.selected]])
return True
return False
def get_gmm(self) -> Tuple[TS, T]:
raw_gmm = [g.get_raw_data() for g in self.gmm if g.included]
phi = torch.tensor([g[0] for g in raw_gmm], dtype=torch.float32).view(1, 1, -1)
# phi = torch.from_numpy(self.raw_gmm[0]).view(1, 1, -1).float()
mu = torch.stack([torch.from_numpy(g[1]).float() for g in raw_gmm], dim=0).view(1, 1, -1, 3)
p = torch.stack([torch.from_numpy(g[3]).float() for g in raw_gmm], dim=0).view(1, 1, -1, 3, 3)
eigen = torch.stack([torch.from_numpy(g[2]).float() for g in raw_gmm], dim=0).view(1, 1, -1, 3)
gmm = mu, p, phi, eigen
included = torch.tensor([g.gaussian_id for g in self.gmm if g.included], dtype=torch.int64)
return gmm, included
def reset(self):
for g in self.gmm:
g.reset()
# self.turn_off_selected()
def remove_all(self):
self.remove_gaussians(list(self.addresses_dict.keys()))
self.addresses_dict = {}
self.gmm = []
# def switch_arrows(self, arrow_type: ui_utils.Buttons):
# if self.arrows.switch_arrows(arrow_type) and self.selected is not None:
# self.arrows.update_arrows_transform(self.gmm[self.addresses_dict[self.selected]])
def toggle_symmetric(self, force_include: bool):
self.symmetric_mode = not self.symmetric_mode and False
# visited = set()
if self.symmetric_mode:
for i in range(len(self)):
self.gmm[i].make_symmetric(force_include)
def remove_gaussians(self, addresses: List[str]):
for address in addresses:
gaussian_id: int = self.addresses_dict[address]
gaussian = self.gmm[gaussian_id]
# if gaussian.is_selected:
# self.toggle_selection(address)
self.gmm[gaussian_id] = None
gaussian.delete(self.render)
del self.addresses_dict[address]
self.gmm = [gaussian for gaussian in self.gmm if gaussian is not None]
self.addresses_dict = {self.gmm[i].get_address(): i for i in range(len(self.gmm))}
def add_gaussians(self, gaussians: List[gaussian_status.GaussianStatus]) -> List[str]:
new_addresses = []
for i, gaussian in enumerate(gaussians):
gaussian_copy = gaussian.copy(self.render, self.view_style, is_selected=False)
self.gmm.append(gaussian_copy)
new_addresses.append(gaussian_copy.get_address())
self.addresses_dict = {self.gmm[i].get_address(): i for i in range(len(self.gmm))}
return new_addresses
def make_twins(self, address_a: str, address_b: str):
if address_a in self.addresses_dict and address_b in self.addresses_dict:
gaussian_a, gaussian_b = self.gmm[self.addresses_dict[address_a]], self.gmm[self.addresses_dict[address_b]]
gaussian_a.twin = gaussian_b
gaussian_b.twin = gaussian_a
def split_mesh_by_gmm(self, mesh) -> Dict[int, T]:
faces_split = {}
mu, p, phi, _ = self.get_gmm()[0]
eigen = torch.stack([torch.from_numpy(g.get_view_eigen()).float() for g in self.gmm if g.included], dim=0).view(1, 1, -1, 3)
gmm = mu, p, phi, eigen
faces_split_ = gm_utils.split_mesh_by_gmm(mesh, gmm)
counter = 0
for i in range(len(self.gmm)):
if self.gmm[i].disabled:
faces_split[i] = None
else:
faces_split[i] = faces_split_[counter]
counter += 1
return faces_split
@staticmethod
def get_part_face(mesh: V_Mesh, faces_inds: T) -> Tuple[T_Mesh, T]:
mesh = mesh[0], torch.from_numpy(mesh[1]).long()
mask = faces_inds.ne(0)
faces = mesh[1][mask]
vs_inds = faces.flatten().unique()
vs = mesh[0][vs_inds]
mapper = torch.zeros(mesh[0].shape[0], dtype=torch.int64)
mapper[vs_inds] = torch.arange(vs.shape[0])
return (vs, mapper[faces]), faces_inds[mask]
def save(self, root: str, filter_faces: Callable[[gaussian_status.GaussianStatus], bool] = filter_by_inclusion):
if self.faces is not None:
if self.gmm_id == -1:
name = "mix"
else:
name = str(self.gmm_id)
path = f"{root}/{files_utils.get_time_name(name)}"
faces = list(filter(lambda x: x[1] is not None, self.faces.items()))
mesh = self.vs, np.concatenate(list(map(lambda x: x[1], faces)))
faces_inds = map(lambda x:
torch.ones(x[1].shape[0], dtype=torch.int64)
if filter_faces(self.gmm[x[0]]) else torch.zeros(x[1].shape[0], dtype=torch.int64), faces)
faces_inds = torch.cat(list(faces_inds))
# if name != 'mix':
# mesh, faces_inds = self.get_part_face(mesh, faces_inds)
files_utils.export_mesh(mesh, path)
files_utils.export_list(faces_inds.tolist(), f"{path}_faces")
def aggregate_symmetric(self) -> Dict[str, int]:
if not self.symmetric_mode:
return self.votes
out = {}
for item in self.votes:
actor_id = self.addresses_dict[item]
twin = self.gmm[actor_id].twin
out[item] = self.votes[item]
if twin is not None and twin.get_address() not in self.votes:
out[twin.get_address()] = self.votes[item]
return out
def aggregate_votes(self) -> List[int]:
# to_do = self.add_selection if select else self.clear_selection
actors_id = []
# votes = self.aggregate_symmetric()
for item in self.votes:
actor_id = self.addresses_dict[item]
actors_id.append(actor_id)
self.votes = {}
return actors_id
def vote(self, *actors: Optional[vtk.vtkActor]):
for actor in actors:
if actor is not None:
address = actor.GetAddressAsString('')
if address in self.addresses_dict:
if address not in self.votes:
self.votes[address] = 0
self.votes[address] += 1
@staticmethod
def faces_to_vtk_faces(faces: Union[T, ARRAY]):
if type(faces) is T:
faces = faces.detach().cpu().numpy()
cells_npy = np.column_stack(
[np.full(faces.shape[0], 3, dtype=np.int64), faces.astype(np.int64)]).ravel()
faces_vtk = vtk.vtkCellArray()
faces_vtk.SetCells(faces.shape[0], numpy_support.numpy_to_vtkIdTypeArray(cells_npy))
return faces_vtk
def get_mesh_part(self, vs: vtk.vtkPoints, faces: Optional[Union[T, ARRAY]]) -> Optional[vtk.vtkPolyData]:
if faces is not None:
# actor_mesh = vtk.vtkActor()
mesh = vtk.vtkPolyData()
# mapper = vtk.vtkPolyDataMapper()
mesh.SetPoints(vs)
mesh.SetPolys(self.faces_to_vtk_faces(faces))
# mapper.SetInputData(mesh)
# actor_mesh.SetMapper(mapper)
# actor_mesh.GetProperty().SetOpacity(0.3)
# actor_mesh.PickableOff()
# if self.to_init:
# self.render.AddActor(actor_mesh)
return mesh
return None
def add_gmm(self) -> List[gaussian_status.GaussianStatus]:
gmms = []
if len(self.raw_gmm) > 0:
phi = self.raw_gmm[0]
phi = np.exp(phi)
phi = phi / phi.sum()
for i, gaussian in enumerate(zip(*self.raw_gmm)):
gaussian = gaussian_status.GaussianStatus(gaussian, (self.gmm_id, i), False, self.view_style,
self.render, phi[i])
gmms.append(gaussian)
return gmms
def add_mesh(self, base_mesh: T_Mesh, split_mesh: bool = True, for_slider: bool = True):
if base_mesh is not None:
vs_vtk = vtk.vtkPoints()
self.vs = base_mesh[0]
if for_slider:
vs_ui = self.init_mesh_pos(base_mesh[0])
else:
vs_ui = self.vs
vs_vtk.SetData(numpy_support.numpy_to_vtk(vs_ui.numpy()))
if split_mesh:
self.faces = self.split_mesh_by_gmm(base_mesh)
for i in range(len(self.gmm)):
part_mesh = self.get_mesh_part(vs_vtk, self.faces[i])
self.gmm[i].replace_part(part_mesh)
else:
part_mesh = self.get_mesh_part(vs_vtk, base_mesh[1])
self.gmm[0].replace_part(part_mesh)
def set_brush(self, is_draw: bool):
self.render.set_brush(is_draw)
def replace_mesh(self, mesh: Optional[V_Mesh]):
mesh = torch.from_numpy(mesh[0]).float(), torch.from_numpy(mesh[1]).long()
self.add_mesh(mesh, for_slider=False)
# if mesh is None:
# return
# else:
# reduction = 1 - 50000. / mesh[1].shape[0]
# source_ = MeshStage.mesh_to_polydata(mesh)
# source_ = MeshStage.smooth_mesh(source_, ui_utils.SmoothingMethod.Taubin)
# self.decimate_mesh(source_, reduction, out=self.mapper.GetInput())
# self.is_changed = True
# if not self.to_init:
# self.to_init = True
# self.render.AddActor(self.actor)
def init_mesh_pos(self, vs: T):
vs = vs.clone()
r_a = rotation_utils.get_rotation_matrix(150, 1, degree=True)
r_b = rotation_utils.get_rotation_matrix(-15, 0, degree=True)
r = torch.from_numpy(np.einsum('km,mn->kn', r_b, r_a)).float()
vs = torch.einsum('ad,nd->na', r, vs)
vs[:, 0] += self.gmm_id * 2
return vs
@staticmethod
def mesh_to_polydata(mesh: Union[T_Mesh, V_Mesh], source: Optional[vtk.vtkPolyData] = None) -> vtk.vtkPolyData:
if source is None:
source = vtk.vtkPolyData()
vs, faces = mesh
if type(vs) is T:
vs, faces = vs.detach().cpu().numpy(), faces.detach().cpu().numpy()
new_vs_vtk = numpy_support.numpy_to_vtk(vs)
cells_npy = np.column_stack(
[np.full(faces.shape[0], 3, dtype=np.int64), faces.astype(np.int64)]).ravel()
vs_vtk, faces_vtk = vtk.vtkPoints(), vtk.vtkCellArray()
vs_vtk.SetData(new_vs_vtk)
faces_vtk.SetCells(faces.shape[0], numpy_support.numpy_to_vtkIdTypeArray(cells_npy))
source.SetPoints(vs_vtk)
source.SetPolys(faces_vtk)
return source
@property
def included(self):
for g in self.gmm:
if g.included:
return True
return False
def move_mesh_to_end(self, cycle: int):
self.offset += cycle
vs = None
for i in range(len(self)):
mapper = self.gmm[i].mapper
if mapper is not None and mapper.GetInput() is not None:
vs_vtk = mapper.GetInput().GetPoints()
if vs is None:
vs = numpy_support.vtk_to_numpy(vs_vtk.GetData())
vs[:, 0] += cycle * 2
vs_vtk.SetData(numpy_support.numpy_to_vtk(vs))
def pick(self, actor_address: str) -> bool:
return actor_address in self.addresses_dict
def __init__(self, opt: options.Options, shape_path: List[str], render: ui_utils.CanvasRender, render_number: int,
view_style: ui_utils.ViewStyle, to_init=True):
self.view_style = view_style
self.votes = {}
self.shape_id = shape_path[1]
self.gmm_id = render_number
self.render = render
self.symmetric_mode = sum(opt.symmetric) > 0 and False
self.selected = None
self.offset = render_number
# self.arrows = arrows.ArrowManger(render)
if self.shape_id != '-1':
self.base_mesh = files_utils.load_mesh( ''.join(shape_path))
self.raw_gmm = files_utils.load_gmm(f'{shape_path[0]}/{shape_path[1]}.txt', as_np=True)[:-1]
else:
self.base_mesh = None
self.raw_gmm = []
self.to_init = to_init
self.is_changed = False
self.gmm: List[gaussian_status.GaussianStatus] = self.add_gmm()
self.vs = self.faces = None
self.add_mesh(self.base_mesh)
self.addresses_dict: Dict[str, int] = {self.gmm[i].get_address(): i for i in range(len(self.gmm))}
if self.symmetric_mode:
for i in range(len(self) // 2):
self.make_twins(self.gmm[i].get_address(), self.gmm[i + len(self) // 2].get_address())
self.toggle_all()
# if self.raw_gmm:
# gmms = self.get_gmm()[0]
# files_utils.export_gmm(gmms, 0, f"./{render_number}")
class GmmStatuses:
def __len__(self):
return len(self.gmms)
def switch_arrows(self, arrow_type: ui_utils.Buttons):
self.main_gmm.switch_arrows(arrow_type)
def turn_gmm_off(self):
self.main_gmm.turn_gmm_off()
def turn_gmm_on(self):
self.main_gmm.turn_gmm_on()
def update_gmm(self, button: ui_utils.Buttons, key: str) -> bool:
return self.main_gmm.update_gmm(button, key)
def toggle_symmetric(self, force_include: bool = False):
for gmm in self.gmms:
gmm.toggle_symmetric(force_include)
def event_manger(self, object_id: str):
for gmm in self.gmms:
if gmm.event_manger(object_id):
return True
return False
def toggle_inclusion(self, object_id: str):
for gmm in self.gmms:
if gmm.toggle_inclusion(object_id)[0]:
return True
return False
@property
def main_gmm(self) -> GmmMeshStage:
return self.gmms[0]
def reset(self):
for gmm in self.gmms:
gmm.reset()
def set_brush(self, is_draw: bool):
for gmm in self.gmms:
gmm.set_brush(is_draw)
def move_mesh_to_end(self, ptr: int):
self.gmms[ptr].move_mesh_to_end(len(self))
def pick(self, actor_address: str) -> Optional[GmmMeshStage]:
for gmm in self.gmms:
if gmm.pick(actor_address):
return gmm
return None
def __init__(self, opt: options.Options, shape_paths: List[List[str]], render, view_styles: List[ui_utils.ViewStyle]):
self.gmms = [GmmMeshStage(opt, shape_path, render, i, view_style) for i, (shape_path, view_style) in
enumerate(zip(shape_paths, view_styles))]
def to_local(func):
def inner(self: MeshGmmStatuses.TransitionController, mouse_pos: Optional[Tuple[int, int]], *args, **kwargs):
if mouse_pos is not None:
size = self.render.GetRenderWindow().GetScreenSize()
aspect = self.render.GetAspect()
mouse_pos = float(mouse_pos[0]) / size[0] - .5, float(mouse_pos[1]) / size[1] - .5
mouse_pos = torch.tensor([mouse_pos[0] / aspect[1], mouse_pos[1] / aspect[0]])
return func(self, mouse_pos, *args, **kwargs)
return inner
class MeshGmmStatuses(GmmStatuses):
def aggregate_votes(self, select: bool):
if self.cur_canvas < len(self.gmms):
stage = self.gmms[self.cur_canvas]
changed = stage.aggregate_votes()
changed = list(filter(lambda x: not stage.gmm[x].disabled and stage.gmm[x].is_selected != select, changed))
for item in changed:
stage.gmm[item].toggle_selection()
return len(changed) > 0
def vote(self, *actors: Optional[vtk.vtkActor]):
self.gmms[self.cur_canvas].vote(*actors)
def init_draw(self, side: int):
self.cur_canvas = side
def sort_gmms(self, gmms, included):
order = torch.arange(gmms[0].shape[2]).tolist()
order = sorted(order, key=lambda x: included[x][0] * 100 + included[x][1])
gmms = [[item[:, :, order[i]] for item in gmms] for i in range(gmms[0].shape[2])]
gmms = [torch.stack([gmms[j][i] for j in range(len(gmms))], dim=2) for i in range(len(gmms[0]))]
return gmms
def save_light(self, root, gmms):
gmms = self.sort_gmms(*gmms)
save_dict = {'ids': {
gmm.shape_id: [gaussian.gaussian_id[1] for gaussian in gmm.gmm if gaussian.included]
for gmm in self.gmms if gmm.included},
'gmm': gmms}
path = f"{root}/{files_utils.get_time_name('light')}"
files_utils.save_pickle(save_dict, path)
def save(self, root: str, gmms):
# for gmm in self.gmms:
# if gmm.included:
# gmm.save(root)
if len(gmms[0]) > 0:
self.save_light(root, gmms)
def set_brush(self, is_draw: bool):
super(MeshGmmStatuses, self).set_brush(is_draw)
self.main_gmm.render.set_brush(is_draw)
def update_mesh(self, res=128):
if self.model_process is not None:
self.model_process.get_mesh(res)
return True
return False
# self.all_info[side] = gaussian_inds
def request_gmm(self) -> Tuple[TS, T]:
gmm, included = self.main_gmm.get_gmm()
return gmm, included
def replace_mesh(self):
if self.model_process is not None:
self.model_process.replace_mesh()
def exit(self):
if self.model_process is not None:
self.model_process.exit()
@property
def main_stage(self) -> GmmMeshStage:
return self.gmms[0]
@property
def stages(self):
return self.gmms
class TransitionController:
@property
def moving_axis(self) -> int:
return {ui_utils.EditDirection.X_Axis: 0,
ui_utils.EditDirection.Y_Axis: 2,
ui_utils.EditDirection.Z_Axis: 1}[self.edit_direction]
def get_delta_translation(self, mouse_pos: T) -> ARRAY:
delta_3d = np.zeros(3)
axis = self.moving_axis
vec = mouse_pos - self.origin_mouse
delta = torch.einsum('d,d', vec, self.dir_2d[:, axis])
delta_3d[axis] = delta
return delta_3d
def get_delta_rotation(self, mouse_pos: T) -> ARRAY:
projections = []
for pos in (self.origin_mouse, mouse_pos):
vec = pos - self.transition_origin_2d
projection = torch.einsum('d,da->a', vec, self.dir_2d)
projection[self.moving_axis] = 0
projection = nnf.normalize(projection, p=2, dim=0)
projections.append(projection)
sign = (projections[0][(self.moving_axis + 2) % 3] * projections[1][(self.moving_axis + 1) % 3]
- projections[0][(self.moving_axis + 1) % 3] * projections[1][(self.moving_axis + 2) % 3] ).sign()
angle = (torch.acos(torch.einsum('d,d', *projections)) * sign).item()
return ui_utils.get_rotation_matrix(angle, self.moving_axis)
def get_delta_scaling(self, mouse_pos: T) -> ARRAY:
raise NotImplementedError
def toggle_edit_direction(self, direction: ui_utils.EditDirection):
self.edit_direction = direction
@to_local
def get_transition(self, mouse_pos: Optional[T]) -> ui_utils.Transition:
transition = ui_utils.Transition(self.transition_origin.numpy(), self.transition_type)
if mouse_pos is not None:
if self.transition_type is ui_utils.EditType.Translating:
transition.translation = self.get_delta_translation(mouse_pos)
elif self.transition_type is ui_utils.EditType.Rotating:
transition.rotation = self.get_delta_rotation(mouse_pos)
elif self.transition_type is ui_utils.EditType.Scaling:
transition.rotation = self.get_delta_scaling(mouse_pos)
return transition
@to_local
def init_transition(self, mouse_pos: Tuple[int, int], transition_origin: T, transition_type: ui_utils.EditType):
transform_mat_vtk = self.camera.GetViewTransformMatrix()
dir_2d = torch.zeros(3, 4)
for i in range(3):
for j in range(4):
dir_2d[i, j] = transform_mat_vtk.GetElement(i, j)
self.transition_origin = transition_origin
transition_origin = torch.tensor(transition_origin.tolist() + [1])
transition_origin_2d = torch.einsum('ab,b->a', dir_2d, transition_origin)
self.transition_origin_2d = transition_origin_2d[:2] / transition_origin_2d[-1].abs()
# print(f"<{self.transition_origin[0]}, {self.transition_origin[1]}>")
# print(mouse_pos)
self.origin_mouse, self.dir_2d = mouse_pos, nnf.normalize(dir_2d[:2, :3], p=2, dim=1)
self.transition_type = transition_type
@property
def camera(self):
return self.render.GetActiveCamera()
def __init__(self, render: ui_utils.CanvasRender):
self.render = render
self.transition_origin = torch.zeros(3)
self.transition_origin_2d = torch.zeros(2)
self.origin_mouse, self.dir_2d = torch.zeros(2), torch.zeros(2, 3)
self.edit_direction = ui_utils.EditDirection.X_Axis
self.transition_type = ui_utils.EditType.Translating
@property
def selected_gaussians(self) -> Iterable[gaussian_status.GaussianStatus]:
return filter(lambda x: x.is_selected, self.main_stage.gmm)
def temporary_transition(self, mouse_pos: Optional[Tuple[int, int]] = None, end=False) -> bool:
transition = self.transition_controller.get_transition(mouse_pos)
is_change = False
for gaussian in self.selected_gaussians:
if end:
is_change = gaussian.end_transition(transition) or is_change
else:
is_change = gaussian.temporary_transition(transition) or is_change
return is_change
def end_transition(self, mouse_pos: Optional[Tuple[int, int]]) -> bool:
return self.temporary_transition(mouse_pos, True)
def init_transition(self, mouse_pos, transition_type: ui_utils.EditType):
center = list(map(lambda x: x.mu_baked, self.selected_gaussians))
if len(center) == 0:
return
# center = torch.from_numpy(np.stack(center, axis=0).mean(0))
center = torch.zeros(3)
self.transition_controller.init_transition(mouse_pos, center, transition_type)
def toggle_edit_direction(self, direction: ui_utils.EditDirection):
self.transition_controller.toggle_edit_direction(direction)
def clear_selection(self) -> bool:
is_changed = False
for gaussian in self.selected_gaussians:
gaussian.toggle_selection()
is_changed = True
return is_changed
def __init__(self, opt: options.Options, shape_paths: List[List[str]], render, view_styles: List[ui_utils.ViewStyle],
with_model: bool):
super(MeshGmmStatuses, self).__init__(opt, shape_paths, render, view_styles)
if with_model:
self.model_process = inference_processing.InferenceProcess(opt, self.main_stage.replace_mesh,
self.request_gmm,
shape_paths)
else:
self.model_process = None
self.counter = 0
self.cur_canvas = 0
self.transition_controller = MeshGmmStatuses.TransitionController(self.main_stage.render)
class MeshGmmUnited(MeshGmmStatuses):
def save(self, root: str):
super(MeshGmmUnited, self).save(root)
self.main_gmm.save(root, filter_by_selection)
def aggregate_votes(self, select: bool):
if self.cur_canvas < len(self.gmms):
stage = self.gmms[self.cur_canvas]
changed = stage.aggregate_votes()
changed = list(filter(lambda x: not stage.gmm[x].disabled and stage.gmm[x].included != select, changed))
for item in changed:
is_toggled, toggled = stage.toggle_inclusion_by_id(item, select)
if is_toggled:
if toggled[0].included:
new_addresses = self.main_gmm.add_gaussians(toggled)
for gaussian, new_address in zip(toggled, new_addresses):
self.stage_mapper[gaussian.get_address()] = new_address
self.make_twins(toggled, new_addresses)
else:
addresses = [gaussian.get_address() for gaussian in toggled]
addresses = list(filter(lambda x: x in self.stage_mapper, addresses))
self.main_gmm.remove_gaussians([self.stage_mapper[address] for address in addresses])
for address in addresses:
del self.stage_mapper[address]
return len(changed) > 0
else:
return self.update_selection(select)
def update_selection(self, select: bool):
changed = self.main_stage.aggregate_votes()
changed = filter(lambda x: self.main_stage.gmm[x].is_selected != select, changed)
for item in changed:
self.main_stage.gmm[item].toggle_selection()
return False
def vote(self, *actors: Optional[vtk.vtkActor]):
if self.cur_canvas < len(self.gmms):
self.gmms[self.cur_canvas].vote(*actors)
else:
self.main_gmm.vote(*actors)
def reset(self):
super(MeshGmmUnited, self).reset()
self.main_gmm.remove_all()
for gmm in self.gmms:
gmm.toggle_all()
self.stage_mapper = {}
def event_manger(self, object_id: str):
return self.toggle_inclusion(object_id) or self.main_gmm.event_manger(object_id)
def make_twins(self, toggled: List[gaussian_status.GaussianStatus], new_addresses : List[str]):
if len(new_addresses) == 2:
self.main_gmm.make_twins(*new_addresses)
else:
if toggled[0].twin is not None and toggled[0].twin.get_address() in self.stage_mapper:
self.main_gmm.make_twins(new_addresses[0], self.stage_mapper[toggled[0].twin.get_address()])
def toggle_symmetric(self, force_include: bool = False):
super(MeshGmmUnited, self).toggle_symmetric(force_include)
self.main_gmm.toggle_symmetric(force_include)
@property
def main_gmm(self) -> GmmMeshStage:
return self.main_gmm_
@property
def main_stage(self) -> GmmMeshStage:
return self.main_gmm_
def __init__(self, opt: options.Options, gmm_paths: List[int], renders_right, view_styles: List[ui_utils.ViewStyle],
main_render: ui_utils.CanvasRender, with_model: bool):
self.main_gmm_ = GmmMeshStage(opt, -1, main_render, len(gmm_paths), view_styles[-1], to_init=False)
super(MeshGmmUnited, self).__init__(opt, gmm_paths, renders_right, view_styles[:-1], with_model)
self.main_render = main_render
self.reset()
self.stage_mapper: Dict[str, str] = {}
def main():
opt = options.Options(tag="chairs_sym_hard").load()
model = train_utils.model_lc(opt)[0]
model = model.to(CPU)
colors = torch.rand(opt.num_gaussians, 3)
shape_nums = 1103, 1637, 2954, 3631, 4814
for shape_num in shape_nums:
mesh = files_utils.load_mesh(f"{opt.cp_folder}/occ/samples_{shape_num}")
gmm = files_utils.load_gmm(f"{opt.cp_folder}/gmms/samples_{shape_num}")
vs, faces = mesh
phi, mu, eigen, p, _ = [item.unsqueeze(0).unsqueeze(0) for item in gmm]
gmm = mu, p, phi, eigen
attention = model.get_attention(vs.unsqueeze(0), torch.tensor([shape_num], dtype=torch.int64))[-4:]
# _, supports = gm_utils.hierarchical_gm_log_likelihood_loss([gmm], vs.unsqueeze(0), get_supports=True)
# supports = supports[0][0]
supports = torch.cat(attention, dim=0)
supports = supports.mean(-1).mean(0)
label = supports.argmax(1)
colors_ = colors[label]
files_utils.export_mesh((vs, faces), f"{constants.OUT_ROOT}/{opt.tag}_{shape_num}b", colors=colors_)
return 0
if __name__ == '__main__':
from utils import train_utils
exit(main())