Spaces:
vilarin
/
Running on Zero

vilarin commited on
Commit
3eaeeea
1 Parent(s): 0fb0440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -10,8 +10,8 @@ import PIL
10
  base = "stabilityai/stable-diffusion-xl-base-1.0"
11
  repo = "tianweiy/DMD2"
12
  checkpoints = {
13
- "1-Step" : ["dmd2_sdxl_1step_unet.bin", 1],
14
- "4-Step" : ["dmd2_sdxl_4step_unet.bin", 4],
15
  }
16
  loaded = None
17
 
@@ -37,7 +37,7 @@ def generate_image(prompt, ckpt):
37
  num_inference_steps = checkpoints[ckpt][1]
38
 
39
  if loaded != num_inference_steps:
40
- unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoints)), map_location="cuda")
41
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
42
  loaded = num_inference_steps
43
 
@@ -51,7 +51,7 @@ def generate_image(prompt, ckpt):
51
 
52
  with gr.Blocks(css=CSS) as demo:
53
  gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
54
- gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>https://huggingface.co/tianweiy/DMD2</a> text-to-image generation</center></p>")
55
  with gr.Group():
56
  with gr.Row():
57
  prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
 
10
  base = "stabilityai/stable-diffusion-xl-base-1.0"
11
  repo = "tianweiy/DMD2"
12
  checkpoints = {
13
+ "1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
14
+ "4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
15
  }
16
  loaded = None
17
 
 
37
  num_inference_steps = checkpoints[ckpt][1]
38
 
39
  if loaded != num_inference_steps:
40
+ unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint)), map_location="cuda")
41
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
42
  loaded = num_inference_steps
43
 
 
51
 
52
  with gr.Blocks(css=CSS) as demo:
53
  gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
54
+ gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
55
  with gr.Group():
56
  with gr.Row():
57
  prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)