File size: 4,431 Bytes
154c8c5 22587f6 154c8c5 de3e780 154c8c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import logging
import os
import time
import cv2
from diffusers import StableDiffusionPipeline
import gradio as gr
import mediapipe as mp
import numpy as np
import PIL
import torch.cuda
# from transformers import pipeline
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
force=True)
LOG = logging.getLogger(__name__)
LOG.info("Loading image segmentation model")
# seg_kwargs = {
# "task": "image-segmentation",
# "model": "nvidia/segformer-b0-finetuned-ade-512-512"
# }
#
# img_segmentation = pipeline(**seg_kwargs)
mp_selfie_segmentation = mp.solutions.selfie_segmentation
img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)
LOG.info("Loading diffusion model")
diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
if torch.cuda.is_available():
LOG.info("Moving diffusion model to GPU")
diffusion.to('cuda')
def image_preprocess(image: PIL.Image):
LOG.info("Preprocessing image %s", image)
start = time.time()
# image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = resize_image(image)
image = np.array(image)
# Convert RGB to BGR
image = image[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed)
return image
def resize_image(image: PIL.Image):
width, height = image.size
ratio = max(width / 512, height / 512)
width = int(width / ratio) // 8 * 8
height = int(height / ratio) // 8 * 8
image = image.resize((width, height))
return image
def extract_selfie_mask(threshold, image):
LOG.info("Extracting selfie mask")
start = time.time()
image = img_segmentation_model.process(image)
mask = image.segmentation_mask
cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask)
cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask)
cv2.blur(mask, (10, 10), dst=mask)
elapsed = time.time() - start
LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed)
return mask
def generate_background(prompt, num_inference_steps, height, width):
LOG.info("Generating background")
start = time.time()
background = diffusion(
prompt=prompt,
num_inference_steps=int(num_inference_steps),
height=height,
width=width
)
nsfw = background.nsfw_content_detected[0]
background = background.images[0]
if nsfw:
LOG.info('NSFW detected, skipping')
background = np.zeros((height, width, 3), dtype='uint8')
else:
background = np.array(background)
# Convert RGB to BGR
background = background[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Background generated, elapsed %.2f seconds", elapsed)
return background
def merge_selfie_and_background(selfie, background, mask):
LOG.info("Merging extracted selfie and generated background")
cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie)
selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB)
selfie = PIL.Image.fromarray(selfie)
return selfie
def demo(threshold, image, prompt, num_inference_steps):
LOG.info("Processing image")
try:
image = image_preprocess(image)
mask = extract_selfie_mask(threshold, image)
background = generate_background(prompt, num_inference_steps,
image.shape[0], image.shape[1])
output = merge_selfie_and_background(image, background, mask)
except Exception as e:
LOG.error("Some unexpected error occured")
LOG.exception(e)
raise
return output
iface = gr.Interface(
fn=demo,
inputs=[
gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold",
value=0.8),
gr.Image(type='pil', label="Upload your selfie"),
gr.Text(value="a photo of the Eiffel tower on the right side",
label="Background description"),
gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps",
value=50)
],
outputs=[
gr.Image(label="Invent yourself a life :)")
])
# iface.launch(server_name="0.0.0.0", server_port=6443)
iface.launch()
|