Spaces:
Running
on
Zero
Running
on
Zero
ZhengPeng7
commited on
Commit
•
53ff575
1
Parent(s):
fa3042b
Change all predicted results as refined RGBA images. Fix a typo in device specification.
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ import zipfile
|
|
21 |
torch.set_float32_matmul_precision('high')
|
22 |
torch.jit.script = lambda f: f
|
23 |
|
24 |
-
device = "cuda" if torch.cuda.is_available() else "
|
25 |
|
26 |
### image_proc.py
|
27 |
def refine_foreground(image, mask, r=90):
|
@@ -125,20 +125,18 @@ def predict(images, resolution, weights_file):
|
|
125 |
for idx_image, image_src in enumerate(images):
|
126 |
if isinstance(image_src, str):
|
127 |
if os.path.isfile(image_src):
|
128 |
-
|
129 |
else:
|
130 |
response = requests.get(image_src)
|
131 |
image_data = BytesIO(response.content)
|
132 |
-
|
133 |
else:
|
134 |
-
|
135 |
-
|
136 |
-
image_shape = image.shape[:2]
|
137 |
-
image_pil = array_to_pil_image(image, tuple(resolution))
|
138 |
|
|
|
139 |
# Preprocess the image
|
140 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
141 |
-
image_proc = image_preprocessor.proc(
|
142 |
image_proc = image_proc.unsqueeze(0)
|
143 |
|
144 |
# Prediction
|
@@ -148,8 +146,8 @@ def predict(images, resolution, weights_file):
|
|
148 |
|
149 |
# Show Results
|
150 |
pred_pil = transforms.ToPILImage()(pred)
|
151 |
-
image_masked = refine_foreground(
|
152 |
-
image_masked.putalpha(pred_pil.resize(
|
153 |
|
154 |
torch.cuda.empty_cache()
|
155 |
|
@@ -158,12 +156,6 @@ def predict(images, resolution, weights_file):
|
|
158 |
image_masked.save(save_file_path)
|
159 |
save_paths.append(save_file_path)
|
160 |
|
161 |
-
# Apply the prediction mask to the original image
|
162 |
-
pred = torch.nn.functional.interpolate(preds, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
163 |
-
image_pil = image_pil.resize(pred.shape[::-1])
|
164 |
-
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
165 |
-
image_masked = (pred * np.array(image_pil)).astype(np.uint8)
|
166 |
-
|
167 |
if tab_is_batch:
|
168 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
169 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
@@ -171,7 +163,7 @@ def predict(images, resolution, weights_file):
|
|
171 |
zipf.write(file, os.path.basename(file))
|
172 |
return save_paths, zip_file_path
|
173 |
else:
|
174 |
-
return
|
175 |
|
176 |
|
177 |
examples = [[_] for _ in glob('examples/*')][:]
|
|
|
21 |
torch.set_float32_matmul_precision('high')
|
22 |
torch.jit.script = lambda f: f
|
23 |
|
24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
|
26 |
### image_proc.py
|
27 |
def refine_foreground(image, mask, r=90):
|
|
|
125 |
for idx_image, image_src in enumerate(images):
|
126 |
if isinstance(image_src, str):
|
127 |
if os.path.isfile(image_src):
|
128 |
+
image_ori = Image.open(image_src)
|
129 |
else:
|
130 |
response = requests.get(image_src)
|
131 |
image_data = BytesIO(response.content)
|
132 |
+
image_ori = Image.open(image_data)
|
133 |
else:
|
134 |
+
image_ori = Image.fromarray(image_src)
|
|
|
|
|
|
|
135 |
|
136 |
+
image = image_ori.convert('RGB')
|
137 |
# Preprocess the image
|
138 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
139 |
+
image_proc = image_preprocessor.proc(image)
|
140 |
image_proc = image_proc.unsqueeze(0)
|
141 |
|
142 |
# Prediction
|
|
|
146 |
|
147 |
# Show Results
|
148 |
pred_pil = transforms.ToPILImage()(pred)
|
149 |
+
image_masked = refine_foreground(image, pred_pil)
|
150 |
+
image_masked.putalpha(pred_pil.resize(image.size))
|
151 |
|
152 |
torch.cuda.empty_cache()
|
153 |
|
|
|
156 |
image_masked.save(save_file_path)
|
157 |
save_paths.append(save_file_path)
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
if tab_is_batch:
|
160 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
161 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
|
|
163 |
zipf.write(file, os.path.basename(file))
|
164 |
return save_paths, zip_file_path
|
165 |
else:
|
166 |
+
return (image_ori, image_masked)
|
167 |
|
168 |
|
169 |
examples = [[_] for _ in glob('examples/*')][:]
|