ziqima commited on
Commit
164964b
1 Parent(s): b772103
Files changed (3) hide show
  1. app.py +2 -30
  2. inference/utils.py +0 -7
  3. utils.py +12 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import re
3
- from utils import read_pcd, render_point_cloud, render_pcd_file
4
  from inference.utils import get_legend
5
  from inference.inference import segment_obj, get_heatmap
6
  from huggingface_hub import login
@@ -42,6 +42,7 @@ source_dict = {
42
  }
43
 
44
  def predict(pcd_path, inference_mode, part_queries):
 
45
  xyz, rgb, normal = read_pcd(pcd_path)
46
  if inference_mode == "Segmentation":
47
  parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)]
@@ -118,21 +119,6 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as d
118
  gallery_objaverse.select(fn=on_select,
119
  inputs=None,
120
  outputs=[file_upload, part_queries])
121
- '''
122
- gr.Examples(
123
- inputs=[file_upload, part_queries],
124
- examples=[
125
- ["examples/objaverse/fireplug.pcd", "bonnet of a fireplug,side cap of a fireplug,barrel of a fireplug,base of a fireplug"],
126
- ["examples/objaverse/mickey.pcd", "ear,head,arms,hands,body,legs"],
127
- ["examples/objaverse/motorvehicle.pcd", "wheel of a motor vehicle,seat of a motor vehicle,handle of a motor vehicle"],
128
- ["examples/objaverse/teddy.pcd", "head,body,arms,legs"],
129
- ["examples/objaverse/lamppost.pcd", "lighting of a lamppost,pole of a lamppost"],
130
- ["examples/objaverse/shirt.pcd", "sleeve of a shirt,collar of a shirt,body of a shirt"]
131
- ],
132
- example_labels=["fireplug", "Mickey", "motor vehicle", "teddy bear", "lamppost", "shirt"],
133
- label=""
134
- )
135
- '''
136
  with gr.Column(scale=6):
137
  title = gr.HTML("""<h1 text-align="center">In the Wild</h1>
138
  <p style='font-size: 16px;'>Challenging in-the-wild reconstructions from iPhone photos & AI-generated images!</p>
@@ -147,20 +133,6 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as d
147
  gallery_wild.select(fn=on_select,
148
  inputs=None,
149
  outputs=[file_upload, part_queries])
150
- '''
151
- gr.Examples(
152
- inputs=[file_upload, part_queries],
153
- examples=[
154
- ["examples/wild/mcc_dalle_capybara.pcd", "hat worn by a capybara,head,body,feet"],
155
- ["examples/wild/mcc_dalle_corgi.pcd", "head,leg,body,ear"],
156
- ["examples/wild/mcc_iphone_pushcar.pcd", "wheel,body,handle"],
157
- ["examples/wild/mcc_iphone_plant.pcd", "pot,plant"],
158
- ["examples/wild/mcc_iphone_chair.pcd", "back of chair,leg,seat"],
159
- ],
160
- example_labels=["DALLE-capybara", "DALLE-corgi", "iPhone-pushcar", "iPhone-plant", "iPhone-chair"],
161
- label=""
162
- )
163
- '''
164
 
165
  file_upload.change(
166
  fn=render_pcd_file,
 
1
  import gradio as gr
2
  import re
3
+ from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed
4
  from inference.utils import get_legend
5
  from inference.inference import segment_obj, get_heatmap
6
  from huggingface_hub import login
 
42
  }
43
 
44
  def predict(pcd_path, inference_mode, part_queries):
45
+ set_seed()
46
  xyz, rgb, normal = read_pcd(pcd_path)
47
  if inference_mode == "Segmentation":
48
  parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)]
 
119
  gallery_objaverse.select(fn=on_select,
120
  inputs=None,
121
  outputs=[file_upload, part_queries])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  with gr.Column(scale=6):
123
  title = gr.HTML("""<h1 text-align="center">In the Wild</h1>
124
  <p style='font-size: 16px;'>Challenging in-the-wild reconstructions from iPhone photos & AI-generated images!</p>
 
133
  gallery_wild.select(fn=on_select,
134
  inputs=None,
135
  outputs=[file_upload, part_queries])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  file_upload.change(
138
  fn=render_pcd_file,
inference/utils.py CHANGED
@@ -39,13 +39,6 @@ def load_model():
39
  model = model.to(DEVICE)
40
  return model
41
 
42
- def set_seed(seed):
43
- torch.manual_seed(seed)
44
- if DEVICE != "cpu":
45
- torch.cuda.manual_seed(seed)
46
- torch.cuda.manual_seed_all(seed)
47
- np.random.seed(seed)
48
- random.seed(seed)
49
 
50
  def fnv_hash_vec(arr):
51
  """
 
39
  model = model.to(DEVICE)
40
  return model
41
 
 
 
 
 
 
 
 
42
 
43
  def fnv_hash_vec(arr):
44
  """
utils.py CHANGED
@@ -2,6 +2,18 @@ import plotly.graph_objects as go
2
  import open3d as o3d
3
  import numpy as np
4
  import textwrap
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def read_pcd(pcd_path):
7
  pcd = o3d.io.read_point_cloud(pcd_path)
 
2
  import open3d as o3d
3
  import numpy as np
4
  import textwrap
5
+ import torch
6
+ import random
7
+
8
+
9
+ def set_seed():
10
+ seed = 123
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed(seed)
13
+ torch.cuda.manual_seed_all(seed)
14
+ np.random.seed(seed)
15
+ random.seed(seed)
16
+
17
 
18
  def read_pcd(pcd_path):
19
  pcd = o3d.io.read_point_cloud(pcd_path)