Spaces:
No application file
No application file
initial commit
Browse files- deploy_1.py +239 -0
deploy_1.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""deploy_1
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/15bRa4lN0gamY1gSoZhpUGDp61rmTJ0Eg
|
8 |
+
|
9 |
+
# Installing Modules
|
10 |
+
"""
|
11 |
+
|
12 |
+
!pip install mediapipe
|
13 |
+
!pip install --upgrade diffusers[torch]
|
14 |
+
!pip install transformers
|
15 |
+
!pip install accelerate
|
16 |
+
!pip install git+https://github.com/huggingface/diffusers
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
"""# Importing Modules"""
|
21 |
+
|
22 |
+
import mediapipe as mp
|
23 |
+
from mediapipe.tasks import python
|
24 |
+
from mediapipe.tasks.python import vision
|
25 |
+
import cv2
|
26 |
+
from google.colab.patches import cv2_imshow
|
27 |
+
import math
|
28 |
+
import numpy as np
|
29 |
+
from PIL import Image
|
30 |
+
from cv2 import kmeans, TERM_CRITERIA_MAX_ITER, TERM_CRITERIA_EPS, KMEANS_RANDOM_CENTERS
|
31 |
+
from numpy import float32
|
32 |
+
from matplotlib.pyplot import scatter, show
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
import requests
|
35 |
+
from transformers import pipeline
|
36 |
+
import torch
|
37 |
+
import PIL
|
38 |
+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDPMScheduler
|
39 |
+
from diffusers.utils import load_image
|
40 |
+
import torch
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
"""# Stable Diffusion and ControlNet Pipeline"""
|
45 |
+
|
46 |
+
# Stable Diffusion Controlnet Pipeline Class
|
47 |
+
class StableDiffusionControlnetPipeline:
|
48 |
+
def __init__(self):
|
49 |
+
self.SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH = "/content/selfie_multiclass_256x256.tflite"
|
50 |
+
self.CONTROLNET_PATH = "lllyasviel/control_v11p_sd15_inpaint"
|
51 |
+
self.MODEL_PATH = "Uminosachi/realisticVisionV51_v51VAE-inpainting"
|
52 |
+
self.device = "cuda"
|
53 |
+
self.hair_color_pipeline = pipeline("image-classification", model="enzostvs/hair-color")
|
54 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
55 |
+
self.CONTROLNET_PATH, torch_dtype=torch.float16
|
56 |
+
).to(self.device)
|
57 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
58 |
+
self.MODEL_PATH,
|
59 |
+
controlnet=self.controlnet,
|
60 |
+
safety_checker=None,
|
61 |
+
requires_safety_checker=False,
|
62 |
+
torch_dtype=torch.float16
|
63 |
+
).to(self.device)
|
64 |
+
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
|
65 |
+
self.pipe = pipe
|
66 |
+
|
67 |
+
def get_hair_dominant_color(self, image_path):
|
68 |
+
hair_img = Image.open(image_path).convert('RGB')
|
69 |
+
results = self.hair_color_pipeline.predict(hair_img)
|
70 |
+
first_score, first_hair_color = results[0]["score"], results[0]["label"]
|
71 |
+
second_score, second_hair_color = results[1]["score"], results[1]["label"]
|
72 |
+
if first_hair_color != "completely bald":
|
73 |
+
return first_hair_color
|
74 |
+
else:
|
75 |
+
return second_hair_color
|
76 |
+
|
77 |
+
def make_inpaint_condition(self, image, image_mask):
|
78 |
+
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
|
79 |
+
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
|
80 |
+
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
|
81 |
+
image[image_mask > 0.5] = -1.0 # set as masked pixel
|
82 |
+
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
83 |
+
image = torch.from_numpy(image)
|
84 |
+
return image
|
85 |
+
|
86 |
+
def roundUp(self, input, round):
|
87 |
+
return input + round - (input % round)
|
88 |
+
|
89 |
+
def stable_diffusion_controlnet(self, image_path):
|
90 |
+
HAIR_ROOT_MASK_PATH = self.create_hair_root_mask(image_path, self.SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH)
|
91 |
+
HAIR_COLOR = self.get_hair_dominant_color(image_path)
|
92 |
+
PROMPT = f"({HAIR_COLOR} root:1.2), raw photo, high detail"
|
93 |
+
NEGATIVE_PROMPT = "black hair root"
|
94 |
+
init_image = load_image(image_path)
|
95 |
+
mask_image = load_image(HAIR_ROOT_MASK_PATH)
|
96 |
+
height = self.roundUp(init_image.height, 8)
|
97 |
+
width = self.roundUp(init_image.width, 8)
|
98 |
+
generator = torch.Generator(device=self.device).manual_seed(1)
|
99 |
+
control_image = self.make_inpaint_condition(init_image, mask_image)
|
100 |
+
new_image = self.pipe(
|
101 |
+
prompt=PROMPT,
|
102 |
+
image=init_image,
|
103 |
+
mask_image=mask_image,
|
104 |
+
num_inference_steps=40,
|
105 |
+
generator=generator,
|
106 |
+
control_image=control_image,
|
107 |
+
negative_prompt=NEGATIVE_PROMPT,
|
108 |
+
strength=1,
|
109 |
+
height=height,
|
110 |
+
width=width,
|
111 |
+
padding_mask_crop=40,
|
112 |
+
guidance_scale=3.5
|
113 |
+
).images
|
114 |
+
hair_root_edited_img = new_image[0]
|
115 |
+
hair_root_edited_img.save("new_img_modified.jpg")
|
116 |
+
return hair_root_edited_img
|
117 |
+
|
118 |
+
def view_result(self, init_image, touched_up_image):
|
119 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
120 |
+
axes[0].imshow(init_image)
|
121 |
+
axes[0].set_title('Original Image')
|
122 |
+
axes[0].axis('off')
|
123 |
+
axes[1].imshow(touched_up_image)
|
124 |
+
axes[1].set_title('Hair Root Touched-up')
|
125 |
+
axes[1].axis('off')
|
126 |
+
plt.show()
|
127 |
+
|
128 |
+
def resize_and_show(self, image, INPUT_HEIGHT=512, INPUT_WIDTH=512):
|
129 |
+
h, w = image.shape[:2]
|
130 |
+
if h < w:
|
131 |
+
img = cv2.resize(image, (INPUT_WIDTH, math.floor(h/(w/INPUT_WIDTH))))
|
132 |
+
else:
|
133 |
+
img = cv2.resize(image, (math.floor(w/(h/INPUT_HEIGHT)), INPUT_HEIGHT))
|
134 |
+
cv2_imshow(img)
|
135 |
+
|
136 |
+
def create_hair_root_mask(self, image_path, SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH):
|
137 |
+
BG_COLOR = (0, 0, 0) # Background RGB Color
|
138 |
+
MASK_COLOR = (255, 255, 255) # Mask RGB Color
|
139 |
+
HAIR_CLASS_INDEX = 1 # Index of the Hair Class
|
140 |
+
N_CLUSTERS = 3
|
141 |
+
img = cv2.imread(image_path)
|
142 |
+
base_options = python.BaseOptions(model_asset_path=SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH)
|
143 |
+
options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True)
|
144 |
+
with vision.ImageSegmenter.create_from_options(options) as segmenter:
|
145 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
|
146 |
+
segmentation_result = segmenter.segment(image)
|
147 |
+
category_mask = segmentation_result.category_mask
|
148 |
+
image_data = image.numpy_view()
|
149 |
+
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
150 |
+
fg_image[:] = MASK_COLOR
|
151 |
+
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
152 |
+
bg_image[:] = BG_COLOR
|
153 |
+
condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) == HAIR_CLASS_INDEX
|
154 |
+
output_image = np.where(condition, fg_image, bg_image)
|
155 |
+
cv2.imwrite("hair_mask.png", output_image)
|
156 |
+
hair_mask_cropped = cv2.bitwise_and(img, output_image)
|
157 |
+
coords = np.where(output_image != [255, 255, 255])
|
158 |
+
background = np.full(img.shape, 128, dtype=np.uint8) # gray background color
|
159 |
+
hair_mask_cropped[coords[0], coords[1], coords[2]] = background[coords[0], coords[1], coords[2]]
|
160 |
+
rgb_img_hair_mask_cropped = cv2.cvtColor(hair_mask_cropped, cv2.COLOR_BGR2RGB)
|
161 |
+
pillow_img = Image.fromarray(rgb_img_hair_mask_cropped)
|
162 |
+
pillow_img.save("hair_mask_cropped.jpg")
|
163 |
+
img_data = rgb_img_hair_mask_cropped.reshape(-1, 3)
|
164 |
+
criteria = (TERM_CRITERIA_MAX_ITER + TERM_CRITERIA_EPS, 100, 0.2)
|
165 |
+
compactness, labels, centers = kmeans(data=img_data.astype(float32), K=N_CLUSTERS, bestLabels=None,
|
166 |
+
criteria=criteria, attempts=10, flags=KMEANS_RANDOM_CENTERS)
|
167 |
+
colours = centers[labels].reshape(-1, 3)
|
168 |
+
img_colours = colours.reshape(rgb_img_hair_mask_cropped.shape)
|
169 |
+
number_labels = np.bincount(labels.flatten())
|
170 |
+
minimum_cluster_class = number_labels.argmin()
|
171 |
+
masked_image = np.copy(rgb_img_hair_mask_cropped)
|
172 |
+
masked_image = masked_image.reshape((-1, 3))
|
173 |
+
labels = labels.flatten()
|
174 |
+
masked_image[labels == minimum_cluster_class] = [255, 255, 255]
|
175 |
+
masked_image = masked_image.reshape(rgb_img_hair_mask_cropped.shape)
|
176 |
+
masked_image = np.copy(rgb_img_hair_mask_cropped)
|
177 |
+
masked_image = masked_image.reshape((-1, 3))
|
178 |
+
for i in range(0, len(number_labels)):
|
179 |
+
masked_image[labels == i] = [0, 0, 0]
|
180 |
+
masked_image[labels == minimum_cluster_class] = [255, 255, 255]
|
181 |
+
masked_image = masked_image.reshape(rgb_img_hair_mask_cropped.shape)
|
182 |
+
cv2.imwrite("hair_root_mask.jpg", masked_image)
|
183 |
+
hair_rost_mask_img = cv2.imread('hair_root_mask.jpg')
|
184 |
+
gray = cv2.cvtColor(hair_rost_mask_img, cv2.COLOR_BGR2GRAY)
|
185 |
+
ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
186 |
+
contours, hierarchy = cv2.findContours(binary, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)
|
187 |
+
image_copy = hair_rost_mask_img.copy()
|
188 |
+
image_copy = cv2.drawContours(image_copy, contours, -1, (255, 255, 255), thickness=3, lineType=cv2.LINE_4)
|
189 |
+
cv2.fillPoly(image_copy, pts=contours, color=(255, 255, 255))
|
190 |
+
(h, w) = image_copy.shape[:2]
|
191 |
+
cut_pixel = int((w // 2) * 0.25)
|
192 |
+
chin_point = ((w // 2) - cut_pixel, (h // 2) - cut_pixel)
|
193 |
+
image_copy[chin_point[0]:, :] = [0, 0, 0]
|
194 |
+
cv2.imwrite("hair_root_mask_mdf.png", image_copy)
|
195 |
+
HAIR_ROOT_MASK_PATH = "/content/hair_root_mask_mdf.png"
|
196 |
+
return HAIR_ROOT_MASK_PATH
|
197 |
+
|
198 |
+
"""# Installing Gradio"""
|
199 |
+
|
200 |
+
!pip install gradio --upgrade
|
201 |
+
|
202 |
+
"""## Calling the StableDiffusionControlnetPipeline for Gradio Interface
|
203 |
+
|
204 |
+
"""
|
205 |
+
|
206 |
+
import numpy as np
|
207 |
+
import gradio as gr
|
208 |
+
|
209 |
+
# Assuming StableDiffusionControlnetPipeline class is already defined
|
210 |
+
|
211 |
+
# Define the function for Gradio
|
212 |
+
def process_image(input_img):
|
213 |
+
# Convert Gradio input image to numpy array
|
214 |
+
input_img_np = np.array(input_img)
|
215 |
+
|
216 |
+
# Save the uploaded image to a temporary file
|
217 |
+
temp_image_path = "/tmp/uploaded_image.jpg"
|
218 |
+
input_img.save(temp_image_path)
|
219 |
+
|
220 |
+
# Instantiate your pipeline with the uploaded image
|
221 |
+
SB_ControlNet_pipeline = StableDiffusionControlnetPipeline()
|
222 |
+
|
223 |
+
# Process the image using your pipeline
|
224 |
+
output_img = SB_ControlNet_pipeline.stable_diffusion_controlnet(temp_image_path)
|
225 |
+
|
226 |
+
return output_img
|
227 |
+
|
228 |
+
# Create a Gradio interface
|
229 |
+
iface = gr.Interface(
|
230 |
+
fn=process_image,
|
231 |
+
inputs=gr.Image(type="pil", label="Upload Image"),
|
232 |
+
outputs="image",
|
233 |
+
title="Hair Root Touch Up using AI!",
|
234 |
+
description="Upload an image to edit hair roots using Stable Diffusion Controlnet:)"
|
235 |
+
)
|
236 |
+
|
237 |
+
# Launch the Gradio interface
|
238 |
+
iface.launch(debug=True)
|
239 |
+
|