Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchvision | |
import torchaudio | |
import torchvision.transforms as transforms | |
from diffusers import UNet2DConditionModel, ControlNetModel | |
from foleycrafter.pipelines.pipeline_controlnet import StableDiffusionControlNetPipeline | |
from foleycrafter.pipelines.auffusion_pipeline import AuffusionNoAdapterPipeline, Generator | |
from foleycrafter.models.auffusion_unet import UNet2DConditionModel as af_UNet2DConditionModel | |
from diffusers.models import AutoencoderKLTemporalDecoder, AutoencoderKL | |
from diffusers.schedulers import EulerDiscreteScheduler, DDIMScheduler, PNDMScheduler, KarrasDiffusionSchedulers | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection,\ | |
SpeechT5HifiGan, ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast,\ | |
CLIPTextModel, CLIPTokenizer | |
import glob | |
from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip, VideoClip | |
from moviepy.audio.AudioClip import AudioArrayClip | |
import numpy as np | |
from safetensors import safe_open | |
import random | |
from typing import Union, Optional | |
import decord | |
import os | |
import os.path as osp | |
import imageio | |
import soundfile as sf | |
from PIL import Image, ImageOps | |
import torch.distributed as dist | |
import io | |
from omegaconf import OmegaConf | |
import json | |
from dataclasses import dataclass | |
from enum import Enum | |
import typing as T | |
import warnings | |
import pydub | |
from scipy.io import wavfile | |
from einops import rearrange | |
def zero_rank_print(s): | |
if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True) | |
def build_foleycrafter( | |
pretrained_model_name_or_path: str="auffusion/auffusion-full-no-adapter", | |
) -> StableDiffusionControlNetPipeline: | |
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | |
unet = af_UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | |
scheduler = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | |
controlnet = ControlNetModel.from_unet(unet, conditioning_channels=1) | |
pipe = StableDiffusionControlNetPipeline( | |
vae=vae, | |
controlnet=controlnet, | |
unet=unet, | |
scheduler=scheduler, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
feature_extractor=None, | |
safety_checker=None, | |
requires_safety_checker=False, | |
) | |
return pipe | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
if len(videos.shape) == 4: | |
videos = videos.unsqueeze(0) | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
def save_videos_from_pil_list(videos: list, path: str, fps=7): | |
for i in range(len(videos)): | |
videos[i] = ImageOps.scale(videos[i], 255) | |
imageio.mimwrite(path, videos, fps=fps) | |
def seed_everything(seed: int) -> None: | |
r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, | |
:obj:`numpy` and :python:`Python`. | |
Args: | |
seed (int): The desired seed. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def get_video_frames(video: np.ndarray, num_frames: int=200): | |
video_length = video.shape[0] | |
video_idx = np.linspace(0, video_length-1, num_frames, dtype=int) | |
video = video[video_idx, ...] | |
return video | |
def random_audio_video_clip(audio: np.ndarray, video: np.ndarray, fps:float, \ | |
sample_rate:int=16000, duration:int=5, num_frames: int=20): | |
""" | |
Random sample video clips with duration | |
""" | |
video_length = video.shape[0] | |
audio_length = audio.shape[-1] | |
av_duration = int(video_length / fps) | |
assert av_duration >= duration,\ | |
f"video duration {av_duration} is less than {duration}" | |
# random sample start time | |
start_time = random.uniform(0, av_duration - duration) | |
end_time = start_time + duration | |
start_idx, end_idx = start_time / av_duration, end_time / av_duration | |
video_start_frame, video_end_frame\ | |
= video_length * start_idx, video_length * end_idx | |
audio_start_frame, audio_end_frame\ | |
= audio_length * start_idx, audio_length * end_idx | |
# print(f"time_idx : {start_time}:{end_time}") | |
# print(f"video_idx: {video_start_frame}:{video_end_frame}") | |
# print(f"audio_idx: {audio_start_frame}:{audio_end_frame}") | |
audio_idx = np.linspace(audio_start_frame, audio_end_frame, sample_rate * duration, dtype=int) | |
video_idx = np.linspace(video_start_frame, video_end_frame, num_frames, dtype=int) | |
audio = audio[..., audio_idx] | |
video = video[video_idx, ...] | |
return audio, video | |
def get_full_indices(reader: Union[decord.VideoReader, decord.AudioReader])\ | |
-> np.ndarray: | |
if isinstance(reader, decord.VideoReader): | |
return np.linspace(0, len(reader) - 1, len(reader), dtype=int) | |
elif isinstance(reader, decord.AudioReader): | |
return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int) | |
def get_frames(video_path:str, onset_list, frame_nums=1024): | |
video = decord.VideoReader(video_path) | |
video_frame = len(video) | |
frames_list = [] | |
for start, end in onset_list: | |
video_start = int(start / frame_nums * video_frame) | |
video_end = int(end / frame_nums * video_frame) | |
frames_list.extend(range(video_start, video_end)) | |
frames = video.get_batch(frames_list).asnumpy() | |
return frames | |
def get_frames_in_video(video_path:str, onset_list, frame_nums=1024, audio_length_in_s=10): | |
# this function consider the video length | |
video = decord.VideoReader(video_path) | |
video_frame = len(video) | |
duration = video_frame / video.get_avg_fps() | |
frames_list = [] | |
video_onset_list = [] | |
for start, end in onset_list: | |
if int(start / frame_nums * duration) >= audio_length_in_s: | |
continue | |
video_start = int(start / audio_length_in_s * duration / frame_nums * video_frame) | |
if video_start >= video_frame: | |
continue | |
video_end = int(end / audio_length_in_s * duration / frame_nums * video_frame) | |
video_onset_list.append([int(start / audio_length_in_s * duration), int(end / audio_length_in_s * duration)]) | |
frames_list.extend(range(video_start, video_end)) | |
frames = video.get_batch(frames_list).asnumpy() | |
return frames, video_onset_list | |
def save_multimodal(video, audio, output_path, audio_fps:int=16000, video_fps:int=8, remove_audio:bool=True): | |
imgs = [img for img in video] | |
# if audio.shape[0] == 1 or audio.shape[0] == 2: | |
# audio = audio.T #[len, channel] | |
# audio = np.repeat(audio, 2, axis=1) | |
output_dir = osp.dirname(output_path) | |
try: | |
wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio) | |
except: | |
sf.write(osp.join(output_dir, "audio.wav"), audio, audio_fps) | |
audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav")) | |
# audio_clip = AudioArrayClip(audio, fps=audio_fps) | |
video_clip = ImageSequenceClip(imgs, fps=video_fps) | |
video_clip = video_clip.set_audio(audio_clip) | |
video_clip.write_videofile(output_path, video_fps, audio=True, audio_fps=audio_fps) | |
if remove_audio: | |
os.remove(osp.join(output_dir, "audio.wav")) | |
return | |
def save_multimodal_by_frame(video, audio, output_path, audio_fps:int=16000): | |
imgs = [img for img in video] | |
# if audio.shape[0] == 1 or audio.shape[0] == 2: | |
# audio = audio.T #[len, channel] | |
# audio = np.repeat(audio, 2, axis=1) | |
# output_dir = osp.dirname(output_path) | |
output_dir = output_path | |
wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio) | |
audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav")) | |
# audio_clip = AudioArrayClip(audio, fps=audio_fps) | |
os.makedirs(osp.join(output_dir, 'frames'), exist_ok=True) | |
for num, img in enumerate(imgs): | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img.astype(np.uint8)) | |
img.save(osp.join(output_dir, 'frames', f"{num}.jpg")) | |
return | |
def sanity_check(data: dict, save_path: str="sanity_check", batch_size: int=4, sample_rate: int=16000): | |
video_path = osp.join(save_path, 'video') | |
audio_path = osp.join(save_path, 'audio') | |
av_path = osp.join(save_path, 'av') | |
video, audio, text = data['pixel_values'], data['audio'], data['text'] | |
video = (video / 2 + 0.5).clamp(0, 1) | |
zero_rank_print(f"Saving {text} audio: {audio[0].shape} video: {video[0].shape}") | |
for bsz in range(batch_size): | |
os.makedirs(video_path, exist_ok=True) | |
os.makedirs(audio_path, exist_ok=True) | |
os.makedirs(av_path, exist_ok=True) | |
# save_videos_grid(video[bsz:bsz+1,...], f"{osp.join(video_path, str(bsz) + '.mp4')}") | |
bsz_audio = audio[bsz,...].permute(1, 0).cpu().numpy() | |
bsz_video = video_tensor_to_np(video[bsz, ...]) | |
sf.write(f"{osp.join(audio_path, str(bsz) + '.wav')}", bsz_audio, sample_rate) | |
save_multimodal(bsz_video, bsz_audio, osp.join(av_path, str(bsz) + '.mp4')) | |
def video_tensor_to_np(video: torch.Tensor, rescale: bool=True, scale: bool=False): | |
if scale: | |
video = (video / 2 + 0.5).clamp(0, 1) | |
# c f h w -> f h w c | |
if video.shape[0] == 3: | |
video = video.permute(1, 2, 3, 0).detach().cpu().numpy() | |
elif video.shape[1] == 3: | |
video = video.permute(0, 2, 3, 1).detach().cpu().numpy() | |
if rescale: | |
video = video * 255 | |
return video | |
def composite_audio_video(video: str, audio: str, path:str, video_fps:int=7, audio_sample_rate:int=16000): | |
video = decord.VideoReader(video) | |
audio = decord.AudioReader(audio, sample_rate=audio_sample_rate) | |
audio = audio.get_batch(get_full_indices(audio)).asnumpy() | |
video = video.get_batch(get_full_indices(video)).asnumpy() | |
save_multimodal(video, audio, path, audio_fps=audio_sample_rate, video_fps=video_fps) | |
return | |
# for video pipeline | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
return x[(...,) + (None,) * dims_to_append] | |
def resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): | |
h, w = input.shape[-2:] | |
factors = (h / size[0], w / size[1]) | |
# First, we have to determine sigma | |
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 | |
sigmas = ( | |
max((factors[0] - 1.0) / 2.0, 0.001), | |
max((factors[1] - 1.0) / 2.0, 0.001), | |
) | |
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma | |
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 | |
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now | |
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) | |
# Make sure it is odd | |
if (ks[0] % 2) == 0: | |
ks = ks[0] + 1, ks[1] | |
if (ks[1] % 2) == 0: | |
ks = ks[0], ks[1] + 1 | |
input = _gaussian_blur2d(input, ks, sigmas) | |
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) | |
return output | |
def _gaussian_blur2d(input, kernel_size, sigma): | |
if isinstance(sigma, tuple): | |
sigma = torch.tensor([sigma], dtype=input.dtype) | |
else: | |
sigma = sigma.to(dtype=input.dtype) | |
ky, kx = int(kernel_size[0]), int(kernel_size[1]) | |
bs = sigma.shape[0] | |
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) | |
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) | |
out_x = _filter2d(input, kernel_x[..., None, :]) | |
out = _filter2d(out_x, kernel_y[..., None]) | |
return out | |
def _filter2d(input, kernel): | |
# prepare kernel | |
b, c, h, w = input.shape | |
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) | |
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) | |
height, width = tmp_kernel.shape[-2:] | |
padding_shape: list[int] = _compute_padding([height, width]) | |
input = torch.nn.functional.pad(input, padding_shape, mode="reflect") | |
# kernel and input tensor reshape to align element-wise or batch-wise params | |
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) | |
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) | |
# convolve the tensor with the kernel. | |
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
out = output.view(b, c, h, w) | |
return out | |
def _gaussian(window_size: int, sigma): | |
if isinstance(sigma, float): | |
sigma = torch.tensor([[sigma]]) | |
batch_size = sigma.shape[0] | |
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) | |
if window_size % 2 == 0: | |
x = x + 0.5 | |
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) | |
return gauss / gauss.sum(-1, keepdim=True) | |
def _compute_padding(kernel_size): | |
"""Compute padding tuple.""" | |
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) | |
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | |
if len(kernel_size) < 2: | |
raise AssertionError(kernel_size) | |
computed = [k - 1 for k in kernel_size] | |
# for even kernels we need to do asymmetric padding :( | |
out_padding = 2 * len(kernel_size) * [0] | |
for i in range(len(kernel_size)): | |
computed_tmp = computed[-(i + 1)] | |
pad_front = computed_tmp // 2 | |
pad_rear = computed_tmp - pad_front | |
out_padding[2 * i + 0] = pad_front | |
out_padding[2 * i + 1] = pad_rear | |
return out_padding | |
def print_gpu_memory_usage(info: str, cuda_id:int=0): | |
print(f">>> {info} <<<") | |
reserved = torch.cuda.memory_reserved(cuda_id) / 1024 ** 3 | |
used = torch.cuda.memory_allocated(cuda_id) / 1024 ** 3 | |
print("total: ", reserved, "G") | |
print("used: ", used, "G") | |
print("available: ", reserved - used, "G") | |
# use for dsp mel2spec | |
class SpectrogramParams: | |
""" | |
Parameters for the conversion from audio to spectrograms to images and back. | |
Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored | |
within spectrogram images. | |
To understand what these parameters do and to customize them, read `spectrogram_converter.py` | |
and the linked torchaudio documentation. | |
""" | |
# Whether the audio is stereo or mono | |
stereo: bool = False | |
# FFT parameters | |
sample_rate: int = 44100 | |
step_size_ms: int = 10 | |
window_duration_ms: int = 100 | |
padded_duration_ms: int = 400 | |
# Mel scale parameters | |
num_frequencies: int = 200 | |
# TODO(hayk): Set these to [20, 20000] for newer models | |
min_frequency: int = 0 | |
max_frequency: int = 10000 | |
mel_scale_norm: T.Optional[str] = None | |
mel_scale_type: str = "htk" | |
max_mel_iters: int = 200 | |
# Griffin Lim parameters | |
num_griffin_lim_iters: int = 32 | |
# Image parameterization | |
power_for_image: float = 0.25 | |
class ExifTags(Enum): | |
""" | |
Custom EXIF tags for the spectrogram image. | |
""" | |
SAMPLE_RATE = 11000 | |
STEREO = 11005 | |
STEP_SIZE_MS = 11010 | |
WINDOW_DURATION_MS = 11020 | |
PADDED_DURATION_MS = 11030 | |
NUM_FREQUENCIES = 11040 | |
MIN_FREQUENCY = 11050 | |
MAX_FREQUENCY = 11060 | |
POWER_FOR_IMAGE = 11070 | |
MAX_VALUE = 11080 | |
def n_fft(self) -> int: | |
""" | |
The number of samples in each STFT window, with padding. | |
""" | |
return int(self.padded_duration_ms / 1000.0 * self.sample_rate) | |
def win_length(self) -> int: | |
""" | |
The number of samples in each STFT window. | |
""" | |
return int(self.window_duration_ms / 1000.0 * self.sample_rate) | |
def hop_length(self) -> int: | |
""" | |
The number of samples between each STFT window. | |
""" | |
return int(self.step_size_ms / 1000.0 * self.sample_rate) | |
def to_exif(self) -> T.Dict[int, T.Any]: | |
""" | |
Return a dictionary of EXIF tags for the current values. | |
""" | |
return { | |
self.ExifTags.SAMPLE_RATE.value: self.sample_rate, | |
self.ExifTags.STEREO.value: self.stereo, | |
self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms, | |
self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms, | |
self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms, | |
self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies, | |
self.ExifTags.MIN_FREQUENCY.value: self.min_frequency, | |
self.ExifTags.MAX_FREQUENCY.value: self.max_frequency, | |
self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image), | |
} | |
class SpectrogramImageConverter: | |
""" | |
Convert between spectrogram images and audio segments. | |
This is a wrapper around SpectrogramConverter that additionally converts from spectrograms | |
to images and back. The real audio processing lives in SpectrogramConverter. | |
""" | |
def __init__(self, params: SpectrogramParams, device: str = "cuda"): | |
self.p = params | |
self.device = device | |
self.converter = SpectrogramConverter(params=params, device=device) | |
def spectrogram_image_from_audio( | |
self, | |
segment: pydub.AudioSegment, | |
) -> Image.Image: | |
""" | |
Compute a spectrogram image from an audio segment. | |
Args: | |
segment: Audio segment to convert | |
Returns: | |
Spectrogram image (in pillow format) | |
""" | |
assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch" | |
if self.p.stereo: | |
if segment.channels == 1: | |
print("WARNING: Mono audio but stereo=True, cloning channel") | |
segment = segment.set_channels(2) | |
elif segment.channels > 2: | |
print("WARNING: Multi channel audio, reducing to stereo") | |
segment = segment.set_channels(2) | |
else: | |
if segment.channels > 1: | |
print("WARNING: Stereo audio but stereo=False, setting to mono") | |
segment = segment.set_channels(1) | |
spectrogram = self.converter.spectrogram_from_audio(segment) | |
image = image_from_spectrogram( | |
spectrogram, | |
power=self.p.power_for_image, | |
) | |
# Store conversion params in exif metadata of the image | |
exif_data = self.p.to_exif() | |
exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram)) | |
exif = image.getexif() | |
exif.update(exif_data.items()) | |
return image | |
def audio_from_spectrogram_image( | |
self, | |
image: Image.Image, | |
apply_filters: bool = True, | |
max_value: float = 30e6, | |
) -> pydub.AudioSegment: | |
""" | |
Reconstruct an audio segment from a spectrogram image. | |
Args: | |
image: Spectrogram image (in pillow format) | |
apply_filters: Apply post-processing to improve the reconstructed audio | |
max_value: Scaled max amplitude of the spectrogram. Shouldn't matter. | |
""" | |
spectrogram = spectrogram_from_image( | |
image, | |
max_value=max_value, | |
power=self.p.power_for_image, | |
stereo=self.p.stereo, | |
) | |
segment = self.converter.audio_from_spectrogram( | |
spectrogram, | |
apply_filters=apply_filters, | |
) | |
return segment | |
def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image: | |
""" | |
Compute a spectrogram image from a spectrogram magnitude array. | |
This is the inverse of spectrogram_from_image, except for discretization error from | |
quantizing to uint8. | |
Args: | |
spectrogram: (channels, frequency, time) | |
power: A power curve to apply to the spectrogram to preserve contrast | |
Returns: | |
image: (frequency, time, channels) | |
""" | |
# Rescale to 0-1 | |
max_value = np.max(spectrogram) | |
data = spectrogram / max_value | |
# Apply the power curve | |
data = np.power(data, power) | |
# Rescale to 0-255 | |
data = data * 255 | |
# Invert | |
data = 255 - data | |
# Convert to uint8 | |
data = data.astype(np.uint8) | |
# Munge channels into a PIL image | |
if data.shape[0] == 1: | |
# TODO(hayk): Do we want to write single channel to disk instead? | |
image = Image.fromarray(data[0], mode="L").convert("RGB") | |
elif data.shape[0] == 2: | |
data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0) | |
image = Image.fromarray(data, mode="RGB") | |
else: | |
raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}") | |
# Flip Y | |
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) | |
return image | |
def spectrogram_from_image( | |
image: Image.Image, | |
power: float = 0.25, | |
stereo: bool = False, | |
max_value: float = 30e6, | |
) -> np.ndarray: | |
""" | |
Compute a spectrogram magnitude array from a spectrogram image. | |
This is the inverse of image_from_spectrogram, except for discretization error from | |
quantizing to uint8. | |
Args: | |
image: (frequency, time, channels) | |
power: The power curve applied to the spectrogram | |
stereo: Whether the spectrogram encodes stereo data | |
max_value: The max value of the original spectrogram. In practice doesn't matter. | |
Returns: | |
spectrogram: (channels, frequency, time) | |
""" | |
# Convert to RGB if single channel | |
if image.mode in ("P", "L"): | |
image = image.convert("RGB") | |
# Flip Y | |
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) | |
# Munge channels into a numpy array of (channels, frequency, time) | |
data = np.array(image).transpose(2, 0, 1) | |
if stereo: | |
# Take the G and B channels as done in image_from_spectrogram | |
data = data[[1, 2], :, :] | |
else: | |
data = data[0:1, :, :] | |
# Convert to floats | |
data = data.astype(np.float32) | |
# Invert | |
data = 255 - data | |
# Rescale to 0-1 | |
data = data / 255 | |
# Reverse the power curve | |
data = np.power(data, 1 / power) | |
# Rescale to max value | |
data = data * max_value | |
return data | |
class SpectrogramConverter: | |
""" | |
Convert between audio segments and spectrogram tensors using torchaudio. | |
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values | |
that represent the amplitude of the frequency at that time bucket (in the frequency domain). | |
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term | |
used in some functions is "mel amplitudes". | |
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only | |
returns the amplitude, because the phase is chaotic and hard to learn. The function | |
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which | |
approximates the phase information using the Griffin-Lim algorithm. | |
Each channel in the audio is treated independently, and the spectrogram has a batch dimension | |
equal to the number of channels in the input audio segment. | |
Both the Griffin Lim algorithm and the Mel scaling process are lossy. | |
For more information, see https://pytorch.org/audio/stable/transforms.html | |
""" | |
def __init__(self, params: SpectrogramParams, device: str = "cuda"): | |
self.p = params | |
self.device = check_device(device) | |
if device.lower().startswith("mps"): | |
warnings.warn( | |
"WARNING: MPS does not support audio operations, falling back to CPU for them", | |
stacklevel=2, | |
) | |
self.device = "cpu" | |
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html | |
self.spectrogram_func = torchaudio.transforms.Spectrogram( | |
n_fft=params.n_fft, | |
hop_length=params.hop_length, | |
win_length=params.win_length, | |
pad=0, | |
window_fn=torch.hann_window, | |
power=None, | |
normalized=False, | |
wkwargs=None, | |
center=True, | |
pad_mode="reflect", | |
onesided=True, | |
).to(self.device) | |
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html | |
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim( | |
n_fft=params.n_fft, | |
n_iter=params.num_griffin_lim_iters, | |
win_length=params.win_length, | |
hop_length=params.hop_length, | |
window_fn=torch.hann_window, | |
power=1.0, | |
wkwargs=None, | |
momentum=0.99, | |
length=None, | |
rand_init=True, | |
).to(self.device) | |
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html | |
self.mel_scaler = torchaudio.transforms.MelScale( | |
n_mels=params.num_frequencies, | |
sample_rate=params.sample_rate, | |
f_min=params.min_frequency, | |
f_max=params.max_frequency, | |
n_stft=params.n_fft // 2 + 1, | |
norm=params.mel_scale_norm, | |
mel_scale=params.mel_scale_type, | |
).to(self.device) | |
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html | |
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale( | |
n_stft=params.n_fft // 2 + 1, | |
n_mels=params.num_frequencies, | |
sample_rate=params.sample_rate, | |
f_min=params.min_frequency, | |
f_max=params.max_frequency, | |
# max_iter=params.max_mel_iters, # for higher verson of torchaudio | |
# tolerance_loss=1e-5, # for higher verson of torchaudio | |
# tolerance_change=1e-8, # for higher verson of torchaudio | |
# sgdargs=None, # for higher verson of torchaudio | |
norm=params.mel_scale_norm, | |
mel_scale=params.mel_scale_type, | |
).to(self.device) | |
def spectrogram_from_audio( | |
self, | |
audio: pydub.AudioSegment, | |
) -> np.ndarray: | |
""" | |
Compute a spectrogram from an audio segment. | |
Args: | |
audio: Audio segment which must match the sample rate of the params | |
Returns: | |
spectrogram: (channel, frequency, time) | |
""" | |
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params" | |
# Get the samples as a numpy array in (batch, samples) shape | |
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()]) | |
# Convert to floats if necessary | |
if waveform.dtype != np.float32: | |
waveform = waveform.astype(np.float32) | |
waveform_tensor = torch.from_numpy(waveform).to(self.device) | |
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor) | |
return amplitudes_mel.cpu().numpy() | |
def audio_from_spectrogram( | |
self, | |
spectrogram: np.ndarray, | |
apply_filters: bool = True, | |
) -> pydub.AudioSegment: | |
""" | |
Reconstruct an audio segment from a spectrogram. | |
Args: | |
spectrogram: (batch, frequency, time) | |
apply_filters: Post-process with normalization and compression | |
Returns: | |
audio: Audio segment with channels equal to the batch dimension | |
""" | |
# Move to device | |
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device) | |
# Reconstruct the waveform | |
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel) | |
# Convert to audio segment | |
segment = audio_from_waveform( | |
samples=waveform.cpu().numpy(), | |
sample_rate=self.p.sample_rate, | |
# Normalize the waveform to the range [-1, 1] | |
normalize=True, | |
) | |
# Optionally apply post-processing filters | |
if apply_filters: | |
segment = apply_filters_func( | |
segment, | |
compression=False, | |
) | |
return segment | |
def mel_amplitudes_from_waveform( | |
self, | |
waveform: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Torch-only function to compute Mel-scale amplitudes from a waveform. | |
Args: | |
waveform: (batch, samples) | |
Returns: | |
amplitudes_mel: (batch, frequency, time) | |
""" | |
# Compute the complex-valued spectrogram | |
spectrogram_complex = self.spectrogram_func(waveform) | |
# Take the magnitude | |
amplitudes = torch.abs(spectrogram_complex) | |
# Convert to mel scale | |
return self.mel_scaler(amplitudes) | |
def waveform_from_mel_amplitudes( | |
self, | |
amplitudes_mel: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes. | |
Args: | |
amplitudes_mel: (batch, frequency, time) | |
Returns: | |
waveform: (batch, samples) | |
""" | |
# Convert from mel scale to linear | |
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel) | |
# Run the approximate algorithm to compute the phase and recover the waveform | |
return self.inverse_spectrogram_func(amplitudes_linear) | |
def check_device(device: str, backup: str = "cpu") -> str: | |
""" | |
Check that the device is valid and available. If not, | |
""" | |
cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available() | |
mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available() | |
if cuda_not_found or mps_not_found: | |
warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3) | |
return backup | |
return device | |
def audio_from_waveform( | |
samples: np.ndarray, sample_rate: int, normalize: bool = False | |
) -> pydub.AudioSegment: | |
""" | |
Convert a numpy array of samples of a waveform to an audio segment. | |
Args: | |
samples: (channels, samples) array | |
""" | |
# Normalize volume to fit in int16 | |
if normalize: | |
samples *= np.iinfo(np.int16).max / np.max(np.abs(samples)) | |
# Transpose and convert to int16 | |
samples = samples.transpose(1, 0) | |
samples = samples.astype(np.int16) | |
# Write to the bytes of a WAV file | |
wav_bytes = io.BytesIO() | |
wavfile.write(wav_bytes, sample_rate, samples) | |
wav_bytes.seek(0) | |
# Read into pydub | |
return pydub.AudioSegment.from_wav(wav_bytes) | |
def apply_filters_func(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment: | |
""" | |
Apply post-processing filters to the audio segment to compress it and | |
keep at a -10 dBFS level. | |
""" | |
# TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end. | |
# TODO(hayk): Is this going to make audio unbalanced between sequential clips? | |
if compression: | |
segment = pydub.effects.normalize( | |
segment, | |
headroom=0.1, | |
) | |
segment = segment.apply_gain(-10 - segment.dBFS) | |
# TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU | |
segment = pydub.effects.compress_dynamic_range( | |
segment, | |
threshold=-20.0, | |
ratio=4.0, | |
attack=5.0, | |
release=50.0, | |
) | |
desired_db = -12 | |
segment = segment.apply_gain(desired_db - segment.dBFS) | |
segment = pydub.effects.normalize( | |
segment, | |
headroom=0.1, | |
) | |
return segment | |
def shave_segments(path, n_shave_prefix_segments=1): | |
""" | |
Removes segments. Positive values shave the first segments, negative shave the last segments. | |
""" | |
if n_shave_prefix_segments >= 0: | |
return ".".join(path.split(".")[n_shave_prefix_segments:]) | |
else: | |
return ".".join(path.split(".")[:n_shave_prefix_segments]) | |
def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | |
""" | |
Updates paths inside resnets to the new naming scheme (local renaming) | |
""" | |
mapping = [] | |
for old_item in old_list: | |
new_item = old_item.replace("in_layers.0", "norm1") | |
new_item = new_item.replace("in_layers.2", "conv1") | |
new_item = new_item.replace("out_layers.0", "norm2") | |
new_item = new_item.replace("out_layers.3", "conv2") | |
new_item = new_item.replace("emb_layers.1", "time_emb_proj") | |
new_item = new_item.replace("skip_connection", "conv_shortcut") | |
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
mapping.append({"old": old_item, "new": new_item}) | |
return mapping | |
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | |
""" | |
Updates paths inside resnets to the new naming scheme (local renaming) | |
""" | |
mapping = [] | |
for old_item in old_list: | |
new_item = old_item | |
new_item = new_item.replace("nin_shortcut", "conv_shortcut") | |
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
mapping.append({"old": old_item, "new": new_item}) | |
return mapping | |
def renew_attention_paths(old_list, n_shave_prefix_segments=0): | |
""" | |
Updates paths inside attentions to the new naming scheme (local renaming) | |
""" | |
mapping = [] | |
for old_item in old_list: | |
new_item = old_item | |
# new_item = new_item.replace('norm.weight', 'group_norm.weight') | |
# new_item = new_item.replace('norm.bias', 'group_norm.bias') | |
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | |
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | |
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
mapping.append({"old": old_item, "new": new_item}) | |
return mapping | |
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | |
""" | |
Updates paths inside attentions to the new naming scheme (local renaming) | |
""" | |
mapping = [] | |
for old_item in old_list: | |
new_item = old_item | |
new_item = new_item.replace("norm.weight", "group_norm.weight") | |
new_item = new_item.replace("norm.bias", "group_norm.bias") | |
new_item = new_item.replace("q.weight", "to_q.weight") | |
new_item = new_item.replace("q.bias", "to_q.bias") | |
new_item = new_item.replace("k.weight", "to_k.weight") | |
new_item = new_item.replace("k.bias", "to_k.bias") | |
new_item = new_item.replace("v.weight", "to_v.weight") | |
new_item = new_item.replace("v.bias", "to_v.bias") | |
new_item = new_item.replace("proj_out.weight", "to_out.0.weight") | |
new_item = new_item.replace("proj_out.bias", "to_out.0.bias") | |
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
mapping.append({"old": old_item, "new": new_item}) | |
return mapping | |
def assign_to_checkpoint( | |
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | |
): | |
""" | |
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits | |
attention layers, and takes into account additional replacements that may arise. | |
Assigns the weights to the new checkpoint. | |
""" | |
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | |
# Splits the attention layers into three variables. | |
if attention_paths_to_split is not None: | |
for path, path_map in attention_paths_to_split.items(): | |
old_tensor = old_checkpoint[path] | |
channels = old_tensor.shape[0] // 3 | |
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | |
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | |
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | |
query, key, value = old_tensor.split(channels // num_heads, dim=1) | |
checkpoint[path_map["query"]] = query.reshape(target_shape) | |
checkpoint[path_map["key"]] = key.reshape(target_shape) | |
checkpoint[path_map["value"]] = value.reshape(target_shape) | |
for path in paths: | |
new_path = path["new"] | |
# These have already been assigned | |
if attention_paths_to_split is not None and new_path in attention_paths_to_split: | |
continue | |
# Global renaming happens here | |
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | |
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | |
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | |
if additional_replacements is not None: | |
for replacement in additional_replacements: | |
new_path = new_path.replace(replacement["old"], replacement["new"]) | |
# proj_attn.weight has to be converted from conv 1D to linear | |
if "proj_attn.weight" in new_path: | |
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | |
elif 'to_out.0.weight' in new_path: | |
checkpoint[new_path] = old_checkpoint[path['old']].squeeze() | |
elif any([qkv in new_path for qkv in ['to_q', 'to_k', 'to_v']]): | |
checkpoint[new_path] = old_checkpoint[path['old']].squeeze() | |
else: | |
checkpoint[new_path] = old_checkpoint[path["old"]] | |
def conv_attn_to_linear(checkpoint): | |
keys = list(checkpoint.keys()) | |
attn_keys = ["query.weight", "key.weight", "value.weight"] | |
for key in keys: | |
if ".".join(key.split(".")[-2:]) in attn_keys: | |
if checkpoint[key].ndim > 2: | |
checkpoint[key] = checkpoint[key][:, :, 0, 0] | |
elif "proj_attn.weight" in key: | |
if checkpoint[key].ndim > 2: | |
checkpoint[key] = checkpoint[key][:, :, 0] | |
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): | |
""" | |
Creates a config for the diffusers based on the config of the LDM model. | |
""" | |
if controlnet: | |
unet_params = original_config.model.params.control_stage_config.params | |
else: | |
unet_params = original_config.model.params.unet_config.params | |
vae_params = original_config.model.params.first_stage_config.params.ddconfig | |
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | |
down_block_types = [] | |
resolution = 1 | |
for i in range(len(block_out_channels)): | |
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | |
down_block_types.append(block_type) | |
if i != len(block_out_channels) - 1: | |
resolution *= 2 | |
up_block_types = [] | |
for i in range(len(block_out_channels)): | |
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | |
up_block_types.append(block_type) | |
resolution //= 2 | |
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) | |
head_dim = unet_params.num_heads if "num_heads" in unet_params else None | |
use_linear_projection = ( | |
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False | |
) | |
if use_linear_projection: | |
# stable diffusion 2-base-512 and 2-768 | |
if head_dim is None: | |
head_dim = [5, 10, 20, 20] | |
class_embed_type = None | |
projection_class_embeddings_input_dim = None | |
if "num_classes" in unet_params: | |
if unet_params.num_classes == "sequential": | |
class_embed_type = "projection" | |
assert "adm_in_channels" in unet_params | |
projection_class_embeddings_input_dim = unet_params.adm_in_channels | |
else: | |
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") | |
config = { | |
"sample_size": image_size // vae_scale_factor, | |
"in_channels": unet_params.in_channels, | |
"down_block_types": tuple(down_block_types), | |
"block_out_channels": tuple(block_out_channels), | |
"layers_per_block": unet_params.num_res_blocks, | |
"cross_attention_dim": unet_params.context_dim, | |
"attention_head_dim": head_dim, | |
"use_linear_projection": use_linear_projection, | |
"class_embed_type": class_embed_type, | |
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, | |
} | |
if not controlnet: | |
config["out_channels"] = unet_params.out_channels | |
config["up_block_types"] = tuple(up_block_types) | |
return config | |
def create_vae_diffusers_config(original_config, image_size: int): | |
""" | |
Creates a config for the diffusers based on the config of the LDM model. | |
""" | |
vae_params = original_config.model.params.first_stage_config.params.ddconfig | |
_ = original_config.model.params.first_stage_config.params.embed_dim | |
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | |
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | |
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | |
config = { | |
"sample_size": image_size, | |
"in_channels": vae_params.in_channels, | |
"out_channels": vae_params.out_ch, | |
"down_block_types": tuple(down_block_types), | |
"up_block_types": tuple(up_block_types), | |
"block_out_channels": tuple(block_out_channels), | |
"latent_channels": vae_params.z_channels, | |
"layers_per_block": vae_params.num_res_blocks, | |
} | |
return config | |
def create_diffusers_schedular(original_config): | |
schedular = DDIMScheduler( | |
num_train_timesteps=original_config.model.params.timesteps, | |
beta_start=original_config.model.params.linear_start, | |
beta_end=original_config.model.params.linear_end, | |
beta_schedule="scaled_linear", | |
) | |
return schedular | |
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): | |
""" | |
Takes a state dict and a config, and returns a converted checkpoint. | |
""" | |
# extract state_dict for UNet | |
unet_state_dict = {} | |
keys = list(checkpoint.keys()) | |
if controlnet: | |
unet_key = "control_model." | |
else: | |
unet_key = "model.diffusion_model." | |
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA | |
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: | |
print(f"Checkpoint {path} has both EMA and non-EMA weights.") | |
print( | |
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" | |
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." | |
) | |
for key in keys: | |
if key.startswith("model.diffusion_model"): | |
flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) | |
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) | |
else: | |
if sum(k.startswith("model_ema") for k in keys) > 100: | |
print( | |
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" | |
" weights (usually better for inference), please make sure to add the `--extract_ema` flag." | |
) | |
for key in keys: | |
if key.startswith(unet_key): | |
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | |
new_checkpoint = {} | |
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | |
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | |
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | |
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | |
if config["class_embed_type"] is None: | |
# No parameters to port | |
... | |
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": | |
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] | |
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] | |
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] | |
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] | |
else: | |
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") | |
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | |
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | |
if not controlnet: | |
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | |
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | |
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | |
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | |
# Retrieves the keys for the input blocks only | |
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | |
input_blocks = { | |
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | |
for layer_id in range(num_input_blocks) | |
} | |
# Retrieves the keys for the middle blocks only | |
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | |
middle_blocks = { | |
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | |
for layer_id in range(num_middle_blocks) | |
} | |
# Retrieves the keys for the output blocks only | |
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | |
output_blocks = { | |
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | |
for layer_id in range(num_output_blocks) | |
} | |
for i in range(1, num_input_blocks): | |
block_id = (i - 1) // (config["layers_per_block"] + 1) | |
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | |
resnets = [ | |
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | |
] | |
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | |
if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | |
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | |
f"input_blocks.{i}.0.op.weight" | |
) | |
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | |
f"input_blocks.{i}.0.op.bias" | |
) | |
paths = renew_resnet_paths(resnets) | |
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
if len(attentions): | |
paths = renew_attention_paths(attentions) | |
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
resnet_0 = middle_blocks[0] | |
attentions = middle_blocks[1] | |
resnet_1 = middle_blocks[2] | |
resnet_0_paths = renew_resnet_paths(resnet_0) | |
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | |
resnet_1_paths = renew_resnet_paths(resnet_1) | |
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | |
attentions_paths = renew_attention_paths(attentions) | |
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | |
assign_to_checkpoint( | |
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
for i in range(num_output_blocks): | |
block_id = i // (config["layers_per_block"] + 1) | |
layer_in_block_id = i % (config["layers_per_block"] + 1) | |
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | |
output_block_list = {} | |
for layer in output_block_layers: | |
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | |
if layer_id in output_block_list: | |
output_block_list[layer_id].append(layer_name) | |
else: | |
output_block_list[layer_id] = [layer_name] | |
if len(output_block_list) > 1: | |
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | |
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | |
resnet_0_paths = renew_resnet_paths(resnets) | |
paths = renew_resnet_paths(resnets) | |
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
output_block_list = {k: sorted(v) for k, v in output_block_list.items()} | |
if ["conv.bias", "conv.weight"] in output_block_list.values(): | |
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) | |
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | |
f"output_blocks.{i}.{index}.conv.weight" | |
] | |
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | |
f"output_blocks.{i}.{index}.conv.bias" | |
] | |
# Clear attentions as they have been attributed above. | |
if len(attentions) == 2: | |
attentions = [] | |
if len(attentions): | |
paths = renew_attention_paths(attentions) | |
meta_path = { | |
"old": f"output_blocks.{i}.1", | |
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | |
} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
else: | |
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | |
for path in resnet_0_paths: | |
old_path = ".".join(["output_blocks", str(i), path["old"]]) | |
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | |
new_checkpoint[new_path] = unet_state_dict[old_path] | |
if controlnet: | |
# conditioning embedding | |
orig_index = 0 | |
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( | |
f"input_hint_block.{orig_index}.weight" | |
) | |
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( | |
f"input_hint_block.{orig_index}.bias" | |
) | |
orig_index += 2 | |
diffusers_index = 0 | |
while diffusers_index < 6: | |
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( | |
f"input_hint_block.{orig_index}.weight" | |
) | |
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( | |
f"input_hint_block.{orig_index}.bias" | |
) | |
diffusers_index += 1 | |
orig_index += 2 | |
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( | |
f"input_hint_block.{orig_index}.weight" | |
) | |
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( | |
f"input_hint_block.{orig_index}.bias" | |
) | |
# down blocks | |
for i in range(num_input_blocks): | |
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") | |
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") | |
# mid block | |
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") | |
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") | |
return new_checkpoint | |
def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False): | |
# extract state dict for VAE | |
vae_state_dict = {} | |
vae_key = "first_stage_model." | |
keys = list(checkpoint.keys()) | |
for key in keys: | |
if key.startswith(vae_key): | |
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | |
new_checkpoint = {} | |
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | |
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | |
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | |
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | |
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | |
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | |
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | |
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | |
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | |
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | |
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | |
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | |
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | |
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | |
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | |
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | |
# Retrieves the keys for the encoder down blocks only | |
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | |
down_blocks = { | |
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | |
} | |
# Retrieves the keys for the decoder up blocks only | |
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | |
up_blocks = { | |
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | |
} | |
for i in range(num_down_blocks): | |
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | |
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | |
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | |
f"encoder.down.{i}.downsample.conv.weight" | |
) | |
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | |
f"encoder.down.{i}.downsample.conv.bias" | |
) | |
paths = renew_vae_resnet_paths(resnets) | |
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | |
num_mid_res_blocks = 2 | |
for i in range(1, num_mid_res_blocks + 1): | |
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | |
paths = renew_vae_resnet_paths(resnets) | |
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | |
paths = renew_vae_attention_paths(mid_attentions) | |
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
conv_attn_to_linear(new_checkpoint) | |
for i in range(num_up_blocks): | |
block_id = num_up_blocks - 1 - i | |
resnets = [ | |
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | |
] | |
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | |
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | |
f"decoder.up.{block_id}.upsample.conv.weight" | |
] | |
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | |
f"decoder.up.{block_id}.upsample.conv.bias" | |
] | |
paths = renew_vae_resnet_paths(resnets) | |
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | |
num_mid_res_blocks = 2 | |
for i in range(1, num_mid_res_blocks + 1): | |
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | |
paths = renew_vae_resnet_paths(resnets) | |
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | |
paths = renew_vae_attention_paths(mid_attentions) | |
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
conv_attn_to_linear(new_checkpoint) | |
if only_decoder: | |
new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('decoder') or k.startswith('post_quant')} | |
elif only_encoder: | |
new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('encoder') or k.startswith('quant')} | |
return new_checkpoint | |
def convert_ldm_clip_checkpoint(checkpoint): | |
keys = list(checkpoint.keys()) | |
text_model_dict = {} | |
for key in keys: | |
if key.startswith("cond_stage_model.transformer"): | |
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] | |
return text_model_dict | |
def convert_lora_model_level(state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): | |
"""convert lora in model level instead of pipeline leval | |
""" | |
visited = [] | |
# directly update weight in diffusers model | |
for key in state_dict: | |
# it is suggested to print out the key, it usually will be something like below | |
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
# as we have set the alpha beforehand, so just skip | |
if ".alpha" in key or key in visited: | |
continue | |
if "text" in key: | |
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") | |
assert text_encoder is not None, ( | |
'text_encoder must be passed since lora contains text encoder layers') | |
curr_layer = text_encoder | |
else: | |
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") | |
curr_layer = unet | |
# find the target layer | |
temp_name = layer_infos.pop(0) | |
while len(layer_infos) > -1: | |
try: | |
curr_layer = curr_layer.__getattr__(temp_name) | |
if len(layer_infos) > 0: | |
temp_name = layer_infos.pop(0) | |
elif len(layer_infos) == 0: | |
break | |
except Exception: | |
if len(temp_name) > 0: | |
temp_name += "_" + layer_infos.pop(0) | |
else: | |
temp_name = layer_infos.pop(0) | |
pair_keys = [] | |
if "lora_down" in key: | |
pair_keys.append(key.replace("lora_down", "lora_up")) | |
pair_keys.append(key) | |
else: | |
pair_keys.append(key) | |
pair_keys.append(key.replace("lora_up", "lora_down")) | |
# update weight | |
# NOTE: load lycon, meybe have bugs :( | |
if 'conv_in' in pair_keys[0]: | |
weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
weight_up = weight_up.view(weight_up.size(0), -1) | |
weight_down = weight_down.view(weight_down.size(0), -1) | |
shape = [e for e in curr_layer.weight.data.shape] | |
shape[1] = 4 | |
curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape) | |
elif 'conv' in pair_keys[0]: | |
weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
weight_up = weight_up.view(weight_up.size(0), -1) | |
weight_down = weight_down.view(weight_down.size(0), -1) | |
shape = [e for e in curr_layer.weight.data.shape] | |
curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape) | |
elif len(state_dict[pair_keys[0]].shape) == 4: | |
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) | |
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) | |
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) | |
else: | |
weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) | |
# update visited list | |
for item in pair_keys: | |
visited.append(item) | |
return unet, text_encoder | |
def denormalize_spectrogram( | |
data: torch.Tensor, | |
max_value: float = 200, | |
min_value: float = 1e-5, | |
power: float = 1, | |
inverse: bool = False, | |
) -> torch.Tensor: | |
max_value = np.log(max_value) | |
min_value = np.log(min_value) | |
# Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner | |
data = torch.flip(data, [1]) | |
assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape)) | |
if data.shape[0] == 1: | |
data = data.repeat(3, 1, 1) | |
assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0]) | |
data = data[0] | |
# Reverse the power curve | |
data = torch.pow(data, 1 / power) | |
# Invert | |
if inverse: | |
data = 1 - data | |
# Rescale to max value | |
spectrogram = data * (max_value - min_value) + min_value | |
return spectrogram | |
class ToTensor1D(torchvision.transforms.ToTensor): | |
def __call__(self, tensor: np.ndarray): | |
tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis]) | |
return tensor_2d.squeeze_(0) | |
def scale(old_value, old_min, old_max, new_min, new_max): | |
old_range = (old_max - old_min) | |
new_range = (new_max - new_min) | |
new_value = (((old_value - old_min) * new_range) / old_range) + new_min | |
return new_value | |
def read_frames_with_moviepy(video_path, max_frame_nums=None): | |
clip = VideoFileClip(video_path) | |
duration = clip.duration | |
frames = [] | |
for frame in clip.iter_frames(): | |
frames.append(frame) | |
if max_frame_nums is not None: | |
frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int) | |
return np.array(frames)[frames_idx,...], duration | |
def read_frames_with_moviepy_resample(video_path, save_path): | |
vision_transform_list = [ | |
transforms.Resize((128, 128)), | |
transforms.CenterCrop((112, 112)), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
] | |
video_transform = transforms.Compose(vision_transform_list) | |
os.makedirs(save_path, exist_ok=True) | |
command = f'ffmpeg -v quiet -y -i \"{video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg' | |
os.system(command) | |
frame_list = glob.glob(f'{save_path}/*.jpg') | |
frame_list.sort() | |
convert_tensor = transforms.ToTensor() | |
frame_list = [convert_tensor(np.array(Image.open(frame))) for frame in frame_list] | |
imgs = torch.stack(frame_list, dim=0) | |
imgs = video_transform(imgs) | |
imgs = imgs.permute(1, 0, 2, 3) | |
return imgs |