Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2024 HuggingFace Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import unittest | |
import torch | |
from diffusers import UNet1DModel | |
from diffusers.utils.testing_utils import ( | |
backend_manual_seed, | |
floats_tensor, | |
slow, | |
torch_device, | |
) | |
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin | |
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): | |
model_class = UNet1DModel | |
main_input_name = "sample" | |
def dummy_input(self): | |
batch_size = 4 | |
num_features = 14 | |
seq_len = 16 | |
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) | |
time_step = torch.tensor([10] * batch_size).to(torch_device) | |
return {"sample": noise, "timestep": time_step} | |
def input_shape(self): | |
return (4, 14, 16) | |
def output_shape(self): | |
return (4, 14, 16) | |
def test_ema_training(self): | |
pass | |
def test_training(self): | |
pass | |
def test_determinism(self): | |
super().test_determinism() | |
def test_outputs_equivalence(self): | |
super().test_outputs_equivalence() | |
def test_from_save_pretrained(self): | |
super().test_from_save_pretrained() | |
def test_from_save_pretrained_variant(self): | |
super().test_from_save_pretrained_variant() | |
def test_model_from_pretrained(self): | |
super().test_model_from_pretrained() | |
def test_output(self): | |
super().test_output() | |
def prepare_init_args_and_inputs_for_common(self): | |
init_dict = { | |
"block_out_channels": (8, 8, 16, 16), | |
"in_channels": 14, | |
"out_channels": 14, | |
"time_embedding_type": "positional", | |
"use_timestep_embedding": True, | |
"flip_sin_to_cos": False, | |
"freq_shift": 1.0, | |
"out_block_type": "OutConv1DBlock", | |
"mid_block_type": "MidResTemporalBlock1D", | |
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), | |
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), | |
"act_fn": "swish", | |
} | |
inputs_dict = self.dummy_input | |
return init_dict, inputs_dict | |
def test_from_pretrained_hub(self): | |
model, loading_info = UNet1DModel.from_pretrained( | |
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" | |
) | |
self.assertIsNotNone(model) | |
self.assertEqual(len(loading_info["missing_keys"]), 0) | |
model.to(torch_device) | |
image = model(**self.dummy_input) | |
assert image is not None, "Make sure output is not None" | |
def test_output_pretrained(self): | |
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") | |
torch.manual_seed(0) | |
backend_manual_seed(torch_device, 0) | |
num_features = model.config.in_channels | |
seq_len = 16 | |
noise = torch.randn((1, seq_len, num_features)).permute( | |
0, 2, 1 | |
) # match original, we can update values and remove | |
time_step = torch.full((num_features,), 0) | |
with torch.no_grad(): | |
output = model(noise, time_step).sample.permute(0, 2, 1) | |
output_slice = output[0, -3:, -3:].flatten() | |
# fmt: off | |
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348]) | |
# fmt: on | |
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) | |
def test_forward_with_norm_groups(self): | |
# Not implemented yet for this UNet | |
pass | |
def test_unet_1d_maestro(self): | |
model_id = "harmonai/maestro-150k" | |
model = UNet1DModel.from_pretrained(model_id, subfolder="unet") | |
model.to(torch_device) | |
sample_size = 65536 | |
noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device) | |
timestep = torch.tensor([1]).to(torch_device) | |
with torch.no_grad(): | |
output = model(noise, timestep).sample | |
output_sum = output.abs().sum() | |
output_max = output.abs().max() | |
assert (output_sum - 224.0896).abs() < 0.5 | |
assert (output_max - 0.0607).abs() < 4e-4 | |
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): | |
model_class = UNet1DModel | |
main_input_name = "sample" | |
def dummy_input(self): | |
batch_size = 4 | |
num_features = 14 | |
seq_len = 16 | |
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) | |
time_step = torch.tensor([10] * batch_size).to(torch_device) | |
return {"sample": noise, "timestep": time_step} | |
def input_shape(self): | |
return (4, 14, 16) | |
def output_shape(self): | |
return (4, 14, 1) | |
def test_determinism(self): | |
super().test_determinism() | |
def test_outputs_equivalence(self): | |
super().test_outputs_equivalence() | |
def test_from_save_pretrained(self): | |
super().test_from_save_pretrained() | |
def test_from_save_pretrained_variant(self): | |
super().test_from_save_pretrained_variant() | |
def test_model_from_pretrained(self): | |
super().test_model_from_pretrained() | |
def test_output(self): | |
# UNetRL is a value-function is different output shape | |
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | |
model = self.model_class(**init_dict) | |
model.to(torch_device) | |
model.eval() | |
with torch.no_grad(): | |
output = model(**inputs_dict) | |
if isinstance(output, dict): | |
output = output.sample | |
self.assertIsNotNone(output) | |
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) | |
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") | |
def test_ema_training(self): | |
pass | |
def test_training(self): | |
pass | |
def prepare_init_args_and_inputs_for_common(self): | |
init_dict = { | |
"in_channels": 14, | |
"out_channels": 14, | |
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], | |
"up_block_types": [], | |
"out_block_type": "ValueFunction", | |
"mid_block_type": "ValueFunctionMidBlock1D", | |
"block_out_channels": [32, 64, 128, 256], | |
"layers_per_block": 1, | |
"downsample_each_block": True, | |
"use_timestep_embedding": True, | |
"freq_shift": 1.0, | |
"flip_sin_to_cos": False, | |
"time_embedding_type": "positional", | |
"act_fn": "mish", | |
} | |
inputs_dict = self.dummy_input | |
return init_dict, inputs_dict | |
def test_from_pretrained_hub(self): | |
value_function, vf_loading_info = UNet1DModel.from_pretrained( | |
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" | |
) | |
self.assertIsNotNone(value_function) | |
self.assertEqual(len(vf_loading_info["missing_keys"]), 0) | |
value_function.to(torch_device) | |
image = value_function(**self.dummy_input) | |
assert image is not None, "Make sure output is not None" | |
def test_output_pretrained(self): | |
value_function, vf_loading_info = UNet1DModel.from_pretrained( | |
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" | |
) | |
torch.manual_seed(0) | |
backend_manual_seed(torch_device, 0) | |
num_features = value_function.config.in_channels | |
seq_len = 14 | |
noise = torch.randn((1, seq_len, num_features)).permute( | |
0, 2, 1 | |
) # match original, we can update values and remove | |
time_step = torch.full((num_features,), 0) | |
with torch.no_grad(): | |
output = value_function(noise, time_step).sample | |
# fmt: off | |
expected_output_slice = torch.tensor([165.25] * seq_len) | |
# fmt: on | |
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) | |
def test_forward_with_norm_groups(self): | |
# Not implemented yet for this UNet | |
pass | |