Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import os | |
import random | |
import uuid | |
import json | |
import re | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import spaces | |
import torch | |
from diffusers import DiffusionPipeline | |
from typing import Tuple | |
# Initialize device to None | |
device = None | |
pipe = None | |
# Setup rules for bad words (ensure the prompts are kid-friendly) | |
bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]')) | |
default_negative = os.getenv("default_negative","") | |
def check_text(prompt, negative=""): | |
for i in bad_words: | |
if i in prompt: | |
return True | |
return False | |
# Kid-friendly styles | |
style_list = [ | |
{ | |
"name": "Cartoon", | |
"prompt": "colorful cartoon {prompt}. vibrant, playful, friendly, suitable for children, highly detailed, bright colors", | |
"negative_prompt": "scary, dark, violent, ugly, realistic", | |
}, | |
{ | |
"name": "Children's Illustration", | |
"prompt": "children's illustration {prompt}. cute, colorful, fun, simple shapes, smooth lines, highly detailed, joyful", | |
"negative_prompt": "scary, dark, violent, deformed, ugly", | |
}, | |
{ | |
"name": "Sticker", | |
"prompt": "children's sticker of {prompt}. bright colors, playful, high resolution, cartoonish", | |
"negative_prompt": "scary, dark, violent, ugly, low resolution", | |
}, | |
{ | |
"name": "Fantasy", | |
"prompt": "fantasy world for children with {prompt}. magical, vibrant, friendly, beautiful, colorful", | |
"negative_prompt": "dark, scary, violent, ugly, realistic", | |
}, | |
{ | |
"name": "(No style)", | |
"prompt": "{prompt}", | |
"negative_prompt": "", | |
}, | |
] | |
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} | |
STYLE_NAMES = list(styles.keys()) | |
DEFAULT_STYLE_NAME = "Sticker" | |
def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: | |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
return p.replace("{prompt}", positive), n + negative | |
DESCRIPTION = """## Children's Sticker Generator | |
Generate fun and playful stickers for children using AI. | |
""" | |
MAX_SEED = np.iinfo(np.int32).max | |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1" | |
def initialize_pipeline(device_type): | |
global device, pipe | |
device = torch.device(device_type) | |
pipe = DiffusionPipeline.from_pretrained( | |
"SG161222/RealVisXL_V3.0_Turbo", | |
torch_dtype=torch.float32 if device_type == "cpu" else torch.float16, | |
use_safetensors=True, | |
).to(device) | |
# Initialize with CPU by default | |
initialize_pipeline("cpu") | |
# Convert mm to pixels for a specific DPI (300) and ensure divisible by 8 | |
def mm_to_pixels(mm, dpi=300): | |
"""Convert mm to pixels and make the dimensions divisible by 8.""" | |
pixels = int((mm / 25.4) * dpi) | |
return pixels - (pixels % 8) # Adjust to the nearest lower multiple of 8 | |
# Default sizes for 75mm and 35mm, rounded to nearest multiple of 8 | |
size_map = { | |
"75mm": (mm_to_pixels(75), mm_to_pixels(75)), # 75mm in pixels at 300dpi | |
"35mm": (mm_to_pixels(35), mm_to_pixels(35)), # 35mm in pixels at 300dpi | |
} | |
# Function to post-process images (transparent or white background) | |
def save_image(img, background="transparent"): | |
img = img.convert("RGBA") | |
data = img.getdata() | |
new_data = [] | |
if background == "transparent": | |
for item in data: | |
# Replace white with transparent | |
if item[0] == 255 and item[1] == 255 and item[2] == 255: | |
new_data.append((255, 255, 255, 0)) # Transparent | |
else: | |
new_data.append(item) | |
elif background == "white": | |
for item in data: | |
new_data.append(item) # Keep as white | |
img.putdata(new_data) | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def generate( | |
prompt: str, | |
negative_prompt: str = "", | |
use_negative_prompt: bool = False, | |
style: str = DEFAULT_STYLE_NAME, | |
seed: int = 0, | |
size: str = "75mm", | |
guidance_scale: float = 3, | |
randomize_seed: bool = False, | |
background: str = "transparent", | |
device_type: str = "cpu", | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global device, pipe | |
# Switch device if necessary | |
if device.type != device_type: | |
initialize_pipeline(device_type) | |
if check_text(prompt, negative_prompt): | |
raise ValueError("Prompt contains restricted words.") | |
# Ensure prompt is 2-3 words long | |
prompt = " ".join(re.findall(r'\w+', prompt)[:3]) | |
# Apply style | |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt) | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# Ensure we have only white or transparent background options | |
width, height = size_map.get(size, (1024, 1024)) | |
if not use_negative_prompt: | |
negative_prompt = "" # type: ignore | |
options = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"width": width, | |
"height": height, | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": 25, | |
"generator": generator, | |
"num_images_per_prompt": 6, # Max 6 images | |
"output_type": "pil", | |
} | |
# Generate images with the pipeline | |
images = pipe(**options).images | |
image_paths = [save_image(img, background) for img in images] | |
return image_paths, seed | |
examples = [ | |
"cute bunny", | |
"happy cat", | |
"funny dog", | |
] | |
css = ''' | |
.gradio-container{max-width: 700px !important} | |
h1{text-align:center} | |
''' | |
# Define the Gradio UI for the sticker generator | |
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton( | |
value="Duplicate Space for private use", | |
elem_id="duplicate-button", | |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Enter your prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter 2-3 word prompt (e.g., cute bunny)", | |
container=False, | |
) | |
run_button = gr.Button("Run") | |
result = gr.Gallery(label="Generated Stickers", columns=2, preview=True) | |
with gr.Accordion("Advanced options", open=False): | |
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True) | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
value="(scary, violent, dark, ugly)", | |
visible=True, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
visible=True | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
size_selection = gr.Radio( | |
choices=["75mm", "35mm"], | |
value="75mm", | |
label="Sticker Size", | |
) | |
style_selection = gr.Radio( | |
choices=STYLE_NAMES, | |
value=DEFAULT_STYLE_NAME, | |
label="Image Style", | |
) | |
background_selection = gr.Radio( | |
choices=["transparent", "white"], | |
value="transparent", | |
label="Background Color", | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=0.1, | |
maximum=20.0, | |
step=0.1, | |
value=15.7, | |
) | |
device_selection = gr.Radio( | |
choices=["cpu", "cuda"], | |
value="cpu", | |
label="Device", | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=prompt, | |
outputs=[result, seed], | |
fn=generate, | |
cache_examples=CACHE_EXAMPLES, | |
) | |
gr.on( | |
triggers=[ | |
prompt.submit, | |
negative_prompt.submit, | |
run_button.click, | |
], | |
fn=generate, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
use_negative_prompt, | |
style_selection, | |
seed, | |
size_selection, | |
guidance_scale, | |
randomize_seed, | |
background_selection, | |
device_selection, | |
], | |
outputs=[result, seed], | |
api_name="run", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |