amazonaws-la commited on
Commit
1e95d75
1 Parent(s): 210bcbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -4
app.py CHANGED
@@ -10,7 +10,7 @@ import numpy as np
10
  import PIL.Image
11
  import spaces
12
  import torch
13
- from diffusers import StableDiffusionKDiffusionPipeline, AutoencoderKL, DiffusionPipeline
14
 
15
  DESCRIPTION = "# SDXL"
16
  if not torch.cuda.is_available():
@@ -55,8 +55,8 @@ def generate(
55
  lora = 'amazonaws-la/juliette',
56
  ) -> PIL.Image.Image:
57
  if torch.cuda.is_available():
58
- pipe = StableDiffusionKDiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
59
- pipe.set_scheduler('sample_dpmpp_2m_sde')
60
  if use_lora:
61
  pipe.load_lora_weights(lora)
62
  pipe.fuse_lora(lora_scale=0.7)
@@ -90,9 +90,32 @@ def generate(
90
  guidance_scale=guidance_scale_base,
91
  num_inference_steps=num_inference_steps_base,
92
  generator=generator,
93
- use_karras_sigmas=True,
94
  output_type="pil",
95
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  examples = [
 
10
  import PIL.Image
11
  import spaces
12
  import torch
13
+ from diffusers import AutoencoderKL, DiffusionPipeline
14
 
15
  DESCRIPTION = "# SDXL"
16
  if not torch.cuda.is_available():
 
55
  lora = 'amazonaws-la/juliette',
56
  ) -> PIL.Image.Image:
57
  if torch.cuda.is_available():
58
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16, safety_checker=None)
59
+
60
  if use_lora:
61
  pipe.load_lora_weights(lora)
62
  pipe.fuse_lora(lora_scale=0.7)
 
90
  guidance_scale=guidance_scale_base,
91
  num_inference_steps=num_inference_steps_base,
92
  generator=generator,
 
93
  output_type="pil",
94
  ).images[0]
95
+ else:
96
+ latents = pipe(
97
+ prompt=prompt,
98
+ negative_prompt=negative_prompt,
99
+ prompt_2=prompt_2,
100
+ negative_prompt_2=negative_prompt_2,
101
+ width=width,
102
+ height=height,
103
+ guidance_scale=guidance_scale_base,
104
+ num_inference_steps=num_inference_steps_base,
105
+ generator=generator,
106
+ output_type="latent",
107
+ ).images
108
+ image = refiner(
109
+ prompt=prompt,
110
+ negative_prompt=negative_prompt,
111
+ prompt_2=prompt_2,
112
+ negative_prompt_2=negative_prompt_2,
113
+ guidance_scale=guidance_scale_refiner,
114
+ num_inference_steps=num_inference_steps_refiner,
115
+ image=latents,
116
+ generator=generator,
117
+ ).images[0]
118
+ return image
119
 
120
 
121
  examples = [