Spaces:
Running
on
Zero
Running
on
Zero
import pickle as pkl | |
import unittest | |
from dataclasses import dataclass | |
from typing import List, Union | |
import numpy as np | |
import PIL.Image | |
from diffusers.utils.outputs import BaseOutput | |
from diffusers.utils.testing_utils import require_torch | |
class CustomOutput(BaseOutput): | |
images: Union[List[PIL.Image.Image], np.ndarray] | |
class ConfigTester(unittest.TestCase): | |
def test_outputs_single_attribute(self): | |
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4)) | |
# check every way of getting the attribute | |
assert isinstance(outputs.images, np.ndarray) | |
assert outputs.images.shape == (1, 3, 4, 4) | |
assert isinstance(outputs["images"], np.ndarray) | |
assert outputs["images"].shape == (1, 3, 4, 4) | |
assert isinstance(outputs[0], np.ndarray) | |
assert outputs[0].shape == (1, 3, 4, 4) | |
# test with a non-tensor attribute | |
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) | |
# check every way of getting the attribute | |
assert isinstance(outputs.images, list) | |
assert isinstance(outputs.images[0], PIL.Image.Image) | |
assert isinstance(outputs["images"], list) | |
assert isinstance(outputs["images"][0], PIL.Image.Image) | |
assert isinstance(outputs[0], list) | |
assert isinstance(outputs[0][0], PIL.Image.Image) | |
def test_outputs_dict_init(self): | |
# test output reinitialization with a `dict` for compatibility with `accelerate` | |
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)}) | |
# check every way of getting the attribute | |
assert isinstance(outputs.images, np.ndarray) | |
assert outputs.images.shape == (1, 3, 4, 4) | |
assert isinstance(outputs["images"], np.ndarray) | |
assert outputs["images"].shape == (1, 3, 4, 4) | |
assert isinstance(outputs[0], np.ndarray) | |
assert outputs[0].shape == (1, 3, 4, 4) | |
# test with a non-tensor attribute | |
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]}) | |
# check every way of getting the attribute | |
assert isinstance(outputs.images, list) | |
assert isinstance(outputs.images[0], PIL.Image.Image) | |
assert isinstance(outputs["images"], list) | |
assert isinstance(outputs["images"][0], PIL.Image.Image) | |
assert isinstance(outputs[0], list) | |
assert isinstance(outputs[0][0], PIL.Image.Image) | |
def test_outputs_serialization(self): | |
outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) | |
serialized = pkl.dumps(outputs_orig) | |
outputs_copy = pkl.loads(serialized) | |
# Check original and copy are equal | |
assert dir(outputs_orig) == dir(outputs_copy) | |
assert dict(outputs_orig) == dict(outputs_copy) | |
assert vars(outputs_orig) == vars(outputs_copy) | |
def test_torch_pytree(self): | |
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) | |
# this is important for DistributedDataParallel gradient synchronization with static_graph=True | |
import torch | |
import torch.utils._pytree | |
data = np.random.rand(1, 3, 4, 4) | |
x = CustomOutput(images=data) | |
self.assertFalse(torch.utils._pytree._is_leaf(x)) | |
expected_flat_outs = [data] | |
expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()]) | |
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) | |
self.assertEqual(expected_flat_outs, actual_flat_outs) | |
self.assertEqual(expected_tree_spec, actual_tree_spec) | |
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) | |
self.assertEqual(x, unflattened_x) | |