SD 2.1 unet NaNs in output
Hey Ollin, I'm sorry if this isnt' the best place to begin a discussion on the issue, but I was wondering if you have any idea why SD 2.1-v's unet produces NaNs in some positions for each iteration in Diffusers, Auto, and ComfyUI?
I did some mucking-about and discovered that running the VAE in fp32 doesn't help avoid black outputs there, because each step of iteration is on bad latents with many NaN positions. doing stupid things like setting these positions to zero or 1e-10 doesn't fix anything, it only reveals that the model is returning pure noise.
2.1-base (512px) seems unaffected by the problem, but other v-prediction models (eg. ptx0/terminus-xl-gamma-v2-1) does not exhibit the problem at fp16 precision level, albeit that uses OpenCLIP-G/14 and CLIP-L/14 instead of OpenCLIP-H/14.
Hmm, do you have an example notebook/script reproducing the issue? If it's a model problem (like the SDXL VAE NaNs), I have some helper code for logging activations that could narrow things down.
essentially. manually casting the unet to fp16 for stabilityai/stable-diffusion-2-1
will reproduce the issue:
pipe = DiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1')
pipe.unet = pipe.unet.to(torch.float16)
the vae and text encoder can run at fp16 just fine.
also, we discovered then that the single file ckpt non-ema pruned 2.1-v file works correctly at fp16.
Digging around, it looks like Patrick summarized the sd2.1-768 fp16 issue and mitigations here https://github.com/huggingface/diffusers/issues/1614#issuecomment-1385093065. I haven't figured out how to repro the NaNs in latest diffusers yet (I guess the mitigations are working)
In my quick test, it looks like the overflowing values are probably in the attn1 of the last two up_blocks. So maybe you only need to upcast those? But I still haven't figured out how to disable upcast and repro NaNs :p
import torch as th
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=th.float16)
pipe.unet = pipe.unet.to(th.float16)
pipe = pipe.to("cuda")
def summarize_tensor(x):
color = [color for max_val, color in ((10, "2"), (100, "3"), (1000, "1")) if max_val > x.abs().max().item()][0]
return f"\033[3{color}m (min {x.min().item():04.4f} / mean {x.mean().item():04.4f}m / max {x.max().item():04.4f}\033[0m)"
class ModelActivationPrinter:
def __init__(self, module, submodules_to_log):
self.id_to_name = {
id(module): str(name) for name, module in module.named_modules()
}
self.submodules = submodules_to_log
self.hooks = []
def __enter__(self, *args, **kwargs):
def log_activations(m, m_in, m_out):
label = self.id_to_name.get(id(m), "(unnamed)") + " output"
if isinstance(m_out, (tuple, list)):
m_out = m_out[0]
label += "[0]"
print(label.ljust(96) + summarize_tensor(m_out))
for m in self.submodules:
self.hooks.append(m.register_forward_hook(log_activations))
return self
def __exit__(self, *args, **kwargs):
for hook in self.hooks:
hook.remove()
def select_modules(model):
modules = []
for m in pipe.unet.modules():
if hasattr(m, "to_q"):
modules.append(m.to_q)
if hasattr(m, "to_k"):
modules.append(m.to_k)
return modules
with ModelActivationPrinter(pipe.unet, select_modules(pipe.unet)):
image = pipe("slice of delicious New York-style berry cheesecake", num_inference_steps=15).images[0]
display(image)
@madebyollin that code didn't fix it for me. I'm still getting the black outputs with MPS at fp16 on an M2 with sd2-1 (but not with sd2-1-base). Neither did setting upcast_attention=True when loading the model, as Patrick suggested on that GitHub issue.
have you enabled unet attention slicing?
Ah yes, with pipe.enable_attention_slicing() it works, whether or not I do anything else. That alone fixes the black outputs with fp16. Thank you!
you can also enable it for controlnets :D
Thanks for the help!! Do you have any idea why attention slicing alone fixes it?
unfortunately i haven't a good answer but a foggy notion would be that the MPS driver backend has limited address space vs CUDA and so large calculations might overflow
there could be a more trivial reason like torch not supporting int64 on MPS systems, or something like that