CerealBoxMaker / app.py
AP123's picture
fp 16 to allow for T4 inference
1e875da
import gradio as gr
import torch
import numpy as np
from PIL import Image
import random
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
pipeline.to("cuda:0")
MAX_SEED = np.iinfo(np.int32).max
def text_to_image(prompt):
seed = random.randint(0, MAX_SEED)
negative_prompt = "ugly, blurry, nsfw, gore, blood"
output = pipeline(prompt=prompt, negative_prompt=negative_prompt, width=1024, height=1024, guidance_scale=7.0, num_inference_steps=25, generator=torch.Generator().manual_seed(seed))
generated_img = output.images[0]
generated_img_array = np.array(generated_img)
return generated_img_array
def create_cereal_box(input_image):
cover_img = Image.fromarray(input_image.astype('uint8'), 'RGB')
template_img = Image.open("template.jpeg")
scaling_factor = 1.5
rect_height = int(template_img.height * 0.32)
new_width = int(rect_height * 0.70)
cover_resized = cover_img.resize((new_width, rect_height), Image.LANCZOS)
new_width_scaled = int(new_width * scaling_factor)
new_height_scaled = int(rect_height * scaling_factor)
cover_resized_scaled = cover_resized.resize((new_width_scaled, new_height_scaled), Image.LANCZOS)
left_x = int(template_img.width * 0.085)
left_y = int((template_img.height - new_height_scaled) // 2 + template_img.height * 0.012)
left_position = (left_x, left_y)
right_x = int(template_img.width * 0.82) - new_width_scaled
right_y = left_y
right_position = (right_x, right_y)
template_copy = template_img.copy()
template_copy.paste(cover_resized_scaled, left_position)
template_copy.paste(cover_resized_scaled, right_position)
template_copy_array = np.array(template_copy)
return template_copy_array
def combined_function(prompt):
generated_img_array = text_to_image(prompt)
final_img = create_cereal_box(generated_img_array)
return final_img
with gr.Blocks() as app:
gr.HTML("<div style='text-align: center;'><h1>Cereal Box Maker 🥣</h1></div>")
gr.HTML("<div style='text-align: center;'><p>This application uses StableDiffusion XL to create any cereal box you could ever imagine!</p></div>")
gr.HTML("<div style='text-align: center;'><h3>Instructions:</h3><ol><li>Describe the cereal box you want to create and hit generate!</li><li>Print it out, cut the outside, fold the lines, and then tape!</li></ol></div>")
gr.HTML("<div style='text-align: center;'><p>A space by AP 🐧, follow me on <a href='https://twitter.com/angrypenguinPNG'>Twitter</a>! H/T to <a href='https://twitter.com/ostrisai'>OstrisAI</a> for their Cereal Box LoRA!</p></div>")
with gr.Row():
textbox = gr.Textbox(label="Describe your cereal box: Ex: 'Avengers Cereal'")
btn_generate = gr.Button("Generate", label="Generate")
with gr.Row():
output_img = gr.Image(label="Your Custom Cereal Box")
btn_generate.click(
combined_function,
inputs=[textbox],
outputs=[output_img]
)
app.queue(max_size=20, api_open=False)
app.launch()