Atin Sakkeer Hussain commited on
Commit
795ce43
β€’
1 Parent(s): eae7a25
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/M2UGen-Demo.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="GOOGLE" />
10
+ <option name="myDocStringFormat" value="Google" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="18">
8
+ <item index="0" class="java.lang.String" itemvalue="pandas" />
9
+ <item index="1" class="java.lang.String" itemvalue="tqdm" />
10
+ <item index="2" class="java.lang.String" itemvalue="absl-py" />
11
+ <item index="3" class="java.lang.String" itemvalue="dgl" />
12
+ <item index="4" class="java.lang.String" itemvalue="torch" />
13
+ <item index="5" class="java.lang.String" itemvalue="numpy" />
14
+ <item index="6" class="java.lang.String" itemvalue="Cython" />
15
+ <item index="7" class="java.lang.String" itemvalue="torchlibrosa" />
16
+ <item index="8" class="java.lang.String" itemvalue="gdown" />
17
+ <item index="9" class="java.lang.String" itemvalue="wget" />
18
+ <item index="10" class="java.lang.String" itemvalue="accelerate" />
19
+ <item index="11" class="java.lang.String" itemvalue="transformers" />
20
+ <item index="12" class="java.lang.String" itemvalue="gradio" />
21
+ <item index="13" class="java.lang.String" itemvalue="tensorboard" />
22
+ <item index="14" class="java.lang.String" itemvalue="diffusers" />
23
+ <item index="15" class="java.lang.String" itemvalue="opencv-python" />
24
+ <item index="16" class="java.lang.String" itemvalue="huggingface_hub" />
25
+ <item index="17" class="java.lang.String" itemvalue="Pillow" />
26
+ </list>
27
+ </value>
28
+ </option>
29
+ </inspection_tool>
30
+ <inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
31
+ <option name="ignoredErrors">
32
+ <list>
33
+ <option value="W605" />
34
+ </list>
35
+ </option>
36
+ </inspection_tool>
37
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
38
+ <option name="ignoredErrors">
39
+ <list>
40
+ <option value="N806" />
41
+ <option value="N802" />
42
+ <option value="N803" />
43
+ </list>
44
+ </option>
45
+ </inspection_tool>
46
+ <inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
47
+ <option name="ignoredIdentifiers">
48
+ <list>
49
+ <option value="tokenizers.BertWordPieceTokenizer" />
50
+ <option value="cv2.aruco" />
51
+ <option value="llama" />
52
+ </list>
53
+ </option>
54
+ </inspection_tool>
55
+ </profile>
56
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (AudioCaption)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/M2UGen-Demo.iml" filepath="$PROJECT_DIR$/.idea/M2UGen-Demo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
llama/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .llama import ModelArgs, Transformer
2
+ from .tokenizer import Tokenizer
3
+ from .m2ugen import *
4
+ from .utils import format_prompt
llama/audioldm2/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline_audioldm2 import AudioLDM2Pipeline
llama/audioldm2/modeling_audioldm2.py ADDED
@@ -0,0 +1,1513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.embeddings import (
33
+ TimestepEmbedding,
34
+ Timesteps,
35
+ )
36
+ from diffusers.models.modeling_utils import ModelMixin
37
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
38
+ from diffusers.models.transformer_2d import Transformer2DModel
39
+ from diffusers.models.unet_2d_blocks import DownBlock2D, UpBlock2D
40
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
41
+ from diffusers.utils import BaseOutput, is_torch_version, logging
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+
47
+ def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token):
48
+ batch_size = hidden_states.shape[0]
49
+
50
+ if attention_mask is not None:
51
+ # Add two more steps to attn mask
52
+ new_attn_mask_step = attention_mask.new_ones((batch_size, 1))
53
+ attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1)
54
+
55
+ # Add the SOS / EOS tokens at the start / end of the sequence respectively
56
+ sos_token = sos_token.expand(batch_size, 1, -1)
57
+ eos_token = eos_token.expand(batch_size, 1, -1)
58
+ hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1)
59
+ return hidden_states, attention_mask
60
+
61
+
62
+ @dataclass
63
+ class AudioLDM2ProjectionModelOutput(BaseOutput):
64
+ """
65
+ Args:
66
+ Class for AudioLDM2 projection layer's outputs.
67
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
68
+ Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text
69
+ encoders and subsequently concatenating them together.
70
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
71
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks
72
+ for the two text encoders together. Mask values selected in `[0, 1]`:
73
+
74
+ - 1 for tokens that are **not masked**,
75
+ - 0 for tokens that are **masked**.
76
+ """
77
+
78
+ hidden_states: torch.FloatTensor
79
+ attention_mask: Optional[torch.LongTensor] = None
80
+
81
+
82
+ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
83
+ """
84
+ A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned
85
+ embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with
86
+ `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first.
87
+
88
+ Args:
89
+ text_encoder_dim (`int`):
90
+ Dimensionality of the text embeddings from the first text encoder (CLAP).
91
+ text_encoder_1_dim (`int`):
92
+ Dimensionality of the text embeddings from the second text encoder (T5 or VITS).
93
+ langauge_model_dim (`int`):
94
+ Dimensionality of the text embeddings from the language model (GPT2).
95
+ """
96
+
97
+ @register_to_config
98
+ def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
99
+ super().__init__()
100
+ # additional projection layers for each text encoder
101
+ self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
102
+ self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim)
103
+
104
+ # learnable SOS / EOS token embeddings for each text encoder
105
+ self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim))
106
+ self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim))
107
+
108
+ self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
109
+ self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
110
+
111
+ def forward(
112
+ self,
113
+ hidden_states: Optional[torch.FloatTensor] = None,
114
+ hidden_states_1: Optional[torch.FloatTensor] = None,
115
+ attention_mask: Optional[torch.LongTensor] = None,
116
+ attention_mask_1: Optional[torch.LongTensor] = None,
117
+ ):
118
+ hidden_states = self.projection(hidden_states)
119
+ hidden_states, attention_mask = add_special_tokens(
120
+ hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
121
+ )
122
+
123
+ hidden_states_1 = self.projection_1(hidden_states_1)
124
+ hidden_states_1, attention_mask_1 = add_special_tokens(
125
+ hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
126
+ )
127
+
128
+ # concatenate clap and t5 text encoding
129
+ hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1)
130
+
131
+ # concatenate attention masks
132
+ if attention_mask is None and attention_mask_1 is not None:
133
+ attention_mask = attention_mask_1.new_ones((hidden_states[:2]))
134
+ elif attention_mask is not None and attention_mask_1 is None:
135
+ attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2]))
136
+
137
+ if attention_mask is not None and attention_mask_1 is not None:
138
+ attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1)
139
+ else:
140
+ attention_mask = None
141
+
142
+ return AudioLDM2ProjectionModelOutput(
143
+ hidden_states=hidden_states,
144
+ attention_mask=attention_mask,
145
+ )
146
+
147
+
148
+ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
149
+ r"""
150
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
151
+ shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
152
+ self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up
153
+ to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`.
154
+
155
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
156
+ for all models (such as downloading or saving).
157
+
158
+ Parameters:
159
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
160
+ Height and width of input/output sample.
161
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
162
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
163
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
164
+ Whether to flip the sin to cos in the time embedding.
165
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
166
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
167
+ The tuple of downsample blocks to use.
168
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
169
+ Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2.
170
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
171
+ The tuple of upsample blocks to use.
172
+ only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`):
173
+ Whether to include self-attention in the basic transformer blocks, see
174
+ [`~models.attention.BasicTransformerBlock`].
175
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
176
+ The tuple of output channels for each block.
177
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
178
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
179
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
180
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
181
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
182
+ If `None`, normalization and activation layers is skipped in post-processing.
183
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
184
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
185
+ The dimension of the cross attention features.
186
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
187
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
188
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
189
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
190
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
191
+ num_attention_heads (`int`, *optional*):
192
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
193
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
194
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
195
+ class_embed_type (`str`, *optional*, defaults to `None`):
196
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
197
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
198
+ num_class_embeds (`int`, *optional*, defaults to `None`):
199
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
200
+ class conditioning with `class_embed_type` equal to `None`.
201
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
202
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
203
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
204
+ An optional override for the dimension of the projected time embedding.
205
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
206
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
207
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
208
+ timestep_post_act (`str`, *optional*, defaults to `None`):
209
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
210
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
211
+ The dimension of `cond_proj` layer in the timestep embedding.
212
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
213
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
214
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
215
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
216
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
217
+ embeddings with the class embeddings.
218
+ """
219
+
220
+ _supports_gradient_checkpointing = True
221
+
222
+ @register_to_config
223
+ def __init__(
224
+ self,
225
+ sample_size: Optional[int] = None,
226
+ in_channels: int = 4,
227
+ out_channels: int = 4,
228
+ flip_sin_to_cos: bool = True,
229
+ freq_shift: int = 0,
230
+ down_block_types: Tuple[str] = (
231
+ "CrossAttnDownBlock2D",
232
+ "CrossAttnDownBlock2D",
233
+ "CrossAttnDownBlock2D",
234
+ "DownBlock2D",
235
+ ),
236
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
237
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
238
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
239
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
240
+ layers_per_block: Union[int, Tuple[int]] = 2,
241
+ downsample_padding: int = 1,
242
+ mid_block_scale_factor: float = 1,
243
+ act_fn: str = "silu",
244
+ norm_num_groups: Optional[int] = 32,
245
+ norm_eps: float = 1e-5,
246
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
247
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
248
+ attention_head_dim: Union[int, Tuple[int]] = 8,
249
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
250
+ use_linear_projection: bool = False,
251
+ class_embed_type: Optional[str] = None,
252
+ num_class_embeds: Optional[int] = None,
253
+ upcast_attention: bool = False,
254
+ resnet_time_scale_shift: str = "default",
255
+ time_embedding_type: str = "positional",
256
+ time_embedding_dim: Optional[int] = None,
257
+ time_embedding_act_fn: Optional[str] = None,
258
+ timestep_post_act: Optional[str] = None,
259
+ time_cond_proj_dim: Optional[int] = None,
260
+ conv_in_kernel: int = 3,
261
+ conv_out_kernel: int = 3,
262
+ projection_class_embeddings_input_dim: Optional[int] = None,
263
+ class_embeddings_concat: bool = False,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.sample_size = sample_size
268
+
269
+ if num_attention_heads is not None:
270
+ raise ValueError(
271
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
272
+ )
273
+
274
+ # If `num_attention_heads` is not defined (which is the case for most models)
275
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
276
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
277
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
278
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
279
+ # which is why we correct for the naming here.
280
+ num_attention_heads = num_attention_heads or attention_head_dim
281
+
282
+ # Check inputs
283
+ if len(down_block_types) != len(up_block_types):
284
+ raise ValueError(
285
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
286
+ )
287
+
288
+ if len(block_out_channels) != len(down_block_types):
289
+ raise ValueError(
290
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
291
+ )
292
+
293
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
294
+ raise ValueError(
295
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
296
+ )
297
+
298
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
299
+ raise ValueError(
300
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
301
+ )
302
+
303
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
304
+ raise ValueError(
305
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
306
+ )
307
+
308
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
309
+ raise ValueError(
310
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
311
+ )
312
+
313
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
314
+ raise ValueError(
315
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
316
+ )
317
+
318
+ # input
319
+ conv_in_padding = (conv_in_kernel - 1) // 2
320
+ self.conv_in = nn.Conv2d(
321
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
322
+ )
323
+
324
+ # time
325
+ if time_embedding_type == "positional":
326
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
327
+
328
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
329
+ timestep_input_dim = block_out_channels[0]
330
+ else:
331
+ raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.")
332
+
333
+ self.time_embedding = TimestepEmbedding(
334
+ timestep_input_dim,
335
+ time_embed_dim,
336
+ act_fn=act_fn,
337
+ post_act_fn=timestep_post_act,
338
+ cond_proj_dim=time_cond_proj_dim,
339
+ )
340
+
341
+ # class embedding
342
+ if class_embed_type is None and num_class_embeds is not None:
343
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
344
+ elif class_embed_type == "timestep":
345
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
346
+ elif class_embed_type == "identity":
347
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
348
+ elif class_embed_type == "projection":
349
+ if projection_class_embeddings_input_dim is None:
350
+ raise ValueError(
351
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
352
+ )
353
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
354
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
355
+ # 2. it projects from an arbitrary input dimension.
356
+ #
357
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
358
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
359
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
360
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
361
+ elif class_embed_type == "simple_projection":
362
+ if projection_class_embeddings_input_dim is None:
363
+ raise ValueError(
364
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
365
+ )
366
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
367
+ else:
368
+ self.class_embedding = None
369
+
370
+ if time_embedding_act_fn is None:
371
+ self.time_embed_act = None
372
+ else:
373
+ self.time_embed_act = get_activation(time_embedding_act_fn)
374
+
375
+ self.down_blocks = nn.ModuleList([])
376
+ self.up_blocks = nn.ModuleList([])
377
+
378
+ if isinstance(only_cross_attention, bool):
379
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
380
+
381
+ if isinstance(num_attention_heads, int):
382
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
383
+
384
+ if isinstance(cross_attention_dim, int):
385
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
386
+
387
+ if isinstance(layers_per_block, int):
388
+ layers_per_block = [layers_per_block] * len(down_block_types)
389
+
390
+ if isinstance(transformer_layers_per_block, int):
391
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
392
+
393
+ if class_embeddings_concat:
394
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
395
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
396
+ # regular time embeddings
397
+ blocks_time_embed_dim = time_embed_dim * 2
398
+ else:
399
+ blocks_time_embed_dim = time_embed_dim
400
+
401
+ # down
402
+ output_channel = block_out_channels[0]
403
+ for i, down_block_type in enumerate(down_block_types):
404
+ input_channel = output_channel
405
+ output_channel = block_out_channels[i]
406
+ is_final_block = i == len(block_out_channels) - 1
407
+
408
+ down_block = get_down_block(
409
+ down_block_type,
410
+ num_layers=layers_per_block[i],
411
+ transformer_layers_per_block=transformer_layers_per_block[i],
412
+ in_channels=input_channel,
413
+ out_channels=output_channel,
414
+ temb_channels=blocks_time_embed_dim,
415
+ add_downsample=not is_final_block,
416
+ resnet_eps=norm_eps,
417
+ resnet_act_fn=act_fn,
418
+ resnet_groups=norm_num_groups,
419
+ cross_attention_dim=cross_attention_dim[i],
420
+ num_attention_heads=num_attention_heads[i],
421
+ downsample_padding=downsample_padding,
422
+ use_linear_projection=use_linear_projection,
423
+ only_cross_attention=only_cross_attention[i],
424
+ upcast_attention=upcast_attention,
425
+ resnet_time_scale_shift=resnet_time_scale_shift,
426
+ )
427
+ self.down_blocks.append(down_block)
428
+
429
+ # mid
430
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
431
+ self.mid_block = UNetMidBlock2DCrossAttn(
432
+ transformer_layers_per_block=transformer_layers_per_block[-1],
433
+ in_channels=block_out_channels[-1],
434
+ temb_channels=blocks_time_embed_dim,
435
+ resnet_eps=norm_eps,
436
+ resnet_act_fn=act_fn,
437
+ output_scale_factor=mid_block_scale_factor,
438
+ resnet_time_scale_shift=resnet_time_scale_shift,
439
+ cross_attention_dim=cross_attention_dim[-1],
440
+ num_attention_heads=num_attention_heads[-1],
441
+ resnet_groups=norm_num_groups,
442
+ use_linear_projection=use_linear_projection,
443
+ upcast_attention=upcast_attention,
444
+ )
445
+ else:
446
+ raise ValueError(
447
+ f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2."
448
+ )
449
+
450
+ # count how many layers upsample the images
451
+ self.num_upsamplers = 0
452
+
453
+ # up
454
+ reversed_block_out_channels = list(reversed(block_out_channels))
455
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
456
+ reversed_layers_per_block = list(reversed(layers_per_block))
457
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
458
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
459
+ only_cross_attention = list(reversed(only_cross_attention))
460
+
461
+ output_channel = reversed_block_out_channels[0]
462
+ for i, up_block_type in enumerate(up_block_types):
463
+ is_final_block = i == len(block_out_channels) - 1
464
+
465
+ prev_output_channel = output_channel
466
+ output_channel = reversed_block_out_channels[i]
467
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
468
+
469
+ # add upsample block for all BUT final layer
470
+ if not is_final_block:
471
+ add_upsample = True
472
+ self.num_upsamplers += 1
473
+ else:
474
+ add_upsample = False
475
+
476
+ up_block = get_up_block(
477
+ up_block_type,
478
+ num_layers=reversed_layers_per_block[i] + 1,
479
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
480
+ in_channels=input_channel,
481
+ out_channels=output_channel,
482
+ prev_output_channel=prev_output_channel,
483
+ temb_channels=blocks_time_embed_dim,
484
+ add_upsample=add_upsample,
485
+ resnet_eps=norm_eps,
486
+ resnet_act_fn=act_fn,
487
+ resnet_groups=norm_num_groups,
488
+ cross_attention_dim=reversed_cross_attention_dim[i],
489
+ num_attention_heads=reversed_num_attention_heads[i],
490
+ use_linear_projection=use_linear_projection,
491
+ only_cross_attention=only_cross_attention[i],
492
+ upcast_attention=upcast_attention,
493
+ resnet_time_scale_shift=resnet_time_scale_shift,
494
+ )
495
+ self.up_blocks.append(up_block)
496
+ prev_output_channel = output_channel
497
+
498
+ # out
499
+ if norm_num_groups is not None:
500
+ self.conv_norm_out = nn.GroupNorm(
501
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
502
+ )
503
+
504
+ self.conv_act = get_activation(act_fn)
505
+
506
+ else:
507
+ self.conv_norm_out = None
508
+ self.conv_act = None
509
+
510
+ conv_out_padding = (conv_out_kernel - 1) // 2
511
+ self.conv_out = nn.Conv2d(
512
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
513
+ )
514
+
515
+ @property
516
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
517
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
518
+ r"""
519
+ Returns:
520
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
521
+ indexed by its weight name.
522
+ """
523
+ # set recursively
524
+ processors = {}
525
+
526
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
527
+ if hasattr(module, "get_processor"):
528
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
529
+
530
+ for sub_name, child in module.named_children():
531
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
532
+
533
+ return processors
534
+
535
+ for name, module in self.named_children():
536
+ fn_recursive_add_processors(name, module, processors)
537
+
538
+ return processors
539
+
540
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
541
+ def set_attn_processor(
542
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
543
+ ):
544
+ r"""
545
+ Sets the attention processor to use to compute attention.
546
+
547
+ Parameters:
548
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
549
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
550
+ for **all** `Attention` layers.
551
+
552
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
553
+ processor. This is strongly recommended when setting trainable attention processors.
554
+
555
+ """
556
+ count = len(self.attn_processors.keys())
557
+
558
+ if isinstance(processor, dict) and len(processor) != count:
559
+ raise ValueError(
560
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
561
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
562
+ )
563
+
564
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
565
+ if hasattr(module, "set_processor"):
566
+ if not isinstance(processor, dict):
567
+ module.set_processor(processor, _remove_lora=_remove_lora)
568
+ else:
569
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
570
+
571
+ for sub_name, child in module.named_children():
572
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
573
+
574
+ for name, module in self.named_children():
575
+ fn_recursive_attn_processor(name, module, processor)
576
+
577
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
578
+ def set_default_attn_processor(self):
579
+ """
580
+ Disables custom attention processors and sets the default attention implementation.
581
+ """
582
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
583
+ processor = AttnAddedKVProcessor()
584
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
585
+ processor = AttnProcessor()
586
+ else:
587
+ raise ValueError(
588
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
589
+ )
590
+
591
+ self.set_attn_processor(processor, _remove_lora=True)
592
+
593
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
594
+ def set_attention_slice(self, slice_size):
595
+ r"""
596
+ Enable sliced attention computation.
597
+
598
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
599
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
600
+
601
+ Args:
602
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
603
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
604
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
605
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
606
+ must be a multiple of `slice_size`.
607
+ """
608
+ sliceable_head_dims = []
609
+
610
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
611
+ if hasattr(module, "set_attention_slice"):
612
+ sliceable_head_dims.append(module.sliceable_head_dim)
613
+
614
+ for child in module.children():
615
+ fn_recursive_retrieve_sliceable_dims(child)
616
+
617
+ # retrieve number of attention layers
618
+ for module in self.children():
619
+ fn_recursive_retrieve_sliceable_dims(module)
620
+
621
+ num_sliceable_layers = len(sliceable_head_dims)
622
+
623
+ if slice_size == "auto":
624
+ # half the attention head size is usually a good trade-off between
625
+ # speed and memory
626
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
627
+ elif slice_size == "max":
628
+ # make smallest slice possible
629
+ slice_size = num_sliceable_layers * [1]
630
+
631
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
632
+
633
+ if len(slice_size) != len(sliceable_head_dims):
634
+ raise ValueError(
635
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
636
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
637
+ )
638
+
639
+ for i in range(len(slice_size)):
640
+ size = slice_size[i]
641
+ dim = sliceable_head_dims[i]
642
+ if size is not None and size > dim:
643
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
644
+
645
+ # Recursively walk through all the children.
646
+ # Any children which exposes the set_attention_slice method
647
+ # gets the message
648
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
649
+ if hasattr(module, "set_attention_slice"):
650
+ module.set_attention_slice(slice_size.pop())
651
+
652
+ for child in module.children():
653
+ fn_recursive_set_attention_slice(child, slice_size)
654
+
655
+ reversed_slice_size = list(reversed(slice_size))
656
+ for module in self.children():
657
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
658
+
659
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
660
+ def _set_gradient_checkpointing(self, module, value=False):
661
+ if hasattr(module, "gradient_checkpointing"):
662
+ module.gradient_checkpointing = value
663
+
664
+ def forward(
665
+ self,
666
+ sample: torch.FloatTensor,
667
+ timestep: Union[torch.Tensor, float, int],
668
+ encoder_hidden_states: torch.Tensor,
669
+ class_labels: Optional[torch.Tensor] = None,
670
+ timestep_cond: Optional[torch.Tensor] = None,
671
+ attention_mask: Optional[torch.Tensor] = None,
672
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
673
+ encoder_attention_mask: Optional[torch.Tensor] = None,
674
+ return_dict: bool = True,
675
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
676
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
677
+ ) -> Union[UNet2DConditionOutput, Tuple]:
678
+ r"""
679
+ The [`AudioLDM2UNet2DConditionModel`] forward method.
680
+
681
+ Args:
682
+ sample (`torch.FloatTensor`):
683
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
684
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
685
+ encoder_hidden_states (`torch.FloatTensor`):
686
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
687
+ encoder_attention_mask (`torch.Tensor`):
688
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
689
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
690
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
691
+ return_dict (`bool`, *optional*, defaults to `True`):
692
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
693
+ tuple.
694
+ cross_attention_kwargs (`dict`, *optional*):
695
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
696
+ encoder_hidden_states_1 (`torch.FloatTensor`, *optional*):
697
+ A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be
698
+ used to condition the model on a different set of embeddings to `encoder_hidden_states`.
699
+ encoder_attention_mask_1 (`torch.Tensor`, *optional*):
700
+ A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`.
701
+ If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
702
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
703
+
704
+ Returns:
705
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
706
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
707
+ a `tuple` is returned where the first element is the sample tensor.
708
+ """
709
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
710
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
711
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
712
+ # on the fly if necessary.
713
+ default_overall_up_factor = 2**self.num_upsamplers
714
+
715
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
716
+ forward_upsample_size = False
717
+ upsample_size = None
718
+
719
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
720
+ logger.info("Forward upsample size to force interpolation output size.")
721
+ forward_upsample_size = True
722
+
723
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
724
+ # expects mask of shape:
725
+ # [batch, key_tokens]
726
+ # adds singleton query_tokens dimension:
727
+ # [batch, 1, key_tokens]
728
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
729
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
730
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
731
+ if attention_mask is not None:
732
+ # assume that mask is expressed as:
733
+ # (1 = keep, 0 = discard)
734
+ # convert mask into a bias that can be added to attention scores:
735
+ # (keep = +0, discard = -10000.0)
736
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
737
+ attention_mask = attention_mask.unsqueeze(1)
738
+
739
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
740
+ if encoder_attention_mask is not None:
741
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
742
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
743
+
744
+ if encoder_attention_mask_1 is not None:
745
+ encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
746
+ encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
747
+
748
+ # 1. time
749
+ timesteps = timestep
750
+ if not torch.is_tensor(timesteps):
751
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
752
+ # This would be a good case for the `match` statement (Python 3.10+)
753
+ is_mps = sample.device.type == "mps"
754
+ if isinstance(timestep, float):
755
+ dtype = torch.float32 if is_mps else torch.float64
756
+ else:
757
+ dtype = torch.int32 if is_mps else torch.int64
758
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
759
+ elif len(timesteps.shape) == 0:
760
+ timesteps = timesteps[None].to(sample.device)
761
+
762
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
763
+ timesteps = timesteps.expand(sample.shape[0])
764
+
765
+ t_emb = self.time_proj(timesteps)
766
+
767
+ # `Timesteps` does not contain any weights and will always return f32 tensors
768
+ # but time_embedding might actually be running in fp16. so we need to cast here.
769
+ # there might be better ways to encapsulate this.
770
+ t_emb = t_emb.to(dtype=sample.dtype)
771
+
772
+ emb = self.time_embedding(t_emb, timestep_cond)
773
+ aug_emb = None
774
+
775
+ if self.class_embedding is not None:
776
+ if class_labels is None:
777
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
778
+
779
+ if self.config.class_embed_type == "timestep":
780
+ class_labels = self.time_proj(class_labels)
781
+
782
+ # `Timesteps` does not contain any weights and will always return f32 tensors
783
+ # there might be better ways to encapsulate this.
784
+ class_labels = class_labels.to(dtype=sample.dtype)
785
+
786
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
787
+
788
+ if self.config.class_embeddings_concat:
789
+ emb = torch.cat([emb, class_emb], dim=-1)
790
+ else:
791
+ emb = emb + class_emb
792
+
793
+ emb = emb + aug_emb if aug_emb is not None else emb
794
+
795
+ if self.time_embed_act is not None:
796
+ emb = self.time_embed_act(emb)
797
+
798
+ # 2. pre-process
799
+ sample = self.conv_in(sample)
800
+
801
+ # 3. down
802
+ down_block_res_samples = (sample,)
803
+ for downsample_block in self.down_blocks:
804
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
805
+ sample, res_samples = downsample_block(
806
+ hidden_states=sample,
807
+ temb=emb,
808
+ encoder_hidden_states=encoder_hidden_states,
809
+ attention_mask=attention_mask,
810
+ cross_attention_kwargs=cross_attention_kwargs,
811
+ encoder_attention_mask=encoder_attention_mask,
812
+ encoder_hidden_states_1=encoder_hidden_states_1,
813
+ encoder_attention_mask_1=encoder_attention_mask_1,
814
+ )
815
+ else:
816
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
817
+
818
+ down_block_res_samples += res_samples
819
+
820
+ # 4. mid
821
+ if self.mid_block is not None:
822
+ sample = self.mid_block(
823
+ sample,
824
+ emb,
825
+ encoder_hidden_states=encoder_hidden_states,
826
+ attention_mask=attention_mask,
827
+ cross_attention_kwargs=cross_attention_kwargs,
828
+ encoder_attention_mask=encoder_attention_mask,
829
+ encoder_hidden_states_1=encoder_hidden_states_1,
830
+ encoder_attention_mask_1=encoder_attention_mask_1,
831
+ )
832
+
833
+ # 5. up
834
+ for i, upsample_block in enumerate(self.up_blocks):
835
+ is_final_block = i == len(self.up_blocks) - 1
836
+
837
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
838
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
839
+
840
+ # if we have not reached the final block and need to forward the
841
+ # upsample size, we do it here
842
+ if not is_final_block and forward_upsample_size:
843
+ upsample_size = down_block_res_samples[-1].shape[2:]
844
+
845
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
846
+ sample = upsample_block(
847
+ hidden_states=sample,
848
+ temb=emb,
849
+ res_hidden_states_tuple=res_samples,
850
+ encoder_hidden_states=encoder_hidden_states,
851
+ cross_attention_kwargs=cross_attention_kwargs,
852
+ upsample_size=upsample_size,
853
+ attention_mask=attention_mask,
854
+ encoder_attention_mask=encoder_attention_mask,
855
+ encoder_hidden_states_1=encoder_hidden_states_1,
856
+ encoder_attention_mask_1=encoder_attention_mask_1,
857
+ )
858
+ else:
859
+ sample = upsample_block(
860
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
861
+ )
862
+
863
+ # 6. post-process
864
+ if self.conv_norm_out:
865
+ sample = self.conv_norm_out(sample)
866
+ sample = self.conv_act(sample)
867
+ sample = self.conv_out(sample)
868
+
869
+ if not return_dict:
870
+ return (sample,)
871
+
872
+ return UNet2DConditionOutput(sample=sample)
873
+
874
+
875
+ def get_down_block(
876
+ down_block_type,
877
+ num_layers,
878
+ in_channels,
879
+ out_channels,
880
+ temb_channels,
881
+ add_downsample,
882
+ resnet_eps,
883
+ resnet_act_fn,
884
+ transformer_layers_per_block=1,
885
+ num_attention_heads=None,
886
+ resnet_groups=None,
887
+ cross_attention_dim=None,
888
+ downsample_padding=None,
889
+ use_linear_projection=False,
890
+ only_cross_attention=False,
891
+ upcast_attention=False,
892
+ resnet_time_scale_shift="default",
893
+ ):
894
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
895
+ if down_block_type == "DownBlock2D":
896
+ return DownBlock2D(
897
+ num_layers=num_layers,
898
+ in_channels=in_channels,
899
+ out_channels=out_channels,
900
+ temb_channels=temb_channels,
901
+ add_downsample=add_downsample,
902
+ resnet_eps=resnet_eps,
903
+ resnet_act_fn=resnet_act_fn,
904
+ resnet_groups=resnet_groups,
905
+ downsample_padding=downsample_padding,
906
+ resnet_time_scale_shift=resnet_time_scale_shift,
907
+ )
908
+ elif down_block_type == "CrossAttnDownBlock2D":
909
+ if cross_attention_dim is None:
910
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
911
+ return CrossAttnDownBlock2D(
912
+ num_layers=num_layers,
913
+ transformer_layers_per_block=transformer_layers_per_block,
914
+ in_channels=in_channels,
915
+ out_channels=out_channels,
916
+ temb_channels=temb_channels,
917
+ add_downsample=add_downsample,
918
+ resnet_eps=resnet_eps,
919
+ resnet_act_fn=resnet_act_fn,
920
+ resnet_groups=resnet_groups,
921
+ downsample_padding=downsample_padding,
922
+ cross_attention_dim=cross_attention_dim,
923
+ num_attention_heads=num_attention_heads,
924
+ use_linear_projection=use_linear_projection,
925
+ only_cross_attention=only_cross_attention,
926
+ upcast_attention=upcast_attention,
927
+ resnet_time_scale_shift=resnet_time_scale_shift,
928
+ )
929
+ raise ValueError(f"{down_block_type} does not exist.")
930
+
931
+
932
+ def get_up_block(
933
+ up_block_type,
934
+ num_layers,
935
+ in_channels,
936
+ out_channels,
937
+ prev_output_channel,
938
+ temb_channels,
939
+ add_upsample,
940
+ resnet_eps,
941
+ resnet_act_fn,
942
+ transformer_layers_per_block=1,
943
+ num_attention_heads=None,
944
+ resnet_groups=None,
945
+ cross_attention_dim=None,
946
+ use_linear_projection=False,
947
+ only_cross_attention=False,
948
+ upcast_attention=False,
949
+ resnet_time_scale_shift="default",
950
+ ):
951
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
952
+ if up_block_type == "UpBlock2D":
953
+ return UpBlock2D(
954
+ num_layers=num_layers,
955
+ in_channels=in_channels,
956
+ out_channels=out_channels,
957
+ prev_output_channel=prev_output_channel,
958
+ temb_channels=temb_channels,
959
+ add_upsample=add_upsample,
960
+ resnet_eps=resnet_eps,
961
+ resnet_act_fn=resnet_act_fn,
962
+ resnet_groups=resnet_groups,
963
+ resnet_time_scale_shift=resnet_time_scale_shift,
964
+ )
965
+ elif up_block_type == "CrossAttnUpBlock2D":
966
+ if cross_attention_dim is None:
967
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
968
+ return CrossAttnUpBlock2D(
969
+ num_layers=num_layers,
970
+ transformer_layers_per_block=transformer_layers_per_block,
971
+ in_channels=in_channels,
972
+ out_channels=out_channels,
973
+ prev_output_channel=prev_output_channel,
974
+ temb_channels=temb_channels,
975
+ add_upsample=add_upsample,
976
+ resnet_eps=resnet_eps,
977
+ resnet_act_fn=resnet_act_fn,
978
+ resnet_groups=resnet_groups,
979
+ cross_attention_dim=cross_attention_dim,
980
+ num_attention_heads=num_attention_heads,
981
+ use_linear_projection=use_linear_projection,
982
+ only_cross_attention=only_cross_attention,
983
+ upcast_attention=upcast_attention,
984
+ resnet_time_scale_shift=resnet_time_scale_shift,
985
+ )
986
+ raise ValueError(f"{up_block_type} does not exist.")
987
+
988
+
989
+ class CrossAttnDownBlock2D(nn.Module):
990
+ def __init__(
991
+ self,
992
+ in_channels: int,
993
+ out_channels: int,
994
+ temb_channels: int,
995
+ dropout: float = 0.0,
996
+ num_layers: int = 1,
997
+ transformer_layers_per_block: int = 1,
998
+ resnet_eps: float = 1e-6,
999
+ resnet_time_scale_shift: str = "default",
1000
+ resnet_act_fn: str = "swish",
1001
+ resnet_groups: int = 32,
1002
+ resnet_pre_norm: bool = True,
1003
+ num_attention_heads=1,
1004
+ cross_attention_dim=1280,
1005
+ output_scale_factor=1.0,
1006
+ downsample_padding=1,
1007
+ add_downsample=True,
1008
+ use_linear_projection=False,
1009
+ only_cross_attention=False,
1010
+ upcast_attention=False,
1011
+ ):
1012
+ super().__init__()
1013
+ resnets = []
1014
+ attentions = []
1015
+
1016
+ self.has_cross_attention = True
1017
+ self.num_attention_heads = num_attention_heads
1018
+
1019
+ if isinstance(cross_attention_dim, int):
1020
+ cross_attention_dim = (cross_attention_dim,)
1021
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1022
+ raise ValueError(
1023
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1024
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1025
+ )
1026
+ self.cross_attention_dim = cross_attention_dim
1027
+
1028
+ for i in range(num_layers):
1029
+ in_channels = in_channels if i == 0 else out_channels
1030
+ resnets.append(
1031
+ ResnetBlock2D(
1032
+ in_channels=in_channels,
1033
+ out_channels=out_channels,
1034
+ temb_channels=temb_channels,
1035
+ eps=resnet_eps,
1036
+ groups=resnet_groups,
1037
+ dropout=dropout,
1038
+ time_embedding_norm=resnet_time_scale_shift,
1039
+ non_linearity=resnet_act_fn,
1040
+ output_scale_factor=output_scale_factor,
1041
+ pre_norm=resnet_pre_norm,
1042
+ )
1043
+ )
1044
+ for j in range(len(cross_attention_dim)):
1045
+ attentions.append(
1046
+ Transformer2DModel(
1047
+ num_attention_heads,
1048
+ out_channels // num_attention_heads,
1049
+ in_channels=out_channels,
1050
+ num_layers=transformer_layers_per_block,
1051
+ cross_attention_dim=cross_attention_dim[j],
1052
+ norm_num_groups=resnet_groups,
1053
+ use_linear_projection=use_linear_projection,
1054
+ only_cross_attention=only_cross_attention,
1055
+ upcast_attention=upcast_attention,
1056
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1057
+ )
1058
+ )
1059
+ self.attentions = nn.ModuleList(attentions)
1060
+ self.resnets = nn.ModuleList(resnets)
1061
+
1062
+ if add_downsample:
1063
+ self.downsamplers = nn.ModuleList(
1064
+ [
1065
+ Downsample2D(
1066
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
1067
+ )
1068
+ ]
1069
+ )
1070
+ else:
1071
+ self.downsamplers = None
1072
+
1073
+ self.gradient_checkpointing = False
1074
+
1075
+ def forward(
1076
+ self,
1077
+ hidden_states: torch.FloatTensor,
1078
+ temb: Optional[torch.FloatTensor] = None,
1079
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1080
+ attention_mask: Optional[torch.FloatTensor] = None,
1081
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1082
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1083
+ encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1084
+ encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1085
+ ):
1086
+ output_states = ()
1087
+ num_layers = len(self.resnets)
1088
+ num_attention_per_layer = len(self.attentions) // num_layers
1089
+
1090
+ encoder_hidden_states_1 = (
1091
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1092
+ )
1093
+ encoder_attention_mask_1 = (
1094
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1095
+ )
1096
+
1097
+ for i in range(num_layers):
1098
+ if self.training and self.gradient_checkpointing:
1099
+
1100
+ def create_custom_forward(module, return_dict=None):
1101
+ def custom_forward(*inputs):
1102
+ if return_dict is not None:
1103
+ return module(*inputs, return_dict=return_dict)
1104
+ else:
1105
+ return module(*inputs)
1106
+
1107
+ return custom_forward
1108
+
1109
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1110
+ hidden_states = torch.utils.checkpoint.checkpoint(
1111
+ create_custom_forward(self.resnets[i]),
1112
+ hidden_states,
1113
+ temb,
1114
+ **ckpt_kwargs,
1115
+ )
1116
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1117
+ if cross_attention_dim is not None and idx <= 1:
1118
+ forward_encoder_hidden_states = encoder_hidden_states
1119
+ forward_encoder_attention_mask = encoder_attention_mask
1120
+ elif cross_attention_dim is not None and idx > 1:
1121
+ forward_encoder_hidden_states = encoder_hidden_states_1
1122
+ forward_encoder_attention_mask = encoder_attention_mask_1
1123
+ else:
1124
+ forward_encoder_hidden_states = None
1125
+ forward_encoder_attention_mask = None
1126
+ hidden_states = torch.utils.checkpoint.checkpoint(
1127
+ create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1128
+ hidden_states,
1129
+ forward_encoder_hidden_states,
1130
+ None, # timestep
1131
+ None, # class_labels
1132
+ cross_attention_kwargs,
1133
+ attention_mask,
1134
+ forward_encoder_attention_mask,
1135
+ **ckpt_kwargs,
1136
+ )[0]
1137
+ else:
1138
+ hidden_states = self.resnets[i](hidden_states, temb)
1139
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1140
+ if cross_attention_dim is not None and idx <= 1:
1141
+ forward_encoder_hidden_states = encoder_hidden_states
1142
+ forward_encoder_attention_mask = encoder_attention_mask
1143
+ elif cross_attention_dim is not None and idx > 1:
1144
+ forward_encoder_hidden_states = encoder_hidden_states_1
1145
+ forward_encoder_attention_mask = encoder_attention_mask_1
1146
+ else:
1147
+ forward_encoder_hidden_states = None
1148
+ forward_encoder_attention_mask = None
1149
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1150
+ hidden_states,
1151
+ attention_mask=attention_mask,
1152
+ encoder_hidden_states=forward_encoder_hidden_states,
1153
+ encoder_attention_mask=forward_encoder_attention_mask,
1154
+ return_dict=False,
1155
+ )[0]
1156
+
1157
+ output_states = output_states + (hidden_states,)
1158
+
1159
+ if self.downsamplers is not None:
1160
+ for downsampler in self.downsamplers:
1161
+ hidden_states = downsampler(hidden_states)
1162
+
1163
+ output_states = output_states + (hidden_states,)
1164
+
1165
+ return hidden_states, output_states
1166
+
1167
+
1168
+ class UNetMidBlock2DCrossAttn(nn.Module):
1169
+ def __init__(
1170
+ self,
1171
+ in_channels: int,
1172
+ temb_channels: int,
1173
+ dropout: float = 0.0,
1174
+ num_layers: int = 1,
1175
+ transformer_layers_per_block: int = 1,
1176
+ resnet_eps: float = 1e-6,
1177
+ resnet_time_scale_shift: str = "default",
1178
+ resnet_act_fn: str = "swish",
1179
+ resnet_groups: int = 32,
1180
+ resnet_pre_norm: bool = True,
1181
+ num_attention_heads=1,
1182
+ output_scale_factor=1.0,
1183
+ cross_attention_dim=1280,
1184
+ use_linear_projection=False,
1185
+ upcast_attention=False,
1186
+ ):
1187
+ super().__init__()
1188
+
1189
+ self.has_cross_attention = True
1190
+ self.num_attention_heads = num_attention_heads
1191
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1192
+
1193
+ if isinstance(cross_attention_dim, int):
1194
+ cross_attention_dim = (cross_attention_dim,)
1195
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1196
+ raise ValueError(
1197
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1198
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1199
+ )
1200
+ self.cross_attention_dim = cross_attention_dim
1201
+
1202
+ # there is always at least one resnet
1203
+ resnets = [
1204
+ ResnetBlock2D(
1205
+ in_channels=in_channels,
1206
+ out_channels=in_channels,
1207
+ temb_channels=temb_channels,
1208
+ eps=resnet_eps,
1209
+ groups=resnet_groups,
1210
+ dropout=dropout,
1211
+ time_embedding_norm=resnet_time_scale_shift,
1212
+ non_linearity=resnet_act_fn,
1213
+ output_scale_factor=output_scale_factor,
1214
+ pre_norm=resnet_pre_norm,
1215
+ )
1216
+ ]
1217
+ attentions = []
1218
+
1219
+ for i in range(num_layers):
1220
+ for j in range(len(cross_attention_dim)):
1221
+ attentions.append(
1222
+ Transformer2DModel(
1223
+ num_attention_heads,
1224
+ in_channels // num_attention_heads,
1225
+ in_channels=in_channels,
1226
+ num_layers=transformer_layers_per_block,
1227
+ cross_attention_dim=cross_attention_dim[j],
1228
+ norm_num_groups=resnet_groups,
1229
+ use_linear_projection=use_linear_projection,
1230
+ upcast_attention=upcast_attention,
1231
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1232
+ )
1233
+ )
1234
+ resnets.append(
1235
+ ResnetBlock2D(
1236
+ in_channels=in_channels,
1237
+ out_channels=in_channels,
1238
+ temb_channels=temb_channels,
1239
+ eps=resnet_eps,
1240
+ groups=resnet_groups,
1241
+ dropout=dropout,
1242
+ time_embedding_norm=resnet_time_scale_shift,
1243
+ non_linearity=resnet_act_fn,
1244
+ output_scale_factor=output_scale_factor,
1245
+ pre_norm=resnet_pre_norm,
1246
+ )
1247
+ )
1248
+
1249
+ self.attentions = nn.ModuleList(attentions)
1250
+ self.resnets = nn.ModuleList(resnets)
1251
+
1252
+ self.gradient_checkpointing = False
1253
+
1254
+ def forward(
1255
+ self,
1256
+ hidden_states: torch.FloatTensor,
1257
+ temb: Optional[torch.FloatTensor] = None,
1258
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1259
+ attention_mask: Optional[torch.FloatTensor] = None,
1260
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1261
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1262
+ encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1263
+ encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1264
+ ) -> torch.FloatTensor:
1265
+ hidden_states = self.resnets[0](hidden_states, temb)
1266
+ num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1)
1267
+
1268
+ encoder_hidden_states_1 = (
1269
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1270
+ )
1271
+ encoder_attention_mask_1 = (
1272
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1273
+ )
1274
+
1275
+ for i in range(len(self.resnets[1:])):
1276
+ if self.training and self.gradient_checkpointing:
1277
+
1278
+ def create_custom_forward(module, return_dict=None):
1279
+ def custom_forward(*inputs):
1280
+ if return_dict is not None:
1281
+ return module(*inputs, return_dict=return_dict)
1282
+ else:
1283
+ return module(*inputs)
1284
+
1285
+ return custom_forward
1286
+
1287
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1288
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1289
+ if cross_attention_dim is not None and idx <= 1:
1290
+ forward_encoder_hidden_states = encoder_hidden_states
1291
+ forward_encoder_attention_mask = encoder_attention_mask
1292
+ elif cross_attention_dim is not None and idx > 1:
1293
+ forward_encoder_hidden_states = encoder_hidden_states_1
1294
+ forward_encoder_attention_mask = encoder_attention_mask_1
1295
+ else:
1296
+ forward_encoder_hidden_states = None
1297
+ forward_encoder_attention_mask = None
1298
+ hidden_states = torch.utils.checkpoint.checkpoint(
1299
+ create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1300
+ hidden_states,
1301
+ forward_encoder_hidden_states,
1302
+ None, # timestep
1303
+ None, # class_labels
1304
+ cross_attention_kwargs,
1305
+ attention_mask,
1306
+ forward_encoder_attention_mask,
1307
+ **ckpt_kwargs,
1308
+ )[0]
1309
+ hidden_states = torch.utils.checkpoint.checkpoint(
1310
+ create_custom_forward(self.resnets[i + 1]),
1311
+ hidden_states,
1312
+ temb,
1313
+ **ckpt_kwargs,
1314
+ )
1315
+ else:
1316
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1317
+ if cross_attention_dim is not None and idx <= 1:
1318
+ forward_encoder_hidden_states = encoder_hidden_states
1319
+ forward_encoder_attention_mask = encoder_attention_mask
1320
+ elif cross_attention_dim is not None and idx > 1:
1321
+ forward_encoder_hidden_states = encoder_hidden_states_1
1322
+ forward_encoder_attention_mask = encoder_attention_mask_1
1323
+ else:
1324
+ forward_encoder_hidden_states = None
1325
+ forward_encoder_attention_mask = None
1326
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1327
+ hidden_states,
1328
+ attention_mask=attention_mask,
1329
+ encoder_hidden_states=forward_encoder_hidden_states,
1330
+ encoder_attention_mask=forward_encoder_attention_mask,
1331
+ return_dict=False,
1332
+ )[0]
1333
+
1334
+ hidden_states = self.resnets[i + 1](hidden_states, temb)
1335
+
1336
+ return hidden_states
1337
+
1338
+
1339
+ class CrossAttnUpBlock2D(nn.Module):
1340
+ def __init__(
1341
+ self,
1342
+ in_channels: int,
1343
+ out_channels: int,
1344
+ prev_output_channel: int,
1345
+ temb_channels: int,
1346
+ dropout: float = 0.0,
1347
+ num_layers: int = 1,
1348
+ transformer_layers_per_block: int = 1,
1349
+ resnet_eps: float = 1e-6,
1350
+ resnet_time_scale_shift: str = "default",
1351
+ resnet_act_fn: str = "swish",
1352
+ resnet_groups: int = 32,
1353
+ resnet_pre_norm: bool = True,
1354
+ num_attention_heads=1,
1355
+ cross_attention_dim=1280,
1356
+ output_scale_factor=1.0,
1357
+ add_upsample=True,
1358
+ use_linear_projection=False,
1359
+ only_cross_attention=False,
1360
+ upcast_attention=False,
1361
+ ):
1362
+ super().__init__()
1363
+ resnets = []
1364
+ attentions = []
1365
+
1366
+ self.has_cross_attention = True
1367
+ self.num_attention_heads = num_attention_heads
1368
+
1369
+ if isinstance(cross_attention_dim, int):
1370
+ cross_attention_dim = (cross_attention_dim,)
1371
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1372
+ raise ValueError(
1373
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1374
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1375
+ )
1376
+ self.cross_attention_dim = cross_attention_dim
1377
+
1378
+ for i in range(num_layers):
1379
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1380
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1381
+
1382
+ resnets.append(
1383
+ ResnetBlock2D(
1384
+ in_channels=resnet_in_channels + res_skip_channels,
1385
+ out_channels=out_channels,
1386
+ temb_channels=temb_channels,
1387
+ eps=resnet_eps,
1388
+ groups=resnet_groups,
1389
+ dropout=dropout,
1390
+ time_embedding_norm=resnet_time_scale_shift,
1391
+ non_linearity=resnet_act_fn,
1392
+ output_scale_factor=output_scale_factor,
1393
+ pre_norm=resnet_pre_norm,
1394
+ )
1395
+ )
1396
+ for j in range(len(cross_attention_dim)):
1397
+ attentions.append(
1398
+ Transformer2DModel(
1399
+ num_attention_heads,
1400
+ out_channels // num_attention_heads,
1401
+ in_channels=out_channels,
1402
+ num_layers=transformer_layers_per_block,
1403
+ cross_attention_dim=cross_attention_dim[j],
1404
+ norm_num_groups=resnet_groups,
1405
+ use_linear_projection=use_linear_projection,
1406
+ only_cross_attention=only_cross_attention,
1407
+ upcast_attention=upcast_attention,
1408
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1409
+ )
1410
+ )
1411
+ self.attentions = nn.ModuleList(attentions)
1412
+ self.resnets = nn.ModuleList(resnets)
1413
+
1414
+ if add_upsample:
1415
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1416
+ else:
1417
+ self.upsamplers = None
1418
+
1419
+ self.gradient_checkpointing = False
1420
+
1421
+ def forward(
1422
+ self,
1423
+ hidden_states: torch.FloatTensor,
1424
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1425
+ temb: Optional[torch.FloatTensor] = None,
1426
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1427
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1428
+ upsample_size: Optional[int] = None,
1429
+ attention_mask: Optional[torch.FloatTensor] = None,
1430
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1431
+ encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1432
+ encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1433
+ ):
1434
+ num_layers = len(self.resnets)
1435
+ num_attention_per_layer = len(self.attentions) // num_layers
1436
+
1437
+ encoder_hidden_states_1 = (
1438
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1439
+ )
1440
+ encoder_attention_mask_1 = (
1441
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1442
+ )
1443
+
1444
+ for i in range(num_layers):
1445
+ # pop res hidden states
1446
+ res_hidden_states = res_hidden_states_tuple[-1]
1447
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1448
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1449
+
1450
+ if self.training and self.gradient_checkpointing:
1451
+
1452
+ def create_custom_forward(module, return_dict=None):
1453
+ def custom_forward(*inputs):
1454
+ if return_dict is not None:
1455
+ return module(*inputs, return_dict=return_dict)
1456
+ else:
1457
+ return module(*inputs)
1458
+
1459
+ return custom_forward
1460
+
1461
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1462
+ hidden_states = torch.utils.checkpoint.checkpoint(
1463
+ create_custom_forward(self.resnets[i]),
1464
+ hidden_states,
1465
+ temb,
1466
+ **ckpt_kwargs,
1467
+ )
1468
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1469
+ if cross_attention_dim is not None and idx <= 1:
1470
+ forward_encoder_hidden_states = encoder_hidden_states
1471
+ forward_encoder_attention_mask = encoder_attention_mask
1472
+ elif cross_attention_dim is not None and idx > 1:
1473
+ forward_encoder_hidden_states = encoder_hidden_states_1
1474
+ forward_encoder_attention_mask = encoder_attention_mask_1
1475
+ else:
1476
+ forward_encoder_hidden_states = None
1477
+ forward_encoder_attention_mask = None
1478
+ hidden_states = torch.utils.checkpoint.checkpoint(
1479
+ create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1480
+ hidden_states,
1481
+ forward_encoder_hidden_states,
1482
+ None, # timestep
1483
+ None, # class_labels
1484
+ cross_attention_kwargs,
1485
+ attention_mask,
1486
+ forward_encoder_attention_mask,
1487
+ **ckpt_kwargs,
1488
+ )[0]
1489
+ else:
1490
+ hidden_states = self.resnets[i](hidden_states, temb)
1491
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1492
+ if cross_attention_dim is not None and idx <= 1:
1493
+ forward_encoder_hidden_states = encoder_hidden_states
1494
+ forward_encoder_attention_mask = encoder_attention_mask
1495
+ elif cross_attention_dim is not None and idx > 1:
1496
+ forward_encoder_hidden_states = encoder_hidden_states_1
1497
+ forward_encoder_attention_mask = encoder_attention_mask_1
1498
+ else:
1499
+ forward_encoder_hidden_states = None
1500
+ forward_encoder_attention_mask = None
1501
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1502
+ hidden_states,
1503
+ attention_mask=attention_mask,
1504
+ encoder_hidden_states=forward_encoder_hidden_states,
1505
+ encoder_attention_mask=forward_encoder_attention_mask,
1506
+ return_dict=False,
1507
+ )[0]
1508
+
1509
+ if self.upsamplers is not None:
1510
+ for upsampler in self.upsamplers:
1511
+ hidden_states = upsampler(hidden_states, upsample_size)
1512
+
1513
+ return hidden_states
llama/audioldm2/pipeline_audioldm2.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import (
23
+ ClapFeatureExtractor,
24
+ ClapModel,
25
+ GPT2Model,
26
+ RobertaTokenizer,
27
+ RobertaTokenizerFast,
28
+ SpeechT5HifiGan,
29
+ T5EncoderModel,
30
+ T5Tokenizer,
31
+ T5TokenizerFast,
32
+ )
33
+
34
+ from diffusers.models import AutoencoderKL
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.utils import (
37
+ is_accelerate_available,
38
+ is_accelerate_version,
39
+ is_librosa_available,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+ from diffusers.utils.torch_utils import randn_tensor
44
+ from diffusers.pipeline_utils import DiffusionPipeline
45
+ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
46
+ from diffusers.utils import BaseOutput
47
+
48
+ if is_librosa_available():
49
+ import librosa
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """
54
+ Examples:
55
+ ```py
56
+ >>> import scipy
57
+ >>> import torch
58
+ >>> from diffusers import AudioLDM2Pipeline
59
+
60
+ >>> repo_id = "cvssp/audioldm2"
61
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
62
+ >>> pipe = pipe.to("cuda")
63
+
64
+ >>> # define the prompts
65
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
66
+ >>> negative_prompt = "Low quality."
67
+
68
+ >>> # set the seed for generator
69
+ >>> generator = torch.Generator("cuda").manual_seed(0)
70
+
71
+ >>> # run the generation
72
+ >>> audio = pipe(
73
+ ... prompt,
74
+ ... negative_prompt=negative_prompt,
75
+ ... num_inference_steps=200,
76
+ ... audio_length_in_s=10.0,
77
+ ... num_waveforms_per_prompt=3,
78
+ ... generator=generator,
79
+ ... ).audios
80
+
81
+ >>> # save the best audio sample (index 0) as a .wav file
82
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
83
+ ```
84
+ """
85
+
86
+
87
+ @dataclass
88
+ class AudioPipelineOutput(BaseOutput):
89
+ """
90
+ Output class for audio pipelines.
91
+
92
+ Args:
93
+ audios (`np.ndarray`)
94
+ List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`.
95
+ """
96
+
97
+ audios: np.ndarray
98
+
99
+
100
+ def prepare_inputs_for_generation(
101
+ inputs_embeds,
102
+ attention_mask=None,
103
+ past_key_values=None,
104
+ **kwargs,
105
+ ):
106
+ if past_key_values is not None:
107
+ # only last token for inputs_embeds if past is defined in kwargs
108
+ inputs_embeds = inputs_embeds[:, -1:]
109
+
110
+ return {
111
+ "inputs_embeds": inputs_embeds,
112
+ "attention_mask": attention_mask,
113
+ "past_key_values": past_key_values,
114
+ "use_cache": kwargs.get("use_cache"),
115
+ }
116
+
117
+
118
+ class AudioLDM2Pipeline(DiffusionPipeline):
119
+ r"""
120
+ Pipeline for text-to-audio generation using AudioLDM2.
121
+
122
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
123
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
124
+
125
+ Args:
126
+ vae ([`AutoencoderKL`]):
127
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
128
+ text_encoder ([`~transformers.ClapModel`]):
129
+ First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
130
+ [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
131
+ specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
132
+ text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
133
+ rank generated waveforms against the text prompt by computing similarity scores.
134
+ text_encoder_2 ([`~transformers.T5EncoderModel`]):
135
+ Second frozen text-encoder. AudioLDM2 uses the encoder of
136
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
137
+ [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
138
+ projection_model ([`AudioLDM2ProjectionModel`]):
139
+ A trained model used to linearly project the hidden-states from the first and second text encoder models
140
+ and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
141
+ concatenated to give the input to the language model.
142
+ language_model ([`~transformers.GPT2Model`]):
143
+ An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
144
+ outputs from the two text encoders.
145
+ tokenizer ([`~transformers.RobertaTokenizer`]):
146
+ Tokenizer to tokenize text for the first frozen text-encoder.
147
+ tokenizer_2 ([`~transformers.T5Tokenizer`]):
148
+ Tokenizer to tokenize text for the second frozen text-encoder.
149
+ feature_extractor ([`~transformers.ClapFeatureExtractor`]):
150
+ Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
151
+ unet ([`UNet2DConditionModel`]):
152
+ A `UNet2DConditionModel` to denoise the encoded audio latents.
153
+ scheduler ([`SchedulerMixin`]):
154
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
155
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
156
+ vocoder ([`~transformers.SpeechT5HifiGan`]):
157
+ Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ vae: AutoencoderKL,
163
+ text_encoder: ClapModel,
164
+ text_encoder_2: T5EncoderModel,
165
+ projection_model: AudioLDM2ProjectionModel,
166
+ language_model: GPT2Model,
167
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
168
+ tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
169
+ feature_extractor: ClapFeatureExtractor,
170
+ unet: AudioLDM2UNet2DConditionModel,
171
+ scheduler: KarrasDiffusionSchedulers,
172
+ vocoder: SpeechT5HifiGan,
173
+ ):
174
+ super().__init__()
175
+
176
+ self.register_modules(
177
+ vae=vae,
178
+ text_encoder=text_encoder,
179
+ text_encoder_2=text_encoder_2,
180
+ projection_model=projection_model,
181
+ language_model=language_model,
182
+ tokenizer=tokenizer,
183
+ tokenizer_2=tokenizer_2,
184
+ feature_extractor=feature_extractor,
185
+ unet=unet,
186
+ scheduler=scheduler,
187
+ vocoder=vocoder,
188
+ )
189
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
190
+
191
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
192
+ def enable_vae_slicing(self):
193
+ r"""
194
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
195
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
196
+ """
197
+ self.vae.enable_slicing()
198
+
199
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
200
+ def disable_vae_slicing(self):
201
+ r"""
202
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
203
+ computing decoding in one step.
204
+ """
205
+ self.vae.disable_slicing()
206
+
207
+ def enable_model_cpu_offload(self, gpu_id=0):
208
+ r"""
209
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
210
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
211
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
212
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
213
+ """
214
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
215
+ from accelerate import cpu_offload_with_hook
216
+ else:
217
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
218
+
219
+ device = torch.device(f"cuda:{gpu_id}")
220
+
221
+ if self.device.type != "cpu":
222
+ self.to("cpu", silence_dtype_warnings=True)
223
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
224
+
225
+ model_sequence = [
226
+ self.text_encoder.text_model,
227
+ self.text_encoder.text_projection,
228
+ self.text_encoder_2,
229
+ self.projection_model,
230
+ self.language_model,
231
+ self.unet,
232
+ self.vae,
233
+ self.vocoder,
234
+ self.text_encoder,
235
+ ]
236
+
237
+ hook = None
238
+ for cpu_offloaded_model in model_sequence:
239
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
240
+
241
+ # We'll offload the last model manually.
242
+ self.final_offload_hook = hook
243
+
244
+ def generate_language_model(
245
+ self,
246
+ inputs_embeds: torch.Tensor = None,
247
+ max_new_tokens: int = 8,
248
+ **model_kwargs,
249
+ ):
250
+ """
251
+
252
+ Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
253
+
254
+ Parameters:
255
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
256
+ The sequence used as a prompt for the generation.
257
+ max_new_tokens (`int`):
258
+ Number of new tokens to generate.
259
+ model_kwargs (`Dict[str, Any]`, *optional*):
260
+ Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
261
+ function of the model.
262
+
263
+ Return:
264
+ `inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
265
+ The sequence of generated hidden-states.
266
+ """
267
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
268
+ for _ in range(max_new_tokens):
269
+ # prepare model inputs
270
+ model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
271
+
272
+ # forward pass to get next hidden states
273
+ output = self.language_model(**model_inputs, return_dict=True)
274
+
275
+ next_hidden_states = output.last_hidden_state
276
+
277
+ # Update the model input
278
+ inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
279
+
280
+ # Update generated hidden states, model inputs, and length for next step
281
+ model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
282
+
283
+ return inputs_embeds[:, -max_new_tokens:, :]
284
+
285
+ def encode_prompt(
286
+ self,
287
+ prompt,
288
+ device,
289
+ num_waveforms_per_prompt,
290
+ do_classifier_free_guidance,
291
+ negative_prompt=None,
292
+ prompt_embeds: Optional[torch.FloatTensor] = None,
293
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
294
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
295
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
296
+ attention_mask: Optional[torch.LongTensor] = None,
297
+ negative_attention_mask: Optional[torch.LongTensor] = None,
298
+ max_new_tokens: Optional[int] = None,
299
+ ):
300
+ r"""
301
+ Encodes the prompt into text encoder hidden states.
302
+
303
+ Args:
304
+ prompt (`str` or `List[str]`, *optional*):
305
+ prompt to be encoded
306
+ device (`torch.device`):
307
+ torch device
308
+ num_waveforms_per_prompt (`int`):
309
+ number of waveforms that should be generated per prompt
310
+ do_classifier_free_guidance (`bool`):
311
+ whether to use classifier free guidance or not
312
+ negative_prompt (`str` or `List[str]`, *optional*):
313
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
314
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
315
+ less than `1`).
316
+ prompt_embeds (`torch.FloatTensor`, *optional*):
317
+ Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
318
+ prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
319
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
320
+ Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
321
+ *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
322
+ `negative_prompt` input argument.
323
+ generated_prompt_embeds (`torch.FloatTensor`, *optional*):
324
+ Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
325
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
326
+ argument.
327
+ negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
328
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
329
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
330
+ `negative_prompt` input argument.
331
+ attention_mask (`torch.LongTensor`, *optional*):
332
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
333
+ be computed from `prompt` input argument.
334
+ negative_attention_mask (`torch.LongTensor`, *optional*):
335
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
336
+ mask will be computed from `negative_prompt` input argument.
337
+ max_new_tokens (`int`, *optional*, defaults to None):
338
+ The number of new tokens to generate with the GPT2 language model.
339
+ Returns:
340
+ prompt_embeds (`torch.FloatTensor`):
341
+ Text embeddings from the Flan T5 model.
342
+ attention_mask (`torch.LongTensor`):
343
+ Attention mask to be applied to the `prompt_embeds`.
344
+ generated_prompt_embeds (`torch.FloatTensor`):
345
+ Text embeddings generated from the GPT2 langauge model.
346
+
347
+ Example:
348
+
349
+ ```python
350
+ >>> import scipy
351
+ >>> import torch
352
+ >>> from diffusers import AudioLDM2Pipeline
353
+
354
+ >>> repo_id = "cvssp/audioldm2"
355
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
356
+ >>> pipe = pipe.to("cuda")
357
+
358
+ >>> # Get text embedding vectors
359
+ >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
360
+ ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
361
+ ... device="cuda",
362
+ ... do_classifier_free_guidance=True,
363
+ ... )
364
+
365
+ >>> # Pass text embeddings to pipeline for text-conditional audio generation
366
+ >>> audio = pipe(
367
+ ... prompt_embeds=prompt_embeds,
368
+ ... attention_mask=attention_mask,
369
+ ... generated_prompt_embeds=generated_prompt_embeds,
370
+ ... num_inference_steps=200,
371
+ ... audio_length_in_s=10.0,
372
+ ... ).audios[0]
373
+
374
+ >>> # save generated audio sample
375
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
376
+ ```"""
377
+ if prompt is not None and isinstance(prompt, str):
378
+ batch_size = 1
379
+ elif prompt is not None and isinstance(prompt, list):
380
+ batch_size = len(prompt)
381
+ else:
382
+ batch_size = prompt_embeds.shape[0]
383
+
384
+ # Define tokenizers and text encoders
385
+ tokenizers = [self.tokenizer, self.tokenizer_2]
386
+ text_encoders = [self.text_encoder, self.text_encoder_2]
387
+
388
+ if prompt_embeds is None:
389
+ prompt_embeds_list = []
390
+ attention_mask_list = []
391
+
392
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
393
+ text_inputs = tokenizer(
394
+ prompt,
395
+ padding="max_length",
396
+ max_length=tokenizer.model_max_length,
397
+ truncation=True,
398
+ return_tensors="pt",
399
+ )
400
+ text_input_ids = text_inputs.input_ids
401
+ attention_mask = text_inputs.attention_mask
402
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
403
+
404
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
405
+ text_input_ids, untruncated_ids
406
+ ):
407
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
408
+ logger.warning(
409
+ f"The following part of your input was truncated because {text_encoder.config.model_type} can "
410
+ f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
411
+ )
412
+
413
+ text_input_ids = text_input_ids.to(device)
414
+ attention_mask = attention_mask.to(device)
415
+
416
+ if text_encoder.config.model_type == "clap":
417
+ prompt_embeds = text_encoder.get_text_features(
418
+ text_input_ids,
419
+ attention_mask=attention_mask,
420
+ )
421
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
422
+ prompt_embeds = prompt_embeds[:, None, :]
423
+ # make sure that we attend to this single hidden-state
424
+ attention_mask = attention_mask.new_ones((batch_size, 1))
425
+ else:
426
+ prompt_embeds = text_encoder(
427
+ text_input_ids,
428
+ attention_mask=attention_mask,
429
+ )
430
+ prompt_embeds = prompt_embeds[0]
431
+
432
+ prompt_embeds_list.append(prompt_embeds)
433
+ attention_mask_list.append(attention_mask)
434
+ projection_output = self.projection_model(
435
+ hidden_states=prompt_embeds_list[0],
436
+ hidden_states_1=prompt_embeds_list[1],
437
+ attention_mask=attention_mask_list[0],
438
+ attention_mask_1=attention_mask_list[1],
439
+ )
440
+ projected_prompt_embeds = projection_output.hidden_states
441
+ projected_attention_mask = projection_output.attention_mask
442
+
443
+ generated_prompt_embeds = self.generate_language_model(
444
+ projected_prompt_embeds,
445
+ attention_mask=projected_attention_mask,
446
+ max_new_tokens=max_new_tokens,
447
+ )
448
+
449
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
450
+ attention_mask = (
451
+ attention_mask.to(device=device)
452
+ if attention_mask is not None
453
+ else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
454
+ )
455
+ generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
456
+
457
+ bs_embed, seq_len, hidden_size = prompt_embeds.shape
458
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
459
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
460
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
461
+
462
+ # duplicate attention mask for each generation per prompt
463
+ attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
464
+ attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
465
+
466
+ bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
467
+ # duplicate generated embeddings for each generation per prompt, using mps friendly method
468
+ generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
469
+ generated_prompt_embeds = generated_prompt_embeds.view(
470
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
471
+ )
472
+
473
+ # get unconditional embeddings for classifier free guidance
474
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
475
+ uncond_tokens: List[str]
476
+ if negative_prompt is None:
477
+ uncond_tokens = [""] * batch_size
478
+ elif type(prompt) is not type(negative_prompt):
479
+ raise TypeError(
480
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
481
+ f" {type(prompt)}."
482
+ )
483
+ elif isinstance(negative_prompt, str):
484
+ uncond_tokens = [negative_prompt]
485
+ elif batch_size != len(negative_prompt):
486
+ raise ValueError(
487
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
488
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
489
+ " the batch size of `prompt`."
490
+ )
491
+ else:
492
+ uncond_tokens = negative_prompt
493
+
494
+ negative_prompt_embeds_list = []
495
+ negative_attention_mask_list = []
496
+ max_length = prompt_embeds.shape[1]
497
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
498
+ uncond_input = tokenizer(
499
+ uncond_tokens,
500
+ padding="max_length",
501
+ max_length=tokenizer.model_max_length
502
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
503
+ else max_length,
504
+ truncation=True,
505
+ return_tensors="pt",
506
+ )
507
+
508
+ uncond_input_ids = uncond_input.input_ids.to(device)
509
+ negative_attention_mask = uncond_input.attention_mask.to(device)
510
+
511
+ if text_encoder.config.model_type == "clap":
512
+ negative_prompt_embeds = text_encoder.get_text_features(
513
+ uncond_input_ids,
514
+ attention_mask=negative_attention_mask,
515
+ )
516
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
517
+ negative_prompt_embeds = negative_prompt_embeds[:, None, :]
518
+ # make sure that we attend to this single hidden-state
519
+ negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
520
+ else:
521
+ negative_prompt_embeds = text_encoder(
522
+ uncond_input_ids,
523
+ attention_mask=negative_attention_mask,
524
+ )
525
+ negative_prompt_embeds = negative_prompt_embeds[0]
526
+
527
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
528
+ negative_attention_mask_list.append(negative_attention_mask)
529
+
530
+ projection_output = self.projection_model(
531
+ hidden_states=negative_prompt_embeds_list[0],
532
+ hidden_states_1=negative_prompt_embeds_list[1],
533
+ attention_mask=negative_attention_mask_list[0],
534
+ attention_mask_1=negative_attention_mask_list[1],
535
+ )
536
+ negative_projected_prompt_embeds = projection_output.hidden_states
537
+ negative_projected_attention_mask = projection_output.attention_mask
538
+
539
+ negative_generated_prompt_embeds = self.generate_language_model(
540
+ negative_projected_prompt_embeds,
541
+ attention_mask=negative_projected_attention_mask,
542
+ max_new_tokens=max_new_tokens,
543
+ )
544
+
545
+ if do_classifier_free_guidance:
546
+ seq_len = negative_prompt_embeds.shape[1]
547
+
548
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
549
+ negative_attention_mask = (
550
+ negative_attention_mask.to(device=device)
551
+ if negative_attention_mask is not None
552
+ else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
553
+ )
554
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
555
+ dtype=self.language_model.dtype, device=device
556
+ )
557
+
558
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
559
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
560
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
561
+
562
+ # duplicate unconditional attention mask for each generation per prompt
563
+ negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
564
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
565
+
566
+ # duplicate unconditional generated embeddings for each generation per prompt
567
+ seq_len = negative_generated_prompt_embeds.shape[1]
568
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
569
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
570
+ batch_size * num_waveforms_per_prompt, seq_len, -1
571
+ )
572
+
573
+ # For classifier free guidance, we need to do two forward passes.
574
+ # Here we concatenate the unconditional and text embeddings into a single batch
575
+ # to avoid doing two forward passes
576
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
577
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
578
+ generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
579
+
580
+ return prompt_embeds, attention_mask, generated_prompt_embeds
581
+
582
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
583
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
584
+ if mel_spectrogram.dim() == 4:
585
+ mel_spectrogram = mel_spectrogram.squeeze(1)
586
+
587
+ waveform = self.vocoder(mel_spectrogram)
588
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
589
+ waveform = waveform.cpu().float()
590
+ return waveform
591
+
592
+ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
593
+ if not is_librosa_available():
594
+ logger.info(
595
+ "Automatic scoring of the generated audio waveforms against the input prompt text requires the "
596
+ "`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
597
+ "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
598
+ )
599
+ return audio
600
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
601
+ resampled_audio = librosa.resample(
602
+ audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
603
+ )
604
+ inputs["input_features"] = self.feature_extractor(
605
+ list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
606
+ ).input_features.type(dtype)
607
+ inputs = inputs.to(device)
608
+
609
+ # compute the audio-text similarity score using the CLAP model
610
+ logits_per_text = self.text_encoder(**inputs).logits_per_text
611
+ # sort by the highest matching generations per prompt
612
+ indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
613
+ audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
614
+ return audio
615
+
616
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
617
+ def prepare_extra_step_kwargs(self, generator, eta):
618
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
619
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
620
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
621
+ # and should be between [0, 1]
622
+
623
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
624
+ extra_step_kwargs = {}
625
+ if accepts_eta:
626
+ extra_step_kwargs["eta"] = eta
627
+
628
+ # check if the scheduler accepts generator
629
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
630
+ if accepts_generator:
631
+ extra_step_kwargs["generator"] = generator
632
+ return extra_step_kwargs
633
+
634
+ def check_inputs(
635
+ self,
636
+ prompt,
637
+ audio_length_in_s,
638
+ vocoder_upsample_factor,
639
+ callback_steps,
640
+ negative_prompt=None,
641
+ prompt_embeds=None,
642
+ negative_prompt_embeds=None,
643
+ generated_prompt_embeds=None,
644
+ negative_generated_prompt_embeds=None,
645
+ attention_mask=None,
646
+ negative_attention_mask=None,
647
+ ):
648
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
649
+ if audio_length_in_s < min_audio_length_in_s:
650
+ raise ValueError(
651
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
652
+ f"is {audio_length_in_s}."
653
+ )
654
+
655
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
656
+ raise ValueError(
657
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
658
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
659
+ f"{self.vae_scale_factor}."
660
+ )
661
+
662
+ if (callback_steps is None) or (
663
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
664
+ ):
665
+ raise ValueError(
666
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
667
+ f" {type(callback_steps)}."
668
+ )
669
+
670
+ if prompt is not None and prompt_embeds is not None:
671
+ raise ValueError(
672
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
673
+ " only forward one of the two."
674
+ )
675
+ elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
676
+ raise ValueError(
677
+ "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
678
+ "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
679
+ )
680
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
681
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
682
+
683
+ if negative_prompt is not None and negative_prompt_embeds is not None:
684
+ raise ValueError(
685
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
686
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
687
+ )
688
+ elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
689
+ raise ValueError(
690
+ "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
691
+ "both arguments are specified"
692
+ )
693
+
694
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
695
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
696
+ raise ValueError(
697
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
698
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
699
+ f" {negative_prompt_embeds.shape}."
700
+ )
701
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
702
+ raise ValueError(
703
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
704
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
705
+ )
706
+
707
+ if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
708
+ if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
709
+ raise ValueError(
710
+ "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
711
+ f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
712
+ f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
713
+ )
714
+ if (
715
+ negative_attention_mask is not None
716
+ and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
717
+ ):
718
+ raise ValueError(
719
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
720
+ f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
721
+ )
722
+
723
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
724
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
725
+ shape = (
726
+ batch_size,
727
+ num_channels_latents,
728
+ height // self.vae_scale_factor,
729
+ self.vocoder.config.model_in_dim // self.vae_scale_factor,
730
+ )
731
+ if isinstance(generator, list) and len(generator) != batch_size:
732
+ raise ValueError(
733
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
734
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
735
+ )
736
+
737
+ if latents is None:
738
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
739
+ else:
740
+ latents = latents.to(device)
741
+
742
+ # scale the initial noise by the standard deviation required by the scheduler
743
+ latents = latents * self.scheduler.init_noise_sigma
744
+ return latents
745
+
746
+ @torch.no_grad()
747
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
748
+ def __call__(
749
+ self,
750
+ prompt: Union[str, List[str]] = None,
751
+ audio_length_in_s: Optional[float] = None,
752
+ num_inference_steps: int = 200,
753
+ guidance_scale: float = 3.5,
754
+ negative_prompt: Optional[Union[str, List[str]]] = None,
755
+ num_waveforms_per_prompt: Optional[int] = 1,
756
+ eta: float = 0.0,
757
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
758
+ latents: Optional[torch.FloatTensor] = None,
759
+ prompt_embeds: Optional[torch.FloatTensor] = None,
760
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
761
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
762
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
763
+ attention_mask: Optional[torch.LongTensor] = None,
764
+ negative_attention_mask: Optional[torch.LongTensor] = None,
765
+ max_new_tokens: Optional[int] = None,
766
+ return_dict: bool = True,
767
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
768
+ callback_steps: Optional[int] = 1,
769
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
770
+ output_type: Optional[str] = "np",
771
+ return_prompts_only: Optional[bool] = False
772
+ ):
773
+ r"""
774
+ The call function to the pipeline for generation.
775
+
776
+ Args:
777
+ prompt (`str` or `List[str]`, *optional*):
778
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
779
+ audio_length_in_s (`int`, *optional*, defaults to 10.24):
780
+ The length of the generated audio sample in seconds.
781
+ num_inference_steps (`int`, *optional*, defaults to 200):
782
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
783
+ expense of slower inference.
784
+ guidance_scale (`float`, *optional*, defaults to 3.5):
785
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
786
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
787
+ negative_prompt (`str` or `List[str]`, *optional*):
788
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
789
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
790
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
791
+ The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
792
+ scoring is performed between the generated outputs and the text prompt. This scoring ranks the
793
+ generated waveforms based on their cosine similarity with the text input in the joint text-audio
794
+ embedding space.
795
+ eta (`float`, *optional*, defaults to 0.0):
796
+ Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
797
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
798
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
799
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
800
+ generation deterministic.
801
+ latents (`torch.FloatTensor`, *optional*):
802
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
803
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
804
+ tensor is generated by sampling using the supplied random `generator`.
805
+ prompt_embeds (`torch.FloatTensor`, *optional*):
806
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
807
+ provided, text embeddings are generated from the `prompt` input argument.
808
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
809
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
810
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
811
+ generated_prompt_embeds (`torch.FloatTensor`, *optional*):
812
+ Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
813
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
814
+ argument.
815
+ negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
816
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
817
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
818
+ `negative_prompt` input argument.
819
+ attention_mask (`torch.LongTensor`, *optional*):
820
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
821
+ be computed from `prompt` input argument.
822
+ negative_attention_mask (`torch.LongTensor`, *optional*):
823
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
824
+ mask will be computed from `negative_prompt` input argument.
825
+ max_new_tokens (`int`, *optional*, defaults to None):
826
+ Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
827
+ be taken from the config of the model.
828
+ return_dict (`bool`, *optional*, defaults to `True`):
829
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
830
+ plain tuple.
831
+ callback (`Callable`, *optional*):
832
+ A function that calls every `callback_steps` steps during inference. The function is called with the
833
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
834
+ callback_steps (`int`, *optional*, defaults to 1):
835
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
836
+ every step.
837
+ cross_attention_kwargs (`dict`, *optional*):
838
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
839
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
840
+ output_type (`str`, *optional*, defaults to `"np"`):
841
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
842
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
843
+ model (LDM) output.
844
+
845
+ Examples:
846
+
847
+ Returns:
848
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
849
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
850
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
851
+ """
852
+ # 0. Convert audio input length from seconds to spectrogram height
853
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
854
+
855
+ if audio_length_in_s is None:
856
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
857
+
858
+ height = int(audio_length_in_s / vocoder_upsample_factor)
859
+
860
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
861
+ if height % self.vae_scale_factor != 0:
862
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
863
+ logger.info(
864
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
865
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
866
+ f"denoising process."
867
+ )
868
+
869
+ # 1. Check inputs. Raise error if not correct
870
+ self.check_inputs(
871
+ prompt,
872
+ audio_length_in_s,
873
+ vocoder_upsample_factor,
874
+ callback_steps,
875
+ negative_prompt,
876
+ prompt_embeds,
877
+ negative_prompt_embeds,
878
+ generated_prompt_embeds,
879
+ negative_generated_prompt_embeds,
880
+ attention_mask,
881
+ negative_attention_mask,
882
+ )
883
+
884
+ # 2. Define call parameters
885
+ if prompt is not None and isinstance(prompt, str):
886
+ batch_size = 1
887
+ elif prompt is not None and isinstance(prompt, list):
888
+ batch_size = len(prompt)
889
+ else:
890
+ batch_size = prompt_embeds.shape[0]
891
+
892
+ device = self._execution_device
893
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
894
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
895
+ # corresponds to doing no classifier free guidance.
896
+ do_classifier_free_guidance = guidance_scale > 1.0
897
+
898
+ # 3. Encode input prompt
899
+ prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
900
+ prompt,
901
+ device,
902
+ num_waveforms_per_prompt,
903
+ do_classifier_free_guidance,
904
+ negative_prompt,
905
+ prompt_embeds=prompt_embeds,
906
+ negative_prompt_embeds=negative_prompt_embeds,
907
+ generated_prompt_embeds=generated_prompt_embeds,
908
+ negative_generated_prompt_embeds=negative_generated_prompt_embeds,
909
+ attention_mask=attention_mask,
910
+ negative_attention_mask=negative_attention_mask,
911
+ max_new_tokens=max_new_tokens,
912
+ )
913
+
914
+ if return_prompts_only:
915
+ return prompt_embeds, generated_prompt_embeds
916
+
917
+ # 4. Prepare timesteps
918
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
919
+ timesteps = self.scheduler.timesteps
920
+
921
+ # 5. Prepare latent variables
922
+ num_channels_latents = self.unet.config.in_channels
923
+ latents = self.prepare_latents(
924
+ batch_size * num_waveforms_per_prompt,
925
+ num_channels_latents,
926
+ height,
927
+ prompt_embeds.dtype,
928
+ device,
929
+ generator,
930
+ latents,
931
+ )
932
+
933
+ # 6. Prepare extra step kwargs
934
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
935
+
936
+ # 7. Denoising loop
937
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
938
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
939
+ for i, t in enumerate(timesteps):
940
+ # expand the latents if we are doing classifier free guidance
941
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
942
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
943
+
944
+ # predict the noise residual
945
+ noise_pred = self.unet(
946
+ latent_model_input,
947
+ t,
948
+ encoder_hidden_states=generated_prompt_embeds,
949
+ encoder_hidden_states_1=prompt_embeds,
950
+ encoder_attention_mask_1=attention_mask,
951
+ return_dict=False,
952
+ )[0]
953
+
954
+ # perform guidance
955
+ if do_classifier_free_guidance:
956
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
957
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
958
+
959
+ # compute the previous noisy sample x_t -> x_t-1
960
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
961
+
962
+ # call the callback, if provided
963
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
964
+ progress_bar.update()
965
+ if callback is not None and i % callback_steps == 0:
966
+ step_idx = i // getattr(self.scheduler, "order", 1)
967
+ callback(step_idx, t, latents)
968
+
969
+ self.maybe_free_model_hooks()
970
+
971
+ # 8. Post-processing
972
+ if not output_type == "latent":
973
+ latents = 1 / self.vae.config.scaling_factor * latents
974
+ mel_spectrogram = self.vae.decode(latents).sample
975
+ else:
976
+ return AudioPipelineOutput(audios=latents)
977
+
978
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
979
+
980
+ audio = audio[:, :original_waveform_length]
981
+
982
+ # 9. Automatic scoring
983
+ if num_waveforms_per_prompt > 1 and prompt is not None:
984
+ audio = self.score_waveforms(
985
+ text=prompt,
986
+ audio=audio,
987
+ num_waveforms_per_prompt=num_waveforms_per_prompt,
988
+ device=device,
989
+ dtype=prompt_embeds.dtype,
990
+ )
991
+
992
+ if output_type == "np":
993
+ audio = audio.numpy()
994
+
995
+ if not return_dict:
996
+ return (audio,)
997
+
998
+ return AudioPipelineOutput(audios=audio)
llama/llama.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Embedding, Linear
6
+ import torch.nn.functional as F
7
+
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import Any, Optional, Tuple
11
+
12
+
13
+ @dataclass
14
+ class ModelArgs:
15
+ dim: int = 4096
16
+ n_layers: int = 32
17
+ n_heads: int = 32
18
+ n_kv_heads: Optional[int] = None
19
+ vocab_size: int = -1 # defined later by tokenizer
20
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
21
+ ffn_dim_multiplier: Optional[float] = None
22
+ norm_eps: float = 1e-5
23
+
24
+ max_batch_size: int = 1
25
+ max_seq_len: int = 2048
26
+
27
+ w_bias: bool = True # use bias tuning
28
+ w_lora: bool = True # use lora tuning
29
+ lora_rank: int = 16
30
+
31
+ num_output_tokens: int = 128
32
+ output_dim_tokens: int = 768
33
+ num_gen_audio_tokens: int = 8
34
+
35
+
36
+ class RMSNorm(torch.nn.Module):
37
+ def __init__(self, dim: int, eps: float = 1e-6):
38
+ super().__init__()
39
+ self.eps = eps
40
+ self.weight = nn.Parameter(torch.ones(dim))
41
+
42
+ def _norm(self, x):
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ output = self._norm(x.float()).type_as(x)
47
+ return output * self.weight
48
+
49
+
50
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
51
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
52
+ t = torch.arange(end, device=freqs.device) # type: ignore
53
+ freqs = torch.outer(t, freqs).float() # type: ignore
54
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
55
+ return freqs_cis
56
+
57
+
58
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
59
+ ndim = x.ndim
60
+ assert 0 <= 1 < ndim
61
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
62
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
63
+ return freqs_cis.view(*shape)
64
+
65
+
66
+ def apply_rotary_emb(
67
+ xq: torch.Tensor,
68
+ xk: torch.Tensor,
69
+ freqs_cis: torch.Tensor,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
71
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
72
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
73
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
74
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
75
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
76
+ return xq_out.type_as(xq), xk_out.type_as(xk)
77
+
78
+
79
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
80
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
81
+ bs, slen, n_kv_heads, head_dim = x.shape
82
+ if n_rep == 1:
83
+ return x
84
+ return (
85
+ x[:, :, :, None, :]
86
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
87
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
88
+ )
89
+
90
+
91
+ class Attention(nn.Module):
92
+ def __init__(self, args: ModelArgs):
93
+ super().__init__()
94
+ self.args = args
95
+
96
+ self.n_local_heads = args.n_heads
97
+ self.n_kv_heads = args.n_kv_heads
98
+ self.head_dim = args.dim // args.n_heads
99
+
100
+ self.wq = Linear(
101
+ args.dim,
102
+ args.n_heads * self.head_dim,
103
+ bias=args.w_bias
104
+ )
105
+ self.wk = Linear(
106
+ args.dim,
107
+ args.n_heads * self.head_dim,
108
+ bias=False
109
+ )
110
+ self.wv = Linear(
111
+ args.dim,
112
+ args.n_heads * self.head_dim,
113
+ bias=False
114
+ )
115
+ self.wo = Linear(
116
+ args.n_heads * self.head_dim,
117
+ args.dim,
118
+ bias=args.w_bias
119
+ )
120
+
121
+ if args.w_bias:
122
+ nn.init.constant_(self.wq.bias.data, 0)
123
+ nn.init.constant_(self.wo.bias.data, 0)
124
+
125
+ self.w_lora = args.w_lora
126
+ if args.w_lora:
127
+ self.lora_wq_l1 = Linear(args.dim, args.lora_rank, bias=False)
128
+ self.lora_wq_l2 = Linear(args.lora_rank, args.dim, bias=False)
129
+
130
+ self.lora_wk_l1 = Linear(args.dim, args.lora_rank, bias=False)
131
+ self.lora_wk_l2 = Linear(args.lora_rank, args.dim, bias=False)
132
+
133
+ self.lora_wv_l1 = Linear(args.dim, args.lora_rank, bias=False)
134
+ self.lora_wv_l2 = Linear(args.lora_rank, args.dim, bias=False)
135
+
136
+ self.lora_wo_l1 = Linear(args.dim, args.lora_rank, bias=False)
137
+ self.lora_wo_l2 = Linear(args.lora_rank, args.dim, bias=False)
138
+ nn.init.constant_(self.lora_wq_l2.weight.data, 0)
139
+ nn.init.constant_(self.lora_wk_l2.weight.data, 0)
140
+ nn.init.constant_(self.lora_wv_l2.weight.data, 0)
141
+ nn.init.constant_(self.lora_wo_l2.weight.data, 0)
142
+
143
+ self.cache_k = None
144
+ self.cache_v = None
145
+
146
+ self.gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
147
+
148
+ def train(self, mode: bool = True):
149
+ if mode:
150
+ self.cache_k = None
151
+ self.cache_v = None
152
+ else:
153
+ self.cache_k = torch.zeros(
154
+ (self.args.max_batch_size, self.args.max_seq_len, self.n_local_heads, self.head_dim)
155
+ ).cuda()
156
+ self.cache_v = torch.zeros(
157
+ (self.args.max_batch_size, self.args.max_seq_len, self.n_local_heads, self.head_dim)
158
+ ).cuda()
159
+ return super().train(mode)
160
+
161
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor],
162
+ adapter=None):
163
+ bsz, seqlen, _ = x.shape
164
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
165
+ if self.w_lora:
166
+ xq = xq + self.lora_wq_l2(self.lora_wq_l1(x))
167
+ xk = xk + self.lora_wk_l2(self.lora_wk_l1(x))
168
+ xv = xv + self.lora_wv_l2(self.lora_wv_l1(x))
169
+
170
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
171
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
172
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
173
+
174
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
175
+
176
+ if not self.training:
177
+ self.cache_k = self.cache_k.to(xq)
178
+ self.cache_v = self.cache_v.to(xq)
179
+
180
+ self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
181
+ self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
182
+
183
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
184
+ values = self.cache_v[:bsz, : start_pos + seqlen]
185
+ else:
186
+ assert start_pos == 0
187
+ keys = xk
188
+ values = xv
189
+
190
+ if adapter is not None:
191
+ adapter_len = adapter.shape[1]
192
+ adapter_v = self.wv(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
193
+ adapter_v = adapter_v.transpose(1, 2)
194
+
195
+ if adapter_len > 1:
196
+ adapter_k = self.wk(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
197
+ adapter_k = adapter_k.transpose(1, 2)
198
+
199
+ xq = xq.transpose(1, 2)
200
+ keys = keys.transpose(1, 2)
201
+ values = values.transpose(1, 2)
202
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
203
+
204
+ if mask is not None:
205
+ scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
206
+
207
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
208
+ output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
209
+
210
+ if adapter is not None:
211
+ if adapter_len > 1:
212
+ adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
213
+ adapter_scores = self.gate.tanh() * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
214
+ output = output + torch.matmul(adapter_scores, adapter_v)
215
+ else:
216
+ output = output + self.gate.tanh() * adapter_v
217
+
218
+ output = output.transpose(
219
+ 1, 2
220
+ ).contiguous().view(bsz, seqlen, -1)
221
+
222
+ if self.w_lora:
223
+ return self.wo(output) + self.lora_wo_l2(self.lora_wo_l1(output))
224
+ else:
225
+ return self.wo(output)
226
+
227
+
228
+ class FeedForward(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim: int,
232
+ hidden_dim: int,
233
+ multiple_of: int,
234
+ args: ModelArgs,
235
+ ffn_dim_multiplier: Optional[float]
236
+ ):
237
+ super().__init__()
238
+ hidden_dim = int(2 * hidden_dim / 3)
239
+ if ffn_dim_multiplier is not None:
240
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
241
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
242
+
243
+ self.w1 = Linear(
244
+ dim, hidden_dim, bias=args.w_bias
245
+ )
246
+ self.w2 = Linear(
247
+ hidden_dim, dim, bias=args.w_bias
248
+ )
249
+ self.w3 = Linear(
250
+ dim, hidden_dim, bias=args.w_bias
251
+ )
252
+ if args.w_bias:
253
+ nn.init.constant_(self.w1.bias.data, 0)
254
+ nn.init.constant_(self.w2.bias.data, 0)
255
+ nn.init.constant_(self.w3.bias.data, 0)
256
+
257
+ self.w_lora = args.w_lora
258
+ if args.w_lora:
259
+ self.lora_w1_l1 = Linear(dim, args.lora_rank, bias=False)
260
+ self.lora_w1_l2 = Linear(args.lora_rank, hidden_dim, bias=False)
261
+ self.lora_w2_l1 = Linear(hidden_dim, args.lora_rank, bias=False)
262
+ self.lora_w2_l2 = Linear(args.lora_rank, dim, bias=False)
263
+ self.lora_w3_l1 = Linear(dim, args.lora_rank, bias=False)
264
+ self.lora_w3_l2 = Linear(args.lora_rank, hidden_dim, bias=False)
265
+ nn.init.constant_(self.lora_w1_l2.weight.data, 0)
266
+ nn.init.constant_(self.lora_w2_l2.weight.data, 0)
267
+ nn.init.constant_(self.lora_w3_l2.weight.data, 0)
268
+
269
+ def forward(self, x):
270
+ if self.w_lora:
271
+ out = F.silu(self.w1(x) + self.lora_w1_l2(self.lora_w1_l1(x))) * (
272
+ self.w3(x) + self.lora_w3_l2(self.lora_w3_l1(x)))
273
+ return self.w2(out) + self.lora_w2_l2(self.lora_w2_l1(out))
274
+ else:
275
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
276
+
277
+
278
+ class TransformerBlock(nn.Module):
279
+ def __init__(self, layer_id: int, args: ModelArgs):
280
+ super().__init__()
281
+ self.n_heads = args.n_heads
282
+ self.dim = args.dim
283
+ self.head_dim = args.dim // args.n_heads
284
+ self.attention = Attention(args)
285
+ self.feed_forward = FeedForward(
286
+ dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of,
287
+ ffn_dim_multiplier=args.ffn_dim_multiplier, args=args
288
+ )
289
+ self.layer_id = layer_id
290
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
291
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
292
+
293
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor],
294
+ prompt=None):
295
+ h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)
296
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
297
+ return out
298
+
299
+
300
+ class Transformer(nn.Module):
301
+ def __init__(self, params: ModelArgs):
302
+ super().__init__()
303
+ self.params = params
304
+ self.vocab_size = params.vocab_size
305
+ self.n_layers = params.n_layers
306
+ self.tok_embeddings = Embedding(
307
+ params.vocab_size, params.dim
308
+ )
309
+
310
+ self.layers = torch.nn.ModuleList()
311
+ for layer_id in range(params.n_layers):
312
+ self.layers.append(TransformerBlock(layer_id, params))
313
+
314
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
315
+ self.output = Linear(
316
+ params.dim, params.vocab_size, bias=False
317
+ )
318
+
319
+ self.freqs_cis = precompute_freqs_cis(
320
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
321
+ )
322
+
323
+ @torch.inference_mode()
324
+ def forward(self, tokens: torch.Tensor, start_pos: int):
325
+ _bsz, seqlen = tokens.shape
326
+ h = self.tok_embeddings(tokens)
327
+ self.freqs_cis = self.freqs_cis.to(h.device)
328
+ freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
329
+
330
+ mask = None
331
+ if seqlen > 1:
332
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
333
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
334
+
335
+ for layer in self.layers:
336
+ h = layer(h, start_pos, freqs_cis, mask)
337
+ h = self.norm(h)
338
+ output = self.output(h) # only compute last logits
339
+ return output.float()
llama/m2ugen.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .llama import Transformer, ModelArgs, RMSNorm
11
+ from .projector import ProjectionLayer
12
+ from util.misc import download
13
+ from .utils import sample_top_p
14
+ from .musicgen.musicgen import MusicgenForConditionalGeneration
15
+ from .audioldm2 import AudioLDM2Pipeline
16
+
17
+ from transformers import LlamaTokenizer
18
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
19
+ from transformers import ViTImageProcessor, ViTModel
20
+ from transformers import VivitImageProcessor, VivitModel
21
+ from transformers import AutoProcessor
22
+
23
+ import torchaudio
24
+
25
+
26
+ class M2UGen(nn.Module):
27
+ """ Masked Autoencoder with VisionTransformer backbone
28
+ """
29
+
30
+ def __init__(self, llama_ckpt_dir, llama_tokenizer, model_args, knn=False, knn_dir="./ckpts", stage=1,
31
+ legacy_bridge=False, load_llama=True, device=None):
32
+ super().__init__()
33
+
34
+ self.args = model_args
35
+
36
+ if device is None:
37
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ else:
39
+ self.device = device
40
+
41
+ # 1. MERT Encoder
42
+ # The model files for MERT can be downloaded here in case of network issues:
43
+ # https://huggingface.co/m-a-p/MERT-v1-330M
44
+ # And set the mert_path argument to directory with the model files
45
+ print(f'Initialize MERT...')
46
+ self.mert_model = AutoModel.from_pretrained(self.args.mert_path, trust_remote_code=True) # .to(self.device)
47
+ self.mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(self.args.mert_path, trust_remote_code=True)
48
+ self.mu_mert_agg = nn.Conv1d(in_channels=25, out_channels=1, kernel_size=1)
49
+ self.mu_mert_proj = nn.Linear(1024, 4096)
50
+
51
+ if legacy_bridge:
52
+ bridge_norm_layer = nn.LayerNorm
53
+ bridge_bias = True
54
+ else:
55
+ bridge_norm_layer = RMSNorm
56
+ bridge_bias = False
57
+
58
+ self.feature_scaler = 1
59
+
60
+ self.mu_mert_norm_1 = bridge_norm_layer(4096)
61
+ self.mu_mert_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
62
+ self.mu_mert_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
63
+ self.mu_mert_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
64
+
65
+ self.mu_mert_norm_2 = bridge_norm_layer(4096)
66
+ self.mu_mert_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
67
+ self.mu_mert_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
68
+ self.mu_mert_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
69
+
70
+ self.mu_mert_norm_3 = bridge_norm_layer(4096)
71
+ self.mu_mert_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
72
+ self.mu_mert_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
73
+ self.mu_mert_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
74
+ print(f'MERT initialized...')
75
+
76
+ # 2. ViT Encoder
77
+ # The model files for ViT can be downloaded here in case of network issues:
78
+ # https://huggingface.co/google/vit-base-patch16-224-in21k
79
+ # And set the vit_path argument to directory with the model files
80
+ print(f'Initialize ViT...')
81
+ self.vit_model = ViTModel.from_pretrained(self.args.vit_path) # .to(self.device)
82
+ self.vit_model.eval()
83
+ self.vit_processor = ViTImageProcessor.from_pretrained(self.args.vit_path, do_rescale=False)
84
+ self.iu_vit_agg = nn.Conv1d(in_channels=197, out_channels=1, kernel_size=1)
85
+ self.iu_vit_proj = nn.Linear(768, 4096)
86
+
87
+ self.iu_vit_norm_1 = bridge_norm_layer(4096)
88
+ self.iu_vit_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
89
+ self.iu_vit_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
90
+ self.iu_vit_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
91
+
92
+ self.iu_vit_norm_2 = bridge_norm_layer(4096)
93
+ self.iu_vit_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
94
+ self.iu_vit_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
95
+ self.iu_vit_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
96
+
97
+ self.iu_vit_norm_3 = bridge_norm_layer(4096)
98
+ self.iu_vit_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
99
+ self.iu_vit_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
100
+ self.iu_vit_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
101
+ print(f'ViT initialized...')
102
+
103
+ # 3. ViViT Encoder
104
+ # The model files for ViViT can be downloaded here in case of network issues:
105
+ # https://huggingface.co/google/vivit-b-16x2-kinetics400
106
+ # And set the vivit_path argument to directory with the model files
107
+ print(f'Initialize ViViT...')
108
+ self.vivit_model = VivitModel.from_pretrained(self.args.vivit_path) # .to(self.device)
109
+ self.vivit_model.eval()
110
+ self.vivit_processor = VivitImageProcessor.from_pretrained(self.args.vivit_path)
111
+ self.iu_vivit_agg = nn.Conv1d(in_channels=3137, out_channels=1, kernel_size=1)
112
+ self.iu_vivit_proj = nn.Linear(768, 4096)
113
+
114
+ self.iu_vivit_norm_1 = bridge_norm_layer(4096)
115
+ self.iu_vivit_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
116
+ self.iu_vivit_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
117
+ self.iu_vivit_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
118
+
119
+ self.iu_vivit_norm_2 = bridge_norm_layer(4096)
120
+ self.iu_vivit_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
121
+ self.iu_vivit_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
122
+ self.iu_vivit_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
123
+
124
+ self.iu_vivit_norm_3 = bridge_norm_layer(4096)
125
+ self.iu_vivit_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
126
+ self.iu_vivit_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
127
+ self.iu_vivit_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
128
+ print(f'ViViT initialized...')
129
+
130
+ # 4. llama
131
+ with open(os.path.join(llama_ckpt_dir, "params.json"), "r") as f:
132
+ params = json.loads(f.read())
133
+ bias_lora = True
134
+
135
+ if self.args.music_decoder.lower() == "audioldm2":
136
+ self.model_args: ModelArgs = ModelArgs(
137
+ max_seq_len=1024, max_batch_size=1, w_bias=bias_lora, w_lora=bias_lora,
138
+ num_output_tokens=1, output_dim_tokens=137216,
139
+ **params) # max_batch_size only affects inference
140
+ else:
141
+ self.model_args: ModelArgs = ModelArgs(
142
+ max_seq_len=1024, max_batch_size=1, w_bias=bias_lora, w_lora=bias_lora,
143
+ num_output_tokens=128, output_dim_tokens=768,
144
+ **params) # max_batch_size only affects inference
145
+ print(f"model args: {self.model_args}")
146
+
147
+ # 5. tokenizer
148
+ self.tokenizer = LlamaTokenizer.from_pretrained(
149
+ llama_tokenizer) # Tokenizer(model_path=llama_tokenizer, num_aud_tokens=self.model_args.num_gen_audio_tokens)
150
+ self._add_audio_token()
151
+ self.model_args.vocab_size = len(self.tokenizer)
152
+
153
+ if torch.cuda.is_available():
154
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
155
+ self.llama = Transformer(self.model_args)
156
+ torch.set_default_tensor_type(torch.FloatTensor)
157
+
158
+ if load_llama:
159
+ print(f"Loading LLaMA Checkpoint...")
160
+ ckpts = sorted(Path(llama_ckpt_dir).glob("*.pth"))
161
+
162
+ """
163
+ Adapted from https://github.com/cedrickchee/llama/blob/main/chattyllama/combined/inference.py
164
+ """
165
+ key_to_dim = {
166
+ "w1": 0,
167
+ "w2": -1,
168
+ "w3": 0,
169
+ "wo": -1,
170
+ "wq": 0,
171
+ "wk": 0,
172
+ "wv": 0,
173
+ "output": 0,
174
+ "tok_embeddings": 2,
175
+ "ffn_norm": None,
176
+ "attention_norm": None,
177
+ "norm": None,
178
+ "rope": None,
179
+ }
180
+ for i, ckpt in enumerate(ckpts):
181
+ checkpoint = torch.load(ckpt, map_location="cpu")
182
+ for parameter_name, parameter in self.llama.named_parameters():
183
+ short_name = parameter_name.split(".")[-2]
184
+ if "gate" in parameter_name or "lora" in parameter_name or "bias" in parameter_name:
185
+ continue
186
+ if key_to_dim[short_name] is None and i == 0:
187
+ parameter.data = checkpoint[parameter_name]
188
+ elif key_to_dim[short_name] == 0:
189
+ size = checkpoint[parameter_name].size(0)
190
+ parameter.data[size * i: size * (i + 1), :] = checkpoint[
191
+ parameter_name
192
+ ]
193
+ elif key_to_dim[short_name] == -1:
194
+ size = checkpoint[parameter_name].size(-1)
195
+ parameter.data[:, size * i: size * (i + 1)] = checkpoint[
196
+ parameter_name
197
+ ]
198
+ elif key_to_dim[short_name] == 2:
199
+ size = checkpoint[parameter_name].size(-1)
200
+ parameter.data[:-self.model_args.num_gen_audio_tokens, size * i: size * (i + 1)] = checkpoint[
201
+ parameter_name
202
+ ]
203
+ parameter.data[-self.model_args.num_gen_audio_tokens:, :] = 1
204
+ del checkpoint
205
+ print(f"LLaMA Checkpoint Loaded")
206
+
207
+ # 5. projector
208
+ self.output_projector = ProjectionLayer(4096, self.model_args.output_dim_tokens,
209
+ num_input_tokens=self.model_args.num_gen_audio_tokens,
210
+ num_output_tokens=self.model_args.num_output_tokens)
211
+
212
+ # 6. Generator
213
+ if self.args.music_decoder.lower() == "audioldm2":
214
+ # The model files for AudioLDM2 can be downloaded here in case of network issues:
215
+ # https://huggingface.co/cvssp/audioldm2-music
216
+ # And set the music_decoder_path argument to directory with the model files
217
+ print(f'Initialize AudioLDM2...')
218
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
219
+ self.generation_model = AudioLDM2Pipeline.from_pretrained(self.args.music_decoder_path, torch_dtype=dtype)
220
+ self.generation_model.to("cuda")
221
+ print(f'AudioLDM2 initialized...')
222
+ else:
223
+ # The model files for MusicGen can be downloaded here in case of network issues:
224
+ # https://huggingface.co/facebook/musicgen-medium
225
+ # And set the music_decoder_path argument to directory with the model files
226
+ print(f'Initialize MusicGen...')
227
+ self.generation_processor = AutoProcessor.from_pretrained(self.args.music_decoder_path)
228
+ self.generation_model = MusicgenForConditionalGeneration.from_pretrained(self.args.music_decoder_path)
229
+ self.generation_model.eval()
230
+ print(f'MusicGen initialized...')
231
+ self.music_decoder = self.args.music_decoder.lower()
232
+
233
+ # 4. prefix
234
+ self.query_layer = 20
235
+ self.query_len = 1
236
+ self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim)
237
+
238
+ # 5. knn
239
+ self.knn = knn
240
+ if knn:
241
+ import faiss
242
+ self.index = faiss.read_index(download("https://huggingface.co/csuhan/knn/resolve/main/knn.index", knn_dir))
243
+
244
+ # 6. training criterion
245
+ self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
246
+ self.l2_loss = torch.nn.MSELoss()
247
+ self.stage = stage
248
+ self.set_default_trainability(self.stage)
249
+
250
+ def get_trainable_params(self, stage=1):
251
+ trainable = {}
252
+ if stage == 1:
253
+ for name, para in self.named_parameters():
254
+ if "llama." in name:
255
+ if 'norm' in name or 'bias' in name or 'lora' in name:
256
+ trainable[name] = para
257
+ if "mu_mert_" in name:
258
+ trainable[name] = para
259
+ if "iu_vivit_" in name:
260
+ trainable[name] = para
261
+ if "iu_vit_" in name:
262
+ trainable[name] = para
263
+ if "prefix_query" in name:
264
+ trainable[name] = para
265
+ if "output_projector" in name:
266
+ trainable[name] = para
267
+ if "tok_embeddings" in name:
268
+ trainable[name] = para
269
+ elif stage == 2:
270
+ for name, para in self.named_parameters():
271
+ if "llama." in name:
272
+ if 'norm' in name or 'bias' in name or 'lora' in name:
273
+ trainable[name] = para
274
+ if "output_projector" in name:
275
+ trainable[name] = para
276
+ if "prefix_query" in name:
277
+ trainable[name] = para
278
+ if "tok_embeddings" in name:
279
+ trainable[name] = para
280
+ elif stage == 3:
281
+ for name, para in self.named_parameters():
282
+ if "llama." in name:
283
+ if 'norm' in name or 'bias' in name or 'lora' in name:
284
+ trainable[name] = para
285
+ elif "prefix_query" in name:
286
+ trainable[name] = para
287
+ elif "tok_embeddings" in name:
288
+ trainable[name] = para
289
+ return trainable
290
+
291
+ def set_default_trainability(self, stage=1):
292
+ for key, value in self.named_parameters():
293
+ value.requires_grad = False
294
+ trainable_params = self.get_trainable_params(stage)
295
+ print(f"Trainable Params: {trainable_params.keys()}")
296
+ for key, value in trainable_params.items():
297
+ value.data = value.data.float()
298
+ value.requires_grad = True
299
+
300
+ def _add_audio_token(self):
301
+ self.audio_tokens = []
302
+ for i in range(self.model_args.num_gen_audio_tokens):
303
+ print(f'Adding [AUD{i}] token to vocabulary.')
304
+ print(f'Before adding new token, tokenizer("[AUD{i}]") =',
305
+ self.tokenizer(f'[AUD{i}]', add_special_tokens=False))
306
+ num_added_tokens = self.tokenizer.add_tokens([f'[AUD{i}]'])
307
+ print(f'After adding {num_added_tokens} new tokens, tokenizer("[AUD{i}]") =',
308
+ self.tokenizer(f'[AUD{i}]', add_special_tokens=False), ' Number of tokens: ', len(self.tokenizer))
309
+ gen_token_idx = self.tokenizer(f'[AUD{i}]', add_special_tokens=False).input_ids
310
+ assert len(gen_token_idx) == 1, gen_token_idx
311
+ self.audio_tokens.append(gen_token_idx[0])
312
+
313
+ def load_audio(self, audio_path, target_sr=16000):
314
+ y, sr = torchaudio.load(audio_path)
315
+ resampler = torchaudio.transforms.Resample(sr, target_sr, dtype=y.dtype)
316
+ audio = resampler(y)
317
+ return audio, target_sr
318
+
319
+ def encode_audio(self, x):
320
+ xs = []
321
+ for sub_x in x:
322
+ all_inputs = [self.mert_processor(sub_x[ix * self.mert_processor.sampling_rate:min(
323
+ (ix + 60) * self.mert_processor.sampling_rate, len(sub_x))],
324
+ sampling_rate=self.mert_processor.sampling_rate,
325
+ return_tensors="pt").to(self.mert_model.device) for ix in
326
+ range(0, len(sub_x) // (self.mert_processor.sampling_rate * 60) + 1, 60)]
327
+ aggoutputs = torch.zeros(1, 25, 1024).to(self.mert_model.device)
328
+ for inputs in all_inputs:
329
+ with torch.no_grad():
330
+ outputs = self.mert_model(**inputs, output_hidden_states=True)
331
+ all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
332
+ sub_x = all_layer_hidden_states.mean(-2).unsqueeze(0)
333
+ aggoutputs += sub_x
334
+ aggoutputs /= len(all_inputs)
335
+ sub_x = self.mu_mert_agg(aggoutputs.to(self.device)).squeeze()
336
+ del aggoutputs
337
+ xs.append(sub_x)
338
+ x = torch.stack(xs, dim=0)
339
+ return x
340
+
341
+ def encode_image(self, x):
342
+ xs = []
343
+ for sub_x in x:
344
+ inputs = self.vit_processor(images=sub_x, return_tensors="pt").to(self.vit_model.device)
345
+ with torch.no_grad():
346
+ outputs = self.vit_model(**inputs)
347
+ last_hidden_states = outputs.last_hidden_state
348
+ sub_x = self.iu_vit_agg(last_hidden_states.to(self.device)).squeeze()
349
+ xs.append(sub_x)
350
+ return torch.stack(xs, dim=0)
351
+
352
+ def encode_video(self, x):
353
+ xs = []
354
+ for sub_x in x:
355
+ inputs = self.vivit_processor(list(sub_x), padding=True, return_tensors="pt").to(self.vivit_model.device)
356
+ with torch.no_grad():
357
+ outputs = self.vivit_model(**inputs)
358
+ last_hidden_states = outputs.last_hidden_state
359
+ sub_x = self.iu_vivit_agg(last_hidden_states.to(self.device)).squeeze()
360
+ xs.append(sub_x)
361
+ return torch.stack(xs, dim=0)
362
+
363
+ def forward_audio(self, inputs, cache_size=10, cache_t=20, cache_weight=0.5):
364
+ outputs = []
365
+ outputs_weights = []
366
+ for input_type, (input, input_weight) in inputs.items():
367
+ outputs.append(F.normalize(self.encode_audio(input), dim=-1))
368
+ outputs_weights.append(input_weight)
369
+ outputs_weights = [x / (sum(outputs_weights) + 1e-6) for x in outputs_weights]
370
+
371
+ audio_feats = sum([output * output_weight for output, output_weight in zip(outputs, outputs_weights)])
372
+ device = audio_feats.device
373
+
374
+ if self.knn:
375
+ audio_feats_ori = audio_feats
376
+ sims, indices = self.index.search(audio_feats.cpu(), int(cache_size))
377
+ B = sims.shape[0]
378
+ prototypes = [self.index.reconstruct(x) for x in indices.reshape(-1, ).tolist()]
379
+ prototypes = np.vstack(prototypes).reshape(B, int(cache_size), -1) # [N, top_k, 1024]
380
+ sims = torch.tensor(sims, device=device)
381
+ prototypes = torch.tensor(prototypes, device=device)
382
+
383
+ sims = (sims * cache_t).softmax(dim=-1)
384
+ audio_feats = sims @ prototypes
385
+ audio_feats = audio_feats / audio_feats.norm(dim=-1, keepdim=True)
386
+
387
+ audio_feats = (1 - cache_weight) * audio_feats_ori + cache_weight * audio_feats
388
+ audio_feats = audio_feats / audio_feats.norm(dim=-1, keepdim=True)
389
+
390
+ audio_feats = audio_feats.unsqueeze(1) # B, 1, D
391
+ audio_feats = self.mu_mert_proj(audio_feats)
392
+ audio_feats_norm = self.mu_mert_norm_1(audio_feats)
393
+ audio_feats = audio_feats + self.mu_mert_f2_1(
394
+ F.silu(self.mu_mert_f1_1(audio_feats_norm)) * self.mu_mert_f3_1(audio_feats_norm))
395
+
396
+ audio_feats_norm = self.mu_mert_norm_2(audio_feats)
397
+ audio_feats = audio_feats + self.mu_mert_f2_2(
398
+ F.silu(self.mu_mert_f1_2(audio_feats_norm)) * self.mu_mert_f3_2(audio_feats_norm))
399
+
400
+ audio_feats_norm = self.mu_mert_norm_3(audio_feats)
401
+ audio_feats = audio_feats + self.mu_mert_f2_3(
402
+ F.silu(self.mu_mert_f1_3(audio_feats_norm)) * self.mu_mert_f3_3(audio_feats_norm))
403
+ return audio_feats
404
+
405
+ def forward_image(self, inputs, cache_size=10, cache_t=20, cache_weight=0.5):
406
+ outputs = []
407
+ outputs_weights = []
408
+ for input_type, (input, input_weight) in inputs.items():
409
+ outputs.append(F.normalize(self.encode_image(input), dim=-1))
410
+ outputs_weights.append(input_weight)
411
+ outputs_weights = [x / (sum(outputs_weights) + 1e-6) for x in outputs_weights]
412
+
413
+ image_feats = sum([output * output_weight for output, output_weight in zip(outputs, outputs_weights)])
414
+ device = image_feats.device
415
+
416
+ if self.knn:
417
+ image_feats_ori = image_feats
418
+ sims, indices = self.index.search(image_feats.cpu(), int(cache_size))
419
+ B = sims.shape[0]
420
+ prototypes = [self.index.reconstruct(x) for x in indices.reshape(-1, ).tolist()]
421
+ prototypes = np.vstack(prototypes).reshape(B, int(cache_size), -1) # [N, top_k, 1024]
422
+ sims = torch.tensor(sims, device=device)
423
+ prototypes = torch.tensor(prototypes, device=device)
424
+
425
+ sims = (sims * cache_t).softmax(dim=-1)
426
+ image_feats = sims @ prototypes
427
+ image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
428
+
429
+ image_feats = (1 - cache_weight) * image_feats_ori + cache_weight * image_feats
430
+ image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
431
+
432
+ image_feats = image_feats.unsqueeze(1) # B, 1, D
433
+ image_feats = self.iu_vit_proj(image_feats)
434
+ image_feats_norm = self.iu_vit_norm_1(image_feats)
435
+ image_feats = image_feats + self.iu_vit_f2_1(
436
+ F.silu(self.iu_vit_f1_1(image_feats_norm)) * self.iu_vit_f3_1(image_feats_norm))
437
+
438
+ image_feats_norm = self.iu_vit_norm_2(image_feats)
439
+ image_feats = image_feats + self.iu_vit_f2_2(
440
+ F.silu(self.iu_vit_f1_2(image_feats_norm)) * self.iu_vit_f3_2(image_feats_norm))
441
+
442
+ image_feats_norm = self.iu_vit_norm_3(image_feats)
443
+ image_feats = image_feats + self.iu_vit_f2_3(
444
+ F.silu(self.iu_vit_f1_3(image_feats_norm)) * self.iu_vit_f3_3(image_feats_norm))
445
+ return image_feats
446
+
447
+ def forward_video(self, inputs, cache_size=10, cache_t=20, cache_weight=0.5):
448
+ outputs = []
449
+ outputs_weights = []
450
+ for input_type, (input, input_weight) in inputs.items():
451
+ outputs.append(F.normalize(self.encode_video(input), dim=-1))
452
+ outputs_weights.append(input_weight)
453
+ outputs_weights = [x / (sum(outputs_weights) + 1e-6) for x in outputs_weights]
454
+
455
+ video_feats = sum([output * output_weight for output, output_weight in zip(outputs, outputs_weights)])
456
+ device = video_feats.device
457
+
458
+ if self.knn:
459
+ video_feats_ori = video_feats
460
+ sims, indices = self.index.search(video_feats.cpu(), int(cache_size))
461
+ B = sims.shape[0]
462
+ prototypes = [self.index.reconstruct(x) for x in indices.reshape(-1, ).tolist()]
463
+ prototypes = np.vstack(prototypes).reshape(B, int(cache_size), -1) # [N, top_k, 1024]
464
+ sims = torch.tensor(sims, device=device)
465
+ prototypes = torch.tensor(prototypes, device=device)
466
+
467
+ sims = (sims * cache_t).softmax(dim=-1)
468
+ video_feats = sims @ prototypes
469
+ video_feats = video_feats / video_feats.norm(dim=-1, keepdim=True)
470
+
471
+ video_feats = (1 - cache_weight) * video_feats_ori + cache_weight * video_feats
472
+ video_feats = video_feats / video_feats.norm(dim=-1, keepdim=True)
473
+
474
+ video_feats = video_feats.unsqueeze(1) # B, 1, D
475
+ video_feats = self.iu_vivit_proj(video_feats)
476
+ video_feats_norm = self.iu_vivit_norm_1(video_feats)
477
+ video_feats = video_feats + self.iu_vivit_f2_1(
478
+ F.silu(self.iu_vivit_f1_1(video_feats_norm)) * self.iu_vivit_f3_1(video_feats_norm))
479
+
480
+ video_feats_norm = self.iu_vivit_norm_2(video_feats)
481
+ video_feats = video_feats + self.iu_vivit_f2_2(
482
+ F.silu(self.iu_vivit_f1_2(video_feats_norm)) * self.iu_vivit_f3_2(video_feats_norm))
483
+
484
+ video_feats_norm = self.iu_vivit_norm_3(video_feats)
485
+ video_feats = video_feats + self.iu_vivit_f2_3(
486
+ F.silu(self.iu_vivit_f1_3(video_feats_norm)) * self.iu_vivit_f3_3(video_feats_norm))
487
+ return video_feats
488
+
489
+ @torch.inference_mode()
490
+ def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
491
+ _bsz, seqlen = tokens.shape
492
+ h = self.llama.tok_embeddings(tokens)
493
+ freqs_cis = self.llama.freqs_cis.to(h.device)
494
+ freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
495
+
496
+ feats = torch.zeros((1, 1, 4096)).to(self.device)
497
+ if audio_feats is not None:
498
+ feats += audio_feats
499
+ if video_feats is not None:
500
+ feats += video_feats
501
+ if image_feats is not None:
502
+ feats += image_feats
503
+
504
+ mask = None
505
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
506
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
507
+
508
+ music_output_embedding = []
509
+ for layer in self.llama.layers[:-1 * self.query_layer]:
510
+ h = layer(h, 0, freqs_cis, mask)
511
+ music_output_embedding.append(h)
512
+
513
+ prefix_query = self.prefix_query.weight.reshape(self.query_layer, 1, 4096).unsqueeze(1)
514
+
515
+ prefix_index = 0
516
+ for layer in self.llama.layers[-1 * self.query_layer:]:
517
+ h = layer(h, 0, freqs_cis, mask, feats + prefix_query[prefix_index])
518
+ prefix_index = prefix_index + 1
519
+
520
+ h = self.llama.norm(h)
521
+ output = self.llama.output(h[:, -1, :])
522
+
523
+ return output.float(), torch.cat(music_output_embedding[-1:], dim=1)
524
+
525
+ def forward(self, tokens, labels, audios=None, imgs=None, videos=None, music_caption=None):
526
+ feats = torch.zeros((1, 1, 4096)).to(self.device)
527
+ if audios is not None:
528
+ feats += self.forward_audio({'Audio': [audios, 1]})
529
+ if videos is not None:
530
+ feats += self.forward_video({'Video': [videos, 1]})
531
+ if imgs is not None:
532
+ feats += self.forward_image({'Image': [imgs, 1]})
533
+ _bsz, seqlen = tokens.shape
534
+
535
+ h = self.llama.tok_embeddings(tokens.to(self.device))
536
+ freqs_cis = self.llama.freqs_cis.to(h.device)
537
+ freqs_cis = freqs_cis[:seqlen]
538
+ mask = None
539
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
540
+ mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
541
+
542
+ for layer in self.llama.layers[:-1 * self.query_layer]:
543
+ h = layer(h, 0, freqs_cis, mask)
544
+ prefix_query = self.prefix_query.weight.reshape(self.query_layer, 1, 4096).unsqueeze(1)
545
+ prefix_index = 0
546
+
547
+ for layer in self.llama.layers[-1 * self.query_layer:]:
548
+ h = layer(h, 0, freqs_cis, mask, feats + prefix_query[prefix_index])
549
+ prefix_index = prefix_index + 1
550
+
551
+ final_hidden = h
552
+ h = self.llama.norm(h)
553
+ output = self.llama.output(h)
554
+ output = output[:, :-1, :]
555
+ labels = labels[:, 1:]
556
+
557
+ if labels.sum() == 0:
558
+ c_loss = output.mean() * 0
559
+ else:
560
+ assert self.llama.vocab_size == 32000 + self.model_args.num_gen_audio_tokens, self.llama.vocab_size
561
+ c_loss = self.criterion(output.reshape(-1, self.llama.vocab_size), labels.flatten().to(self.device))
562
+
563
+ if music_caption is not None and any([mc != '' for mc in music_caption]):
564
+ if not all([i in output for i in range(32000, 32008)]):
565
+ c_loss += 100
566
+ if self.music_decoder == "audioldm2":
567
+ prompt_embeds, generated_prompt_embeds = self.generation_model(prompt=list(music_caption),
568
+ guidance_scale=1,
569
+ return_prompts_only=True)
570
+ prompt_embeds = prompt_embeds.reshape(prompt_embeds.shape[0], -1)
571
+ generated_prompt_embeds = generated_prompt_embeds.reshape(generated_prompt_embeds.shape[0], -1)
572
+ out_embed = torch.cat([prompt_embeds, generated_prompt_embeds], dim=1)
573
+ out_embed = 10 * out_embed.view(out_embed.size(0), 1, out_embed.size(1)).to(self.device)
574
+ else:
575
+ gen_inputs = self.generation_processor(text=music_caption, padding='max_length',
576
+ max_length=128, truncation=True, return_tensors="pt").to(
577
+ self.device)
578
+ out_embed = 10 * self.generation_model.generate(**gen_inputs, guidance_scale=1, encoder_only=True)
579
+ del gen_inputs
580
+ start_pos = (labels == self.audio_tokens[0]).nonzero(as_tuple=False)[:, 1].tolist()
581
+ assert len(start_pos) != 0, (self.tokenizer.batch_decode(labels), music_caption)
582
+ hidden_states = []
583
+ hidden_embedding = []
584
+ input_embedding = []
585
+ for b, s in enumerate(start_pos):
586
+ hidden_embedding.append(final_hidden[b, s:s + self.model_args.num_gen_audio_tokens, :])
587
+ input_embedding.append(
588
+ self.llama.tok_embeddings(labels[b, s:s + self.model_args.num_gen_audio_tokens].to(self.device)))
589
+ hidden_embedding = torch.stack(hidden_embedding, dim=0).to(self.device)
590
+ input_embedding = torch.stack(input_embedding, dim=0).to(self.device)
591
+ hidden_states.append(self.output_projector(hidden_embedding, input_embedding))
592
+ embeddings = torch.stack(hidden_states, dim=-1).sum(dim=-1)
593
+ mse_loss = self.l2_loss(embeddings, out_embed)
594
+ del hidden_states, input_embedding, hidden_embedding, out_embed, embeddings
595
+ # c_loss += mse_loss
596
+ else:
597
+ if any([i in output for i in range(32000, 32008)]):
598
+ c_loss += 100
599
+ mse_loss = torch.tensor(0.0)
600
+ del feats
601
+ return c_loss, mse_loss
602
+
603
+ @torch.inference_mode()
604
+ def generate_music(self, embeddings, audio_length_in_s, music_caption):
605
+ gen_prefix = ''.join([f'[AUD{i}]' for i in range(len(self.audio_tokens))])
606
+ gen_prefx_ids = self.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(
607
+ self.device)
608
+ gen_prefix_embs = self.llama.tok_embeddings(gen_prefx_ids)
609
+ if self.music_decoder == "audioldm2":
610
+ gen_emb = self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs).squeeze(dim=0) / 10
611
+ prompt_embeds, generated_prompt_embeds = gen_emb[:, :128 * 1024], gen_emb[:, 128 * 1024:]
612
+ prompt_embeds = prompt_embeds.reshape(prompt_embeds.shape[0], 128, 1024)
613
+ generated_prompt_embeds = generated_prompt_embeds.reshape(generated_prompt_embeds.shape[0], 8, 768)
614
+ print("Generating Music...")
615
+ print(music_caption)
616
+ audio_outputs = self.generation_model(music_caption,
617
+ num_inference_steps=200,
618
+ num_waveforms_per_prompt=3,
619
+ negative_prompt='Low quality.',
620
+ audio_length_in_s=audio_length_in_s).audios
621
+ return audio_outputs
622
+ else:
623
+ print("Generating Music...")
624
+ gen_emb = 0.1 * self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs) / 10
625
+ gen_inputs = self.generation_processor(text=music_caption, padding='max_length',
626
+ max_length=128, truncation=True, return_tensors="pt").to(
627
+ self.device)
628
+ #gen_emb = self.generation_model.generate(**gen_inputs, guidance_scale=3.5, encoder_only=True)
629
+ audio_outputs = self.generation_model.generate(**gen_inputs, guidance_scale=3.5,
630
+ max_new_tokens=int(256 / 5 * audio_length_in_s))
631
+ #encoder_outputs=(gen_emb,))
632
+ return audio_outputs[0][0].cpu().detach().numpy()
633
+
634
+ @torch.inference_mode()
635
+ def generate(
636
+ self,
637
+ prompts,
638
+ audios=None,
639
+ imgs=None,
640
+ videos=None,
641
+ max_gen_len: int = 100,
642
+ temperature: float = 0.1,
643
+ top_p: float = 0.75,
644
+ cache_size=10,
645
+ cache_t=20,
646
+ cache_weight=0.5,
647
+ audio_length_in_s=10
648
+ ):
649
+ bsz = len(prompts)
650
+ params = self.llama.params
651
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
652
+
653
+ with torch.cuda.amp.autocast():
654
+ if audios is not None:
655
+ audio_feats = self.forward_audio({'Audio': [[audios], 1]}, cache_size, cache_t, cache_weight)
656
+ else:
657
+ audio_feats = None
658
+ if videos is not None:
659
+ video_feats = self.forward_video({'Video': [[videos], 1]}, cache_size, cache_t, cache_weight)
660
+ else:
661
+ video_feats = None
662
+ if imgs is not None:
663
+ image_feats = self.forward_image({'Image': [[imgs], 1]}, cache_size, cache_t, cache_weight)
664
+ else:
665
+ image_feats = None
666
+
667
+ if isinstance(prompts[0], str):
668
+ prompts = [self.tokenizer(x).input_ids[:, 1:] for x in prompts]
669
+
670
+ min_prompt_size = min([len(t) for t in prompts])
671
+ max_prompt_size = max([len(t) for t in prompts])
672
+
673
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
674
+
675
+ tokens = torch.full((bsz, total_len), 0).cuda().long()
676
+
677
+ for k, t in enumerate(prompts):
678
+ tokens[k, : len(t)] = torch.tensor(t).cuda().long()
679
+ input_text_mask = tokens != 0
680
+ start_pos = min_prompt_size
681
+ prev_pos = 0
682
+ music_output_embeddings = []
683
+ start_gather = 0
684
+ for cur_pos in range(start_pos, total_len):
685
+ with torch.cuda.amp.autocast():
686
+ logits, music_output_embedding = self.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos,
687
+ audio_feats, image_feats, video_feats)
688
+ if temperature > 0:
689
+ probs = torch.softmax(logits / temperature, dim=-1)
690
+ next_token = sample_top_p(probs, top_p)
691
+ else:
692
+ next_token = torch.argmax(logits, dim=-1)
693
+ next_token = next_token.reshape(-1)
694
+
695
+ next_token = torch.where(
696
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
697
+ )
698
+ tokens[:, cur_pos] = next_token
699
+ if next_token[0] == self.audio_tokens[start_gather]:
700
+ if start_gather == 0:
701
+ music_output_embeddings = []
702
+ music_output_embeddings.append(music_output_embedding[:, -1:, :])
703
+ start_gather += 1
704
+ if start_gather >= len(self.audio_tokens):
705
+ start_gather = 0
706
+ # trick: early stop if bsz==1
707
+ if bsz == 1 and self.tokenizer.decode(tokens[0, cur_pos - 2:cur_pos + 1]) == "\n###":
708
+ break
709
+ # prev_pos = cur_pos
710
+
711
+ decoded = []
712
+ for i, t in enumerate(tokens.tolist()):
713
+
714
+ # cut to max gen len
715
+ t = t[len(prompts[i]): len(prompts[i]) + max_gen_len]
716
+ # cut to eos tok if any
717
+ try:
718
+ t = t[: t.index(13)]
719
+ except ValueError:
720
+ pass
721
+ decoded.append(self.tokenizer.decode(t))
722
+
723
+ if len(music_output_embeddings) == len(self.audio_tokens):
724
+ music_output_embeddings = torch.cat(music_output_embeddings, dim=1)
725
+ return [decoded[0], {'aud': [self.generate_music(music_output_embeddings, audio_length_in_s, decoded[0])]}]
726
+
727
+ return [decoded[0]]
728
+
729
+
730
+ def load(model_path, llama_dir, mert_path="m-a-p/MERT-v1-330M", device="cuda" if torch.cuda.is_available() else "cpu",
731
+ knn=False, knn_dir="./ckpts", llama_type="7B", stage=3):
732
+ llama_ckpt_dir = os.path.join(llama_dir, llama_type)
733
+ llama_tokenzier_path = llama_dir
734
+
735
+ # load M2UGen weights and model_cfg
736
+ print(f'Loading LLaMA-Adapter from {model_path}')
737
+ adapter_ckpt = torch.load(model_path, map_location='cpu')
738
+ model_cfg = adapter_ckpt.get('config', {})
739
+
740
+ # The model files for MERT can be downloaded here in case of network issues:
741
+ # https://huggingface.co/m-a-p/MERT-v1-330M
742
+ # And set the MERT argument to directory with the model files
743
+ model = M2UGen(
744
+ llama_ckpt_dir, llama_tokenzier_path, mert_path, knn=knn, knn_dir=knn_dir, stage=stage)
745
+
746
+ load_result = model.load_state_dict(adapter_ckpt['model'], strict=False)
747
+ assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
748
+ return model.to(device)
llama/musicgen/configuration_musicgen.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MusicGen model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+ from transformers.models.auto.configuration_auto import AutoConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "facebook/musicgen-small": "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json",
26
+ # See all Musicgen models at https://huggingface.co/models?filter=musicgen
27
+ }
28
+
29
+
30
+ class MusicgenDecoderConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
33
+ MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
34
+ configuration with the defaults will yield a similar configuration to that of the MusicGen
35
+ [facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 2048):
43
+ Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
44
+ represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
45
+ hidden_size (`int`, *optional*, defaults to 1024):
46
+ Dimensionality of the layers and the pooler layer.
47
+ num_hidden_layers (`int`, *optional*, defaults to 24):
48
+ Number of decoder layers.
49
+ num_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer block.
51
+ ffn_dim (`int`, *optional*, defaults to 4096):
52
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
53
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
54
+ The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
55
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
56
+ dropout (`float`, *optional*, defaults to 0.1):
57
+ The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
58
+ attention_dropout (`float`, *optional*, defaults to 0.0):
59
+ The dropout ratio for the attention probabilities.
60
+ activation_dropout (`float`, *optional*, defaults to 0.0):
61
+ The dropout ratio for activations inside the fully connected layer.
62
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
63
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
64
+ just in case (e.g., 512 or 1024 or 2048).
65
+ initializer_factor (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layerdrop (`float`, *optional*, defaults to 0.0):
68
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
69
+ for more details.
70
+ scale_embedding (`bool`, *optional*, defaults to `False`):
71
+ Scale embeddings by diving by sqrt(hidden_size).
72
+ use_cache (`bool`, *optional*, defaults to `True`):
73
+ Whether the model should return the last key/values attentions (not used by all models)
74
+ num_codebooks (`int`, *optional*, defaults to 4):
75
+ The number of parallel codebooks forwarded to the model.
76
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
77
+ Whether input and output word embeddings should be tied.
78
+ """
79
+ model_type = "musicgen_decoder"
80
+ keys_to_ignore_at_inference = ["past_key_values"]
81
+
82
+ def __init__(
83
+ self,
84
+ vocab_size=2048,
85
+ max_position_embeddings=2048,
86
+ num_hidden_layers=24,
87
+ ffn_dim=4096,
88
+ num_attention_heads=16,
89
+ layerdrop=0.0,
90
+ use_cache=True,
91
+ activation_function="gelu",
92
+ hidden_size=1024,
93
+ dropout=0.1,
94
+ attention_dropout=0.0,
95
+ activation_dropout=0.0,
96
+ initializer_factor=0.02,
97
+ scale_embedding=False,
98
+ num_codebooks=4,
99
+ pad_token_id=2048,
100
+ bos_token_id=2048,
101
+ eos_token_id=None,
102
+ tie_word_embeddings=False,
103
+ **kwargs,
104
+ ):
105
+ self.vocab_size = vocab_size
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.hidden_size = hidden_size
108
+ self.ffn_dim = ffn_dim
109
+ self.num_hidden_layers = num_hidden_layers
110
+ self.num_attention_heads = num_attention_heads
111
+ self.dropout = dropout
112
+ self.attention_dropout = attention_dropout
113
+ self.activation_dropout = activation_dropout
114
+ self.activation_function = activation_function
115
+ self.initializer_factor = initializer_factor
116
+ self.layerdrop = layerdrop
117
+ self.use_cache = use_cache
118
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
119
+ self.num_codebooks = num_codebooks
120
+ super().__init__(
121
+ pad_token_id=pad_token_id,
122
+ bos_token_id=bos_token_id,
123
+ eos_token_id=eos_token_id,
124
+ tie_word_embeddings=tie_word_embeddings,
125
+ **kwargs,
126
+ )
127
+
128
+
129
+ class MusicgenConfig(PretrainedConfig):
130
+ r"""
131
+ This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
132
+ MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
133
+ configs.
134
+
135
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
136
+ documentation from [`PretrainedConfig`] for more information.
137
+
138
+ Args:
139
+ kwargs (*optional*):
140
+ Dictionary of keyword arguments. Notably:
141
+
142
+ - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
143
+ defines the text encoder config.
144
+ - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
145
+ defines the audio encoder config.
146
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
147
+ the decoder config.
148
+
149
+ Example:
150
+
151
+ ```python
152
+ >>> from transformers import (
153
+ ... MusicgenConfig,
154
+ ... MusicgenDecoderConfig,
155
+ ... T5Config,
156
+ ... EncodecConfig,
157
+ ... MusicgenForConditionalGeneration,
158
+ ... )
159
+
160
+ >>> # Initializing text encoder, audio encoder, and decoder model configurations
161
+ >>> text_encoder_config = T5Config()
162
+ >>> audio_encoder_config = EncodecConfig()
163
+ >>> decoder_config = MusicgenDecoderConfig()
164
+
165
+ >>> configuration = MusicgenConfig.from_sub_models_config(
166
+ ... text_encoder_config, audio_encoder_config, decoder_config
167
+ ... )
168
+
169
+ >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
170
+ >>> model = MusicgenForConditionalGeneration(configuration)
171
+
172
+ >>> # Accessing the model configuration
173
+ >>> configuration = model.config
174
+ >>> config_text_encoder = model.config.text_encoder
175
+ >>> config_audio_encoder = model.config.audio_encoder
176
+ >>> config_decoder = model.config.decoder
177
+
178
+ >>> # Saving the model, including its configuration
179
+ >>> model.save_pretrained("musicgen-model")
180
+
181
+ >>> # loading model and config from pretrained folder
182
+ >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
183
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
184
+ ```"""
185
+
186
+ model_type = "musicgen"
187
+ is_composition = True
188
+
189
+ def __init__(self, **kwargs):
190
+ super().__init__(**kwargs)
191
+ if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
192
+ raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
193
+
194
+ text_encoder_config = kwargs.pop("text_encoder")
195
+ text_encoder_model_type = text_encoder_config.pop("model_type")
196
+
197
+ audio_encoder_config = kwargs.pop("audio_encoder")
198
+ audio_encoder_model_type = audio_encoder_config.pop("model_type")
199
+
200
+ decoder_config = kwargs.pop("decoder")
201
+
202
+ self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
203
+ self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
204
+ self.decoder = MusicgenDecoderConfig(**decoder_config)
205
+ self.is_encoder_decoder = True
206
+
207
+ @classmethod
208
+ def from_sub_models_config(
209
+ cls,
210
+ text_encoder_config: PretrainedConfig,
211
+ audio_encoder_config: PretrainedConfig,
212
+ decoder_config: MusicgenDecoderConfig,
213
+ **kwargs,
214
+ ):
215
+ r"""
216
+ Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
217
+ configurations.
218
+
219
+ Returns:
220
+ [`MusicgenConfig`]: An instance of a configuration object
221
+ """
222
+
223
+ return cls(
224
+ text_encoder=text_encoder_config.to_dict(),
225
+ audio_encoder=audio_encoder_config.to_dict(),
226
+ decoder=decoder_config.to_dict(),
227
+ **kwargs,
228
+ )
229
+
230
+ @property
231
+ # This is a property because you might want to change the codec model on the fly
232
+ def sampling_rate(self):
233
+ return self.audio_encoder.sampling_rate
llama/musicgen/modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import torch
17
+
18
+
19
+ class AttentionMaskConverter:
20
+ """
21
+ A utility attention mask class that allows one to:
22
+ - Create a causal 4d mask
23
+ - Create a causal 4d mask with slided window
24
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
25
+ key_value_length) that can be multiplied with attention scores
26
+
27
+ Parameters:
28
+ is_causal (`bool`):
29
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
30
+
31
+ sliding_window (`int`, *optional*):
32
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
33
+ """
34
+
35
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
36
+ self.is_causal = is_causal
37
+ self.sliding_window = sliding_window
38
+
39
+ if self.sliding_window is not None and self.sliding_window <= 0:
40
+ raise ValueError(
41
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
42
+ )
43
+
44
+ def to_causal_4d(
45
+ self,
46
+ batch_size: int,
47
+ query_length: int,
48
+ key_value_length: int,
49
+ dtype: torch.dtype = torch.float32,
50
+ device: Union[torch.device, "str"] = "cpu",
51
+ ) -> torch.Tensor:
52
+ """
53
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
54
+ bias to upper right hand triangular matrix (causal mask).
55
+ """
56
+ if not self.is_causal:
57
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
58
+
59
+ # If shape is not cached, create a new causal mask and cache it
60
+ input_shape = (batch_size, query_length)
61
+ past_key_values_length = key_value_length - query_length
62
+
63
+ # create causal mask
64
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
65
+ causal_4d_mask = None
66
+ if input_shape[-1] > 1 or self.sliding_window is not None:
67
+ causal_4d_mask = self._make_causal_mask(
68
+ input_shape,
69
+ dtype,
70
+ device=device,
71
+ past_key_values_length=past_key_values_length,
72
+ sliding_window=self.sliding_window,
73
+ )
74
+
75
+ return causal_4d_mask
76
+
77
+ def to_4d(
78
+ self,
79
+ attention_mask_2d: torch.Tensor,
80
+ query_length: int,
81
+ key_value_length: Optional[int] = None,
82
+ dtype: torch.dtype = torch.float32,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
86
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
87
+ causal, a causal mask will be added.
88
+ """
89
+ input_shape = (attention_mask_2d.shape[0], query_length)
90
+
91
+ # create causal mask
92
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
93
+ causal_4d_mask = None
94
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
95
+ if key_value_length is None:
96
+ raise ValueError(
97
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
98
+ )
99
+
100
+ past_key_values_length = key_value_length - query_length
101
+ causal_4d_mask = self._make_causal_mask(
102
+ input_shape,
103
+ dtype,
104
+ device=attention_mask_2d.device,
105
+ past_key_values_length=past_key_values_length,
106
+ sliding_window=self.sliding_window,
107
+ )
108
+ elif self.sliding_window is not None:
109
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
110
+
111
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
112
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
113
+ attention_mask_2d.device
114
+ )
115
+ expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
116
+
117
+ return expanded_4d_mask
118
+
119
+ @staticmethod
120
+ def _make_causal_mask(
121
+ input_ids_shape: torch.Size,
122
+ dtype: torch.dtype,
123
+ device: torch.device,
124
+ past_key_values_length: int = 0,
125
+ sliding_window: Optional[int] = None,
126
+ ):
127
+ """
128
+ Make causal mask used for bi-directional self-attention.
129
+ """
130
+ bsz, tgt_len = input_ids_shape
131
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
132
+ mask_cond = torch.arange(mask.size(-1), device=device)
133
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
134
+
135
+ mask = mask.to(dtype)
136
+
137
+ if past_key_values_length > 0:
138
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
139
+
140
+ # add lower triangular sliding window mask if necessary
141
+ if sliding_window is not None:
142
+ diagonal = past_key_values_length - sliding_window + 1
143
+
144
+ context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
145
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
146
+
147
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
148
+
149
+ @staticmethod
150
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
151
+ """
152
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
153
+ """
154
+ bsz, src_len = mask.size()
155
+ tgt_len = tgt_len if tgt_len is not None else src_len
156
+
157
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
158
+
159
+ inverted_mask = 1.0 - expanded_mask
160
+
161
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
162
+
163
+
164
+ def _prepare_4d_causal_attention_mask(
165
+ attention_mask: Optional[torch.Tensor],
166
+ input_shape: Union[torch.Size, Tuple, List],
167
+ inputs_embeds: torch.Tensor,
168
+ past_key_values_length: int,
169
+ sliding_window: Optional[int] = None,
170
+ ):
171
+ """
172
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
173
+ `(batch_size, key_value_length)`
174
+
175
+ Args:
176
+ attention_mask (`torch.Tensor` or `None`):
177
+ A 2D attention mask of shape `(batch_size, key_value_length)`
178
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
179
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
180
+ inputs_embeds (`torch.Tensor`):
181
+ The embedded inputs as a torch Tensor.
182
+ past_key_values_length (`int`):
183
+ The length of the key value cache.
184
+ sliding_window (`int`, *optional*):
185
+ If the model uses windowed attention, a sliding window should be passed.
186
+ """
187
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
188
+
189
+ key_value_length = input_shape[-1] + past_key_values_length
190
+
191
+ # 4d mask is passed through the layers
192
+ if attention_mask is not None:
193
+ attention_mask = attn_mask_converter.to_4d(
194
+ attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
195
+ )
196
+ else:
197
+ attention_mask = attn_mask_converter.to_causal_4d(
198
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
199
+ )
200
+
201
+ return attention_mask
202
+
203
+
204
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
205
+ """
206
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
207
+ `(batch_size, key_value_length)`
208
+
209
+ Args:
210
+ mask (`torch.Tensor` or `None`):
211
+ A 2D attention mask of shape `(batch_size, key_value_length)`
212
+ dtype (`torch.dtype`):
213
+ The torch dtype the created mask shall have.
214
+ tgt_len (`int`):
215
+ The target length or query length the created mask shall have.
216
+ """
217
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
218
+
219
+
220
+ def _create_4d_causal_attention_mask(
221
+ input_shape: Union[torch.Size, Tuple, List],
222
+ dtype: torch.dtype,
223
+ device: torch.device,
224
+ past_key_values_length: int = 0,
225
+ sliding_window: Optional[int] = None,
226
+ ):
227
+ """
228
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
229
+
230
+ Args:
231
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
232
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
233
+ dtype (`torch.dtype`):
234
+ The torch dtype the created mask shall have.
235
+ device (`int`):
236
+ The torch device the created mask shall have.
237
+ sliding_window (`int`, *optional*):
238
+ If the model uses windowed attention, a sliding window should be passed.
239
+ """
240
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
241
+
242
+ key_value_length = past_key_values_length + input_shape[-1]
243
+ attention_mask = attn_mask_converter.to_causal_4d(
244
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
245
+ )
246
+
247
+ return attention_mask
llama/musicgen/musicgen.py ADDED
The diff for this file is too large to render. See raw diff
 
