AudioToken / app.py
genevera
dont use xformers if cuda isnt available
0c0d5ca
raw
history blame
10.7 kB
import torch
import numpy as np
import gradio as gr
from scipy import signal
from diffusers.utils import logging
logging.set_verbosity_error()
from diffusers.loaders import AttnProcsLayers
from transformers import CLIPTextModel, CLIPTokenizer
from modules.beats.BEATs import BEATs, BEATsConfig
from modules.AudioToken.embedder import FGAEmbedder
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers import StableDiffusionPipeline
from diffusers import (
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
DEISMultistepScheduler,
UniPCMultistepScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
)
class AudioTokenWrapper(torch.nn.Module):
"""Simple wrapper module for Stable Diffusion that holds all the models together"""
def __init__(
self,
lora,
device,
):
super().__init__()
self.repo_id = repo_id
# Load scheduler and models
self.ddpm = DDPMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
self.ddim = DDIMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
self.pndm = PNDMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
self.lms = LMSDiscreteScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.euler = EulerDiscreteScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.dpm = DPMSolverMultistepScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.dpms = DPMSolverSinglestepScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.deis = DEISMultistepScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.unipc = UniPCMultistepScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.heun = HeunDiscreteScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.kdpm2_anc = KDPM2AncestralDiscreteScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.kdpm2 = KDPM2DiscreteScheduler.from_pretrained(
self.repo_id, subfolder="scheduler"
)
self.tokenizer = CLIPTokenizer.from_pretrained(
self.repo_id, subfolder="tokenizer"
)
self.text_encoder = CLIPTextModel.from_pretrained(
self.repo_id, subfolder="text_encoder", revision=None
)
self.unet = UNet2DConditionModel.from_pretrained(
self.repo_id, subfolder="unet", revision=None
)
self.vae = AutoencoderKL.from_pretrained(
self.repo_id, subfolder="vae", revision=None
)
checkpoint = torch.load(
"models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt"
)
cfg = BEATsConfig(checkpoint["cfg"])
self.aud_encoder = BEATs(cfg)
self.aud_encoder.load_state_dict(checkpoint["model"])
self.aud_encoder.predictor = None
input_size = 768 * 3
self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
self.vae.eval()
self.unet.eval()
self.text_encoder.eval()
self.aud_encoder.eval()
if lora:
# Set correct lora layers
lora_attn_procs = {}
for name in self.unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else self.unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = self.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
self.unet.set_attn_processor(lora_attn_procs)
self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
self.lora_layers.eval()
lora_layers_learned_embeds = "models/lora_layers_learned_embeds.bin"
self.lora_layers.load_state_dict(
torch.load(lora_layers_learned_embeds, map_location=device)
)
self.unet.load_attn_procs(lora_layers_learned_embeds)
self.embedder.eval()
embedder_learned_embeds = "models/embedder_learned_embeds.bin"
self.embedder.load_state_dict(
torch.load(embedder_learned_embeds, map_location=device)
)
self.placeholder_token = "<*>"
num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(
self.placeholder_token
)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
def greet(audio, steps=25, scheduler="ddpm"):
sample_rate, audio = audio
audio = audio.astype(np.float32, order="C") / 32768.0
desired_sample_rate = 16000
match scheduler:
case "ddpm":
use_sched = model.ddpm
case "ddim":
use_sched = model.ddim
case "pndm":
use_sched = model.pndm
case "lms":
use_sched = model.lms
case "euler_anc":
use_sched = model.euler_anc
case "euler":
use_sched = model.euler
case "dpm":
use_sched = model.dpm
case "dpms":
use_sched = model.dpms
case "deis":
use_sched = model.deis
case "unipc":
use_sched = model.unipc
case "heun":
use_sched = model.heun
case "kdpm2_anc":
use_sched = model.kdpm2_anc
case "kdpm2":
use_sched = model.kdpm2
if audio.ndim == 2:
audio = audio.sum(axis=1) / 2
if sample_rate != desired_sample_rate:
# Calculate the resampling ratio
resample_ratio = desired_sample_rate / sample_rate
# Determine the new length of the audio data after downsampling
new_length = int(len(audio) * resample_ratio)
# Downsample the audio data using resample
audio = signal.resample(audio, new_length)
weight_dtype = torch.float32
prompt = "a photo of <*>"
audio_values = (
torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
)
if audio_values.ndim == 1:
audio_values = torch.unsqueeze(audio_values, dim=0)
# i dont know why but this seems mandatory for deterministic results
with torch.no_grad():
aud_features = model.aud_encoder.extract_features(audio_values)[1]
audio_token = model.embedder(aud_features)
token_embeds = model.text_encoder.get_input_embeddings().weight.data
token_embeds[model.placeholder_token_id] = audio_token.clone()
generator = torch.Generator(device=device)
generator.manual_seed(23229249375547) # no reason this can't be input by the user!
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=model.repo_id,
tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
vae=model.vae,
unet=model.unet,
scheduler=use_sched,
safety_checker=None,
).to(device)
pipeline.enable_attention_slicing()
if torch.cuda.is_available():
pipeline.enable_xformers_memory_efficient_attention()
# print(f"taking {steps} steps using the {scheduler} scheduler")
image = pipeline(
prompt, num_inference_steps=steps, guidance_scale=8.5, generator=generator
).images[0]
return image
lora = False
repo_id = "philz1337/reliberate"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AudioTokenWrapper(lora, device)
model = model.to(device)
description = """<p>
This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br>
A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.<br><br>
For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>.
</p>"""
examples = [
# ["assets/train.wav"],
# ["assets/dog barking.wav"],
# ["assets/airplane taking off.wav"],
# ["assets/electric guitar.wav"],
# ["assets/female sings.wav"],
]
my_demo = gr.Interface(
fn=greet,
inputs=[
"audio",
gr.Slider(value=25, step=1, label="diffusion steps"),
gr.Dropdown(
choices=[
"ddim",
"ddpm",
"pndm",
"lms",
"euler_anc",
"euler",
"dpm",
"dpms",
"deis",
"unipc",
"heun",
"kdpm2_anc",
"kdpm2",
],
value="unipc",
),
],
outputs="image",
title="AudioToken",
description=description,
examples=examples,
)
my_demo.launch()