Spaces:
Runtime error
Runtime error
from PIL import Image | |
import gradio as gr | |
import gradio.processing_utils | |
from modules import scripts, patches, gr_tempdir | |
hijacked = False | |
original_IOComponent_init = None | |
original_Block_get_config = None | |
original_BlockContext_init = None | |
original_Blocks_get_config_file = None | |
def gr_image_preprocess(self, x): | |
if x is None: | |
return x | |
mask = None | |
if isinstance(x, dict): | |
x, mask = x["image"], x["mask"] | |
im = gradio.processing_utils.decode_base64_to_image(x) | |
im = im.convert(self.image_mode) | |
if self.shape is not None: | |
im = gradio.processing_utils.resize_and_crop(im, self.shape) | |
if self.tool == "sketch" and self.source in ["upload"]: | |
if mask is not None: | |
mask_im = gradio.processing_utils.decode_base64_to_image(mask) | |
if mask_im.mode == "RGBA": # whiten any opaque pixels in the mask | |
alpha_data = mask_im.getchannel("A").convert("L") | |
mask_im = Image.merge("RGB", [alpha_data, alpha_data, alpha_data]) | |
else: | |
mask_im = Image.new("L", im.size, 0) | |
return { "image": self._format_image(im), "mask": self._format_image(mask_im) } # pylint: disable=protected-access | |
return self._format_image(im) # pylint: disable=protected-access | |
def add_classes_to_gradio_component(comp): | |
""" | |
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others | |
""" | |
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] | |
if getattr(comp, 'multiselect', False): | |
comp.elem_classes.append('multiselect') | |
def IOComponent_init(self, *args, **kwargs): | |
self.webui_tooltip = kwargs.pop('tooltip', None) | |
if scripts.scripts_current is not None: | |
scripts.scripts_current.before_component(self, **kwargs) | |
scripts.script_callbacks.before_component_callback(self, **kwargs) | |
res = original_IOComponent_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return | |
add_classes_to_gradio_component(self) | |
scripts.script_callbacks.after_component_callback(self, **kwargs) | |
if scripts.scripts_current is not None: | |
scripts.scripts_current.after_component(self, **kwargs) | |
return res | |
def Block_get_config(self): | |
config = original_Block_get_config(self) | |
webui_tooltip = getattr(self, 'webui_tooltip', None) | |
if webui_tooltip: | |
config["webui_tooltip"] = webui_tooltip | |
config.pop('example_inputs', None) | |
return config | |
def BlockContext_init(self, *args, **kwargs): | |
if scripts.scripts_current is not None: | |
scripts.scripts_current.before_component(self, **kwargs) | |
scripts.script_callbacks.before_component_callback(self, **kwargs) | |
res = original_BlockContext_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return | |
add_classes_to_gradio_component(self) | |
scripts.script_callbacks.after_component_callback(self, **kwargs) | |
if scripts.scripts_current is not None: | |
scripts.scripts_current.after_component(self, **kwargs) | |
return res | |
def Blocks_get_config_file(self, *args, **kwargs): | |
config = original_Blocks_get_config_file(self, *args, **kwargs) | |
for comp_config in config["components"]: | |
if "example_inputs" in comp_config: | |
comp_config["example_inputs"] = {"serialized": []} | |
return config | |
def init(): | |
global hijacked, original_IOComponent_init, original_Block_get_config, original_BlockContext_init, original_Blocks_get_config_file # pylint: disable=global-statement | |
if hijacked: | |
return | |
gr.components.Image.preprocess = gr_image_preprocess | |
gr.components.IOComponent.pil_to_temp_file = gr_tempdir.pil_to_temp_file | |
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init) | |
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) | |
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) | |
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) | |
hijacked = True | |