Ariamehr commited on
Commit
82c9791
1 Parent(s): 4ad0438

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -32
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import colorsys
2
  import os
3
-
4
  import gradio as gr
5
  import matplotlib.colors as mcolors
6
  import numpy as np
@@ -16,7 +15,6 @@ from torchvision import transforms
16
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
17
  os.makedirs(ASSETS_DIR, exist_ok=True)
18
 
19
-
20
  LABELS_TO_IDS = {
21
  "Background": 0,
22
  "Apparel": 1,
@@ -48,7 +46,6 @@ LABELS_TO_IDS = {
48
  "Tongue": 27,
49
  }
50
 
51
-
52
  def get_palette(num_cls):
53
  palette = [0] * (256 * 3)
54
  palette[0:3] = [0, 0, 0]
@@ -63,12 +60,10 @@ def get_palette(num_cls):
63
 
64
  return palette
65
 
66
-
67
  def create_colormap(palette):
68
  colormap = np.array(palette).reshape(-1, 3) / 255.0
69
  return mcolors.ListedColormap(colormap)
70
 
71
-
72
  def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_ids: dict[str, int], alpha=0.5):
73
  img_np = np.array(img.convert("RGB"))
74
  mask_np = np.array(mask)
@@ -86,7 +81,6 @@ def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_i
86
 
87
  return blended
88
 
89
-
90
  def create_legend_image(labels_to_ids: dict[str, int], filename="legend.png"):
91
  num_cls = len(labels_to_ids)
92
  palette = get_palette(num_cls)
@@ -128,30 +122,28 @@ def create_legend_image(labels_to_ids: dict[str, int], filename="legend.png"):
128
  plt.savefig(filename, dpi=300, bbox_inches="tight")
129
  plt.close()
130
 
131
-
132
- # create_legend_image(LABELS_TO_IDS, filename=os.path.join(ASSETS_DIR, "legend.png"))
133
-
134
-
135
  # ----------------- MODEL ----------------- #
136
 
137
- URL = "https://huggingface.co/facebook/sapiens/blob/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_2b/sapiens_2b_normal_render_people_epoch_70_torchscript.pt2?download=true"
138
  CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
139
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
140
 
141
  model_path = os.path.join(CHECKPOINTS_DIR, "sapiens_2b_normal_render_people_epoch_70_torchscript.pt2")
142
 
143
- if not os.path.exists(model_path):
144
- os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
145
  import requests
146
 
147
  response = requests.get(URL)
148
- with open(model_path, "wb") as file:
149
- file.write(response.content)
 
 
 
150
 
151
  model = torch.jit.load(model_path)
152
  model.eval()
153
 
154
-
155
  @torch.no_grad()
156
  def run_model(input_tensor, height, width):
157
  output = model(input_tensor)
@@ -159,7 +151,6 @@ def run_model(input_tensor, height, width):
159
  _, preds = torch.max(output, 1)
160
  return preds
161
 
