zouzx commited on
Commit
54148c8
1 Parent(s): df58ba5

update app

Browse files
Files changed (1) hide show
  1. app.py +61 -49
app.py CHANGED
@@ -8,16 +8,20 @@ import sys
8
  import tempfile
9
  import subprocess
10
  from huggingface_hub import snapshot_download
 
11
 
12
  LOCAL_CODE = os.environ.get("LOCAL_CODE", "1") == "1"
 
 
13
  AUTH = ("admin", os.environ["PASSWD"]) if "PASSWD" in os.environ else None
 
14
 
15
  code_dir = snapshot_download("zouzx/TriplaneGaussian", local_dir="./code", token=os.environ["HF_TOKEN"]) if not LOCAL_CODE else "./code"
16
 
17
  sys.path.append(code_dir)
18
 
19
  if not LOCAL_CODE:
20
- subprocess.run(["pip", "install", "--upgrade", "gradio"])
21
 
22
  import gradio as gr
23
  print("gr version: ", gr.__version__)
@@ -40,10 +44,6 @@ device = "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu"
40
 
41
  print("device: ", device)
42
 
43
- # load SAM checkpoint
44
- # sam_predictor = sam_init(SAM_CKPT_PATH, gpu)
45
- # print("load sam ckpt done.")
46
-
47
  # init system
48
  base_cfg: ExperimentConfig
49
  base_cfg = load_config(CONFIG, cli_args=[], n_gpus=1)
@@ -78,28 +78,26 @@ def resize_image(input_raw, size):
78
  resized_h = int(h * ratio)
79
  return input_raw.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
80
 
81
- def preprocess(image_path, save_path):
82
  # if not preprocess:
83
  # print("No preprocess")
84
  # # return image_path
85
-
86
- # input_raw = Image.open(image_path)
87
- # input_raw.thumbnail([512, 512], Image.Resampling.LANCZOS)
88
- # input_raw = resize_image(input_raw, 512)
89
- # print("image size:", input_raw.size)
90
- # image_sam = sam_out_nosave(
91
- # sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw)
92
- # )
93
-
94
- save_path = os.path.join(save_path, "input_rgba.png")
95
- # if save_path is None:
96
- # save_path, ext = os.path.splitext(image_path)
97
- # save_path = save_path + "_rgba.png"
98
- # image_preprocess(image_sam, save_path, lower_contrast=False, rescale=True)
99
-
100
- subprocess.run([f"python run_sam.py --image_path {image_path} --save_path {save_path}"], shell=True)
101
-
102
- print("image save path = ", save_path)
103
  return save_path
104
 
105
  def init_trial_dir():
