Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import gradio as gr | |
from scipy import signal | |
from diffusers.utils import logging | |
logging.set_verbosity_error() | |
from diffusers.loaders import AttnProcsLayers | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from modules.beats.BEATs import BEATs, BEATsConfig | |
from modules.AudioToken.embedder import FGAEmbedder | |
from diffusers import AutoencoderKL, UNet2DConditionModel | |
from diffusers.models.attention_processor import LoRAAttnProcessor | |
from diffusers import StableDiffusionPipeline | |
from diffusers import ( | |
DDPMScheduler, | |
DDIMScheduler, | |
PNDMScheduler, | |
LMSDiscreteScheduler, | |
EulerDiscreteScheduler, | |
EulerAncestralDiscreteScheduler, | |
DPMSolverMultistepScheduler, | |
DPMSolverSinglestepScheduler, | |
DEISMultistepScheduler, | |
UniPCMultistepScheduler, | |
HeunDiscreteScheduler, | |
KDPM2AncestralDiscreteScheduler, | |
KDPM2DiscreteScheduler, | |
) | |
class AudioTokenWrapper(torch.nn.Module): | |
"""Simple wrapper module for Stable Diffusion that holds all the models together""" | |
def __init__( | |
self, | |
lora, | |
device, | |
): | |
super().__init__() | |
self.repo_id = repo_id | |
# Load scheduler and models | |
self.ddpm = DDPMScheduler.from_pretrained(self.repo_id, subfolder="scheduler") | |
self.ddim = DDIMScheduler.from_pretrained(self.repo_id, subfolder="scheduler") | |
self.pndm = PNDMScheduler.from_pretrained(self.repo_id, subfolder="scheduler") | |
self.lms = LMSDiscreteScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.euler_anc = EulerAncestralDiscreteScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.euler = EulerDiscreteScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.dpm = DPMSolverMultistepScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.dpms = DPMSolverSinglestepScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.deis = DEISMultistepScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.unipc = UniPCMultistepScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.heun = HeunDiscreteScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.kdpm2_anc = KDPM2AncestralDiscreteScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.kdpm2 = KDPM2DiscreteScheduler.from_pretrained( | |
self.repo_id, subfolder="scheduler" | |
) | |
self.tokenizer = CLIPTokenizer.from_pretrained( | |
self.repo_id, subfolder="tokenizer" | |
) | |
self.text_encoder = CLIPTextModel.from_pretrained( | |
self.repo_id, subfolder="text_encoder", revision=None | |
) | |
self.unet = UNet2DConditionModel.from_pretrained( | |
self.repo_id, subfolder="unet", revision=None | |
) | |
self.vae = AutoencoderKL.from_pretrained( | |
self.repo_id, subfolder="vae", revision=None | |
) | |
checkpoint = torch.load( | |
"models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt" | |
) | |
cfg = BEATsConfig(checkpoint["cfg"]) | |
self.aud_encoder = BEATs(cfg) | |
self.aud_encoder.load_state_dict(checkpoint["model"]) | |
self.aud_encoder.predictor = None | |
input_size = 768 * 3 | |
self.embedder = FGAEmbedder(input_size=input_size, output_size=768) | |
self.vae.eval() | |
self.unet.eval() | |
self.text_encoder.eval() | |
self.aud_encoder.eval() | |
if lora: | |
# Set correct lora layers | |
lora_attn_procs = {} | |
for name in self.unet.attn_processors.keys(): | |
cross_attention_dim = ( | |
None | |
if name.endswith("attn1.processor") | |
else self.unet.config.cross_attention_dim | |
) | |
if name.startswith("mid_block"): | |
hidden_size = self.unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(self.unet.config.block_out_channels))[ | |
block_id | |
] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = self.unet.config.block_out_channels[block_id] | |
lora_attn_procs[name] = LoRAAttnProcessor( | |
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | |
) | |
self.unet.set_attn_processor(lora_attn_procs) | |
self.lora_layers = AttnProcsLayers(self.unet.attn_processors) | |
self.lora_layers.eval() | |
lora_layers_learned_embeds = "models/lora_layers_learned_embeds.bin" | |
self.lora_layers.load_state_dict( | |
torch.load(lora_layers_learned_embeds, map_location=device) | |
) | |
self.unet.load_attn_procs(lora_layers_learned_embeds) | |
self.embedder.eval() | |
embedder_learned_embeds = "models/embedder_learned_embeds.bin" | |
self.embedder.load_state_dict( | |
torch.load(embedder_learned_embeds, map_location=device) | |
) | |
self.placeholder_token = "<*>" | |
num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token) | |
if num_added_tokens == 0: | |
raise ValueError( | |
f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different" | |
" `placeholder_token` that is not already in the tokenizer." | |
) | |
self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids( | |
self.placeholder_token | |
) | |
# Resize the token embeddings as we are adding new special tokens to the tokenizer | |
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) | |
def greet(audio, steps=25, scheduler="ddpm"): | |
sample_rate, audio = audio | |
audio = audio.astype(np.float32, order="C") / 32768.0 | |
desired_sample_rate = 16000 | |
match scheduler: | |
case "ddpm": | |
use_sched = model.ddpm | |
case "ddim": | |
use_sched = model.ddim | |
case "pndm": | |
use_sched = model.pndm | |
case "lms": | |
use_sched = model.lms | |
case "euler_anc": | |
use_sched = model.euler_anc | |
case "euler": | |
use_sched = model.euler | |
case "dpm": | |
use_sched = model.dpm | |
case "dpms": | |
use_sched = model.dpms | |
case "deis": | |
use_sched = model.deis | |
case "unipc": | |
use_sched = model.unipc | |
case "heun": | |
use_sched = model.heun | |
case "kdpm2_anc": | |
use_sched = model.kdpm2_anc | |
case "kdpm2": | |
use_sched = model.kdpm2 | |
if audio.ndim == 2: | |
audio = audio.sum(axis=1) / 2 | |
if sample_rate != desired_sample_rate: | |
# Calculate the resampling ratio | |
resample_ratio = desired_sample_rate / sample_rate | |
# Determine the new length of the audio data after downsampling | |
new_length = int(len(audio) * resample_ratio) | |
# Downsample the audio data using resample | |
audio = signal.resample(audio, new_length) | |
weight_dtype = torch.float32 | |
prompt = "a photo of <*>" | |
audio_values = ( | |
torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype) | |
) | |
if audio_values.ndim == 1: | |
audio_values = torch.unsqueeze(audio_values, dim=0) | |
# i dont know why but this seems mandatory for deterministic results | |
with torch.no_grad(): | |
aud_features = model.aud_encoder.extract_features(audio_values)[1] | |
audio_token = model.embedder(aud_features) | |
token_embeds = model.text_encoder.get_input_embeddings().weight.data | |
token_embeds[model.placeholder_token_id] = audio_token.clone() | |
generator = torch.Generator(device=device) | |
generator.manual_seed(23229249375547) # no reason this can't be input by the user! | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path=model.repo_id, | |
tokenizer=model.tokenizer, | |
text_encoder=model.text_encoder, | |
vae=model.vae, | |
unet=model.unet, | |
scheduler=use_sched, | |
safety_checker=None, | |
).to(device) | |
pipeline.enable_attention_slicing() | |
if torch.cuda.is_available(): | |
pipeline.enable_xformers_memory_efficient_attention() | |
# print(f"taking {steps} steps using the {scheduler} scheduler") | |
image = pipeline( | |
prompt, num_inference_steps=steps, guidance_scale=8.5, generator=generator | |
).images[0] | |
return image | |
lora = False | |
repo_id = "philz1337/reliberate" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model = AudioTokenWrapper(lora, device) | |
model = model.to(device) | |
description = """<p> | |
This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br> | |
A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.<br><br> | |
For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>. | |
</p>""" | |
examples = [ | |
# ["assets/train.wav"], | |
# ["assets/dog barking.wav"], | |
# ["assets/airplane taking off.wav"], | |
# ["assets/electric guitar.wav"], | |
# ["assets/female sings.wav"], | |
] | |
my_demo = gr.Interface( | |
fn=greet, | |
inputs=[ | |
"audio", | |
gr.Slider(value=25, step=1, label="diffusion steps"), | |
gr.Dropdown( | |
choices=[ | |
"ddim", | |
"ddpm", | |
"pndm", | |
"lms", | |
"euler_anc", | |
"euler", | |
"dpm", | |
"dpms", | |
"deis", | |
"unipc", | |
"heun", | |
"kdpm2_anc", | |
"kdpm2", | |
], | |
value="unipc", | |
), | |
], | |
outputs="image", | |
title="AudioToken", | |
description=description, | |
examples=examples, | |
) | |
my_demo.launch() | |