162
-
163
  transform_fn = transforms.Compose(
164
  [
165
  transforms.Resize((1024, 768)),
@@ -167,8 +158,8 @@ transform_fn = transforms.Compose(
167
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
168
  ]
169
  )
170
- # ----------------- CORE FUNCTION ----------------- #
171
 
 
172
 
173
  def segment(image: Image.Image) -> Image.Image:
174
  input_tensor = transform_fn(image).unsqueeze(0)
@@ -178,10 +169,8 @@ def segment(image: Image.Image) -> Image.Image:
178
  blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
179
  return blended_image
180
 
181
-
182
  # ----------------- GRADIO UI ----------------- #
183
 
184
-
185
  with open("banner.html", "r") as file:
186
  banner = file.read()
187
  with open("tips.html", "r") as file:
@@ -202,27 +191,15 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radi
202
  with gr.Row():
203
  with gr.Column():
204
  input_image = gr.Image(label="Input Image", type="pil", format="png")
205
-
206
- '''example_model = gr.Examples(
207
- inputs=input_image,
208
- examples_per_page=10,
209
- examples=[
210
- os.path.join(ASSETS_DIR, "examples", img)
211
- for img in os.listdir(os.path.join(ASSETS_DIR, "examples"))
212
- ],
213
- )'''
214
  with gr.Column():
215
  result_image = gr.Image(label="Segmentation Result", format="png")
216
  run_button = gr.Button("Run")
217
 
218
- #gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
219
-
220
  run_button.click(
221
  fn=segment,
222
  inputs=[input_image],
223
  outputs=[result_image],
224
  )
225
 
226
-
227
  if __name__ == "__main__":
228
  demo.launch(share=False)
 
1
  import colorsys
2
  import os
 
3
  import gradio as gr
4
  import matplotlib.colors as mcolors
5
  import numpy as np
 
15
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
16
  os.makedirs(ASSETS_DIR, exist_ok=True)
17
 
 
18
  LABELS_TO_IDS = {
19
  "Background": 0,
20
  "Apparel": 1,
 
46
  "Tongue": 27,
47
  }
48
 
 
49
  def get_palette(num_cls):
50
  palette = [0] * (256 * 3)
51
  palette[0:3] = [0, 0, 0]
 
60
 
61
  return palette
62
 
 
63
  def create_colormap(palette):
64
  colormap = np.array(palette).reshape(-1, 3) / 255.0
65
  return mcolors.ListedColormap(colormap)
66
 
 
67
  def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_ids: dict[str, int], alpha=0.5):
68
  img_np = np.array(img.convert("RGB"))
69
  mask_np = np.array(mask)
 
81
 
82
  return blended
83
 
 
84
  def create_legend_image(labels_to_ids: dict[str, int], filename="legend.png"):
85
  num_cls = len(labels_to_ids)
86
  palette = get_palette(num_cls)
 
122
  plt.savefig(filename, dpi=300, bbox_inches="tight")
123
  plt.close()
124
 
 
 
 
 
125
  # ----------------- MODEL ----------------- #
126
 
127
+ URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_2b/sapiens_2b_normal_render_people_epoch_70_torchscript.pt2?download=true"
128
  CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
129
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
130
 
131
  model_path = os.path.join(CHECKPOINTS_DIR, "sapiens_2b_normal_render_people_epoch_70_torchscript.pt2")
132
 
133
+ if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
134
+ print("Downloading model...")
135
  import requests
136
 
137
  response = requests.get(URL)
138
+ if response.status_code == 200:
139
+ with open(model_path, "wb") as file:
140
+ file.write(response.content)
141
+ else:
142
+ raise Exception("Failed to download the model. Please check the URL.")
143
 
144
  model = torch.jit.load(model_path)
145
  model.eval()
146
 
 
147
  @torch.no_grad()
148
  def run_model(input_tensor, height, width):
149
  output = model(input_tensor)
 
151
  _, preds = torch.max(output, 1)
152
  return preds
153
 
 
154
  transform_fn = transforms.Compose(
155
  [
156
  transforms.Resize((1024, 768)),
 
158
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
159
  ]
160
  )
 
161
 
162
+ # ----------------- CORE FUNCTION ----------------- #
163
 
164
  def segment(image: Image.Image) -> Image.Image:
165
  input_tensor = transform_fn(image).unsqueeze(0)
 
169
  blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
170
  return blended_image
171
 
 
172
  # ----------------- GRADIO UI ----------------- #
173
 
 
174
  with open("banner.html", "r") as file:
175
  banner = file.read()
176
  with open("tips.html", "r") as file:
 
191
  with gr.Row():
192
  with gr.Column():
193
  input_image = gr.Image(label="Input Image", type="pil", format="png")
 
 
 
 
 
 
 
 
 
194
  with gr.Column():
195
  result_image = gr.Image(label="Segmentation Result", format="png")
196
  run_button = gr.Button("Run")
197
 
 
 
198
  run_button.click(
199
  fn=segment,
200
  inputs=[input_image],
201
  outputs=[result_image],
202
  )
203
 
 
204
  if __name__ == "__main__":
205
  demo.launch(share=False)