qipchip31 commited on
Commit
9be3758
1 Parent(s): 85fba25

initial commit

Browse files
Files changed (1) hide show
  1. deploy_1.py +239 -0
deploy_1.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """deploy_1
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/15bRa4lN0gamY1gSoZhpUGDp61rmTJ0Eg
8
+
9
+ # Installing Modules
10
+ """
11
+
12
+ !pip install mediapipe
13
+ !pip install --upgrade diffusers[torch]
14
+ !pip install transformers
15
+ !pip install accelerate
16
+ !pip install git+https://github.com/huggingface/diffusers
17
+
18
+
19
+
20
+ """# Importing Modules"""
21
+
22
+ import mediapipe as mp
23
+ from mediapipe.tasks import python
24
+ from mediapipe.tasks.python import vision
25
+ import cv2
26
+ from google.colab.patches import cv2_imshow
27
+ import math
28
+ import numpy as np
29
+ from PIL import Image
30
+ from cv2 import kmeans, TERM_CRITERIA_MAX_ITER, TERM_CRITERIA_EPS, KMEANS_RANDOM_CENTERS
31
+ from numpy import float32
32
+ from matplotlib.pyplot import scatter, show
33
+ import matplotlib.pyplot as plt
34
+ import requests
35
+ from transformers import pipeline
36
+ import torch
37
+ import PIL
38
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDPMScheduler
39
+ from diffusers.utils import load_image
40
+ import torch
41
+
42
+
43
+
44
+ """# Stable Diffusion and ControlNet Pipeline"""
45
+
46
+ # Stable Diffusion Controlnet Pipeline Class
47
+ class StableDiffusionControlnetPipeline:
48
+ def __init__(self):
49
+ self.SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH = "/content/selfie_multiclass_256x256.tflite"
50
+ self.CONTROLNET_PATH = "lllyasviel/control_v11p_sd15_inpaint"
51
+ self.MODEL_PATH = "Uminosachi/realisticVisionV51_v51VAE-inpainting"
52
+ self.device = "cuda"
53
+ self.hair_color_pipeline = pipeline("image-classification", model="enzostvs/hair-color")
54
+ self.controlnet = ControlNetModel.from_pretrained(
55
+ self.CONTROLNET_PATH, torch_dtype=torch.float16
56
+ ).to(self.device)
57
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
58
+ self.MODEL_PATH,
59
+ controlnet=self.controlnet,
60
+ safety_checker=None,
61
+ requires_safety_checker=False,
62
+ torch_dtype=torch.float16
63
+ ).to(self.device)
64
+ pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
65
+ self.pipe = pipe
66
+
67
+ def get_hair_dominant_color(self, image_path):
68
+ hair_img = Image.open(image_path).convert('RGB')
69
+ results = self.hair_color_pipeline.predict(hair_img)
70
+ first_score, first_hair_color = results[0]["score"], results[0]["label"]
71
+ second_score, second_hair_color = results[1]["score"], results[1]["label"]
72
+ if first_hair_color != "completely bald":
73
+ return first_hair_color
74
+ else:
75
+ return second_hair_color
76
+
77
+ def make_inpaint_condition(self, image, image_mask):
78
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
79
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
80
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
81
+ image[image_mask > 0.5] = -1.0 # set as masked pixel
82
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
83
+ image = torch.from_numpy(image)
84
+ return image
85
+
86
+ def roundUp(self, input, round):
87
+ return input + round - (input % round)
88
+
89
+ def stable_diffusion_controlnet(self, image_path):
90
+ HAIR_ROOT_MASK_PATH = self.create_hair_root_mask(image_path, self.SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH)
91
+ HAIR_COLOR = self.get_hair_dominant_color(image_path)
92
+ PROMPT = f"({HAIR_COLOR} root:1.2), raw photo, high detail"
93
+ NEGATIVE_PROMPT = "black hair root"
94
+ init_image = load_image(image_path)
95
+ mask_image = load_image(HAIR_ROOT_MASK_PATH)
96
+ height = self.roundUp(init_image.height, 8)
97
+ width = self.roundUp(init_image.width, 8)
98
+ generator = torch.Generator(device=self.device).manual_seed(1)
99
+ control_image = self.make_inpaint_condition(init_image, mask_image)
100
+ new_image = self.pipe(
101
+ prompt=PROMPT,
102
+ image=init_image,
103
+ mask_image=mask_image,
104
+ num_inference_steps=40,
105
+ generator=generator,
106
+ control_image=control_image,
107
+ negative_prompt=NEGATIVE_PROMPT,
108
+ strength=1,
109
+ height=height,
110
+ width=width,
111
+ padding_mask_crop=40,
112
+ guidance_scale=3.5
113
+ ).images
114
+ hair_root_edited_img = new_image[0]
115
+ hair_root_edited_img.save("new_img_modified.jpg")
116
+ return hair_root_edited_img
117
+
118
+ def view_result(self, init_image, touched_up_image):
119
+ fig, axes = plt.subplots(1, 2, figsize=(12, 6))
120
+ axes[0].imshow(init_image)
121
+ axes[0].set_title('Original Image')
122
+ axes[0].axis('off')
123
+ axes[1].imshow(touched_up_image)
124
+ axes[1].set_title('Hair Root Touched-up')
125
+ axes[1].axis('off')
126
+ plt.show()
127
+
128
+ def resize_and_show(self, image, INPUT_HEIGHT=512, INPUT_WIDTH=512):
129
+ h, w = image.shape[:2]
130
+ if h < w:
131
+ img = cv2.resize(image, (INPUT_WIDTH, math.floor(h/(w/INPUT_WIDTH))))
132
+ else:
133
+ img = cv2.resize(image, (math.floor(w/(h/INPUT_HEIGHT)), INPUT_HEIGHT))
134
+ cv2_imshow(img)
135
+
136
+ def create_hair_root_mask(self, image_path, SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH):
137
+ BG_COLOR = (0, 0, 0) # Background RGB Color
138
+ MASK_COLOR = (255, 255, 255) # Mask RGB Color
139
+ HAIR_CLASS_INDEX = 1 # Index of the Hair Class
140
+ N_CLUSTERS = 3
141
+ img = cv2.imread(image_path)
142
+ base_options = python.BaseOptions(model_asset_path=SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH)
143
+ options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True)
144
+ with vision.ImageSegmenter.create_from_options(options) as segmenter:
145
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
146
+ segmentation_result = segmenter.segment(image)
147
+ category_mask = segmentation_result.category_mask
148
+ image_data = image.numpy_view()
149
+ fg_image = np.zeros(image_data.shape, dtype=np.uint8)
150
+ fg_image[:] = MASK_COLOR
151
+ bg_image = np.zeros(image_data.shape, dtype=np.uint8)
152
+ bg_image[:] = BG_COLOR
153
+ condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) == HAIR_CLASS_INDEX
154
+ output_image = np.where(condition, fg_image, bg_image)
155
+ cv2.imwrite("hair_mask.png", output_image)
156
+ hair_mask_cropped = cv2.bitwise_and(img, output_image)
157
+ coords = np.where(output_image != [255, 255, 255])
158
+ background = np.full(img.shape, 128, dtype=np.uint8) # gray background color
159
+ hair_mask_cropped[coords[0], coords[1], coords[2]] = background[coords[0], coords[1], coords[2]]
160
+ rgb_img_hair_mask_cropped = cv2.cvtColor(hair_mask_cropped, cv2.COLOR_BGR2RGB)
161
+ pillow_img = Image.fromarray(rgb_img_hair_mask_cropped)
162
+ pillow_img.save("hair_mask_cropped.jpg")
163
+ img_data = rgb_img_hair_mask_cropped.reshape(-1, 3)
164
+ criteria = (TERM_CRITERIA_MAX_ITER + TERM_CRITERIA_EPS, 100, 0.2)
165
+ compactness, labels, centers = kmeans(data=img_data.astype(float32), K=N_CLUSTERS, bestLabels=None,
166
+ criteria=criteria, attempts=10, flags=KMEANS_RANDOM_CENTERS)
167
+ colours = centers[labels].reshape(-1, 3)
168
+ img_colours = colours.reshape(rgb_img_hair_mask_cropped.shape)
169
+ number_labels = np.bincount(labels.flatten())
170
+ minimum_cluster_class = number_labels.argmin()
171
+ masked_image = np.copy(rgb_img_hair_mask_cropped)
172
+ masked_image = masked_image.reshape((-1, 3))
173
+ labels = labels.flatten()
174
+ masked_image[labels == minimum_cluster_class] = [255, 255, 255]
175
+ masked_image = masked_image.reshape(rgb_img_hair_mask_cropped.shape)
176
+ masked_image = np.copy(rgb_img_hair_mask_cropped)
177
+ masked_image = masked_image.reshape((-1, 3))
178
+ for i in range(0, len(number_labels)):
179
+ masked_image[labels == i] = [0, 0, 0]
180
+ masked_image[labels == minimum_cluster_class] = [255, 255, 255]
181
+ masked_image = masked_image.reshape(rgb_img_hair_mask_cropped.shape)
182
+ cv2.imwrite("hair_root_mask.jpg", masked_image)
183
+ hair_rost_mask_img = cv2.imread('hair_root_mask.jpg')
184
+ gray = cv2.cvtColor(hair_rost_mask_img, cv2.COLOR_BGR2GRAY)
185
+ ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
186
+ contours, hierarchy = cv2.findContours(binary, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)
187
+ image_copy = hair_rost_mask_img.copy()
188
+ image_copy = cv2.drawContours(image_copy, contours, -1, (255, 255, 255), thickness=3, lineType=cv2.LINE_4)
189
+ cv2.fillPoly(image_copy, pts=contours, color=(255, 255, 255))
190
+ (h, w) = image_copy.shape[:2]
191
+ cut_pixel = int((w // 2) * 0.25)
192
+ chin_point = ((w // 2) - cut_pixel, (h // 2) - cut_pixel)
193
+ image_copy[chin_point[0]:, :] = [0, 0, 0]
194
+ cv2.imwrite("hair_root_mask_mdf.png", image_copy)
195
+ HAIR_ROOT_MASK_PATH = "/content/hair_root_mask_mdf.png"
196
+ return HAIR_ROOT_MASK_PATH
197
+
198
+ """# Installing Gradio"""
199
+
200
+ !pip install gradio --upgrade
201
+
202
+ """## Calling the StableDiffusionControlnetPipeline for Gradio Interface
203
+
204
+ """
205
+
206
+ import numpy as np
207
+ import gradio as gr
208
+
209
+ # Assuming StableDiffusionControlnetPipeline class is already defined
210
+
211
+ # Define the function for Gradio
212
+ def process_image(input_img):
213
+ # Convert Gradio input image to numpy array
214
+ input_img_np = np.array(input_img)
215
+
216
+ # Save the uploaded image to a temporary file
217
+ temp_image_path = "/tmp/uploaded_image.jpg"
218
+ input_img.save(temp_image_path)
219
+
220
+ # Instantiate your pipeline with the uploaded image
221
+ SB_ControlNet_pipeline = StableDiffusionControlnetPipeline()
222
+
223
+ # Process the image using your pipeline
224
+ output_img = SB_ControlNet_pipeline.stable_diffusion_controlnet(temp_image_path)
225
+
226
+ return output_img
227
+
228
+ # Create a Gradio interface
229
+ iface = gr.Interface(
230
+ fn=process_image,
231
+ inputs=gr.Image(type="pil", label="Upload Image"),
232
+ outputs="image",
233
+ title="Hair Root Touch Up using AI!",
234
+ description="Upload an image to edit hair roots using Stable Diffusion Controlnet:)"
235
+ )
236
+
237
+ # Launch the Gradio interface
238
+ iface.launch(debug=True)
239
+