Spaces:
Runtime error
Runtime error
from types import SimpleNamespace | |
from typing import List | |
import os | |
import sys | |
import time | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageFilter, ImageOps | |
from transformers import SamModel, SamImageProcessor, MaskGenerationPipeline | |
from modules import shared, errors, devices, ui_components, ui_symbols, paths, sd_models | |
from modules.memstats import memory_stats | |
debug = shared.log.trace if os.environ.get('SD_MASK_DEBUG', None) is not None else lambda *args, **kwargs: None | |
debug('Trace: MASK') | |
def get_crop_region(mask, pad=0): | |
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. | |
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" | |
h, w = mask.shape | |
crop_left = 0 | |
for i in range(w): | |
if not (mask[:, i] == 0).all(): | |
break | |
crop_left += 1 | |
crop_right = 0 | |
for i in reversed(range(w)): | |
if not (mask[:, i] == 0).all(): | |
break | |
crop_right += 1 | |
crop_top = 0 | |
for i in range(h): | |
if not (mask[i] == 0).all(): | |
break | |
crop_top += 1 | |
crop_bottom = 0 | |
for i in reversed(range(h)): | |
if not (mask[i] == 0).all(): | |
break | |
crop_bottom += 1 | |
x1 = max(crop_left - pad, 0) | |
y1 = max(crop_top - pad, 0) | |
x2 = max(w - crop_right + pad, 0) | |
y2 = max(h - crop_bottom + pad, 0) | |
if x2 < x1: | |
x1, x2 = x2, x1 | |
if y2 < y1: | |
y1, y2 = y2, y1 | |
crop_region = ( | |
int(min(x1, w)), | |
int(min(y1, h)), | |
int(min(x2, w)), | |
int(min(y2, h)), | |
) | |
debug(f'Mask crop: mask={w, h} region={crop_region} pad={pad}') | |
return crop_region | |
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): | |
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region | |
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" | |
x1, y1, x2, y2 = crop_region | |
ratio_crop_region = (x2 - x1) / (y2 - y1) | |
ratio_processing = processing_width / processing_height | |
if ratio_crop_region > ratio_processing: | |
desired_height = (x2 - x1) / ratio_processing | |
desired_height_diff = int(desired_height - (y2-y1)) | |
y1 -= desired_height_diff//2 | |
y2 += desired_height_diff - desired_height_diff//2 | |
if y2 >= image_height: | |
diff = y2 - image_height | |
y2 -= diff | |
y1 -= diff | |
if y1 < 0: | |
y2 -= y1 | |
y1 -= y1 | |
if y2 >= image_height: | |
y2 = image_height | |
else: | |
desired_width = (y2 - y1) * ratio_processing | |
desired_width_diff = int(desired_width - (x2-x1)) | |
x1 -= desired_width_diff//2 | |
x2 += desired_width_diff - desired_width_diff//2 | |
if x2 >= image_width: | |
diff = x2 - image_width | |
x2 -= diff | |
x1 -= diff | |
if x1 < 0: | |
x2 -= x1 | |
x1 -= x1 | |
if x2 >= image_width: | |
x2 = image_width | |
crop_expand = ( | |
int(x1), | |
int(y1), | |
int(x2), | |
int(y2), | |
) | |
debug(f'Mask expand: image={image_width, image_height} processing={processing_width, processing_height} region={crop_expand}') | |
return crop_expand | |
def fill(image, mask): | |
"""fills masked regions with colors from image using blur. Not extremely effective.""" | |
image_mod = Image.new('RGBA', (image.width, image.height)) | |
image_masked = Image.new('RGBa', (image.width, image.height)) | |
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) | |
image_masked = image_masked.convert('RGBa') | |
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: | |
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') | |
for _ in range(repeats): | |
image_mod.alpha_composite(blurred) | |
return image_mod.convert("RGB") | |
""" | |
[docs](https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/sam#overview) | |
TODO: | |
- PerSAM | |
- REMBG | |
- https://huggingface.co/docs/transformers/tasks/semantic_segmentation | |
- transformers.pipeline.MaskGenerationPipeline: https://huggingface.co/models?pipeline_tag=mask-generation | |
- transformers.pipeline.ImageSegmentationPipeline: https://huggingface.co/models?pipeline_tag=image-segmentation | |
""" | |
MODELS = { | |
'None': None, | |
'Facebook SAM ViT Base': 'facebook/sam-vit-base', | |
'Facebook SAM ViT Large': 'facebook/sam-vit-large', | |
'Facebook SAM ViT Huge': 'facebook/sam-vit-huge', | |
'SlimSAM Uniform': 'Zigeng/SlimSAM-uniform-50', | |
'SlimSAM Uniform Tiny': 'Zigeng/SlimSAM-uniform-77', | |
'Rembg Silueta': 'silueta', | |
'Rembg U2Net': 'u2net', | |
'Rembg ISNet': 'isnet', | |
# "u2net_human_seg", | |
# "isnet-general-use", | |
# "isnet-anime", | |
} | |
COLORMAP = ['autumn', 'bone', 'jet', 'winter', 'rainbow', 'ocean', 'summer', 'spring', 'cool', 'hsv', 'pink', 'hot', 'parula', 'magma', 'inferno', 'plasma', 'viridis', 'cividis', 'twilight', 'shifted', 'turbo', 'deepgreen'] | |
TYPES = ['None', 'Opaque', 'Binary', 'Masked', 'Grayscale', 'Color', 'Composite'] | |
cache_dir = 'models/control/segment' | |
generator: MaskGenerationPipeline = None | |
busy = False | |
btn_mask = None | |
btn_lama = None | |
lama_model = None | |
controls = [] | |
opts = SimpleNamespace(**{ | |
'model': None, | |
'auto_mask': 'None', | |
'mask_only': False, | |
'mask_blur': 0.01, | |
'mask_erode': 0.01, | |
'mask_dilate': 0.01, | |
'seg_iou_thresh': 0.5, | |
'seg_score_thresh': 0.5, | |
'seg_nms_thresh': 0.5, | |
'seg_overlap_ratio': 0.3, | |
'seg_points_per_batch': 64, | |
'seg_topK': 50, | |
'seg_colormap': 'pink', | |
'preview_type': 'Composite', | |
'seg_live': True, | |
'weight_original': 0.5, | |
'weight_mask': 0.5, | |
'kernel_iterations': 1, | |
'invert': False | |
}) | |
def init_model(selected_model: str): | |
global busy, generator # pylint: disable=global-statement | |
model_path = MODELS[selected_model] | |
if model_path is None: # none | |
if generator is not None: | |
shared.log.debug('Mask segment unloading model') | |
opts.model = None | |
generator = None | |
devices.torch_gc() | |
return selected_model | |
if 'Rembg' in selected_model: # rembg | |
opts.model = model_path | |
generator = None | |
devices.torch_gc() | |
return selected_model | |
if opts.model != selected_model or generator is None: # sam pipeline | |
busy = True | |
t0 = time.time() | |
shared.log.debug(f'Mask segment loading: model={selected_model} path={model_path}') | |
model = SamModel.from_pretrained(model_path, cache_dir=cache_dir).to(device=devices.device) | |
processor = SamImageProcessor.from_pretrained(model_path, cache_dir=cache_dir) | |
generator = MaskGenerationPipeline( | |
model=model, | |
image_processor=processor, | |
device=devices.device, | |
# output_bboxes_mask=False, | |
# output_rle_masks=False, | |
) | |
devices.torch_gc() | |
shared.log.debug(f'Mask segment loaded: model={selected_model} path={model_path} time={time.time()-t0:.2f}s') | |
opts.model = selected_model | |
busy = False | |
return selected_model | |
def run_segment(input_image: gr.Image, input_mask: np.ndarray): | |
outputs = None | |
with devices.inference_context(): | |
try: | |
outputs = generator( | |
input_image, | |
points_per_batch=opts.seg_points_per_batch, | |
pred_iou_thresh=opts.seg_iou_thresh, | |
stability_score_thresh=opts.seg_score_thresh, | |
crops_nms_thresh=opts.seg_nms_thresh, | |
crop_overlap_ratio=opts.seg_overlap_ratio, | |
crops_n_layers=0, | |
crop_n_points_downscale_factor=1, | |
) | |
except Exception as e: | |
shared.log.error(f'Mask segment error: {e}') | |
errors.display(e, 'Mask segment') | |
return outputs | |
devices.torch_gc() | |
i = 1 | |
combined_mask = np.zeros(input_mask.shape, dtype='uint8') | |
input_mask_size = np.count_nonzero(input_mask) | |
debug(f'Segment SAM: {vars(opts)}') | |
for mask in outputs['masks']: | |
mask = mask.astype('uint8') | |
mask_size = np.count_nonzero(mask) | |
if mask_size == 0: | |
continue | |
overlap = 0 | |
if input_mask_size > 0: | |
if mask.shape != input_mask.shape: | |
mask = cv2.resize(mask, (input_mask.shape[1], input_mask.shape[0]), interpolation=cv2.INTER_CUBIC) | |
overlap = cv2.bitwise_and(mask, input_mask) | |
overlap = np.count_nonzero(overlap) | |
if overlap == 0: | |
continue | |
mask = (opts.seg_topK + 1 - i) * mask * (255 // opts.seg_topK) # set grayscale intensity so we can recolor | |
combined_mask = combined_mask + mask | |
debug(f'Segment mask: i={i} size={input_image.width}x{input_image.height} masked={mask_size}px overlap={overlap} score={outputs["scores"][i-1]:.2f}') | |
i += 1 | |
if i > opts.seg_topK: | |
break | |
return combined_mask | |
def run_rembg(input_image: Image, input_mask: np.ndarray): | |
try: | |
import rembg | |
except Exception as e: | |
shared.log.error(f'Mask Rembg load failed: {e}') | |
return input_mask | |
if "U2NET_HOME" not in os.environ: | |
os.environ["U2NET_HOME"] = os.path.join(paths.models_path, "Rembg") | |
args = { | |
'data': input_image, | |
'only_mask': True, | |
'post_process_mask': False, | |
'bgcolor': None, | |
'alpha_matting': False, | |
'alpha_matting_foreground_threshold': 240, | |
'alpha_matting_background_threshold': 10, | |
'alpha_matting_erode_size': int(opts.mask_erode * 40), | |
'session': rembg.new_session(opts.model), | |
} | |
mask = rembg.remove(**args) | |
mask = np.array(mask) | |
if len(input_mask.shape) > 2: | |
mask = cv2.cvtColor(input_mask, cv2.COLOR_RGB2GRAY) | |
binary_input = cv2.threshold(input_mask, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] | |
binary_output = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] | |
if binary_input.shape != binary_output.shape: | |
binary_output = cv2.resize(binary_output, binary_input.shape[:2], interpolation=cv2.INTER_LINEAR) | |
binary_overlap = cv2.bitwise_and(binary_input, binary_output) | |
input_size = np.count_nonzero(binary_input) | |
overlap_size = np.count_nonzero(binary_overlap) | |
debug(f'Segment Rembg: {args} overlap={overlap_size}') | |
if input_size > 0 and overlap_size == 0: | |
mask = np.invert(mask) | |
return mask | |
def get_mask(input_image: gr.Image, input_mask: gr.Image): | |
t0 = time.time() | |
if input_mask is not None: | |
output_mask = np.array(input_mask) | |
if len(output_mask.shape) > 2: | |
output_mask = cv2.cvtColor(output_mask, cv2.COLOR_RGB2GRAY) | |
binary_mask = cv2.threshold(output_mask, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] | |
mask_size = np.count_nonzero(binary_mask) | |
else: | |
output_mask = None | |
mask_size = 0 | |
if mask_size == 0 and opts.auto_mask != 'None': # mask_size == 0 | |
output_mask = np.array(input_image) | |
if opts.auto_mask == 'Threshold': | |
output_mask = cv2.cvtColor(output_mask, cv2.COLOR_RGB2GRAY) | |
output_mask = cv2.threshold(output_mask, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] | |
elif opts.auto_mask == 'Edge': | |
output_mask = cv2.cvtColor(output_mask, cv2.COLOR_RGB2GRAY) | |
output_mask = cv2.threshold(output_mask, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] | |
# output_mask = cv2.Canny(output_mask, 50, 150) # run either canny or threshold before contouring | |
contours, _hierarchy = cv2.findContours(output_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
contours = sorted(contours, key=cv2.contourArea, reverse=True) # sort contours by area with largest first | |
contours = contours[:opts.seg_topK] # limit to top K contours | |
output_mask = np.zeros(output_mask.shape, dtype='uint8') | |
largest_size = cv2.contourArea(contours[0]) if len(contours) > 0 else 0 | |
for i, contour in enumerate(contours): | |
area_size = cv2.contourArea(contour) | |
luminance = int(255.0 * area_size / largest_size) | |
if luminance < 1: | |
break | |
cv2.drawContours(output_mask, contours, i, (luminance), -1) | |
elif opts.auto_mask == 'Grayscale': | |
lab_image = cv2.cvtColor(output_mask, cv2.COLOR_RGB2LAB) | |
l_channel, a, b = cv2.split(lab_image) | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) # applying CLAHE to L-channel | |
cl = clahe.apply(l_channel) | |
lab_image = cv2.merge((cl, a, b)) # merge the CLAHE enhanced L-channel with the a and b channel | |
lab_image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB) | |
output_mask = cv2.cvtColor(lab_image, cv2.COLOR_RGB2GRAY) | |
t1 = time.time() | |
debug(f'Segment auto-mask: mode={opts.auto_mask} time={t1-t0:.2f}') | |
return output_mask | |
else: # no mask or empty mask and no auto-mask | |
return output_mask | |
def outpaint(input_image: Image.Image, outpaint_type: str = 'Edge'): | |
image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
h0, w0 = image.shape[:2] | |
empty = (image == 0).all(axis=2) | |
y0, x0 = np.where(~empty) # non empty | |
x1, x2 = min(x0), max(x0) | |
y1, y2 = min(y0), max(y0) | |
cropped = image[y1:y2, x1:x2] | |
h1, w1 = cropped.shape[:2] | |
mask = None | |
if opts.mask_only: | |
mask = cv2.copyMakeBorder(cropped, y1, h0-y2, x1, w0-x2, cv2.BORDER_CONSTANT, value=(0, 0, 0)) | |
mask = cv2.resize(mask, (w0, h0)) | |
mask = cv2.cvtColor(np.array(mask), cv2.COLOR_BGR2GRAY) | |
mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)[1] | |
sigmaX, sigmaY = int((h0-h1)/3), int((w0-w1)/3) | |
kernel = np.ones((5, 5), np.uint8) | |
mask = cv2.erode(mask, kernel, iterations=max(sigmaX, sigmaY) // 3) # increase overlap area | |
mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=sigmaX, sigmaY=sigmaY) # blur mask | |
mask = Image.fromarray(mask) | |
if outpaint_type == 'Edge': | |
bordered = cv2.copyMakeBorder(cropped, y1, h0-y2, x1, w0-x2, cv2.BORDER_REPLICATE) | |
bordered = cv2.resize(bordered, (w0, h0)) | |
image = bordered | |
# noise = np.random.normal(1, variation, bordered.shape) | |
# noised = (noise * bordered).astype(np.uint8) | |
# h, w = cropped.shape[:2] | |
# noised[y1:y1 + h, x1:x1 + w] = cropped # overlay original over initialized | |
# image = noised | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(image) | |
return image, mask | |
def run_mask(input_image: Image.Image, input_mask: Image.Image = None, return_type: str = None, mask_blur: int = None, mask_padding: int = None, segment_enable=True, invert=None): | |
debug(f'Run mask: fn={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access | |
if input_image is None: | |
return input_mask | |
if isinstance(input_image, list): | |
input_image = input_image[0] | |
if isinstance(input_image, dict): | |
input_mask = input_image.get('mask', None) | |
input_image = input_image.get('image', None) | |
if input_image is None: | |
return input_mask | |
t0 = time.time() | |
input_mask = get_mask(input_image, input_mask) # perform optional auto-masking | |
if input_mask is None: | |
return None | |
size = min(input_image.width, input_image.height) | |
if mask_blur is not None or mask_padding is not None: | |
debug(f'Mask args legacy: blur={mask_blur} padding={mask_padding}') | |
if invert is not None: | |
opts.invert = invert | |
if mask_blur is not None: # compatibility with old img2img values which uses px values | |
opts.mask_blur = round(4 * mask_blur / size, 3) | |
if mask_padding is not None: # compatibility with old img2img values which uses px values | |
opts.mask_dilate = 4 * mask_padding / size | |
if opts.model is None or not segment_enable: | |
mask = input_mask | |
elif generator is None: | |
mask = run_rembg(input_image, input_mask) | |
else: | |
mask = run_segment(input_image, input_mask) | |
mask = cv2.resize(mask, (input_image.width, input_image.height), interpolation=cv2.INTER_LINEAR) | |
debug(f'Mask shape={mask.shape} opts={opts}') | |
if opts.mask_erode > 0: | |
try: | |
kernel = np.ones((int(opts.mask_erode * size / 4) + 1, int(opts.mask_erode * size / 4) + 1), np.uint8) | |
mask = cv2.erode(mask, kernel, iterations=opts.kernel_iterations) # remove noise | |
debug(f'Mask erode={opts.mask_erode:.3f} kernel={kernel.shape} mask={mask.shape}') | |
except Exception as e: | |
shared.log.error(f'Mask erode: {e}') | |
if opts.mask_dilate > 0: | |
try: | |
kernel = np.ones((int(opts.mask_dilate * size / 4) + 1, int(opts.mask_dilate * size / 4) + 1), np.uint8) | |
mask = cv2.dilate(mask, kernel, iterations=opts.kernel_iterations) # expand area | |
debug(f'Mask dilate={opts.mask_dilate:.3f} kernel={kernel.shape} mask={mask.shape}') | |
except Exception as e: | |
shared.log.error(f'Mask dilate: {e}') | |
if opts.mask_blur > 0: | |
try: | |
sigmax, sigmay = 1 + int(opts.mask_blur * size / 4), 1 + int(opts.mask_blur * size / 4) | |
mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=sigmax, sigmaY=sigmay) # blur mask | |
debug(f'Mask blur={opts.mask_blur:.3f} x={sigmax} y={sigmay} mask={mask.shape}') | |
except Exception as e: | |
shared.log.error(f'Mask blur: {e}') | |
if opts.invert: | |
mask = np.invert(mask) | |
mask_size = np.count_nonzero(mask) | |
total_size = np.prod(mask.shape) | |
area_size = np.count_nonzero(mask) | |
t1 = time.time() | |
return_type = return_type or opts.preview_type | |
shared.log.debug(f'Mask: size={input_image.width}x{input_image.height} masked={mask_size}px area={area_size/total_size:.2f} auto={opts.auto_mask} blur={opts.mask_blur} erode={opts.mask_erode} dilate={opts.mask_dilate} type={return_type} time={t1-t0:.2f}') | |
if return_type == 'None': | |
return input_mask | |
elif return_type == 'Opaque': | |
binary_mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)[1] | |
return Image.fromarray(binary_mask) | |
elif return_type == 'Binary': | |
binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] # otsu uses mean instead of threshold | |
return Image.fromarray(binary_mask) | |
elif return_type == 'Masked': | |
orig = np.array(input_image) | |
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) | |
masked_image = cv2.bitwise_and(orig, mask) | |
return Image.fromarray(masked_image) | |
elif return_type == 'Grayscale': | |
return Image.fromarray(mask) | |
elif return_type == 'Color': | |
colored_mask = cv2.applyColorMap(mask, COLORMAP.index(opts.seg_colormap)) # recolor mask | |
return Image.fromarray(colored_mask) | |
elif return_type == 'Composite': | |
colored_mask = cv2.applyColorMap(mask, COLORMAP.index(opts.seg_colormap)) # recolor mask | |
orig = np.array(input_image) | |
combined_image = cv2.addWeighted(orig, opts.weight_original, colored_mask, opts.weight_mask, 0) | |
return Image.fromarray(combined_image) | |
else: | |
shared.log.error(f'Mask unknown return type: {return_type}') | |
return input_mask | |
def run_lama(input_image: gr.Image, input_mask: gr.Image = None): | |
global lama_model # pylint: disable=global-statement | |
if isinstance(input_image, dict): | |
input_mask = input_image.get('mask', None) | |
input_image = input_image.get('image', None) | |
if input_image is None: | |
return None | |
input_mask = run_mask(input_image, input_mask, return_type='Grayscale') | |
if lama_model is None: | |
import modules.lama | |
shared.log.debug(f'Mask LaMa loading: model={modules.lama.LAMA_MODEL_URL}') | |
lama_model = modules.lama.SimpleLama() | |
shared.log.debug(f'Mask LaMa loaded: {memory_stats()}') | |
sd_models.move_model(lama_model.model, devices.device) | |
result = lama_model(input_image, input_mask) | |
if shared.opts.control_move_processor: | |
lama_model.model.to('cpu') | |
return result | |
def run_mask_live(input_image: gr.Image): | |
global busy # pylint: disable=global-statement | |
if opts.seg_live: | |
if not busy: | |
busy = True | |
res = run_mask(input_image) | |
busy = False | |
return res | |
else: | |
return None | |
def create_segment_ui(): | |
def update_opts(*args): | |
opts.seg_live = args[0] | |
opts.mask_only = args[1] | |
opts.invert = args[2] | |
opts.mask_blur = args[3] | |
opts.mask_erode = args[4] | |
opts.mask_dilate = args[5] | |
opts.auto_mask = args[6] | |
opts.seg_score_thresh = args[7] | |
opts.seg_iou_thresh = args[8] | |
opts.seg_nms_thresh = args[9] | |
opts.preview_type = args[10] | |
opts.seg_colormap = args[11] | |
global btn_mask, btn_lama # pylint: disable=global-statement | |
with gr.Accordion(open=False, label="Mask", elem_id="control_mask", elem_classes=["small-accordion"]): | |
controls.clear() | |
with gr.Row(): | |
controls.append(gr.Checkbox(label="Live update", value=True)) | |
btn_mask = ui_components.ToolButton(value=ui_symbols.refresh, visible=True) | |
btn_lama = ui_components.ToolButton(value=ui_symbols.image, visible=True) | |
with gr.Row(): | |
controls.append(gr.Checkbox(label="Inpaint masked only", value=False)) | |
controls.append(gr.Checkbox(label="Invert mask", value=False)) | |
with gr.Row(): | |
controls.append(gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Blur', value=0.01, elem_id="control_mask_blur")) | |
controls.append(gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Erode', value=0.01, elem_id="control_mask_erode")) | |
controls.append(gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Dilate', value=0.01, elem_id="control_mask_dilate")) | |
with gr.Row(): | |
controls.append(gr.Dropdown(label="Auto-mask", choices=['None', 'Threshold', 'Edge', 'Grayscale'], value='None')) | |
selected_model = gr.Dropdown(label="Auto-segment", choices=MODELS.keys(), value='None') | |
with gr.Row(): | |
controls.append(gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Score', value=0.5, visible=False)) | |
controls.append(gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='IOU', value=0.5, visible=False)) | |
controls.append(gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='NMS', value=0.5, visible=False)) | |
with gr.Row(): | |
controls.append(gr.Dropdown(label="Preview", choices=['None', 'Masked', 'Binary', 'Grayscale', 'Color', 'Composite'], value='Composite')) | |
controls.append(gr.Dropdown(label="Colormap", choices=COLORMAP, value='pink')) | |
selected_model.change(fn=init_model, inputs=[selected_model], outputs=[selected_model]) | |
for control in controls: | |
control.change(fn=update_opts, inputs=controls, outputs=[]) | |
return controls | |
def bind_controls(image_controls: List[gr.Image], preview_image: gr.Image, output_image: gr.Image): | |
for image_control in image_controls: | |
btn_mask.click(run_mask, inputs=[image_control], outputs=[preview_image]) | |
btn_lama.click(run_lama, inputs=[image_control], outputs=[output_image]) | |
image_control.edit(fn=run_mask_live, inputs=[image_control], outputs=[preview_image]) | |
for control in controls: | |
control.change(fn=run_mask_live, inputs=[image_control], outputs=[preview_image]) | |