import torch import torchvision import torchaudio import torchvision.transforms as transforms from diffusers import UNet2DConditionModel, ControlNetModel from foleycrafter.pipelines.pipeline_controlnet import StableDiffusionControlNetPipeline from foleycrafter.pipelines.auffusion_pipeline import AuffusionNoAdapterPipeline, Generator from foleycrafter.models.auffusion_unet import UNet2DConditionModel as af_UNet2DConditionModel from diffusers.models import AutoencoderKLTemporalDecoder, AutoencoderKL from diffusers.schedulers import EulerDiscreteScheduler, DDIMScheduler, PNDMScheduler, KarrasDiffusionSchedulers from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection,\ SpeechT5HifiGan, ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast,\ CLIPTextModel, CLIPTokenizer import glob from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip, VideoClip from moviepy.audio.AudioClip import AudioArrayClip import numpy as np from safetensors import safe_open import random from typing import Union, Optional import decord import os import os.path as osp import imageio import soundfile as sf from PIL import Image, ImageOps import torch.distributed as dist import io from omegaconf import OmegaConf import json from dataclasses import dataclass from enum import Enum import typing as T import warnings import pydub from scipy.io import wavfile from einops import rearrange def zero_rank_print(s): if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True) def build_foleycrafter( pretrained_model_name_or_path: str="auffusion/auffusion-full-no-adapter", ) -> StableDiffusionControlNetPipeline: vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = af_UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') scheduler = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') controlnet = ControlNetModel.from_unet(unet, conditioning_channels=1) pipe = StableDiffusionControlNetPipeline( vae=vae, controlnet=controlnet, unet=unet, scheduler=scheduler, tokenizer=tokenizer, text_encoder=text_encoder, feature_extractor=None, safety_checker=None, requires_safety_checker=False, ) return pipe def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): if len(videos.shape) == 4: videos = videos.unsqueeze(0) videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, fps=fps) def save_videos_from_pil_list(videos: list, path: str, fps=7): for i in range(len(videos)): videos[i] = ImageOps.scale(videos[i], 255) imageio.mimwrite(path, videos, fps=fps) def seed_everything(seed: int) -> None: r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, :obj:`numpy` and :python:`Python`. Args: seed (int): The desired seed. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_video_frames(video: np.ndarray, num_frames: int=200): video_length = video.shape[0] video_idx = np.linspace(0, video_length-1, num_frames, dtype=int) video = video[video_idx, ...] return video def random_audio_video_clip(audio: np.ndarray, video: np.ndarray, fps:float, \ sample_rate:int=16000, duration:int=5, num_frames: int=20): """ Random sample video clips with duration """ video_length = video.shape[0] audio_length = audio.shape[-1] av_duration = int(video_length / fps) assert av_duration >= duration,\ f"video duration {av_duration} is less than {duration}" # random sample start time start_time = random.uniform(0, av_duration - duration) end_time = start_time + duration start_idx, end_idx = start_time / av_duration, end_time / av_duration video_start_frame, video_end_frame\ = video_length * start_idx, video_length * end_idx audio_start_frame, audio_end_frame\ = audio_length * start_idx, audio_length * end_idx # print(f"time_idx : {start_time}:{end_time}") # print(f"video_idx: {video_start_frame}:{video_end_frame}") # print(f"audio_idx: {audio_start_frame}:{audio_end_frame}") audio_idx = np.linspace(audio_start_frame, audio_end_frame, sample_rate * duration, dtype=int) video_idx = np.linspace(video_start_frame, video_end_frame, num_frames, dtype=int) audio = audio[..., audio_idx] video = video[video_idx, ...] return audio, video def get_full_indices(reader: Union[decord.VideoReader, decord.AudioReader])\ -> np.ndarray: if isinstance(reader, decord.VideoReader): return np.linspace(0, len(reader) - 1, len(reader), dtype=int) elif isinstance(reader, decord.AudioReader): return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int) def get_frames(video_path:str, onset_list, frame_nums=1024): video = decord.VideoReader(video_path) video_frame = len(video) frames_list = [] for start, end in onset_list: video_start = int(start / frame_nums * video_frame) video_end = int(end / frame_nums * video_frame) frames_list.extend(range(video_start, video_end)) frames = video.get_batch(frames_list).asnumpy() return frames def get_frames_in_video(video_path:str, onset_list, frame_nums=1024, audio_length_in_s=10): # this function consider the video length video = decord.VideoReader(video_path) video_frame = len(video) duration = video_frame / video.get_avg_fps() frames_list = [] video_onset_list = [] for start, end in onset_list: if int(start / frame_nums * duration) >= audio_length_in_s: continue video_start = int(start / audio_length_in_s * duration / frame_nums * video_frame) if video_start >= video_frame: continue video_end = int(end / audio_length_in_s * duration / frame_nums * video_frame) video_onset_list.append([int(start / audio_length_in_s * duration), int(end / audio_length_in_s * duration)]) frames_list.extend(range(video_start, video_end)) frames = video.get_batch(frames_list).asnumpy() return frames, video_onset_list def save_multimodal(video, audio, output_path, audio_fps:int=16000, video_fps:int=8, remove_audio:bool=True): imgs = [img for img in video] # if audio.shape[0] == 1 or audio.shape[0] == 2: # audio = audio.T #[len, channel] # audio = np.repeat(audio, 2, axis=1) output_dir = osp.dirname(output_path) try: wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio) except: sf.write(osp.join(output_dir, "audio.wav"), audio, audio_fps) audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav")) # audio_clip = AudioArrayClip(audio, fps=audio_fps) video_clip = ImageSequenceClip(imgs, fps=video_fps) video_clip = video_clip.set_audio(audio_clip) video_clip.write_videofile(output_path, video_fps, audio=True, audio_fps=audio_fps) if remove_audio: os.remove(osp.join(output_dir, "audio.wav")) return def save_multimodal_by_frame(video, audio, output_path, audio_fps:int=16000): imgs = [img for img in video] # if audio.shape[0] == 1 or audio.shape[0] == 2: # audio = audio.T #[len, channel] # audio = np.repeat(audio, 2, axis=1) # output_dir = osp.dirname(output_path) output_dir = output_path wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio) audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav")) # audio_clip = AudioArrayClip(audio, fps=audio_fps) os.makedirs(osp.join(output_dir, 'frames'), exist_ok=True) for num, img in enumerate(imgs): if isinstance(img, np.ndarray): img = Image.fromarray(img.astype(np.uint8)) img.save(osp.join(output_dir, 'frames', f"{num}.jpg")) return def sanity_check(data: dict, save_path: str="sanity_check", batch_size: int=4, sample_rate: int=16000): video_path = osp.join(save_path, 'video') audio_path = osp.join(save_path, 'audio') av_path = osp.join(save_path, 'av') video, audio, text = data['pixel_values'], data['audio'], data['text'] video = (video / 2 + 0.5).clamp(0, 1) zero_rank_print(f"Saving {text} audio: {audio[0].shape} video: {video[0].shape}") for bsz in range(batch_size): os.makedirs(video_path, exist_ok=True) os.makedirs(audio_path, exist_ok=True) os.makedirs(av_path, exist_ok=True) # save_videos_grid(video[bsz:bsz+1,...], f"{osp.join(video_path, str(bsz) + '.mp4')}") bsz_audio = audio[bsz,...].permute(1, 0).cpu().numpy() bsz_video = video_tensor_to_np(video[bsz, ...]) sf.write(f"{osp.join(audio_path, str(bsz) + '.wav')}", bsz_audio, sample_rate) save_multimodal(bsz_video, bsz_audio, osp.join(av_path, str(bsz) + '.mp4')) def video_tensor_to_np(video: torch.Tensor, rescale: bool=True, scale: bool=False): if scale: video = (video / 2 + 0.5).clamp(0, 1) # c f h w -> f h w c if video.shape[0] == 3: video = video.permute(1, 2, 3, 0).detach().cpu().numpy() elif video.shape[1] == 3: video = video.permute(0, 2, 3, 1).detach().cpu().numpy() if rescale: video = video * 255 return video def composite_audio_video(video: str, audio: str, path:str, video_fps:int=7, audio_sample_rate:int=16000): video = decord.VideoReader(video) audio = decord.AudioReader(audio, sample_rate=audio_sample_rate) audio = audio.get_batch(get_full_indices(audio)).asnumpy() video = video.get_batch(get_full_indices(video)).asnumpy() save_multimodal(video, audio, path, audio_fps=audio_sample_rate, video_fps=video_fps) return # for video pipeline def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] def resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): h, w = input.shape[-2:] factors = (h / size[0], w / size[1]) # First, we have to determine sigma # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 sigmas = ( max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001), ) # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) # Make sure it is odd if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1] if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1 input = _gaussian_blur2d(input, ks, sigmas) output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) return output def _gaussian_blur2d(input, kernel_size, sigma): if isinstance(sigma, tuple): sigma = torch.tensor([sigma], dtype=input.dtype) else: sigma = sigma.to(dtype=input.dtype) ky, kx = int(kernel_size[0]), int(kernel_size[1]) bs = sigma.shape[0] kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) out_x = _filter2d(input, kernel_x[..., None, :]) out = _filter2d(out_x, kernel_y[..., None]) return out def _filter2d(input, kernel): # prepare kernel b, c, h, w = input.shape tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) height, width = tmp_kernel.shape[-2:] padding_shape: list[int] = _compute_padding([height, width]) input = torch.nn.functional.pad(input, padding_shape, mode="reflect") # kernel and input tensor reshape to align element-wise or batch-wise params tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) # convolve the tensor with the kernel. output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) out = output.view(b, c, h, w) return out def _gaussian(window_size: int, sigma): if isinstance(sigma, float): sigma = torch.tensor([[sigma]]) batch_size = sigma.shape[0] x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) if window_size % 2 == 0: x = x + 0.5 gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) return gauss / gauss.sum(-1, keepdim=True) def _compute_padding(kernel_size): """Compute padding tuple.""" # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad if len(kernel_size) < 2: raise AssertionError(kernel_size) computed = [k - 1 for k in kernel_size] # for even kernels we need to do asymmetric padding :( out_padding = 2 * len(kernel_size) * [0] for i in range(len(kernel_size)): computed_tmp = computed[-(i + 1)] pad_front = computed_tmp // 2 pad_rear = computed_tmp - pad_front out_padding[2 * i + 0] = pad_front out_padding[2 * i + 1] = pad_rear return out_padding def print_gpu_memory_usage(info: str, cuda_id:int=0): print(f">>> {info} <<<") reserved = torch.cuda.memory_reserved(cuda_id) / 1024 ** 3 used = torch.cuda.memory_allocated(cuda_id) / 1024 ** 3 print("total: ", reserved, "G") print("used: ", used, "G") print("available: ", reserved - used, "G") # use for dsp mel2spec @dataclass(frozen=True) class SpectrogramParams: """ Parameters for the conversion from audio to spectrograms to images and back. Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored within spectrogram images. To understand what these parameters do and to customize them, read `spectrogram_converter.py` and the linked torchaudio documentation. """ # Whether the audio is stereo or mono stereo: bool = False # FFT parameters sample_rate: int = 44100 step_size_ms: int = 10 window_duration_ms: int = 100 padded_duration_ms: int = 400 # Mel scale parameters num_frequencies: int = 200 # TODO(hayk): Set these to [20, 20000] for newer models min_frequency: int = 0 max_frequency: int = 10000 mel_scale_norm: T.Optional[str] = None mel_scale_type: str = "htk" max_mel_iters: int = 200 # Griffin Lim parameters num_griffin_lim_iters: int = 32 # Image parameterization power_for_image: float = 0.25 class ExifTags(Enum): """ Custom EXIF tags for the spectrogram image. """ SAMPLE_RATE = 11000 STEREO = 11005 STEP_SIZE_MS = 11010 WINDOW_DURATION_MS = 11020 PADDED_DURATION_MS = 11030 NUM_FREQUENCIES = 11040 MIN_FREQUENCY = 11050 MAX_FREQUENCY = 11060 POWER_FOR_IMAGE = 11070 MAX_VALUE = 11080 @property def n_fft(self) -> int: """ The number of samples in each STFT window, with padding. """ return int(self.padded_duration_ms / 1000.0 * self.sample_rate) @property def win_length(self) -> int: """ The number of samples in each STFT window. """ return int(self.window_duration_ms / 1000.0 * self.sample_rate) @property def hop_length(self) -> int: """ The number of samples between each STFT window. """ return int(self.step_size_ms / 1000.0 * self.sample_rate) def to_exif(self) -> T.Dict[int, T.Any]: """ Return a dictionary of EXIF tags for the current values. """ return { self.ExifTags.SAMPLE_RATE.value: self.sample_rate, self.ExifTags.STEREO.value: self.stereo, self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms, self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms, self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms, self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies, self.ExifTags.MIN_FREQUENCY.value: self.min_frequency, self.ExifTags.MAX_FREQUENCY.value: self.max_frequency, self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image), } class SpectrogramImageConverter: """ Convert between spectrogram images and audio segments. This is a wrapper around SpectrogramConverter that additionally converts from spectrograms to images and back. The real audio processing lives in SpectrogramConverter. """ def __init__(self, params: SpectrogramParams, device: str = "cuda"): self.p = params self.device = device self.converter = SpectrogramConverter(params=params, device=device) def spectrogram_image_from_audio( self, segment: pydub.AudioSegment, ) -> Image.Image: """ Compute a spectrogram image from an audio segment. Args: segment: Audio segment to convert Returns: Spectrogram image (in pillow format) """ assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch" if self.p.stereo: if segment.channels == 1: print("WARNING: Mono audio but stereo=True, cloning channel") segment = segment.set_channels(2) elif segment.channels > 2: print("WARNING: Multi channel audio, reducing to stereo") segment = segment.set_channels(2) else: if segment.channels > 1: print("WARNING: Stereo audio but stereo=False, setting to mono") segment = segment.set_channels(1) spectrogram = self.converter.spectrogram_from_audio(segment) image = image_from_spectrogram( spectrogram, power=self.p.power_for_image, ) # Store conversion params in exif metadata of the image exif_data = self.p.to_exif() exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram)) exif = image.getexif() exif.update(exif_data.items()) return image def audio_from_spectrogram_image( self, image: Image.Image, apply_filters: bool = True, max_value: float = 30e6, ) -> pydub.AudioSegment: """ Reconstruct an audio segment from a spectrogram image. Args: image: Spectrogram image (in pillow format) apply_filters: Apply post-processing to improve the reconstructed audio max_value: Scaled max amplitude of the spectrogram. Shouldn't matter. """ spectrogram = spectrogram_from_image( image, max_value=max_value, power=self.p.power_for_image, stereo=self.p.stereo, ) segment = self.converter.audio_from_spectrogram( spectrogram, apply_filters=apply_filters, ) return segment def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image: """ Compute a spectrogram image from a spectrogram magnitude array. This is the inverse of spectrogram_from_image, except for discretization error from quantizing to uint8. Args: spectrogram: (channels, frequency, time) power: A power curve to apply to the spectrogram to preserve contrast Returns: image: (frequency, time, channels) """ # Rescale to 0-1 max_value = np.max(spectrogram) data = spectrogram / max_value # Apply the power curve data = np.power(data, power) # Rescale to 0-255 data = data * 255 # Invert data = 255 - data # Convert to uint8 data = data.astype(np.uint8) # Munge channels into a PIL image if data.shape[0] == 1: # TODO(hayk): Do we want to write single channel to disk instead? image = Image.fromarray(data[0], mode="L").convert("RGB") elif data.shape[0] == 2: data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0) image = Image.fromarray(data, mode="RGB") else: raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}") # Flip Y image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) return image def spectrogram_from_image( image: Image.Image, power: float = 0.25, stereo: bool = False, max_value: float = 30e6, ) -> np.ndarray: """ Compute a spectrogram magnitude array from a spectrogram image. This is the inverse of image_from_spectrogram, except for discretization error from quantizing to uint8. Args: image: (frequency, time, channels) power: The power curve applied to the spectrogram stereo: Whether the spectrogram encodes stereo data max_value: The max value of the original spectrogram. In practice doesn't matter. Returns: spectrogram: (channels, frequency, time) """ # Convert to RGB if single channel if image.mode in ("P", "L"): image = image.convert("RGB") # Flip Y image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) # Munge channels into a numpy array of (channels, frequency, time) data = np.array(image).transpose(2, 0, 1) if stereo: # Take the G and B channels as done in image_from_spectrogram data = data[[1, 2], :, :] else: data = data[0:1, :, :] # Convert to floats data = data.astype(np.float32) # Invert data = 255 - data # Rescale to 0-1 data = data / 255 # Reverse the power curve data = np.power(data, 1 / power) # Rescale to max value data = data * max_value return data class SpectrogramConverter: """ Convert between audio segments and spectrogram tensors using torchaudio. In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values that represent the amplitude of the frequency at that time bucket (in the frequency domain). Frequencies are given in the perceptul Mel scale defined by the params. A more specific term used in some functions is "mel amplitudes". The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only returns the amplitude, because the phase is chaotic and hard to learn. The function `audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which approximates the phase information using the Griffin-Lim algorithm. Each channel in the audio is treated independently, and the spectrogram has a batch dimension equal to the number of channels in the input audio segment. Both the Griffin Lim algorithm and the Mel scaling process are lossy. For more information, see https://pytorch.org/audio/stable/transforms.html """ def __init__(self, params: SpectrogramParams, device: str = "cuda"): self.p = params self.device = check_device(device) if device.lower().startswith("mps"): warnings.warn( "WARNING: MPS does not support audio operations, falling back to CPU for them", stacklevel=2, ) self.device = "cpu" # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html self.spectrogram_func = torchaudio.transforms.Spectrogram( n_fft=params.n_fft, hop_length=params.hop_length, win_length=params.win_length, pad=0, window_fn=torch.hann_window, power=None, normalized=False, wkwargs=None, center=True, pad_mode="reflect", onesided=True, ).to(self.device) # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim( n_fft=params.n_fft, n_iter=params.num_griffin_lim_iters, win_length=params.win_length, hop_length=params.hop_length, window_fn=torch.hann_window, power=1.0, wkwargs=None, momentum=0.99, length=None, rand_init=True, ).to(self.device) # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html self.mel_scaler = torchaudio.transforms.MelScale( n_mels=params.num_frequencies, sample_rate=params.sample_rate, f_min=params.min_frequency, f_max=params.max_frequency, n_stft=params.n_fft // 2 + 1, norm=params.mel_scale_norm, mel_scale=params.mel_scale_type, ).to(self.device) # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale( n_stft=params.n_fft // 2 + 1, n_mels=params.num_frequencies, sample_rate=params.sample_rate, f_min=params.min_frequency, f_max=params.max_frequency, # max_iter=params.max_mel_iters, # for higher verson of torchaudio # tolerance_loss=1e-5, # for higher verson of torchaudio # tolerance_change=1e-8, # for higher verson of torchaudio # sgdargs=None, # for higher verson of torchaudio norm=params.mel_scale_norm, mel_scale=params.mel_scale_type, ).to(self.device) def spectrogram_from_audio( self, audio: pydub.AudioSegment, ) -> np.ndarray: """ Compute a spectrogram from an audio segment. Args: audio: Audio segment which must match the sample rate of the params Returns: spectrogram: (channel, frequency, time) """ assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params" # Get the samples as a numpy array in (batch, samples) shape waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()]) # Convert to floats if necessary if waveform.dtype != np.float32: waveform = waveform.astype(np.float32) waveform_tensor = torch.from_numpy(waveform).to(self.device) amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor) return amplitudes_mel.cpu().numpy() def audio_from_spectrogram( self, spectrogram: np.ndarray, apply_filters: bool = True, ) -> pydub.AudioSegment: """ Reconstruct an audio segment from a spectrogram. Args: spectrogram: (batch, frequency, time) apply_filters: Post-process with normalization and compression Returns: audio: Audio segment with channels equal to the batch dimension """ # Move to device amplitudes_mel = torch.from_numpy(spectrogram).to(self.device) # Reconstruct the waveform waveform = self.waveform_from_mel_amplitudes(amplitudes_mel) # Convert to audio segment segment = audio_from_waveform( samples=waveform.cpu().numpy(), sample_rate=self.p.sample_rate, # Normalize the waveform to the range [-1, 1] normalize=True, ) # Optionally apply post-processing filters if apply_filters: segment = apply_filters_func( segment, compression=False, ) return segment def mel_amplitudes_from_waveform( self, waveform: torch.Tensor, ) -> torch.Tensor: """ Torch-only function to compute Mel-scale amplitudes from a waveform. Args: waveform: (batch, samples) Returns: amplitudes_mel: (batch, frequency, time) """ # Compute the complex-valued spectrogram spectrogram_complex = self.spectrogram_func(waveform) # Take the magnitude amplitudes = torch.abs(spectrogram_complex) # Convert to mel scale return self.mel_scaler(amplitudes) def waveform_from_mel_amplitudes( self, amplitudes_mel: torch.Tensor, ) -> torch.Tensor: """ Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes. Args: amplitudes_mel: (batch, frequency, time) Returns: waveform: (batch, samples) """ # Convert from mel scale to linear amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel) # Run the approximate algorithm to compute the phase and recover the waveform return self.inverse_spectrogram_func(amplitudes_linear) def check_device(device: str, backup: str = "cpu") -> str: """ Check that the device is valid and available. If not, """ cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available() mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available() if cuda_not_found or mps_not_found: warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3) return backup return device def audio_from_waveform( samples: np.ndarray, sample_rate: int, normalize: bool = False ) -> pydub.AudioSegment: """ Convert a numpy array of samples of a waveform to an audio segment. Args: samples: (channels, samples) array """ # Normalize volume to fit in int16 if normalize: samples *= np.iinfo(np.int16).max / np.max(np.abs(samples)) # Transpose and convert to int16 samples = samples.transpose(1, 0) samples = samples.astype(np.int16) # Write to the bytes of a WAV file wav_bytes = io.BytesIO() wavfile.write(wav_bytes, sample_rate, samples) wav_bytes.seek(0) # Read into pydub return pydub.AudioSegment.from_wav(wav_bytes) def apply_filters_func(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment: """ Apply post-processing filters to the audio segment to compress it and keep at a -10 dBFS level. """ # TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end. # TODO(hayk): Is this going to make audio unbalanced between sequential clips? if compression: segment = pydub.effects.normalize( segment, headroom=0.1, ) segment = segment.apply_gain(-10 - segment.dBFS) # TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU segment = pydub.effects.compress_dynamic_range( segment, threshold=-20.0, ratio=4.0, attack=5.0, release=50.0, ) desired_db = -12 segment = segment.apply_gain(desired_db - segment.dBFS) segment = pydub.effects.normalize( segment, headroom=0.1, ) return segment def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: return ".".join(path.split(".")[n_shave_prefix_segments:]) else: return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item.replace("in_layers.0", "norm1") new_item = new_item.replace("in_layers.2", "conv1") new_item = new_item.replace("out_layers.0", "norm2") new_item = new_item.replace("out_layers.3", "conv2") new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item # new_item = new_item.replace('norm.weight', 'group_norm.weight') # new_item = new_item.replace('norm.bias', 'group_norm.bias') # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("q.weight", "to_q.weight") new_item = new_item.replace("q.bias", "to_q.bias") new_item = new_item.replace("k.weight", "to_k.weight") new_item = new_item.replace("k.bias", "to_k.bias") new_item = new_item.replace("v.weight", "to_v.weight") new_item = new_item.replace("v.bias", "to_v.bias") new_item = new_item.replace("proj_out.weight", "to_out.0.weight") new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint. """ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: for path, path_map in attention_paths_to_split.items(): old_tensor = old_checkpoint[path] channels = old_tensor.shape[0] // 3 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["key"]] = key.reshape(target_shape) checkpoint[path_map["value"]] = value.reshape(target_shape) for path in paths: new_path = path["new"] # These have already been assigned if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear if "proj_attn.weight" in new_path: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] elif 'to_out.0.weight' in new_path: checkpoint[new_path] = old_checkpoint[path['old']].squeeze() elif any([qkv in new_path for qkv in ['to_q', 'to_k', 'to_v']]): checkpoint[new_path] = old_checkpoint[path['old']].squeeze() else: checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) attn_keys = ["query.weight", "key.weight", "value.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: unet_params = original_config.model.params.control_stage_config.params else: unet_params = original_config.model.params.unet_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: head_dim = [5, 10, 20, 20] class_embed_type = None projection_class_embeddings_input_dim = None if "num_classes" in unet_params: if unet_params.num_classes == "sequential": class_embed_type = "projection" assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params.adm_in_channels else: raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") config = { "sample_size": image_size // vae_scale_factor, "in_channels": unet_params.in_channels, "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params.num_res_blocks, "cross_attention_dim": unet_params.context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, } if not controlnet: config["out_channels"] = unet_params.out_channels config["up_block_types"] = tuple(up_block_types) return config def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ vae_params = original_config.model.params.first_stage_config.params.ddconfig _ = original_config.model.params.first_stage_config.params.embed_dim block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, "in_channels": vae_params.in_channels, "out_channels": vae_params.out_ch, "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), "latent_channels": vae_params.z_channels, "layers_per_block": vae_params.num_res_blocks, } return config def create_diffusers_schedular(original_config): schedular = DDIMScheduler( num_train_timesteps=original_config.model.params.timesteps, beta_start=original_config.model.params.linear_start, beta_end=original_config.model.params.linear_end, beta_schedule="scaled_linear", ) return schedular def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ # extract state_dict for UNet unet_state_dict = {} keys = list(checkpoint.keys()) if controlnet: unet_key = "control_model." else: unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: print(f"Checkpoint {path} has both EMA and non-EMA weights.") print( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: print( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port ... elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] else: raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) } for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) if len(attentions): paths = renew_attention_paths(attentions) meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] attentions = middle_blocks[1] resnet_1 = middle_blocks[2] resnet_0_paths = renew_resnet_paths(resnet_0) assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) resnet_1_paths = renew_resnet_paths(resnet_1) assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) layer_in_block_id = i % (config["layers_per_block"] + 1) output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_list = {} for layer in output_block_layers: layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) if layer_id in output_block_list: output_block_list[layer_id].append(layer_name) else: output_block_list[layer_id] = [layer_name] if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] # Clear attentions as they have been attributed above. if len(attentions) == 2: attentions = [] if len(attentions): paths = renew_attention_paths(attentions) meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = unet_state_dict[old_path] if controlnet: # conditioning embedding orig_index = 0 new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) orig_index += 2 diffusers_index = 0 while diffusers_index < 6: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) diffusers_index += 1 orig_index += 2 new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) # down blocks for i in range(num_input_blocks): new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") # mid block new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False): # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." keys = list(checkpoint.keys()) for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.weight" ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.bias" ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) if only_decoder: new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('decoder') or k.startswith('post_quant')} elif only_encoder: new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('encoder') or k.startswith('quant')} return new_checkpoint def convert_ldm_clip_checkpoint(checkpoint): keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] return text_model_dict def convert_lora_model_level(state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): """convert lora in model level instead of pipeline leval """ visited = [] # directly update weight in diffusers model for key in state_dict: # it is suggested to print out the key, it usually will be something like below # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" # as we have set the alpha beforehand, so just skip if ".alpha" in key or key in visited: continue if "text" in key: layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") assert text_encoder is not None, ( 'text_encoder must be passed since lora contains text encoder layers') curr_layer = text_encoder else: layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") curr_layer = unet # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += "_" + layer_infos.pop(0) else: temp_name = layer_infos.pop(0) pair_keys = [] if "lora_down" in key: pair_keys.append(key.replace("lora_down", "lora_up")) pair_keys.append(key) else: pair_keys.append(key) pair_keys.append(key.replace("lora_up", "lora_down")) # update weight # NOTE: load lycon, meybe have bugs :( if 'conv_in' in pair_keys[0]: weight_up = state_dict[pair_keys[0]].to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) weight_up = weight_up.view(weight_up.size(0), -1) weight_down = weight_down.view(weight_down.size(0), -1) shape = [e for e in curr_layer.weight.data.shape] shape[1] = 4 curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape) elif 'conv' in pair_keys[0]: weight_up = state_dict[pair_keys[0]].to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) weight_up = weight_up.view(weight_up.size(0), -1) weight_down = weight_down.view(weight_down.size(0), -1) shape = [e for e in curr_layer.weight.data.shape] curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape) elif len(state_dict[pair_keys[0]].shape) == 4: weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) else: weight_up = state_dict[pair_keys[0]].to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) # update visited list for item in pair_keys: visited.append(item) return unet, text_encoder def denormalize_spectrogram( data: torch.Tensor, max_value: float = 200, min_value: float = 1e-5, power: float = 1, inverse: bool = False, ) -> torch.Tensor: max_value = np.log(max_value) min_value = np.log(min_value) # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner data = torch.flip(data, [1]) assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape)) if data.shape[0] == 1: data = data.repeat(3, 1, 1) assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0]) data = data[0] # Reverse the power curve data = torch.pow(data, 1 / power) # Invert if inverse: data = 1 - data # Rescale to max value spectrogram = data * (max_value - min_value) + min_value return spectrogram class ToTensor1D(torchvision.transforms.ToTensor): def __call__(self, tensor: np.ndarray): tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis]) return tensor_2d.squeeze_(0) def scale(old_value, old_min, old_max, new_min, new_max): old_range = (old_max - old_min) new_range = (new_max - new_min) new_value = (((old_value - old_min) * new_range) / old_range) + new_min return new_value def read_frames_with_moviepy(video_path, max_frame_nums=None): clip = VideoFileClip(video_path) duration = clip.duration frames = [] for frame in clip.iter_frames(): frames.append(frame) if max_frame_nums is not None: frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int) return np.array(frames)[frames_idx,...], duration def read_frames_with_moviepy_resample(video_path, save_path): vision_transform_list = [ transforms.Resize((128, 128)), transforms.CenterCrop((112, 112)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] video_transform = transforms.Compose(vision_transform_list) os.makedirs(save_path, exist_ok=True) command = f'ffmpeg -v quiet -y -i \"{video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg' os.system(command) frame_list = glob.glob(f'{save_path}/*.jpg') frame_list.sort() convert_tensor = transforms.ToTensor() frame_list = [convert_tensor(np.array(Image.open(frame))) for frame in frame_list] imgs = torch.stack(frame_list, dim=0) imgs = video_transform(imgs) imgs = imgs.permute(1, 0, 2, 3) return imgs