diff --git a/.ipynb_checkpoints/app-checkpoint.py b/.ipynb_checkpoints/app-checkpoint.py index aa0703928777487a88eb6c87fd31edb01a1e6252..869a2f3eb47978921e3158f2c0b9dd4681ca9aba 100644 --- a/.ipynb_checkpoints/app-checkpoint.py +++ b/.ipynb_checkpoints/app-checkpoint.py @@ -2,6 +2,7 @@ import gradio as gr import json import torch import wavio +import numpy as np from tqdm import tqdm from huggingface_hub import snapshot_download @@ -23,6 +24,7 @@ class MusicFeaturePredictor: def __init__(self, path, device="cuda:0", cache_dir=None, local_files_only=False): self.beats_tokenizer = AutoTokenizer.from_pretrained( "microsoft/deberta-v3-large", + use_fast=False, cache_dir=cache_dir, local_files_only=local_files_only, ) @@ -164,6 +166,7 @@ class Mustango: main_config["scheduler_name"], unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json", ).to(device) + self.model.device = device vae_weights = torch.load( f"{path}/vae/pytorch_model_vae.bin", map_location=device @@ -213,9 +216,11 @@ class Mustango: # Initialize Mustango if torch.cuda.is_available(): - mustango = Mustango() + mustango = Mustango(device="cpu") else: mustango = Mustango(device="cpu") + +output_wave = mustango.generate("This techno song features a synth lead playing the main melody.", 5, 3, disable_progress=False) def gradio_generate(prompt, steps, guidance): output_wave = mustango.generate(prompt, steps, guidance) @@ -225,6 +230,7 @@ def gradio_generate(prompt, steps, guidance): return output_filename + # description_text = """ #
For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings.
# Generate music using Mustango by providing a text prompt.
diff --git a/.ipynb_checkpoints/modelling_deberta_v2-checkpoint.py b/.ipynb_checkpoints/modelling_deberta_v2-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6788c005e28b1351cee584c82d8f8182b18d7a10
--- /dev/null
+++ b/.ipynb_checkpoints/modelling_deberta_v2-checkpoint.py
@@ -0,0 +1,1750 @@
+# coding=utf-8
+# Copyright 2020 Microsoft and the Hugging Face Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch DeBERTa-v2 model."""
+
+from collections.abc import Sequence
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ ModelOutput,
+ BaseModelOutput,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import softmax_backward_data
+from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DebertaV2Config"
+_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
+_QA_TARGET_START_INDEX = 2
+_QA_TARGET_END_INDEX = 9
+
+DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/deberta-v2-xlarge",
+ "microsoft/deberta-v2-xxlarge",
+ "microsoft/deberta-v2-xlarge-mnli",
+ "microsoft/deberta-v2-xxlarge-mnli",
+]
+
+
+# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
+class ContextPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
+ self.dropout = StableDropout(config.pooler_dropout)
+ self.config = config
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+
+ context_token = hidden_states[:, 0]
+ context_token = self.dropout(context_token)
+ pooled_output = self.dense(context_token)
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
+ return pooled_output
+
+ @property
+ def output_dim(self):
+ return self.config.hidden_size
+
+
+# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
+class XSoftmax(torch.autograd.Function):
+ """
+ Masked Softmax which is optimized for saving memory
+
+ Args:
+ input (`torch.tensor`): The input tensor that will apply softmax.
+ mask (`torch.IntTensor`):
+ The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
+ dim (int): The dimension that will apply softmax
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
+
+ >>> # Make a tensor
+ >>> x = torch.randn([4, 20, 100])
+
+ >>> # Create a mask
+ >>> mask = (x > 0).int()
+
+ >>> # Specify the dimension to apply softmax
+ >>> dim = -1
+
+ >>> y = XSoftmax.apply(x, mask, dim)
+ ```"""
+
+ @staticmethod
+ def forward(self, input, mask, dim):
+ self.dim = dim
+ rmask = ~(mask.to(torch.bool))
+
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
+ output = torch.softmax(output, self.dim)
+ output.masked_fill_(rmask, 0)
+ self.save_for_backward(output)
+ return output
+
+ @staticmethod
+ def backward(self, grad_output):
+ (output,) = self.saved_tensors
+ inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
+ return inputGrad, None, None
+
+ @staticmethod
+ def symbolic(g, self, mask, dim):
+ import torch.onnx.symbolic_helper as sym_help
+ from torch.onnx.symbolic_opset9 import masked_fill, softmax
+
+ mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
+ r_mask = g.op(
+ "Cast",
+ g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
+ to_i=sym_help.cast_pytorch_to_onnx["Bool"],
+ )
+ output = masked_fill(
+ g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
+ )
+ output = softmax(g, output, dim)
+ return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
+class DropoutContext(object):
+ def __init__(self):
+ self.dropout = 0
+ self.mask = None
+ self.scale = 1
+ self.reuse_mask = True
+
+
+# Copied from transformers.models.deberta.modeling_deberta.get_mask
+def get_mask(input, local_context):
+ if not isinstance(local_context, DropoutContext):
+ dropout = local_context
+ mask = None
+ else:
+ dropout = local_context.dropout
+ dropout *= local_context.scale
+ mask = local_context.mask if local_context.reuse_mask else None
+
+ if dropout > 0 and mask is None:
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
+
+ if isinstance(local_context, DropoutContext):
+ if local_context.mask is None:
+ local_context.mask = mask
+
+ return mask, dropout
+
+
+# Copied from transformers.models.deberta.modeling_deberta.XDropout
+class XDropout(torch.autograd.Function):
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
+
+ @staticmethod
+ def forward(ctx, input, local_ctx):
+ mask, dropout = get_mask(input, local_ctx)
+ ctx.scale = 1.0 / (1 - dropout)
+ if dropout > 0:
+ ctx.save_for_backward(mask)
+ return input.masked_fill(mask, 0) * ctx.scale
+ else:
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.scale > 1:
+ (mask,) = ctx.saved_tensors
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
+ else:
+ return grad_output, None
+
+ @staticmethod
+ def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
+ from torch.onnx import symbolic_opset12
+
+ dropout_p = local_ctx
+ if isinstance(local_ctx, DropoutContext):
+ dropout_p = local_ctx.dropout
+ # StableDropout only calls this function when training.
+ train = True
+ # TODO: We should check if the opset_version being used to export
+ # is > 12 here, but there's no good way to do that. As-is, if the
+ # opset_version < 12, export will fail with a CheckerError.
+ # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
+ # if opset_version < 12:
+ # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
+ return symbolic_opset12.dropout(g, input, dropout_p, train)
+
+
+# Copied from transformers.models.deberta.modeling_deberta.StableDropout
+class StableDropout(nn.Module):
+ """
+ Optimized dropout module for stabilizing the training
+
+ Args:
+ drop_prob (float): the dropout probabilities
+ """
+
+ def __init__(self, drop_prob):
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.count = 0
+ self.context_stack = None
+
+ def forward(self, x):
+ """
+ Call the module
+
+ Args:
+ x (`torch.tensor`): The input tensor to apply dropout
+ """
+ if self.training and self.drop_prob > 0:
+ return XDropout.apply(x, self.get_context())
+ return x
+
+ def clear_context(self):
+ self.count = 0
+ self.context_stack = None
+
+ def init_context(self, reuse_mask=True, scale=1):
+ if self.context_stack is None:
+ self.context_stack = []
+ self.count = 0
+ for c in self.context_stack:
+ c.reuse_mask = reuse_mask
+ c.scale = scale
+
+ def get_context(self):
+ if self.context_stack is not None:
+ if self.count >= len(self.context_stack):
+ self.context_stack.append(DropoutContext())
+ ctx = self.context_stack[self.count]
+ ctx.dropout = self.drop_prob
+ self.count += 1
+ return ctx
+ else:
+ return self.drop_prob
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
+class DebertaV2SelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
+class DebertaV2Attention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = DisentangledSelfAttention(config)
+ self.output = DebertaV2SelfOutput(config)
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ ):
+ self_output = self.self(
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ if output_attentions:
+ self_output, att_matrix = self_output
+ if query_states is None:
+ query_states = hidden_states
+ attention_output = self.output(self_output, query_states)
+
+ if output_attentions:
+ return (attention_output, att_matrix)
+ else:
+ return attention_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
+class DebertaV2Intermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
+class DebertaV2Output(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
+class DebertaV2Layer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = DebertaV2Attention(config)
+ self.intermediate = DebertaV2Intermediate(config)
+ self.output = DebertaV2Output(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ output_attentions=False,
+ ):
+ attention_output = self.attention(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ if output_attentions:
+ attention_output, att_matrix = attention_output
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ if output_attentions:
+ return (layer_output, att_matrix)
+ else:
+ return layer_output
+
+
+class ConvLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ kernel_size = getattr(config, "conv_kernel_size", 3)
+ groups = getattr(config, "conv_groups", 1)
+ self.conv_act = getattr(config, "conv_act", "tanh")
+ self.conv = nn.Conv1d(
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
+ )
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def forward(self, hidden_states, residual_states, input_mask):
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
+ rmask = (1 - input_mask).bool()
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
+ out = ACT2FN[self.conv_act](self.dropout(out))
+
+ layer_norm_input = residual_states + out
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
+
+ if input_mask is None:
+ output_states = output
+ else:
+ if input_mask.dim() != layer_norm_input.dim():
+ if input_mask.dim() == 4:
+ input_mask = input_mask.squeeze(1).squeeze(1)
+ input_mask = input_mask.unsqueeze(2)
+
+ input_mask = input_mask.to(output.dtype)
+ output_states = output * input_mask
+
+ return output_states
+
+
+class DebertaV2Encoder(nn.Module):
+ """Modified BertEncoder with relative position bias support"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
+ self.relative_attention = getattr(config, "relative_attention", False)
+
+ if self.relative_attention:
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+
+ self.position_buckets = getattr(config, "position_buckets", -1)
+ pos_ebd_size = self.max_relative_positions * 2
+
+ if self.position_buckets > 0:
+ pos_ebd_size = self.position_buckets * 2
+
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
+
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
+
+ if "layer_norm" in self.norm_rel_ebd:
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
+
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
+ self.gradient_checkpointing = False
+
+ def get_rel_embedding(self):
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
+ rel_embeddings = self.LayerNorm(rel_embeddings)
+ return rel_embeddings
+
+ def get_attention_mask(self, attention_mask):
+ if attention_mask.dim() <= 2:
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
+ elif attention_mask.dim() == 3:
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
+ if self.relative_attention and relative_pos is None:
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
+ relative_pos = build_relative_position(
+ q,
+ hidden_states.size(-2),
+ bucket_size=self.position_buckets,
+ max_position=self.max_relative_positions,
+ device=hidden_states.device,
+ )
+ return relative_pos
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ return_dict=True,
+ ):
+ if attention_mask.dim() <= 2:
+ input_mask = attention_mask
+ else:
+ input_mask = attention_mask.sum(-2) > 0
+ attention_mask = self.get_attention_mask(attention_mask)
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if isinstance(hidden_states, Sequence):
+ next_kv = hidden_states[0]
+ else:
+ next_kv = hidden_states
+ rel_embeddings = self.get_rel_embedding()
+ output_states = next_kv
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (output_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ output_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ next_kv,
+ attention_mask,
+ query_states,
+ relative_pos,
+ rel_embeddings,
+ )
+ else:
+ output_states = layer_module(
+ next_kv,
+ attention_mask,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ if output_attentions:
+ output_states, att_m = output_states
+
+ if i == 0 and self.conv is not None:
+ output_states = self.conv(hidden_states, output_states, input_mask)
+
+ if query_states is not None:
+ query_states = output_states
+ if isinstance(hidden_states, Sequence):
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
+ else:
+ next_kv = output_states
+
+ if output_attentions:
+ all_attentions = all_attentions + (att_m,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (output_states,)
+
+ if not return_dict:
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+def make_log_bucket_position(relative_pos, bucket_size, max_position):
+ sign = torch.sign(relative_pos)
+ mid = bucket_size // 2
+ abs_pos = torch.where(
+ (relative_pos < mid) & (relative_pos > -mid),
+ torch.tensor(mid - 1).type_as(relative_pos),
+ torch.abs(relative_pos),
+ )
+ log_pos = (
+ torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
+ )
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
+ return bucket_pos
+
+
+def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
+ """
+ Build relative position according to the query and key
+
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
+ P_k\\)
+
+ Args:
+ query_size (int): the length of query
+ key_size (int): the length of key
+ bucket_size (int): the size of position bucket
+ max_position (int): the maximum allowed absolute position
+ device (`torch.device`): the device on which tensors will be created.
+
+ Return:
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
+ """
+
+ q_ids = torch.arange(0, query_size, device=device)
+ k_ids = torch.arange(0, key_size, device=device)
+ rel_pos_ids = q_ids[:, None] - k_ids[None, :]
+ if bucket_size > 0 and max_position > 0:
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
+ rel_pos_ids = rel_pos_ids.to(torch.long)
+ rel_pos_ids = rel_pos_ids[:query_size, :]
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
+ return rel_pos_ids
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
+def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
+def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
+def pos_dynamic_expand(pos_index, p2c_att, key_layer):
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
+
+
+class DisentangledSelfAttention(nn.Module):
+ """
+ Disentangled self-attention module
+
+ Parameters:
+ config (`DebertaV2Config`):
+ A model config class instance with the configuration to build a new model. The schema is similar to
+ *BertConfig*, for more details, please refer [`DebertaV2Config`]
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+ self.num_attention_heads = config.num_attention_heads
+ _attention_head_size = config.hidden_size // config.num_attention_heads
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+
+ self.share_att_key = getattr(config, "share_att_key", False)
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
+ self.relative_attention = getattr(config, "relative_attention", False)
+
+ if self.relative_attention:
+ self.position_buckets = getattr(config, "position_buckets", -1)
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+ self.pos_ebd_size = self.max_relative_positions
+ if self.position_buckets > 0:
+ self.pos_ebd_size = self.position_buckets
+
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
+
+ if not self.share_att_key:
+ if "c2p" in self.pos_att_type:
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ if "p2c" in self.pos_att_type:
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x, attention_heads):
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ ):
+ """
+ Call the module
+
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
+ *Attention(Q,K,V)*
+
+ attention_mask (`torch.BoolTensor`):
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
+ th token.
+
+ output_attentions (`bool`, optional):
+ Whether return the attention matrix.
+
+ query_states (`torch.FloatTensor`, optional):
+ The *Q* state in *Attention(Q,K,V)*.
+
+ relative_pos (`torch.LongTensor`):
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
+
+ rel_embeddings (`torch.FloatTensor`):
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
+ \\text{max_relative_positions}\\), *hidden_size*].
+
+
+ """
+ if query_states is None:
+ query_states = hidden_states
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
+
+ rel_att = None
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ scale_factor = 1
+ if "c2p" in self.pos_att_type:
+ scale_factor += 1
+ if "p2c" in self.pos_att_type:
+ scale_factor += 1
+ scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype)
+ if self.relative_attention:
+ rel_embeddings = self.pos_dropout(rel_embeddings)
+ rel_att = self.disentangled_attention_bias(
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
+ )
+
+ if rel_att is not None:
+ attention_scores = attention_scores + rel_att
+ attention_scores = attention_scores
+ attention_scores = attention_scores.view(
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
+ )
+
+ # bsz x height x length x dimension
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
+ attention_probs = self.dropout(attention_probs)
+ context_layer = torch.bmm(
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
+ )
+ context_layer = (
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
+ .permute(0, 2, 1, 3)
+ .contiguous()
+ )
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ if output_attentions:
+ return (context_layer, attention_probs)
+ else:
+ return context_layer
+
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
+ if relative_pos is None:
+ q = query_layer.size(-2)
+ relative_pos = build_relative_position(
+ q,
+ key_layer.size(-2),
+ bucket_size=self.position_buckets,
+ max_position=self.max_relative_positions,
+ device=query_layer.device,
+ )
+ if relative_pos.dim() == 2:
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
+ elif relative_pos.dim() == 3:
+ relative_pos = relative_pos.unsqueeze(1)
+ # bsz x height x query x key
+ elif relative_pos.dim() != 4:
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
+
+ att_span = self.pos_ebd_size
+ relative_pos = relative_pos.long().to(query_layer.device)
+
+ rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
+ if self.share_att_key:
+ pos_query_layer = self.transpose_for_scores(
+ self.query_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ )
+ else:
+ if "c2p" in self.pos_att_type:
+ pos_key_layer = self.transpose_for_scores(
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ ) # .split(self.all_head_size, dim=-1)
+ if "p2c" in self.pos_att_type:
+ pos_query_layer = self.transpose_for_scores(
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ ) # .split(self.all_head_size, dim=-1)
+
+ score = 0
+ # content->position
+ if "c2p" in self.pos_att_type:
+ scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
+ c2p_att = torch.gather(
+ c2p_att,
+ dim=-1,
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
+ )
+ score += c2p_att / scale.to(dtype=c2p_att.dtype)
+
+ # position->content
+ if "p2c" in self.pos_att_type:
+ scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
+ if key_layer.size(-2) != query_layer.size(-2):
+ r_pos = build_relative_position(
+ key_layer.size(-2),
+ key_layer.size(-2),
+ bucket_size=self.position_buckets,
+ max_position=self.max_relative_positions,
+ device=query_layer.device,
+ )
+ r_pos = r_pos.unsqueeze(0)
+ else:
+ r_pos = relative_pos
+
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
+ p2c_att = torch.gather(
+ p2c_att,
+ dim=-1,
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
+ ).transpose(-1, -2)
+ score += p2c_att / scale.to(dtype=p2c_att.dtype)
+
+ return score
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm
+class DebertaV2Embeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ pad_token_id = getattr(config, "pad_token_id", 0)
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
+
+ self.position_biased_input = getattr(config, "position_biased_input", True)
+ if not self.position_biased_input:
+ self.position_embeddings = None
+ else:
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
+
+ if config.type_vocab_size > 0:
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
+
+ if self.embedding_size != config.hidden_size:
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if self.position_embeddings is not None:
+ position_embeddings = self.position_embeddings(position_ids.long())
+ else:
+ position_embeddings = torch.zeros_like(inputs_embeds)
+
+ embeddings = inputs_embeds
+ if self.position_biased_input:
+ embeddings += position_embeddings
+ if self.config.type_vocab_size > 0:
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+ embeddings += token_type_embeddings
+
+ if self.embedding_size != self.config.hidden_size:
+ embeddings = self.embed_proj(embeddings)
+
+ embeddings = self.LayerNorm(embeddings)
+
+ if mask is not None:
+ if mask.dim() != embeddings.dim():
+ if mask.dim() == 4:
+ mask = mask.squeeze(1).squeeze(1)
+ mask = mask.unsqueeze(2)
+ mask = mask.to(embeddings.dtype)
+
+ embeddings = embeddings * mask
+
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2
+class DebertaV2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DebertaV2Config
+ base_model_prefix = "deberta"
+ _keys_to_ignore_on_load_missing = ["position_ids"]
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, DebertaV2Encoder):
+ module.gradient_checkpointing = value
+
+
+DEBERTA_START_DOCSTRING = r"""
+ The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
+ Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
+ on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+
+ Parameters:
+ config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DEBERTA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
+class DebertaV2Model(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embeddings = DebertaV2Embeddings(config)
+ self.encoder = DebertaV2Encoder(config)
+ self.z_steps = 0
+ self.config = config
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embeddings.word_embeddings = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+ encoded_layers = encoder_outputs[1]
+
+ if self.z_steps > 1:
+ hidden_states = encoded_layers[-2]
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
+ query_states = encoded_layers[-1]
+ rel_embeddings = self.encoder.get_rel_embedding()
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
+ for layer in layers[1:]:
+ query_states = layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=False,
+ query_states=query_states,
+ relative_pos=rel_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ encoded_layers.append(query_states)
+
+ sequence_output = encoded_layers[-1]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
+class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.deberta = DebertaV2Model(config)
+ self.cls = DebertaV2OnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="[MASK]",
+ )
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM.forward with Deberta->DebertaV2
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
+class DebertaV2PredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
+class DebertaV2LMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = DebertaV2PredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
+class DebertaV2OnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = DebertaV2LMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ num_labels = getattr(config, "num_labels", 2)
+ self.num_labels = num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+
+ self.classifier = nn.Linear(output_dim, num_labels)
+ drop_out = getattr(config, "cls_dropout", None)
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
+ self.dropout = StableDropout(drop_out)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.deberta.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.deberta.set_input_embeddings(new_embeddings)
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification.forward with Deberta->DebertaV2
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ # regression task
+ loss_fn = nn.MSELoss()
+ logits = logits.view(-1).to(labels.dtype)
+ loss = loss_fn(logits, labels.view(-1))
+ elif labels.dim() == 1 or labels.size(-1) == 1:
+ label_index = (labels >= 0).nonzero()
+ labels = labels.long()
+ if label_index.size(0) > 0:
+ labeled_logits = torch.gather(
+ logits, 0, label_index.expand(label_index.size(0), logits.size(1))
+ )
+ labels = torch.gather(labels, 0, label_index.view(-1))
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
+ else:
+ loss = torch.tensor(0).to(logits)
+ else:
+ log_softmax = nn.LogSoftmax(-1)
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
+ elif self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
+class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+class TokenClassifierRegressionOutput(ModelOutput):
+ """
+ Base class for outputs of token classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ values: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+class DebertaV2ForTokenClassificationRegression(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = 4
+
+ self.deberta = DebertaV2Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ self.hidden1 = nn.Linear(config.hidden_size, config.hidden_size)
+ self.classifier = nn.Linear(config.hidden_size, self.num_labels)
+
+ self.hidden2 = nn.Linear(config.hidden_size, config.hidden_size)
+ self.regressor = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+
+ logits = self.classifier(self.hidden1(sequence_output))
+ values = self.regressor(self.hidden2(sequence_output))
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierRegressionOutput(
+ loss=loss, logits=logits, values=values, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ qa_target_start_index=_QA_TARGET_START_INDEX,
+ qa_target_end_index=_QA_TARGET_END_INDEX,
+ )
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering.forward with Deberta->DebertaV2
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ num_labels = getattr(config, "num_labels", 2)
+ self.num_labels = num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+
+ self.classifier = nn.Linear(output_dim, 1)
+ drop_out = getattr(config, "cls_dropout", None)
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
+ self.dropout = StableDropout(drop_out)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.deberta.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.deberta.set_input_embeddings(new_embeddings)
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.deberta(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/.ipynb_checkpoints/models-checkpoint.py b/.ipynb_checkpoints/models-checkpoint.py
index 652d8f290ff0415e03a723d97a238bd0b6a03f4d..259288a96e229be80c97ae62fa45f2d1813d16fa 100644
--- a/.ipynb_checkpoints/models-checkpoint.py
+++ b/.ipynb_checkpoints/models-checkpoint.py
@@ -28,711 +28,713 @@ from diffusers import AutoencoderKL as DiffuserAutoencoderKL
from layers.layers import chord_tokenizer, beat_tokenizer, Chord_Embedding, Beat_Embedding, Music_PositionalEncoding, Fundamental_Music_Embedding
def build_pretrained_models(name):
- checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
- scale_factor = checkpoint["state_dict"]["scale_factor"].item()
+ checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
+ scale_factor = checkpoint["state_dict"]["scale_factor"].item()
- vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
+ vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
- config = default_audioldm_config(name)
- vae_config = config["model"]["params"]["first_stage_config"]["params"]
- vae_config["scale_factor"] = scale_factor
+ config = default_audioldm_config(name)
+ vae_config = config["model"]["params"]["first_stage_config"]["params"]
+ vae_config["scale_factor"] = scale_factor
- vae = AutoencoderKL(**vae_config)
- vae.load_state_dict(vae_state_dict)
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(vae_state_dict)
- fn_STFT = TacotronSTFT(
- config["preprocessing"]["stft"]["filter_length"],
- config["preprocessing"]["stft"]["hop_length"],
- config["preprocessing"]["stft"]["win_length"],
- config["preprocessing"]["mel"]["n_mel_channels"],
- config["preprocessing"]["audio"]["sampling_rate"],
- config["preprocessing"]["mel"]["mel_fmin"],
- config["preprocessing"]["mel"]["mel_fmax"],
- )
+ fn_STFT = TacotronSTFT(
+ config["preprocessing"]["stft"]["filter_length"],
+ config["preprocessing"]["stft"]["hop_length"],
+ config["preprocessing"]["stft"]["win_length"],
+ config["preprocessing"]["mel"]["n_mel_channels"],
+ config["preprocessing"]["audio"]["sampling_rate"],
+ config["preprocessing"]["mel"]["mel_fmin"],
+ config["preprocessing"]["mel"]["mel_fmax"],
+ )
- vae.eval()
- fn_STFT.eval()
- return vae, fn_STFT
+ vae.eval()
+ fn_STFT.eval()
+ return vae, fn_STFT
class AudioDiffusion(nn.Module):
- def __init__(
- self,
- text_encoder_name,
- scheduler_name,
- unet_model_name=None,
- unet_model_config_path=None,
- snr_gamma=None,
- freeze_text_encoder=True,
- uncondition=False,
-
- ):
- super().__init__()
-
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
-
- self.text_encoder_name = text_encoder_name
- self.scheduler_name = scheduler_name
- self.unet_model_name = unet_model_name
- self.unet_model_config_path = unet_model_config_path
- self.snr_gamma = snr_gamma
- self.freeze_text_encoder = freeze_text_encoder
- self.uncondition = uncondition
-
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
- self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
- self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
-
- if unet_model_config_path:
- unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
- self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
- self.set_from = "random"
- print("UNet initialized randomly.")
- else:
- self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
- self.set_from = "pre-trained"
- self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
- self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
- print("UNet initialized from stable diffusion checkpoint.")
-
- if "stable-diffusion" in self.text_encoder_name:
- self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
- self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
- elif "t5" in self.text_encoder_name:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
- else:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
-
- def compute_snr(self, timesteps):
- """
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
- """
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
- sqrt_alphas_cumprod = alphas_cumprod**0.5
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
-
- # Expand the tensors.
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
-
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
-
- # Compute SNR.
- snr = (alpha / sigma) ** 2
- return snr
-
- def encode_text(self, prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
-
- if self.freeze_text_encoder:
- with torch.no_grad():
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
- else:
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
-
- boolean_encoder_mask = (attention_mask == 1).to(device)
- return encoder_hidden_states, boolean_encoder_mask
-
- def forward(self, latents, prompt, validation_mode=False):
- device = self.text_encoder.device
- num_train_timesteps = self.noise_scheduler.num_train_timesteps
- self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
-
- encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
-
- if self.uncondition:
- mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
- if len(mask_indices) > 0:
- encoder_hidden_states[mask_indices] = 0
-
- bsz = latents.shape[0]
-
- if validation_mode:
- timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
- else:
- # Sample a random timestep for each instance
- timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
- # print('in if ', timesteps)
- timesteps = timesteps.long()
- # print('outside if ' , timesteps)
- noise = torch.randn_like(latents)
- noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the target for loss depending on the prediction type
- if self.noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif self.noise_scheduler.config.prediction_type == "v_prediction":
- target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
- else:
- raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
-
- if self.set_from == "random":
- model_pred = self.unet(
- noisy_latents, timesteps, encoder_hidden_states,
- encoder_attention_mask=boolean_encoder_mask
- ).sample
-
- elif self.set_from == "pre-trained":
- compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- model_pred = self.unet(
- compressed_latents, timesteps, encoder_hidden_states,
- encoder_attention_mask=boolean_encoder_mask
- ).sample
- model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
-
- if self.snr_gamma is None:
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
- # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
- snr = self.compute_snr(timesteps)
- mse_loss_weights = (
- torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
- )
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
- loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
- loss = loss.mean()
-
- return loss
-
- @torch.no_grad()
- def inference(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
- disable_progress=True):
- device = self.text_encoder.device
- classifier_free_guidance = guidance_scale > 1.0
- batch_size = len(prompt) * num_samples_per_prompt
-
- if classifier_free_guidance:
- prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
- else:
- prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- inference_scheduler.set_timesteps(num_steps, device=device)
- timesteps = inference_scheduler.timesteps
-
- num_channels_latents = self.unet.in_channels
- latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
-
- num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
- progress_bar = tqdm(range(num_steps), disable=disable_progress)
-
- for i, t in enumerate(timesteps):
- # expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
- latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
-
- noise_pred = self.unet(
- latent_model_input, t, encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=boolean_prompt_mask
- ).sample
-
- # perform guidance
- if classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
-
- # call the callback, if provided
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
- progress_bar.update(1)
-
- if self.set_from == "pre-trained":
- latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- return latents
-
- def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
- shape = (batch_size, num_channels_latents, 256, 16)
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
- # scale the initial noise by the standard deviation required by the scheduler
- latents = latents * inference_scheduler.init_noise_sigma
- return latents
-
- def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
-
- with torch.no_grad():
- prompt_embeds = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
-
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # get unconditional embeddings for classifier free guidance
- uncond_tokens = [""] * len(prompt)
-
- max_length = prompt_embeds.shape[1]
- uncond_batch = self.tokenizer(
- uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
- )
- uncond_input_ids = uncond_batch.input_ids.to(device)
- uncond_attention_mask = uncond_batch.attention_mask.to(device)
-
- with torch.no_grad():
- negative_prompt_embeds = self.text_encoder(
- input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
- )[0]
-
- negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # For classifier free guidance, we need to do two forward passes.
- # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
- boolean_prompt_mask = (prompt_mask == 1).to(device)
-
- return prompt_embeds, boolean_prompt_mask
-
+ def __init__(
+ self,
+ text_encoder_name,
+ scheduler_name,
+ unet_model_name=None,
+ unet_model_config_path=None,
+ snr_gamma=None,
+ freeze_text_encoder=True,
+ uncondition=False,
+
+ ):
+ super().__init__()
+
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
+
+ self.text_encoder_name = text_encoder_name
+ self.scheduler_name = scheduler_name
+ self.unet_model_name = unet_model_name
+ self.unet_model_config_path = unet_model_config_path
+ self.snr_gamma = snr_gamma
+ self.freeze_text_encoder = freeze_text_encoder
+ self.uncondition = uncondition
+
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
+ self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+ self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+
+ if unet_model_config_path:
+ unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
+ self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
+ self.set_from = "random"
+ print("UNet initialized randomly.")
+ else:
+ self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
+ self.set_from = "pre-trained"
+ self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
+ self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
+ print("UNet initialized from stable diffusion checkpoint.")
+
+ if "stable-diffusion" in self.text_encoder_name:
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
+ elif "t5" in self.text_encoder_name:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
+
+ def compute_snr(self, timesteps):
+ """
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ # Expand the tensors.
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+ def encode_text(self, prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
+
+ if self.freeze_text_encoder:
+ with torch.no_grad():
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+ else:
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+
+ boolean_encoder_mask = (attention_mask == 1).to(device)
+ return encoder_hidden_states, boolean_encoder_mask
+
+ def forward(self, latents, prompt, validation_mode=False):
+ device = self.text_encoder.device
+ num_train_timesteps = self.noise_scheduler.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
+
+ if self.uncondition:
+ mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
+ if len(mask_indices) > 0:
+ encoder_hidden_states[mask_indices] = 0
+
+ bsz = latents.shape[0]
+
+ if validation_mode:
+ timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
+ else:
+ # Sample a random timestep for each instance
+ timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
+ # print('in if ', timesteps)
+ timesteps = timesteps.long()
+ # print('outside if ' , timesteps)
+ noise = torch.randn_like(latents)
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the target for loss depending on the prediction type
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
+
+ if self.set_from == "random":
+ model_pred = self.unet(
+ noisy_latents, timesteps, encoder_hidden_states,
+ encoder_attention_mask=boolean_encoder_mask
+ ).sample
+
+ elif self.set_from == "pre-trained":
+ compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ model_pred = self.unet(
+ compressed_latents, timesteps, encoder_hidden_states,
+ encoder_attention_mask=boolean_encoder_mask
+ ).sample
+ model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+
+ if self.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+ snr = self.compute_snr(timesteps)
+ mse_loss_weights = (
+ torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ return loss
+
+ @torch.no_grad()
+ def inference(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
+ disable_progress=True):
+ device = self.text_encoder.device
+ classifier_free_guidance = guidance_scale > 1.0
+ batch_size = len(prompt) * num_samples_per_prompt
+
+ if classifier_free_guidance:
+ prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
+ else:
+ prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ inference_scheduler.set_timesteps(num_steps, device=device)
+ timesteps = inference_scheduler.timesteps
+
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
+
+ num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
+ latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
+
+ noise_pred = self.unet(
+ latent_model_input, t, encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=boolean_prompt_mask
+ ).sample
+
+ # perform guidance
+ if classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
+ progress_bar.update(1)
+
+ if self.set_from == "pre-trained":
+ latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ return latents
+
+ def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
+ shape = (batch_size, num_channels_latents, 256, 16)
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * inference_scheduler.init_noise_sigma
+ return latents
+
+ def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ prompt_embeds = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # get unconditional embeddings for classifier free guidance
+ uncond_tokens = [""] * len(prompt)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_batch = self.tokenizer(
+ uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
+ )
+ uncond_input_ids = uncond_batch.input_ids.to(device)
+ uncond_attention_mask = uncond_batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ negative_prompt_embeds = self.text_encoder(
+ input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
+ )[0]
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
+ boolean_prompt_mask = (prompt_mask == 1).to(device)
+
+ return prompt_embeds, boolean_prompt_mask
+
class MusicAudioDiffusion(nn.Module):
- def __init__(
- self,
- text_encoder_name,
- scheduler_name,
- unet_model_name=None,
- unet_model_config_path=None,
- snr_gamma=None,
- freeze_text_encoder=True,
- uncondition=False,
-
- d_fme = 1024, #FME
- fme_type = "se",
- base = 1,
- if_trainable = True,
- translation_bias_type = "nd",
- emb_nn = True,
- d_pe = 1024, #PE
- if_index = True,
- if_global_timing = True,
- if_modulo_timing = False,
- d_beat = 1024, #Beat
- d_oh_beat_type = 7,
- beat_len = 50,
- d_chord = 1024, #Chord
- d_oh_chord_type = 12,
- d_oh_inv_type = 4,
- chord_len = 20,
-
- ):
- super().__init__()
-
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
-
- self.text_encoder_name = text_encoder_name
- self.scheduler_name = scheduler_name
- self.unet_model_name = unet_model_name
- self.unet_model_config_path = unet_model_config_path
- self.snr_gamma = snr_gamma
- self.freeze_text_encoder = freeze_text_encoder
- self.uncondition = uncondition
-
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
- self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
- self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
-
- if unet_model_config_path:
- unet_config = UNet2DConditionModelMusic.load_config(unet_model_config_path)
- self.unet = UNet2DConditionModelMusic.from_config(unet_config, subfolder="unet")
- self.set_from = "random"
- print("UNet initialized randomly.")
- else:
- self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
- self.set_from = "pre-trained"
- self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
- self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
- print("UNet initialized from stable diffusion checkpoint.")
-
- if "stable-diffusion" in self.text_encoder_name:
- self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
- self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
- elif "t5" in self.text_encoder_name:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
- else:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
-
- self.device = self.text_encoder.device
- #Music Feature Encoder
- self.FME = Fundamental_Music_Embedding(d_model = d_fme, base= base, if_trainable = False, type = fme_type,emb_nn=emb_nn,translation_bias_type = translation_bias_type)
- self.PE = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
- # self.PE2 = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
- self.beat_tokenizer = beat_tokenizer(seq_len_beat=beat_len, if_pad = True)
- self.beat_embedding_layer = Beat_Embedding(self.PE, d_model = d_beat, d_oh_beat_type = d_oh_beat_type)
- self.chord_embedding_layer = Chord_Embedding(self.FME, self.PE, d_model = d_chord, d_oh_type = d_oh_chord_type, d_oh_inv = d_oh_inv_type)
- self.chord_tokenizer = chord_tokenizer(seq_len_chord=chord_len, if_pad = True)
-
-
- def compute_snr(self, timesteps):
- """
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
- """
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
- sqrt_alphas_cumprod = alphas_cumprod**0.5
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
-
- # Expand the tensors.
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
-
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
-
- # Compute SNR.
- snr = (alpha / sigma) ** 2
- return snr
-
- def encode_text(self, prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) #cuda
- if self.freeze_text_encoder:
- with torch.no_grad():
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0] #batch, len_text, dim
- else:
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
- boolean_encoder_mask = (attention_mask == 1).to(device) ##batch, len_text
- return encoder_hidden_states, boolean_encoder_mask
-
- def encode_beats(self, beats):
- # device = self.beat_embedding_layer.device
- out_beat = []
- out_beat_timing = []
- out_mask = []
- for beat in beats:
- tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
- out_beat.append(tokenized_beats)
- out_beat_timing.append(tokenized_beats_timing)
- out_mask.append(tokenized_beat_mask)
- out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).cuda(), torch.tensor(out_beat_timing).cuda(), torch.tensor(out_mask).cuda() #batch, len_beat
- embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing)
-
- return embedded_beat, out_mask
-
- def encode_chords(self, chords,chords_time):
- out_chord_root = []
- out_chord_type = []
- out_chord_inv = []
- out_chord_timing = []
- out_mask = []
- for chord, chord_time in zip(chords,chords_time): #batch loop
- tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
- out_chord_root.append(tokenized_chord_root)
- out_chord_type.append(tokenized_chord_type)
- out_chord_inv.append(tokenized_chord_inv)
- out_chord_timing.append(tokenized_chord_time)
- out_mask.append(tokenized_chord_mask)
- #chords: (B, LEN, 4)
- out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).cuda(), torch.tensor(out_chord_type).cuda(), torch.tensor(out_chord_inv).cuda(), torch.tensor(out_chord_timing).cuda(), torch.tensor(out_mask).cuda()
- embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing)
- return embedded_chord, out_mask
- # return out_chord_root, out_mask
-
-
- def forward(self, latents, prompt, beats, chords,chords_time, validation_mode=False):
- device = self.text_encoder.device
- num_train_timesteps = self.noise_scheduler.num_train_timesteps
- self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
-
- encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
-
- # with torch.no_grad():
- encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
- encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
-
-
- if self.uncondition:
- mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
- if len(mask_indices) > 0:
- encoder_hidden_states[mask_indices] = 0
- encoded_chords[mask_indices] = 0
- encoded_beats[mask_indices] = 0
-
- bsz = latents.shape[0]
-
- if validation_mode:
- timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
- else:
- timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
-
-
- timesteps = timesteps.long()
-
- noise = torch.randn_like(latents)
- noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the target for loss depending on the prediction type
- if self.noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif self.noise_scheduler.config.prediction_type == "v_prediction":
- target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
- else:
- raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
-
- if self.set_from == "random":
- # model_pred = torch.zeros((bsz,8,256,16)).to(device)
- model_pred = self.unet(
- noisy_latents, timesteps, encoder_hidden_states, encoded_beats, encoded_chords,
- encoder_attention_mask=boolean_encoder_mask, beat_attention_mask = beat_mask, chord_attention_mask = chord_mask
- ).sample
-
- elif self.set_from == "pre-trained":
- compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- model_pred = self.unet(
- compressed_latents, timesteps, encoder_hidden_states,
- encoder_attention_mask=boolean_encoder_mask
- ).sample
- model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
-
- if self.snr_gamma is None:
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
- # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
- snr = self.compute_snr(timesteps)
- mse_loss_weights = (
- torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
- )
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
- loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
- loss = loss.mean()
-
- return loss
-
- @torch.no_grad()
- def inference(self, prompt, beats, chords,chords_time, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
- disable_progress=True):
- device = self.text_encoder.device
- classifier_free_guidance = guidance_scale > 1.0
- batch_size = len(prompt) * num_samples_per_prompt
-
- if classifier_free_guidance:
- prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
- encoded_beats, beat_mask = self.encode_beats_classifier_free(beats, num_samples_per_prompt) #batch, len_beats, dim; batch, len_beats
- encoded_chords, chord_mask = self.encode_chords_classifier_free(chords, chords_time, num_samples_per_prompt)
- else:
- prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
- encoded_beats = encoded_beats.repeat_interleave(num_samples_per_prompt, 0)
- beat_mask = beat_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
- encoded_chords = encoded_chords.repeat_interleave(num_samples_per_prompt, 0)
- chord_mask = chord_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # print(f"encoded_chords:{encoded_chords.shape}, chord_mask:{chord_mask.shape}, prompt_embeds:{prompt_embeds.shape},boolean_prompt_mask:{boolean_prompt_mask.shape} ")
- inference_scheduler.set_timesteps(num_steps, device=device)
- timesteps = inference_scheduler.timesteps
-
- num_channels_latents = self.unet.in_channels
- latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
-
- num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
- progress_bar = tqdm(range(num_steps), disable=disable_progress)
-
- for i, t in enumerate(timesteps):
- # expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
- latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
-
- noise_pred = self.unet(
- latent_model_input, t, encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=boolean_prompt_mask,
- beat_features = encoded_beats, beat_attention_mask = beat_mask, chord_features = encoded_chords,chord_attention_mask = chord_mask
- ).sample
-
- # perform guidance
- if classifier_free_guidance: #should work for beats and chords too
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
-
- # call the callback, if provided
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
- progress_bar.update(1)
-
- if self.set_from == "pre-trained":
- latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- return latents
-
- def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
- shape = (batch_size, num_channels_latents, 256, 16)
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
- # scale the initial noise by the standard deviation required by the scheduler
- latents = latents * inference_scheduler.init_noise_sigma
- return latents
-
- def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
-
- with torch.no_grad():
- prompt_embeds = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
-
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # get unconditional embeddings for classifier free guidance
- # print(len(prompt), 'this is prompt len')
- uncond_tokens = [""] * len(prompt)
-
- max_length = prompt_embeds.shape[1]
- uncond_batch = self.tokenizer(
- uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
- )
- uncond_input_ids = uncond_batch.input_ids.to(device)
- uncond_attention_mask = uncond_batch.attention_mask.to(device)
-
- with torch.no_grad():
- negative_prompt_embeds = self.text_encoder(
- input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
- )[0]
-
- negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # For classifier free guidance, we need to do two forward passes.
- # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
- boolean_prompt_mask = (prompt_mask == 1).to(device)
-
- return prompt_embeds, boolean_prompt_mask
-
-
- def encode_beats_classifier_free(self, beats, num_samples_per_prompt):
- with torch.no_grad():
- out_beat = []
- out_beat_timing = []
- out_mask = []
- for beat in beats:
- tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
- out_beat.append(tokenized_beats)
- out_beat_timing.append(tokenized_beats_timing)
- out_mask.append(tokenized_beat_mask)
- out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).cuda(), torch.tensor(out_beat_timing).cuda(), torch.tensor(out_mask).cuda() #batch, len_beat
- embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing)
-
- embedded_beat = embedded_beat.repeat_interleave(num_samples_per_prompt, 0)
- out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- uncond_beats = [[[],[]]] * len(beats)
-
- max_length = embedded_beat.shape[1]
- with torch.no_grad():
- out_beat_unc = []
- out_beat_timing_unc = []
- out_mask_unc = []
- for beat in uncond_beats:
- tokenized_beats, tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
- out_beat_unc.append(tokenized_beats)
- out_beat_timing_unc.append(tokenized_beats_timing)
- out_mask_unc.append(tokenized_beat_mask)
- out_beat_unc, out_beat_timing_unc, out_mask_unc = torch.tensor(out_beat_unc).cuda(), torch.tensor(out_beat_timing_unc).cuda(), torch.tensor(out_mask_unc).cuda() #batch, len_beat
- embedded_beat_unc = self.beat_embedding_layer(out_beat_unc, out_beat_timing_unc)
-
- embedded_beat_unc = embedded_beat_unc.repeat_interleave(num_samples_per_prompt, 0)
- out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
-
- embedded_beat = torch.cat([embedded_beat_unc, embedded_beat])
- out_mask = torch.cat([out_mask_unc, out_mask])
-
- return embedded_beat, out_mask
-
-
- def encode_chords_classifier_free(self, chords, chords_time, num_samples_per_prompt):
-
- with torch.no_grad():
- out_chord_root = []
- out_chord_type = []
- out_chord_inv = []
- out_chord_timing = []
- out_mask = []
- for chord, chord_time in zip(chords,chords_time): #batch loop
- tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
- out_chord_root.append(tokenized_chord_root)
- out_chord_type.append(tokenized_chord_type)
- out_chord_inv.append(tokenized_chord_inv)
- out_chord_timing.append(tokenized_chord_time)
- out_mask.append(tokenized_chord_mask)
- out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).cuda(), torch.tensor(out_chord_type).cuda(), torch.tensor(out_chord_inv).cuda(), torch.tensor(out_chord_timing).cuda(), torch.tensor(out_mask).cuda()
- embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing)
-
- embedded_chord = embedded_chord.repeat_interleave(num_samples_per_prompt, 0)
- out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- chords_unc=[[]] * len(chords)
- chords_time_unc=[[]] * len(chords_time)
-
- max_length = embedded_chord.shape[1]
-
- with torch.no_grad():
- out_chord_root_unc = []
- out_chord_type_unc = []
- out_chord_inv_unc = []
- out_chord_timing_unc = []
- out_mask_unc = []
- for chord, chord_time in zip(chords_unc,chords_time_unc): #batch loop
- tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
- out_chord_root_unc.append(tokenized_chord_root)
- out_chord_type_unc.append(tokenized_chord_type)
- out_chord_inv_unc.append(tokenized_chord_inv)
- out_chord_timing_unc.append(tokenized_chord_time)
- out_mask_unc.append(tokenized_chord_mask)
- out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc, out_mask_unc = torch.tensor(out_chord_root_unc).cuda(), torch.tensor(out_chord_type_unc).cuda(), torch.tensor(out_chord_inv_unc).cuda(), torch.tensor(out_chord_timing_unc).cuda(), torch.tensor(out_mask_unc).cuda()
- embedded_chord_unc = self.chord_embedding_layer(out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc)
-
-
- embedded_chord_unc = embedded_chord_unc.repeat_interleave(num_samples_per_prompt, 0)
- out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
-
- embedded_chord = torch.cat([embedded_chord_unc, embedded_chord])
- out_mask = torch.cat([out_mask_unc, out_mask])
-
- return embedded_chord, out_mask
+ def __init__(
+ self,
+ text_encoder_name,
+ scheduler_name,
+ unet_model_name=None,
+ unet_model_config_path=None,
+ snr_gamma=None,
+ freeze_text_encoder=True,
+ uncondition=False,
+
+ d_fme = 1024, #FME
+ fme_type = "se",
+ base = 1,
+ if_trainable = True,
+ translation_bias_type = "nd",
+ emb_nn = True,
+ d_pe = 1024, #PE
+ if_index = True,
+ if_global_timing = True,
+ if_modulo_timing = False,
+ d_beat = 1024, #Beat
+ d_oh_beat_type = 7,
+ beat_len = 50,
+ d_chord = 1024, #Chord
+ d_oh_chord_type = 12,
+ d_oh_inv_type = 4,
+ chord_len = 20,
+
+ ):
+ super().__init__()
+
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
+
+ self.text_encoder_name = text_encoder_name
+ self.scheduler_name = scheduler_name
+ self.unet_model_name = unet_model_name
+ self.unet_model_config_path = unet_model_config_path
+ self.snr_gamma = snr_gamma
+ self.freeze_text_encoder = freeze_text_encoder
+ self.uncondition = uncondition
+
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
+ self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+ self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+
+ if unet_model_config_path:
+ unet_config = UNet2DConditionModelMusic.load_config(unet_model_config_path)
+ self.unet = UNet2DConditionModelMusic.from_config(unet_config, subfolder="unet")
+ self.set_from = "random"
+ print("UNet initialized randomly.")
+ else:
+ self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
+ self.set_from = "pre-trained"
+ self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
+ self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
+ print("UNet initialized from stable diffusion checkpoint.")
+
+ if "stable-diffusion" in self.text_encoder_name:
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
+ elif "t5" in self.text_encoder_name:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
+
+ self.device = self.text_encoder.device
+ #Music Feature Encoder
+ self.FME = Fundamental_Music_Embedding(d_model = d_fme, base= base, if_trainable = False, type = fme_type,emb_nn=emb_nn,translation_bias_type = translation_bias_type)
+ self.PE = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
+ # self.PE2 = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
+ self.beat_tokenizer = beat_tokenizer(seq_len_beat=beat_len, if_pad = True)
+ self.beat_embedding_layer = Beat_Embedding(self.PE, d_model = d_beat, d_oh_beat_type = d_oh_beat_type)
+ self.chord_embedding_layer = Chord_Embedding(self.FME, self.PE, d_model = d_chord, d_oh_type = d_oh_chord_type, d_oh_inv = d_oh_inv_type)
+ self.chord_tokenizer = chord_tokenizer(seq_len_chord=chord_len, if_pad = True)
+
+
+ def compute_snr(self, timesteps):
+ """
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ # Expand the tensors.
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+ def encode_text(self, prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) #cuda
+ if self.freeze_text_encoder:
+ with torch.no_grad():
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0] #batch, len_text, dim
+ else:
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+ boolean_encoder_mask = (attention_mask == 1).to(device) ##batch, len_text
+ return encoder_hidden_states, boolean_encoder_mask
+
+ def encode_beats(self, beats):
+ device = self.device
+ out_beat = []
+ out_beat_timing = []
+ out_mask = []
+ for beat in beats:
+ tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
+ out_beat.append(tokenized_beats)
+ out_beat_timing.append(tokenized_beats_timing)
+ out_mask.append(tokenized_beat_mask)
+ out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).to(device), torch.tensor(out_beat_timing).to(device), torch.tensor(out_mask).to(device) #batch, len_beat
+ embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing, device)
+
+ return embedded_beat, out_mask
+
+ def encode_chords(self, chords,chords_time):
+ device = self.device
+ out_chord_root = []
+ out_chord_type = []
+ out_chord_inv = []
+ out_chord_timing = []
+ out_mask = []
+ for chord, chord_time in zip(chords,chords_time): #batch loop
+ tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
+ out_chord_root.append(tokenized_chord_root)
+ out_chord_type.append(tokenized_chord_type)
+ out_chord_inv.append(tokenized_chord_inv)
+ out_chord_timing.append(tokenized_chord_time)
+ out_mask.append(tokenized_chord_mask)
+ #chords: (B, LEN, 4)
+ out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).to(device), torch.tensor(out_chord_type).to(device), torch.tensor(out_chord_inv).to(device), torch.tensor(out_chord_timing).to(device), torch.tensor(out_mask).to(device)
+ embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, device)
+ return embedded_chord, out_mask
+ # return out_chord_root, out_mask
+
+
+ def forward(self, latents, prompt, beats, chords,chords_time, validation_mode=False):
+ device = self.text_encoder.device
+ num_train_timesteps = self.noise_scheduler.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
+
+ # with torch.no_grad():
+ encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
+ encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
+
+
+ if self.uncondition:
+ mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
+ if len(mask_indices) > 0:
+ encoder_hidden_states[mask_indices] = 0
+ encoded_chords[mask_indices] = 0
+ encoded_beats[mask_indices] = 0
+
+ bsz = latents.shape[0]
+
+ if validation_mode:
+ timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
+ else:
+ timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
+
+
+ timesteps = timesteps.long()
+
+ noise = torch.randn_like(latents)
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the target for loss depending on the prediction type
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
+
+ if self.set_from == "random":
+ # model_pred = torch.zeros((bsz,8,256,16)).to(device)
+ model_pred = self.unet(
+ noisy_latents, timesteps, encoder_hidden_states, encoded_beats, encoded_chords,
+ encoder_attention_mask=boolean_encoder_mask, beat_attention_mask = beat_mask, chord_attention_mask = chord_mask
+ ).sample
+
+ elif self.set_from == "pre-trained":
+ compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ model_pred = self.unet(
+ compressed_latents, timesteps, encoder_hidden_states,
+ encoder_attention_mask=boolean_encoder_mask
+ ).sample
+ model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+
+ if self.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+ snr = self.compute_snr(timesteps)
+ mse_loss_weights = (
+ torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ return loss
+
+ @torch.no_grad()
+ def inference(self, prompt, beats, chords,chords_time, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
+ disable_progress=True):
+ device = self.text_encoder.device
+ classifier_free_guidance = guidance_scale > 1.0
+ batch_size = len(prompt) * num_samples_per_prompt
+
+ if classifier_free_guidance:
+ prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
+ encoded_beats, beat_mask = self.encode_beats_classifier_free(beats, num_samples_per_prompt) #batch, len_beats, dim; batch, len_beats
+ encoded_chords, chord_mask = self.encode_chords_classifier_free(chords, chords_time, num_samples_per_prompt)
+ else:
+ prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
+ encoded_beats = encoded_beats.repeat_interleave(num_samples_per_prompt, 0)
+ beat_mask = beat_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
+ encoded_chords = encoded_chords.repeat_interleave(num_samples_per_prompt, 0)
+ chord_mask = chord_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # print(f"encoded_chords:{encoded_chords.shape}, chord_mask:{chord_mask.shape}, prompt_embeds:{prompt_embeds.shape},boolean_prompt_mask:{boolean_prompt_mask.shape} ")
+ inference_scheduler.set_timesteps(num_steps, device=device)
+ timesteps = inference_scheduler.timesteps
+
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
+
+ num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
+ latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
+
+ noise_pred = self.unet(
+ latent_model_input, t, encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=boolean_prompt_mask,
+ beat_features = encoded_beats, beat_attention_mask = beat_mask, chord_features = encoded_chords,chord_attention_mask = chord_mask
+ ).sample
+
+ # perform guidance
+ if classifier_free_guidance: #should work for beats and chords too
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
+ progress_bar.update(1)
+
+ if self.set_from == "pre-trained":
+ latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ return latents
+
+ def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
+ shape = (batch_size, num_channels_latents, 256, 16)
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * inference_scheduler.init_noise_sigma
+ return latents
+
+ def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ prompt_embeds = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # get unconditional embeddings for classifier free guidance
+ # print(len(prompt), 'this is prompt len')
+ uncond_tokens = [""] * len(prompt)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_batch = self.tokenizer(
+ uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
+ )
+ uncond_input_ids = uncond_batch.input_ids.to(device)
+ uncond_attention_mask = uncond_batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ negative_prompt_embeds = self.text_encoder(
+ input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
+ )[0]
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
+ boolean_prompt_mask = (prompt_mask == 1).to(device)
+
+ return prompt_embeds, boolean_prompt_mask
+
+
+ def encode_beats_classifier_free(self, beats, num_samples_per_prompt):
+ device = self.device
+ with torch.no_grad():
+ out_beat = []
+ out_beat_timing = []
+ out_mask = []
+ for beat in beats:
+ tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
+ out_beat.append(tokenized_beats)
+ out_beat_timing.append(tokenized_beats_timing)
+ out_mask.append(tokenized_beat_mask)
+ out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).to(device), torch.tensor(out_beat_timing).to(device), torch.tensor(out_mask).to(device) #batch, len_beat
+ embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing, device)
+
+ embedded_beat = embedded_beat.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ uncond_beats = [[[],[]]] * len(beats)
+
+ max_length = embedded_beat.shape[1]
+ with torch.no_grad():
+ out_beat_unc = []
+ out_beat_timing_unc = []
+ out_mask_unc = []
+ for beat in uncond_beats:
+ tokenized_beats, tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
+ out_beat_unc.append(tokenized_beats)
+ out_beat_timing_unc.append(tokenized_beats_timing)
+ out_mask_unc.append(tokenized_beat_mask)
+ out_beat_unc, out_beat_timing_unc, out_mask_unc = torch.tensor(out_beat_unc).to(device), torch.tensor(out_beat_timing_unc).to(device), torch.tensor(out_mask_unc).to(device) #batch, len_beat
+ embedded_beat_unc = self.beat_embedding_layer(out_beat_unc, out_beat_timing_unc, device)
+
+ embedded_beat_unc = embedded_beat_unc.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
+
+ embedded_beat = torch.cat([embedded_beat_unc, embedded_beat])
+ out_mask = torch.cat([out_mask_unc, out_mask])
+
+ return embedded_beat, out_mask
+
+
+ def encode_chords_classifier_free(self, chords, chords_time, num_samples_per_prompt):
+ device = self.device
+ with torch.no_grad():
+ out_chord_root = []
+ out_chord_type = []
+ out_chord_inv = []
+ out_chord_timing = []
+ out_mask = []
+ for chord, chord_time in zip(chords,chords_time): #batch loop
+ tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
+ out_chord_root.append(tokenized_chord_root)
+ out_chord_type.append(tokenized_chord_type)
+ out_chord_inv.append(tokenized_chord_inv)
+ out_chord_timing.append(tokenized_chord_time)
+ out_mask.append(tokenized_chord_mask)
+ out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).to(device), torch.tensor(out_chord_type).to(device), torch.tensor(out_chord_inv).to(device), torch.tensor(out_chord_timing).to(device), torch.tensor(out_mask).to(device)
+ embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, device)
+
+ embedded_chord = embedded_chord.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ chords_unc=[[]] * len(chords)
+ chords_time_unc=[[]] * len(chords_time)
+
+ max_length = embedded_chord.shape[1]
+
+ with torch.no_grad():
+ out_chord_root_unc = []
+ out_chord_type_unc = []
+ out_chord_inv_unc = []
+ out_chord_timing_unc = []
+ out_mask_unc = []
+ for chord, chord_time in zip(chords_unc,chords_time_unc): #batch loop
+ tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
+ out_chord_root_unc.append(tokenized_chord_root)
+ out_chord_type_unc.append(tokenized_chord_type)
+ out_chord_inv_unc.append(tokenized_chord_inv)
+ out_chord_timing_unc.append(tokenized_chord_time)
+ out_mask_unc.append(tokenized_chord_mask)
+ out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc, out_mask_unc = torch.tensor(out_chord_root_unc).to(device), torch.tensor(out_chord_type_unc).to(device), torch.tensor(out_chord_inv_unc).to(device), torch.tensor(out_chord_timing_unc).to(device), torch.tensor(out_mask_unc).to(device)
+ embedded_chord_unc = self.chord_embedding_layer(out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc, device)
+
+
+ embedded_chord_unc = embedded_chord_unc.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
+
+ embedded_chord = torch.cat([embedded_chord_unc, embedded_chord])
+ out_mask = torch.cat([out_mask_unc, out_mask])
+
+ return embedded_chord, out_mask
diff --git a/.ipynb_checkpoints/requirements-checkpoint.txt b/.ipynb_checkpoints/requirements-checkpoint.txt
index bf57c2bb0f671f329df16ad09dc37fcd9f71a6f9..d57857ffe6f63d2840feb944b4acd123e2e84c71 100644
--- a/.ipynb_checkpoints/requirements-checkpoint.txt
+++ b/.ipynb_checkpoints/requirements-checkpoint.txt
@@ -1,12 +1,12 @@
-torch==1.13.1
-torchaudio==0.13.1
-torchvision==0.14.1
-transformers==4.27.0
-accelerate==0.18.0
+torch==2.0.1
+torchaudio==2.0.2
+torchvision==0.15.2
+transformers==4.31.0
+accelerate==0.21.0
datasets==2.1.0
einops==0.6.1
h5py==3.8.0
-huggingface_hub==0.13.3
+huggingface_hub==0.19.4
importlib_metadata==6.3.0
librosa==0.9.2
matplotlib==3.5.2
@@ -17,6 +17,7 @@ pandas==1.4.1
progressbar33==2.4
protobuf==3.20.*
resampy==0.4.2
+safetensors==0.3.2
sentencepiece==0.1.99
scikit_image==0.19.3
scikit_learn==1.2.2
diff --git a/__pycache__/modelling_deberta_v2.cpython-310.pyc b/__pycache__/modelling_deberta_v2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82acd852688915436b649aca16fae940e0ab15ca
Binary files /dev/null and b/__pycache__/modelling_deberta_v2.cpython-310.pyc differ
diff --git a/__pycache__/models.cpython-310.pyc b/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c40ff371b2eb0d45af66c0125a41feb17daf8ab2
Binary files /dev/null and b/__pycache__/models.cpython-310.pyc differ
diff --git a/app.py b/app.py
index aa0703928777487a88eb6c87fd31edb01a1e6252..869a2f3eb47978921e3158f2c0b9dd4681ca9aba 100644
--- a/app.py
+++ b/app.py
@@ -2,6 +2,7 @@ import gradio as gr
import json
import torch
import wavio
+import numpy as np
from tqdm import tqdm
from huggingface_hub import snapshot_download
@@ -23,6 +24,7 @@ class MusicFeaturePredictor:
def __init__(self, path, device="cuda:0", cache_dir=None, local_files_only=False):
self.beats_tokenizer = AutoTokenizer.from_pretrained(
"microsoft/deberta-v3-large",
+ use_fast=False,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
@@ -164,6 +166,7 @@ class Mustango:
main_config["scheduler_name"],
unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json",
).to(device)
+ self.model.device = device
vae_weights = torch.load(
f"{path}/vae/pytorch_model_vae.bin", map_location=device
@@ -213,9 +216,11 @@ class Mustango:
# Initialize Mustango
if torch.cuda.is_available():
- mustango = Mustango()
+ mustango = Mustango(device="cpu")
else:
mustango = Mustango(device="cpu")
+
+output_wave = mustango.generate("This techno song features a synth lead playing the main melody.", 5, 3, disable_progress=False)
def gradio_generate(prompt, steps, guidance):
output_wave = mustango.generate(prompt, steps, guidance)
@@ -225,6 +230,7 @@ def gradio_generate(prompt, steps, guidance):
return output_filename
+
# description_text = """
#
For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings.
# Generate music using Mustango by providing a text prompt.
diff --git a/audioldm/__pycache__/__init__.cpython-310.pyc b/audioldm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..394b466c2988b4fd8a7a90f7a5108e38e404ec34
Binary files /dev/null and b/audioldm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audioldm/__pycache__/ldm.cpython-310.pyc b/audioldm/__pycache__/ldm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0760f0fe27d2ffdcd73eb51b06d36b8af8ec423c
Binary files /dev/null and b/audioldm/__pycache__/ldm.cpython-310.pyc differ
diff --git a/audioldm/__pycache__/pipeline.cpython-310.pyc b/audioldm/__pycache__/pipeline.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b8160c2e6244c3ae05e1d3c414929a2b427578e
Binary files /dev/null and b/audioldm/__pycache__/pipeline.cpython-310.pyc differ
diff --git a/audioldm/__pycache__/utils.cpython-310.pyc b/audioldm/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8685e28ce2a7bd2474d38406a661233120156347
Binary files /dev/null and b/audioldm/__pycache__/utils.cpython-310.pyc differ
diff --git a/audioldm/audio/__pycache__/__init__.cpython-310.pyc b/audioldm/audio/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a55a703eea963145e7914cb738d1410e45915c1a
Binary files /dev/null and b/audioldm/audio/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audioldm/audio/__pycache__/audio_processing.cpython-310.pyc b/audioldm/audio/__pycache__/audio_processing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..690a842361820408d71ba597d3a0b02797a5c7f6
Binary files /dev/null and b/audioldm/audio/__pycache__/audio_processing.cpython-310.pyc differ
diff --git a/audioldm/audio/__pycache__/stft.cpython-310.pyc b/audioldm/audio/__pycache__/stft.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7b23c07f7f475639177ad0942fcace0ffd6b876
Binary files /dev/null and b/audioldm/audio/__pycache__/stft.cpython-310.pyc differ
diff --git a/audioldm/audio/__pycache__/tools.cpython-310.pyc b/audioldm/audio/__pycache__/tools.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..096a1cc5b80886cb25fb2bd9030b6586df02187d
Binary files /dev/null and b/audioldm/audio/__pycache__/tools.cpython-310.pyc differ
diff --git a/audioldm/hifigan/__pycache__/__init__.cpython-310.pyc b/audioldm/hifigan/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b746991833ef6121dd3e16d672d8f42f921b6e62
Binary files /dev/null and b/audioldm/hifigan/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audioldm/hifigan/__pycache__/models.cpython-310.pyc b/audioldm/hifigan/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd0f4f290e3203d344e740c73934f8c61aa73b44
Binary files /dev/null and b/audioldm/hifigan/__pycache__/models.cpython-310.pyc differ
diff --git a/audioldm/hifigan/__pycache__/utilities.cpython-310.pyc b/audioldm/hifigan/__pycache__/utilities.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6212b032c4bea702b131fe992c886959775f229
Binary files /dev/null and b/audioldm/hifigan/__pycache__/utilities.cpython-310.pyc differ
diff --git a/audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..486b822a287344fcc07dc46fc7e72a384ea05ea7
Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..813f57c2828c51522baaa682a4987b1b1c38ce57
Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc differ
diff --git a/audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aeda4e5729dccea1d6f3072a7c6f56867787e290
Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc differ
diff --git a/audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f215700808a76356fa0794855041c031fb034ad1
Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc differ
diff --git a/audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..000683591b3dadb5c564b1ba06275cf949e297df
Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc differ
diff --git a/audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17a6eacb3fbaae64ed7d66419ddc0af613eb71c8
Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc differ
diff --git a/audioldm/variational_autoencoder/__pycache__/__init__.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b8a54fed3a335cbf6dc4c90864c31dc0004bd07
Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..765d78230e36f30b928da047a31c1a635bfb0823
Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-310.pyc differ
diff --git a/audioldm/variational_autoencoder/__pycache__/distributions.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/distributions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf601683e015a37939b5fc4e5fbc064c4727fa7b
Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/distributions.cpython-310.pyc differ
diff --git a/audioldm/variational_autoencoder/__pycache__/modules.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de4d11a7f30fe1c2b73d266f80472e5661ec555d
Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/modules.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e8ca64596a6b6e3df9934bf5fc5d308827d1e34
Binary files /dev/null and b/diffusers/src/diffusers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/__pycache__/configuration_utils.cpython-310.pyc b/diffusers/src/diffusers/__pycache__/configuration_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c31fbcbc828e98d0238434b13a407036217dd42b
Binary files /dev/null and b/diffusers/src/diffusers/__pycache__/configuration_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/__pycache__/image_processor.cpython-310.pyc b/diffusers/src/diffusers/__pycache__/image_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37eaec95cdf868229ce1bb02d1f98d88ea23b9fb
Binary files /dev/null and b/diffusers/src/diffusers/__pycache__/image_processor.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/__pycache__/loaders.cpython-310.pyc b/diffusers/src/diffusers/__pycache__/loaders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..338623d82eb037dcd577385e1cf96561edc561e4
Binary files /dev/null and b/diffusers/src/diffusers/__pycache__/loaders.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/__pycache__/optimization.cpython-310.pyc b/diffusers/src/diffusers/__pycache__/optimization.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..455bef21c3882ac9e04f47c76bd1ab67abf0db3c
Binary files /dev/null and b/diffusers/src/diffusers/__pycache__/optimization.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/__pycache__/pipeline_utils.cpython-310.pyc b/diffusers/src/diffusers/__pycache__/pipeline_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fc7da966600d1cfedd705e7c6aab6c41967396d
Binary files /dev/null and b/diffusers/src/diffusers/__pycache__/pipeline_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/__pycache__/training_utils.cpython-310.pyc b/diffusers/src/diffusers/__pycache__/training_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d7658a10822179eee3a04303038f207bf278811
Binary files /dev/null and b/diffusers/src/diffusers/__pycache__/training_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10f2516a2ec0e880addb5ad3b4fff8f7e8483080
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/attention.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a4d184cf867f2459b50efab957a5e9cb0c7780c
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/attention.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/attention_processor.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/attention_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0617a1eb27b311fcd34f42acee2ac3b7e0345df
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/attention_processor.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a21dc8c2ab40cd7645ae62e260a752a2ae1c460
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/controlnet.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43f7a2d2673d7dd214968d7d2f919c53aa212314
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/controlnet.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2aed246eae48aab6864e50257ee1245a0c5c45d
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/embeddings.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/embeddings.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72c587ed1563b49233062035924c43c8765dbe82
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/embeddings.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6dbe26b02823c4c91614d7622df7af716ac2847
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce1fbd8db4802cf9357b171bdfaae2acca0ea982
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/resnet.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/resnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd7821e844b306249e9bc63f25a40ea4e4d68684
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/resnet.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35322111af383d5c207edcf368b6537899869860
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fb38fb88374d128bda0f7f81621d8f9fc2311e5
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6abd254c082509aad4435211e42913fbabd9227
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_1d.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_1d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c45e57f0daa93d0c9b3f09674e352582b9c26b
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_1d.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce9e1698743c71ca0eabfce0d4166a6d8c21dee5
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_2d.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f170d383912f24addbf5af56f8f8401e8e44a60c
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_2d.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1eb6c929e9489bbd0b51a8be00223fb4a02390f0
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..068aeee8d84d786452c01afedc4979a28309db85
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_2d_condition_music.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_2d_condition_music.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c36ebac5869777d5c7cd8003d07c7041d8ba2c5c
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_2d_condition_music.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e799dfd0059b0ff5f250386bc01969d7d34a5a43
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d4474ae55ab08b1b1247b1de96457a77560e6ee
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/vae.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03e839e5b0e42753291d4448a1215a4106c9597f
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/vae.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/models/__pycache__/vq_model.cpython-310.pyc b/diffusers/src/diffusers/models/__pycache__/vq_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ade68461bfd1e17226b381eee3d6f5a70c5eeeb
Binary files /dev/null and b/diffusers/src/diffusers/models/__pycache__/vq_model.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..733d0b24f413795d7e319ecc453877f6b37c01bc
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc b/diffusers/src/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..803df7b10d0f1c64d5deaba473683dab66567418
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..293947bd0bc8dda1b967d17d287d7dbe1d71a167
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d147534b26c19cbe8423ccedf82d170c4bffa653
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83e2ba659f548b9805d276a033ef6c9e61f5eaaa
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9e22397ac94080a408024f1c913d6067b158c3c
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbb99f323d5baaa247437fd7d13ea3a0dd69ebce
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/mel.cpython-310.pyc b/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/mel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fc959d1dc80395fd113db35fe6192a64ce1b9ab
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/mel.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a52719911ccc24b2985551ae2df70ffd2994c36e
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c238f12f5909910ce689024252816e5a94f73aa8
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc b/diffusers/src/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8985e8c2af8731c426ffdff51521b77d30be20d7
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6e91b5a3928f616488e7f5a0f4c73fa2a0ad8d8
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de89fa8cb735bf7f078ee238debd6156501c7b8a
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e7b688d9c324b88b7e96fc7a038dd7227615ae6
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc b/diffusers/src/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbbfa0b12ab391b96efbe970ed23b996a86578e2
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a305bd4a15b5c006d3b13e7890e67b305c72e747
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc b/diffusers/src/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74d6225ff286ca6d838b86fc2a7ecdc6b9378107
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e53ef5e0f0e9cfa164127e129e3d8c9c0cbf153d
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc b/diffusers/src/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99088c60e316e7868e99def69f351f2874ae355e
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffa3340ed99bec3df8ea2488a5c0c180817bb53e
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6509c25b8751a3f3ff7e4f02831b465f8a858353
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc b/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b7bbd49b17ed8f4a20ac971c23e2fea2e7aeab2
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e1bd890f73004f3737c5b44c2af412242cb691a
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc b/diffusers/src/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b5d2de322cad8cc0ecd4fd4a1e5ecfad07a1e5d
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..738fe8f2de8749b8571139cb8c8b04156e8c9af5
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc b/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0419df6dea3e2f116642990e707ad37fcd2f013f
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc b/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80d07eff3c7f115d6847ed4f9fbaf1cf6aa28fbe
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92ffe71d6a054449e9ddce29076295274b616efc
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc b/diffusers/src/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bff57cd753778c97fbddfa907923ee90271985d3
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40e41ed1e7096bf13d9695c86c5f7368b0620cc0
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc b/diffusers/src/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77bfe866458de39025e2b997743e96eb847eeb67
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0bf261a5cb77f9f55fb0163c033ff439531c3c56
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc b/diffusers/src/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..812fb54ba118d67c2bcf56409ce7d180f0b298cd
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13e95647241fe3b7cf5b2a2a656d5a4e41544617
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf5e99f7b4cf7042d841662d38034a70da5698dc
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46c8e846508386e8191089de12567c853814b894
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f62b8e385ffe8fd510bceac4c9ed5e9333d9b08c
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75ac93ff2fc9395cd0d3f1a4691e3bca31c170b9
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abb133e2acb0e4cb0f382972669c8836ac71641f
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_controlnet.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ab71d80f13005125dd7d1614375983fe1ba903b
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_controlnet.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e328ac731f0efac5b854ed24df962320d2d16332
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8a0c9ff10929e5745ab5c6549f6cc355eb0aa16
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1efcf2f871737026e8e815494837f4e3d4ee823a
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3f83ab0048435ef7eca177f5b61d9bc762b1361
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d302d06adbd1ce06034b0504609c7319c70e8e1b
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..867bf05c3ea3cf3891ff8c58b8b4e87cd97fe86e
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af509c47acfa889bb06f920da0f885723791c9f8
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bdbe1009a2c23da9bbd1b35ab3ad2ff7317e051
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04fef03a8e7fad760421093f1434f442ca8267b3
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c138406df50b02fce2d1fb4ac21335819948775e
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cb99eccfbd4d06c1815c70f2ef03685a6db07a4
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96633802423a4a5d01e0cc53fcb3753720ee17f4
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99e75630b02b36654e9a3b94149657c8f5f99aba
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86c37250ca0d0e2af0472ed3a8d7bed964604730
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e944438b068b1c28f63813842321e9f1acd8449f
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd8b75cb3eefa362a5b10e9c921314a7b6b118a9
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cae5846920c9580faee3d24ac434831d62c6f411
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c5b6f906a08001be218f1976eaa0a38485288f8
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19877eaa53a0717b6dbed84ec6d722046d224b97
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65ef172523e2c7069693d7fe06bd2642e585433d
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc b/diffusers/src/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b3e4b9070fd40ed1a0238131f2a9dcd102cc95f
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5128c5dce8f64e0865115115ec4c9a2d30ba0f35
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c1128130ab3b4c0e558251e6c24a4f69745b2a1
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b70db6de8aa736ed9a371f97b27ab5c7bf494160
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc b/diffusers/src/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebf5a0a78f2e1c33b370580b7c40a615f5cdb1c7
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc b/diffusers/src/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0d9925bede08fcefc506eefcb5be8e23ec6e6a3
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc b/diffusers/src/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edd5f9b9befc8a1a7c5e1ac1572f61c048e982ed
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d16ee1302606c5d3f73bb0888aab317ec0545b9f
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4b343ffea8436d0986075fc8abba0f3d1a44f7e
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a951c559c0165fbaa8242a9d6a0ee2c7e1c2c01
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..178beda82ea3a05d4954a381acd7cb32853db1e2
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df4a0913c8215e01b19443df8f798a47924f9399
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf0a8c38612095f0cb2ec89a4c6c60a7fcf8b6ae
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..222a0e61e572950c863811e79268251d802d2d6d
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc b/diffusers/src/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d2202ed3831fa0c64e27a8b16c5b636aad1c2a3
Binary files /dev/null and b/diffusers/src/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a952380efd2dbda15d519bac511af7c5fb1864fa
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..169b68e2c09e14ec4541954fa3411bcd9fd6d3c1
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b665dd89971e5044353cd3146cdee5c6ef7f3acc
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb8bf32a200b224a51eada540cee0eace753998f
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a38cdd582a37319972b4f315dfc7e60ed0853363
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3771e8a84400d1cae9d169cd4816d9eaaa01f0e
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f8cd9f2209dae750bcb8d22d4662c31ad49d5ec
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ae3e4dedd084bf92c1ee2227350ae000cec8947
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4b19bed0746cb62d0acfb70d168782c57cee887
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f27835166bcee0da9380a6c50988fc3724eebafa
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56b02f9a9b89aba638328ac4aff9297028cf22e3
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06db03fe105a9958f8642c80dc2c6fa5b7f97d19
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f3c5e8dfb2587b71f88c373b46d6c8d1202dfc0
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae84a9cfa1dd038ceaeb1e17a35462a8449925a5
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca7a63702018838fa7d486290e72c63e7dcbc795
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8da8f92e8e092a3d9d64bdb259a5f251b3e4cf99
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97f19f76a7f271bd1382cacfe1ba5e9dd8b0a6df
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..407f80f2963ae249d06d33b32efc8ef433062a4b
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7541240b5c857438b8870cda7793be8a9021fe6a
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ce2e346a2141112489a395045d6998d7e1aa84a
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebecfc10d1175b51b6a4a62027993f73dd6c7718
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b3ce40cdecdc035f87db9801db3a802158e16a9
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d762424b44332bf11f914ff3af62a5df193113fa
Binary files /dev/null and b/diffusers/src/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/__init__.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7979f9707fdba5ae1a6b3a931788d32c089b9f1c
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b67ae587d6e12ac192a13e48ceb716c48dbe0d5
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/constants.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a66ba5035a5f11f991ebef5b2577ac3c218d234e
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/constants.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72e0ee32a37c20a1e4404b004f2236a3a5c2f5b8
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aea9380c2b602d7b1aeb7f1ab52e8da93d1b36d5
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4fe074e9f0fc416c9b3e17a1edbaff3d658c20b
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48f17bb0470f3799b2dbc0611e3c8806ce773151
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..932d5fa056462a8e633092dee332ebd258d5d796
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e8178decb0a36624f2449083c645401c48b694d
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..971f4ad2fc03dd4326f70afd657fa0c673027d8d
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16e0ecc6a1762bd75bb141cf538e3a5943fbe8e7
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5f517b1e2bb9e4babbd97318b67571bc1f91b65
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b464f291b4683ea1b05b182be00f136862e6d28a
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22aefa847498430b96d604a6e6b7ca7e05295ba8
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/import_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/import_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c272d65c3f051ff4fd87ed83105a1a59268fab8
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/import_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/logging.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/logging.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a72051086a59654d0ded4e593711e2beef4fbcd
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/logging.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/outputs.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/outputs.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fff084a077e8c2b12894b23b408f17f029d381c6
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/outputs.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..636baee16920a55b991b8834c682e367d1a82634
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7da29b43f3949ab96e45c82c2196ad173196b49a
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc differ
diff --git a/diffusers/src/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc b/diffusers/src/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cafd8a93004bf9e8b7f00ec10c272c392340bbda
Binary files /dev/null and b/diffusers/src/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc differ
diff --git a/layers/.ipynb_checkpoints/layers-checkpoint.py b/layers/.ipynb_checkpoints/layers-checkpoint.py
index 15d715533ef15feed4a98bcd114e165e874fefb3..dcbc0722776f04af5556401dd6e776fb87179fa7 100644
--- a/layers/.ipynb_checkpoints/layers-checkpoint.py
+++ b/layers/.ipynb_checkpoints/layers-checkpoint.py
@@ -28,7 +28,7 @@ class Fundamental_Music_Embedding(nn.Module):
i = torch.arange(d_model)
angle_rates = 1 / torch.pow(self.base, (2 * (i//2)) / d_model)
- angle_rates = angle_rates[None, ... ].cuda()
+ angle_rates = angle_rates[None, ... ]#.cuda()
if self.if_trainable:
angles = nn.Parameter(angle_rates, requires_grad=True)
@@ -38,12 +38,12 @@ class Fundamental_Music_Embedding(nn.Module):
self.angles = angle_rates
- def __call__(self, inp):
+ def __call__(self, inp, device):
if inp.dim()==2:
inp = inp[..., None] #pos (batch, num_pitch, 1)
elif inp.dim()==1:
inp = inp[None, ..., None] #pos (1, num_pitch, 1)
- angle_rads = inp*self.angles #(batch, num_pitch)*(1,dim)
+ angle_rads = inp*self.angles.to(device) #(batch, num_pitch)*(1,dim)
# apply sin to even indices in the array; 2i
angle_rads[:, :, 0::2] = torch.sin(angle_rads.clone()[:, : , 0::2])
@@ -71,9 +71,18 @@ class Music_PositionalEncoding(nn.Module):
self.if_global_timing = if_global_timing
self.if_modulo_timing = if_modulo_timing
self.dropout = nn.Dropout(p=dropout)
- self.index_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10000, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda()
- self.global_time_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10001, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda()
- self.modulo_time_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10001, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda()
+ self.index_embedding = Fundamental_Music_Embedding(
+ d_model = d_model, base=10000, if_trainable=False, translation_bias_type = None,
+ if_translation_bias_trainable = False, type = "se"
+ )# .cuda()
+ self.global_time_embedding = Fundamental_Music_Embedding(
+ d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None,
+ if_translation_bias_trainable = False, type = "se"
+ )# .cuda()
+ self.modulo_time_embedding = Fundamental_Music_Embedding(
+ d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None,
+ if_translation_bias_trainable = False, type = "se"
+ )# .cuda()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
@@ -124,7 +133,6 @@ class PositionalEncoding(nn.Module):
def forward(self, x):
pos = self.pe[:x.size(1)] #[seq_len, batch_size, embedding_dim]
pos = torch.swapaxes(pos, 0, 1) #[batch_size, seq_len, embedding_dim]
- print("huh????", pos.shape, x.shape)
x = x + pos
return self.dropout(x)
@@ -254,13 +262,13 @@ class Chord_Embedding(nn.Module):
self.d_model = d_model
self.d_oh_type = d_oh_type
self.d_oh_inv = d_oh_inv
- self.chord_ffn = nn.Linear(d_oh_type + d_oh_inv + d_model + d_model, d_model).cuda()
- def __call__(self, chord_root, chord_type, chord_inv, chord_timing):
+ self.chord_ffn = nn.Linear(d_oh_type + d_oh_inv + d_model + d_model, d_model) #.cuda()
+ def __call__(self, chord_root, chord_type, chord_inv, chord_timing, device):
#chords: (B, LEN, 4)
#Embed root using FME
#Embed chord type, chord inversion using OH
#Embed timestamps using shared PE
- chord_root_emb = self.FME(chord_root)
+ chord_root_emb = self.FME(chord_root, device)
# print(chord_root_emb.size())
# print('this is chord root: ', chord_root)
# print('this is chord type: ', chord_type)
@@ -272,9 +280,9 @@ class Chord_Embedding(nn.Module):
# chord_root_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_model).to(torch.float32)
chord_type_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_oh_type).to(torch.float32)
chord_inv_emb = F.one_hot(chord_inv.to(torch.int64), num_classes = self.d_oh_inv).to(torch.float32)
- chord_time_emb = self.PE.global_time_embedding(chord_timing)
+ chord_time_emb = self.PE.global_time_embedding(chord_timing, device)
- chord_emb = self.chord_ffn(torch.cat((chord_root_emb, chord_type_emb, chord_inv_emb, chord_time_emb), dim = -1))
+ chord_emb = self.chord_ffn(torch.cat((chord_root_emb, chord_type_emb, chord_inv_emb, chord_time_emb), dim = -1).to(device))
# print("TADY toje", chord_emb.device)
return chord_emb
@@ -287,13 +295,13 @@ class Beat_Embedding(nn.Module):
self.d_oh_beat_type = d_oh_beat_type
self.beat_ffn = nn.Linear(d_oh_beat_type+d_model, d_model)
- def __call__(self, beats, beats_timing):
+ def __call__(self, beats, beats_timing, device):
#Embed beat type using OH
#Embed time using PE
- beat_type_emb = F.one_hot(beats.to(torch.int64), num_classes = self.d_oh_beat_type).to(torch.float32)
- beat_time_emb = self.PE.global_time_embedding(beats_timing)
- merged_beat = torch.cat((beat_type_emb, beat_time_emb), dim = -1).cuda()
+ beat_type_emb = F.one_hot(beats.to(torch.int64), num_classes = self.d_oh_beat_type).to(torch.float32).to(device)
+ beat_time_emb = self.PE.global_time_embedding(beats_timing, device)
+ merged_beat = torch.cat((beat_type_emb, beat_time_emb), dim = -1)
beat_emb = self.beat_ffn(merged_beat)
return beat_emb
diff --git a/layers/__pycache__/layers.cpython-310.pyc b/layers/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6d5533d89c60e91a53b57bda00c3c18fb6d1069
Binary files /dev/null and b/layers/__pycache__/layers.cpython-310.pyc differ
diff --git a/layers/layers.py b/layers/layers.py
index 15d715533ef15feed4a98bcd114e165e874fefb3..dcbc0722776f04af5556401dd6e776fb87179fa7 100644
--- a/layers/layers.py
+++ b/layers/layers.py
@@ -28,7 +28,7 @@ class Fundamental_Music_Embedding(nn.Module):
i = torch.arange(d_model)
angle_rates = 1 / torch.pow(self.base, (2 * (i//2)) / d_model)
- angle_rates = angle_rates[None, ... ].cuda()
+ angle_rates = angle_rates[None, ... ]#.cuda()
if self.if_trainable:
angles = nn.Parameter(angle_rates, requires_grad=True)
@@ -38,12 +38,12 @@ class Fundamental_Music_Embedding(nn.Module):
self.angles = angle_rates
- def __call__(self, inp):
+ def __call__(self, inp, device):
if inp.dim()==2:
inp = inp[..., None] #pos (batch, num_pitch, 1)
elif inp.dim()==1:
inp = inp[None, ..., None] #pos (1, num_pitch, 1)
- angle_rads = inp*self.angles #(batch, num_pitch)*(1,dim)
+ angle_rads = inp*self.angles.to(device) #(batch, num_pitch)*(1,dim)
# apply sin to even indices in the array; 2i
angle_rads[:, :, 0::2] = torch.sin(angle_rads.clone()[:, : , 0::2])
@@ -71,9 +71,18 @@ class Music_PositionalEncoding(nn.Module):
self.if_global_timing = if_global_timing
self.if_modulo_timing = if_modulo_timing
self.dropout = nn.Dropout(p=dropout)
- self.index_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10000, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda()
- self.global_time_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10001, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda()
- self.modulo_time_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10001, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda()
+ self.index_embedding = Fundamental_Music_Embedding(
+ d_model = d_model, base=10000, if_trainable=False, translation_bias_type = None,
+ if_translation_bias_trainable = False, type = "se"
+ )# .cuda()
+ self.global_time_embedding = Fundamental_Music_Embedding(
+ d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None,
+ if_translation_bias_trainable = False, type = "se"
+ )# .cuda()
+ self.modulo_time_embedding = Fundamental_Music_Embedding(
+ d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None,
+ if_translation_bias_trainable = False, type = "se"
+ )# .cuda()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
@@ -124,7 +133,6 @@ class PositionalEncoding(nn.Module):
def forward(self, x):
pos = self.pe[:x.size(1)] #[seq_len, batch_size, embedding_dim]
pos = torch.swapaxes(pos, 0, 1) #[batch_size, seq_len, embedding_dim]
- print("huh????", pos.shape, x.shape)
x = x + pos
return self.dropout(x)
@@ -254,13 +262,13 @@ class Chord_Embedding(nn.Module):
self.d_model = d_model
self.d_oh_type = d_oh_type
self.d_oh_inv = d_oh_inv
- self.chord_ffn = nn.Linear(d_oh_type + d_oh_inv + d_model + d_model, d_model).cuda()
- def __call__(self, chord_root, chord_type, chord_inv, chord_timing):
+ self.chord_ffn = nn.Linear(d_oh_type + d_oh_inv + d_model + d_model, d_model) #.cuda()
+ def __call__(self, chord_root, chord_type, chord_inv, chord_timing, device):
#chords: (B, LEN, 4)
#Embed root using FME
#Embed chord type, chord inversion using OH
#Embed timestamps using shared PE
- chord_root_emb = self.FME(chord_root)
+ chord_root_emb = self.FME(chord_root, device)
# print(chord_root_emb.size())
# print('this is chord root: ', chord_root)
# print('this is chord type: ', chord_type)
@@ -272,9 +280,9 @@ class Chord_Embedding(nn.Module):
# chord_root_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_model).to(torch.float32)
chord_type_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_oh_type).to(torch.float32)
chord_inv_emb = F.one_hot(chord_inv.to(torch.int64), num_classes = self.d_oh_inv).to(torch.float32)
- chord_time_emb = self.PE.global_time_embedding(chord_timing)
+ chord_time_emb = self.PE.global_time_embedding(chord_timing, device)
- chord_emb = self.chord_ffn(torch.cat((chord_root_emb, chord_type_emb, chord_inv_emb, chord_time_emb), dim = -1))
+ chord_emb = self.chord_ffn(torch.cat((chord_root_emb, chord_type_emb, chord_inv_emb, chord_time_emb), dim = -1).to(device))
# print("TADY toje", chord_emb.device)
return chord_emb
@@ -287,13 +295,13 @@ class Beat_Embedding(nn.Module):
self.d_oh_beat_type = d_oh_beat_type
self.beat_ffn = nn.Linear(d_oh_beat_type+d_model, d_model)
- def __call__(self, beats, beats_timing):
+ def __call__(self, beats, beats_timing, device):
#Embed beat type using OH
#Embed time using PE
- beat_type_emb = F.one_hot(beats.to(torch.int64), num_classes = self.d_oh_beat_type).to(torch.float32)
- beat_time_emb = self.PE.global_time_embedding(beats_timing)
- merged_beat = torch.cat((beat_type_emb, beat_time_emb), dim = -1).cuda()
+ beat_type_emb = F.one_hot(beats.to(torch.int64), num_classes = self.d_oh_beat_type).to(torch.float32).to(device)
+ beat_time_emb = self.PE.global_time_embedding(beats_timing, device)
+ merged_beat = torch.cat((beat_type_emb, beat_time_emb), dim = -1)
beat_emb = self.beat_ffn(merged_beat)
return beat_emb
diff --git a/models.py b/models.py
index 652d8f290ff0415e03a723d97a238bd0b6a03f4d..259288a96e229be80c97ae62fa45f2d1813d16fa 100644
--- a/models.py
+++ b/models.py
@@ -28,711 +28,713 @@ from diffusers import AutoencoderKL as DiffuserAutoencoderKL
from layers.layers import chord_tokenizer, beat_tokenizer, Chord_Embedding, Beat_Embedding, Music_PositionalEncoding, Fundamental_Music_Embedding
def build_pretrained_models(name):
- checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
- scale_factor = checkpoint["state_dict"]["scale_factor"].item()
+ checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
+ scale_factor = checkpoint["state_dict"]["scale_factor"].item()
- vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
+ vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
- config = default_audioldm_config(name)
- vae_config = config["model"]["params"]["first_stage_config"]["params"]
- vae_config["scale_factor"] = scale_factor
+ config = default_audioldm_config(name)
+ vae_config = config["model"]["params"]["first_stage_config"]["params"]
+ vae_config["scale_factor"] = scale_factor
- vae = AutoencoderKL(**vae_config)
- vae.load_state_dict(vae_state_dict)
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(vae_state_dict)
- fn_STFT = TacotronSTFT(
- config["preprocessing"]["stft"]["filter_length"],
- config["preprocessing"]["stft"]["hop_length"],
- config["preprocessing"]["stft"]["win_length"],
- config["preprocessing"]["mel"]["n_mel_channels"],
- config["preprocessing"]["audio"]["sampling_rate"],
- config["preprocessing"]["mel"]["mel_fmin"],
- config["preprocessing"]["mel"]["mel_fmax"],
- )
+ fn_STFT = TacotronSTFT(
+ config["preprocessing"]["stft"]["filter_length"],
+ config["preprocessing"]["stft"]["hop_length"],
+ config["preprocessing"]["stft"]["win_length"],
+ config["preprocessing"]["mel"]["n_mel_channels"],
+ config["preprocessing"]["audio"]["sampling_rate"],
+ config["preprocessing"]["mel"]["mel_fmin"],
+ config["preprocessing"]["mel"]["mel_fmax"],
+ )
- vae.eval()
- fn_STFT.eval()
- return vae, fn_STFT
+ vae.eval()
+ fn_STFT.eval()
+ return vae, fn_STFT
class AudioDiffusion(nn.Module):
- def __init__(
- self,
- text_encoder_name,
- scheduler_name,
- unet_model_name=None,
- unet_model_config_path=None,
- snr_gamma=None,
- freeze_text_encoder=True,
- uncondition=False,
-
- ):
- super().__init__()
-
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
-
- self.text_encoder_name = text_encoder_name
- self.scheduler_name = scheduler_name
- self.unet_model_name = unet_model_name
- self.unet_model_config_path = unet_model_config_path
- self.snr_gamma = snr_gamma
- self.freeze_text_encoder = freeze_text_encoder
- self.uncondition = uncondition
-
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
- self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
- self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
-
- if unet_model_config_path:
- unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
- self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
- self.set_from = "random"
- print("UNet initialized randomly.")
- else:
- self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
- self.set_from = "pre-trained"
- self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
- self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
- print("UNet initialized from stable diffusion checkpoint.")
-
- if "stable-diffusion" in self.text_encoder_name:
- self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
- self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
- elif "t5" in self.text_encoder_name:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
- else:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
-
- def compute_snr(self, timesteps):
- """
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
- """
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
- sqrt_alphas_cumprod = alphas_cumprod**0.5
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
-
- # Expand the tensors.
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
-
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
-
- # Compute SNR.
- snr = (alpha / sigma) ** 2
- return snr
-
- def encode_text(self, prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
-
- if self.freeze_text_encoder:
- with torch.no_grad():
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
- else:
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
-
- boolean_encoder_mask = (attention_mask == 1).to(device)
- return encoder_hidden_states, boolean_encoder_mask
-
- def forward(self, latents, prompt, validation_mode=False):
- device = self.text_encoder.device
- num_train_timesteps = self.noise_scheduler.num_train_timesteps
- self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
-
- encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
-
- if self.uncondition:
- mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
- if len(mask_indices) > 0:
- encoder_hidden_states[mask_indices] = 0
-
- bsz = latents.shape[0]
-
- if validation_mode:
- timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
- else:
- # Sample a random timestep for each instance
- timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
- # print('in if ', timesteps)
- timesteps = timesteps.long()
- # print('outside if ' , timesteps)
- noise = torch.randn_like(latents)
- noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the target for loss depending on the prediction type
- if self.noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif self.noise_scheduler.config.prediction_type == "v_prediction":
- target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
- else:
- raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
-
- if self.set_from == "random":
- model_pred = self.unet(
- noisy_latents, timesteps, encoder_hidden_states,
- encoder_attention_mask=boolean_encoder_mask
- ).sample
-
- elif self.set_from == "pre-trained":
- compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- model_pred = self.unet(
- compressed_latents, timesteps, encoder_hidden_states,
- encoder_attention_mask=boolean_encoder_mask
- ).sample
- model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
-
- if self.snr_gamma is None:
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
- # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
- snr = self.compute_snr(timesteps)
- mse_loss_weights = (
- torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
- )
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
- loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
- loss = loss.mean()
-
- return loss
-
- @torch.no_grad()
- def inference(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
- disable_progress=True):
- device = self.text_encoder.device
- classifier_free_guidance = guidance_scale > 1.0
- batch_size = len(prompt) * num_samples_per_prompt
-
- if classifier_free_guidance:
- prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
- else:
- prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- inference_scheduler.set_timesteps(num_steps, device=device)
- timesteps = inference_scheduler.timesteps
-
- num_channels_latents = self.unet.in_channels
- latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
-
- num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
- progress_bar = tqdm(range(num_steps), disable=disable_progress)
-
- for i, t in enumerate(timesteps):
- # expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
- latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
-
- noise_pred = self.unet(
- latent_model_input, t, encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=boolean_prompt_mask
- ).sample
-
- # perform guidance
- if classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
-
- # call the callback, if provided
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
- progress_bar.update(1)
-
- if self.set_from == "pre-trained":
- latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- return latents
-
- def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
- shape = (batch_size, num_channels_latents, 256, 16)
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
- # scale the initial noise by the standard deviation required by the scheduler
- latents = latents * inference_scheduler.init_noise_sigma
- return latents
-
- def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
-
- with torch.no_grad():
- prompt_embeds = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
-
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # get unconditional embeddings for classifier free guidance
- uncond_tokens = [""] * len(prompt)
-
- max_length = prompt_embeds.shape[1]
- uncond_batch = self.tokenizer(
- uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
- )
- uncond_input_ids = uncond_batch.input_ids.to(device)
- uncond_attention_mask = uncond_batch.attention_mask.to(device)
-
- with torch.no_grad():
- negative_prompt_embeds = self.text_encoder(
- input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
- )[0]
-
- negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # For classifier free guidance, we need to do two forward passes.
- # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
- boolean_prompt_mask = (prompt_mask == 1).to(device)
-
- return prompt_embeds, boolean_prompt_mask
-
+ def __init__(
+ self,
+ text_encoder_name,
+ scheduler_name,
+ unet_model_name=None,
+ unet_model_config_path=None,
+ snr_gamma=None,
+ freeze_text_encoder=True,
+ uncondition=False,
+
+ ):
+ super().__init__()
+
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
+
+ self.text_encoder_name = text_encoder_name
+ self.scheduler_name = scheduler_name
+ self.unet_model_name = unet_model_name
+ self.unet_model_config_path = unet_model_config_path
+ self.snr_gamma = snr_gamma
+ self.freeze_text_encoder = freeze_text_encoder
+ self.uncondition = uncondition
+
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
+ self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+ self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+
+ if unet_model_config_path:
+ unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
+ self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
+ self.set_from = "random"
+ print("UNet initialized randomly.")
+ else:
+ self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
+ self.set_from = "pre-trained"
+ self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
+ self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
+ print("UNet initialized from stable diffusion checkpoint.")
+
+ if "stable-diffusion" in self.text_encoder_name:
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
+ elif "t5" in self.text_encoder_name:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
+
+ def compute_snr(self, timesteps):
+ """
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ # Expand the tensors.
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+ def encode_text(self, prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
+
+ if self.freeze_text_encoder:
+ with torch.no_grad():
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+ else:
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+
+ boolean_encoder_mask = (attention_mask == 1).to(device)
+ return encoder_hidden_states, boolean_encoder_mask
+
+ def forward(self, latents, prompt, validation_mode=False):
+ device = self.text_encoder.device
+ num_train_timesteps = self.noise_scheduler.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
+
+ if self.uncondition:
+ mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
+ if len(mask_indices) > 0:
+ encoder_hidden_states[mask_indices] = 0
+
+ bsz = latents.shape[0]
+
+ if validation_mode:
+ timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
+ else:
+ # Sample a random timestep for each instance
+ timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
+ # print('in if ', timesteps)
+ timesteps = timesteps.long()
+ # print('outside if ' , timesteps)
+ noise = torch.randn_like(latents)
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the target for loss depending on the prediction type
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
+
+ if self.set_from == "random":
+ model_pred = self.unet(
+ noisy_latents, timesteps, encoder_hidden_states,
+ encoder_attention_mask=boolean_encoder_mask
+ ).sample
+
+ elif self.set_from == "pre-trained":
+ compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ model_pred = self.unet(
+ compressed_latents, timesteps, encoder_hidden_states,
+ encoder_attention_mask=boolean_encoder_mask
+ ).sample
+ model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+
+ if self.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+ snr = self.compute_snr(timesteps)
+ mse_loss_weights = (
+ torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ return loss
+
+ @torch.no_grad()
+ def inference(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
+ disable_progress=True):
+ device = self.text_encoder.device
+ classifier_free_guidance = guidance_scale > 1.0
+ batch_size = len(prompt) * num_samples_per_prompt
+
+ if classifier_free_guidance:
+ prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
+ else:
+ prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ inference_scheduler.set_timesteps(num_steps, device=device)
+ timesteps = inference_scheduler.timesteps
+
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
+
+ num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
+ latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
+
+ noise_pred = self.unet(
+ latent_model_input, t, encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=boolean_prompt_mask
+ ).sample
+
+ # perform guidance
+ if classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
+ progress_bar.update(1)
+
+ if self.set_from == "pre-trained":
+ latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ return latents
+
+ def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
+ shape = (batch_size, num_channels_latents, 256, 16)
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * inference_scheduler.init_noise_sigma
+ return latents
+
+ def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ prompt_embeds = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # get unconditional embeddings for classifier free guidance
+ uncond_tokens = [""] * len(prompt)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_batch = self.tokenizer(
+ uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
+ )
+ uncond_input_ids = uncond_batch.input_ids.to(device)
+ uncond_attention_mask = uncond_batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ negative_prompt_embeds = self.text_encoder(
+ input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
+ )[0]
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
+ boolean_prompt_mask = (prompt_mask == 1).to(device)
+
+ return prompt_embeds, boolean_prompt_mask
+
class MusicAudioDiffusion(nn.Module):
- def __init__(
- self,
- text_encoder_name,
- scheduler_name,
- unet_model_name=None,
- unet_model_config_path=None,
- snr_gamma=None,
- freeze_text_encoder=True,
- uncondition=False,
-
- d_fme = 1024, #FME
- fme_type = "se",
- base = 1,
- if_trainable = True,
- translation_bias_type = "nd",
- emb_nn = True,
- d_pe = 1024, #PE
- if_index = True,
- if_global_timing = True,
- if_modulo_timing = False,
- d_beat = 1024, #Beat
- d_oh_beat_type = 7,
- beat_len = 50,
- d_chord = 1024, #Chord
- d_oh_chord_type = 12,
- d_oh_inv_type = 4,
- chord_len = 20,
-
- ):
- super().__init__()
-
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
-
- self.text_encoder_name = text_encoder_name
- self.scheduler_name = scheduler_name
- self.unet_model_name = unet_model_name
- self.unet_model_config_path = unet_model_config_path
- self.snr_gamma = snr_gamma
- self.freeze_text_encoder = freeze_text_encoder
- self.uncondition = uncondition
-
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
- self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
- self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
-
- if unet_model_config_path:
- unet_config = UNet2DConditionModelMusic.load_config(unet_model_config_path)
- self.unet = UNet2DConditionModelMusic.from_config(unet_config, subfolder="unet")
- self.set_from = "random"
- print("UNet initialized randomly.")
- else:
- self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
- self.set_from = "pre-trained"
- self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
- self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
- print("UNet initialized from stable diffusion checkpoint.")
-
- if "stable-diffusion" in self.text_encoder_name:
- self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
- self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
- elif "t5" in self.text_encoder_name:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
- else:
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
- self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
-
- self.device = self.text_encoder.device
- #Music Feature Encoder
- self.FME = Fundamental_Music_Embedding(d_model = d_fme, base= base, if_trainable = False, type = fme_type,emb_nn=emb_nn,translation_bias_type = translation_bias_type)
- self.PE = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
- # self.PE2 = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
- self.beat_tokenizer = beat_tokenizer(seq_len_beat=beat_len, if_pad = True)
- self.beat_embedding_layer = Beat_Embedding(self.PE, d_model = d_beat, d_oh_beat_type = d_oh_beat_type)
- self.chord_embedding_layer = Chord_Embedding(self.FME, self.PE, d_model = d_chord, d_oh_type = d_oh_chord_type, d_oh_inv = d_oh_inv_type)
- self.chord_tokenizer = chord_tokenizer(seq_len_chord=chord_len, if_pad = True)
-
-
- def compute_snr(self, timesteps):
- """
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
- """
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
- sqrt_alphas_cumprod = alphas_cumprod**0.5
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
-
- # Expand the tensors.
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
-
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
-
- # Compute SNR.
- snr = (alpha / sigma) ** 2
- return snr
-
- def encode_text(self, prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) #cuda
- if self.freeze_text_encoder:
- with torch.no_grad():
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0] #batch, len_text, dim
- else:
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
- boolean_encoder_mask = (attention_mask == 1).to(device) ##batch, len_text
- return encoder_hidden_states, boolean_encoder_mask
-
- def encode_beats(self, beats):
- # device = self.beat_embedding_layer.device
- out_beat = []
- out_beat_timing = []
- out_mask = []
- for beat in beats:
- tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
- out_beat.append(tokenized_beats)
- out_beat_timing.append(tokenized_beats_timing)
- out_mask.append(tokenized_beat_mask)
- out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).cuda(), torch.tensor(out_beat_timing).cuda(), torch.tensor(out_mask).cuda() #batch, len_beat
- embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing)
-
- return embedded_beat, out_mask
-
- def encode_chords(self, chords,chords_time):
- out_chord_root = []
- out_chord_type = []
- out_chord_inv = []
- out_chord_timing = []
- out_mask = []
- for chord, chord_time in zip(chords,chords_time): #batch loop
- tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
- out_chord_root.append(tokenized_chord_root)
- out_chord_type.append(tokenized_chord_type)
- out_chord_inv.append(tokenized_chord_inv)
- out_chord_timing.append(tokenized_chord_time)
- out_mask.append(tokenized_chord_mask)
- #chords: (B, LEN, 4)
- out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).cuda(), torch.tensor(out_chord_type).cuda(), torch.tensor(out_chord_inv).cuda(), torch.tensor(out_chord_timing).cuda(), torch.tensor(out_mask).cuda()
- embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing)
- return embedded_chord, out_mask
- # return out_chord_root, out_mask
-
-
- def forward(self, latents, prompt, beats, chords,chords_time, validation_mode=False):
- device = self.text_encoder.device
- num_train_timesteps = self.noise_scheduler.num_train_timesteps
- self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
-
- encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
-
- # with torch.no_grad():
- encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
- encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
-
-
- if self.uncondition:
- mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
- if len(mask_indices) > 0:
- encoder_hidden_states[mask_indices] = 0
- encoded_chords[mask_indices] = 0
- encoded_beats[mask_indices] = 0
-
- bsz = latents.shape[0]
-
- if validation_mode:
- timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
- else:
- timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
-
-
- timesteps = timesteps.long()
-
- noise = torch.randn_like(latents)
- noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the target for loss depending on the prediction type
- if self.noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif self.noise_scheduler.config.prediction_type == "v_prediction":
- target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
- else:
- raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
-
- if self.set_from == "random":
- # model_pred = torch.zeros((bsz,8,256,16)).to(device)
- model_pred = self.unet(
- noisy_latents, timesteps, encoder_hidden_states, encoded_beats, encoded_chords,
- encoder_attention_mask=boolean_encoder_mask, beat_attention_mask = beat_mask, chord_attention_mask = chord_mask
- ).sample
-
- elif self.set_from == "pre-trained":
- compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- model_pred = self.unet(
- compressed_latents, timesteps, encoder_hidden_states,
- encoder_attention_mask=boolean_encoder_mask
- ).sample
- model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
-
- if self.snr_gamma is None:
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
- # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
- snr = self.compute_snr(timesteps)
- mse_loss_weights = (
- torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
- )
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
- loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
- loss = loss.mean()
-
- return loss
-
- @torch.no_grad()
- def inference(self, prompt, beats, chords,chords_time, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
- disable_progress=True):
- device = self.text_encoder.device
- classifier_free_guidance = guidance_scale > 1.0
- batch_size = len(prompt) * num_samples_per_prompt
-
- if classifier_free_guidance:
- prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
- encoded_beats, beat_mask = self.encode_beats_classifier_free(beats, num_samples_per_prompt) #batch, len_beats, dim; batch, len_beats
- encoded_chords, chord_mask = self.encode_chords_classifier_free(chords, chords_time, num_samples_per_prompt)
- else:
- prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
- encoded_beats = encoded_beats.repeat_interleave(num_samples_per_prompt, 0)
- beat_mask = beat_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
- encoded_chords = encoded_chords.repeat_interleave(num_samples_per_prompt, 0)
- chord_mask = chord_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # print(f"encoded_chords:{encoded_chords.shape}, chord_mask:{chord_mask.shape}, prompt_embeds:{prompt_embeds.shape},boolean_prompt_mask:{boolean_prompt_mask.shape} ")
- inference_scheduler.set_timesteps(num_steps, device=device)
- timesteps = inference_scheduler.timesteps
-
- num_channels_latents = self.unet.in_channels
- latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
-
- num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
- progress_bar = tqdm(range(num_steps), disable=disable_progress)
-
- for i, t in enumerate(timesteps):
- # expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
- latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
-
- noise_pred = self.unet(
- latent_model_input, t, encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=boolean_prompt_mask,
- beat_features = encoded_beats, beat_attention_mask = beat_mask, chord_features = encoded_chords,chord_attention_mask = chord_mask
- ).sample
-
- # perform guidance
- if classifier_free_guidance: #should work for beats and chords too
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
-
- # call the callback, if provided
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
- progress_bar.update(1)
-
- if self.set_from == "pre-trained":
- latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
- return latents
-
- def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
- shape = (batch_size, num_channels_latents, 256, 16)
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
- # scale the initial noise by the standard deviation required by the scheduler
- latents = latents * inference_scheduler.init_noise_sigma
- return latents
-
- def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
- device = self.text_encoder.device
- batch = self.tokenizer(
- prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
- )
- input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
-
- with torch.no_grad():
- prompt_embeds = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask
- )[0]
-
- prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # get unconditional embeddings for classifier free guidance
- # print(len(prompt), 'this is prompt len')
- uncond_tokens = [""] * len(prompt)
-
- max_length = prompt_embeds.shape[1]
- uncond_batch = self.tokenizer(
- uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
- )
- uncond_input_ids = uncond_batch.input_ids.to(device)
- uncond_attention_mask = uncond_batch.attention_mask.to(device)
-
- with torch.no_grad():
- negative_prompt_embeds = self.text_encoder(
- input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
- )[0]
-
- negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
- uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- # For classifier free guidance, we need to do two forward passes.
- # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
- boolean_prompt_mask = (prompt_mask == 1).to(device)
-
- return prompt_embeds, boolean_prompt_mask
-
-
- def encode_beats_classifier_free(self, beats, num_samples_per_prompt):
- with torch.no_grad():
- out_beat = []
- out_beat_timing = []
- out_mask = []
- for beat in beats:
- tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
- out_beat.append(tokenized_beats)
- out_beat_timing.append(tokenized_beats_timing)
- out_mask.append(tokenized_beat_mask)
- out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).cuda(), torch.tensor(out_beat_timing).cuda(), torch.tensor(out_mask).cuda() #batch, len_beat
- embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing)
-
- embedded_beat = embedded_beat.repeat_interleave(num_samples_per_prompt, 0)
- out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- uncond_beats = [[[],[]]] * len(beats)
-
- max_length = embedded_beat.shape[1]
- with torch.no_grad():
- out_beat_unc = []
- out_beat_timing_unc = []
- out_mask_unc = []
- for beat in uncond_beats:
- tokenized_beats, tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
- out_beat_unc.append(tokenized_beats)
- out_beat_timing_unc.append(tokenized_beats_timing)
- out_mask_unc.append(tokenized_beat_mask)
- out_beat_unc, out_beat_timing_unc, out_mask_unc = torch.tensor(out_beat_unc).cuda(), torch.tensor(out_beat_timing_unc).cuda(), torch.tensor(out_mask_unc).cuda() #batch, len_beat
- embedded_beat_unc = self.beat_embedding_layer(out_beat_unc, out_beat_timing_unc)
-
- embedded_beat_unc = embedded_beat_unc.repeat_interleave(num_samples_per_prompt, 0)
- out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
-
- embedded_beat = torch.cat([embedded_beat_unc, embedded_beat])
- out_mask = torch.cat([out_mask_unc, out_mask])
-
- return embedded_beat, out_mask
-
-
- def encode_chords_classifier_free(self, chords, chords_time, num_samples_per_prompt):
-
- with torch.no_grad():
- out_chord_root = []
- out_chord_type = []
- out_chord_inv = []
- out_chord_timing = []
- out_mask = []
- for chord, chord_time in zip(chords,chords_time): #batch loop
- tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
- out_chord_root.append(tokenized_chord_root)
- out_chord_type.append(tokenized_chord_type)
- out_chord_inv.append(tokenized_chord_inv)
- out_chord_timing.append(tokenized_chord_time)
- out_mask.append(tokenized_chord_mask)
- out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).cuda(), torch.tensor(out_chord_type).cuda(), torch.tensor(out_chord_inv).cuda(), torch.tensor(out_chord_timing).cuda(), torch.tensor(out_mask).cuda()
- embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing)
-
- embedded_chord = embedded_chord.repeat_interleave(num_samples_per_prompt, 0)
- out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
-
- chords_unc=[[]] * len(chords)
- chords_time_unc=[[]] * len(chords_time)
-
- max_length = embedded_chord.shape[1]
-
- with torch.no_grad():
- out_chord_root_unc = []
- out_chord_type_unc = []
- out_chord_inv_unc = []
- out_chord_timing_unc = []
- out_mask_unc = []
- for chord, chord_time in zip(chords_unc,chords_time_unc): #batch loop
- tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
- out_chord_root_unc.append(tokenized_chord_root)
- out_chord_type_unc.append(tokenized_chord_type)
- out_chord_inv_unc.append(tokenized_chord_inv)
- out_chord_timing_unc.append(tokenized_chord_time)
- out_mask_unc.append(tokenized_chord_mask)
- out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc, out_mask_unc = torch.tensor(out_chord_root_unc).cuda(), torch.tensor(out_chord_type_unc).cuda(), torch.tensor(out_chord_inv_unc).cuda(), torch.tensor(out_chord_timing_unc).cuda(), torch.tensor(out_mask_unc).cuda()
- embedded_chord_unc = self.chord_embedding_layer(out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc)
-
-
- embedded_chord_unc = embedded_chord_unc.repeat_interleave(num_samples_per_prompt, 0)
- out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
-
- embedded_chord = torch.cat([embedded_chord_unc, embedded_chord])
- out_mask = torch.cat([out_mask_unc, out_mask])
-
- return embedded_chord, out_mask
+ def __init__(
+ self,
+ text_encoder_name,
+ scheduler_name,
+ unet_model_name=None,
+ unet_model_config_path=None,
+ snr_gamma=None,
+ freeze_text_encoder=True,
+ uncondition=False,
+
+ d_fme = 1024, #FME
+ fme_type = "se",
+ base = 1,
+ if_trainable = True,
+ translation_bias_type = "nd",
+ emb_nn = True,
+ d_pe = 1024, #PE
+ if_index = True,
+ if_global_timing = True,
+ if_modulo_timing = False,
+ d_beat = 1024, #Beat
+ d_oh_beat_type = 7,
+ beat_len = 50,
+ d_chord = 1024, #Chord
+ d_oh_chord_type = 12,
+ d_oh_inv_type = 4,
+ chord_len = 20,
+
+ ):
+ super().__init__()
+
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
+
+ self.text_encoder_name = text_encoder_name
+ self.scheduler_name = scheduler_name
+ self.unet_model_name = unet_model_name
+ self.unet_model_config_path = unet_model_config_path
+ self.snr_gamma = snr_gamma
+ self.freeze_text_encoder = freeze_text_encoder
+ self.uncondition = uncondition
+
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
+ self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+ self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
+
+ if unet_model_config_path:
+ unet_config = UNet2DConditionModelMusic.load_config(unet_model_config_path)
+ self.unet = UNet2DConditionModelMusic.from_config(unet_config, subfolder="unet")
+ self.set_from = "random"
+ print("UNet initialized randomly.")
+ else:
+ self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
+ self.set_from = "pre-trained"
+ self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
+ self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
+ print("UNet initialized from stable diffusion checkpoint.")
+
+ if "stable-diffusion" in self.text_encoder_name:
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
+ elif "t5" in self.text_encoder_name:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
+ self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
+
+ self.device = self.text_encoder.device
+ #Music Feature Encoder
+ self.FME = Fundamental_Music_Embedding(d_model = d_fme, base= base, if_trainable = False, type = fme_type,emb_nn=emb_nn,translation_bias_type = translation_bias_type)
+ self.PE = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
+ # self.PE2 = Music_PositionalEncoding(d_model = d_pe, if_index = if_index, if_global_timing = if_global_timing, if_modulo_timing = if_modulo_timing, device = self.device)
+ self.beat_tokenizer = beat_tokenizer(seq_len_beat=beat_len, if_pad = True)
+ self.beat_embedding_layer = Beat_Embedding(self.PE, d_model = d_beat, d_oh_beat_type = d_oh_beat_type)
+ self.chord_embedding_layer = Chord_Embedding(self.FME, self.PE, d_model = d_chord, d_oh_type = d_oh_chord_type, d_oh_inv = d_oh_inv_type)
+ self.chord_tokenizer = chord_tokenizer(seq_len_chord=chord_len, if_pad = True)
+
+
+ def compute_snr(self, timesteps):
+ """
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ # Expand the tensors.
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+ def encode_text(self, prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) #cuda
+ if self.freeze_text_encoder:
+ with torch.no_grad():
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0] #batch, len_text, dim
+ else:
+ encoder_hidden_states = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+ boolean_encoder_mask = (attention_mask == 1).to(device) ##batch, len_text
+ return encoder_hidden_states, boolean_encoder_mask
+
+ def encode_beats(self, beats):
+ device = self.device
+ out_beat = []
+ out_beat_timing = []
+ out_mask = []
+ for beat in beats:
+ tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
+ out_beat.append(tokenized_beats)
+ out_beat_timing.append(tokenized_beats_timing)
+ out_mask.append(tokenized_beat_mask)
+ out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).to(device), torch.tensor(out_beat_timing).to(device), torch.tensor(out_mask).to(device) #batch, len_beat
+ embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing, device)
+
+ return embedded_beat, out_mask
+
+ def encode_chords(self, chords,chords_time):
+ device = self.device
+ out_chord_root = []
+ out_chord_type = []
+ out_chord_inv = []
+ out_chord_timing = []
+ out_mask = []
+ for chord, chord_time in zip(chords,chords_time): #batch loop
+ tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
+ out_chord_root.append(tokenized_chord_root)
+ out_chord_type.append(tokenized_chord_type)
+ out_chord_inv.append(tokenized_chord_inv)
+ out_chord_timing.append(tokenized_chord_time)
+ out_mask.append(tokenized_chord_mask)
+ #chords: (B, LEN, 4)
+ out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).to(device), torch.tensor(out_chord_type).to(device), torch.tensor(out_chord_inv).to(device), torch.tensor(out_chord_timing).to(device), torch.tensor(out_mask).to(device)
+ embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, device)
+ return embedded_chord, out_mask
+ # return out_chord_root, out_mask
+
+
+ def forward(self, latents, prompt, beats, chords,chords_time, validation_mode=False):
+ device = self.text_encoder.device
+ num_train_timesteps = self.noise_scheduler.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
+
+ # with torch.no_grad():
+ encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
+ encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
+
+
+ if self.uncondition:
+ mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
+ if len(mask_indices) > 0:
+ encoder_hidden_states[mask_indices] = 0
+ encoded_chords[mask_indices] = 0
+ encoded_beats[mask_indices] = 0
+
+ bsz = latents.shape[0]
+
+ if validation_mode:
+ timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device)
+ else:
+ timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device)
+
+
+ timesteps = timesteps.long()
+
+ noise = torch.randn_like(latents)
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the target for loss depending on the prediction type
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
+
+ if self.set_from == "random":
+ # model_pred = torch.zeros((bsz,8,256,16)).to(device)
+ model_pred = self.unet(
+ noisy_latents, timesteps, encoder_hidden_states, encoded_beats, encoded_chords,
+ encoder_attention_mask=boolean_encoder_mask, beat_attention_mask = beat_mask, chord_attention_mask = chord_mask
+ ).sample
+
+ elif self.set_from == "pre-trained":
+ compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ model_pred = self.unet(
+ compressed_latents, timesteps, encoder_hidden_states,
+ encoder_attention_mask=boolean_encoder_mask
+ ).sample
+ model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+
+ if self.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+ snr = self.compute_snr(timesteps)
+ mse_loss_weights = (
+ torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ return loss
+
+ @torch.no_grad()
+ def inference(self, prompt, beats, chords,chords_time, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
+ disable_progress=True):
+ device = self.text_encoder.device
+ classifier_free_guidance = guidance_scale > 1.0
+ batch_size = len(prompt) * num_samples_per_prompt
+
+ if classifier_free_guidance:
+ prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt)
+ encoded_beats, beat_mask = self.encode_beats_classifier_free(beats, num_samples_per_prompt) #batch, len_beats, dim; batch, len_beats
+ encoded_chords, chord_mask = self.encode_chords_classifier_free(chords, chords_time, num_samples_per_prompt)
+ else:
+ prompt_embeds, boolean_prompt_mask = self.encode_text(prompt)
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ encoded_beats, beat_mask = self.encode_beats(beats) #batch, len_beats, dim; batch, len_beats
+ encoded_beats = encoded_beats.repeat_interleave(num_samples_per_prompt, 0)
+ beat_mask = beat_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ encoded_chords, chord_mask = self.encode_chords(chords,chords_time)
+ encoded_chords = encoded_chords.repeat_interleave(num_samples_per_prompt, 0)
+ chord_mask = chord_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # print(f"encoded_chords:{encoded_chords.shape}, chord_mask:{chord_mask.shape}, prompt_embeds:{prompt_embeds.shape},boolean_prompt_mask:{boolean_prompt_mask.shape} ")
+ inference_scheduler.set_timesteps(num_steps, device=device)
+ timesteps = inference_scheduler.timesteps
+
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
+
+ num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
+ latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
+
+ noise_pred = self.unet(
+ latent_model_input, t, encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=boolean_prompt_mask,
+ beat_features = encoded_beats, beat_attention_mask = beat_mask, chord_features = encoded_chords,chord_attention_mask = chord_mask
+ ).sample
+
+ # perform guidance
+ if classifier_free_guidance: #should work for beats and chords too
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
+ progress_bar.update(1)
+
+ if self.set_from == "pre-trained":
+ latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
+ return latents
+
+ def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
+ shape = (batch_size, num_channels_latents, 256, 16)
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * inference_scheduler.init_noise_sigma
+ return latents
+
+ def encode_text_classifier_free(self, prompt, num_samples_per_prompt):
+ device = self.text_encoder.device
+ batch = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ prompt_embeds = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # get unconditional embeddings for classifier free guidance
+ # print(len(prompt), 'this is prompt len')
+ uncond_tokens = [""] * len(prompt)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_batch = self.tokenizer(
+ uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
+ )
+ uncond_input_ids = uncond_batch.input_ids.to(device)
+ uncond_attention_mask = uncond_batch.attention_mask.to(device)
+
+ with torch.no_grad():
+ negative_prompt_embeds = self.text_encoder(
+ input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
+ )[0]
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
+ uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
+ boolean_prompt_mask = (prompt_mask == 1).to(device)
+
+ return prompt_embeds, boolean_prompt_mask
+
+
+ def encode_beats_classifier_free(self, beats, num_samples_per_prompt):
+ device = self.device
+ with torch.no_grad():
+ out_beat = []
+ out_beat_timing = []
+ out_mask = []
+ for beat in beats:
+ tokenized_beats,tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
+ out_beat.append(tokenized_beats)
+ out_beat_timing.append(tokenized_beats_timing)
+ out_mask.append(tokenized_beat_mask)
+ out_beat, out_beat_timing, out_mask = torch.tensor(out_beat).to(device), torch.tensor(out_beat_timing).to(device), torch.tensor(out_mask).to(device) #batch, len_beat
+ embedded_beat = self.beat_embedding_layer(out_beat, out_beat_timing, device)
+
+ embedded_beat = embedded_beat.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ uncond_beats = [[[],[]]] * len(beats)
+
+ max_length = embedded_beat.shape[1]
+ with torch.no_grad():
+ out_beat_unc = []
+ out_beat_timing_unc = []
+ out_mask_unc = []
+ for beat in uncond_beats:
+ tokenized_beats, tokenized_beats_timing, tokenized_beat_mask = self.beat_tokenizer(beat)
+ out_beat_unc.append(tokenized_beats)
+ out_beat_timing_unc.append(tokenized_beats_timing)
+ out_mask_unc.append(tokenized_beat_mask)
+ out_beat_unc, out_beat_timing_unc, out_mask_unc = torch.tensor(out_beat_unc).to(device), torch.tensor(out_beat_timing_unc).to(device), torch.tensor(out_mask_unc).to(device) #batch, len_beat
+ embedded_beat_unc = self.beat_embedding_layer(out_beat_unc, out_beat_timing_unc, device)
+
+ embedded_beat_unc = embedded_beat_unc.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
+
+ embedded_beat = torch.cat([embedded_beat_unc, embedded_beat])
+ out_mask = torch.cat([out_mask_unc, out_mask])
+
+ return embedded_beat, out_mask
+
+
+ def encode_chords_classifier_free(self, chords, chords_time, num_samples_per_prompt):
+ device = self.device
+ with torch.no_grad():
+ out_chord_root = []
+ out_chord_type = []
+ out_chord_inv = []
+ out_chord_timing = []
+ out_mask = []
+ for chord, chord_time in zip(chords,chords_time): #batch loop
+ tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
+ out_chord_root.append(tokenized_chord_root)
+ out_chord_type.append(tokenized_chord_type)
+ out_chord_inv.append(tokenized_chord_inv)
+ out_chord_timing.append(tokenized_chord_time)
+ out_mask.append(tokenized_chord_mask)
+ out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, out_mask = torch.tensor(out_chord_root).to(device), torch.tensor(out_chord_type).to(device), torch.tensor(out_chord_inv).to(device), torch.tensor(out_chord_timing).to(device), torch.tensor(out_mask).to(device)
+ embedded_chord = self.chord_embedding_layer(out_chord_root, out_chord_type, out_chord_inv, out_chord_timing, device)
+
+ embedded_chord = embedded_chord.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask = out_mask.repeat_interleave(num_samples_per_prompt, 0)
+
+ chords_unc=[[]] * len(chords)
+ chords_time_unc=[[]] * len(chords_time)
+
+ max_length = embedded_chord.shape[1]
+
+ with torch.no_grad():
+ out_chord_root_unc = []
+ out_chord_type_unc = []
+ out_chord_inv_unc = []
+ out_chord_timing_unc = []
+ out_mask_unc = []
+ for chord, chord_time in zip(chords_unc,chords_time_unc): #batch loop
+ tokenized_chord_root, tokenized_chord_type, tokenized_chord_inv, tokenized_chord_time, tokenized_chord_mask = self.chord_tokenizer(chord, chord_time)
+ out_chord_root_unc.append(tokenized_chord_root)
+ out_chord_type_unc.append(tokenized_chord_type)
+ out_chord_inv_unc.append(tokenized_chord_inv)
+ out_chord_timing_unc.append(tokenized_chord_time)
+ out_mask_unc.append(tokenized_chord_mask)
+ out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc, out_mask_unc = torch.tensor(out_chord_root_unc).to(device), torch.tensor(out_chord_type_unc).to(device), torch.tensor(out_chord_inv_unc).to(device), torch.tensor(out_chord_timing_unc).to(device), torch.tensor(out_mask_unc).to(device)
+ embedded_chord_unc = self.chord_embedding_layer(out_chord_root_unc, out_chord_type_unc, out_chord_inv_unc, out_chord_timing_unc, device)
+
+
+ embedded_chord_unc = embedded_chord_unc.repeat_interleave(num_samples_per_prompt, 0)
+ out_mask_unc = out_mask_unc.repeat_interleave(num_samples_per_prompt, 0)
+
+ embedded_chord = torch.cat([embedded_chord_unc, embedded_chord])
+ out_mask = torch.cat([out_mask_unc, out_mask])
+
+ return embedded_chord, out_mask
diff --git a/requirements.txt b/requirements.txt
index bf57c2bb0f671f329df16ad09dc37fcd9f71a6f9..d57857ffe6f63d2840feb944b4acd123e2e84c71 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,12 +1,12 @@
-torch==1.13.1
-torchaudio==0.13.1
-torchvision==0.14.1
-transformers==4.27.0
-accelerate==0.18.0
+torch==2.0.1
+torchaudio==2.0.2
+torchvision==0.15.2
+transformers==4.31.0
+accelerate==0.21.0
datasets==2.1.0
einops==0.6.1
h5py==3.8.0
-huggingface_hub==0.13.3
+huggingface_hub==0.19.4
importlib_metadata==6.3.0
librosa==0.9.2
matplotlib==3.5.2
@@ -17,6 +17,7 @@ pandas==1.4.1
progressbar33==2.4
protobuf==3.20.*
resampy==0.4.2
+safetensors==0.3.2
sentencepiece==0.1.99
scikit_image==0.19.3
scikit_learn==1.2.2
diff --git a/tools/__pycache__/__init__.cpython-310.pyc b/tools/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57f53b8354644afb3250bd5295869db079e6d457
Binary files /dev/null and b/tools/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tools/__pycache__/mix.cpython-310.pyc b/tools/__pycache__/mix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7887d5969502cd5c9e5604ec3e9221a5ae2e5420
Binary files /dev/null and b/tools/__pycache__/mix.cpython-310.pyc differ
diff --git a/tools/__pycache__/torch_tools.cpython-310.pyc b/tools/__pycache__/torch_tools.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9268f43d41fb5c457f4c490340fb618fc7c1b79e
Binary files /dev/null and b/tools/__pycache__/torch_tools.cpython-310.pyc differ