FloorAI / app.py
LuyangZ's picture
Update app.py
6430849 verified
raw
history blame
5.27 kB
import gradio
import cv2
from PIL import Image
import numpy as np
import spaces
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import accelerate
import transformers
from random import randrange
from transformers.utils.hub import move_cache
move_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_id = "runwayml/stable-diffusion-v1-5"
model_id = "LuyangZ/FloorAI"
# model_id = "LuyangZ/controlnet_Neufert4_64_100"
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype="auto")
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32, force_download=True)
controlnet = ControlNetModel.from_pretrained(model_id, force_download=True)
controlnet.to(device)
torch.cuda.empty_cache()
# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float32, force_download=True)
# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype="auto")
# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float16)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id, controlnet=controlnet, force_download=True)
pipeline.safety_checker = None
pipeline.requires_safety_checker = False
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
# pipeline.enable_xformers_memory_efficient_attention()
# pipeline.enable_model_cpu_offload()
# pipeline.enable_attention_slicing()
pipeline = pipeline.to(device)
torch.cuda.empty_cache()
def expand2square(ol_img, background_color):
width, height = ol_img.size
if width == height:
pad = int(width*0.2)
width_new = width + pad
halfpad = int(pad/2)
ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color)
ol_result.paste(ol_img, (halfpad, halfpad))
return ol_img
elif width > height:
pad = int(width*0.2)
width_new = width + pad
halfpad = int(pad/2)
ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color)
ol_result.paste(ol_img, (halfpad, (width_new - height) // 2))
return ol_result
else:
pad = int(height*0.2)
height_new = height + pad
halfpad = int(pad/2)
ol_result = Image.new(ol_img.mode, (height_new, height_new), background_color)
ol_result.paste(ol_img, ((height_new - width) // 2, halfpad))
return ol_result
def clean_img(image, mask):
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
mask = cv2.threshold(mask, 250, 255, cv2.THRESH_BINARY_INV)[1]
image[mask<250]=(255,255,255)
image = Image.fromarray(image).convert('RGB')
return image
@spaces.GPU
def floorplan_generation(outline, num_of_rooms):
new_width = 512
new_height = 512
outline = cv2.cvtColor(outline, cv2.COLOR_RGB2BGR)
outline_original = outline.copy()
gray = cv2.cvtColor(outline, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)[1]
x,y,w,h = cv2.boundingRect(thresh)
n_outline = outline_original[y:y+h, x:x+w]
n_outline = cv2.cvtColor(n_outline, cv2.COLOR_BGR2RGB)
n_outline = Image.fromarray(n_outline).convert('RGB')
n_outline = expand2square(n_outline, (255, 255, 255))
n_outline = n_outline.resize((new_width, new_height))
num_of_rooms = str(num_of_rooms)
validation_prompt = "floor plan, " + num_of_rooms + " rooms"
validation_image = n_outline
image_lst = []
for i in range(5):
seed = randrange(5000)
generator = torch.Generator(device=device).manual_seed(seed)
image = pipeline(validation_prompt,
validation_image,
num_inference_steps=20,
generator=generator).images[0]
image = np.array(image)
mask = np.array(n_outline)
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
image = clean_img(image, mask)
image_lst.append(image)
return image_lst[0], image_lst[1], image_lst[2], image_lst[3], image_lst[4]
gradio_interface = gradio.Interface(
fn=floorplan_generation,
inputs=[gradio.Image(label="Floor Plan Outline, Entrance"),
gradio.Textbox(type="text", label="Number of Rooms", placeholder="Number of Rooms")],
outputs=[gradio.Image(label="Generated Floor Plan 1"),
gradio.Image(label="Generated Floor Plan 2"),
gradio.Image(label="Generated Floor Plan 3"),
gradio.Image(label="Generated Floor Plan 4"),
gradio.Image(label="Generated Floor Plan 5")],
title="FloorAI",
examples=[["example_1.png", "4"], ["example_2.png", "3"], ["example_3.png", "2"], ["example_4.png", "4"], ["example_5.png", "4"]])
gradio_interface.queue(max_size=10, status_update_rate="auto", api_open=True)
gradio_interface.launch(share=True, show_api=True, show_error=True)