@@ -142,7 +140,18 @@ def run_video(image_path: str,
142
  # print("save video", video)
143
  return video
144
 
 
 
 
 
 
 
 
145
  def launch(port):
 
 
 
 
146
  with gr.Blocks(
147
  title="TGS - Demo"
148
  ) as demo:
@@ -151,44 +160,47 @@ def launch(port):
151
 
152
  with gr.Row(variant='panel'):
153
  with gr.Column(scale=1):
154
- input_image = gr.Image(value=None, width=512, height=512, type="filepath", sources="upload", label="Input Image")
155
  gr.Markdown(
156
  """
157
  **Camera distance** denotes the distance between camera center and scene center.
158
  If you find the 3D model appears flattened, you can increase it. Conversely, if the 3D model appears thick, you can decrease it.
159
  """
160
  )
161
- camera_dist_slider = gr.Slider(1.0, 4.0, value=1.9, step=0.1, label="Camera Distance")
162
  # preprocess_ckb = gr.Checkbox(value=True, label="Remove background")
163
  img_run_btn = gr.Button("Reconstruction", variant="primary")
164
 
165
- gr.Examples(
166
- examples=[
167
- "example_images/green_parrot.webp",
168
- "example_images/rusty_gameboy.webp",
169
- "example_images/a_pikachu_with_smily_face.webp",
170
- "example_images/an_otter_wearing_sunglasses.webp",
171
- "example_images/lumberjack_axe.webp",
172
- "example_images/medieval_shield.webp",
173
- "example_images/a_cat_dressed_as_the_pope.webp",
174
- "example_images/a_cute_little_frog_comicbook_style.webp",
175
- "example_images/a_purple_winter_jacket.webp",
176
- "example_images/MP5,_high_quality,_ultra_realistic.webp",
177
- "example_images/retro_pc_photorealistic_high_detailed.webp",
178
- "example_images/stratocaster_guitar_pixar_style.webp"
179
- ],
180
- inputs=[input_image],
181
- cache_examples=False,
182
- label="Examples",
183
- examples_per_page=40
184
- )
185
-
186
  with gr.Column(scale=1):
187
  with gr.Row(variant='panel'):
188
  seg_image = gr.Image(value=None, width="auto", type="filepath", image_mode="RGBA", label="Segmented Image", interactive=False)
189
  output_video = gr.Video(value=None, width="auto", label="Rendered Video", autoplay=True)
190
  output_3dgs = Model3DGS(value=None, label="3D Model")
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  trial_dir = gr.State()
193
  img_run_btn.click(
194
  fn=assert_input_image,
@@ -199,7 +211,7 @@ def launch(port):
199
  outputs=[trial_dir],
200
  # queue=False
201
  ).success(
202
- fn=preprocess,
203
  inputs=[input_image, trial_dir],
204
  outputs=[seg_image],
205
  ).success(fn=run,
@@ -210,7 +222,7 @@ def launch(port):
210
  outputs=[output_video])
211
 
212
  launch_args = {"server_port": port}
213
- demo.queue(max_size=4)
214
  demo.launch(auth=AUTH, **launch_args)
215
 
216
  if __name__ == "__main__":
 
8
  import tempfile
9
  import subprocess
10
  from huggingface_hub import snapshot_download
11
+ from functools import partial
12
 
13
  LOCAL_CODE = os.environ.get("LOCAL_CODE", "1") == "1"
14
+ CACHE_EXAMPLES = os.environ.get("CACHE_EXAMPLES", "1") == "1"
15
+ SAM_LOCAL = os.environ.get("SAM_LOCAL", "1") == "1"
16
  AUTH = ("admin", os.environ["PASSWD"]) if "PASSWD" in os.environ else None
17
+ DEFAULT_CAM_DIST = 1.9
18
 
19
  code_dir = snapshot_download("zouzx/TriplaneGaussian", local_dir="./code", token=os.environ["HF_TOKEN"]) if not LOCAL_CODE else "./code"
20
 
21
  sys.path.append(code_dir)
22
 
23
  if not LOCAL_CODE:
24
+ subprocess.run(["pip", "install", "--upgrade", "gradio==4.12.0"])
25
 
26
  import gradio as gr
27
  print("gr version: ", gr.__version__)
 
44
 
45
  print("device: ", device)
46
 
 
 
 
 
47
  # init system
48
  base_cfg: ExperimentConfig
49
  base_cfg = load_config(CONFIG, cli_args=[], n_gpus=1)
 
78
  resized_h = int(h * ratio)
79
  return input_raw.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
80
 
81
+ def preprocess(input_raw, save_path, sam_predictor=None):
82
  # if not preprocess:
83
  # print("No preprocess")
84
  # # return image_path
85
+ image_path = os.path.join(save_path, "input_raw.png")
86
+ save_path = os.path.join(save_path, "seg_rgba.png")
87
+ if SAM_LOCAL and sam_predictor is not None:
88
+ # input_raw = Image.open(image_path)
89
+ # input_raw.thumbnail([512, 512], Image.Resampling.LANCZOS)
90
+ input_raw = resize_image(input_raw, 512)
91
+ print("image size:", input_raw.size)
92
+ image_sam = sam_out_nosave(
93
+ sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw)
94
+ )
95
+ image_preprocess(image_sam, save_path, lower_contrast=False, rescale=True)
96
+ else:
97
+ input_raw.save(image_path)
98
+ subprocess.run([f"python run_sam.py --image_path {image_path} --save_path {save_path}"], shell=True)
99
+
100
+ print("image raw path = ", image_path, "image save path =", save_path)
 
 
101
  return save_path
