zouzx commited on
Commit
a7620c3
1 Parent(s): c52acae

add run_sam

Browse files
Files changed (2) hide show
  1. app.py +13 -11
  2. run_sam.py +34 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  from copy import deepcopy
7
  import sys
8
  import tempfile
 
9
  from huggingface_hub import snapshot_download
10
 
11
  LOCAL_CODE = os.environ.get("LOCAL_CODE", "1") == "1"
@@ -16,7 +17,6 @@ code_dir = snapshot_download("zouzx/TriplaneGaussian", local_dir="./code", token
16
  sys.path.append(code_dir)
17
 
18
  if not LOCAL_CODE:
19
- import subprocess
20
  subprocess.run(["pip", "install", "--upgrade", "gradio"])
21
 
22
  import gradio as gr
@@ -41,8 +41,8 @@ device = "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu"
41
  print("device: ", device)
42
 
43
  # load SAM checkpoint
44
- sam_predictor = sam_init(SAM_CKPT_PATH, gpu)
45
- print("load sam ckpt done.")
46
 
47
  # init system
48
  base_cfg: ExperimentConfig
@@ -78,24 +78,26 @@ def resize_image(input_raw, size):
78
  resized_h = int(h * ratio)
79
  return input_raw.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
80
 
81
- def preprocess(input_raw, save_path):
82
  # if not preprocess:
83
  # print("No preprocess")
84
  # # return image_path
85
 
86
  # input_raw = Image.open(image_path)
87
  # input_raw.thumbnail([512, 512], Image.Resampling.LANCZOS)
88
- input_raw = resize_image(input_raw, 512)
89
- print("image size:", input_raw.size)
90
- image_sam = sam_out_nosave(
91
- sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw)
92
- )
93
 
94
  save_path = os.path.join(save_path, "input_rgba.png")
95
  # if save_path is None:
96
  # save_path, ext = os.path.splitext(image_path)
97
  # save_path = save_path + "_rgba.png"
98
- image_preprocess(image_sam, save_path, lower_contrast=False, rescale=True)
 
 
99
 
100
  # print("image save path = ", save_path)
101
  return save_path
@@ -149,7 +151,7 @@ def launch(port):
149
 
150
  with gr.Row(variant='panel'):
151
  with gr.Column(scale=1):
152
- input_image = gr.Image(value=None, image_mode="RGB", width=512, height=512, type="pil", sources="upload", label="Input Image")
153
  gr.Markdown(
154
  """
155
  **Camera distance** denotes the distance between camera center and scene center.
 
6
  from copy import deepcopy
7
  import sys
8
  import tempfile
9
+ import subprocess
10
  from huggingface_hub import snapshot_download
11
 
12
  LOCAL_CODE = os.environ.get("LOCAL_CODE", "1") == "1"
 
17
  sys.path.append(code_dir)
18
 
19
  if not LOCAL_CODE:
 
20
  subprocess.run(["pip", "install", "--upgrade", "gradio"])
21
 
22
  import gradio as gr
 
41
  print("device: ", device)
42
 
43
  # load SAM checkpoint
44
+ # sam_predictor = sam_init(SAM_CKPT_PATH, gpu)
45
+ # print("load sam ckpt done.")
46
 
47
  # init system
48
  base_cfg: ExperimentConfig
 
78
  resized_h = int(h * ratio)
79
  return input_raw.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
80
 
81
+ def preprocess(image_path, save_path):
82
  # if not preprocess:
83
  # print("No preprocess")
84
  # # return image_path
85
 
86
  # input_raw = Image.open(image_path)
87
  # input_raw.thumbnail([512, 512], Image.Resampling.LANCZOS)
88
+ # input_raw = resize_image(input_raw, 512)
89
+ # print("image size:", input_raw.size)
90
+ # image_sam = sam_out_nosave(
91
+ # sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw)
92
+ # )
93
 
94
  save_path = os.path.join(save_path, "input_rgba.png")
95
  # if save_path is None:
96
  # save_path, ext = os.path.splitext(image_path)
97
  # save_path = save_path + "_rgba.png"
98
+ # image_preprocess(image_sam, save_path, lower_contrast=False, rescale=True)
99
+
100
+ subprocess.run([f"python run_sam.py --image_path {image_path} --save_path {save_path}"], shell=True)
101
 
102
  # print("image save path = ", save_path)
103
  return save_path
 
151
 
152
  with gr.Row(variant='panel'):
153
  with gr.Column(scale=1):
154
+ input_image = gr.Image(value=None, width=512, height=512, type="filepath", sources="upload", label="Input Image")
155
  gr.Markdown(
156
  """
157
  **Camera distance** denotes the distance between camera center and scene center.
run_sam.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import image_preprocess, pred_bbox, sam_init, sam_out_nosave
2
+ import os
3
+ from PIL import Image
4
+ import argparse
5
+
6
+ SAM_CKPT_PATH = "code/checkpoints/sam_vit_h_4b8939.pth"
7
+
8
+ def resize_image(input_raw, size):
9
+ w, h = input_raw.size
10
+ ratio = size / max(w, h)
11
+ resized_w = int(w * ratio)
12
+ resized_h = int(h * ratio)
13
+ return input_raw.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
14
+
15
+ if __name__ == "__main__":
16
+ # load SAM checkpoint
17
+ gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
18
+ sam_predictor = sam_init(SAM_CKPT_PATH, gpu)
19
+ print("load sam ckpt done.")
20
+
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--image_path", required=True)
23
+ parser.add_argument("--save_path", required=True)
24
+ args = parser.parse_args()
25
+
26
+ input_raw = Image.open(args.image_path)
27
+ # input_raw.thumbnail([512, 512], Image.Resampling.LANCZOS)
28
+ input_raw = resize_image(input_raw, 512)
29
+ image_sam = sam_out_nosave(
30
+ sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw)
31
+ )
32
+
33
+ # save_path = os.path.join(args.save_path, "input_rgba.png")
34
+ image_preprocess(image_sam, args.save_path, lower_contrast=False, rescale=True)