Oranblock commited on
Commit
0e09841
1 Parent(s): 5024e57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -18
app.py CHANGED
@@ -11,10 +11,19 @@ from PIL import Image
11
  import torch
12
  from diffusers import DiffusionPipeline
13
  from typing import Tuple
 
14
 
15
- # Force CPU usage
16
- device = torch.device("cpu")
17
- torch.cuda.is_available = lambda: False
 
 
 
 
 
 
 
 
18
 
19
  # Setup rules for bad words (ensure the prompts are kid-friendly)
20
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
@@ -68,18 +77,32 @@ DESCRIPTION = """## Children's Sticker Generator
68
 
69
  Generate fun and playful stickers for children using AI.
70
 
71
- ⚠️ Running on CPU. This may be slower.
72
  """
 
73
 
74
  MAX_SEED = np.iinfo(np.int32).max
75
- CACHE_EXAMPLES = False
76
 
77
  # Initialize the DiffusionPipeline
78
- pipe = DiffusionPipeline.from_pretrained(
79
- "runwayml/stable-diffusion-v1-5", # Using a smaller model for CPU
80
- torch_dtype=torch.float32,
81
- use_safetensors=True,
82
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
85
  def mm_to_pixels(mm, dpi=300):
@@ -127,7 +150,7 @@ def generate(
127
  style: str = DEFAULT_STYLE_NAME,
128
  seed: int = 0,
129
  size: str = "75mm",
130
- guidance_scale: float = 3,
131
  randomize_seed: bool = False,
132
  background: str = "transparent",
133
  progress=gr.Progress(track_tqdm=True),
@@ -142,9 +165,9 @@ def generate(
142
  # Apply style
143
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
144
  seed = int(randomize_seed_fn(seed, randomize_seed))
145
- generator = torch.manual_seed(seed)
146
 
147
- width, height = size_map.get(size, (512, 512))
148
 
149
  if not use_negative_prompt:
150
  negative_prompt = ""
@@ -155,9 +178,9 @@ def generate(
155
  "width": width,
156
  "height": height,
157
  "guidance_scale": guidance_scale,
158
- "num_inference_steps": 20,
159
  "generator": generator,
160
- "num_images_per_prompt": 1, # Reduced to 1 for CPU
161
  "output_type": "pil",
162
  }
163
 
@@ -196,7 +219,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
196
  container=False,
197
  )
198
  run_button = gr.Button("Run")
199
- result = gr.Gallery(label="Generated Stickers", columns=1, preview=True)
200
  error_output = gr.Textbox(label="Error", visible=False)
201
  with gr.Accordion("Advanced options", open=False):
202
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
@@ -233,7 +256,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
233
  )
234
  guidance_scale = gr.Slider(
235
  label="Guidance Scale",
236
- minimum=0.1,
237
  maximum=20.0,
238
  step=0.1,
239
  value=7.5,
@@ -270,4 +293,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
270
  )
271
 
272
  if __name__ == "__main__":
273
- demo.queue(max_size=20).launch()
 
 
 
 
 
11
  import torch
12
  from diffusers import DiffusionPipeline
13
  from typing import Tuple
14
+ import logging
15
 
16
+ # Set up logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Check for GPU availability and fall back to CPU if necessary
21
+ if torch.cuda.is_available():
22
+ device = torch.device("cuda")
23
+ logger.info("GPU detected. Using CUDA.")
24
+ else:
25
+ device = torch.device("cpu")
26
+ logger.warning("No GPU detected. Falling back to CPU. This will be slower.")
27
 
28
  # Setup rules for bad words (ensure the prompts are kid-friendly)
29
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
 
77
 
78
  Generate fun and playful stickers for children using AI.
79
 
 
80
  """
81
+ DESCRIPTION += "🚀 Running on GPU for faster generation." if device.type == "cuda" else "⚠️ Running on CPU. This may be slower."
82
 
83
  MAX_SEED = np.iinfo(np.int32).max
84
+ CACHE_EXAMPLES = True
85
 
86
  # Initialize the DiffusionPipeline
87
+ try:
88
+ if device.type == "cuda":
89
+ pipe = DiffusionPipeline.from_pretrained(
90
+ "stabilityai/stable-diffusion-xl-base-1.0",
91
+ torch_dtype=torch.float16,
92
+ use_safetensors=True,
93
+ variant="fp16",
94
+ ).to(device)
95
+ pipe.enable_xformers_memory_efficient_attention()
96
+ else:
97
+ pipe = DiffusionPipeline.from_pretrained(
98
+ "runwayml/stable-diffusion-v1-5",
99
+ torch_dtype=torch.float32,
100
+ use_safetensors=True,
101
+ ).to(device)
102
+ logger.info("DiffusionPipeline initialized successfully")
103
+ except Exception as e:
104
+ logger.error(f"Error initializing DiffusionPipeline: {str(e)}")
105
+ raise
106
 
107
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
108
  def mm_to_pixels(mm, dpi=300):
 
150
  style: str = DEFAULT_STYLE_NAME,
151
  seed: int = 0,
152
  size: str = "75mm",
153
+ guidance_scale: float = 7.5,
154
  randomize_seed: bool = False,
155
  background: str = "transparent",
156
  progress=gr.Progress(track_tqdm=True),
 
165
  # Apply style
166
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
167
  seed = int(randomize_seed_fn(seed, randomize_seed))
168
+ generator = torch.Generator(device=device).manual_seed(seed)
169
 
170
+ width, height = size_map.get(size, (1024, 1024))
171
 
172
  if not use_negative_prompt:
173
  negative_prompt = ""
 
178
  "width": width,
179
  "height": height,
180
  "guidance_scale": guidance_scale,
181
+ "num_inference_steps": 30 if device.type == "cuda" else 20,
182
  "generator": generator,
183
+ "num_images_per_prompt": 4 if device.type == "cuda" else 1,
184
  "output_type": "pil",
185
  }
186
 
 
219
  container=False,
220
  )
221
  run_button = gr.Button("Run")
222
+ result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
223
  error_output = gr.Textbox(label="Error", visible=False)
224
  with gr.Accordion("Advanced options", open=False):
225
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
 
256
  )
257
  guidance_scale = gr.Slider(
258
  label="Guidance Scale",
259
+ minimum=1.0,
260
  maximum=20.0,
261
  step=0.1,
262
  value=7.5,
 
293
  )
294
 
295
  if __name__ == "__main__":
296
+ try:
297
+ demo.queue(max_size=20).launch()
298
+ except Exception as e:
299
+ logger.error(f"Error launching Gradio interface: {str(e)}")
300
+ raise