cocktailpeanut commited on
Commit
ca3008d
1 Parent(s): 6704916
Files changed (1) hide show
  1. src/pix2pix_turbo.py +2 -1
src/pix2pix_turbo.py CHANGED
@@ -109,7 +109,8 @@ class Pix2Pix_Turbo(torch.nn.Module):
109
  _sd_unet = unet.state_dict()
110
  for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k]
111
  unet.load_state_dict(_sd_unet)
112
- unet.enable_xformers_memory_efficient_attention()
 
113
  _sd_vae = vae.state_dict()
114
  for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k]
115
  vae.load_state_dict(_sd_vae)
 
109
  _sd_unet = unet.state_dict()
110
  for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k]
111
  unet.load_state_dict(_sd_unet)
112
+ if device == "cuda":
113
+ unet.enable_xformers_memory_efficient_attention()
114
  _sd_vae = vae.state_dict()
115
  for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k]
116
  vae.load_state_dict(_sd_vae)