Jordan Legg commited on
Commit
b473829
β€’
1 Parent(s): d2c614b
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,19 +1,20 @@
1
- import spaces
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
  from diffusers import DiffusionPipeline
6
-
7
 
8
  # Constants
9
  MAX_SEED = 2**32 - 1
10
  MAX_IMAGE_SIZE = 2048
 
 
11
 
12
  # Load FLUX model
13
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
14
- pipe.to("cuda")
15
  pipe.enable_model_cpu_offload()
16
- pipe.enable_vae_slicing()
 
17
 
18
  def print_model_shapes(pipe):
19
  print("Model component shapes:")
@@ -26,7 +27,7 @@ print_model_shapes(pipe)
26
 
27
  @spaces.GPU()
28
  def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
29
- generator = torch.Generator(device="cuda").manual_seed(seed) if seed is not None else None
30
 
31
  try:
32
  if init_image is None:
 
 
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from diffusers import DiffusionPipeline
5
+ import spaces
6
 
7
  # Constants
8
  MAX_SEED = 2**32 - 1
9
  MAX_IMAGE_SIZE = 2048
10
+ dtype = torch.bfloat16
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  # Load FLUX model
14
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
15
  pipe.enable_model_cpu_offload()
16
+ pipe.vae.enable_slicing()
17
+ pipe.vae.enable_tiling()
18
 
19
  def print_model_shapes(pipe):
20
  print("Model component shapes:")
 
27
 
28
  @spaces.GPU()
29
  def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
30
+ generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
31
 
32
  try:
33
  if init_image is None: