import spaces import gradio as gr import subprocess from PIL import Image,ImageEnhance,ImageFilter,ImageDraw import json import numpy as np from skimage.exposure import match_histograms from gradio_utils import save_image from color_utils import simple_white_balance,apply_tone_curve,curve_midtones,create_left_half_mask,create_top_half_mask,create_compare_image ,mirror def color_match(base_image,cropped_image,color_match_format="RGB"): reference = np.array(base_image.convert(color_match_format)) target =np.array(cropped_image.convert(color_match_format)) matched = match_histograms(target, reference,channel_axis=-1) return Image.fromarray(matched,mode=color_match_format) def apply_layer(image_dict): base_rgba = image_dict["background"].convert("RGBA") if len(image_dict['layers']) > 0: layer = image_dict['layers'][0] mask = layer.convert("L") # グレイスケールに変換 mask=mask.point(lambda x: 255 if x > 0 else x) layer_rgba = layer.convert("RGBA") base_rgba.paste(layer_rgba, (0, 0),mask) return base_rgba def create_enhanced_image(reference_image,brightness=1.0,color=1.0,contrast=1.0,use_whitebalance=False,top_whitebalance=1): if use_whitebalance: reference_image = simple_white_balance(reference_image,top_whitebalance) if brightness!=1.0: brightness_enhancer = ImageEnhance.Brightness(reference_image) reference_image = brightness_enhancer.enhance(brightness) if color!=1.0: color_enhancer = ImageEnhance.Color(reference_image) reference_image = color_enhancer.enhance(color) if contrast!=1.0: contrast_enhancer = ImageEnhance.Contrast(reference_image) reference_image = contrast_enhancer.enhance(contrast) return reference_image def process_images(reference_image_dict,target_image_dict,mirror_target=False,middle_tone_value=0.75,color_match_format="RGB",progress=gr.Progress(track_tqdm=True)): progress(0, desc="Start color matching") if reference_image_dict == None: raise gr.Error("Need reference_image") if target_image_dict == None: raise gr.Error("Need target_image") if not isinstance(reference_image_dict, dict): raise gr.Error("Need DictData reference_image_dict") if not isinstance(target_image_dict, dict): raise gr.Error("Need DictData target_image_dict") reference_image = apply_layer(reference_image_dict) target_image = apply_layer(target_image_dict) if mirror_target: target_image = mirror(target_image) images = [] left_mask = create_left_half_mask(reference_image) top_mask = create_top_half_mask(reference_image) color_matched = color_match(reference_image,target_image,color_match_format) color_matched_resized = color_matched.resize(reference_image.size) matched_path = save_image(color_matched.convert("RGB")) images.append((matched_path,"color matched")) progress(0.2) reference_mix_left,reference_mix_right = create_compare_image(reference_image,color_matched_resized,left_mask) images.append((save_image(reference_mix_left.convert("RGB"),extension="webp"),"mixed_left")) images.append((save_image(reference_mix_right.convert("RGB"),extension="webp"),"mixed_right")) progress(0.4) reference_mix_top,reference_mix_bottom = create_compare_image(reference_image,color_matched_resized,top_mask) images.append((save_image(reference_mix_top.convert("RGB"),extension="webp"),"mixed_top")) images.append((save_image(reference_mix_bottom.convert("RGB"),extension="webp"),"mixed_bottom")) progress(0.6) color_matched_tone = apply_tone_curve(color_matched.convert("RGB"),curve_midtones,middle_tone_value) color_matched_tone_resized = color_matched_tone.resize(reference_image.size) images.append((save_image(color_matched_tone.convert("RGB")),"tone-curved")) reference_mix_left,reference_mix_right = create_compare_image(reference_image,color_matched_tone_resized,left_mask) images.append((save_image(reference_mix_left.convert("RGB"),extension="webp"),"mixed_left")) images.append((save_image(reference_mix_right.convert("RGB"),extension="webp"),"mixed_right")) progress(0.8) reference_mix_top,reference_mix_bottom = create_compare_image(reference_image,color_matched_tone_resized,top_mask) images.append((save_image(reference_mix_top.convert("RGB"),extension="webp"),"mixed_top")) images.append((save_image(reference_mix_bottom.convert("RGB"),extension="webp"),"mixed_bottom")) progress(1.0) return images def read_file(file_path: str) -> str: """read the text of target file """ with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content css=""" #col-left { margin: 0 auto; max-width: 640px; } #col-right { margin: 0 auto; max-width: 640px; } .grid-container { display: flex; align-items: center; justify-content: center; gap:10px } .image { width: 128px; height: 128px; object-fit: cover; } .text { font-size: 16px; } """ def color_changed(color): #mode must be RGBA editor = gr.ImageEditor(brush=gr.Brush(colors=[color],color_mode="fixed")) return editor,editor #css=css, def update_button_label(image): if image == None: print("none replace") return gr.Button(visible=True),gr.Button(visible=False),gr.Row(visible=True),gr.Row(visible=True) else: return gr.Button(visible=False),gr.Button(visible=True),gr.Row(visible=False),gr.Row(visible=False) def update_visible(fill_color_mode,image): if image != None: return gr.Row(visible=False),gr.Row(visible=False) if fill_color_mode: return gr.Row(visible=False),gr.Row(visible=True) else: return gr.Row(visible=True),gr.Row(visible=False) with gr.Blocks(css=css, elem_id="demo-container") as demo: with gr.Column(): gr.HTML(read_file("demo_header.html")) gr.HTML(read_file("demo_tools.html")) with gr.Row(): with gr.Column(): reference_image = gr.ImageEditor(height=1050,sources=['upload','clipboard'],layers = False,transforms=[],image_mode='RGBA',elem_id="image_upload", type="pil", label="Reference Image",brush=gr.Brush(colors=["#001"], color_mode="fixed")) with gr.Row(elem_id="prompt-container", equal_height=False): btn1 = gr.Button("Color Match", elem_id="run_button",variant="primary") mirror_target = gr.Checkbox(label="Mirror target",value=False) pick=gr.ColorPicker(label="color",value="#001",info="ImageEditor color is broken,pick color from here.reselect paint-tool and draw.but not so effective") target_image = gr.ImageEditor(height=1050,sources=['upload','clipboard'],layers = False,transforms=[],image_mode='RGBA',elem_id="image_upload", type="pil", label="Target Image",brush=gr.Brush(colors=["#001"], color_mode="fixed")) pick.change(fn=color_changed,inputs=[pick],outputs=[reference_image,target_image]) with gr.Accordion(label="Advanced Settings", open=False): gr.HTML("

Post-Process Target Image

") with gr.Row(equal_height=True): middle_tone_value = gr.Slider( label="middle tone", minimum=0, maximum=2.0, step=0.01, value=0.75) color_match_format = gr.Dropdown(label="Format",choices=["RGB","CMYK","YCbCr","HSV","LAB"],value="RGB",info="RGB and CMYK seems same,others are broken") with gr.Column(): image_out = gr.Gallery(height=800,label="Output", elem_id="output-img",format="webp", preview=True) gr.on( [btn1.click], fn=process_images, inputs=[reference_image,target_image,mirror_target,middle_tone_value,color_match_format], outputs =[image_out], api_name='infer' ) gr.Examples( examples =[ ["examples/face01.webp","examples/face02.webp"], ["examples/face01.webp","examples/face03.webp"], ["examples/face01.webp","examples/face04.webp"], ["examples/face02.webp","examples/face03.webp"], ["examples/face02.webp","examples/face04.webp"], ["examples/face03.webp","examples/face04.webp"], ], inputs=[reference_image,target_image] ) gr.HTML(read_file("demo_footer.html")) if __name__ == "__main__": demo.launch()