cocktailpeanut commited on
Commit
56fa2d2
1 Parent(s): 8f4bc3f
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -43,8 +43,12 @@ if device != "cpu":
43
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
44
 
45
  if ENABLE_CPU_OFFLOAD:
46
- prior_pipeline.enable_model_cpu_offload()
47
- decoder_pipeline.enable_model_cpu_offload()
 
 
 
 
48
  else:
49
  prior_pipeline.to(device)
50
  decoder_pipeline.to(device)
@@ -101,10 +105,10 @@ def generate(
101
  print("")
102
 
103
  #previewer.eval().requires_grad_(False).to(device).to(dtype)
104
- # if device != "cpu":
105
- # prior_pipeline.to(device)
106
- # decoder_pipeline.to(device)
107
- #
108
  generator = torch.Generator().manual_seed(seed)
109
  print("prior_num_inference_steps: ", prior_num_inference_steps)
110
  prior_output = prior_pipeline(
 
43
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
44
 
45
  if ENABLE_CPU_OFFLOAD:
46
+ if device == "mps":
47
+ prior_pipeline.enable_attention_slicing()
48
+ decoder_pipeline.enable_attention_slicing()
49
+ else:
50
+ prior_pipeline.enable_model_cpu_offload()
51
+ decoder_pipeline.enable_model_cpu_offload()
52
  else:
53
  prior_pipeline.to(device)
54
  decoder_pipeline.to(device)
 
105
  print("")
106
 
107
  #previewer.eval().requires_grad_(False).to(device).to(dtype)
108
+ if device != "cpu":
109
+ prior_pipeline.to(device)
110
+ decoder_pipeline.to(device)
111
+
112
  generator = torch.Generator().manual_seed(seed)
113
  print("prior_num_inference_steps: ", prior_num_inference_steps)
114
  prior_output = prior_pipeline(