Akjava commited on
Commit
c96d305
1 Parent(s): 9bb2601
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ files
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import subprocess
4
+ from PIL import Image,ImageEnhance,ImageFilter,ImageDraw
5
+ import json
6
+ import numpy as np
7
+ from skimage.exposure import match_histograms
8
+ from gradio_utils import save_image
9
+ from color_utils import simple_white_balance,apply_tone_curve,curve_midtones,create_left_half_mask,create_top_half_mask,create_compare_image ,mirror
10
+
11
+ def color_match(base_image,cropped_image):
12
+ reference = np.array(base_image)
13
+ target =np.array(cropped_image)
14
+ matched = match_histograms(target, reference,channel_axis=-1)
15
+ return Image.fromarray(matched)
16
+
17
+
18
+
19
+ def apply_layer(image_dict):
20
+ base_rgba = image_dict["background"].convert("RGBA")
21
+ if len(image_dict['layers']) > 0:
22
+ layer = image_dict['layers'][0]
23
+ mask = layer.convert("L") # グレイスケールに変換
24
+ mask=mask.point(lambda x: 255 if x > 0 else x)
25
+
26
+
27
+ layer_rgba = layer.convert("RGBA")
28
+ base_rgba.paste(layer_rgba, (0, 0),mask)
29
+ return base_rgba
30
+
31
+
32
+ def create_enhanced_image(reference_image,brightness=1.0,color=1.0,contrast=1.0,use_whitebalance=False,top_whitebalance=1):
33
+
34
+ if use_whitebalance:
35
+ reference_image = simple_white_balance(reference_image,top_whitebalance)
36
+
37
+ if brightness!=1.0:
38
+ brightness_enhancer = ImageEnhance.Brightness(reference_image)
39
+ reference_image = brightness_enhancer.enhance(brightness)
40
+
41
+ if color!=1.0:
42
+ color_enhancer = ImageEnhance.Color(reference_image)
43
+ reference_image = color_enhancer.enhance(color)
44
+
45
+ if contrast!=1.0:
46
+ contrast_enhancer = ImageEnhance.Contrast(reference_image)
47
+ reference_image = contrast_enhancer.enhance(contrast)
48
+
49
+ return reference_image
50
+
51
+
52
+
53
+
54
+ def process_images(reference_image_dict,target_image_dict,mirror_target=False,middle_tone_value=0.75):
55
+ if reference_image_dict == None:
56
+ raise gr.Error("Need reference_image")
57
+
58
+ if target_image_dict == None:
59
+ raise gr.Error("Need target_image")
60
+
61
+ if not isinstance(reference_image_dict, dict):
62
+ raise gr.Error("Need DictData reference_image_dict")
63
+
64
+ if not isinstance(target_image_dict, dict):
65
+ raise gr.Error("Need DictData target_image_dict")
66
+
67
+ reference_image = apply_layer(reference_image_dict)
68
+ target_image = apply_layer(target_image_dict)
69
+ if mirror_target:
70
+ target_image = mirror(target_image)
71
+
72
+
73
+ images = []
74
+
75
+ left_mask = create_left_half_mask(reference_image)
76
+ top_mask = create_top_half_mask(reference_image)
77
+
78
+ color_matched = color_match(reference_image,target_image)
79
+ color_matched_resized = color_matched.resize(reference_image.size)
80
+ matched_path = save_image(color_matched.convert("RGB"))
81
+ images.append((matched_path,"color matched"))
82
+
83
+ reference_mix_left,reference_mix_right = create_compare_image(reference_image,color_matched_resized,left_mask)
84
+ images.append((save_image(reference_mix_left.convert("RGB"),extension="webp"),"mixed_left"))
85
+ images.append((save_image(reference_mix_right.convert("RGB"),extension="webp"),"mixed_right"))
86
+
87
+ reference_mix_top,reference_mix_bottom = create_compare_image(reference_image,color_matched_resized,top_mask)
88
+ images.append((save_image(reference_mix_top.convert("RGB"),extension="webp"),"mixed_top"))
89
+ images.append((save_image(reference_mix_bottom.convert("RGB"),extension="webp"),"mixed_bottom"))
90
+
91
+ color_matched_tone = apply_tone_curve(color_matched.convert("RGB"),curve_midtones,middle_tone_value)
92
+ color_matched_tone_resized = color_matched_tone.resize(reference_image.size)
93
+
94
+ images.append((save_image(color_matched_tone.convert("RGB")),"tone-curved"))
95
+ reference_mix_left,reference_mix_right = create_compare_image(reference_image,color_matched_tone_resized,left_mask)
96
+ images.append((save_image(reference_mix_left.convert("RGB"),extension="webp"),"mixed_left"))
97
+ images.append((save_image(reference_mix_right.convert("RGB"),extension="webp"),"mixed_right"))
98
+
99
+ reference_mix_top,reference_mix_bottom = create_compare_image(reference_image,color_matched_tone_resized,top_mask)
100
+ images.append((save_image(reference_mix_top.convert("RGB"),extension="webp"),"mixed_top"))
101
+ images.append((save_image(reference_mix_bottom.convert("RGB"),extension="webp"),"mixed_bottom"))
102
+
103
+
104
+ return images
105
+
106
+
107
+
108
+ def read_file(file_path: str) -> str:
109
+ """read the text of target file
110
+ """
111
+ with open(file_path, 'r', encoding='utf-8') as f:
112
+ content = f.read()
113
+
114
+ return content
115
+
116
+ css="""
117
+ #col-left {
118
+ margin: 0 auto;
119
+ max-width: 640px;
120
+ }
121
+ #col-right {
122
+ margin: 0 auto;
123
+ max-width: 640px;
124
+ }
125
+ .grid-container {
126
+ display: flex;
127
+ align-items: center;
128
+ justify-content: center;
129
+ gap:10px
130
+ }
131
+
132
+ .image {
133
+ width: 128px;
134
+ height: 128px;
135
+ object-fit: cover;
136
+ }
137
+
138
+ .text {
139
+ font-size: 16px;
140
+ }
141
+ """
142
+
143
+ def color_changed(color):
144
+ #mode must be RGBA
145
+ editor = gr.ImageEditor(brush=gr.Brush(colors=[color],color_mode="fixed"))
146
+ return editor,editor
147
+
148
+ #css=css,
149
+ def update_button_label(image):
150
+ if image == None:
151
+ print("none replace")
152
+ return gr.Button(visible=True),gr.Button(visible=False),gr.Row(visible=True),gr.Row(visible=True)
153
+ else:
154
+ return gr.Button(visible=False),gr.Button(visible=True),gr.Row(visible=False),gr.Row(visible=False)
155
+
156
+ def update_visible(fill_color_mode,image):
157
+ if image != None:
158
+ return gr.Row(visible=False),gr.Row(visible=False)
159
+
160
+ if fill_color_mode:
161
+ return gr.Row(visible=False),gr.Row(visible=True)
162
+ else:
163
+ return gr.Row(visible=True),gr.Row(visible=False)
164
+
165
+ with gr.Blocks(css=css, elem_id="demo-container") as demo:
166
+ with gr.Column():
167
+ gr.HTML(read_file("demo_header.html"))
168
+ gr.HTML(read_file("demo_tools.html"))
169
+ with gr.Row():
170
+ with gr.Column():
171
+ reference_image = gr.ImageEditor(height=1050,sources=['upload','clipboard'],layers = False,transforms=[],image_mode='RGBA',elem_id="image_upload", type="pil", label="Reference Image",brush=gr.Brush(colors=["#001"], color_mode="fixed"))
172
+ with gr.Row(elem_id="prompt-container", equal_height=False):
173
+ btn1 = gr.Button("Color Match", elem_id="run_button",variant="primary")
174
+ mirror_target = gr.Checkbox(label="Mirror target",value=False)
175
+ pick=gr.ColorPicker(label="color",value="#001",info="ImageEditor color is broken,pick color from here.reselect paint-tool and draw.but not so effective")
176
+
177
+ target_image = gr.ImageEditor(height=1050,sources=['upload','clipboard'],layers = False,transforms=[],image_mode='RGBA',elem_id="image_upload", type="pil", label="Target Image",brush=gr.Brush(colors=["#001"], color_mode="fixed"))
178
+ pick.change(fn=color_changed,inputs=[pick],outputs=[reference_image,target_image])
179
+
180
+
181
+
182
+ with gr.Accordion(label="Advanced Settings", open=False):
183
+ gr.HTML("<h4>Post-Process Target Image</h4>")
184
+ with gr.Row(equal_height=True):
185
+ middle_tone_value = gr.Slider(
186
+ label="middle tone",
187
+ minimum=0,
188
+ maximum=2.0,
189
+ step=0.01,
190
+ value=0.75)
191
+
192
+ with gr.Column():
193
+ image_out = gr.Gallery(height=800,label="Output", elem_id="output-img",format="webp", preview=True)
194
+
195
+
196
+
197
+
198
+
199
+
200
+ gr.on(
201
+ [btn1.click],
202
+ fn=process_images, inputs=[reference_image,target_image,mirror_target,middle_tone_value], outputs =[image_out], api_name='infer'
203
+ )
204
+ gr.Examples(
205
+ examples =[
206
+ ["examples/face01.webp","examples/face02.webp"],
207
+ ["examples/face01.webp","examples/face03.webp"],
208
+ ["examples/face01.webp","examples/face04.webp"],
209
+ ["examples/face02.webp","examples/face03.webp"],
210
+ ["examples/face02.webp","examples/face04.webp"],
211
+ ["examples/face03.webp","examples/face04.webp"],
212
+ ],
213
+ inputs=[reference_image,target_image]
214
+ )
215
+ gr.HTML(read_file("demo_footer.html"))
216
+
217
+ if __name__ == "__main__":
218
+ demo.launch()
color_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from PIL import Image,ImageDraw,ImageOps
4
+
5
+
6
+ def create_color_image(width, height, color=(255,255,255)):
7
+ img = Image.new('RGB', (width, height), color)
8
+ return img
9
+
10
+ def create_compare_image(base_image,paste_image,mask):
11
+ normal_image=base_image.copy()
12
+ normal_image.paste(paste_image,(0,0),mask)
13
+
14
+ invert_image=base_image.copy()
15
+ invert_image.paste(paste_image,(0,0),ImageOps.invert(mask))
16
+ return normal_image,invert_image
17
+
18
+ def mirror(image):
19
+ return ImageOps.mirror(image)
20
+
21
+ def create_left_half_mask(image):
22
+
23
+ left_mask = create_color_image(image.width,image.height)
24
+ draw = ImageDraw.Draw(left_mask)
25
+ draw.rectangle((0, 0, int(image.width/2), int(image.height)), fill=(0, 0, 0),outline=None)
26
+ return left_mask.convert("L")
27
+
28
+ def create_top_half_mask(image):
29
+
30
+ left_mask = create_color_image(image.width,image.height)
31
+ draw = ImageDraw.Draw(left_mask)
32
+ draw.rectangle((0, 0, int(image.width), int(image.height/2)), fill=(0, 0, 0),outline=None)
33
+
34
+ return left_mask.convert("L")
35
+
36
+ def curve_midtones(x,option=0.7):
37
+ return 255 * (x / 255) ** option
38
+
39
+ def apply_tone_curve(image, curve_function,option=1.0):
40
+ # LUTを作成
41
+ lut = np.array([curve_function(i,option) for i in range(256)], dtype=np.uint8)
42
+
43
+ # 画像をNumPy配列に変換
44
+ img_array = np.array(image)
45
+
46
+ # LUTを適用
47
+ adjusted_array = lut[img_array]
48
+
49
+ # 調整後の配列を画像に戻す
50
+ return Image.fromarray(adjusted_array)
51
+
52
+ def simple_white_balance(image, p=10, output_min=0, output_max=255):
53
+ """
54
+ PIL simple white balance without numpy
55
+
56
+ Args:
57
+ image: PIL Image
58
+ p: ignore pixel percent (50 convert to 51)
59
+ output_min: min bright
60
+ output_max: max bright
61
+
62
+ Returns:
63
+ PIL Image
64
+ """
65
+ if p == 50:
66
+ p = 51# even make zero-error
67
+
68
+ # convert to rgb
69
+ image = image.convert("RGB")
70
+
71
+ # get histgram
72
+ histograms = image.histogram()
73
+
74
+ # make lut
75
+ luts = []
76
+ for i in range(3):
77
+ hist = histograms[i * 256:(i + 1) * 256]
78
+ total = sum(hist)
79
+
80
+ # min
81
+ sum_low = 0
82
+ low_value = 0
83
+ for j, count in enumerate(hist):
84
+ sum_low += count
85
+ if sum_low > total * p / 100:
86
+ low_value = j
87
+ break
88
+
89
+ # max
90
+ sum_high = 0
91
+ high_value = 255
92
+ for j, count in enumerate(reversed(hist)):
93
+ sum_high += count
94
+ if sum_high > total * p / 100:
95
+ high_value = 255 - j
96
+ break
97
+
98
+ # LUT
99
+ lut = [0] * 256 # initialize 0
100
+ for j in range(256):
101
+ if j < low_value:
102
+ lut[j] = output_min
103
+ elif j > high_value:
104
+ lut[j] = output_max
105
+ else:
106
+ v = (j - low_value) / (high_value - low_value)
107
+ lut[j] = int(round(output_min + (output_max - output_min) * v))
108
+
109
+ luts.extend(lut)
110
+
111
+ # apply LUT
112
+ return image.point(luts)
demo_footer.html ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ <div>
2
+ <P> Images are generated with <a href="https://huggingface.co/black-forest-labs/FLUX.1-schnell">FLUX.1-schnell</a> and licensed under <a href="http://www.apache.org/licenses/LICENSE-2.0">the Apache 2.0 License</a>
3
+ </div>
demo_header.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="text-align: center;">
2
+ <h1>
3
+ Histgram Color Matching
4
+ </h1>
5
+ <div class="grid-container">
6
+ Body-to-body image color matching will never work; try face-to-face.keep eyes and hair color same<br>
7
+ Usually default is best.To similar lighting angle,mirror(flip horizental) is important<br>
8
+ Adding color paint(pink or green) can change hue(but not recommend to add target image)<br>
9
+ modify tone-curve adjust brightness.However most of the case it's the lighting's issue, so there's not much we can do <br>
10
+ </p>
11
+ </div>
12
+
13
+ </div>
demo_tools.html ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <div style="text-align: center;">
2
+ <p><a href="https://huggingface.co/spaces/Akjava/mediapipe-face-detect">Mediapipe Face detector</a></p>
3
+ <p></p>
4
+ </div>
examples/face01.webp ADDED
examples/face02.webp ADDED
examples/face03.webp ADDED
examples/face04.webp ADDED
gradio_utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import time
5
+ import io
6
+ import hashlib
7
+
8
+ def clear_old_files(dir="files",passed_time=60*60):
9
+ try:
10
+ files = os.listdir(dir)
11
+ current_time = time.time()
12
+ for file in files:
13
+ file_path = os.path.join(dir,file)
14
+
15
+ ctime = os.stat(file_path).st_ctime
16
+ diff = current_time - ctime
17
+ #print(f"ctime={ctime},current_time={current_time},passed_time={passed_time},diff={diff}")
18
+ if diff > passed_time:
19
+ os.remove(file_path)
20
+ except:
21
+ print("maybe still gallery using error")
22
+
23
+ def get_image_id(image):
24
+ buffer = io.BytesIO()
25
+ image.save(buffer, format='PNG')
26
+ hash_object = hashlib.sha256(buffer.getvalue())
27
+ hex_dig = hash_object.hexdigest()
28
+ unique_id = hex_dig[:32]
29
+ return unique_id
30
+
31
+ def save_image(image,extension="jpg",dir_name="files"):
32
+ id = get_image_id(image)
33
+ os.makedirs(dir_name,exist_ok=True)
34
+ file_path = f"{dir_name}/{id}.{extension}"
35
+
36
+ image.save(file_path)
37
+ return file_path
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ spaces
4
+ scikit-image