llama/projector.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class ProjectionLayer(nn.Module):
5
+ """Layers used in mapping text embeddings to visual outputs."""
6
+
7
+ def __init__(self, in_dim: int, out_dim: int, num_input_tokens: int = 1, num_output_tokens: int = 1):
8
+ super().__init__()
9
+
10
+ self.num_input_tokens = num_input_tokens
11
+ self.num_output_tokens = num_output_tokens
12
+ self.out_dim = out_dim
13
+
14
+ hidden_dim = 512
15
+ self.fc = nn.Linear(in_dim, hidden_dim)
16
+ self.tfm = nn.Transformer(batch_first=True, norm_first=False,
17
+ d_model=hidden_dim, num_encoder_layers=4, num_decoder_layers=4,
18
+ dim_feedforward=hidden_dim * 4, dropout=0.0, nhead=4)
19
+ self.model = nn.Linear(hidden_dim, out_dim)
20
+ self.query_embs = nn.Parameter(torch.randn(1, num_output_tokens, hidden_dim))
21
+
22
+ def forward(self, x: torch.Tensor, input_embs: torch.Tensor) -> torch.Tensor:
23
+ outputs = None
24
+ x = x + input_embs
25
+ x = self.fc(x)
26
+ x = self.tfm(x, self.query_embs.repeat(x.shape[0], 1, 1))
27
+ outputs = self.model(x)
28
+
29
+ assert outputs.shape[1] == 1 or (
30
+ outputs.shape[1] * outputs.shape[2] == self.num_output_tokens * self.out_dim), (
31
+ outputs.shape, self.num_output_tokens)
32
+ return outputs # (N, T_I_V_A.txt, D)
llama/tokenizer.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from sentencepiece import SentencePieceProcessor
5
+ import sentencepiece.sentencepiece_model_pb2 as model
6
+ from logging import getLogger
7
+ from typing import List
8
+ import os
9
+
10
+
11
+ logger = getLogger()
12
+
13
+
14
+ class Tokenizer:
15
+ def __init__(self, model_path: str, num_aud_tokens: int):
16
+ # reload tokenizer
17
+ assert os.path.isfile(model_path), model_path
18
+ m = model.ModelProto()
19
+ m.ParseFromString(open(model_path, "rb").read())
20
+ special_tokens = [f'[AUD{i}]' for i in range(num_aud_tokens)]
21
+ for token in special_tokens:
22
+ new_token = model.ModelProto().SentencePiece()
23
+ new_token.piece = token
24
+ new_token.score = 0
25
+ if new_token in m.pieces:
26
+ m.pieces.remove(new_token)
27
+ m.pieces.append(new_token)
28
+ with open(model_path, 'wb') as f:
29
+ f.write(m.SerializeToString())
30
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
31
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
32
+
33
+ # BOS / EOS token IDs
34
+ self.n_words: int = self.sp_model.vocab_size()
35
+ self.bos_id: int = self.sp_model.bos_id()
36
+ self.eos_id: int = self.sp_model.eos_id()
37
+ self.pad_id: int = self.sp_model.pad_id()
38
+ logger.info(
39
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
40
+ )
41
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
42
+
43
+
44
+
45
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
46
+ assert type(s) is str
47
+ t = self.sp_model.encode_as_ids(s)
48
+ if bos:
49
+ t = [self.bos_id] + t
50
+ if eos:
51
+ t = t + [self.eos_id]
52
+ return t
53
+
54
+ def decode(self, t: List[int]) -> str:
55
+ return self.sp_model.decode(t)
llama/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def sample_top_p(probs, p):
5
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
6
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
7
+ mask = probs_sum - probs_sort > p
8
+ probs_sort[mask] = 0.0
9
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
10
+ next_token = torch.multinomial(probs_sort, num_samples=1)
11
+ next_token = torch.gather(probs_idx, -1, next_token)
12
+ return next_token
13
+
14
+
15
+ def format_prompt(instruction):
16
+
17
+ PROMPT_DICT = {
18
+ "prompt_input": (
19
+ "Below is an instruction that describes a task. "
20
+ "Write a response that appropriately completes the request.\n\n"
21
+ "### Instruction:\n{instruction}\n\n### Response:"
22
+ )
23
+ }
24
+ return PROMPT_DICT["prompt_input"].format_map({'instruction': instruction})
25
+