|
import logging |
|
import os |
|
import time |
|
|
|
import cv2 |
|
from diffusers import StableDiffusionPipeline |
|
import gradio as gr |
|
import mediapipe as mp |
|
import numpy as np |
|
import PIL |
|
import torch.cuda |
|
|
|
|
|
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
force=True) |
|
|
|
LOG = logging.getLogger(__name__) |
|
|
|
LOG.info("Loading image segmentation model") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mp_selfie_segmentation = mp.solutions.selfie_segmentation |
|
img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0) |
|
|
|
|
|
LOG.info("Loading diffusion model") |
|
|
|
diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
|
|
|
if torch.cuda.is_available(): |
|
LOG.info("Moving diffusion model to GPU") |
|
diffusion.to('cuda') |
|
|
|
|
|
def image_preprocess(image: PIL.Image): |
|
LOG.info("Preprocessing image %s", image) |
|
start = time.time() |
|
|
|
image = image.convert("RGB") |
|
image = resize_image(image) |
|
image = np.array(image) |
|
|
|
image = image[:, :, ::-1].copy() |
|
elapsed = time.time() - start |
|
LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed) |
|
return image |
|
|
|
|
|
def resize_image(image: PIL.Image): |
|
width, height = image.size |
|
ratio = max(width / 512, height / 512) |
|
width = int(width / ratio) // 8 * 8 |
|
height = int(height / ratio) // 8 * 8 |
|
image = image.resize((width, height)) |
|
return image |
|
|
|
|
|
def extract_selfie_mask(threshold, image): |
|
LOG.info("Extracting selfie mask") |
|
start = time.time() |
|
image = img_segmentation_model.process(image) |
|
mask = image.segmentation_mask |
|
cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask) |
|
cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask) |
|
cv2.blur(mask, (10, 10), dst=mask) |
|
|
|
elapsed = time.time() - start |
|
LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed) |
|
return mask |
|
|
|
|
|
def generate_background(prompt, num_inference_steps, height, width): |
|
LOG.info("Generating background") |
|
start = time.time() |
|
background = diffusion( |
|
prompt=prompt, |
|
num_inference_steps=int(num_inference_steps), |
|
height=height, |
|
width=width |
|
) |
|
nsfw = background.nsfw_content_detected[0] |
|
background = background.images[0] |
|
|
|
if nsfw: |
|
LOG.info('NSFW detected, skipping') |
|
background = np.zeros((height, width, 3), dtype='uint8') |
|
else: |
|
background = np.array(background) |
|
|
|
background = background[:, :, ::-1].copy() |
|
|
|
elapsed = time.time() - start |
|
LOG.info("Background generated, elapsed %.2f seconds", elapsed) |
|
return background |
|
|
|
|
|
def merge_selfie_and_background(selfie, background, mask): |
|
LOG.info("Merging extracted selfie and generated background") |
|
cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie) |
|
selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB) |
|
selfie = PIL.Image.fromarray(selfie) |
|
return selfie |
|
|
|
|
|
def demo(threshold, image, prompt, num_inference_steps): |
|
LOG.info("Processing image") |
|
try: |
|
image = image_preprocess(image) |
|
mask = extract_selfie_mask(threshold, image) |
|
background = generate_background(prompt, num_inference_steps, |
|
image.shape[0], image.shape[1]) |
|
output = merge_selfie_and_background(image, background, mask) |
|
except Exception as e: |
|
LOG.error("Some unexpected error occured") |
|
LOG.exception(e) |
|
raise |
|
return output |
|
|
|
|
|
iface = gr.Interface( |
|
fn=demo, |
|
inputs=[ |
|
gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold", |
|
value=0.8), |
|
gr.Image(type='pil', label="Upload your selfie"), |
|
gr.Text(value="a photo of the Eiffel tower on the right side", |
|
label="Background description"), |
|
gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps", |
|
value=50) |
|
], |
|
outputs=[ |
|
gr.Image(label="Invent yourself a life :)") |
|
]) |
|
|
|
|
|
iface.launch() |
|
|