Spaces:
Running
on
Zero
Running
on
Zero
import tempfile | |
import unittest | |
import numpy as np | |
import torch | |
from diffusers import DiffusionPipeline | |
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor | |
class AttnAddedKVProcessorTests(unittest.TestCase): | |
def get_constructor_arguments(self, only_cross_attention: bool = False): | |
query_dim = 10 | |
if only_cross_attention: | |
cross_attention_dim = 12 | |
else: | |
# when only cross attention is not set, the cross attention dim must be the same as the query dim | |
cross_attention_dim = query_dim | |
return { | |
"query_dim": query_dim, | |
"cross_attention_dim": cross_attention_dim, | |
"heads": 2, | |
"dim_head": 4, | |
"added_kv_proj_dim": 6, | |
"norm_num_groups": 1, | |
"only_cross_attention": only_cross_attention, | |
"processor": AttnAddedKVProcessor(), | |
} | |
def get_forward_arguments(self, query_dim, added_kv_proj_dim): | |
batch_size = 2 | |
hidden_states = torch.rand(batch_size, query_dim, 3, 2) | |
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) | |
attention_mask = None | |
return { | |
"hidden_states": hidden_states, | |
"encoder_hidden_states": encoder_hidden_states, | |
"attention_mask": attention_mask, | |
} | |
def test_only_cross_attention(self): | |
# self and cross attention | |
torch.manual_seed(0) | |
constructor_args = self.get_constructor_arguments(only_cross_attention=False) | |
attn = Attention(**constructor_args) | |
self.assertTrue(attn.to_k is not None) | |
self.assertTrue(attn.to_v is not None) | |
forward_args = self.get_forward_arguments( | |
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] | |
) | |
self_and_cross_attn_out = attn(**forward_args) | |
# only self attention | |
torch.manual_seed(0) | |
constructor_args = self.get_constructor_arguments(only_cross_attention=True) | |
attn = Attention(**constructor_args) | |
self.assertTrue(attn.to_k is None) | |
self.assertTrue(attn.to_v is None) | |
forward_args = self.get_forward_arguments( | |
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] | |
) | |
only_cross_attn_out = attn(**forward_args) | |
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) | |
class DeprecatedAttentionBlockTests(unittest.TestCase): | |
def test_conversion_when_using_device_map(self): | |
pipe = DiffusionPipeline.from_pretrained( | |
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None | |
) | |
pre_conversion = pipe( | |
"foo", | |
num_inference_steps=2, | |
generator=torch.Generator("cpu").manual_seed(0), | |
output_type="np", | |
).images | |
# the initial conversion succeeds | |
pipe = DiffusionPipeline.from_pretrained( | |
"hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None | |
) | |
conversion = pipe( | |
"foo", | |
num_inference_steps=2, | |
generator=torch.Generator("cpu").manual_seed(0), | |
output_type="np", | |
).images | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# save the converted model | |
pipe.save_pretrained(tmpdir) | |
# can also load the converted weights | |
pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None) | |
after_conversion = pipe( | |
"foo", | |
num_inference_steps=2, | |
generator=torch.Generator("cpu").manual_seed(0), | |
output_type="np", | |
).images | |
self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3)) | |
self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) | |