102
 
103
  def init_trial_dir():
 
140
  # print("save video", video)
141
  return video
142
 
143
+ def run_example(image_path, sam_predictor=None):
144
+ save_path = init_trial_dir()
145
+ seg_image_path = preprocess(image_path, save_path, sam_predictor)
146
+ gs = run(seg_image_path, DEFAULT_CAM_DIST, save_path)
147
+ video = run_video(seg_image_path, DEFAULT_CAM_DIST, save_path)
148
+ return seg_image_path, gs, video
149
+
150
  def launch(port):
151
+ if SAM_LOCAL:
152
+ sam_predictor = sam_init(SAM_CKPT_PATH, gpu)
153
+ print("load sam ckpt done.")
154
+
155
  with gr.Blocks(
156
  title="TGS - Demo"
157
  ) as demo:
 
160
 
161
  with gr.Row(variant='panel'):
162
  with gr.Column(scale=1):
163
+ input_image = gr.Image(value=None, image_mode="RGB", width=512, height=512, type="pil", sources="upload", label="Input Image")
164
  gr.Markdown(
165
  """
166
  **Camera distance** denotes the distance between camera center and scene center.
167
  If you find the 3D model appears flattened, you can increase it. Conversely, if the 3D model appears thick, you can decrease it.
168
  """
169
  )
170
+ camera_dist_slider = gr.Slider(1.0, 4.0, value=DEFAULT_CAM_DIST, step=0.1, label="Camera Distance")
171
  # preprocess_ckb = gr.Checkbox(value=True, label="Remove background")
172
  img_run_btn = gr.Button("Reconstruction", variant="primary")
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  with gr.Column(scale=1):
175
  with gr.Row(variant='panel'):
176
  seg_image = gr.Image(value=None, width="auto", type="filepath", image_mode="RGBA", label="Segmented Image", interactive=False)
177
  output_video = gr.Video(value=None, width="auto", label="Rendered Video", autoplay=True)
178
  output_3dgs = Model3DGS(value=None, label="3D Model")
179
 
180
+ with gr.Row(variant="panel"):
181
+ gr.Examples(
182
+ examples=[
183
+ "example_images/green_parrot.webp",
184
+ "example_images/rusty_gameboy.webp",
185
+ "example_images/a_pikachu_with_smily_face.webp",
186
+ "example_images/an_otter_wearing_sunglasses.webp",
187
+ "example_images/lumberjack_axe.webp",
188
+ "example_images/medieval_shield.webp",
189
+ "example_images/a_cat_dressed_as_the_pope.webp",
190
+ "example_images/a_cute_little_frog_comicbook_style.webp",
191
+ "example_images/a_purple_winter_jacket.webp",
192
+ "example_images/MP5,_high_quality,_ultra_realistic.webp",
193
+ "example_images/retro_pc_photorealistic_high_detailed.webp",
194
+ "example_images/stratocaster_guitar_pixar_style.webp"
195
+ ],
196
+ inputs=[input_image],
197
+ outputs=[seg_image, output_3dgs, output_video],
198
+ cache_examples=CACHE_EXAMPLES,
199
+ fn=partial(run_example, sam_predictor=sam_predictor),
200
+ label="Examples",
201
+ examples_per_page=40
202
+ )
203
+
204
  trial_dir = gr.State()
205
  img_run_btn.click(
206
  fn=assert_input_image,
 
211
  outputs=[trial_dir],
212
  # queue=False
213
  ).success(
214
+ fn=partial(preprocess, sam_predictor=sam_predictor),
215
  inputs=[input_image, trial_dir],
216
  outputs=[seg_image],
217
  ).success(fn=run,
 
222
  outputs=[output_video])
223
 
224
  launch_args = {"server_port": port}
225
+ demo.queue(max_size=20)
226
  demo.launch(auth=AUTH, **launch_args)
227
 
228
  if __name__ == "__main__":