ZhengPeng7 commited on
Commit
53ff575
1 Parent(s): fa3042b

Change all predicted results as refined RGBA images. Fix a typo in device specification.

Browse files
Files changed (1) hide show
  1. app.py +9 -17
app.py CHANGED
@@ -21,7 +21,7 @@ import zipfile
21
  torch.set_float32_matmul_precision('high')
22
  torch.jit.script = lambda f: f
23
 
24
- device = "cuda" if torch.cuda.is_available() else "CPU"
25
 
26
  ### image_proc.py
27
  def refine_foreground(image, mask, r=90):
@@ -125,20 +125,18 @@ def predict(images, resolution, weights_file):
125
  for idx_image, image_src in enumerate(images):
126
  if isinstance(image_src, str):
127
  if os.path.isfile(image_src):
128
- image = np.array(Image.open(image_src))
129
  else:
130
  response = requests.get(image_src)
131
  image_data = BytesIO(response.content)
132
- image = np.array(Image.open(image_data))
133
  else:
134
- image = image_src
135
-
136
- image_shape = image.shape[:2]
137
- image_pil = array_to_pil_image(image, tuple(resolution))
138
 
 
139
  # Preprocess the image
140
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
141
- image_proc = image_preprocessor.proc(image_pil)
142
  image_proc = image_proc.unsqueeze(0)
143
 
144
  # Prediction
@@ -148,8 +146,8 @@ def predict(images, resolution, weights_file):
148
 
149
  # Show Results
150
  pred_pil = transforms.ToPILImage()(pred)
151
- image_masked = refine_foreground(Image.fromarray(image), pred_pil)
152
- image_masked.putalpha(pred_pil.resize(Image.fromarray(image).size))
153
 
154
  torch.cuda.empty_cache()
155
 
@@ -158,12 +156,6 @@ def predict(images, resolution, weights_file):
158
  image_masked.save(save_file_path)
159
  save_paths.append(save_file_path)
160
 
161
- # Apply the prediction mask to the original image
162
- pred = torch.nn.functional.interpolate(preds, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
163
- image_pil = image_pil.resize(pred.shape[::-1])
164
- pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
165
- image_masked = (pred * np.array(image_pil)).astype(np.uint8)
166
-
167
  if tab_is_batch:
168
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
169
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
@@ -171,7 +163,7 @@ def predict(images, resolution, weights_file):
171
  zipf.write(file, os.path.basename(file))
172
  return save_paths, zip_file_path
173
  else:
174
- return image, image_masked
175
 
176
 
177
  examples = [[_] for _ in glob('examples/*')][:]
 
21
  torch.set_float32_matmul_precision('high')
22
  torch.jit.script = lambda f: f
23
 
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  ### image_proc.py
27
  def refine_foreground(image, mask, r=90):
 
125
  for idx_image, image_src in enumerate(images):
126
  if isinstance(image_src, str):
127
  if os.path.isfile(image_src):
128
+ image_ori = Image.open(image_src)
129
  else:
130
  response = requests.get(image_src)
131
  image_data = BytesIO(response.content)
132
+ image_ori = Image.open(image_data)
133
  else:
134
+ image_ori = Image.fromarray(image_src)
 
 
 
135
 
136
+ image = image_ori.convert('RGB')
137
  # Preprocess the image
138
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
139
+ image_proc = image_preprocessor.proc(image)
140
  image_proc = image_proc.unsqueeze(0)
141
 
142
  # Prediction
 
146
 
147
  # Show Results
148
  pred_pil = transforms.ToPILImage()(pred)
149
+ image_masked = refine_foreground(image, pred_pil)
150
+ image_masked.putalpha(pred_pil.resize(image.size))
151
 
152
  torch.cuda.empty_cache()
153
 
 
156
  image_masked.save(save_file_path)
157
  save_paths.append(save_file_path)
158
 
 
 
 
 
 
 
159
  if tab_is_batch:
160
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
161
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
 
163
  zipf.write(file, os.path.basename(file))
164
  return save_paths, zip_file_path
165
  else:
166
+ return (image_ori, image_masked)
167
 
168
 
169
  examples = [[_] for _ in glob('examples/*')][:]