MonsterMMORPG
commited on
Commit
•
604e3cb
1
Parent(s):
47565bc
Upload RemoveBG_By_SECourses.py
Browse files- RemoveBG_By_SECourses.py +200 -0
RemoveBG_By_SECourses.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
import argparse
|
7 |
+
from pathlib import Path
|
8 |
+
from glob import glob
|
9 |
+
from typing import Optional, Tuple, List
|
10 |
+
from PIL import Image
|
11 |
+
from transformers import AutoModelForImageSegmentation
|
12 |
+
from torchvision import transforms
|
13 |
+
import time
|
14 |
+
import os
|
15 |
+
import platform
|
16 |
+
|
17 |
+
def parse_args():
|
18 |
+
parser = argparse.ArgumentParser(description="Run the image segmentation app")
|
19 |
+
parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface")
|
20 |
+
return parser.parse_args()
|
21 |
+
|
22 |
+
torch.set_float32_matmul_precision('high')
|
23 |
+
torch.jit.script = lambda f: f
|
24 |
+
|
25 |
+
os.environ['HOME'] = os.path.expanduser('~')
|
26 |
+
|
27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
|
29 |
+
def open_folder():
|
30 |
+
open_folder_path = os.path.abspath("results")
|
31 |
+
if platform.system() == "Windows":
|
32 |
+
os.startfile(open_folder_path)
|
33 |
+
elif platform.system() == "Linux":
|
34 |
+
os.system(f'xdg-open "{open_folder_path}"')
|
35 |
+
|
36 |
+
class ImagePreprocessor():
|
37 |
+
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
38 |
+
self.transform_image = transforms.Compose([
|
39 |
+
transforms.ToTensor(),
|
40 |
+
])
|
41 |
+
self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
42 |
+
|
43 |
+
def proc(self, image: Image.Image) -> torch.Tensor:
|
44 |
+
image = image.convert('RGB') # Convert to RGB
|
45 |
+
image = self.transform_image(image)
|
46 |
+
return self.normalize(image)
|
47 |
+
|
48 |
+
usage_to_weights_file = {
|
49 |
+
'General': 'BiRefNet',
|
50 |
+
'General-Lite': 'BiRefNet_T',
|
51 |
+
'Portrait': 'BiRefNet-portrait',
|
52 |
+
'DIS': 'BiRefNet-DIS5K',
|
53 |
+
'HRSOD': 'BiRefNet-HRSOD',
|
54 |
+
'COD': 'BiRefNet-COD',
|
55 |
+
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
|
56 |
+
}
|
57 |
+
|
58 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
59 |
+
birefnet.to(device)
|
60 |
+
birefnet.eval()
|
61 |
+
|
62 |
+
def process_single_image(image_path: str, resolution: str, output_folder: str) -> Tuple[str, str, float]:
|
63 |
+
start_time = time.time()
|
64 |
+
|
65 |
+
image = Image.open(image_path).convert('RGBA')
|
66 |
+
|
67 |
+
if resolution == '':
|
68 |
+
resolution = f"{image.width}x{image.height}"
|
69 |
+
resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
|
70 |
+
|
71 |
+
image_shape = image.size[::-1]
|
72 |
+
image_pil = image.resize(tuple(resolution))
|
73 |
+
|
74 |
+
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
75 |
+
image_proc = image_preprocessor.proc(image_pil)
|
76 |
+
image_proc = image_proc.unsqueeze(0)
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
80 |
+
|
81 |
+
if device == 'cuda':
|
82 |
+
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
83 |
+
|
84 |
+
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
85 |
+
|
86 |
+
pred_rgba = np.zeros((*pred.shape, 4), dtype=np.uint8)
|
87 |
+
pred_rgba[..., :3] = (pred[..., np.newaxis] * 255).astype(np.uint8)
|
88 |
+
pred_rgba[..., 3] = (pred * 255).astype(np.uint8)
|
89 |
+
|
90 |
+
image_array = np.array(image)
|
91 |
+
image_pred = image_array * (pred_rgba / 255.0)
|
92 |
+
|
93 |
+
output_image = Image.fromarray(image_pred.astype(np.uint8), 'RGBA')
|
94 |
+
|
95 |
+
base_filename = os.path.splitext(os.path.basename(image_path))[0]
|
96 |
+
output_path = os.path.join(output_folder, f"{base_filename}.png")
|
97 |
+
|
98 |
+
counter = 1
|
99 |
+
while os.path.exists(output_path):
|
100 |
+
output_path = os.path.join(output_folder, f"{base_filename}_{counter:04d}.png")
|
101 |
+
counter += 1
|
102 |
+
|
103 |
+
output_image.save(output_path)
|
104 |
+
|
105 |
+
processing_time = time.time() - start_time
|
106 |
+
print(f"Processed {image_path} in {processing_time:.4f} seconds") # Added this line to print processing time
|
107 |
+
return image_path, output_path, processing_time
|
108 |
+
|
109 |
+
def predict(
|
110 |
+
image: str,
|
111 |
+
resolution: str,
|
112 |
+
weights_file: Optional[str],
|
113 |
+
batch_folder: Optional[str] = None,
|
114 |
+
output_folder: Optional[str] = None,
|
115 |
+
is_batch: bool = False
|
116 |
+
) -> Tuple[str, List[Tuple[str, str]]]:
|
117 |
+
global birefnet
|
118 |
+
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
119 |
+
print('Using weights:', _weights_file)
|
120 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
121 |
+
birefnet.to(device)
|
122 |
+
birefnet.eval()
|
123 |
+
|
124 |
+
if not output_folder:
|
125 |
+
output_folder = 'results'
|
126 |
+
os.makedirs(output_folder, exist_ok=True)
|
127 |
+
|
128 |
+
results = []
|
129 |
+
|
130 |
+
if is_batch and batch_folder:
|
131 |
+
image_files = glob(os.path.join(batch_folder, '*'))
|
132 |
+
total_images = len(image_files)
|
133 |
+
processed_images = 0
|
134 |
+
start_time = time.time()
|
135 |
+
|
136 |
+
for img_path in image_files:
|
137 |
+
try:
|
138 |
+
input_path, output_path, proc_time = process_single_image(img_path, resolution, output_folder)
|
139 |
+
results.append((output_path, f"{proc_time:.4f} seconds"))
|
140 |
+
processed_images += 1
|
141 |
+
elapsed_time = time.time() - start_time
|
142 |
+
avg_time_per_image = elapsed_time / processed_images
|
143 |
+
estimated_time_left = avg_time_per_image * (total_images - processed_images)
|
144 |
+
|
145 |
+
status = f"Processed {processed_images}/{total_images} images. Estimated time left: {estimated_time_left:.2f} seconds"
|
146 |
+
print(status)
|
147 |
+
except Exception as e:
|
148 |
+
print(f"Error processing {img_path}: {str(e)}")
|
149 |
+
continue
|
150 |
+
|
151 |
+
return f"Batch processing complete. Processed {processed_images}/{total_images} images.", results
|
152 |
+
else:
|
153 |
+
input_path, output_path, proc_time = process_single_image(image, resolution, output_folder)
|
154 |
+
results.append((output_path, f"{proc_time:.4f} seconds"))
|
155 |
+
return "Single image processing complete.", results
|
156 |
+
|
157 |
+
def create_interface():
|
158 |
+
with gr.Blocks() as demo:
|
159 |
+
gr.Markdown("## SECourses Improved BiRefNet V1 'Bilateral Reference for High-Resolution Dichotomous Image Segmentation' APP - SOTA Background Remover")
|
160 |
+
gr.Markdown("## Most Advanced Latest Version On : https://www.patreon.com/posts/109913645")
|
161 |
+
|
162 |
+
with gr.Row():
|
163 |
+
input_image = gr.Image(type="filepath", label="Input Image",height=512)
|
164 |
+
output_image = gr.Gallery(label="Output Image", elem_id="gallery",height=512)
|
165 |
+
|
166 |
+
|
167 |
+
with gr.Row():
|
168 |
+
resolution = gr.Textbox(label="Resolution", placeholder="1024x1024 - Optional - Don't enter to use original image resolution - Higher res uses more VRAM but still works perfect with shared VRAM so fast")
|
169 |
+
weights_file = gr.Dropdown(choices=list(usage_to_weights_file.keys()), value="General", label="Weights File")
|
170 |
+
btn_open_outputs = gr.Button("Open Results Folder")
|
171 |
+
btn_open_outputs.click(fn=open_folder)
|
172 |
+
|
173 |
+
with gr.Row():
|
174 |
+
batch_folder = gr.Textbox(label="Batch Folder Path")
|
175 |
+
output_folder = gr.Textbox(label="Output Folder Path", value="results")
|
176 |
+
|
177 |
+
with gr.Row():
|
178 |
+
submit_button = gr.Button("Process")
|
179 |
+
batch_button = gr.Button("Process Batch")
|
180 |
+
|
181 |
+
output_text = gr.Textbox(label="Processing Status")
|
182 |
+
|
183 |
+
submit_button.click(
|
184 |
+
predict,
|
185 |
+
inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=False, visible=False)],
|
186 |
+
outputs=[output_text, output_image]
|
187 |
+
)
|
188 |
+
|
189 |
+
batch_button.click(
|
190 |
+
predict,
|
191 |
+
inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=True, visible=False)],
|
192 |
+
outputs=[output_text, output_image]
|
193 |
+
)
|
194 |
+
|
195 |
+
return demo
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
args = parse_args()
|
199 |
+
demo = create_interface()
|
200 |
+
demo.launch(inbrowser=True, share=args.share)
|