cocktailpeanut commited on
Commit
824b2b5
1 Parent(s): da30812
Files changed (2) hide show
  1. app.py +30 -22
  2. requirements.txt +1 -1
app.py CHANGED
@@ -21,26 +21,26 @@ DESCRIPTION += "\n<p style=\"text-align: center\">Unofficial demo for <a href='h
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  CACHE_EXAMPLES = False
24
- #CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0"
25
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
26
  USE_TORCH_COMPILE = False
27
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
 
28
  PREVIEW_IMAGES = True
29
 
30
  dtype = torch.bfloat16
31
- #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
  if torch.cuda.is_available():
33
- device = "cuda"
34
  elif torch.backends.mps.is_available():
35
- device = "mps"
36
- dtype = torch.float32
37
  else:
38
- device = "cpu"
 
39
  print(f"device={device}")
40
- #if torch.cuda.is_available():
41
  if device != "cpu":
42
- prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype, revision="refs/pr/2")#.to(device)
43
- decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype, revision="refs/pr/44")#.to(device)
44
 
45
  if ENABLE_CPU_OFFLOAD:
46
  prior_pipeline.enable_model_cpu_offload()
@@ -57,10 +57,12 @@ if device != "cpu":
57
  previewer = Previewer()
58
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
59
  previewer.load_state_dict(previewer_state_dict)
60
- def callback_prior(i, t, latents):
 
61
  output = previewer(latents)
62
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
63
- return output
 
64
  callback_steps = 1
65
  else:
66
  previewer = None
@@ -92,14 +94,19 @@ def generate(
92
  num_images_per_prompt: int = 2,
93
  # profile: gr.OAuthProfile | None = None,
94
  ) -> PIL.Image.Image:
 
95
  try:
96
  previewer.eval().requires_grad_(False).to(device).to(dtype)
97
  except:
98
  print("")
99
- prior_pipeline.to(device)
100
- decoder_pipeline.to(device)
101
 
 
 
 
 
 
102
  generator = torch.Generator().manual_seed(seed)
 
103
  prior_output = prior_pipeline(
104
  prompt=prompt,
105
  height=height,
@@ -110,15 +117,17 @@ def generate(
110
  guidance_scale=prior_guidance_scale,
111
  num_images_per_prompt=num_images_per_prompt,
112
  generator=generator,
113
- callback=callback_prior,
114
- callback_steps=callback_steps
115
  )
116
-
117
  if PREVIEW_IMAGES:
118
  for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
119
- r = next(prior_output)
120
- if isinstance(r, list):
121
- yield r[0]
 
 
 
122
  prior_output = r
123
 
124
  decoder_output = decoder_pipeline(
@@ -131,7 +140,6 @@ def generate(
131
  generator=generator,
132
  output_type="pil",
133
  ).images
134
-
135
  # #Save images
136
  # for image in decoder_output:
137
  # user_history.save_image(
@@ -156,7 +164,7 @@ def generate(
156
  examples = [
157
  "An astronaut riding a green horse",
158
  "A mecha robot in a favela by Tarsila do Amaral",
159
- "The sprirt of a Tamagotchi wandering in the city of Los Angeles",
160
  "A delicious feijoada ramen dish"
161
  ]
162
 
@@ -289,4 +297,4 @@ with gr.Blocks(css="style.css") as demo_with_history:
289
  # user_history.render()
290
 
291
  if __name__ == "__main__":
292
- demo_with_history.queue(max_size=20).launch()
 
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  CACHE_EXAMPLES = False
 
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
+ #PREVIEW_IMAGES = False
28
  PREVIEW_IMAGES = True
29
 
30
  dtype = torch.bfloat16
 
31
  if torch.cuda.is_available():
32
+ device = "cuda"
33
  elif torch.backends.mps.is_available():
34
+ device = "mps"
35
+ dtype = torch.float32
36
  else:
37
+ device = "cpu"
38
+
39
  print(f"device={device}")
40
+
41
  if device != "cpu":
42
+ prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
43
+ decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
44
 
45
  if ENABLE_CPU_OFFLOAD:
46
  prior_pipeline.enable_model_cpu_offload()
 
57
  previewer = Previewer()
58
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
59
  previewer.load_state_dict(previewer_state_dict)
60
+ def callback_prior(pipeline, step_index, t, callback_kwargs):
61
+ latents = callback_kwargs["latents"]
62
  output = previewer(latents)
63
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
64
+ callback_kwargs["preview_output"] = output
65
+ return callback_kwargs
66
  callback_steps = 1
67
  else:
68
  previewer = None
 
94
  num_images_per_prompt: int = 2,
95
  # profile: gr.OAuthProfile | None = None,
96
  ) -> PIL.Image.Image:
97
+
98
  try:
99
  previewer.eval().requires_grad_(False).to(device).to(dtype)
100
  except:
101
  print("")
 
 
102
 
103
+ #previewer.eval().requires_grad_(False).to(device).to(dtype)
104
+ # if device != "cpu":
105
+ # prior_pipeline.to(device)
106
+ # decoder_pipeline.to(device)
107
+ #
108
  generator = torch.Generator().manual_seed(seed)
109
+ print("prior_num_inference_steps: ", prior_num_inference_steps)
110
  prior_output = prior_pipeline(
111
  prompt=prompt,
112
  height=height,
 
117
  guidance_scale=prior_guidance_scale,
118
  num_images_per_prompt=num_images_per_prompt,
119
  generator=generator,
120
+ #callback_on_step_end=callback_prior,
121
+ #callback_on_step_end_tensor_inputs=['latents']
122
  )
 
123
  if PREVIEW_IMAGES:
124
  for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
125
+ try:
126
+ r = next(prior_output)
127
+ if isinstance(r, list):
128
+ yield r[0]
129
+ except:
130
+ print("")
131
  prior_output = r
132
 
133
  decoder_output = decoder_pipeline(
 
140
  generator=generator,
141
  output_type="pil",
142
  ).images
 
143
  # #Save images
144
  # for image in decoder_output:
145
  # user_history.save_image(
 
164
  examples = [
165
  "An astronaut riding a green horse",
166
  "A mecha robot in a favela by Tarsila do Amaral",
167
+ "The spirit of a Tamagotchi wandering in the city of Los Angeles",
168
  "A delicious feijoada ramen dish"
169
  ]
170
 
 
297
  # user_history.render()
298
 
299
  if __name__ == "__main__":
300
+ demo_with_history.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  #git+https://github.com/kashif/diffusers.git@diffusers-yield-callback
2
  #git+https://github.com/kashif/diffusers.git@a3dc21385b7386beb3dab3a9845962ede6765887
3
- diffusers
4
  accelerate
5
  safetensors
6
  transformers
 
1
  #git+https://github.com/kashif/diffusers.git@diffusers-yield-callback
2
  #git+https://github.com/kashif/diffusers.git@a3dc21385b7386beb3dab3a9845962ede6765887
3
+ git+https://github.com/apolinario/diffusers.git@yield-new-pipe
4
  accelerate
5
  safetensors
6
  transformers