Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,8 +13,9 @@ import torch
|
|
13 |
from diffusers import DiffusionPipeline
|
14 |
from typing import Tuple
|
15 |
|
16 |
-
#
|
17 |
-
device =
|
|
|
18 |
|
19 |
# Setup rules for bad words (ensure the prompts are kid-friendly)
|
20 |
bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
|
@@ -68,21 +69,21 @@ DESCRIPTION = """## Children's Sticker Generator
|
|
68 |
Generate fun and playful stickers for children using AI.
|
69 |
"""
|
70 |
|
71 |
-
if not torch.cuda.is_available():
|
72 |
-
DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
|
73 |
-
|
74 |
MAX_SEED = np.iinfo(np.int32).max
|
75 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
# Initialize
|
80 |
-
|
81 |
-
"SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
|
82 |
-
torch_dtype=torch.float16,
|
83 |
-
use_safetensors=True,
|
84 |
-
variant="fp16"
|
85 |
-
).to(device)
|
86 |
|
87 |
# Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
|
88 |
def mm_to_pixels(mm, dpi=300):
|
@@ -134,8 +135,15 @@ def generate(
|
|
134 |
guidance_scale: float = 3,
|
135 |
randomize_seed: bool = False,
|
136 |
background: str = "transparent",
|
|
|
137 |
progress=gr.Progress(track_tqdm=True),
|
138 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
if check_text(prompt, negative_prompt):
|
140 |
raise ValueError("Prompt contains restricted words.")
|
141 |
|
@@ -145,7 +153,7 @@ def generate(
|
|
145 |
# Apply style
|
146 |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
147 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
148 |
-
generator = torch.Generator().manual_seed(seed)
|
149 |
|
150 |
# Ensure we have only white or transparent background options
|
151 |
width, height = size_map.get(size, (1024, 1024))
|
@@ -241,6 +249,11 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
241 |
step=0.1,
|
242 |
value=15.7,
|
243 |
)
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
gr.Examples(
|
246 |
examples=examples,
|
@@ -267,6 +280,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
267 |
guidance_scale,
|
268 |
randomize_seed,
|
269 |
background_selection,
|
|
|
270 |
],
|
271 |
outputs=[result, seed],
|
272 |
api_name="run",
|
|
|
13 |
from diffusers import DiffusionPipeline
|
14 |
from typing import Tuple
|
15 |
|
16 |
+
# Initialize device to None
|
17 |
+
device = None
|
18 |
+
pipe = None
|
19 |
|
20 |
# Setup rules for bad words (ensure the prompts are kid-friendly)
|
21 |
bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
|
|
|
69 |
Generate fun and playful stickers for children using AI.
|
70 |
"""
|
71 |
|
|
|
|
|
|
|
72 |
MAX_SEED = np.iinfo(np.int32).max
|
73 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
|
74 |
|
75 |
+
def initialize_pipeline(device_type):
|
76 |
+
global device, pipe
|
77 |
+
device = torch.device(device_type)
|
78 |
+
pipe = DiffusionPipeline.from_pretrained(
|
79 |
+
"SG161222/RealVisXL_V3.0_Turbo",
|
80 |
+
torch_dtype=torch.float32 if device_type == "cpu" else torch.float16,
|
81 |
+
use_safetensors=True,
|
82 |
+
variant="fp32" if device_type == "cpu" else "fp16"
|
83 |
+
).to(device)
|
84 |
|
85 |
+
# Initialize with CPU by default
|
86 |
+
initialize_pipeline("cpu")
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
|
89 |
def mm_to_pixels(mm, dpi=300):
|
|
|
135 |
guidance_scale: float = 3,
|
136 |
randomize_seed: bool = False,
|
137 |
background: str = "transparent",
|
138 |
+
device_type: str = "cpu",
|
139 |
progress=gr.Progress(track_tqdm=True),
|
140 |
):
|
141 |
+
global device, pipe
|
142 |
+
|
143 |
+
# Switch device if necessary
|
144 |
+
if device.type != device_type:
|
145 |
+
initialize_pipeline(device_type)
|
146 |
+
|
147 |
if check_text(prompt, negative_prompt):
|
148 |
raise ValueError("Prompt contains restricted words.")
|
149 |
|
|
|
153 |
# Apply style
|
154 |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
155 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
156 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
157 |
|
158 |
# Ensure we have only white or transparent background options
|
159 |
width, height = size_map.get(size, (1024, 1024))
|
|
|
249 |
step=0.1,
|
250 |
value=15.7,
|
251 |
)
|
252 |
+
device_selection = gr.Radio(
|
253 |
+
choices=["cpu", "cuda"],
|
254 |
+
value="cpu",
|
255 |
+
label="Device",
|
256 |
+
)
|
257 |
|
258 |
gr.Examples(
|
259 |
examples=examples,
|
|
|
280 |
guidance_scale,
|
281 |
randomize_seed,
|
282 |
background_selection,
|
283 |
+
device_selection,
|
284 |
],
|
285 |
outputs=[result, seed],
|
286 |
api_name="run",
|