txt2audio commited on
Commit
fa25a07
1 Parent(s): 56c9694
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +144 -282
  2. configs/text_to_audio/bigvgan_args.yaml +63 -0
  3. configs/text_to_audio/clap_args.yaml +26 -0
  4. configs/text_to_audio/txt2audio_args.yaml +78 -0
  5. ldm/lr_scheduler.py +98 -0
  6. ldm/models/autoencoder.py +474 -0
  7. ldm/models/autoencoder_multi.py +201 -0
  8. ldm/models/diffusion/__init__.py +0 -0
  9. ldm/models/diffusion/classifier.py +267 -0
  10. ldm/models/diffusion/ddim.py +262 -0
  11. ldm/models/diffusion/ddpm.py +1444 -0
  12. ldm/models/diffusion/ddpm_audio.py +1262 -0
  13. ldm/models/diffusion/ddpm_audio_inpaint.py +1081 -0
  14. ldm/models/diffusion/plms.py +236 -0
  15. ldm/modules/attention.py +261 -0
  16. ldm/modules/diffusionmodules/__init__.py +0 -0
  17. ldm/modules/diffusionmodules/custom_openaimodel.py +368 -0
  18. ldm/modules/diffusionmodules/model.py +835 -0
  19. ldm/modules/diffusionmodules/openaimodel.py +963 -0
  20. ldm/modules/diffusionmodules/util.py +267 -0
  21. ldm/modules/discriminator/model.py +295 -0
  22. ldm/modules/discriminator/multi_window_disc.py +196 -0
  23. ldm/modules/distributions/__init__.py +0 -0
  24. ldm/modules/distributions/distributions.py +92 -0
  25. ldm/modules/ema.py +76 -0
  26. ldm/modules/encoders/CLAP/CLAPWrapper.py +257 -0
  27. ldm/modules/encoders/CLAP/__init__.py +3 -0
  28. ldm/modules/encoders/CLAP/audio.py +179 -0
  29. ldm/modules/encoders/CLAP/clap.py +89 -0
  30. ldm/modules/encoders/CLAP/config.yml +26 -0
  31. ldm/modules/encoders/CLAP/utils.py +26 -0
  32. ldm/modules/encoders/__init__.py +0 -0
  33. ldm/modules/encoders/modules.py +314 -0
  34. ldm/modules/encoders/open_clap/__init__.py +8 -0
  35. ldm/modules/encoders/open_clap/bert.py +32 -0
  36. ldm/modules/encoders/open_clap/bpe_simple_vocab_16e6.txt.gz +3 -0
  37. ldm/modules/encoders/open_clap/factory.py +257 -0
  38. ldm/modules/encoders/open_clap/feature_fusion.py +193 -0
  39. ldm/modules/encoders/open_clap/htsat.py +1022 -0
  40. ldm/modules/encoders/open_clap/linear_probe.py +63 -0
  41. ldm/modules/encoders/open_clap/loss.py +307 -0
  42. ldm/modules/encoders/open_clap/model.py +913 -0
  43. ldm/modules/encoders/open_clap/model_configs/HTSAT-base.json +23 -0
  44. ldm/modules/encoders/open_clap/model_configs/HTSAT-large.json +23 -0
  45. ldm/modules/encoders/open_clap/model_configs/HTSAT-tiny-win-1536.json +23 -0
  46. ldm/modules/encoders/open_clap/model_configs/HTSAT-tiny.json +23 -0
  47. ldm/modules/encoders/open_clap/model_configs/PANN-10.json +23 -0
  48. ldm/modules/encoders/open_clap/model_configs/PANN-14-fmax-18k.json +23 -0
  49. ldm/modules/encoders/open_clap/model_configs/PANN-14-fmax-8k-20s.json +23 -0
  50. ldm/modules/encoders/open_clap/model_configs/PANN-14-tiny-transformer.json +23 -0
app.py CHANGED
@@ -1,285 +1,147 @@
1
- from langchain.agents.initialize import initialize_agent
2
- from langchain.agents.tools import Tool
3
- from langchain.chains.conversation.memory import ConversationBufferMemory
4
- from langchain.llms.openai import OpenAI
5
- from audio_foundation_models import *
6
  import gradio as gr
7
-
8
- _DESCRIPTION = '# [AudioGPT](https://github.com/AIGC-Audio/AudioGPT)'
9
- _DESCRIPTION += '\n<p>This is a demo to the work <a href="https://github.com/AIGC-Audio/AudioGPT" style="text-decoration: underline;" target="_blank">AudioGPT: Understanding and Generating Speech, Music, Sound, and Talking Head</a>. </p>'
10
- _DESCRIPTION += '\n<p>This model can only be used for non-commercial purposes.'
11
- if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
12
- _DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
13
-
14
-
15
- AUDIO_CHATGPT_PREFIX = """AudioGPT
16
- AudioGPT can not directly read audios, but it has a list of tools to finish different speech, audio, and singing voice tasks. Each audio will have a file name formed as "audio/xxx.wav". When talking about audios, AudioGPT is very strict to the file name and will never fabricate nonexistent files.
17
- AudioGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the audio content and audio file name. It will remember to provide the file name from the last tool observation, if a new audio is generated.
18
- Human may provide new audios to AudioGPT with a description. The description helps AudioGPT to understand this audio, but AudioGPT should use tools to finish following tasks, rather than directly imagine from the description.
19
- Overall, AudioGPT is a powerful audio dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
20
- TOOLS:
21
- ------
22
- AudioGPT has access to the following tools:"""
23
-
24
- AUDIO_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
25
- ```
26
- Thought: Do I need to use a tool? Yes
27
- Action: the action to take, should be one of [{tool_names}]
28
- Action Input: the input to the action
29
- Observation: the result of the action
30
- ```
31
- When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
32
- ```
33
- Thought: Do I need to use a tool? No
34
- {ai_prefix}: [your response here]
35
- ```
36
- """
37
-
38
- AUDIO_CHATGPT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if not exists.
39
- You will remember to provide the audio file name loyally if it's provided in the last tool observation.
40
- Begin!
41
- Previous conversation history:
42
- {chat_history}
43
- New input: {input}
44
- Thought: Do I need to use a tool? {agent_scratchpad}"""
45
-
46
- def cut_dialogue_history(history_memory, keep_last_n_words = 500):
47
- tokens = history_memory.split()
48
- n_tokens = len(tokens)
49
- print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
50
- if n_tokens < keep_last_n_words:
51
- return history_memory
52
- else:
53
- paragraphs = history_memory.split('\n')
54
- last_n_tokens = n_tokens
55
- while last_n_tokens >= keep_last_n_words:
56
- last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
57
- paragraphs = paragraphs[1:]
58
- return '\n' + '\n'.join(paragraphs)
59
-
60
- class ConversationBot:
61
- def __init__(self, load_dict):
62
- print("Initializing AudioGPT")
63
- self.tools = []
64
- self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
65
- self.models = dict()
66
- for class_name, device in load_dict.items():
67
- self.models[class_name] = globals()[class_name](device=device)
68
- for class_name, instance in self.models.items():
69
- for e in dir(instance):
70
- if e.startswith('inference'):
71
- func = getattr(instance, e)
72
- self.tools.append(Tool(name=func.name, description=func.description, func=func))
73
-
74
- def run_text(self, text, state):
75
- print("===============Running run_text =============")
76
- print("Inputs:", text, state)
77
- print("======>Previous memory:\n %s" % self.agent.memory)
78
- self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
79
- res = self.agent({"input": text})
80
- if res['intermediate_steps'] == []:
81
- print("======>Current memory:\n %s" % self.agent.memory)
82
- response = res['output']
83
- state = state + [(text, response)]
84
- print("Outputs:", state)
85
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
86
- else:
87
- tool = res['intermediate_steps'][0][0].tool
88
- if tool == "Generate Image From User Input Text":
89
- res['output'] = res['output'].replace("\\", "/")
90
- response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
91
- state = state + [(text, response)]
92
- print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
93
- f"Current Memory: {self.agent.memory.buffer}")
94
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
95
- elif tool == "Detect The Sound Event From The Audio":
96
- image_filename = res['intermediate_steps'][0][1]
97
- response = res['output'] + f"![](/file={image_filename})*{image_filename}*"
98
- state = state + [(text, response)]
99
- print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
100
- f"Current Memory: {self.agent.memory.buffer}")
101
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
102
- elif tool == "Generate Text From The Audio" or tool == "Transcribe speech" or tool == "Target Sound Detection":
103
- print("======>Current memory:\n %s" % self.agent.memory)
104
- response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
105
- image_filename = res['intermediate_steps'][0][1]
106
- #response = res['output'] + f"![](/file={image_filename})*{image_filename}*"
107
- state = state + [(text, response)]
108
- print("Outputs:", state)
109
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
110
- elif tool == "Audio Inpainting":
111
- audio_filename = res['intermediate_steps'][0][0].tool_input
112
- image_filename = res['intermediate_steps'][0][1]
113
- print("======>Current memory:\n %s" % self.agent.memory)
114
- print(res)
115
- response = res['output']
116
- state = state + [(text, response)]
117
- print("Outputs:", state)
118
- return state, state, gr.Audio.update(value=audio_filename,visible=True), gr.Image.update(value=image_filename,visible=True), gr.Button.update(visible=True)
119
- print("======>Current memory:\n %s" % self.agent.memory)
120
- response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
121
- audio_filename = res['intermediate_steps'][0][1]
122
- state = state + [(text, response)]
123
- print("Outputs:", state)
124
- return state, state, gr.Audio.update(value=audio_filename,visible=True), gr.Image.update(visible=False), gr.Button.update(visible=False)
125
-
126
- def run_image_or_audio(self, file, state, txt):
127
- file_type = file.name[-3:]
128
- if file_type == "wav":
129
- print("===============Running run_audio =============")
130
- print("Inputs:", file, state)
131
- print("======>Previous memory:\n %s" % self.agent.memory)
132
- audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
133
- audio_load = whisper.load_audio(file.name)
134
- soundfile.write(audio_filename, audio_load, samplerate = 16000)
135
- description = self.models['A2T'].inference(audio_filename)
136
- Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
137
- "rather than directly imagine from my description. If you understand, say \"Received\". \n".format(audio_filename, description)
138
- AI_prompt = "Received. "
139
- self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
140
- # AI_prompt = "Received. "
141
- # self.agent.memory.buffer = self.agent.memory.buffer + 'AI: ' + AI_prompt
142
- print("======>Current memory:\n %s" % self.agent.memory)
143
- #state = state + [(f"<audio src=audio_filename controls=controls></audio>*{audio_filename}*", AI_prompt)]
144
- state = state + [(f"*{audio_filename}*", AI_prompt)]
145
- print("Outputs:", state)
146
- return state, state, txt + ' ' + audio_filename + ' ', gr.Audio.update(value=audio_filename,visible=True)
147
- else:
148
- # print("===============Running run_image =============")
149
- # print("Inputs:", file, state)
150
- # print("======>Previous memory:\n %s" % self.agent.memory)
151
- image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
152
- print("======>Auto Resize Image...")
153
- img = Image.open(file.name)
154
- width, height = img.size
155
- ratio = min(512 / width, 512 / height)
156
- width_new, height_new = (round(width * ratio), round(height * ratio))
157
- width_new = int(np.round(width_new / 64.0)) * 64
158
- height_new = int(np.round(height_new / 64.0)) * 64
159
- img = img.resize((width_new, height_new))
160
- img = img.convert('RGB')
161
- img.save(image_filename, "PNG")
162
- print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
163
- description = self.models['ImageCaptioning'].inference(image_filename)
164
- Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
165
- "rather than directly imagine from my description. If you understand, say \"Received\". \n".format(image_filename, description)
166
- AI_prompt = "Received. "
167
- self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
168
- print("======>Current memory:\n %s" % self.agent.memory)
169
- state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
170
- print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
171
- f"Current Memory: {self.agent.memory.buffer}")
172
- return state, state, txt + f'{txt} {image_filename} ', gr.Audio.update(visible=False)
173
-
174
- def inpainting(self, state, audio_filename, image_filename):
175
- print("===============Running inpainting =============")
176
- print("Inputs:", state)
177
- print("======>Previous memory:\n %s" % self.agent.memory)
178
- # inpaint = Inpaint(device="cpu")
179
- new_image_filename, new_audio_filename = self.models['Inpaint'].predict(audio_filename, image_filename)
180
- AI_prompt = "Here are the predict audio and the mel spectrum." + f"*{new_audio_filename}*" + f"![](/file={new_image_filename})*{new_image_filename}*"
181
- self.agent.memory.buffer = self.agent.memory.buffer + 'AI: ' + AI_prompt
182
- print("======>Current memory:\n %s" % self.agent.memory)
183
- state = state + [(f"Audio Inpainting", AI_prompt)]
184
- print("Outputs:", state)
185
- return state, state, gr.Image.update(visible=False), gr.Audio.update(value=new_audio_filename, visible=True), gr.Button.update(visible=False)
186
- def clear_audio(self):
187
- return gr.Audio.update(value=None, visible=False)
188
- def clear_image(self):
189
- return gr.Image.update(value=None, visible=False)
190
- def clear_button(self):
191
- return gr.Button.update(visible=False)
192
- def init_agent(self, openai_api_key):
193
- os.system('nvidia-smi')
194
- self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
195
- self.agent = initialize_agent(
196
- self.tools,
197
- self.llm,
198
- agent="conversational-react-description",
199
- verbose=True,
200
- memory=self.memory,
201
- return_intermediate_steps=True,
202
- agent_kwargs={'prefix': AUDIO_CHATGPT_PREFIX, 'format_instructions': AUDIO_CHATGPT_FORMAT_INSTRUCTIONS, 'suffix': AUDIO_CHATGPT_SUFFIX}, )
203
- return gr.update(visible = True)
204
-
205
-
206
-
207
- if __name__ == '__main__':
208
- bot = ConversationBot({#'ImageCaptioning': 'cuda:0',
209
- 'T2A': 'cuda:0',
210
- #'I2A': 'cuda:0'
211
- #'TTS_OOD':'cuda:0'
212
- 'TTS': 'cuda:0',
213
- #'T2S': 'cuda:0'
214
- 'ASR': 'cuda:0',
215
- 'A2T': 'cuda:0',
216
- #'Inpaint': 'cuda:0',
217
- #'SoundDetection': 'cuda:0'
218
- 'Binaural': 'cuda:0'
219
- #'SoundExtraction': 'cuda:0'
220
- #'TargetSoundDetection': 'cuda:0',
221
- #'Speech_Enh_SC': 'cuda:0'
222
- #'Speech_SS': 'cuda:0'
223
- })
224
- with gr.Blocks(css="#chatbot {overflow:auto; height:500px;}") as demo:
225
- gr.Markdown(_DESCRIPTION)
226
-
227
- with gr.Row():
228
- openai_api_key_textbox = gr.Textbox(
229
- placeholder="Paste your OpenAI API key here to start Audio ChatGPT(sk-...) and press Enter ↵️",
230
- show_label=False,
231
- lines=1,
232
- type="password",
233
- )
234
-
235
- chatbot = gr.Chatbot(elem_id="chatbot", label="Audio ChatGPT")
236
- state = gr.State([])
237
- with gr.Row(visible = False) as input_raws:
238
- with gr.Column(scale=0.7):
239
- txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(container=False)
240
- with gr.Column(scale=0.1, min_width=0):
241
- run = gr.Button("🏃‍♂️Run")
242
- with gr.Column(scale=0.1, min_width=0):
243
- clear = gr.Button("🔄Clear️")
244
- with gr.Column(scale=0.1, min_width=0):
245
- btn = gr.UploadButton("🖼️/🎙️ Upload", file_types=["image","audio"])
246
- with gr.Row():
247
- with gr.Column():
248
- outaudio = gr.Audio(visible=False)
249
- with gr.Row():
250
- with gr.Column():
251
- show_mel = gr.Image(type="filepath",tool='sketch',visible=False)
252
- with gr.Row():
253
- with gr.Column():
254
- run_button = gr.Button("Predict Masked Place",visible=False)
255
- gr.Examples(
256
- examples=["Generate a speech with text 'here we go'",
257
- "Transcribe this speech",
258
- "Transfer the mono speech to a binaural one",
259
- "Generate an audio of a dog barking",
260
- "Generate an audio of this uploaded image",
261
- "Give me the description of this audio",
262
- "I want to inpaint it",
263
- "What events does this audio include?",
264
- "When did the thunder happen in this audio?",
265
- "Extract the thunder event from this audio",
266
- "Generate a piece of singing voice. Text sequence is 小酒窝长睫毛AP是你最美的记号. Note sequence is C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4. Note duration sequence is 0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340.",
267
- ],
268
- inputs=txt
269
  )
270
 
271
- openai_api_key_textbox.submit(bot.init_agent, [openai_api_key_textbox], [input_raws])
272
- txt.submit(bot.run_text, [txt, state], [chatbot, state, outaudio, show_mel, run_button])
273
- txt.submit(lambda: "", None, txt)
274
- run.click(bot.run_text, [txt, state], [chatbot, state, outaudio, show_mel, run_button])
275
- run.click(lambda: "", None, txt)
276
- btn.upload(bot.run_image_or_audio, [btn, state, txt], [chatbot, state, txt, outaudio])
277
- run_button.click(bot.inpainting, [state, outaudio, show_mel], [chatbot, state, show_mel, outaudio, run_button])
278
- clear.click(bot.memory.clear)
279
- clear.click(lambda: [], None, chatbot)
280
- clear.click(lambda: [], None, state)
281
- clear.click(lambda:None, None, txt)
282
- clear.click(bot.clear_button, None, run_button)
283
- clear.click(bot.clear_image, None, show_mel)
284
- clear.click(bot.clear_audio, None, outaudio)
285
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
 
 
 
3
  import gradio as gr
4
+ from PIL import Image
5
+ from omegaconf import OmegaConf
6
+ from pathlib import Path
7
+ from vocoder.bigvgan.models import VocoderBigVGAN
8
+ from ldm.models.diffusion.ddim import DDIMSampler
9
+ from ldm.util import instantiate_from_config
10
+ from wav_evaluation.models.CLAPWrapper import CLAPWrapper
11
+
12
+ SAMPLE_RATE = 16000
13
+
14
+ torch.set_grad_enabled(False)
15
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
+
17
+ def dur_to_size(duration):
18
+ latent_width = int(duration * 7.8)
19
+ if latent_width % 4 != 0:
20
+ latent_width = (latent_width // 4 + 1) * 4
21
+ return latent_width
22
+
23
+ def initialize_model(config, ckpt):
24
+ config = OmegaConf.load(config)
25
+ model = instantiate_from_config(config.model)
26
+ model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
27
+
28
+ model = model.to(device)
29
+ model.cond_stage_model.to(model.device)
30
+ model.cond_stage_model.device = model.device
31
+ print(model.device,device,model.cond_stage_model.device)
32
+ sampler = DDIMSampler(model)
33
+
34
+ return sampler
35
+
36
+ sampler = initialize_model('configs/text_to_audio/txt2audio_args.yaml', 'useful_ckpts/maa1_caps.ckpt')
37
+ vocoder = VocoderBigVGAN('vocoder/logs/bigvnat',device=device)
38
+ clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
39
+
40
+ def select_best_audio(prompt,wav_list):
41
+ text_embeddings = clap_model.get_text_embeddings([prompt])
42
+ score_list = []
43
+ for data in wav_list:
44
+ sr,wav = data
45
+ audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav),sr)], resample=True)
46
+ score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False).squeeze().cpu().numpy()
47
+ score_list.append(score)
48
+ max_index = np.array(score_list).argmax()
49
+ print(score_list,max_index)
50
+ return wav_list[max_index]
51
+
52
+ def txt2audio(sampler,vocoder,prompt, seed, scale, ddim_steps, n_samples=1, W=624, H=80):
53
+ prng = np.random.RandomState(seed)
54
+ start_code = prng.randn(n_samples, sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
55
+ start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
56
+
57
+ uc = None
58
+ if scale != 1.0:
59
+ uc = sampler.model.get_learned_conditioning(n_samples * [""])
60
+ c = sampler.model.get_learned_conditioning(n_samples * [prompt])# shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding
61
+ shape = [sampler.model.first_stage_model.embed_dim, H//8, W//8] # (z_dim, 80//2^x, 848//2^x)
62
+ samples_ddim, _ = sampler.sample(S=ddim_steps,
63
+ conditioning=c,
64
+ batch_size=n_samples,
65
+ shape=shape,
66
+ verbose=False,
67
+ unconditional_guidance_scale=scale,
68
+ unconditional_conditioning=uc,
69
+ x_T=start_code)
70
+
71
+ x_samples_ddim = sampler.model.decode_first_stage(samples_ddim)
72
+
73
+ wav_list = []
74
+ for idx,spec in enumerate(x_samples_ddim):
75
+ wav = vocoder.vocode(spec)
76
+ wav_list.append((SAMPLE_RATE,wav))
77
+ best_wav = select_best_audio(prompt,wav_list)
78
+ return best_wav
79
+
80
+
81
+ def predict(prompt, ddim_steps, num_samples, scale, seed):
82
+ melbins,mel_len = 80,624
83
+ with torch.no_grad():
84
+ result = txt2audio(
85
+ sampler=sampler,
86
+ vocoder=vocoder,
87
+ prompt=prompt,
88
+ seed=seed,
89
+ scale=scale,
90
+ ddim_steps=ddim_steps,
91
+ n_samples=num_samples,
92
+ H=melbins, W=mel_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
94
 
95
+ return result
96
+
97
+
98
+ with gr.Blocks() as demo:
99
+ with gr.Row():
100
+ gr.Markdown("## Make-An-Audio: Text-to-Audio Generation")
101
+
102
+ with gr.Row():
103
+ with gr.Column():
104
+ prompt = gr.Textbox(label="Prompt: Input your text here. ")
105
+ run_button = gr.Button(label="Run")
106
+
107
+
108
+ with gr.Accordion("Advanced options", open=False):
109
+ num_samples = gr.Slider(
110
+ label="Select from audios num.This number control the number of candidates \
111
+ (e.g., generate three audios and choose the best to show you). A Larger value usually lead to \
112
+ better quality with heavier computation", minimum=1, maximum=10, value=3, step=1)
113
+ # num_samples = 1
114
+ ddim_steps = gr.Slider(label="Steps", minimum=1,
115
+ maximum=150, value=100, step=1)
116
+ scale = gr.Slider(
117
+ label="Guidance Scale:(Large => more relevant to text but the quality may drop)", minimum=0.1, maximum=4.0, value=1.5, step=0.1
118
+ )
119
+ seed = gr.Slider(
120
+ label="Seed:Change this value (any integer number) will lead to a different generation result.",
121
+ minimum=0,
122
+ maximum=2147483647,
123
+ step=1,
124
+ value=44,
125
+ )
126
+
127
+ with gr.Column():
128
+ # audio_list = []
129
+ # for i in range(int(num_samples)):
130
+ # audio_list.append(gr.outputs.Audio())
131
+ outaudio = gr.Audio()
132
+
133
+
134
+ run_button.click(fn=predict, inputs=[
135
+ prompt,ddim_steps, num_samples, scale, seed], outputs=[outaudio])# inputs的参数只能传gr.xxx
136
+ with gr.Row():
137
+ with gr.Column():
138
+ gr.Examples(
139
+ examples = [['a dog barking and a bird chirping',100,3,2,55],['fireworks pop and explode',100,3,2,55],
140
+ ['piano and violin plays',100,3,2,55],['wind thunder and rain falling',100,3,2,55],['music made by drum kit',100,3,2,55]],
141
+ inputs = [prompt,ddim_steps, num_samples, scale, seed],
142
+ outputs = [outaudio]
143
+ )
144
+ with gr.Column():
145
+ pass
146
+
147
+ demo.launch()
configs/text_to_audio/bigvgan_args.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resblock: '1'
2
+ num_gpus: 0
3
+ batch_size: 64
4
+ num_mels: 80
5
+ learning_rate: 0.0001
6
+ adam_b1: 0.8
7
+ adam_b2: 0.99
8
+ lr_decay: 0.999
9
+ seed: 1234
10
+ upsample_rates:
11
+ - 4
12
+ - 4
13
+ - 2
14
+ - 2
15
+ - 2
16
+ - 2
17
+ upsample_kernel_sizes:
18
+ - 8
19
+ - 8
20
+ - 4
21
+ - 4
22
+ - 4
23
+ - 4
24
+ upsample_initial_channel: 1536
25
+ resblock_kernel_sizes:
26
+ - 3
27
+ - 7
28
+ - 11
29
+ resblock_dilation_sizes:
30
+ - - 1
31
+ - 3
32
+ - 5
33
+ - - 1
34
+ - 3
35
+ - 5
36
+ - - 1
37
+ - 3
38
+ - 5
39
+ activation: snakebeta
40
+ snake_logscale: true
41
+ resolutions:
42
+ - - 1024
43
+ - 120
44
+ - 600
45
+ - - 2048
46
+ - 240
47
+ - 1200
48
+ - - 512
49
+ - 50
50
+ - 240
51
+ mpd_reshapes:
52
+ - 2
53
+ - 3
54
+ - 5
55
+ - 7
56
+ - 11
57
+ use_spectral_norm: false
58
+ discriminator_channel_mult: 1
59
+ num_workers: 4
60
+ dist_config:
61
+ dist_backend: nccl
62
+ dist_url: tcp://localhost:54341
63
+ world_size: 1
configs/text_to_audio/clap_args.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEXT ENCODER CONFIG
2
+ text_model: 'bert-base-uncased'
3
+ text_len: 100
4
+ transformer_embed_dim: 768
5
+ freeze_text_encoder_weights: True
6
+
7
+ # AUDIO ENCODER CONFIG
8
+ audioenc_name: 'Cnn14'
9
+ out_emb: 2048
10
+ sampling_rate: 44100
11
+ duration: 9
12
+ fmin: 50
13
+ fmax: 14000
14
+ n_fft: 1028
15
+ hop_size: 320
16
+ mel_bins: 64
17
+ window_size: 1024
18
+
19
+ # PROJECTION SPACE CONFIG
20
+ d_proj: 1024
21
+ temperature: 0.003
22
+
23
+ # TRAINING AND EVALUATION CONFIG
24
+ num_classes: 527
25
+ batch_size: 1024
26
+ demo: False
configs/text_to_audio/txt2audio_args.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: ldm.models.diffusion.ddpm_audio.LatentDiffusion_audio
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32 # unused
13
+ mel_dim: 10 # 80 // 2^3
14
+ mel_length: 78 # 624 // 2^3
15
+ channels: 4
16
+ cond_stage_trainable: false
17
+ conditioning_key: crossattn
18
+ monitor: val/loss_simple_ema
19
+ scale_by_std: True
20
+ use_ema: False
21
+
22
+ scheduler_config: # 10000 warmup steps
23
+ target: ldm.lr_scheduler.LambdaLinearScheduler
24
+ params:
25
+ warm_up_steps: [10000]
26
+ cycle_lengths: [10000000000000]
27
+ f_start: [1.e-6]
28
+ f_max: [1.]
29
+ f_min: [ 1.]
30
+
31
+ unet_config:
32
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
33
+ params:
34
+ image_size: 32 # ununsed
35
+ in_channels: 4
36
+ out_channels: 4
37
+ model_channels: 320
38
+ attention_resolutions:
39
+ - 1
40
+ - 2
41
+ num_res_blocks: 2
42
+ channel_mult: # num_down = len(ch_mult)-1
43
+ - 1
44
+ - 2
45
+ num_heads: 8
46
+ use_spatial_transformer: true
47
+ transformer_depth: 1
48
+ context_dim: 1024
49
+ use_checkpoint: true
50
+ legacy: False
51
+
52
+ first_stage_config:
53
+ target: ldm.models.autoencoder.AutoencoderKL
54
+ params:
55
+ embed_dim: 4
56
+ monitor: val/rec_loss
57
+ ckpt_path:
58
+ ddconfig:
59
+ double_z: true
60
+ z_channels: 4
61
+ resolution: 624
62
+ in_channels: 1
63
+ out_ch: 1
64
+ ch: 128
65
+ ch_mult: [ 1, 2, 2, 4 ] # num_down = len(ch_mult)-1
66
+ num_res_blocks: 2
67
+ attn_resolutions: [78, 156]
68
+ dropout: 0.0
69
+ lossconfig:
70
+ target: torch.nn.Identity
71
+
72
+ cond_stage_config:
73
+ target: ldm.modules.encoders.modules.FrozenCLAPEmbedder
74
+ params:
75
+ weights_path: useful_ckpts/CLAP/CLAP_weights_2022.pth
76
+
77
+ ckpt_path: useful_ckpts/maa1_caps.ckpt
78
+
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ import torch.nn.functional as F
5
+ from contextlib import contextmanager
6
+ from packaging import version
7
+ import numpy as np
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+ from torch.optim.lr_scheduler import LambdaLR
11
+ from ldm.util import instantiate_from_config
12
+ # from icecream import ic
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def test_step(self, batch, batch_idx):
198
+ x = self.get_input(batch, self.image_key)
199
+ xrec, qloss, ind = self(x, return_pred_indices=True)
200
+ reconstructions = (xrec + 1)/2 # to mel scale
201
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
202
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
203
+ if not os.path.exists(savedir):
204
+ os.makedirs(savedir)
205
+
206
+ file_names = batch['f_name']
207
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
208
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
209
+ for b in range(reconstructions.shape[0]):
210
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
211
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
212
+ save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy')
213
+ np.save(save_img_path,reconstructions[b])
214
+
215
+ return None
216
+
217
+ def configure_optimizers(self):
218
+ lr_d = self.learning_rate
219
+ lr_g = self.lr_g_factor*self.learning_rate
220
+ print("lr_d", lr_d)
221
+ print("lr_g", lr_g)
222
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
223
+ list(self.decoder.parameters())+
224
+ list(self.quantize.parameters())+
225
+ list(self.quant_conv.parameters())+
226
+ list(self.post_quant_conv.parameters()),
227
+ lr=lr_g, betas=(0.5, 0.9))
228
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
229
+ lr=lr_d, betas=(0.5, 0.9))
230
+
231
+ if self.scheduler_config is not None:
232
+ scheduler = instantiate_from_config(self.scheduler_config)
233
+
234
+ print("Setting up LambdaLR scheduler...")
235
+ scheduler = [
236
+ {
237
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
238
+ 'interval': 'step',
239
+ 'frequency': 1
240
+ },
241
+ {
242
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
243
+ 'interval': 'step',
244
+ 'frequency': 1
245
+ },
246
+ ]
247
+ return [opt_ae, opt_disc], scheduler
248
+ return [opt_ae, opt_disc], []
249
+
250
+ def get_last_layer(self):
251
+ return self.decoder.conv_out.weight
252
+
253
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
254
+ log = dict()
255
+ x = self.get_input(batch, self.image_key)
256
+ x = x.to(self.device)
257
+ if only_inputs:
258
+ log["inputs"] = x
259
+ return log
260
+ xrec, _ = self(x)
261
+ if x.shape[1] > 3:
262
+ # colorize with random projection
263
+ assert xrec.shape[1] > 3
264
+ x = self.to_rgb(x)
265
+ xrec = self.to_rgb(xrec)
266
+ log["inputs"] = x
267
+ log["reconstructions"] = xrec
268
+ if plot_ema:
269
+ with self.ema_scope():
270
+ xrec_ema, _ = self(x)
271
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
272
+ log["reconstructions_ema"] = xrec_ema
273
+ return log
274
+
275
+ def to_rgb(self, x):
276
+ assert self.image_key == "segmentation"
277
+ if not hasattr(self, "colorize"):
278
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
279
+ x = F.conv2d(x, weight=self.colorize)
280
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
281
+ return x
282
+
283
+
284
+ class VQModelInterface(VQModel):
285
+ def __init__(self, embed_dim, *args, **kwargs):
286
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
287
+ self.embed_dim = embed_dim
288
+
289
+ def encode(self, x):# VQModel的quantize写在encoder里,VQModelInterface则将其写在decoder里
290
+ h = self.encoder(x)
291
+ h = self.quant_conv(h)
292
+ return h
293
+
294
+ def decode(self, h, force_not_quantize=False):
295
+ # also go through quantization layer
296
+ if not force_not_quantize:
297
+ quant, emb_loss, info = self.quantize(h)
298
+ else:
299
+ quant = h
300
+ quant = self.post_quant_conv(quant)
301
+ dec = self.decoder(quant)
302
+ return dec
303
+
304
+
305
+ class AutoencoderKL(pl.LightningModule):
306
+ def __init__(self,
307
+ ddconfig,
308
+ lossconfig,
309
+ embed_dim,
310
+ ckpt_path=None,
311
+ ignore_keys=[],
312
+ image_key="image",
313
+ colorize_nlabels=None,
314
+ monitor=None,
315
+ ):
316
+ super().__init__()
317
+ self.image_key = image_key
318
+ self.encoder = Encoder(**ddconfig)
319
+ self.decoder = Decoder(**ddconfig)
320
+ self.loss = instantiate_from_config(lossconfig)
321
+ assert ddconfig["double_z"]
322
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
323
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
324
+ self.embed_dim = embed_dim
325
+ if colorize_nlabels is not None:
326
+ assert type(colorize_nlabels)==int
327
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
328
+ if monitor is not None:
329
+ self.monitor = monitor
330
+ if ckpt_path is not None:
331
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
332
+ # self.automatic_optimization = False # hjw for debug
333
+
334
+ def init_from_ckpt(self, path, ignore_keys=list()):
335
+ sd = torch.load(path, map_location="cpu")["state_dict"]
336
+ keys = list(sd.keys())
337
+ for k in keys:
338
+ for ik in ignore_keys:
339
+ if k.startswith(ik):
340
+ print("Deleting key {} from state_dict.".format(k))
341
+ del sd[k]
342
+ self.load_state_dict(sd, strict=False)
343
+ print(f"Restored from {path}")
344
+
345
+ def encode(self, x):
346
+ h = self.encoder(x)
347
+ moments = self.quant_conv(h)
348
+ posterior = DiagonalGaussianDistribution(moments)
349
+ return posterior
350
+
351
+ def decode(self, z):
352
+ z = self.post_quant_conv(z)
353
+ dec = self.decoder(z)
354
+ return dec
355
+
356
+ def forward(self, input, sample_posterior=True):
357
+ posterior = self.encode(input)
358
+ if sample_posterior:
359
+ z = posterior.sample()
360
+ else:
361
+ z = posterior.mode()
362
+ dec = self.decode(z)
363
+ return dec, posterior
364
+
365
+ def get_input(self, batch, k):
366
+ x = batch[k]
367
+ if len(x.shape) == 3:
368
+ x = x[..., None]
369
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
370
+ return x
371
+
372
+ def training_step(self, batch, batch_idx, optimizer_idx):
373
+ inputs = self.get_input(batch, self.image_key)
374
+ reconstructions, posterior = self(inputs)
375
+
376
+ if optimizer_idx == 0:
377
+ # train encoder+decoder+logvar
378
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
379
+ last_layer=self.get_last_layer(), split="train")
380
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
381
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
382
+ return aeloss
383
+
384
+ if optimizer_idx == 1:
385
+ # train the discriminator
386
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
387
+ last_layer=self.get_last_layer(), split="train")
388
+
389
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
390
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
391
+ return discloss
392
+
393
+ def validation_step(self, batch, batch_idx):
394
+ # self.log_images(batch,only_inputs=False,save_dir='mel_result_ae13_26/fake_class')
395
+ return self.log_dict
396
+
397
+ def test_step(self, batch, batch_idx):
398
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
399
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
400
+ os.makedirs(savedir,exist_ok=True)
401
+ inputs = self.get_input(batch, self.image_key)# inputs shape:(b,c,mel_len,T) or (b,c,h,w)
402
+ # ic(inputs.shape)
403
+ # inputs = inputs[...,:624]
404
+ # ic(inputs.shape)
405
+ xrec, posterior = self(inputs)# reconstructions:(b,c,mel_len,T) or (b,c,h,w)
406
+ file_names = batch['f_name']
407
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
408
+ for b in range(len(file_names)):
409
+ rcon = (xrec[b].squeeze().detach().cpu().numpy() + 1) / 2 # to mel scale,squeeze channel dim
410
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
411
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
412
+ save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy')
413
+ np.save(save_img_path,rcon)
414
+
415
+ return None
416
+
417
+ def configure_optimizers(self):
418
+ lr = self.learning_rate
419
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
420
+ list(self.decoder.parameters())+
421
+ list(self.quant_conv.parameters())+
422
+ list(self.post_quant_conv.parameters()),
423
+ lr=lr, betas=(0.5, 0.9))
424
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
425
+ lr=lr, betas=(0.5, 0.9))
426
+ return [opt_ae, opt_disc], []
427
+
428
+ def get_last_layer(self):
429
+ return self.decoder.conv_out.weight
430
+
431
+ @torch.no_grad()
432
+ def log_images(self, batch, only_inputs=False,save_dir = 'mel_result_ae13_26_debug/fake_class', **kwargs): # 在main.py的on_validation_batch_end中调用
433
+ log = dict()
434
+ x = self.get_input(batch, self.image_key)
435
+ x = x.to(self.device)
436
+ if not only_inputs:
437
+ xrec, posterior = self(x)
438
+ if x.shape[1] > 3:
439
+ # colorize with random projection
440
+ assert xrec.shape[1] > 3
441
+ x = self.to_rgb(x)
442
+ xrec = self.to_rgb(xrec)
443
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
444
+ log["reconstructions"] = xrec
445
+ log["inputs"] = x
446
+ return log
447
+
448
+ def to_rgb(self, x):
449
+ assert self.image_key == "segmentation"
450
+ if not hasattr(self, "colorize"):
451
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
452
+ x = F.conv2d(x, weight=self.colorize)
453
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
454
+ return x
455
+
456
+
457
+ class IdentityFirstStage(torch.nn.Module):
458
+ def __init__(self, *args, vq_interface=False, **kwargs):
459
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
460
+ super().__init__()
461
+
462
+ def encode(self, x, *args, **kwargs):
463
+ return x
464
+
465
+ def decode(self, x, *args, **kwargs):
466
+ return x
467
+
468
+ def quantize(self, x, *args, **kwargs):
469
+ if self.vq_interface:
470
+ return x, None, [None, None, None]
471
+ return x
472
+
473
+ def forward(self, x, *args, **kwargs):
474
+ return x
ldm/models/autoencoder_multi.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 与autoencoder.py的区别在于,autoencoder.py计算loss时只有一个discriminator,而此处又多了个multiwindowDiscriminator,所以优化器
3
+ 优化的参数改为:
4
+ opt_disc = torch.optim.Adam(list(self.loss.discriminator.parameters()) + list(self.loss.discriminator_multi.parameters()),
5
+ lr=lr, betas=(0.5, 0.9))
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ import pytorch_lightning as pl
11
+ import torch.nn.functional as F
12
+ from contextlib import contextmanager
13
+
14
+ from packaging import version
15
+ import numpy as np
16
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
17
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
18
+ from torch.optim.lr_scheduler import LambdaLR
19
+ from ldm.util import instantiate_from_config
20
+
21
+
22
+
23
+ class AutoencoderKL(pl.LightningModule):
24
+ def __init__(self,
25
+ ddconfig,
26
+ lossconfig,
27
+ embed_dim,
28
+ ckpt_path=None,
29
+ ignore_keys=[],
30
+ image_key="image",
31
+ colorize_nlabels=None,
32
+ monitor=None,
33
+ ):
34
+ super().__init__()
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ assert ddconfig["double_z"]
40
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
41
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
42
+ self.embed_dim = embed_dim
43
+ if colorize_nlabels is not None:
44
+ assert type(colorize_nlabels)==int
45
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
46
+ if monitor is not None:
47
+ self.monitor = monitor
48
+ if ckpt_path is not None:
49
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
50
+
51
+ def init_from_ckpt(self, path, ignore_keys=list()):
52
+ sd = torch.load(path, map_location="cpu")["state_dict"]
53
+ keys = list(sd.keys())
54
+ for k in keys:
55
+ for ik in ignore_keys:
56
+ if k.startswith(ik):
57
+ print("Deleting key {} from state_dict.".format(k))
58
+ del sd[k]
59
+ self.load_state_dict(sd, strict=False)
60
+ print(f"Restored from {path}")
61
+
62
+ def encode(self, x):
63
+ h = self.encoder(x)
64
+ moments = self.quant_conv(h)
65
+ posterior = DiagonalGaussianDistribution(moments)
66
+ return posterior
67
+
68
+ def decode(self, z):
69
+ z = self.post_quant_conv(z)
70
+ dec = self.decoder(z)
71
+ return dec
72
+
73
+ def forward(self, input, sample_posterior=True):
74
+ posterior = self.encode(input)
75
+ if sample_posterior:
76
+ z = posterior.sample()
77
+ else:
78
+ z = posterior.mode()
79
+ dec = self.decode(z)
80
+ return dec, posterior
81
+
82
+ def get_input(self, batch, k):
83
+ x = batch[k]
84
+ if len(x.shape) == 3:
85
+ x = x[..., None]
86
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
87
+ return x
88
+
89
+ def training_step(self, batch, batch_idx, optimizer_idx):
90
+ inputs = self.get_input(batch, self.image_key)
91
+ reconstructions, posterior = self(inputs)
92
+
93
+ if optimizer_idx == 0:
94
+ # train encoder+decoder+logvar
95
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
96
+ last_layer=self.get_last_layer(), split="train")
97
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
98
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
99
+ return aeloss
100
+
101
+ if optimizer_idx == 1:
102
+ # train the discriminator
103
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
104
+ last_layer=self.get_last_layer(), split="train")
105
+
106
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
107
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
108
+ return discloss
109
+
110
+ def validation_step(self, batch, batch_idx):
111
+ inputs = self.get_input(batch, self.image_key)
112
+ reconstructions, posterior = self(inputs)
113
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
114
+ last_layer=self.get_last_layer(), split="val")
115
+
116
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
117
+ last_layer=self.get_last_layer(), split="val")
118
+
119
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
120
+ self.log_dict(log_dict_ae)
121
+ self.log_dict(log_dict_disc)
122
+ return self.log_dict
123
+
124
+ def test_step(self, batch, batch_idx):
125
+ inputs = self.get_input(batch, self.image_key)# inputs shape:(b,c,mel_len,T) or (b,c,h,w)
126
+ reconstructions, posterior = self(inputs)# reconstructions:(b,c,mel_len,T) or (b,c,h,w)
127
+ reconstructions = (reconstructions + 1)/2 # to mel scale
128
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
129
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
130
+ if not os.path.exists(savedir):
131
+ os.makedirs(savedir)
132
+
133
+ file_names = batch['f_name']
134
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
135
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
136
+ for b in range(reconstructions.shape[0]):
137
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
138
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
139
+ save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy')
140
+ np.save(save_img_path,reconstructions[b])
141
+
142
+ return None
143
+
144
+ def configure_optimizers(self):
145
+ lr = self.learning_rate
146
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
147
+ list(self.decoder.parameters())+
148
+ list(self.quant_conv.parameters())+
149
+ list(self.post_quant_conv.parameters()),
150
+ lr=lr, betas=(0.5, 0.9))
151
+ opt_disc = torch.optim.Adam(list(self.loss.discriminator.parameters()) + list(self.loss.discriminator_multi.parameters()),
152
+ lr=lr, betas=(0.5, 0.9))
153
+ return [opt_ae, opt_disc], []
154
+
155
+ def get_last_layer(self):
156
+ return self.decoder.conv_out.weight
157
+
158
+ @torch.no_grad()
159
+ def log_images(self, batch, only_inputs=False, **kwargs):
160
+ log = dict()
161
+ x = self.get_input(batch, self.image_key)
162
+ x = x.to(self.device)
163
+ if not only_inputs:
164
+ xrec, posterior = self(x)
165
+ if x.shape[1] > 3:
166
+ # colorize with random projection
167
+ assert xrec.shape[1] > 3
168
+ x = self.to_rgb(x)
169
+ xrec = self.to_rgb(xrec)
170
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
171
+ log["reconstructions"] = xrec
172
+ log["inputs"] = x
173
+ return log
174
+
175
+ def to_rgb(self, x):
176
+ assert self.image_key == "segmentation"
177
+ if not hasattr(self, "colorize"):
178
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
179
+ x = F.conv2d(x, weight=self.colorize)
180
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
181
+ return x
182
+
183
+
184
+ class IdentityFirstStage(torch.nn.Module):
185
+ def __init__(self, *args, vq_interface=False, **kwargs):
186
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
187
+ super().__init__()
188
+
189
+ def encode(self, x, *args, **kwargs):
190
+ return x
191
+
192
+ def decode(self, x, *args, **kwargs):
193
+ return x
194
+
195
+ def quantize(self, x, *args, **kwargs):
196
+ if self.vq_interface:
197
+ return x, None, [None, None, None]
198
+ return x
199
+
200
+ def forward(self, x, *args, **kwargs):
201
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
+ extract_into_tensor
10
+
11
+
12
+ class DDIMSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+ self.ddpm_num_timesteps = model.num_timesteps
18
+ self.schedule = schedule
19
+
20
+ def register_buffer(self, name, attr):
21
+ if type(attr) == torch.Tensor:
22
+ # if attr.device != torch.device("cuda"):
23
+ # attr = attr.to(torch.device("cuda"))
24
+ attr = attr.to(self.device)
25
+ setattr(self, name, attr)
26
+
27
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
28
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30
+ alphas_cumprod = self.model.alphas_cumprod
31
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33
+
34
+ self.register_buffer('betas', to_torch(self.model.betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
+
45
+ # ddim sampling parameters
46
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=verbose)
49
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
50
+ self.register_buffer('ddim_alphas', ddim_alphas)
51
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
+
58
+ @torch.no_grad()
59
+ def sample(self,
60
+ S,
61
+ batch_size,
62
+ shape,
63
+ conditioning=None,
64
+ callback=None,
65
+ normals_sequence=None,
66
+ img_callback=None,
67
+ quantize_x0=False,
68
+ eta=0.,
69
+ mask=None,
70
+ x0=None,
71
+ temperature=1.,
72
+ noise_dropout=0.,
73
+ score_corrector=None,
74
+ corrector_kwargs=None,
75
+ verbose=True,
76
+ x_T=None,
77
+ log_every_t=100,
78
+ unconditional_guidance_scale=1.,
79
+ unconditional_conditioning=None,
80
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
+ **kwargs
82
+ ):
83
+ if conditioning is not None:
84
+ if isinstance(conditioning, dict):
85
+ ctmp = conditioning[list(conditioning.keys())[0]]
86
+ while isinstance(ctmp, list): ctmp = ctmp[0]
87
+ cbs = ctmp.shape[0]
88
+ if cbs != batch_size:
89
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
90
+ else:
91
+ if conditioning.shape[0] != batch_size:
92
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
93
+
94
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
95
+ # sampling
96
+ C, H, W = shape
97
+ size = (batch_size, C, H, W)
98
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
99
+
100
+ samples, intermediates = self.ddim_sampling(conditioning, size,
101
+ callback=callback,
102
+ img_callback=img_callback,
103
+ quantize_denoised=quantize_x0,
104
+ mask=mask, x0=x0,
105
+ ddim_use_original_steps=False,
106
+ noise_dropout=noise_dropout,
107
+ temperature=temperature,
108
+ score_corrector=score_corrector,
109
+ corrector_kwargs=corrector_kwargs,
110
+ x_T=x_T,
111
+ log_every_t=log_every_t,
112
+ unconditional_guidance_scale=unconditional_guidance_scale,
113
+ unconditional_conditioning=unconditional_conditioning,
114
+ )
115
+ return samples, intermediates
116
+
117
+ @torch.no_grad()
118
+ def ddim_sampling(self, cond, shape,
119
+ x_T=None, ddim_use_original_steps=False,
120
+ callback=None, timesteps=None, quantize_denoised=False,
121
+ mask=None, x0=None, img_callback=None, log_every_t=100,
122
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
124
+ device = self.model.betas.device
125
+ b = shape[0]
126
+ if x_T is None:
127
+ img = torch.randn(shape, device=device)
128
+ else:
129
+ img = x_T
130
+
131
+ if timesteps is None:
132
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
133
+ elif timesteps is not None and not ddim_use_original_steps:
134
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
135
+ timesteps = self.ddim_timesteps[:subset_end]
136
+
137
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
138
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
139
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
140
+
141
+ # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
142
+
143
+ for i, step in enumerate(time_range):
144
+ index = total_steps - i - 1
145
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
146
+
147
+ if mask is not None:
148
+ assert x0 is not None
149
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150
+ img = img_orig * mask + (1. - mask) * img
151
+
152
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153
+ quantize_denoised=quantize_denoised, temperature=temperature,
154
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ unconditional_guidance_scale=unconditional_guidance_scale,
157
+ unconditional_conditioning=unconditional_conditioning)
158
+ img, pred_x0 = outs
159
+ if callback: callback(i)
160
+ if img_callback: img_callback(pred_x0, i)
161
+
162
+ if index % log_every_t == 0 or index == total_steps - 1:
163
+ intermediates['x_inter'].append(img)
164
+ intermediates['pred_x0'].append(pred_x0)
165
+
166
+ return img, intermediates
167
+
168
+ @torch.no_grad()
169
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
170
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
171
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
172
+ b, *_, device = *x.shape, x.device
173
+
174
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
175
+ e_t = self.model.apply_model(x, t, c)
176
+ else:
177
+ x_in = torch.cat([x] * 2)
178
+ t_in = torch.cat([t] * 2)
179
+ if isinstance(c, dict):
180
+ assert isinstance(unconditional_conditioning, dict)
181
+ c_in = dict()
182
+ for k in c:
183
+ if isinstance(c[k], list):
184
+ c_in[k] = [torch.cat([
185
+ unconditional_conditioning[k][i],
186
+ c[k][i]]) for i in range(len(c[k]))]
187
+ else:
188
+ c_in[k] = torch.cat([
189
+ unconditional_conditioning[k],
190
+ c[k]])
191
+ elif isinstance(c, list):
192
+ c_in = list()
193
+ assert isinstance(unconditional_conditioning, list)
194
+ for i in range(len(c)):
195
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
196
+ else:
197
+ c_in = torch.cat([unconditional_conditioning, c])# c/uc shape [b,seq_len=77,dim=1024],c_in shape [b*2,seq_len,dim]
198
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
199
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
200
+
201
+ if score_corrector is not None:
202
+ assert self.model.parameterization == "eps"
203
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
204
+
205
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
206
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
207
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
208
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
209
+ # select parameters corresponding to the currently considered timestep
210
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
211
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
212
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
213
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
214
+
215
+ # current prediction for x_0
216
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
217
+ if quantize_denoised:
218
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
219
+ # direction pointing to x_t
220
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
221
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
222
+ if noise_dropout > 0.:
223
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
224
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
225
+ return x_prev, pred_x0
226
+
227
+ @torch.no_grad()
228
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
229
+ # fast, but does not allow for exact reconstruction
230
+ # t serves as an index to gather the correct alphas
231
+ if use_original_steps:
232
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
233
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
234
+ else:
235
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
236
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
237
+
238
+ if noise is None:
239
+ noise = torch.randn_like(x0)
240
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
241
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
242
+
243
+ @torch.no_grad()
244
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
245
+ use_original_steps=False):
246
+
247
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
248
+ timesteps = timesteps[:t_start]
249
+
250
+ time_range = np.flip(timesteps)
251
+ total_steps = timesteps.shape[0]
252
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
253
+
254
+ # iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
255
+ x_dec = x_latent
256
+ for i, step in enumerate(time_range):
257
+ index = total_steps - i - 1
258
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
259
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
260
+ unconditional_guidance_scale=unconditional_guidance_scale,
261
+ unconditional_conditioning=unconditional_conditioning)
262
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ from torch.optim.lr_scheduler import LambdaLR
13
+ from einops import rearrange, repeat
14
+ from contextlib import contextmanager
15
+ from functools import partial
16
+ from tqdm import tqdm
17
+ from torchvision.utils import make_grid
18
+ from pytorch_lightning.utilities.distributed import rank_zero_only
19
+
20
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
21
+ from ldm.modules.ema import LitEma
22
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
23
+ from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
24
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
25
+ from ldm.models.diffusion.ddim import DDIMSampler
26
+
27
+
28
+ __conditioning_keys__ = {'concat': 'c_concat',
29
+ 'crossattn': 'c_crossattn',
30
+ 'adm': 'y'}
31
+
32
+
33
+ def disabled_train(self, mode=True):
34
+ """Overwrite model.train with this function to make sure train/eval mode
35
+ does not change anymore."""
36
+ return self
37
+
38
+
39
+ def uniform_on_device(r1, r2, shape, device):
40
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
41
+
42
+
43
+ class DDPM(pl.LightningModule):
44
+ # classic DDPM with Gaussian diffusion, in image space
45
+ def __init__(self,
46
+ unet_config,
47
+ timesteps=1000,
48
+ beta_schedule="linear",
49
+ loss_type="l2",
50
+ ckpt_path=None,
51
+ ignore_keys=[],
52
+ load_only_unet=False,
53
+ monitor="val/loss",
54
+ use_ema=True,
55
+ first_stage_key="image",
56
+ image_size=256,
57
+ channels=3,
58
+ log_every_t=100,
59
+ clip_denoised=True,
60
+ linear_start=1e-4,
61
+ linear_end=2e-2,
62
+ cosine_s=8e-3,
63
+ given_betas=None,
64
+ original_elbo_weight=0.,
65
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
66
+ l_simple_weight=1.,
67
+ conditioning_key=None,
68
+ parameterization="eps", # all config files uses "eps"
69
+ scheduler_config=None,
70
+ use_positional_encodings=False,
71
+ learn_logvar=False,
72
+ logvar_init=0.,
73
+ ):
74
+ super().__init__()
75
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
76
+ self.parameterization = parameterization
77
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
78
+ self.cond_stage_model = None
79
+ self.clip_denoised = clip_denoised
80
+ self.log_every_t = log_every_t
81
+ self.first_stage_key = first_stage_key
82
+ self.image_size = image_size # try conv?
83
+ self.channels = channels
84
+ self.use_positional_encodings = use_positional_encodings
85
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
86
+ count_params(self.model, verbose=True)
87
+ self.use_ema = use_ema
88
+ if self.use_ema:
89
+ self.model_ema = LitEma(self.model)
90
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
91
+
92
+ self.use_scheduler = scheduler_config is not None
93
+ if self.use_scheduler:
94
+ self.scheduler_config = scheduler_config
95
+
96
+ self.v_posterior = v_posterior
97
+ self.original_elbo_weight = original_elbo_weight
98
+ self.l_simple_weight = l_simple_weight
99
+
100
+ if monitor is not None:
101
+ self.monitor = monitor
102
+ if ckpt_path is not None:
103
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
104
+
105
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
106
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
107
+
108
+ self.loss_type = loss_type
109
+
110
+ self.learn_logvar = learn_logvar
111
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
112
+ if self.learn_logvar:
113
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
114
+
115
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
116
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
117
+ if exists(given_betas):
118
+ betas = given_betas
119
+ else:
120
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
121
+ cosine_s=cosine_s)
122
+ alphas = 1. - betas
123
+ alphas_cumprod = np.cumprod(alphas, axis=0)
124
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
125
+
126
+ timesteps, = betas.shape
127
+ self.num_timesteps = int(timesteps)
128
+ self.linear_start = linear_start
129
+ self.linear_end = linear_end
130
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
131
+
132
+ to_torch = partial(torch.tensor, dtype=torch.float32)
133
+
134
+ self.register_buffer('betas', to_torch(betas))
135
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
136
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
137
+
138
+ # calculations for diffusion q(x_t | x_{t-1}) and others
139
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
140
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
141
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
142
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
143
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
144
+
145
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
146
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
147
+ 1. - alphas_cumprod) + self.v_posterior * betas
148
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
149
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
150
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
151
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
152
+ self.register_buffer('posterior_mean_coef1', to_torch(
153
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
154
+ self.register_buffer('posterior_mean_coef2', to_torch(
155
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
156
+
157
+ if self.parameterization == "eps":
158
+ lvlb_weights = self.betas ** 2 / (
159
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
160
+ elif self.parameterization == "x0":
161
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
162
+ else:
163
+ raise NotImplementedError("mu not supported")
164
+ # TODO how to choose this term
165
+ lvlb_weights[0] = lvlb_weights[1]
166
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
167
+ assert not torch.isnan(self.lvlb_weights).all()
168
+
169
+ @contextmanager
170
+ def ema_scope(self, context=None):
171
+ if self.use_ema:
172
+ self.model_ema.store(self.model.parameters())
173
+ self.model_ema.copy_to(self.model)
174
+ if context is not None:
175
+ print(f"{context}: Switched to EMA weights")
176
+ try:
177
+ yield None
178
+ finally:
179
+ if self.use_ema:
180
+ self.model_ema.restore(self.model.parameters())
181
+ if context is not None:
182
+ print(f"{context}: Restored training weights")
183
+
184
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
185
+ sd = torch.load(path, map_location="cpu")
186
+ if "state_dict" in list(sd.keys()):
187
+ sd = sd["state_dict"]
188
+ keys = list(sd.keys())
189
+ for k in keys:
190
+ for ik in ignore_keys:
191
+ if k.startswith(ik):
192
+ print("Deleting key {} from state_dict.".format(k))
193
+ del sd[k]
194
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
195
+ sd, strict=False)
196
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
197
+ if len(missing) > 0:
198
+ print(f"Missing Keys: {missing}")
199
+ if len(unexpected) > 0:
200
+ print(f"Unexpected Keys: {unexpected}")
201
+
202
+ def q_mean_variance(self, x_start, t):
203
+ """
204
+ Get the distribution q(x_t | x_0).
205
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
206
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
207
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
208
+ """
209
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
210
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
211
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
212
+ return mean, variance, log_variance
213
+
214
+ def predict_start_from_noise(self, x_t, t, noise):
215
+ return (
216
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
217
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
218
+ )
219
+
220
+ def q_posterior(self, x_start, x_t, t):
221
+ posterior_mean = (
222
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
223
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
224
+ )
225
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
226
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
227
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
228
+
229
+ def p_mean_variance(self, x, t, clip_denoised: bool):
230
+ model_out = self.model(x, t)
231
+ if self.parameterization == "eps":
232
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
233
+ elif self.parameterization == "x0":
234
+ x_recon = model_out
235
+ if clip_denoised:
236
+ x_recon.clamp_(-1., 1.)
237
+
238
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
239
+ return model_mean, posterior_variance, posterior_log_variance
240
+
241
+ @torch.no_grad()
242
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
243
+ b, *_, device = *x.shape, x.device
244
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
245
+ noise = noise_like(x.shape, device, repeat_noise)
246
+ # no noise when t == 0
247
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
248
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
249
+
250
+ @torch.no_grad()
251
+ def p_sample_loop(self, shape, return_intermediates=False):
252
+ device = self.betas.device
253
+ b = shape[0]
254
+ img = torch.randn(shape, device=device)
255
+ intermediates = [img]
256
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
257
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
258
+ clip_denoised=self.clip_denoised)
259
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
260
+ intermediates.append(img)
261
+ if return_intermediates:
262
+ return img, intermediates
263
+ return img
264
+
265
+ @torch.no_grad()
266
+ def sample(self, batch_size=16, return_intermediates=False):
267
+ image_size = self.image_size
268
+ channels = self.channels
269
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
270
+ return_intermediates=return_intermediates)
271
+
272
+ def q_sample(self, x_start, t, noise=None):
273
+ noise = default(noise, lambda: torch.randn_like(x_start))
274
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
275
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
276
+
277
+ def get_loss(self, pred, target, mean=True):
278
+ if self.loss_type == 'l1':
279
+ loss = (target - pred).abs()
280
+ if mean:
281
+ loss = loss.mean()
282
+ elif self.loss_type == 'l2':
283
+ if mean:
284
+ loss = torch.nn.functional.mse_loss(target, pred)
285
+ else:
286
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
287
+ else:
288
+ raise NotImplementedError("unknown loss type '{loss_type}'")
289
+
290
+ return loss
291
+
292
+ def p_losses(self, x_start, t, noise=None):
293
+ noise = default(noise, lambda: torch.randn_like(x_start))
294
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
295
+ model_out = self.model(x_noisy, t)
296
+
297
+ loss_dict = {}
298
+ if self.parameterization == "eps":
299
+ target = noise
300
+ elif self.parameterization == "x0":
301
+ target = x_start
302
+ else:
303
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
304
+
305
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
306
+
307
+ log_prefix = 'train' if self.training else 'val'
308
+
309
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
310
+ loss_simple = loss.mean() * self.l_simple_weight
311
+
312
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
313
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
314
+
315
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
316
+
317
+ loss_dict.update({f'{log_prefix}/loss': loss})
318
+
319
+ return loss, loss_dict
320
+
321
+ def forward(self, x, *args, **kwargs):
322
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
323
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
324
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
325
+ return self.p_losses(x, t, *args, **kwargs)
326
+
327
+ def get_input(self, batch, k):
328
+ x = batch[k]
329
+ if len(x.shape) == 3:
330
+ x = x[..., None]
331
+ x = rearrange(x, 'b h w c -> b c h w')
332
+ x = x.to(memory_format=torch.contiguous_format).float()
333
+ return x
334
+
335
+ def shared_step(self, batch):
336
+ x = self.get_input(batch, self.first_stage_key)
337
+ loss, loss_dict = self(x)
338
+ return loss, loss_dict
339
+
340
+ def training_step(self, batch, batch_idx):
341
+ loss, loss_dict = self.shared_step(batch)
342
+
343
+ self.log_dict(loss_dict, prog_bar=True,
344
+ logger=True, on_step=True, on_epoch=True)
345
+
346
+ self.log("global_step", self.global_step,
347
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
348
+
349
+ if self.use_scheduler:
350
+ lr = self.optimizers().param_groups[0]['lr']
351
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
352
+
353
+ return loss
354
+
355
+ @torch.no_grad()
356
+ def validation_step(self, batch, batch_idx):
357
+ _, loss_dict_no_ema = self.shared_step(batch)
358
+ with self.ema_scope():
359
+ _, loss_dict_ema = self.shared_step(batch)
360
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
361
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
362
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
363
+
364
+ def on_train_batch_end(self, *args, **kwargs):
365
+ if self.use_ema:
366
+ self.model_ema(self.model)
367
+
368
+ def _get_rows_from_list(self, samples):
369
+ n_imgs_per_row = len(samples)
370
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
371
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
372
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
373
+ return denoise_grid
374
+
375
+ @torch.no_grad()
376
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
377
+ log = dict()
378
+ x = self.get_input(batch, self.first_stage_key)
379
+ N = min(x.shape[0], N)
380
+ n_row = min(x.shape[0], n_row)
381
+ x = x.to(self.device)[:N]
382
+ log["inputs"] = x
383
+
384
+ # get diffusion row
385
+ diffusion_row = list()
386
+ x_start = x[:n_row]
387
+
388
+ for t in range(self.num_timesteps):
389
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
390
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
391
+ t = t.to(self.device).long()
392
+ noise = torch.randn_like(x_start)
393
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
394
+ diffusion_row.append(x_noisy)
395
+
396
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
397
+
398
+ if sample:
399
+ # get denoise row
400
+ with self.ema_scope("Plotting"):
401
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
402
+
403
+ log["samples"] = samples
404
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
405
+
406
+ if return_keys:
407
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
408
+ return log
409
+ else:
410
+ return {key: log[key] for key in return_keys}
411
+ return log
412
+
413
+ def configure_optimizers(self):
414
+ lr = self.learning_rate
415
+ params = list(self.model.parameters())
416
+ if self.learn_logvar:
417
+ params = params + [self.logvar]
418
+ opt = torch.optim.AdamW(params, lr=lr)
419
+ return opt
420
+
421
+
422
+ class LatentDiffusion(DDPM):
423
+ """main class"""
424
+ def __init__(self,
425
+ first_stage_config,
426
+ cond_stage_config,
427
+ num_timesteps_cond=None,
428
+ cond_stage_key="image",# 'caption' for txt2image, 'masked_image' for inpainting
429
+ cond_stage_trainable=False,
430
+ concat_mode=True,# true for inpainting
431
+ cond_stage_forward=None,
432
+ conditioning_key=None, # 'crossattn' for txt2image, None for inpainting
433
+ scale_factor=1.0,
434
+ scale_by_std=False,
435
+ *args, **kwargs):
436
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
437
+ self.scale_by_std = scale_by_std
438
+ assert self.num_timesteps_cond <= kwargs['timesteps']
439
+ # for backwards compatibility after implementation of DiffusionWrapper
440
+ if conditioning_key is None:
441
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
442
+ if cond_stage_config == '__is_unconditional__':
443
+ conditioning_key = None
444
+ ckpt_path = kwargs.pop("ckpt_path", None)
445
+ ignore_keys = kwargs.pop("ignore_keys", [])
446
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
447
+ self.concat_mode = concat_mode
448
+ self.cond_stage_trainable = cond_stage_trainable
449
+ self.cond_stage_key = cond_stage_key
450
+ try:
451
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
452
+ except:
453
+ self.num_downs = 0
454
+ if not scale_by_std:
455
+ self.scale_factor = scale_factor
456
+ else:
457
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
458
+ self.instantiate_first_stage(first_stage_config)
459
+ self.instantiate_cond_stage(cond_stage_config)
460
+ self.cond_stage_forward = cond_stage_forward
461
+ self.clip_denoised = False
462
+ self.bbox_tokenizer = None
463
+
464
+ self.restarted_from_ckpt = False
465
+ if ckpt_path is not None:
466
+ self.init_from_ckpt(ckpt_path, ignore_keys)
467
+ self.restarted_from_ckpt = True
468
+
469
+ def make_cond_schedule(self, ):
470
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
471
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
472
+ self.cond_ids[:self.num_timesteps_cond] = ids
473
+
474
+ @rank_zero_only
475
+ @torch.no_grad()
476
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
477
+ # only for very first batch
478
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
479
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
480
+ # set rescale weight to 1./std of encodings
481
+ print("### USING STD-RESCALING ###")
482
+ x = super().get_input(batch, self.first_stage_key)
483
+ x = x.to(self.device)
484
+ encoder_posterior = self.encode_first_stage(x)
485
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
486
+ del self.scale_factor
487
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
488
+ print(f"setting self.scale_factor to {self.scale_factor}")
489
+ print("### USING STD-RESCALING ###")
490
+
491
+ def register_schedule(self,
492
+ given_betas=None, beta_schedule="linear", timesteps=1000,
493
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
494
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
495
+
496
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
497
+ if self.shorten_cond_schedule:
498
+ self.make_cond_schedule()
499
+
500
+ def instantiate_first_stage(self, config):
501
+ model = instantiate_from_config(config)
502
+ self.first_stage_model = model.eval()
503
+ self.first_stage_model.train = disabled_train
504
+ for param in self.first_stage_model.parameters():
505
+ param.requires_grad = False
506
+
507
+ def instantiate_cond_stage(self, config):
508
+ if not self.cond_stage_trainable:
509
+ if config == "__is_first_stage__":# inpaint
510
+ print("Using first stage also as cond stage.")
511
+ self.cond_stage_model = self.first_stage_model
512
+ elif config == "__is_unconditional__":
513
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
514
+ self.cond_stage_model = None
515
+ # self.be_unconditional = True
516
+ else:
517
+ model = instantiate_from_config(config)
518
+ self.cond_stage_model = model.eval()
519
+ self.cond_stage_model.train = disabled_train
520
+ for param in self.cond_stage_model.parameters():
521
+ param.requires_grad = False
522
+ else:
523
+ assert config != '__is_first_stage__'
524
+ assert config != '__is_unconditional__'
525
+ model = instantiate_from_config(config)
526
+ self.cond_stage_model = model
527
+
528
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
529
+ denoise_row = []
530
+ for zd in tqdm(samples, desc=desc):
531
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
532
+ force_not_quantize=force_no_decoder_quantization))
533
+ n_imgs_per_row = len(denoise_row)
534
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
535
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
536
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
537
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
538
+ return denoise_grid
539
+
540
+ def get_first_stage_encoding(self, encoder_posterior):
541
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
542
+ z = encoder_posterior.sample()
543
+ elif isinstance(encoder_posterior, torch.Tensor):
544
+ z = encoder_posterior
545
+ else:
546
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
547
+ return self.scale_factor * z
548
+
549
+ def get_learned_conditioning(self, c):
550
+ if self.cond_stage_forward is None:
551
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
552
+ c = self.cond_stage_model.encode(c)
553
+ if isinstance(c, DiagonalGaussianDistribution):
554
+ c = c.mode()
555
+ else:
556
+ c = self.cond_stage_model(c)
557
+ else:
558
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
559
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
560
+ return c
561
+
562
+ def meshgrid(self, h, w):
563
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
564
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
565
+
566
+ arr = torch.cat([y, x], dim=-1)
567
+ return arr
568
+
569
+ def delta_border(self, h, w):
570
+ """
571
+ :param h: height
572
+ :param w: width
573
+ :return: normalized distance to image border,
574
+ wtith min distance = 0 at border and max dist = 0.5 at image center
575
+ """
576
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
577
+ arr = self.meshgrid(h, w) / lower_right_corner
578
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
579
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
580
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
581
+ return edge_dist
582
+
583
+ def get_weighting(self, h, w, Ly, Lx, device):
584
+ weighting = self.delta_border(h, w)
585
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
586
+ self.split_input_params["clip_max_weight"], )
587
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
588
+
589
+ if self.split_input_params["tie_braker"]:
590
+ L_weighting = self.delta_border(Ly, Lx)
591
+ L_weighting = torch.clip(L_weighting,
592
+ self.split_input_params["clip_min_tie_weight"],
593
+ self.split_input_params["clip_max_tie_weight"])
594
+
595
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
596
+ weighting = weighting * L_weighting
597
+ return weighting
598
+
599
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
600
+ """
601
+ :param x: img of size (bs, c, h, w)
602
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
603
+ """
604
+ bs, nc, h, w = x.shape
605
+
606
+ # number of crops in image
607
+ Ly = (h - kernel_size[0]) // stride[0] + 1
608
+ Lx = (w - kernel_size[1]) // stride[1] + 1
609
+
610
+ if uf == 1 and df == 1:
611
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
612
+ unfold = torch.nn.Unfold(**fold_params)
613
+
614
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
615
+
616
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
617
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
618
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
619
+
620
+ elif uf > 1 and df == 1:
621
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
622
+ unfold = torch.nn.Unfold(**fold_params)
623
+
624
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
625
+ dilation=1, padding=0,
626
+ stride=(stride[0] * uf, stride[1] * uf))
627
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
628
+
629
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
630
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
631
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
632
+
633
+ elif df > 1 and uf == 1:
634
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
635
+ unfold = torch.nn.Unfold(**fold_params)
636
+
637
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
638
+ dilation=1, padding=0,
639
+ stride=(stride[0] // df, stride[1] // df))
640
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
641
+
642
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
643
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
644
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
645
+
646
+ else:
647
+ raise NotImplementedError
648
+
649
+ return fold, unfold, normalization, weighting
650
+
651
+ @torch.no_grad()
652
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
653
+ cond_key=None, return_original_cond=False, bs=None):
654
+ x = super().get_input(batch, k)
655
+ if bs is not None:
656
+ x = x[:bs]
657
+ x = x.to(self.device)
658
+ encoder_posterior = self.encode_first_stage(x)
659
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
660
+
661
+ if self.model.conditioning_key is not None:
662
+ if cond_key is None:
663
+ cond_key = self.cond_stage_key
664
+ if cond_key != self.first_stage_key:# cond_key is not image. for inapint it's masked_img
665
+ if cond_key in ['caption', 'coordinates_bbox']:
666
+ xc = batch[cond_key]
667
+ elif cond_key == 'class_label':
668
+ xc = batch
669
+ else:
670
+ xc = super().get_input(batch, cond_key).to(self.device)
671
+ else:
672
+ xc = x
673
+ if not self.cond_stage_trainable or force_c_encode:
674
+ if isinstance(xc, dict) or isinstance(xc, list):
675
+ # import pudb; pudb.set_trace()
676
+ c = self.get_learned_conditioning(xc)
677
+ else:
678
+ c = self.get_learned_conditioning(xc.to(self.device))
679
+ else:
680
+ c = xc
681
+ if bs is not None:
682
+ c = c[:bs]
683
+
684
+ if self.use_positional_encodings:
685
+ pos_x, pos_y = self.compute_latent_shifts(batch)
686
+ ckey = __conditioning_keys__[self.model.conditioning_key]
687
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
688
+
689
+ else:
690
+ c = None
691
+ xc = None
692
+ if self.use_positional_encodings:
693
+ pos_x, pos_y = self.compute_latent_shifts(batch)
694
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
695
+ out = [z, c]
696
+ if return_first_stage_outputs:
697
+ xrec = self.decode_first_stage(z)
698
+ out.extend([x, xrec])
699
+ if return_original_cond:
700
+ out.append(xc)
701
+ return out
702
+
703
+ @torch.no_grad()
704
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
705
+ if predict_cids:
706
+ if z.dim() == 4:
707
+ z = torch.argmax(z.exp(), dim=1).long()
708
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
709
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
710
+
711
+ z = 1. / self.scale_factor * z
712
+
713
+ if hasattr(self, "split_input_params"):
714
+ if self.split_input_params["patch_distributed_vq"]:
715
+ ks = self.split_input_params["ks"] # eg. (128, 128)
716
+ stride = self.split_input_params["stride"] # eg. (64, 64)
717
+ uf = self.split_input_params["vqf"]
718
+ bs, nc, h, w = z.shape
719
+ if ks[0] > h or ks[1] > w:
720
+ ks = (min(ks[0], h), min(ks[1], w))
721
+ print("reducing Kernel")
722
+
723
+ if stride[0] > h or stride[1] > w:
724
+ stride = (min(stride[0], h), min(stride[1], w))
725
+ print("reducing stride")
726
+
727
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
728
+
729
+ z = unfold(z) # (bn, nc * prod(**ks), L)
730
+ # 1. Reshape to img shape
731
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
732
+
733
+ # 2. apply model loop over last dim
734
+ if isinstance(self.first_stage_model, VQModelInterface):
735
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
736
+ force_not_quantize=predict_cids or force_not_quantize)
737
+ for i in range(z.shape[-1])]
738
+ else:
739
+
740
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
741
+ for i in range(z.shape[-1])]
742
+
743
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
744
+ o = o * weighting
745
+ # Reverse 1. reshape to img shape
746
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
747
+ # stitch crops together
748
+ decoded = fold(o)
749
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
750
+ return decoded
751
+ else:
752
+ if isinstance(self.first_stage_model, VQModelInterface):
753
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
754
+ else:
755
+ return self.first_stage_model.decode(z)
756
+
757
+ else:
758
+ if isinstance(self.first_stage_model, VQModelInterface):
759
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
760
+ else:
761
+ return self.first_stage_model.decode(z)
762
+
763
+ # same as above but without decorator
764
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
765
+ if predict_cids:
766
+ if z.dim() == 4:
767
+ z = torch.argmax(z.exp(), dim=1).long()
768
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
769
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
770
+
771
+ z = 1. / self.scale_factor * z
772
+
773
+ if hasattr(self, "split_input_params"):
774
+ if self.split_input_params["patch_distributed_vq"]:
775
+ ks = self.split_input_params["ks"] # eg. (128, 128)
776
+ stride = self.split_input_params["stride"] # eg. (64, 64)
777
+ uf = self.split_input_params["vqf"]
778
+ bs, nc, h, w = z.shape
779
+ if ks[0] > h or ks[1] > w:
780
+ ks = (min(ks[0], h), min(ks[1], w))
781
+ print("reducing Kernel")
782
+
783
+ if stride[0] > h or stride[1] > w:
784
+ stride = (min(stride[0], h), min(stride[1], w))
785
+ print("reducing stride")
786
+
787
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
788
+
789
+ z = unfold(z) # (bn, nc * prod(**ks), L)
790
+ # 1. Reshape to img shape
791
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
792
+
793
+ # 2. apply model loop over last dim
794
+ if isinstance(self.first_stage_model, VQModelInterface):
795
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
796
+ force_not_quantize=predict_cids or force_not_quantize)
797
+ for i in range(z.shape[-1])]
798
+ else:
799
+
800
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
801
+ for i in range(z.shape[-1])]
802
+
803
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
804
+ o = o * weighting
805
+ # Reverse 1. reshape to img shape
806
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
807
+ # stitch crops together
808
+ decoded = fold(o)
809
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
810
+ return decoded
811
+ else:
812
+ if isinstance(self.first_stage_model, VQModelInterface):
813
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
814
+ else:
815
+ return self.first_stage_model.decode(z)
816
+
817
+ else:
818
+ if isinstance(self.first_stage_model, VQModelInterface):
819
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
820
+ else:
821
+ return self.first_stage_model.decode(z)
822
+
823
+ @torch.no_grad()
824
+ def encode_first_stage(self, x):
825
+ if hasattr(self, "split_input_params"):
826
+ if self.split_input_params["patch_distributed_vq"]:
827
+ ks = self.split_input_params["ks"] # eg. (128, 128)
828
+ stride = self.split_input_params["stride"] # eg. (64, 64)
829
+ df = self.split_input_params["vqf"]
830
+ self.split_input_params['original_image_size'] = x.shape[-2:]
831
+ bs, nc, h, w = x.shape
832
+ if ks[0] > h or ks[1] > w:
833
+ ks = (min(ks[0], h), min(ks[1], w))
834
+ print("reducing Kernel")
835
+
836
+ if stride[0] > h or stride[1] > w:
837
+ stride = (min(stride[0], h), min(stride[1], w))
838
+ print("reducing stride")
839
+
840
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
841
+ z = unfold(x) # (bn, nc * prod(**ks), L)
842
+ # Reshape to img shape
843
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
844
+
845
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
846
+ for i in range(z.shape[-1])]
847
+
848
+ o = torch.stack(output_list, axis=-1)
849
+ o = o * weighting
850
+
851
+ # Reverse reshape to img shape
852
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
853
+ # stitch crops together
854
+ decoded = fold(o)
855
+ decoded = decoded / normalization
856
+ return decoded
857
+
858
+ else:
859
+ return self.first_stage_model.encode(x)
860
+ else:
861
+ return self.first_stage_model.encode(x)
862
+
863
+ def shared_step(self, batch, **kwargs):
864
+ x, c = self.get_input(batch, self.first_stage_key)
865
+ loss = self(x, c)
866
+ return loss
867
+
868
+ def forward(self, x, c, *args, **kwargs):
869
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
870
+ if self.model.conditioning_key is not None:
871
+ assert c is not None
872
+ if self.cond_stage_trainable:# true when use text
873
+ c = self.get_learned_conditioning(c) # c: string list -> [B, T, Context_dim]
874
+ if self.shorten_cond_schedule: # TODO: drop this option
875
+ tc = self.cond_ids[t].to(self.device)
876
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
877
+ return self.p_losses(x, c, t, *args, **kwargs)
878
+
879
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
880
+ def rescale_bbox(bbox):
881
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
882
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
883
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
884
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
885
+ return x0, y0, w, h
886
+
887
+ return [rescale_bbox(b) for b in bboxes]
888
+
889
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
890
+
891
+ if isinstance(cond, dict):
892
+ # hybrid case, cond is exptected to be a dict
893
+ pass
894
+ else:
895
+ if not isinstance(cond, list):
896
+ cond = [cond]
897
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
898
+ cond = {key: cond}
899
+
900
+ if hasattr(self, "split_input_params"):
901
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
902
+ assert not return_ids
903
+ ks = self.split_input_params["ks"] # eg. (128, 128)
904
+ stride = self.split_input_params["stride"] # eg. (64, 64)
905
+
906
+ h, w = x_noisy.shape[-2:]
907
+
908
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
909
+
910
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
911
+ # Reshape to img shape
912
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
913
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
914
+
915
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
916
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
917
+ c_key = next(iter(cond.keys())) # get key
918
+ c = next(iter(cond.values())) # get value
919
+ assert (len(c) == 1) # todo extend to list with more than one elem
920
+ c = c[0] # get element
921
+
922
+ c = unfold(c)
923
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
924
+
925
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
926
+
927
+ elif self.cond_stage_key == 'coordinates_bbox':
928
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
929
+
930
+ # assuming padding of unfold is always 0 and its dilation is always 1
931
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
932
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
933
+ # as we are operating on latents, we need the factor from the original image size to the
934
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
935
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
936
+ rescale_latent = 2 ** (num_downs)
937
+
938
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
939
+ # need to rescale the tl patch coordinates to be in between (0,1)
940
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
941
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
942
+ for patch_nr in range(z.shape[-1])]
943
+
944
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
945
+ patch_limits = [(x_tl, y_tl,
946
+ rescale_latent * ks[0] / full_img_w,
947
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
948
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
949
+
950
+ # tokenize crop coordinates for the bounding boxes of the respective patches
951
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
952
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
953
+ print(patch_limits_tknzd[0].shape)
954
+ # cut tknzd crop position from conditioning
955
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
956
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
957
+ print(cut_cond.shape)
958
+
959
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
960
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
961
+ print(adapted_cond.shape)
962
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
963
+ print(adapted_cond.shape)
964
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
965
+ print(adapted_cond.shape)
966
+
967
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
968
+
969
+ else:
970
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
971
+
972
+ # apply model by loop over crops
973
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
974
+ assert not isinstance(output_list[0],
975
+ tuple) # todo cant deal with multiple model outputs check this never happens
976
+
977
+ o = torch.stack(output_list, axis=-1)
978
+ o = o * weighting
979
+ # Reverse reshape to img shape
980
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
981
+ # stitch crops together
982
+ x_recon = fold(o) / normalization
983
+
984
+ else:
985
+ x_recon = self.model(x_noisy, t, **cond)
986
+
987
+ if isinstance(x_recon, tuple) and not return_ids:
988
+ return x_recon[0]
989
+ else:
990
+ return x_recon
991
+
992
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
993
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
994
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
995
+
996
+ def _prior_bpd(self, x_start):
997
+ """
998
+ Get the prior KL term for the variational lower-bound, measured in
999
+ bits-per-dim.
1000
+ This term can't be optimized, as it only depends on the encoder.
1001
+ :param x_start: the [N x C x ...] tensor of inputs.
1002
+ :return: a batch of [N] KL values (in bits), one per batch element.
1003
+ """
1004
+ batch_size = x_start.shape[0]
1005
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1006
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1007
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1008
+ return mean_flat(kl_prior) / np.log(2.0)
1009
+
1010
+ def p_losses(self, x_start, cond, t, noise=None):
1011
+ noise = default(noise, lambda: torch.randn_like(x_start))
1012
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1013
+ model_output = self.apply_model(x_noisy, t, cond)
1014
+
1015
+ loss_dict = {}
1016
+ prefix = 'train' if self.training else 'val'
1017
+
1018
+ if self.parameterization == "x0":
1019
+ target = x_start
1020
+ elif self.parameterization == "eps":
1021
+ target = noise
1022
+ else:
1023
+ raise NotImplementedError()
1024
+
1025
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1026
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1027
+
1028
+ logvar_t = self.logvar[t].to(self.device)
1029
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1030
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1031
+ if self.learn_logvar:
1032
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1033
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1034
+
1035
+ loss = self.l_simple_weight * loss.mean()
1036
+
1037
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1038
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1039
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1040
+ loss += (self.original_elbo_weight * loss_vlb)
1041
+ loss_dict.update({f'{prefix}/loss': loss})
1042
+
1043
+ return loss, loss_dict
1044
+
1045
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1046
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1047
+ t_in = t
1048
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1049
+
1050
+ if score_corrector is not None:
1051
+ assert self.parameterization == "eps"
1052
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1053
+
1054
+ if return_codebook_ids:
1055
+ model_out, logits = model_out
1056
+
1057
+ if self.parameterization == "eps":
1058
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1059
+ elif self.parameterization == "x0":
1060
+ x_recon = model_out
1061
+ else:
1062
+ raise NotImplementedError()
1063
+
1064
+ if clip_denoised:
1065
+ x_recon.clamp_(-1., 1.)
1066
+ if quantize_denoised:
1067
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1068
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1069
+ if return_codebook_ids:
1070
+ return model_mean, posterior_variance, posterior_log_variance, logits
1071
+ elif return_x0:
1072
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1073
+ else:
1074
+ return model_mean, posterior_variance, posterior_log_variance
1075
+
1076
+ @torch.no_grad()
1077
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1078
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1079
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1080
+ b, *_, device = *x.shape, x.device
1081
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1082
+ return_codebook_ids=return_codebook_ids,
1083
+ quantize_denoised=quantize_denoised,
1084
+ return_x0=return_x0,
1085
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1086
+ if return_codebook_ids:
1087
+ raise DeprecationWarning("Support dropped.")
1088
+ model_mean, _, model_log_variance, logits = outputs
1089
+ elif return_x0:
1090
+ model_mean, _, model_log_variance, x0 = outputs
1091
+ else:
1092
+ model_mean, _, model_log_variance = outputs
1093
+
1094
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1095
+ if noise_dropout > 0.:
1096
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1097
+ # no noise when t == 0
1098
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1099
+
1100
+ if return_codebook_ids:
1101
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1102
+ if return_x0:
1103
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1104
+ else:
1105
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1106
+
1107
+ @torch.no_grad()
1108
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1109
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1110
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1111
+ log_every_t=None):
1112
+ if not log_every_t:
1113
+ log_every_t = self.log_every_t
1114
+ timesteps = self.num_timesteps
1115
+ if batch_size is not None:
1116
+ b = batch_size if batch_size is not None else shape[0]
1117
+ shape = [batch_size] + list(shape)
1118
+ else:
1119
+ b = batch_size = shape[0]
1120
+ if x_T is None:
1121
+ img = torch.randn(shape, device=self.device)
1122
+ else:
1123
+ img = x_T
1124
+ intermediates = []
1125
+ if cond is not None:
1126
+ if isinstance(cond, dict):
1127
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1128
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1129
+ else:
1130
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1131
+
1132
+ if start_T is not None:
1133
+ timesteps = min(timesteps, start_T)
1134
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1135
+ total=timesteps) if verbose else reversed(
1136
+ range(0, timesteps))
1137
+ if type(temperature) == float:
1138
+ temperature = [temperature] * timesteps
1139
+
1140
+ for i in iterator:
1141
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1142
+ if self.shorten_cond_schedule:
1143
+ assert self.model.conditioning_key != 'hybrid'
1144
+ tc = self.cond_ids[ts].to(cond.device)
1145
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1146
+
1147
+ img, x0_partial = self.p_sample(img, cond, ts,
1148
+ clip_denoised=self.clip_denoised,
1149
+ quantize_denoised=quantize_denoised, return_x0=True,
1150
+ temperature=temperature[i], noise_dropout=noise_dropout,
1151
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1152
+ if mask is not None:
1153
+ assert x0 is not None
1154
+ img_orig = self.q_sample(x0, ts)
1155
+ img = img_orig * mask + (1. - mask) * img
1156
+
1157
+ if i % log_every_t == 0 or i == timesteps - 1:
1158
+ intermediates.append(x0_partial)
1159
+ if callback: callback(i)
1160
+ if img_callback: img_callback(img, i)
1161
+ return img, intermediates
1162
+
1163
+ @torch.no_grad()
1164
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1165
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1166
+ mask=None, x0=None, img_callback=None, start_T=None,
1167
+ log_every_t=None):
1168
+
1169
+ if not log_every_t:
1170
+ log_every_t = self.log_every_t
1171
+ device = self.betas.device
1172
+ b = shape[0]
1173
+ if x_T is None:
1174
+ img = torch.randn(shape, device=device)
1175
+ else:
1176
+ img = x_T
1177
+
1178
+ intermediates = [img]
1179
+ if timesteps is None:
1180
+ timesteps = self.num_timesteps
1181
+
1182
+ if start_T is not None:
1183
+ timesteps = min(timesteps, start_T)
1184
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1185
+ range(0, timesteps))
1186
+
1187
+ if mask is not None:
1188
+ assert x0 is not None
1189
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1190
+
1191
+ for i in iterator:
1192
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1193
+ if self.shorten_cond_schedule:
1194
+ assert self.model.conditioning_key != 'hybrid'
1195
+ tc = self.cond_ids[ts].to(cond.device)
1196
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1197
+
1198
+ img = self.p_sample(img, cond, ts,
1199
+ clip_denoised=self.clip_denoised,
1200
+ quantize_denoised=quantize_denoised)
1201
+ if mask is not None:
1202
+ img_orig = self.q_sample(x0, ts)
1203
+ img = img_orig * mask + (1. - mask) * img
1204
+
1205
+ if i % log_every_t == 0 or i == timesteps - 1:
1206
+ intermediates.append(img)
1207
+ if callback: callback(i)
1208
+ if img_callback: img_callback(img, i)
1209
+
1210
+ if return_intermediates:
1211
+ return img, intermediates
1212
+ return img
1213
+
1214
+ @torch.no_grad()
1215
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1216
+ verbose=True, timesteps=None, quantize_denoised=False,
1217
+ mask=None, x0=None, shape=None,**kwargs):
1218
+ if shape is None:
1219
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1220
+ if cond is not None:
1221
+ if isinstance(cond, dict):
1222
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1223
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1224
+ else:
1225
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1226
+ return self.p_sample_loop(cond,
1227
+ shape,
1228
+ return_intermediates=return_intermediates, x_T=x_T,
1229
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1230
+ mask=mask, x0=x0)
1231
+
1232
+ @torch.no_grad()
1233
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1234
+
1235
+ if ddim:
1236
+ ddim_sampler = DDIMSampler(self)
1237
+ shape = (self.channels, self.image_size, self.image_size)
1238
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1239
+ shape,cond,verbose=False,**kwargs)
1240
+
1241
+ else:
1242
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1243
+ return_intermediates=True,**kwargs)
1244
+
1245
+ return samples, intermediates
1246
+
1247
+
1248
+ @torch.no_grad()
1249
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1250
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1251
+ plot_diffusion_rows=True, **kwargs):
1252
+
1253
+ use_ddim = ddim_steps is not None
1254
+
1255
+ log = dict()
1256
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1257
+ return_first_stage_outputs=True,
1258
+ force_c_encode=True,
1259
+ return_original_cond=True,
1260
+ bs=N)
1261
+ N = min(x.shape[0], N)
1262
+ n_row = min(x.shape[0], n_row)
1263
+ log["inputs"] = x
1264
+ log["reconstruction"] = xrec
1265
+ if self.model.conditioning_key is not None:
1266
+ if hasattr(self.cond_stage_model, "decode"):
1267
+ xc = self.cond_stage_model.decode(c)
1268
+ log["conditioning"] = xc
1269
+ elif self.cond_stage_key in ["caption"]:
1270
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1271
+ log["conditioning"] = xc
1272
+ elif self.cond_stage_key == 'class_label':
1273
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1274
+ log['conditioning'] = xc
1275
+ elif isimage(xc):
1276
+ log["conditioning"] = xc
1277
+ if ismap(xc):
1278
+ log["original_conditioning"] = self.to_rgb(xc)
1279
+
1280
+ if plot_diffusion_rows:
1281
+ # get diffusion row
1282
+ diffusion_row = list()
1283
+ z_start = z[:n_row]
1284
+ for t in range(self.num_timesteps):
1285
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1286
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1287
+ t = t.to(self.device).long()
1288
+ noise = torch.randn_like(z_start)
1289
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1290
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1291
+
1292
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1293
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1294
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1295
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1296
+ log["diffusion_row"] = diffusion_grid
1297
+
1298
+ if sample:
1299
+ # get denoise row
1300
+ with self.ema_scope("Plotting"):
1301
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1302
+ ddim_steps=ddim_steps,eta=ddim_eta)
1303
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1304
+ x_samples = self.decode_first_stage(samples)
1305
+ log["samples"] = x_samples
1306
+ if plot_denoise_rows:
1307
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1308
+ log["denoise_row"] = denoise_grid
1309
+
1310
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1311
+ self.first_stage_model, IdentityFirstStage):
1312
+ # also display when quantizing x0 while sampling
1313
+ with self.ema_scope("Plotting Quantized Denoised"):
1314
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1315
+ ddim_steps=ddim_steps,eta=ddim_eta,
1316
+ quantize_denoised=True)
1317
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1318
+ # quantize_denoised=True)
1319
+ x_samples = self.decode_first_stage(samples.to(self.device))
1320
+ log["samples_x0_quantized"] = x_samples
1321
+
1322
+ if inpaint:
1323
+ # make a simple center square
1324
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1325
+ mask = torch.ones(N, h, w).to(self.device)
1326
+ # zeros will be filled in
1327
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1328
+ mask = mask[:, None, ...]
1329
+ with self.ema_scope("Plotting Inpaint"):
1330
+
1331
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1332
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1333
+ x_samples = self.decode_first_stage(samples.to(self.device))
1334
+ log["samples_inpainting"] = x_samples
1335
+ log["mask"] = mask
1336
+
1337
+ # outpaint
1338
+ with self.ema_scope("Plotting Outpaint"):
1339
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1340
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1341
+ x_samples = self.decode_first_stage(samples.to(self.device))
1342
+ log["samples_outpainting"] = x_samples
1343
+
1344
+ if plot_progressive_rows:
1345
+ with self.ema_scope("Plotting Progressives"):
1346
+ img, progressives = self.progressive_denoising(c,
1347
+ shape=(self.channels, self.image_size, self.image_size),
1348
+ batch_size=N)
1349
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1350
+ log["progressive_row"] = prog_row
1351
+
1352
+ if return_keys:
1353
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1354
+ return log
1355
+ else:
1356
+ return {key: log[key] for key in return_keys}
1357
+ return log
1358
+
1359
+ def configure_optimizers(self):
1360
+ lr = self.learning_rate
1361
+ params = list(self.model.parameters())
1362
+ if self.cond_stage_trainable:
1363
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1364
+ params = params + list(self.cond_stage_model.parameters())
1365
+ if self.learn_logvar:
1366
+ print('Diffusion model optimizing logvar')
1367
+ params.append(self.logvar)
1368
+ opt = torch.optim.AdamW(params, lr=lr)
1369
+ if self.use_scheduler:
1370
+ assert 'target' in self.scheduler_config
1371
+ scheduler = instantiate_from_config(self.scheduler_config)
1372
+
1373
+ print("Setting up LambdaLR scheduler...")
1374
+ scheduler = [
1375
+ {
1376
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1377
+ 'interval': 'step',
1378
+ 'frequency': 1
1379
+ }]
1380
+ return [opt], scheduler
1381
+ return opt
1382
+
1383
+ @torch.no_grad()
1384
+ def to_rgb(self, x):
1385
+ x = x.float()
1386
+ if not hasattr(self, "colorize"):
1387
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1388
+ x = nn.functional.conv2d(x, weight=self.colorize)
1389
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1390
+ return x
1391
+
1392
+
1393
+ class DiffusionWrapper(pl.LightningModule):
1394
+ def __init__(self, diff_model_config, conditioning_key):
1395
+ super().__init__()
1396
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1397
+ self.conditioning_key = conditioning_key # 'crossattn' for txt2image, concat for inpainting
1398
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1399
+
1400
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1401
+ """param x: tensor with shape:[B,C,mel_len,T]"""
1402
+ if self.conditioning_key is None:
1403
+ out = self.diffusion_model(x, t)
1404
+ elif self.conditioning_key == 'concat':
1405
+ xc = torch.cat([x] + c_concat, dim=1)# channel dim,x shape (b,3,64,64) c_concat shape(b,4,64,64)
1406
+ out = self.diffusion_model(xc, t)
1407
+ elif self.conditioning_key == 'crossattn':
1408
+ cc = torch.cat(c_crossattn, 1)# [b,seq_len,dim]
1409
+ out = self.diffusion_model(x, t, context=cc)
1410
+ elif self.conditioning_key == 'hybrid':# not implemented in the LatentDiffusion
1411
+ xc = torch.cat([x] + c_concat, dim=1)
1412
+ cc = torch.cat(c_crossattn, 1)
1413
+ out = self.diffusion_model(xc, t, context=cc)
1414
+ elif self.conditioning_key == 'adm':
1415
+ cc = c_crossattn[0]
1416
+ out = self.diffusion_model(x, t, y=cc)
1417
+ else:
1418
+ raise NotImplementedError()
1419
+
1420
+ return out
1421
+
1422
+
1423
+ class Layout2ImgDiffusion(LatentDiffusion):
1424
+ # TODO: move all layout-specific hacks to this class
1425
+ def __init__(self, cond_stage_key, *args, **kwargs):
1426
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1427
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1428
+
1429
+ def log_images(self, batch, N=8, *args, **kwargs):
1430
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1431
+
1432
+ key = 'train' if self.training else 'validation'
1433
+ dset = self.trainer.datamodule.datasets[key]
1434
+ mapper = dset.conditional_builders[self.cond_stage_key]
1435
+
1436
+ bbox_imgs = []
1437
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1438
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1439
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1440
+ bbox_imgs.append(bboximg)
1441
+
1442
+ cond_img = torch.stack(bbox_imgs, dim=0)
1443
+ logs['bbox_image'] = cond_img
1444
+ return logs
ldm/models/diffusion/ddpm_audio.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from tqdm import tqdm
18
+ from torchvision.utils import make_grid
19
+ from pytorch_lightning.utilities.distributed import rank_zero_only
20
+
21
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
22
+ from ldm.modules.ema import LitEma
23
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
25
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
+ from ldm.models.diffusion.ddim import DDIMSampler
27
+ from ldm.models.diffusion.ddpm import DDPM, disabled_train
28
+ from omegaconf import ListConfig
29
+
30
+ __conditioning_keys__ = {'concat': 'c_concat',
31
+ 'crossattn': 'c_crossattn',
32
+ 'adm': 'y'}
33
+
34
+
35
+ class LatentDiffusion_audio(DDPM):
36
+ """main class"""
37
+ def __init__(self,
38
+ first_stage_config,
39
+ cond_stage_config,
40
+ num_timesteps_cond=None,
41
+ mel_dim=80,
42
+ mel_length=848,
43
+ cond_stage_key="image",
44
+ cond_stage_trainable=False,
45
+ concat_mode=True,
46
+ cond_stage_forward=None,
47
+ conditioning_key=None,
48
+ scale_factor=1.0,
49
+ scale_by_std=False,
50
+ *args, **kwargs):
51
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
52
+ self.scale_by_std = scale_by_std
53
+ assert self.num_timesteps_cond <= kwargs['timesteps']
54
+ # for backwards compatibility after implementation of DiffusionWrapper
55
+ if conditioning_key is None:
56
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
57
+ if cond_stage_config == '__is_unconditional__':
58
+ conditioning_key = None
59
+ ckpt_path = kwargs.pop("ckpt_path", None)
60
+ ignore_keys = kwargs.pop("ignore_keys", [])
61
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
62
+ self.concat_mode = concat_mode
63
+ self.mel_dim = mel_dim
64
+ self.mel_length = mel_length
65
+ self.cond_stage_trainable = cond_stage_trainable
66
+ self.cond_stage_key = cond_stage_key
67
+ try:
68
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
69
+ except:
70
+ self.num_downs = 0
71
+ if not scale_by_std:
72
+ self.scale_factor = scale_factor
73
+ else:
74
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
75
+ self.instantiate_first_stage(first_stage_config)
76
+ self.instantiate_cond_stage(cond_stage_config)
77
+ self.cond_stage_forward = cond_stage_forward
78
+ self.clip_denoised = False
79
+ self.bbox_tokenizer = None
80
+
81
+ self.restarted_from_ckpt = False
82
+ if ckpt_path is not None:
83
+ self.init_from_ckpt(ckpt_path, ignore_keys)
84
+ self.restarted_from_ckpt = True
85
+
86
+ def make_cond_schedule(self, ):
87
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
88
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
89
+ self.cond_ids[:self.num_timesteps_cond] = ids
90
+
91
+ @rank_zero_only
92
+ @torch.no_grad()
93
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
94
+ # only for very first batch
95
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
96
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
97
+ # set rescale weight to 1./std of encodings
98
+ print("### USING STD-RESCALING ###")
99
+ x = super().get_input(batch, self.first_stage_key)
100
+ x = x.to(self.device)
101
+ encoder_posterior = self.encode_first_stage(x)
102
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
103
+ del self.scale_factor
104
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
105
+ print(f"setting self.scale_factor to {self.scale_factor}")
106
+ print("### USING STD-RESCALING ###")
107
+
108
+ def register_schedule(self,
109
+ given_betas=None, beta_schedule="linear", timesteps=1000,
110
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
111
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
112
+
113
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
114
+ if self.shorten_cond_schedule:
115
+ self.make_cond_schedule()
116
+
117
+ def instantiate_first_stage(self, config):
118
+ model = instantiate_from_config(config)
119
+ self.first_stage_model = model.eval()
120
+ self.first_stage_model.train = disabled_train
121
+ for param in self.first_stage_model.parameters():
122
+ param.requires_grad = False
123
+
124
+ def instantiate_cond_stage(self, config):
125
+ if not self.cond_stage_trainable:
126
+ if config == "__is_first_stage__":
127
+ print("Using first stage also as cond stage.")
128
+ self.cond_stage_model = self.first_stage_model
129
+ elif config == "__is_unconditional__":
130
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
131
+ self.cond_stage_model = None
132
+ # self.be_unconditional = True
133
+ else:
134
+ model = instantiate_from_config(config)
135
+ self.cond_stage_model = model.eval()
136
+ self.cond_stage_model.train = disabled_train
137
+ for param in self.cond_stage_model.parameters():
138
+ param.requires_grad = False
139
+ else:
140
+ assert config != '__is_first_stage__'
141
+ assert config != '__is_unconditional__'
142
+ model = instantiate_from_config(config)
143
+ self.cond_stage_model = model
144
+
145
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
146
+ denoise_row = []
147
+ for zd in tqdm(samples, desc=desc):
148
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
149
+ force_not_quantize=force_no_decoder_quantization))
150
+ n_imgs_per_row = len(denoise_row)
151
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
152
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
153
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
154
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
155
+ return denoise_grid
156
+
157
+ def get_first_stage_encoding(self, encoder_posterior):
158
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
159
+ z = encoder_posterior.sample()
160
+ elif isinstance(encoder_posterior, torch.Tensor):
161
+ z = encoder_posterior
162
+ else:
163
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
164
+ return self.scale_factor * z
165
+
166
+ def get_learned_conditioning(self, c):
167
+ if self.cond_stage_forward is None:
168
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
169
+ c = self.cond_stage_model.encode(c)
170
+ if isinstance(c, DiagonalGaussianDistribution):
171
+ c = c.mode()
172
+ else:
173
+ c = self.cond_stage_model(c)
174
+ else:
175
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
176
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
177
+ return c
178
+
179
+
180
+ @torch.no_grad()
181
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
182
+ if null_label is not None:
183
+ xc = null_label
184
+ if isinstance(xc, ListConfig):
185
+ xc = list(xc)
186
+ if isinstance(xc, dict) or isinstance(xc, list):
187
+ c = self.get_learned_conditioning(xc)
188
+ else:
189
+ if hasattr(xc, "to"):
190
+ xc = xc.to(self.device)
191
+ c = self.get_learned_conditioning(xc)
192
+ else:
193
+ if self.cond_stage_key in ["class_label", "cls"]:
194
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
195
+ return self.get_learned_conditioning(xc)
196
+ else:
197
+ raise NotImplementedError("todo")
198
+ if isinstance(c, list): # in case the encoder gives us a list
199
+ for i in range(len(c)):
200
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
201
+ else:
202
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
203
+ return c
204
+
205
+ def meshgrid(self, h, w):
206
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
207
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
208
+
209
+ arr = torch.cat([y, x], dim=-1)
210
+ return arr
211
+
212
+ def delta_border(self, h, w):
213
+ """
214
+ :param h: height
215
+ :param w: width
216
+ :return: normalized distance to image border,
217
+ wtith min distance = 0 at border and max dist = 0.5 at image center
218
+ """
219
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
220
+ arr = self.meshgrid(h, w) / lower_right_corner
221
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
222
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
223
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
224
+ return edge_dist
225
+
226
+ def get_weighting(self, h, w, Ly, Lx, device):
227
+ weighting = self.delta_border(h, w)
228
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
229
+ self.split_input_params["clip_max_weight"], )
230
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
231
+
232
+ if self.split_input_params["tie_braker"]:
233
+ L_weighting = self.delta_border(Ly, Lx)
234
+ L_weighting = torch.clip(L_weighting,
235
+ self.split_input_params["clip_min_tie_weight"],
236
+ self.split_input_params["clip_max_tie_weight"])
237
+
238
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
239
+ weighting = weighting * L_weighting
240
+ return weighting
241
+
242
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
243
+ """
244
+ :param x: img of size (bs, c, h, w)
245
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
246
+ """
247
+ bs, nc, h, w = x.shape
248
+
249
+ # number of crops in image
250
+ Ly = (h - kernel_size[0]) // stride[0] + 1
251
+ Lx = (w - kernel_size[1]) // stride[1] + 1
252
+
253
+ if uf == 1 and df == 1:
254
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
255
+ unfold = torch.nn.Unfold(**fold_params)
256
+
257
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
258
+
259
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
260
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
261
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
262
+
263
+ elif uf > 1 and df == 1:
264
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
265
+ unfold = torch.nn.Unfold(**fold_params)
266
+
267
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
268
+ dilation=1, padding=0,
269
+ stride=(stride[0] * uf, stride[1] * uf))
270
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
271
+
272
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
273
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
274
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
275
+
276
+ elif df > 1 and uf == 1:
277
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
278
+ unfold = torch.nn.Unfold(**fold_params)
279
+
280
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
281
+ dilation=1, padding=0,
282
+ stride=(stride[0] // df, stride[1] // df))
283
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
284
+
285
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
286
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
287
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
288
+
289
+ else:
290
+ raise NotImplementedError
291
+
292
+ return fold, unfold, normalization, weighting
293
+
294
+ @torch.no_grad()
295
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
296
+ cond_key=None, return_original_cond=False, bs=None):
297
+ x = super().get_input(batch, k)
298
+ if bs is not None:
299
+ x = x[:bs]
300
+ x = x.to(self.device)
301
+ encoder_posterior = self.encode_first_stage(x)
302
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
303
+
304
+ if self.model.conditioning_key is not None:
305
+ if cond_key is None:
306
+ cond_key = self.cond_stage_key
307
+ if cond_key != self.first_stage_key:
308
+ if cond_key in ['caption', 'coordinates_bbox']:
309
+ xc = batch[cond_key]
310
+ elif cond_key == 'class_label':
311
+ xc = batch
312
+ else:
313
+ xc = super().get_input(batch, cond_key).to(self.device)
314
+ else:
315
+ xc = x
316
+ if not self.cond_stage_trainable or force_c_encode:
317
+ if isinstance(xc, dict) or isinstance(xc, list):
318
+ # import pudb; pudb.set_trace()
319
+ c = self.get_learned_conditioning(xc)
320
+ else:
321
+ c = self.get_learned_conditioning(xc.to(self.device))
322
+ else:
323
+ c = xc
324
+ if bs is not None:
325
+ c = c[:bs]
326
+ # Testing #
327
+ if cond_key == 'masked_image':
328
+ mask = super().get_input(batch, "mask")
329
+ cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # [B, 1, 10, 106]
330
+ c = torch.cat((c, cc), dim=1) # [B, 5, 10, 106]
331
+ # Testing #
332
+ if self.use_positional_encodings:
333
+ pos_x, pos_y = self.compute_latent_shifts(batch)
334
+ ckey = __conditioning_keys__[self.model.conditioning_key]
335
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
336
+
337
+ else:
338
+ c = None
339
+ xc = None
340
+ if self.use_positional_encodings:
341
+ pos_x, pos_y = self.compute_latent_shifts(batch)
342
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
343
+ out = [z, c]
344
+ if return_first_stage_outputs:
345
+ xrec = self.decode_first_stage(z)
346
+ out.extend([x, xrec])
347
+ if return_original_cond:
348
+ out.append(xc)
349
+ return out
350
+
351
+ @torch.no_grad()
352
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
353
+ if predict_cids:
354
+ if z.dim() == 4:
355
+ z = torch.argmax(z.exp(), dim=1).long()
356
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
357
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
358
+
359
+ z = 1. / self.scale_factor * z
360
+
361
+ if hasattr(self, "split_input_params"):
362
+ if self.split_input_params["patch_distributed_vq"]:
363
+ ks = self.split_input_params["ks"] # eg. (128, 128)
364
+ stride = self.split_input_params["stride"] # eg. (64, 64)
365
+ uf = self.split_input_params["vqf"]
366
+ bs, nc, h, w = z.shape
367
+ if ks[0] > h or ks[1] > w:
368
+ ks = (min(ks[0], h), min(ks[1], w))
369
+ print("reducing Kernel")
370
+
371
+ if stride[0] > h or stride[1] > w:
372
+ stride = (min(stride[0], h), min(stride[1], w))
373
+ print("reducing stride")
374
+
375
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
376
+
377
+ z = unfold(z) # (bn, nc * prod(**ks), L)
378
+ # 1. Reshape to img shape
379
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
380
+
381
+ # 2. apply model loop over last dim
382
+ if isinstance(self.first_stage_model, VQModelInterface):
383
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
384
+ force_not_quantize=predict_cids or force_not_quantize)
385
+ for i in range(z.shape[-1])]
386
+ else:
387
+
388
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
389
+ for i in range(z.shape[-1])]
390
+
391
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
392
+ o = o * weighting
393
+ # Reverse 1. reshape to img shape
394
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
395
+ # stitch crops together
396
+ decoded = fold(o)
397
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
398
+ return decoded
399
+ else:
400
+ if isinstance(self.first_stage_model, VQModelInterface):
401
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
402
+ else:
403
+ return self.first_stage_model.decode(z)
404
+
405
+ else:
406
+ if isinstance(self.first_stage_model, VQModelInterface):
407
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
408
+ else:
409
+ return self.first_stage_model.decode(z)
410
+
411
+ # same as above but without decorator
412
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
413
+ if predict_cids:
414
+ if z.dim() == 4:
415
+ z = torch.argmax(z.exp(), dim=1).long()
416
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
417
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
418
+
419
+ z = 1. / self.scale_factor * z
420
+
421
+ if hasattr(self, "split_input_params"):
422
+ if self.split_input_params["patch_distributed_vq"]:
423
+ ks = self.split_input_params["ks"] # eg. (128, 128)
424
+ stride = self.split_input_params["stride"] # eg. (64, 64)
425
+ uf = self.split_input_params["vqf"]
426
+ bs, nc, h, w = z.shape
427
+ if ks[0] > h or ks[1] > w:
428
+ ks = (min(ks[0], h), min(ks[1], w))
429
+ print("reducing Kernel")
430
+
431
+ if stride[0] > h or stride[1] > w:
432
+ stride = (min(stride[0], h), min(stride[1], w))
433
+ print("reducing stride")
434
+
435
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
436
+
437
+ z = unfold(z) # (bn, nc * prod(**ks), L)
438
+ # 1. Reshape to img shape
439
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
440
+
441
+ # 2. apply model loop over last dim
442
+ if isinstance(self.first_stage_model, VQModelInterface):
443
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
444
+ force_not_quantize=predict_cids or force_not_quantize)
445
+ for i in range(z.shape[-1])]
446
+ else:
447
+
448
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
449
+ for i in range(z.shape[-1])]
450
+
451
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
452
+ o = o * weighting
453
+ # Reverse 1. reshape to img shape
454
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
455
+ # stitch crops together
456
+ decoded = fold(o)
457
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
458
+ return decoded
459
+ else:
460
+ if isinstance(self.first_stage_model, VQModelInterface):
461
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
462
+ else:
463
+ return self.first_stage_model.decode(z)
464
+
465
+ else:
466
+ if isinstance(self.first_stage_model, VQModelInterface):
467
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
468
+ else:
469
+ return self.first_stage_model.decode(z)
470
+
471
+ @torch.no_grad()
472
+ def encode_first_stage(self, x):
473
+ if hasattr(self, "split_input_params"):
474
+ if self.split_input_params["patch_distributed_vq"]:
475
+ ks = self.split_input_params["ks"] # eg. (128, 128)
476
+ stride = self.split_input_params["stride"] # eg. (64, 64)
477
+ df = self.split_input_params["vqf"]
478
+ self.split_input_params['original_image_size'] = x.shape[-2:]
479
+ bs, nc, h, w = x.shape
480
+ if ks[0] > h or ks[1] > w:
481
+ ks = (min(ks[0], h), min(ks[1], w))
482
+ print("reducing Kernel")
483
+
484
+ if stride[0] > h or stride[1] > w:
485
+ stride = (min(stride[0], h), min(stride[1], w))
486
+ print("reducing stride")
487
+
488
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
489
+ z = unfold(x) # (bn, nc * prod(**ks), L)
490
+ # Reshape to img shape
491
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
492
+
493
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
494
+ for i in range(z.shape[-1])]
495
+
496
+ o = torch.stack(output_list, axis=-1)
497
+ o = o * weighting
498
+
499
+ # Reverse reshape to img shape
500
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
501
+ # stitch crops together
502
+ decoded = fold(o)
503
+ decoded = decoded / normalization
504
+ return decoded
505
+
506
+ else:
507
+ return self.first_stage_model.encode(x)
508
+ else:
509
+ return self.first_stage_model.encode(x)
510
+
511
+ def shared_step(self, batch, **kwargs):
512
+ x, c = self.get_input(batch, self.first_stage_key)
513
+ loss = self(x, c)
514
+ return loss
515
+
516
+ def test_step(self,batch,batch_idx):
517
+ cond = batch[self.cond_stage_key] * self.test_repeat
518
+ cond = self.get_learned_conditioning(cond) # c: string -> [B, T, Context_dim]
519
+ batch_size = len(cond)
520
+ enc_emb = self.sample(cond,batch_size,timesteps=self.test_numsteps)# shape = [batch_size,self.channels,self.mel_dim,self.mel_length]
521
+ xrec = self.decode_first_stage(enc_emb)
522
+ reconstructions = (xrec + 1)/2 # to mel scale
523
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
524
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
525
+ if not os.path.exists(savedir):
526
+ os.makedirs(savedir)
527
+
528
+ file_names = batch['f_name']
529
+ nfiles = len(file_names)
530
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
531
+ for k in range(reconstructions.shape[0]):
532
+ b,repeat = k % nfiles, k // nfiles
533
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
534
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
535
+ save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}_{repeat}.npy')# the num_th caption, the repeat_th repitition
536
+ np.save(save_img_path,reconstructions[b])
537
+
538
+ return None
539
+
540
+ def forward(self, x, c, *args, **kwargs):
541
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
542
+ if self.model.conditioning_key is not None:
543
+ assert c is not None
544
+ if self.cond_stage_trainable:
545
+ c = self.get_learned_conditioning(c) # c: string -> [B, T, Context_dim]
546
+ if self.shorten_cond_schedule: # TODO: drop this option
547
+ tc = self.cond_ids[t].to(self.device)
548
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
549
+ return self.p_losses(x, c, t, *args, **kwargs)
550
+
551
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
552
+ def rescale_bbox(bbox):
553
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
554
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
555
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
556
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
557
+ return x0, y0, w, h
558
+
559
+ return [rescale_bbox(b) for b in bboxes]
560
+
561
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
562
+
563
+ if isinstance(cond, dict):
564
+ # hybrid case, cond is exptected to be a dict
565
+ pass
566
+ else:
567
+ if not isinstance(cond, list):
568
+ cond = [cond]
569
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
570
+ cond = {key: cond}
571
+
572
+ if hasattr(self, "split_input_params"):
573
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
574
+ assert not return_ids
575
+ ks = self.split_input_params["ks"] # eg. (128, 128)
576
+ stride = self.split_input_params["stride"] # eg. (64, 64)
577
+
578
+ h, w = x_noisy.shape[-2:]
579
+
580
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
581
+
582
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
583
+ # Reshape to img shape
584
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
585
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
586
+
587
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
588
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
589
+ c_key = next(iter(cond.keys())) # get key
590
+ c = next(iter(cond.values())) # get value
591
+ assert (len(c) == 1) # todo extend to list with more than one elem
592
+ c = c[0] # get element
593
+
594
+ c = unfold(c)
595
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
596
+
597
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
598
+
599
+ elif self.cond_stage_key == 'coordinates_bbox':
600
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
601
+
602
+ # assuming padding of unfold is always 0 and its dilation is always 1
603
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
604
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
605
+ # as we are operating on latents, we need the factor from the original image size to the
606
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
607
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
608
+ rescale_latent = 2 ** (num_downs)
609
+
610
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
611
+ # need to rescale the tl patch coordinates to be in between (0,1)
612
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
613
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
614
+ for patch_nr in range(z.shape[-1])]
615
+
616
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
617
+ patch_limits = [(x_tl, y_tl,
618
+ rescale_latent * ks[0] / full_img_w,
619
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
620
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
621
+
622
+ # tokenize crop coordinates for the bounding boxes of the respective patches
623
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
624
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
625
+ print(patch_limits_tknzd[0].shape)
626
+ # cut tknzd crop position from conditioning
627
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
628
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
629
+ print(cut_cond.shape)
630
+
631
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
632
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
633
+ print(adapted_cond.shape)
634
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
635
+ print(adapted_cond.shape)
636
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
637
+ print(adapted_cond.shape)
638
+
639
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
640
+
641
+ else:
642
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
643
+
644
+ # apply model by loop over crops
645
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
646
+ assert not isinstance(output_list[0],
647
+ tuple) # todo cant deal with multiple model outputs check this never happens
648
+
649
+ o = torch.stack(output_list, axis=-1)
650
+ o = o * weighting
651
+ # Reverse reshape to img shape
652
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
653
+ # stitch crops together
654
+ x_recon = fold(o) / normalization
655
+
656
+ else:
657
+ x_recon = self.model(x_noisy, t, **cond)
658
+
659
+ if isinstance(x_recon, tuple) and not return_ids:
660
+ return x_recon[0]
661
+ else:
662
+ return x_recon
663
+
664
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
665
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
666
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
667
+
668
+ def _prior_bpd(self, x_start):
669
+ """
670
+ Get the prior KL term for the variational lower-bound, measured in
671
+ bits-per-dim.
672
+ This term can't be optimized, as it only depends on the encoder.
673
+ :param x_start: the [N x C x ...] tensor of inputs.
674
+ :return: a batch of [N] KL values (in bits), one per batch element.
675
+ """
676
+ batch_size = x_start.shape[0]
677
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
678
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
679
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
680
+ return mean_flat(kl_prior) / np.log(2.0)
681
+
682
+ def p_losses(self, x_start, cond, t, noise=None):
683
+ noise = default(noise, lambda: torch.randn_like(x_start))
684
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
685
+ model_output = self.apply_model(x_noisy, t, cond)
686
+
687
+ loss_dict = {}
688
+ prefix = 'train' if self.training else 'val'
689
+
690
+ if self.parameterization == "x0":
691
+ target = x_start
692
+ elif self.parameterization == "eps":
693
+ target = noise
694
+ else:
695
+ raise NotImplementedError()
696
+
697
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
698
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
699
+
700
+ logvar_t = self.logvar[t].to(self.device)
701
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
702
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
703
+ if self.learn_logvar:
704
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
705
+ loss_dict.update({'logvar': self.logvar.data.mean()})
706
+
707
+ loss = self.l_simple_weight * loss.mean()
708
+
709
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
710
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
711
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
712
+ loss += (self.original_elbo_weight * loss_vlb)
713
+ loss_dict.update({f'{prefix}/loss': loss})
714
+
715
+ return loss, loss_dict
716
+
717
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
718
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
719
+ t_in = t
720
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
721
+
722
+ if score_corrector is not None:
723
+ assert self.parameterization == "eps"
724
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
725
+
726
+ if return_codebook_ids:
727
+ model_out, logits = model_out
728
+
729
+ if self.parameterization == "eps":
730
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
731
+ elif self.parameterization == "x0":
732
+ x_recon = model_out
733
+ else:
734
+ raise NotImplementedError()
735
+
736
+ if clip_denoised:
737
+ x_recon.clamp_(-1., 1.)
738
+ if quantize_denoised:
739
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
740
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
741
+ if return_codebook_ids:
742
+ return model_mean, posterior_variance, posterior_log_variance, logits
743
+ elif return_x0:
744
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
745
+ else:
746
+ return model_mean, posterior_variance, posterior_log_variance
747
+
748
+ @torch.no_grad()
749
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
750
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
751
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
752
+ b, *_, device = *x.shape, x.device
753
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
754
+ return_codebook_ids=return_codebook_ids,
755
+ quantize_denoised=quantize_denoised,
756
+ return_x0=return_x0,
757
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
758
+ if return_codebook_ids:
759
+ raise DeprecationWarning("Support dropped.")
760
+ model_mean, _, model_log_variance, logits = outputs
761
+ elif return_x0:
762
+ model_mean, _, model_log_variance, x0 = outputs
763
+ else:
764
+ model_mean, _, model_log_variance = outputs
765
+
766
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
767
+ if noise_dropout > 0.:
768
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
769
+ # no noise when t == 0
770
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
771
+
772
+ if return_codebook_ids:
773
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
774
+ if return_x0:
775
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
776
+ else:
777
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
778
+
779
+ @torch.no_grad()
780
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
781
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
782
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
783
+ log_every_t=None):
784
+ if not log_every_t:
785
+ log_every_t = self.log_every_t
786
+ timesteps = self.num_timesteps
787
+ if batch_size is not None:
788
+ b = batch_size if batch_size is not None else shape[0]
789
+ shape = [batch_size] + list(shape)
790
+ else:
791
+ b = batch_size = shape[0]
792
+ if x_T is None:
793
+ img = torch.randn(shape, device=self.device)
794
+ else:
795
+ img = x_T
796
+ intermediates = []
797
+ if cond is not None:
798
+ if isinstance(cond, dict):
799
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
800
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
801
+ else:
802
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
803
+
804
+ if start_T is not None:
805
+ timesteps = min(timesteps, start_T)
806
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
807
+ total=timesteps) if verbose else reversed(
808
+ range(0, timesteps))
809
+ if type(temperature) == float:
810
+ temperature = [temperature] * timesteps
811
+
812
+ for i in iterator:
813
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
814
+ if self.shorten_cond_schedule:
815
+ assert self.model.conditioning_key != 'hybrid'
816
+ tc = self.cond_ids[ts].to(cond.device)
817
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
818
+
819
+ img, x0_partial = self.p_sample(img, cond, ts,
820
+ clip_denoised=self.clip_denoised,
821
+ quantize_denoised=quantize_denoised, return_x0=True,
822
+ temperature=temperature[i], noise_dropout=noise_dropout,
823
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
824
+ if mask is not None:
825
+ assert x0 is not None
826
+ img_orig = self.q_sample(x0, ts)
827
+ img = img_orig * mask + (1. - mask) * img
828
+
829
+ if i % log_every_t == 0 or i == timesteps - 1:
830
+ intermediates.append(x0_partial)
831
+ if callback: callback(i)
832
+ if img_callback: img_callback(img, i)
833
+ return img, intermediates
834
+
835
+ @torch.no_grad()
836
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
837
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
838
+ mask=None, x0=None, img_callback=None, start_T=None,
839
+ log_every_t=None):
840
+
841
+ if not log_every_t:
842
+ log_every_t = self.log_every_t
843
+ device = self.betas.device
844
+ b = shape[0]
845
+ if x_T is None:
846
+ img = torch.randn(shape, device=device)
847
+ else:
848
+ img = x_T
849
+
850
+ intermediates = [img]
851
+ if timesteps is None:
852
+ timesteps = self.num_timesteps
853
+
854
+ if start_T is not None:
855
+ timesteps = min(timesteps, start_T)
856
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
857
+ range(0, timesteps))
858
+
859
+ if mask is not None:
860
+ assert x0 is not None
861
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
862
+
863
+ for i in iterator:
864
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
865
+ if self.shorten_cond_schedule:
866
+ assert self.model.conditioning_key != 'hybrid'
867
+ tc = self.cond_ids[ts].to(cond.device)
868
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
869
+
870
+ img = self.p_sample(img, cond, ts,
871
+ clip_denoised=self.clip_denoised,
872
+ quantize_denoised=quantize_denoised)
873
+ if mask is not None:
874
+ img_orig = self.q_sample(x0, ts)
875
+ img = img_orig * mask + (1. - mask) * img
876
+
877
+ if i % log_every_t == 0 or i == timesteps - 1:
878
+ intermediates.append(img)
879
+ if callback: callback(i)
880
+ if img_callback: img_callback(img, i)
881
+
882
+ if return_intermediates:
883
+ return img, intermediates
884
+ return img
885
+
886
+ @torch.no_grad()
887
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
888
+ verbose=True, timesteps=None, quantize_denoised=False,
889
+ mask=None, x0=None, shape=None,**kwargs):
890
+ if shape is None:
891
+ shape = (batch_size, self.channels, self.mel_dim, self.mel_length)
892
+ if cond is not None:
893
+ if isinstance(cond, dict):
894
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
895
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
896
+ else:
897
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
898
+ return self.p_sample_loop(cond,
899
+ shape,
900
+ return_intermediates=return_intermediates, x_T=x_T,
901
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
902
+ mask=mask, x0=x0)
903
+
904
+ @torch.no_grad()
905
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
906
+
907
+ if ddim:
908
+ ddim_sampler = DDIMSampler(self)
909
+ shape = (self.channels, self.mel_dim, self.mel_length)
910
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
911
+ shape,cond,verbose=False,**kwargs)
912
+
913
+ else:
914
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
915
+ return_intermediates=True,**kwargs)
916
+
917
+ return samples, intermediates
918
+
919
+
920
+ @torch.no_grad()
921
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
922
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
923
+ plot_diffusion_rows=True, **kwargs):
924
+
925
+ use_ddim = ddim_steps is not None
926
+
927
+ log = dict()
928
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
929
+ return_first_stage_outputs=True,
930
+ force_c_encode=True,
931
+ return_original_cond=True,
932
+ bs=N)
933
+ N = min(x.shape[0], N)
934
+ n_row = min(x.shape[0], n_row)
935
+ log["inputs"] = x
936
+ log["reconstruction"] = xrec
937
+ if self.model.conditioning_key is not None:
938
+ if hasattr(self.cond_stage_model, "decode") and self.cond_stage_key != "masked_image":
939
+ xc = self.cond_stage_model.decode(c)
940
+ log["conditioning"] = xc
941
+ elif self.cond_stage_key == "masked_image":
942
+ log["mask"] = c[:, -1, :, :][:, None, :, :]
943
+ xc = self.cond_stage_model.decode(c[:, :self.cond_stage_model.embed_dim, :, :])
944
+ log["conditioning"] = xc
945
+ elif self.cond_stage_key in ["caption"]:
946
+ xc = log_txt_as_img((256, 256), batch["caption"])
947
+ log["conditioning"] = xc
948
+ elif self.cond_stage_key == 'class_label':
949
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
950
+ log['conditioning'] = xc
951
+ elif isimage(xc):
952
+ log["conditioning"] = xc
953
+ if ismap(xc):
954
+ log["original_conditioning"] = self.to_rgb(xc)
955
+
956
+ if plot_diffusion_rows:
957
+ # get diffusion row
958
+ diffusion_row = list()
959
+ z_start = z[:n_row]
960
+ for t in range(self.num_timesteps):
961
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
962
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
963
+ t = t.to(self.device).long()
964
+ noise = torch.randn_like(z_start)
965
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
966
+ diffusion_row.append(self.decode_first_stage(z_noisy))
967
+
968
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
969
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
970
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
971
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
972
+ log["diffusion_row"] = diffusion_grid
973
+
974
+ if sample:
975
+ # get denoise row
976
+ with self.ema_scope("Plotting"):
977
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
978
+ ddim_steps=ddim_steps,eta=ddim_eta)
979
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
980
+ x_samples = self.decode_first_stage(samples)
981
+ log["samples"] = x_samples
982
+ if plot_denoise_rows:
983
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
984
+ log["denoise_row"] = denoise_grid
985
+
986
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
987
+ self.first_stage_model, IdentityFirstStage):
988
+ # also display when quantizing x0 while sampling
989
+ with self.ema_scope("Plotting Quantized Denoised"):
990
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
991
+ ddim_steps=ddim_steps,eta=ddim_eta,
992
+ quantize_denoised=True)
993
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
994
+ # quantize_denoised=True)
995
+ x_samples = self.decode_first_stage(samples.to(self.device))
996
+ log["samples_x0_quantized"] = x_samples
997
+
998
+ if inpaint:
999
+ # make a simple center square
1000
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1001
+ mask = torch.ones(N, h, w).to(self.device)
1002
+ # zeros will be filled in
1003
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1004
+ mask = mask[:, None, ...]
1005
+ with self.ema_scope("Plotting Inpaint"):
1006
+
1007
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1008
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1009
+ x_samples = self.decode_first_stage(samples.to(self.device))
1010
+ log["samples_inpainting"] = x_samples
1011
+ log["mask_inpainting"] = mask
1012
+
1013
+ # outpaint
1014
+ mask = 1 - mask
1015
+ with self.ema_scope("Plotting Outpaint"):
1016
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1017
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1018
+ x_samples = self.decode_first_stage(samples.to(self.device))
1019
+ log["samples_outpainting"] = x_samples
1020
+ log["mask_outpainting"] = mask
1021
+
1022
+ if plot_progressive_rows:
1023
+ with self.ema_scope("Plotting Progressives"):
1024
+ img, progressives = self.progressive_denoising(c,
1025
+ shape=(self.channels, self.mel_dim, self.mel_length),
1026
+ batch_size=N)
1027
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1028
+ log["progressive_row"] = prog_row
1029
+
1030
+ if return_keys:
1031
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1032
+ return log
1033
+ else:
1034
+ return {key: log[key] for key in return_keys}
1035
+ return log
1036
+
1037
+ def configure_optimizers(self):
1038
+ lr = self.learning_rate
1039
+ params = list(self.model.parameters())
1040
+ if self.cond_stage_trainable:
1041
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1042
+ params = params + list(self.cond_stage_model.parameters())
1043
+ if self.learn_logvar:
1044
+ print('Diffusion model optimizing logvar')
1045
+ params.append(self.logvar)
1046
+ opt = torch.optim.AdamW(params, lr=lr)
1047
+ if self.use_scheduler:
1048
+ assert 'target' in self.scheduler_config
1049
+ scheduler = instantiate_from_config(self.scheduler_config)
1050
+
1051
+ print("Setting up LambdaLR scheduler...")
1052
+ scheduler = [
1053
+ {
1054
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1055
+ 'interval': 'step',
1056
+ 'frequency': 1
1057
+ }]
1058
+ return [opt], scheduler
1059
+ return opt
1060
+
1061
+ @torch.no_grad()
1062
+ def to_rgb(self, x):
1063
+ x = x.float()
1064
+ if not hasattr(self, "colorize"):
1065
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1066
+ x = nn.functional.conv2d(x, weight=self.colorize)
1067
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1068
+ return x
1069
+
1070
+
1071
+ class LatentFinetuneDiffusion(LatentDiffusion_audio):
1072
+ """
1073
+ Basis for different finetunas, such as inpainting or depth2image
1074
+ To disable finetuning mode, set finetune_keys to None
1075
+ """
1076
+
1077
+ def __init__(self,
1078
+ concat_keys: tuple,
1079
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1080
+ "model_ema.diffusion_modelinput_blocks00weight"
1081
+ ),
1082
+ keep_finetune_dims=4,
1083
+ # if model was trained without concat mode before and we would like to keep these channels
1084
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
1085
+ c_concat_log_end=None,
1086
+ *args, **kwargs
1087
+ ):
1088
+ ckpt_path = kwargs.pop("ckpt_path", None)
1089
+ ignore_keys = kwargs.pop("ignore_keys", list())
1090
+ super().__init__(*args, **kwargs)
1091
+ self.finetune_keys = finetune_keys
1092
+ self.concat_keys = concat_keys
1093
+ self.keep_dims = keep_finetune_dims
1094
+ self.c_concat_log_start = c_concat_log_start
1095
+ self.c_concat_log_end = c_concat_log_end
1096
+
1097
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1098
+ if exists(ckpt_path):
1099
+ self.init_from_ckpt(ckpt_path, ignore_keys)
1100
+
1101
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1102
+ sd = torch.load(path, map_location="cpu")
1103
+ if "state_dict" in list(sd.keys()):
1104
+ sd = sd["state_dict"]
1105
+ keys = list(sd.keys())
1106
+
1107
+ for k in keys:
1108
+ for ik in ignore_keys:
1109
+ if k.startswith(ik):
1110
+ print("Deleting key {} from state_dict.".format(k))
1111
+ del sd[k]
1112
+
1113
+ # make it explicit, finetune by including extra input channels
1114
+ if exists(self.finetune_keys) and k in self.finetune_keys:
1115
+ new_entry = None
1116
+ for name, param in self.named_parameters():
1117
+ if name in self.finetune_keys:
1118
+ print(
1119
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1120
+ new_entry = torch.zeros_like(param) # zero init
1121
+ assert exists(new_entry), 'did not find matching parameter to modify'
1122
+ new_entry[:, :self.keep_dims, ...] = sd[k]
1123
+ sd[k] = new_entry
1124
+
1125
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
1126
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1127
+ if len(missing) > 0:
1128
+ print(f"Missing Keys: {missing}")
1129
+ if len(unexpected) > 0:
1130
+ print(f"Unexpected Keys: {unexpected}")
1131
+
1132
+ @torch.no_grad()
1133
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1134
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1135
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1136
+ use_ema_scope=True,
1137
+ **kwargs):
1138
+ use_ddim = ddim_steps is not None
1139
+
1140
+ log = dict()
1141
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1142
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1143
+ N = min(x.shape[0], N)
1144
+ n_row = min(x.shape[0], n_row)
1145
+ log["inputs"] = x
1146
+ log["reconstruction"] = xrec
1147
+ if self.model.conditioning_key is not None:
1148
+ if hasattr(self.cond_stage_model, "decode"):
1149
+ xc = self.cond_stage_model.decode(c)
1150
+ log["conditioning"] = xc
1151
+ elif self.cond_stage_key in ["caption"]:
1152
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1153
+ log["conditioning"] = xc
1154
+ elif self.cond_stage_key == 'class_label':
1155
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1156
+ log['conditioning'] = xc
1157
+ elif isimage(xc):
1158
+ log["conditioning"] = xc
1159
+ if ismap(xc):
1160
+ log["original_conditioning"] = self.to_rgb(xc)
1161
+
1162
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1163
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
1164
+
1165
+ if plot_diffusion_rows:
1166
+ # get diffusion row
1167
+ diffusion_row = list()
1168
+ z_start = z[:n_row]
1169
+ for t in range(self.num_timesteps):
1170
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1171
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1172
+ t = t.to(self.device).long()
1173
+ noise = torch.randn_like(z_start)
1174
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1175
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1176
+
1177
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1178
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1179
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1180
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1181
+ log["diffusion_row"] = diffusion_grid
1182
+
1183
+ if sample:
1184
+ # get denoise row
1185
+ with self.ema_scope("Sampling"):
1186
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1187
+ batch_size=N, ddim=use_ddim,
1188
+ ddim_steps=ddim_steps, eta=ddim_eta)
1189
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1190
+ x_samples = self.decode_first_stage(samples)
1191
+ log["samples"] = x_samples
1192
+ if plot_denoise_rows:
1193
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1194
+ log["denoise_row"] = denoise_grid
1195
+
1196
+ if unconditional_guidance_scale > 1.0:
1197
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1198
+ uc_cat = c_cat
1199
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1200
+ with self.ema_scope("Sampling with classifier-free guidance"):
1201
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1202
+ batch_size=N, ddim=use_ddim,
1203
+ ddim_steps=ddim_steps, eta=ddim_eta,
1204
+ unconditional_guidance_scale=unconditional_guidance_scale,
1205
+ unconditional_conditioning=uc_full,
1206
+ )
1207
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1208
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1209
+
1210
+ return log
1211
+
1212
+
1213
+ class LatentInpaintDiffusion(LatentFinetuneDiffusion):
1214
+ """
1215
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1216
+ e.g. mask as concat and text via cross-attn.
1217
+ To disable finetuning mode, set finetune_keys to None
1218
+ """
1219
+
1220
+ def __init__(self,
1221
+ concat_keys=("mask", "masked_image"),
1222
+ masked_image_key="masked_image",
1223
+ *args, **kwargs
1224
+ ):
1225
+ super().__init__(concat_keys, *args, **kwargs)
1226
+ self.masked_image_key = masked_image_key
1227
+ assert self.masked_image_key in concat_keys
1228
+
1229
+ @torch.no_grad()
1230
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1231
+ # note: restricted to non-trainable encoders currently
1232
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1233
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1234
+ force_c_encode=True, return_original_cond=True, bs=bs)
1235
+
1236
+ assert exists(self.concat_keys)
1237
+ c_cat = list()
1238
+ for ck in self.concat_keys:
1239
+ if len(batch[ck].shape) == 3:
1240
+ batch[ck] = batch[ck][..., None]
1241
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1242
+ if bs is not None:
1243
+ cc = cc[:bs]
1244
+ cc = cc.to(self.device)
1245
+ bchw = z.shape
1246
+ if ck != self.masked_image_key:
1247
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1248
+ else:
1249
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1250
+ c_cat.append(cc)
1251
+ c_cat = torch.cat(c_cat, dim=1)
1252
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1253
+ if return_first_stage_outputs:
1254
+ return z, all_conds, x, xrec, xc
1255
+ return z, all_conds
1256
+
1257
+ @torch.no_grad()
1258
+ def log_images(self, *args, **kwargs):
1259
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
1260
+ log["masked_image"] = rearrange(args[0]["masked_image"],
1261
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1262
+ return log
ldm/models/diffusion/ddpm_audio_inpaint.py ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from tqdm import tqdm
18
+ from torchvision.utils import make_grid
19
+ from pytorch_lightning.utilities.distributed import rank_zero_only
20
+
21
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
22
+ from ldm.modules.ema import LitEma
23
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
25
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
+ from ldm.models.diffusion.ddim import DDIMSampler
27
+ from ldm.models.diffusion.ddpm import DDPM, disabled_train
28
+
29
+ __conditioning_keys__ = {'concat': 'c_concat',
30
+ 'crossattn': 'c_crossattn',
31
+ 'adm': 'y'}
32
+
33
+ # add mel_dim and mel_length params to ensure correct shape
34
+ class LatentDiffusion_audioinpaint(DDPM):
35
+ """main class"""
36
+ def __init__(self,
37
+ first_stage_config,
38
+ cond_stage_config,
39
+ num_timesteps_cond=None,
40
+ mel_dim=80,
41
+ mel_length=848,
42
+ cond_stage_key="image",
43
+ cond_stage_trainable=False,
44
+ concat_mode=True,
45
+ cond_stage_forward=None,
46
+ conditioning_key=None,
47
+ scale_factor=1.0,
48
+ scale_by_std=False,
49
+ test_repeat=1,
50
+ test_numsteps = None,
51
+ *args, **kwargs):
52
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
53
+ self.scale_by_std = scale_by_std
54
+ assert self.num_timesteps_cond <= kwargs['timesteps']
55
+ # for backwards compatibility after implementation of DiffusionWrapper
56
+ if conditioning_key is None:
57
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
58
+ if cond_stage_config == '__is_unconditional__':
59
+ conditioning_key = None
60
+ ckpt_path = kwargs.pop("ckpt_path", None)
61
+ ignore_keys = kwargs.pop("ignore_keys", [])
62
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
63
+ self.test_repeat = test_repeat
64
+ if test_numsteps == None:
65
+ self.test_numsteps = self.num_timesteps
66
+ self.concat_mode = concat_mode
67
+ self.mel_dim = mel_dim
68
+ self.mel_length = mel_length
69
+ self.cond_stage_trainable = cond_stage_trainable
70
+ self.cond_stage_key = cond_stage_key
71
+ try:
72
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
73
+ except:
74
+ self.num_downs = 0
75
+ if not scale_by_std:
76
+ self.scale_factor = scale_factor
77
+ else:
78
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
79
+ self.instantiate_first_stage(first_stage_config)
80
+ self.instantiate_cond_stage(cond_stage_config)
81
+ self.cond_stage_forward = cond_stage_forward
82
+ self.clip_denoised = False
83
+ self.bbox_tokenizer = None
84
+
85
+ self.restarted_from_ckpt = False
86
+ if ckpt_path is not None:
87
+ self.init_from_ckpt(ckpt_path, ignore_keys)
88
+ self.restarted_from_ckpt = True
89
+
90
+ def make_cond_schedule(self, ):
91
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
92
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
93
+ self.cond_ids[:self.num_timesteps_cond] = ids
94
+
95
+ @rank_zero_only
96
+ @torch.no_grad()
97
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
98
+ # only for very first batch
99
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
100
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
101
+ # set rescale weight to 1./std of encodings
102
+ print("### USING STD-RESCALING ###")
103
+ x = super().get_input(batch, self.first_stage_key)
104
+ x = x.to(self.device)
105
+ encoder_posterior = self.encode_first_stage(x)
106
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
107
+ del self.scale_factor
108
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
109
+ print(f"setting self.scale_factor to {self.scale_factor}")
110
+ print("### USING STD-RESCALING ###")
111
+
112
+ def register_schedule(self,
113
+ given_betas=None, beta_schedule="linear", timesteps=1000,
114
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
115
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
116
+
117
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
118
+ if self.shorten_cond_schedule:
119
+ self.make_cond_schedule()
120
+
121
+ def instantiate_first_stage(self, config):
122
+ model = instantiate_from_config(config)
123
+ self.first_stage_model = model.eval()
124
+ self.first_stage_model.train = disabled_train
125
+ for param in self.first_stage_model.parameters():
126
+ param.requires_grad = False
127
+
128
+ def instantiate_cond_stage(self, config):
129
+ if not self.cond_stage_trainable:
130
+ if config == "__is_first_stage__":# for no_text inpainting task
131
+ print("Using first stage also as cond stage.")
132
+ self.cond_stage_model = self.first_stage_model
133
+ elif config == "__is_unconditional__":# for unconditional image generation such as human face、ImageNet
134
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
135
+ self.cond_stage_model = None
136
+ # self.be_unconditional = True
137
+ else:
138
+ model = instantiate_from_config(config)
139
+ self.cond_stage_model = model.eval()
140
+ self.cond_stage_model.train = disabled_train
141
+ for param in self.cond_stage_model.parameters():
142
+ param.requires_grad = False
143
+ else:
144
+ assert config != '__is_first_stage__'
145
+ assert config != '__is_unconditional__'
146
+ model = instantiate_from_config(config)
147
+ self.cond_stage_model = model
148
+
149
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
150
+ denoise_row = []
151
+ for zd in tqdm(samples, desc=desc):
152
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
153
+ force_not_quantize=force_no_decoder_quantization))
154
+ n_imgs_per_row = len(denoise_row)
155
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
156
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
157
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
158
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
159
+ return denoise_grid
160
+
161
+ def get_first_stage_encoding(self, encoder_posterior):# encode_emb from autoencoder
162
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
163
+ z = encoder_posterior.sample()
164
+ elif isinstance(encoder_posterior, torch.Tensor):
165
+ z = encoder_posterior
166
+ else:
167
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
168
+ return self.scale_factor * z
169
+
170
+ def get_learned_conditioning(self, c):
171
+ if self.cond_stage_forward is None:
172
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
173
+ c = self.cond_stage_model.encode(c)
174
+ if isinstance(c, DiagonalGaussianDistribution):
175
+ c = c.mode()
176
+ else:
177
+ c = self.cond_stage_model(c)
178
+ else:
179
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
180
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
181
+ return c
182
+
183
+ def meshgrid(self, h, w):
184
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
185
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
186
+
187
+ arr = torch.cat([y, x], dim=-1)
188
+ return arr
189
+
190
+ def delta_border(self, h, w):
191
+ """
192
+ :param h: height
193
+ :param w: width
194
+ :return: normalized distance to image border,
195
+ wtith min distance = 0 at border and max dist = 0.5 at image center
196
+ """
197
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
198
+ arr = self.meshgrid(h, w) / lower_right_corner
199
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
200
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
201
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
202
+ return edge_dist
203
+
204
+ def get_weighting(self, h, w, Ly, Lx, device):
205
+ weighting = self.delta_border(h, w)
206
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
207
+ self.split_input_params["clip_max_weight"], )
208
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
209
+
210
+ if self.split_input_params["tie_braker"]:
211
+ L_weighting = self.delta_border(Ly, Lx)
212
+ L_weighting = torch.clip(L_weighting,
213
+ self.split_input_params["clip_min_tie_weight"],
214
+ self.split_input_params["clip_max_tie_weight"])
215
+
216
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
217
+ weighting = weighting * L_weighting
218
+ return weighting
219
+
220
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
221
+ """
222
+ :param x: img of size (bs, c, h, w)
223
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
224
+ """
225
+ bs, nc, h, w = x.shape
226
+
227
+ # number of crops in image
228
+ Ly = (h - kernel_size[0]) // stride[0] + 1
229
+ Lx = (w - kernel_size[1]) // stride[1] + 1
230
+
231
+ if uf == 1 and df == 1:
232
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
233
+ unfold = torch.nn.Unfold(**fold_params)
234
+
235
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
236
+
237
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
238
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
239
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
240
+
241
+ elif uf > 1 and df == 1:
242
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
243
+ unfold = torch.nn.Unfold(**fold_params)
244
+
245
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
246
+ dilation=1, padding=0,
247
+ stride=(stride[0] * uf, stride[1] * uf))
248
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
249
+
250
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
251
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
252
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
253
+
254
+ elif df > 1 and uf == 1:
255
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
256
+ unfold = torch.nn.Unfold(**fold_params)
257
+
258
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
259
+ dilation=1, padding=0,
260
+ stride=(stride[0] // df, stride[1] // df))
261
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
262
+
263
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
264
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
265
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
266
+
267
+ else:
268
+ raise NotImplementedError
269
+
270
+ return fold, unfold, normalization, weighting
271
+
272
+ @torch.no_grad()
273
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
274
+ cond_key=None, return_original_cond=False, bs=None):
275
+ x = super().get_input(batch, k)
276
+ if bs is not None:
277
+ x = x[:bs]
278
+ x = x.to(self.device)
279
+ encoder_posterior = self.encode_first_stage(x)
280
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
281
+
282
+ if self.model.conditioning_key is not None:# 'crossattn' for txt2image, 'hybird' for txt_inpaint
283
+ if cond_key is None:
284
+ cond_key = self.cond_stage_key # 'caption' for txt_inpaint
285
+ if self.model.conditioning_key == 'hybrid':
286
+ xc = {}
287
+ assert cond_key == 'caption' # only txt_inpaint is implemented now
288
+ assert 'masked_image' in batch.keys()
289
+ assert 'mask' in batch.keys()
290
+ masked_image = super().get_input(batch,'masked_image')
291
+ mask = super().get_input(batch,'mask')
292
+ if bs is not None:
293
+ masked_image,mask = masked_image[:bs],mask[:bs]
294
+ masked_image,mask = masked_image.to(self.device),mask.to(self.device)
295
+ masked_image = self.get_first_stage_encoding(self.encode_first_stage(masked_image)).detach()
296
+ resized_mask = torch.nn.functional.interpolate(mask,size=masked_image.shape[-2:])
297
+ xc['c_concat'] = torch.cat((masked_image,resized_mask),dim = 1)
298
+ xc[cond_key] = batch[cond_key]
299
+ else:
300
+ if cond_key != self.first_stage_key:
301
+ if cond_key in ['caption', 'coordinates_bbox']:
302
+ xc = batch[cond_key]
303
+ elif cond_key == 'class_label':
304
+ xc = batch
305
+ else:
306
+ xc = super().get_input(batch, cond_key).to(self.device)
307
+ else:# cond_key == 'image'
308
+ xc = x
309
+ if not self.cond_stage_trainable or force_c_encode:# cond_stage_trainable is true for txt2img,force_c_encoder = True,when called in log_images
310
+ if isinstance(xc, list):
311
+ # import pudb; pudb.set_trace()
312
+ c = self.get_learned_conditioning(xc)# 因为log_images内接下来要调用sample_log,所以需要预先得到处理好的c
313
+ if isinstance(xc, dict):
314
+ c = {}
315
+ c['c_concat'] = xc['c_concat']
316
+ c['c_crossattn'] = self.get_learned_conditioning(xc[cond_key])
317
+ else:
318
+ c = self.get_learned_conditioning(xc.to(self.device))
319
+ else:
320
+ c = xc
321
+ if bs is not None:
322
+ if isinstance(c,dict):
323
+ for k in c.keys():
324
+ c[k] = c[k][:bs]
325
+ else:
326
+ c = c[:bs]
327
+
328
+ if self.use_positional_encodings:
329
+ pos_x, pos_y = self.compute_latent_shifts(batch)
330
+ ckey = __conditioning_keys__[self.model.conditioning_key]
331
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
332
+
333
+ else:
334
+ c = None
335
+ xc = None
336
+ if self.use_positional_encodings:
337
+ pos_x, pos_y = self.compute_latent_shifts(batch)
338
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
339
+ out = [z, c]
340
+ if return_first_stage_outputs:
341
+ xrec = self.decode_first_stage(z)
342
+ out.extend([x, xrec])
343
+ if return_original_cond:
344
+ out.append(xc)
345
+ return out
346
+
347
+ @torch.no_grad()
348
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
349
+ if predict_cids:
350
+ if z.dim() == 4:
351
+ z = torch.argmax(z.exp(), dim=1).long()
352
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
353
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
354
+
355
+ z = 1. / self.scale_factor * z
356
+
357
+ if hasattr(self, "split_input_params"):
358
+ if self.split_input_params["patch_distributed_vq"]:
359
+ ks = self.split_input_params["ks"] # eg. (128, 128)
360
+ stride = self.split_input_params["stride"] # eg. (64, 64)
361
+ uf = self.split_input_params["vqf"]
362
+ bs, nc, h, w = z.shape
363
+ if ks[0] > h or ks[1] > w:
364
+ ks = (min(ks[0], h), min(ks[1], w))
365
+ print("reducing Kernel")
366
+
367
+ if stride[0] > h or stride[1] > w:
368
+ stride = (min(stride[0], h), min(stride[1], w))
369
+ print("reducing stride")
370
+
371
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
372
+
373
+ z = unfold(z) # (bn, nc * prod(**ks), L)
374
+ # 1. Reshape to img shape
375
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
376
+
377
+ # 2. apply model loop over last dim
378
+ if isinstance(self.first_stage_model, VQModelInterface):
379
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
380
+ force_not_quantize=predict_cids or force_not_quantize)
381
+ for i in range(z.shape[-1])]
382
+ else:
383
+
384
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
385
+ for i in range(z.shape[-1])]
386
+
387
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
388
+ o = o * weighting
389
+ # Reverse 1. reshape to img shape
390
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
391
+ # stitch crops together
392
+ decoded = fold(o)
393
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
394
+ return decoded
395
+ else:
396
+ if isinstance(self.first_stage_model, VQModelInterface):
397
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
398
+ else:
399
+ return self.first_stage_model.decode(z)
400
+
401
+ else:
402
+ if isinstance(self.first_stage_model, VQModelInterface):
403
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
404
+ else:
405
+ return self.first_stage_model.decode(z)
406
+
407
+ # same as above but without decorator
408
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
409
+ if predict_cids:
410
+ if z.dim() == 4:
411
+ z = torch.argmax(z.exp(), dim=1).long()
412
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
413
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
414
+
415
+ z = 1. / self.scale_factor * z
416
+
417
+ if hasattr(self, "split_input_params"):
418
+ if self.split_input_params["patch_distributed_vq"]:
419
+ ks = self.split_input_params["ks"] # eg. (128, 128)
420
+ stride = self.split_input_params["stride"] # eg. (64, 64)
421
+ uf = self.split_input_params["vqf"]
422
+ bs, nc, h, w = z.shape
423
+ if ks[0] > h or ks[1] > w:
424
+ ks = (min(ks[0], h), min(ks[1], w))
425
+ print("reducing Kernel")
426
+
427
+ if stride[0] > h or stride[1] > w:
428
+ stride = (min(stride[0], h), min(stride[1], w))
429
+ print("reducing stride")
430
+
431
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
432
+
433
+ z = unfold(z) # (bn, nc * prod(**ks), L)
434
+ # 1. Reshape to img shape
435
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
436
+
437
+ # 2. apply model loop over last dim
438
+ if isinstance(self.first_stage_model, VQModelInterface):
439
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
440
+ force_not_quantize=predict_cids or force_not_quantize)
441
+ for i in range(z.shape[-1])]
442
+ else:
443
+
444
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
445
+ for i in range(z.shape[-1])]
446
+
447
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
448
+ o = o * weighting
449
+ # Reverse 1. reshape to img shape
450
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
451
+ # stitch crops together
452
+ decoded = fold(o)
453
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
454
+ return decoded
455
+ else:
456
+ if isinstance(self.first_stage_model, VQModelInterface):
457
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
458
+ else:
459
+ return self.first_stage_model.decode(z)
460
+
461
+ else:
462
+ if isinstance(self.first_stage_model, VQModelInterface):
463
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
464
+ else:
465
+ return self.first_stage_model.decode(z)
466
+
467
+ @torch.no_grad()
468
+ def encode_first_stage(self, x):
469
+ if hasattr(self, "split_input_params"):
470
+ if self.split_input_params["patch_distributed_vq"]:
471
+ ks = self.split_input_params["ks"] # eg. (128, 128)
472
+ stride = self.split_input_params["stride"] # eg. (64, 64)
473
+ df = self.split_input_params["vqf"]
474
+ self.split_input_params['original_image_size'] = x.shape[-2:]
475
+ bs, nc, h, w = x.shape
476
+ if ks[0] > h or ks[1] > w:
477
+ ks = (min(ks[0], h), min(ks[1], w))
478
+ print("reducing Kernel")
479
+
480
+ if stride[0] > h or stride[1] > w:
481
+ stride = (min(stride[0], h), min(stride[1], w))
482
+ print("reducing stride")
483
+
484
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
485
+ z = unfold(x) # (bn, nc * prod(**ks), L)
486
+ # Reshape to img shape
487
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
488
+
489
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
490
+ for i in range(z.shape[-1])]
491
+
492
+ o = torch.stack(output_list, axis=-1)
493
+ o = o * weighting
494
+
495
+ # Reverse reshape to img shape
496
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
497
+ # stitch crops together
498
+ decoded = fold(o)
499
+ decoded = decoded / normalization
500
+ return decoded
501
+
502
+ else:
503
+ return self.first_stage_model.encode(x)
504
+ else:
505
+ return self.first_stage_model.encode(x)
506
+
507
+ def shared_step(self, batch, **kwargs):
508
+ x, c = self.get_input(batch, self.first_stage_key)# get latent and condition
509
+ loss = self(x, c)
510
+ return loss
511
+
512
+ def test_step(self,batch,batch_idx):
513
+ # TODO make self.test_repeat work
514
+ cond = {}
515
+ cond[self.cond_stage_key] = batch[self.cond_stage_key]
516
+ cond[self.cond_stage_key] = self.get_learned_conditioning(cond[self.cond_stage_key]) # c: string -> [B, T, Context_dim]
517
+ cond['c_crossattn'] = cond.pop(self.cond_stage_key)
518
+ masked_image = super().get_input(batch,'masked_image')
519
+ mask = super().get_input(batch,'mask')
520
+ masked_image,mask = masked_image.to(self.device),mask.to(self.device)
521
+ masked_image = self.get_first_stage_encoding(self.encode_first_stage(masked_image)).detach()
522
+ resized_mask = torch.nn.functional.interpolate(mask,size=masked_image.shape[-2:])
523
+ cond['c_concat'] = torch.cat((masked_image,resized_mask),dim = 1)
524
+ batch_size = len(batch[self.cond_stage_key])
525
+ # shape = [batch_size,self.channels,self.mel_dim,self.mel_length]
526
+ enc_emb = self.sample(cond,batch_size,timesteps=self.test_numsteps)
527
+ xrec = self.decode_first_stage(enc_emb)
528
+ reconstructions = (xrec + 1)/2 # to mel scale
529
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
530
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
531
+ if not os.path.exists(savedir):
532
+ os.makedirs(savedir)
533
+
534
+ file_names = batch['f_name']
535
+ nfiles = len(file_names)
536
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
537
+ for k in range(reconstructions.shape[0]):
538
+ b,repeat = k % nfiles, k // nfiles
539
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
540
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
541
+ save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}_{repeat}.npy')# the num_th caption, the repeat_th repitition
542
+ np.save(save_img_path,reconstructions[b])
543
+
544
+ return None
545
+
546
+ def forward(self, x, c, *args, **kwargs):
547
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
548
+ if self.model.conditioning_key is not None:
549
+ assert c is not None
550
+ if self.cond_stage_trainable:
551
+ if isinstance(c,dict):
552
+ c[self.cond_stage_key] = self.get_learned_conditioning(c[self.cond_stage_key])
553
+ c['c_crossattn'] = c.pop(self.cond_stage_key)
554
+ else:
555
+ c = self.get_learned_conditioning(c) # c: string -> [B, T, Context_dim]
556
+ if self.shorten_cond_schedule: # TODO: drop this option
557
+ tc = self.cond_ids[t].to(self.device)
558
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
559
+ return self.p_losses(x, c, t, *args, **kwargs)
560
+
561
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
562
+ def rescale_bbox(bbox):
563
+ x0 = torch.clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
564
+ y0 = torch.clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
565
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
566
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
567
+ return x0, y0, w, h
568
+
569
+ return [rescale_bbox(b) for b in bboxes]
570
+
571
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
572
+ # make values to list to enable concat operation in
573
+ if isinstance(cond, dict):
574
+ # hybrid case, cond is exptected to be a dict. (txt2inpaint)
575
+ cond_tmp = {}# use cond_tmp to avoid inplace edit
576
+ for k,v in cond.items():
577
+ if not isinstance(v, list):
578
+ cond_tmp[k] = [cond[k]]
579
+ else:
580
+ cond_tmp[k] = cond[k]
581
+ cond = cond_tmp
582
+ else:
583
+ if not isinstance(cond, list):
584
+ cond = [cond]
585
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
586
+ cond = {key: cond}
587
+
588
+ if hasattr(self, "split_input_params"):
589
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
590
+ assert not return_ids
591
+ ks = self.split_input_params["ks"] # eg. (128, 128)
592
+ stride = self.split_input_params["stride"] # eg. (64, 64)
593
+
594
+ h, w = x_noisy.shape[-2:]
595
+
596
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
597
+
598
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
599
+ # Reshape to img shape
600
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
601
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
602
+
603
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
604
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
605
+ c_key = next(iter(cond.keys())) # get key
606
+ c = next(iter(cond.values())) # get value
607
+ assert (len(c) == 1) # todo extend to list with more than one elem
608
+ c = c[0] # get element
609
+
610
+ c = unfold(c)
611
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
612
+
613
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
614
+
615
+ elif self.cond_stage_key == 'coordinates_bbox':
616
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
617
+
618
+ # assuming padding of unfold is always 0 and its dilation is always 1
619
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
620
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
621
+ # as we are operating on latents, we need the factor from the original image size to the
622
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
623
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
624
+ rescale_latent = 2 ** (num_downs)
625
+
626
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
627
+ # need to rescale the tl patch coordinates to be in between (0,1)
628
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
629
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
630
+ for patch_nr in range(z.shape[-1])]
631
+
632
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
633
+ patch_limits = [(x_tl, y_tl,
634
+ rescale_latent * ks[0] / full_img_w,
635
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
636
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
637
+
638
+ # tokenize crop coordinates for the bounding boxes of the respective patches
639
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
640
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
641
+ print(patch_limits_tknzd[0].shape)
642
+ # cut tknzd crop position from conditioning
643
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
644
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
645
+ print(cut_cond.shape)
646
+
647
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
648
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
649
+ print(adapted_cond.shape)
650
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
651
+ print(adapted_cond.shape)
652
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
653
+ print(adapted_cond.shape)
654
+
655
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
656
+
657
+ else:
658
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
659
+
660
+ # apply model by loop over crops
661
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
662
+ assert not isinstance(output_list[0],
663
+ tuple) # todo cant deal with multiple model outputs check this never happens
664
+
665
+ o = torch.stack(output_list, axis=-1)
666
+ o = o * weighting
667
+ # Reverse reshape to img shape
668
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
669
+ # stitch crops together
670
+ x_recon = fold(o) / normalization
671
+
672
+ else:
673
+ # x_noisy is tensor with shape [b,c,mel_len,T]
674
+ # if condition is caption ,cond['c_crossattn'] is a list, each item shape is [1, 77, 1280]
675
+ x_recon = self.model(x_noisy, t, **cond)# tensor with shape [b,c,mel_len,T]
676
+
677
+ if isinstance(x_recon, tuple) and not return_ids:
678
+ return x_recon[0]
679
+ else:
680
+ return x_recon
681
+
682
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
683
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
684
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
685
+
686
+ def _prior_bpd(self, x_start):
687
+ """
688
+ Get the prior KL term for the variational lower-bound, measured in
689
+ bits-per-dim.
690
+ This term can't be optimized, as it only depends on the encoder.
691
+ :param x_start: the [N x C x ...] tensor of inputs.
692
+ :return: a batch of [N] KL values (in bits), one per batch element.
693
+ """
694
+ batch_size = x_start.shape[0]
695
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
696
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
697
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
698
+ return mean_flat(kl_prior) / np.log(2.0)
699
+
700
+ def p_losses(self, x_start, cond, t, noise=None):
701
+ noise = default(noise, lambda: torch.randn_like(x_start))
702
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
703
+ model_output = self.apply_model(x_noisy, t, cond)
704
+
705
+ loss_dict = {}
706
+ prefix = 'train' if self.training else 'val'
707
+
708
+ if self.parameterization == "x0":
709
+ target = x_start
710
+ elif self.parameterization == "eps":
711
+ target = noise
712
+ else:
713
+ raise NotImplementedError()
714
+
715
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
716
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
717
+
718
+ logvar_t = self.logvar[t].to(self.device)
719
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
720
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
721
+ if self.learn_logvar:
722
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
723
+ loss_dict.update({'logvar': self.logvar.data.mean()})
724
+
725
+ loss = self.l_simple_weight * loss.mean()
726
+
727
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
728
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
729
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
730
+ loss += (self.original_elbo_weight * loss_vlb)
731
+ loss_dict.update({f'{prefix}/loss': loss})
732
+
733
+ return loss, loss_dict
734
+
735
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
736
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
737
+ t_in = t
738
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
739
+
740
+ if score_corrector is not None:
741
+ assert self.parameterization == "eps"
742
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
743
+
744
+ if return_codebook_ids:
745
+ model_out, logits = model_out
746
+
747
+ if self.parameterization == "eps":
748
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
749
+ elif self.parameterization == "x0":
750
+ x_recon = model_out
751
+ else:
752
+ raise NotImplementedError()
753
+
754
+ if clip_denoised:
755
+ x_recon.clamp_(-1., 1.)
756
+ if quantize_denoised:
757
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
758
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
759
+ if return_codebook_ids:
760
+ return model_mean, posterior_variance, posterior_log_variance, logits
761
+ elif return_x0:
762
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
763
+ else:
764
+ return model_mean, posterior_variance, posterior_log_variance
765
+
766
+ @torch.no_grad()
767
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
768
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
769
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
770
+ b, *_, device = *x.shape, x.device
771
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
772
+ return_codebook_ids=return_codebook_ids,
773
+ quantize_denoised=quantize_denoised,
774
+ return_x0=return_x0,
775
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
776
+ if return_codebook_ids:
777
+ raise DeprecationWarning("Support dropped.")
778
+ model_mean, _, model_log_variance, logits = outputs
779
+ elif return_x0:
780
+ model_mean, _, model_log_variance, x0 = outputs
781
+ else:
782
+ model_mean, _, model_log_variance = outputs
783
+
784
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
785
+ if noise_dropout > 0.:
786
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
787
+ # no noise when t == 0
788
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
789
+
790
+ if return_codebook_ids:
791
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
792
+ if return_x0:
793
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
794
+ else:
795
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
796
+
797
+ @torch.no_grad()
798
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
799
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
800
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
801
+ log_every_t=None):
802
+ if not log_every_t:
803
+ log_every_t = self.log_every_t
804
+ timesteps = self.num_timesteps
805
+ if batch_size is not None:
806
+ b = batch_size if batch_size is not None else shape[0]
807
+ shape = [batch_size] + list(shape)
808
+ else:
809
+ b = batch_size = shape[0]
810
+ if x_T is None:
811
+ img = torch.randn(shape, device=self.device)
812
+ else:
813
+ img = x_T
814
+ intermediates = []
815
+ if cond is not None:
816
+ if isinstance(cond, dict):
817
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
818
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
819
+ else:
820
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
821
+
822
+ if start_T is not None:
823
+ timesteps = min(timesteps, start_T)
824
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
825
+ total=timesteps) if verbose else reversed(
826
+ range(0, timesteps))
827
+ if type(temperature) == float:
828
+ temperature = [temperature] * timesteps
829
+
830
+ for i in iterator:
831
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
832
+ if self.shorten_cond_schedule:
833
+ assert self.model.conditioning_key != 'hybrid'
834
+ tc = self.cond_ids[ts].to(cond.device)
835
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
836
+
837
+ img, x0_partial = self.p_sample(img, cond, ts,
838
+ clip_denoised=self.clip_denoised,
839
+ quantize_denoised=quantize_denoised, return_x0=True,
840
+ temperature=temperature[i], noise_dropout=noise_dropout,
841
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
842
+ if mask is not None:
843
+ assert x0 is not None
844
+ img_orig = self.q_sample(x0, ts)
845
+ img = img_orig * mask + (1. - mask) * img
846
+
847
+ if i % log_every_t == 0 or i == timesteps - 1:
848
+ intermediates.append(x0_partial)
849
+ if callback: callback(i)
850
+ if img_callback: img_callback(img, i)
851
+ return img, intermediates
852
+
853
+ @torch.no_grad()
854
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
855
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
856
+ mask=None, x0=None, img_callback=None, start_T=None,
857
+ log_every_t=None):
858
+
859
+ if not log_every_t:
860
+ log_every_t = self.log_every_t
861
+ device = self.betas.device
862
+ b = shape[0]
863
+ if x_T is None:
864
+ img = torch.randn(shape, device=device)
865
+ else:
866
+ img = x_T
867
+
868
+ intermediates = [img]
869
+ if timesteps is None:
870
+ timesteps = self.num_timesteps
871
+
872
+ if start_T is not None:
873
+ timesteps = min(timesteps, start_T)
874
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
875
+ range(0, timesteps))
876
+
877
+ if mask is not None:
878
+ assert x0 is not None
879
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
880
+
881
+ for i in iterator:
882
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
883
+ if self.shorten_cond_schedule:
884
+ assert self.model.conditioning_key != 'hybrid'
885
+ tc = self.cond_ids[ts].to(cond.device)
886
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
887
+
888
+ img = self.p_sample(img, cond, ts,
889
+ clip_denoised=self.clip_denoised,
890
+ quantize_denoised=quantize_denoised)
891
+ if mask is not None:
892
+ img_orig = self.q_sample(x0, ts)
893
+ img = img_orig * mask + (1. - mask) * img
894
+
895
+ if i % log_every_t == 0 or i == timesteps - 1:
896
+ intermediates.append(img)
897
+ if callback: callback(i)
898
+ if img_callback: img_callback(img, i)
899
+
900
+ if return_intermediates:
901
+ return img, intermediates
902
+ return img
903
+
904
+ @torch.no_grad()
905
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
906
+ verbose=True, timesteps=None, quantize_denoised=False,
907
+ mask=None, x0=None, shape=None,**kwargs):
908
+ if shape is None:
909
+ shape = (batch_size, self.channels, self.mel_dim, self.mel_length)
910
+ if cond is not None:
911
+ if isinstance(cond, dict):
912
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
913
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
914
+ else:
915
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
916
+ return self.p_sample_loop(cond,
917
+ shape,
918
+ return_intermediates=return_intermediates, x_T=x_T,
919
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
920
+ mask=mask, x0=x0)
921
+
922
+ @torch.no_grad()
923
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
924
+ if ddim:
925
+ ddim_sampler = DDIMSampler(self)
926
+ shape = (self.channels, self.mel_dim, self.mel_length)
927
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
928
+ shape,cond,verbose=False,**kwargs)
929
+
930
+ else:
931
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
932
+ return_intermediates=True,**kwargs)
933
+
934
+ return samples, intermediates
935
+
936
+ @torch.no_grad()
937
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
938
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
939
+ plot_diffusion_rows=True, **kwargs):
940
+
941
+ use_ddim = ddim_steps is not None
942
+
943
+ log = dict()
944
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
945
+ return_first_stage_outputs=True,
946
+ force_c_encode=True,
947
+ return_original_cond=True,
948
+ bs=N)
949
+
950
+ N = min(x.shape[0], N)
951
+ n_row = min(x.shape[0], n_row)
952
+ log["inputs"] = x # 原始输入图像
953
+ log["reconstruction"] = xrec # 重建得到的图像
954
+ if self.model.conditioning_key is not None:
955
+ if hasattr(self.cond_stage_model, "decode"):# when cond_stage is first_stage. (bert embedder doesnot have decode)
956
+ xc = self.cond_stage_model.decode(c)# decoded masked image
957
+ log["conditioning"] = xc # 重建后的图像
958
+ elif self.cond_stage_key in ["caption"]:
959
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
960
+ log["conditioning"] = xc # 含有文本的图像
961
+ if self.model.conditioning_key == 'hybrid':
962
+ log["decoded_maskedimg"] = self.first_stage_model.decode(c['c_concat'][:,:self.first_stage_model.embed_dim])# c_concat is the concat result of masked_img latent and resized mask. get latent here to decode
963
+ elif self.cond_stage_key == 'class_label':
964
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
965
+ log['conditioning'] = xc # 文本为类标签的图像
966
+ elif isimage(xc):
967
+ log["conditioning"] = xc
968
+ if ismap(xc):
969
+ log["original_conditioning"] = self.to_rgb(xc)
970
+
971
+ if plot_diffusion_rows:# diffusion每一步的图像
972
+ # get diffusion row
973
+ diffusion_row = list()
974
+ z_start = z[:n_row]
975
+ for t in range(self.num_timesteps):
976
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
977
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
978
+ t = t.to(self.device).long()
979
+ noise = torch.randn_like(z_start)
980
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
981
+ diffusion_row.append(self.decode_first_stage(z_noisy))
982
+
983
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
984
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
985
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
986
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
987
+ log["diffusion_row"] = diffusion_grid
988
+
989
+ if sample:#
990
+ # get denoise row
991
+ with self.ema_scope("Plotting"):
992
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
993
+ ddim_steps=ddim_steps,eta=ddim_eta)
994
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
995
+ x_samples = self.decode_first_stage(samples)
996
+ log["samples"] = x_samples
997
+ if plot_denoise_rows:
998
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
999
+ log["denoise_row"] = denoise_grid
1000
+
1001
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1002
+ self.first_stage_model, IdentityFirstStage):
1003
+ # also display when quantizing x0 while sampling
1004
+ with self.ema_scope("Plotting Quantized Denoised"):
1005
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1006
+ ddim_steps=ddim_steps,eta=ddim_eta,
1007
+ quantize_denoised=True)
1008
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1009
+ # quantize_denoised=True)
1010
+ x_samples = self.decode_first_stage(samples.to(self.device))
1011
+ log["samples_x0_quantized"] = x_samples
1012
+
1013
+ if inpaint:
1014
+ # make a simple center square
1015
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1016
+ mask = torch.ones(N, h, w).to(self.device)
1017
+ # zeros will be filled in
1018
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1019
+ mask = mask[:, None, ...]# N,1,H,W
1020
+ with self.ema_scope("Plotting Inpaint"):
1021
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1022
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1023
+ x_samples = self.decode_first_stage(samples.to(self.device))
1024
+ log["samples_inpainting"] = x_samples
1025
+ log["mask"] = mask
1026
+
1027
+ # outpaint
1028
+ with self.ema_scope("Plotting Outpaint"):
1029
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1030
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1031
+ x_samples = self.decode_first_stage(samples.to(self.device))
1032
+ log["samples_outpainting"] = x_samples
1033
+
1034
+ if plot_progressive_rows:
1035
+ with self.ema_scope("Plotting Progressives"):
1036
+ img, progressives = self.progressive_denoising(c,
1037
+ shape=(self.channels, self.mel_dim, self.mel_length),
1038
+ batch_size=N)
1039
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1040
+ log["progressive_row"] = prog_row
1041
+
1042
+ if return_keys:
1043
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1044
+ return log
1045
+ else:
1046
+ return {key: log[key] for key in return_keys}
1047
+ return log
1048
+
1049
+ def configure_optimizers(self):
1050
+ lr = self.learning_rate
1051
+ params = list(self.model.parameters())
1052
+ if self.cond_stage_trainable:
1053
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1054
+ params = params + list(self.cond_stage_model.parameters())
1055
+ if self.learn_logvar:
1056
+ print('Diffusion model optimizing logvar')
1057
+ params.append(self.logvar)
1058
+ opt = torch.optim.AdamW(params, lr=lr)
1059
+ if self.use_scheduler:
1060
+ assert 'target' in self.scheduler_config
1061
+ scheduler = instantiate_from_config(self.scheduler_config)
1062
+
1063
+ print("Setting up LambdaLR scheduler...")
1064
+ scheduler = [
1065
+ {
1066
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1067
+ 'interval': 'step',
1068
+ 'frequency': 1
1069
+ }]
1070
+ return [opt], scheduler
1071
+ return opt
1072
+
1073
+ @torch.no_grad()
1074
+ def to_rgb(self, x):
1075
+ x = x.float()
1076
+ if not hasattr(self, "colorize"):
1077
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1078
+ x = nn.functional.conv2d(x, weight=self.colorize)
1079
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1080
+ return x
1081
+
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for PLMS sampling is {size}')
96
+
97
+ samples, intermediates = self.plms_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def plms_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
138
+
139
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140
+ old_eps = []
141
+
142
+ for i, step in enumerate(iterator):
143
+ index = total_steps - i - 1
144
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
145
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146
+
147
+ if mask is not None:
148
+ assert x0 is not None
149
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150
+ img = img_orig * mask + (1. - mask) * img
151
+
152
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153
+ quantize_denoised=quantize_denoised, temperature=temperature,
154
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ unconditional_guidance_scale=unconditional_guidance_scale,
157
+ unconditional_conditioning=unconditional_conditioning,
158
+ old_eps=old_eps, t_next=ts_next)
159
+ img, pred_x0, e_t = outs
160
+ old_eps.append(e_t)
161
+ if len(old_eps) >= 4:
162
+ old_eps.pop(0)
163
+ if callback: callback(i)
164
+ if img_callback: img_callback(pred_x0, i)
165
+
166
+ if index % log_every_t == 0 or index == total_steps - 1:
167
+ intermediates['x_inter'].append(img)
168
+ intermediates['pred_x0'].append(pred_x0)
169
+
170
+ return img, intermediates
171
+
172
+ @torch.no_grad()
173
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176
+ b, *_, device = *x.shape, x.device
177
+
178
+ def get_model_output(x, t):
179
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
+ e_t = self.model.apply_model(x, t, c)
181
+ else:
182
+ x_in = torch.cat([x] * 2)
183
+ t_in = torch.cat([t] * 2)
184
+ c_in = torch.cat([unconditional_conditioning, c])
185
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
+
188
+ if score_corrector is not None:
189
+ assert self.model.parameterization == "eps"
190
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
+
192
+ return e_t
193
+
194
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
+
199
+ def get_x_prev_and_pred_x0(e_t, index):
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
217
+
218
+ e_t = get_model_output(x, t)
219
+ if len(old_eps) == 0:
220
+ # Pseudo Improved Euler (2nd order)
221
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222
+ e_t_next = get_model_output(x_prev, t_next)
223
+ e_t_prime = (e_t + e_t_next) / 2
224
+ elif len(old_eps) == 1:
225
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
227
+ elif len(old_eps) == 2:
228
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230
+ elif len(old_eps) >= 3:
231
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233
+
234
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235
+
236
+ return x_prev, pred_x0, e_t
ldm/modules/attention.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from ldm.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):# 如果设置了context_dim就不是自注意力了
154
+ super().__init__()
155
+ inner_dim = dim_head * heads # inner_dim == SpatialTransformer.model_channels
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def forward(self, x, context=None, mask=None):# x:(b,h*w,c), context:(b,seq_len,context_dim)
171
+ h = self.heads
172
+
173
+ q = self.to_q(x)# q:(b,h*w,inner_dim)
174
+ context = default(context, x)
175
+ k = self.to_k(context)# (b,seq_len,inner_dim)
176
+ v = self.to_v(context)# (b,seq_len,inner_dim)
177
+
178
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))# n is seq_len for k and v
179
+
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (b*head,h*w,seq_len)
181
+
182
+ if exists(mask):# false
183
+ mask = rearrange(mask, 'b ... -> b (...)')
184
+ max_neg_value = -torch.finfo(sim.dtype).max
185
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
186
+ sim.masked_fill_(~mask, max_neg_value)
187
+
188
+ # attention, what we cannot get enough of
189
+ attn = sim.softmax(dim=-1)
190
+
191
+ out = einsum('b i j, b j d -> b i d', attn, v)# (b*head,h*w,inner_dim/head)
192
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)# (b,h*w,inner_dim)
193
+ return self.to_out(out)
194
+
195
+
196
+ class BasicTransformerBlock(nn.Module):
197
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198
+ super().__init__()
199
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203
+ self.norm1 = nn.LayerNorm(dim)
204
+ self.norm2 = nn.LayerNorm(dim)
205
+ self.norm3 = nn.LayerNorm(dim)
206
+ self.checkpoint = checkpoint
207
+
208
+ def forward(self, x, context=None):
209
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210
+
211
+ def _forward(self, x, context=None):
212
+ x = self.attn1(self.norm1(x)) + x
213
+ x = self.attn2(self.norm2(x), context=context) + x
214
+ x = self.ff(self.norm3(x)) + x
215
+ return x
216
+
217
+
218
+ class SpatialTransformer(nn.Module):
219
+ """
220
+ Transformer block for image-like data.
221
+ First, project the input (aka embedding)
222
+ and reshape to b, t, d.
223
+ Then apply standard transformer action.
224
+ Finally, reshape to image
225
+ """
226
+ def __init__(self, in_channels, n_heads, d_head,
227
+ depth=1, dropout=0., context_dim=None):
228
+ super().__init__()
229
+ self.in_channels = in_channels
230
+ inner_dim = n_heads * d_head
231
+ self.norm = Normalize(in_channels)
232
+
233
+ self.proj_in = nn.Conv2d(in_channels,
234
+ inner_dim,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0)
238
+
239
+ self.transformer_blocks = nn.ModuleList(
240
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241
+ for d in range(depth)]
242
+ )
243
+
244
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
245
+ in_channels,
246
+ kernel_size=1,
247
+ stride=1,
248
+ padding=0))
249
+
250
+ def forward(self, x, context=None):
251
+ # note: if no context is given, cross-attention defaults to self-attention
252
+ b, c, h, w = x.shape # such as [2,320,10,106]
253
+ x_in = x
254
+ x = self.norm(x)# group norm
255
+ x = self.proj_in(x)# no shape change
256
+ x = rearrange(x, 'b c h w -> b (h w) c')
257
+ for block in self.transformer_blocks:
258
+ x = block(x, context=context)# context shape [b,seq_len=77,context_dim]
259
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260
+ x = self.proj_out(x)
261
+ return x + x_in
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/custom_openaimodel.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from ldm.modules.attention import SpatialTransformer
21
+ from ldm.modules.diffusionmodules.openaimodel import convert_module_to_f16, convert_module_to_f32, AttentionPool2d, \
22
+ TimestepBlock, TimestepEmbedSequential, Upsample, TransposedUpsample, Downsample, ResBlock, AttentionBlock, count_flops_attn, \
23
+ QKVAttentionLegacy, QKVAttention
24
+
25
+
26
+ class UNetModel(nn.Module):
27
+ """
28
+ The full UNet model with attention and timestep embedding.
29
+ :param in_channels: channels in the input Tensor.
30
+ :param model_channels: base channel count for the model.
31
+ :param out_channels: channels in the output Tensor.
32
+ :param num_res_blocks: number of residual blocks per downsample.
33
+ :param attention_resolutions: a collection of downsample rates at which
34
+ attention will take place. May be a set, list, or tuple.
35
+ For example, if this contains 4, then at 4x downsampling, attention
36
+ will be used.
37
+ :param dropout: the dropout probability.
38
+ :param channel_mult: channel multiplier for each level of the UNet.
39
+ :param conv_resample: if True, use learned convolutions for upsampling and
40
+ downsampling.
41
+ :param dims: determines if the signal is 1D, 2D, or 3D.
42
+ :param num_classes: if specified (as an int), then this model will be
43
+ class-conditional with `num_classes` classes.
44
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
45
+ :param num_heads: the number of attention heads in each attention layer.
46
+ :param num_heads_channels: if specified, ignore num_heads and instead use
47
+ a fixed channel width per attention head.
48
+ :param num_heads_upsample: works with num_heads to set a different number
49
+ of heads for upsampling. Deprecated.
50
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
51
+ :param resblock_updown: use residual blocks for up/downsampling.
52
+ :param use_new_attention_order: use a different attention pattern for potentially
53
+ increased efficiency.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ image_size,
59
+ in_channels,
60
+ model_channels,
61
+ out_channels,
62
+ num_res_blocks,
63
+ attention_resolutions,
64
+ dropout=0,
65
+ channel_mult=(1, 2, 4, 8),
66
+ conv_resample=True,
67
+ dims=2,
68
+ num_classes=None,
69
+ use_checkpoint=False,
70
+ use_fp16=False,
71
+ num_heads=-1,
72
+ num_head_channels=-1,
73
+ num_heads_upsample=-1,
74
+ use_scale_shift_norm=False,
75
+ resblock_updown=False,
76
+ use_new_attention_order=False,
77
+ use_spatial_transformer=False, # custom transformer support
78
+ transformer_depth=1, # custom transformer support
79
+ context_dim=None, # custom transformer support
80
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
81
+ legacy=True,
82
+ use_context_project=False, # custom text to audio support
83
+ use_context_attn=True # custom text to audio support
84
+ ):
85
+ super().__init__()
86
+ if use_spatial_transformer:
87
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
88
+
89
+ if context_dim is not None and not use_context_project:
90
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
91
+ from omegaconf.listconfig import ListConfig
92
+ if type(context_dim) == ListConfig:
93
+ context_dim = list(context_dim)
94
+
95
+ if num_heads_upsample == -1:
96
+ num_heads_upsample = num_heads
97
+
98
+ if num_heads == -1:
99
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
100
+
101
+ if num_head_channels == -1:
102
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
103
+
104
+ self.image_size = image_size
105
+ self.in_channels = in_channels
106
+ self.model_channels = model_channels
107
+ self.out_channels = out_channels
108
+ self.num_res_blocks = num_res_blocks
109
+ self.attention_resolutions = attention_resolutions
110
+ self.dropout = dropout
111
+ self.channel_mult = channel_mult
112
+ self.conv_resample = conv_resample
113
+ self.num_classes = num_classes
114
+ self.use_checkpoint = use_checkpoint
115
+ self.dtype = th.float16 if use_fp16 else th.float32
116
+ self.num_heads = num_heads
117
+ self.num_head_channels = num_head_channels
118
+ self.num_heads_upsample = num_heads_upsample
119
+ self.predict_codebook_ids = n_embed is not None
120
+
121
+ time_embed_dim = model_channels * 4
122
+ self.time_embed = nn.Sequential(
123
+ linear(model_channels, time_embed_dim),
124
+ nn.SiLU(),
125
+ linear(time_embed_dim, time_embed_dim),
126
+ )
127
+
128
+ if self.num_classes is not None:
129
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
130
+
131
+ self.input_blocks = nn.ModuleList(
132
+ [
133
+ TimestepEmbedSequential(
134
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
135
+ )
136
+ ]
137
+ )
138
+ self._feature_size = model_channels
139
+ input_block_chans = [model_channels]
140
+ ch = model_channels
141
+ ds = 1
142
+ for level, mult in enumerate(channel_mult):
143
+ for _ in range(num_res_blocks):
144
+ layers = [
145
+ ResBlock(
146
+ ch,
147
+ time_embed_dim,
148
+ dropout,
149
+ out_channels=mult * model_channels,
150
+ dims=dims,
151
+ use_checkpoint=use_checkpoint,
152
+ use_scale_shift_norm=use_scale_shift_norm,
153
+ )
154
+ ]
155
+ ch = mult * model_channels
156
+ if ds in attention_resolutions:
157
+ if num_head_channels == -1:
158
+ dim_head = ch // num_heads
159
+ else:
160
+ num_heads = ch // num_head_channels
161
+ dim_head = num_head_channels
162
+ if legacy:
163
+ #num_heads = 1
164
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
165
+ layers.append(
166
+ AttentionBlock(
167
+ ch,
168
+ use_checkpoint=use_checkpoint,
169
+ num_heads=num_heads,
170
+ num_head_channels=dim_head,
171
+ use_new_attention_order=use_new_attention_order,
172
+ ) if not use_spatial_transformer else SpatialTransformer(
173
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
174
+ )
175
+ )
176
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
177
+ self._feature_size += ch
178
+ input_block_chans.append(ch)
179
+ if level != len(channel_mult) - 1:
180
+ out_ch = ch
181
+ self.input_blocks.append(
182
+ TimestepEmbedSequential(
183
+ ResBlock(
184
+ ch,
185
+ time_embed_dim,
186
+ dropout,
187
+ out_channels=out_ch,
188
+ dims=dims,
189
+ use_checkpoint=use_checkpoint,
190
+ use_scale_shift_norm=use_scale_shift_norm,
191
+ down=True,
192
+ )
193
+ if resblock_updown
194
+ else Downsample(
195
+ ch, conv_resample, dims=dims, out_channels=out_ch
196
+ )
197
+ )
198
+ )
199
+ ch = out_ch
200
+ input_block_chans.append(ch)
201
+ ds *= 2
202
+ self._feature_size += ch
203
+
204
+ if num_head_channels == -1:
205
+ dim_head = ch // num_heads
206
+ else:
207
+ num_heads = ch // num_head_channels
208
+ dim_head = num_head_channels
209
+ if legacy:
210
+ #num_heads = 1
211
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
212
+ self.middle_block = TimestepEmbedSequential(
213
+ ResBlock(
214
+ ch,
215
+ time_embed_dim,
216
+ dropout,
217
+ dims=dims,
218
+ use_checkpoint=use_checkpoint,
219
+ use_scale_shift_norm=use_scale_shift_norm,
220
+ ),
221
+ AttentionBlock(
222
+ ch,
223
+ use_checkpoint=use_checkpoint,
224
+ num_heads=num_heads,
225
+ num_head_channels=dim_head,
226
+ use_new_attention_order=use_new_attention_order,
227
+ ) if not use_spatial_transformer else SpatialTransformer(
228
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
229
+ ),
230
+ ResBlock(
231
+ ch,
232
+ time_embed_dim,
233
+ dropout,
234
+ dims=dims,
235
+ use_checkpoint=use_checkpoint,
236
+ use_scale_shift_norm=use_scale_shift_norm,
237
+ ),
238
+ )
239
+ self._feature_size += ch
240
+
241
+ self.output_blocks = nn.ModuleList([])
242
+ for level, mult in list(enumerate(channel_mult))[::-1]:
243
+ for i in range(num_res_blocks + 1):
244
+ ich = input_block_chans.pop()
245
+ layers = [
246
+ ResBlock(
247
+ ch + ich,
248
+ time_embed_dim,
249
+ dropout,
250
+ out_channels=model_channels * mult,
251
+ dims=dims,
252
+ use_checkpoint=use_checkpoint,
253
+ use_scale_shift_norm=use_scale_shift_norm,
254
+ )
255
+ ]
256
+ ch = model_channels * mult
257
+ if ds in attention_resolutions:
258
+ if num_head_channels == -1:
259
+ dim_head = ch // num_heads
260
+ else:
261
+ num_heads = ch // num_head_channels
262
+ dim_head = num_head_channels
263
+ if legacy:
264
+ #num_heads = 1
265
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
266
+ layers.append(
267
+ AttentionBlock(
268
+ ch,
269
+ use_checkpoint=use_checkpoint,
270
+ num_heads=num_heads_upsample,
271
+ num_head_channels=dim_head,
272
+ use_new_attention_order=use_new_attention_order,
273
+ ) if not use_spatial_transformer else SpatialTransformer(
274
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
275
+ )
276
+ )
277
+ if level and i == num_res_blocks:
278
+ out_ch = ch
279
+ layers.append(
280
+ ResBlock(
281
+ ch,
282
+ time_embed_dim,
283
+ dropout,
284
+ out_channels=out_ch,
285
+ dims=dims,
286
+ use_checkpoint=use_checkpoint,
287
+ use_scale_shift_norm=use_scale_shift_norm,
288
+ up=True,
289
+ )
290
+ if resblock_updown
291
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
292
+ )
293
+ ds //= 2
294
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
295
+ self._feature_size += ch
296
+
297
+ self.out = nn.Sequential(
298
+ normalization(ch),
299
+ nn.SiLU(),
300
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
301
+ )
302
+ if self.predict_codebook_ids:
303
+ self.id_predictor = nn.Sequential(
304
+ normalization(ch),
305
+ conv_nd(dims, model_channels, n_embed, 1),
306
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
307
+ )
308
+
309
+ self.use_context_project = use_context_project
310
+ if use_context_project:
311
+ self.context_project = linear(context_dim, time_embed_dim)
312
+ self.use_context_attn = use_context_attn
313
+
314
+
315
+ def convert_to_fp16(self):
316
+ """
317
+ Convert the torso of the model to float16.
318
+ """
319
+ self.input_blocks.apply(convert_module_to_f16)
320
+ self.middle_block.apply(convert_module_to_f16)
321
+ self.output_blocks.apply(convert_module_to_f16)
322
+
323
+ def convert_to_fp32(self):
324
+ """
325
+ Convert the torso of the model to float32.
326
+ """
327
+ self.input_blocks.apply(convert_module_to_f32)
328
+ self.middle_block.apply(convert_module_to_f32)
329
+ self.output_blocks.apply(convert_module_to_f32)
330
+
331
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
332
+ """
333
+ Apply the model to an input batch.
334
+ :param x: an [N x C x ...] Tensor of inputs.
335
+ :param timesteps: a 1-D batch of timesteps.
336
+ :param context: conditioning plugged in via crossattn
337
+ :param y: an [N] Tensor of labels, if class-conditional.
338
+ :return: an [N x C x ...] Tensor of outputs.
339
+ """
340
+ assert (y is not None) == (
341
+ self.num_classes is not None
342
+ ), "must specify y if and only if the model is class-conditional"
343
+ hs = []
344
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
345
+ emb = self.time_embed(t_emb)
346
+
347
+ if self.num_classes is not None:
348
+ assert y.shape == (x.shape[0],)
349
+ emb = emb + self.label_emb(y)
350
+
351
+ # For text-to-audio using global CLIP
352
+ if self.use_context_project:
353
+ context = self.context_project(context)
354
+ emb = emb + context.squeeze(1)
355
+
356
+ h = x.type(self.dtype)
357
+ for module in self.input_blocks:
358
+ h = module(h, emb, context if self.use_context_attn else None)
359
+ hs.append(h)
360
+ h = self.middle_block(h, emb, context if self.use_context_attn else None)
361
+ for module in self.output_blocks:
362
+ h = th.cat([h, hs.pop()], dim=1)
363
+ h = module(h, emb, context if self.use_context_attn else None)
364
+ h = h.type(x.dtype)
365
+ if self.predict_codebook_ids:
366
+ return self.id_predictor(h)
367
+ else:
368
+ return self.out(h)
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from ldm.util import instantiate_from_config
9
+ from ldm.modules.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))# vanilla attention
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)# GroupNorm
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+ #assert z.shape[1:] == self.z_shape[1:]
537
+ self.last_z_shape = z.shape
538
+
539
+ # timestep embedding
540
+ temb = None
541
+
542
+ # z to block_in
543
+ h = self.conv_in(z)
544
+
545
+ # middle
546
+ h = self.mid.block_1(h, temb)
547
+ h = self.mid.attn_1(h)
548
+ h = self.mid.block_2(h, temb)
549
+
550
+ # upsampling
551
+ for i_level in reversed(range(self.num_resolutions)):
552
+ for i_block in range(self.num_res_blocks+1):
553
+ h = self.up[i_level].block[i_block](h, temb)
554
+ if len(self.up[i_level].attn) > 0:
555
+ h = self.up[i_level].attn[i_block](h)
556
+ if i_level != 0:
557
+ h = self.up[i_level].upsample(h)
558
+
559
+ # end
560
+ if self.give_pre_end:
561
+ return h
562
+
563
+ h = self.norm_out(h)
564
+ h = nonlinearity(h)
565
+ h = self.conv_out(h)
566
+ if self.tanh_out:
567
+ h = torch.tanh(h)
568
+ return h
569
+
570
+
571
+ class SimpleDecoder(nn.Module):
572
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
573
+ super().__init__()
574
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
+ ResnetBlock(in_channels=in_channels,
576
+ out_channels=2 * in_channels,
577
+ temb_channels=0, dropout=0.0),
578
+ ResnetBlock(in_channels=2 * in_channels,
579
+ out_channels=4 * in_channels,
580
+ temb_channels=0, dropout=0.0),
581
+ ResnetBlock(in_channels=4 * in_channels,
582
+ out_channels=2 * in_channels,
583
+ temb_channels=0, dropout=0.0),
584
+ nn.Conv2d(2*in_channels, in_channels, 1),
585
+ Upsample(in_channels, with_conv=True)])
586
+ # end
587
+ self.norm_out = Normalize(in_channels)
588
+ self.conv_out = torch.nn.Conv2d(in_channels,
589
+ out_channels,
590
+ kernel_size=3,
591
+ stride=1,
592
+ padding=1)
593
+
594
+ def forward(self, x):
595
+ for i, layer in enumerate(self.model):
596
+ if i in [1,2,3]:
597
+ x = layer(x, None)
598
+ else:
599
+ x = layer(x)
600
+
601
+ h = self.norm_out(x)
602
+ h = nonlinearity(h)
603
+ x = self.conv_out(h)
604
+ return x
605
+
606
+
607
+ class UpsampleDecoder(nn.Module):
608
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
+ ch_mult=(2,2), dropout=0.0):
610
+ super().__init__()
611
+ # upsampling
612
+ self.temb_ch = 0
613
+ self.num_resolutions = len(ch_mult)
614
+ self.num_res_blocks = num_res_blocks
615
+ block_in = in_channels
616
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
+ self.res_blocks = nn.ModuleList()
618
+ self.upsample_blocks = nn.ModuleList()
619
+ for i_level in range(self.num_resolutions):
620
+ res_block = []
621
+ block_out = ch * ch_mult[i_level]
622
+ for i_block in range(self.num_res_blocks + 1):
623
+ res_block.append(ResnetBlock(in_channels=block_in,
624
+ out_channels=block_out,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout))
627
+ block_in = block_out
628
+ self.res_blocks.append(nn.ModuleList(res_block))
629
+ if i_level != self.num_resolutions - 1:
630
+ self.upsample_blocks.append(Upsample(block_in, True))
631
+ curr_res = curr_res * 2
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_channels,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, x):
642
+ # upsampling
643
+ h = x
644
+ for k, i_level in enumerate(range(self.num_resolutions)):
645
+ for i_block in range(self.num_res_blocks + 1):
646
+ h = self.res_blocks[i_level][i_block](h, None)
647
+ if i_level != self.num_resolutions - 1:
648
+ h = self.upsample_blocks[k](h)
649
+ h = self.norm_out(h)
650
+ h = nonlinearity(h)
651
+ h = self.conv_out(h)
652
+ return h
653
+
654
+
655
+ class LatentRescaler(nn.Module):
656
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
+ super().__init__()
658
+ # residual block, interpolate, residual block
659
+ self.factor = factor
660
+ self.conv_in = nn.Conv2d(in_channels,
661
+ mid_channels,
662
+ kernel_size=3,
663
+ stride=1,
664
+ padding=1)
665
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
+ out_channels=mid_channels,
667
+ temb_channels=0,
668
+ dropout=0.0) for _ in range(depth)])
669
+ self.attn = AttnBlock(mid_channels)
670
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+
675
+ self.conv_out = nn.Conv2d(mid_channels,
676
+ out_channels,
677
+ kernel_size=1,
678
+ )
679
+
680
+ def forward(self, x):
681
+ x = self.conv_in(x)
682
+ for block in self.res_block1:
683
+ x = block(x, None)
684
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
+ x = self.attn(x)
686
+ for block in self.res_block2:
687
+ x = block(x, None)
688
+ x = self.conv_out(x)
689
+ return x
690
+
691
+
692
+ class MergedRescaleEncoder(nn.Module):
693
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
+ super().__init__()
697
+ intermediate_chn = ch * ch_mult[-1]
698
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
+ out_ch=None)
702
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
+
705
+ def forward(self, x):
706
+ x = self.encoder(x)
707
+ x = self.rescaler(x)
708
+ return x
709
+
710
+
711
+ class MergedRescaleDecoder(nn.Module):
712
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
+ super().__init__()
715
+ tmp_chn = z_channels*ch_mult[-1]
716
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
719
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
+ out_channels=tmp_chn, depth=rescale_module_depth)
721
+
722
+ def forward(self, x):
723
+ x = self.rescaler(x)
724
+ x = self.decoder(x)
725
+ return x
726
+
727
+
728
+ class Upsampler(nn.Module):
729
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
+ super().__init__()
731
+ assert out_size >= in_size
732
+ num_blocks = int(np.log2(out_size//in_size))+1
733
+ factor_up = 1.+ (out_size % in_size)
734
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
+ out_channels=in_channels)
737
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
+ attn_resolutions=[], in_channels=None, ch=in_channels,
739
+ ch_mult=[ch_mult for _ in range(num_blocks)])
740
+
741
+ def forward(self, x):
742
+ x = self.rescaler(x)
743
+ x = self.decoder(x)
744
+ return x
745
+
746
+
747
+ class Resize(nn.Module):
748
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
+ super().__init__()
750
+ self.with_conv = learned
751
+ self.mode = mode
752
+ if self.with_conv:
753
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
+ raise NotImplementedError()
755
+ assert in_channels is not None
756
+ # no asymmetric padding in torch conv, must do it ourselves
757
+ self.conv = torch.nn.Conv2d(in_channels,
758
+ in_channels,
759
+ kernel_size=4,
760
+ stride=2,
761
+ padding=1)
762
+
763
+ def forward(self, x, scale_factor=1.0):
764
+ if scale_factor==1.0:
765
+ return x
766
+ else:
767
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
+ return x
769
+
770
+ class FirstStagePostProcessor(nn.Module):
771
+
772
+ def __init__(self, ch_mult:list, in_channels,
773
+ pretrained_model:nn.Module=None,
774
+ reshape=False,
775
+ n_channels=None,
776
+ dropout=0.,
777
+ pretrained_config=None):
778
+ super().__init__()
779
+ if pretrained_config is None:
780
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
+ self.pretrained_model = pretrained_model
782
+ else:
783
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
+ self.instantiate_pretrained(pretrained_config)
785
+
786
+ self.do_reshape = reshape
787
+
788
+ if n_channels is None:
789
+ n_channels = self.pretrained_model.encoder.ch
790
+
791
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
+ stride=1,padding=1)
794
+
795
+ blocks = []
796
+ downs = []
797
+ ch_in = n_channels
798
+ for m in ch_mult:
799
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
+ ch_in = m * n_channels
801
+ downs.append(Downsample(ch_in, with_conv=False))
802
+
803
+ self.model = nn.ModuleList(blocks)
804
+ self.downsampler = nn.ModuleList(downs)
805
+
806
+
807
+ def instantiate_pretrained(self, config):
808
+ model = instantiate_from_config(config)
809
+ self.pretrained_model = model.eval()
810
+ # self.pretrained_model.train = False
811
+ for param in self.pretrained_model.parameters():
812
+ param.requires_grad = False
813
+
814
+
815
+ @torch.no_grad()
816
+ def encode_with_pretrained(self,x):
817
+ c = self.pretrained_model.encode(x)
818
+ if isinstance(c, DiagonalGaussianDistribution):
819
+ c = c.mode()
820
+ return c
821
+
822
+ def forward(self,x):
823
+ z_fs = self.encode_with_pretrained(x)
824
+ z = self.proj_norm(z_fs)
825
+ z = self.proj(z)
826
+ z = nonlinearity(z)
827
+
828
+ for submodel, downmodel in zip(self.model,self.downsampler):
829
+ z = submodel(z,temb=None)
830
+ z = downmodel(z)
831
+
832
+ if self.do_reshape:
833
+ z = rearrange(z,'b c h w -> b (h w) c')
834
+ return z
835
+
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from ldm.modules.attention import SpatialTransformer
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ """
64
+ Any module where forward() takes timestep embeddings as a second argument.
65
+ """
66
+
67
+ @abstractmethod
68
+ def forward(self, x, emb):
69
+ """
70
+ Apply the module to `x` given `emb` timestep embeddings.
71
+ """
72
+
73
+
74
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
+ """
76
+ A sequential module that passes timestep embeddings to the children that
77
+ support it as an extra input.
78
+ """
79
+
80
+ def forward(self, x, emb, context=None):
81
+ for layer in self:
82
+ if isinstance(layer, TimestepBlock):
83
+ x = layer(x, emb)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class UNetModel(nn.Module):
414
+ """
415
+ The full UNet model with attention and timestep embedding.
416
+ :param in_channels: channels in the input Tensor.
417
+ :param model_channels: base channel count for the model.
418
+ :param out_channels: channels in the output Tensor.
419
+ :param num_res_blocks: number of residual blocks per downsample.
420
+ :param attention_resolutions: a collection of downsample rates at which
421
+ attention will take place. May be a set, list, or tuple.
422
+ For example, if this contains 4, then at 4x downsampling, attention
423
+ will be used.
424
+ :param dropout: the dropout probability.
425
+ :param channel_mult: channel multiplier for each level of the UNet.
426
+ :param conv_resample: if True, use learned convolutions for upsampling and
427
+ downsampling.
428
+ :param dims: determines if the signal is 1D, 2D, or 3D.
429
+ :param num_classes: if specified (as an int), then this model will be
430
+ class-conditional with `num_classes` classes.
431
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
+ :param num_heads: the number of attention heads in each attention layer.
433
+ :param num_heads_channels: if specified, ignore num_heads and instead use
434
+ a fixed channel width per attention head.
435
+ :param num_heads_upsample: works with num_heads to set a different number
436
+ of heads for upsampling. Deprecated.
437
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
+ :param resblock_updown: use residual blocks for up/downsampling.
439
+ :param use_new_attention_order: use a different attention pattern for potentially
440
+ increased efficiency.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ image_size,
446
+ in_channels,
447
+ model_channels,
448
+ out_channels,
449
+ num_res_blocks,
450
+ attention_resolutions,
451
+ dropout=0,
452
+ channel_mult=(1, 2, 4, 8),
453
+ conv_resample=True,
454
+ dims=2,
455
+ num_classes=None,
456
+ use_checkpoint=False,
457
+ use_fp16=False,
458
+ num_heads=-1,
459
+ num_head_channels=-1,
460
+ num_heads_upsample=-1,
461
+ use_scale_shift_norm=False,
462
+ resblock_updown=False,
463
+ use_new_attention_order=False,
464
+ use_spatial_transformer=False, # custom transformer support
465
+ transformer_depth=1, # custom transformer support
466
+ context_dim=None, # custom transformer support
467
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
+ legacy=True,
469
+ ):
470
+ super().__init__()
471
+ if use_spatial_transformer:
472
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
473
+
474
+ if context_dim is not None:
475
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
476
+ from omegaconf.listconfig import ListConfig
477
+ if type(context_dim) == ListConfig:
478
+ context_dim = list(context_dim)
479
+
480
+ if num_heads_upsample == -1:
481
+ num_heads_upsample = num_heads
482
+
483
+ if num_heads == -1:
484
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
485
+
486
+ if num_head_channels == -1:
487
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
488
+
489
+ self.image_size = image_size
490
+ self.in_channels = in_channels
491
+ self.model_channels = model_channels
492
+ self.out_channels = out_channels
493
+ self.num_res_blocks = num_res_blocks
494
+ self.attention_resolutions = attention_resolutions
495
+ self.dropout = dropout
496
+ self.channel_mult = channel_mult
497
+ self.conv_resample = conv_resample
498
+ self.num_classes = num_classes
499
+ self.use_checkpoint = use_checkpoint
500
+ self.dtype = th.float16 if use_fp16 else th.float32
501
+ self.num_heads = num_heads
502
+ self.num_head_channels = num_head_channels
503
+ self.num_heads_upsample = num_heads_upsample
504
+ self.predict_codebook_ids = n_embed is not None
505
+
506
+ time_embed_dim = model_channels * 4
507
+ self.time_embed = nn.Sequential(
508
+ linear(model_channels, time_embed_dim),
509
+ nn.SiLU(),
510
+ linear(time_embed_dim, time_embed_dim),
511
+ )
512
+
513
+ if self.num_classes is not None:
514
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
515
+
516
+ self.input_blocks = nn.ModuleList(
517
+ [
518
+ TimestepEmbedSequential(
519
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)# conv2d for txt2img/audio
520
+ )
521
+ ]
522
+ )
523
+ self._feature_size = model_channels
524
+ input_block_chans = [model_channels]
525
+ ch = model_channels
526
+ ds = 1
527
+ # downsample blocks
528
+ for level, mult in enumerate(channel_mult):
529
+ for _ in range(num_res_blocks):
530
+ layers = [
531
+ ResBlock(
532
+ ch,
533
+ time_embed_dim,
534
+ dropout,
535
+ out_channels=mult * model_channels,
536
+ dims=dims,
537
+ use_checkpoint=use_checkpoint,
538
+ use_scale_shift_norm=use_scale_shift_norm,
539
+ )
540
+ ]
541
+ ch = mult * model_channels
542
+ if ds in attention_resolutions:
543
+ if num_head_channels == -1:
544
+ dim_head = ch // num_heads
545
+ else:
546
+ num_heads = ch // num_head_channels
547
+ dim_head = num_head_channels
548
+ if legacy:
549
+ #num_heads = 1
550
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
551
+ layers.append(
552
+ AttentionBlock(
553
+ ch,
554
+ use_checkpoint=use_checkpoint,
555
+ num_heads=num_heads,
556
+ num_head_channels=dim_head,
557
+ use_new_attention_order=use_new_attention_order,
558
+ ) if not use_spatial_transformer else SpatialTransformer(# transformer_depth is 1
559
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
560
+ )
561
+ )
562
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
563
+ self._feature_size += ch
564
+ input_block_chans.append(ch)
565
+ if level != len(channel_mult) - 1:
566
+ out_ch = ch
567
+ self.input_blocks.append(
568
+ TimestepEmbedSequential(
569
+ ResBlock(
570
+ ch,
571
+ time_embed_dim,
572
+ dropout,
573
+ out_channels=out_ch,
574
+ dims=dims,
575
+ use_checkpoint=use_checkpoint,
576
+ use_scale_shift_norm=use_scale_shift_norm,
577
+ down=True,
578
+ )
579
+ if resblock_updown
580
+ else Downsample(
581
+ ch, conv_resample, dims=dims, out_channels=out_ch
582
+ )
583
+ )
584
+ )
585
+ ch = out_ch
586
+ input_block_chans.append(ch)
587
+ ds *= 2
588
+ self._feature_size += ch
589
+
590
+ if num_head_channels == -1:
591
+ dim_head = ch // num_heads
592
+ else:
593
+ num_heads = ch // num_head_channels
594
+ dim_head = num_head_channels
595
+ if legacy:
596
+ #num_heads = 1
597
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
598
+ self.middle_block = TimestepEmbedSequential(
599
+ ResBlock(
600
+ ch,
601
+ time_embed_dim,
602
+ dropout,
603
+ dims=dims,
604
+ use_checkpoint=use_checkpoint,
605
+ use_scale_shift_norm=use_scale_shift_norm,
606
+ ),
607
+ AttentionBlock(
608
+ ch,
609
+ use_checkpoint=use_checkpoint,
610
+ num_heads=num_heads,
611
+ num_head_channels=dim_head,
612
+ use_new_attention_order=use_new_attention_order,
613
+ ) if not use_spatial_transformer else SpatialTransformer(
614
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
615
+ ),
616
+ ResBlock(
617
+ ch,
618
+ time_embed_dim,
619
+ dropout,
620
+ dims=dims,
621
+ use_checkpoint=use_checkpoint,
622
+ use_scale_shift_norm=use_scale_shift_norm,
623
+ ),
624
+ )
625
+ self._feature_size += ch
626
+ # upsample blocks
627
+ self.output_blocks = nn.ModuleList([])
628
+ for level, mult in list(enumerate(channel_mult))[::-1]:
629
+ for i in range(num_res_blocks + 1):
630
+ ich = input_block_chans.pop()
631
+ layers = [
632
+ ResBlock(
633
+ ch + ich,
634
+ time_embed_dim,
635
+ dropout,
636
+ out_channels=model_channels * mult,
637
+ dims=dims,
638
+ use_checkpoint=use_checkpoint,
639
+ use_scale_shift_norm=use_scale_shift_norm,
640
+ )
641
+ ]
642
+ ch = model_channels * mult
643
+ if ds in attention_resolutions:
644
+ if num_head_channels == -1:
645
+ dim_head = ch // num_heads
646
+ else:
647
+ num_heads = ch // num_head_channels
648
+ dim_head = num_head_channels
649
+ if legacy:
650
+ #num_heads = 1
651
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
652
+ layers.append(
653
+ AttentionBlock(
654
+ ch,
655
+ use_checkpoint=use_checkpoint,
656
+ num_heads=num_heads_upsample,
657
+ num_head_channels=dim_head,
658
+ use_new_attention_order=use_new_attention_order,
659
+ ) if not use_spatial_transformer else SpatialTransformer(
660
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
661
+ )
662
+ )
663
+ if level and i == num_res_blocks:
664
+ out_ch = ch
665
+ layers.append(
666
+ ResBlock(
667
+ ch,
668
+ time_embed_dim,
669
+ dropout,
670
+ out_channels=out_ch,
671
+ dims=dims,
672
+ use_checkpoint=use_checkpoint,
673
+ use_scale_shift_norm=use_scale_shift_norm,
674
+ up=True,
675
+ )
676
+ if resblock_updown
677
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
678
+ )
679
+ ds //= 2
680
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
681
+ self._feature_size += ch
682
+
683
+ self.out = nn.Sequential(
684
+ normalization(ch),
685
+ nn.SiLU(),
686
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
687
+ )
688
+ if self.predict_codebook_ids:
689
+ self.id_predictor = nn.Sequential(
690
+ normalization(ch),
691
+ conv_nd(dims, model_channels, n_embed, 1),
692
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
693
+ )
694
+
695
+ def convert_to_fp16(self):
696
+ """
697
+ Convert the torso of the model to float16.
698
+ """
699
+ self.input_blocks.apply(convert_module_to_f16)
700
+ self.middle_block.apply(convert_module_to_f16)
701
+ self.output_blocks.apply(convert_module_to_f16)
702
+
703
+ def convert_to_fp32(self):
704
+ """
705
+ Convert the torso of the model to float32.
706
+ """
707
+ self.input_blocks.apply(convert_module_to_f32)
708
+ self.middle_block.apply(convert_module_to_f32)
709
+ self.output_blocks.apply(convert_module_to_f32)
710
+
711
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
712
+ """
713
+ Apply the model to an input batch.
714
+ :param x: an [N x C x ...] Tensor of inputs.
715
+ :param timesteps: a 1-D batch of timesteps,shape [N]
716
+ :param context: conditioning plugged in via crossattn. for txt2img shape is [N,77,context_dim]
717
+ :param y: an [N] Tensor of labels, if class-conditional.
718
+ :return: an [N x C x ...] Tensor of outputs.
719
+ """
720
+ # print(f"in unet {x.shape}")
721
+ assert (y is not None) == (
722
+ self.num_classes is not None
723
+ ), "must specify y if and only if the model is class-conditional"
724
+ hs = []
725
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)# shape [N,self.model_channels]
726
+ emb = self.time_embed(t_emb)# shape [N,context_dim]
727
+
728
+ if self.num_classes is not None:# only for class label
729
+ assert y.shape == (x.shape[0],)
730
+ emb = emb + self.label_emb(y)
731
+
732
+ h = x.type(self.dtype)# [N,C,10,106]
733
+ for module in self.input_blocks:
734
+ h = module(h, emb, context)# 0:[N,self.model_channels,10,106],1:[N,self.model_channels,10,106],2:[N,self.model_channels,10,106] 3:[N,self.model_channels,5,53] 4:[N,self.model_channels,5,53] 5:[N,self.model_channels*2,5,53]
735
+ hs.append(h)
736
+ h = self.middle_block(h, emb, context)# no shape change
737
+ for module in self.output_blocks:
738
+ h = th.cat([h, hs.pop()], dim=1)# 在这里c维度乘2或+self.model_channels,其余维度不变
739
+ h = module(h, emb, context)# 在这里c维度/2回到之前维度,h,w不变或*2
740
+ h = h.type(x.dtype)# 至此h维度和输入x维度回到相同状态
741
+ if self.predict_codebook_ids:
742
+ return self.id_predictor(h)
743
+ else:
744
+ return self.out(h)
745
+
746
+
747
+ class EncoderUNetModel(nn.Module):
748
+ """
749
+ The half UNet model with attention and timestep embedding.
750
+ For usage, see UNet.
751
+ """
752
+
753
+ def __init__(
754
+ self,
755
+ image_size,
756
+ in_channels,
757
+ model_channels,
758
+ out_channels,
759
+ num_res_blocks,
760
+ attention_resolutions,
761
+ dropout=0,
762
+ channel_mult=(1, 2, 4, 8),
763
+ conv_resample=True,
764
+ dims=2,
765
+ use_checkpoint=False,
766
+ use_fp16=False,
767
+ num_heads=1,
768
+ num_head_channels=-1,
769
+ num_heads_upsample=-1,
770
+ use_scale_shift_norm=False,
771
+ resblock_updown=False,
772
+ use_new_attention_order=False,
773
+ pool="adaptive",
774
+ *args,
775
+ **kwargs
776
+ ):
777
+ super().__init__()
778
+
779
+ if num_heads_upsample == -1:
780
+ num_heads_upsample = num_heads
781
+
782
+ self.in_channels = in_channels
783
+ self.model_channels = model_channels
784
+ self.out_channels = out_channels
785
+ self.num_res_blocks = num_res_blocks
786
+ self.attention_resolutions = attention_resolutions
787
+ self.dropout = dropout
788
+ self.channel_mult = channel_mult
789
+ self.conv_resample = conv_resample
790
+ self.use_checkpoint = use_checkpoint
791
+ self.dtype = th.float16 if use_fp16 else th.float32
792
+ self.num_heads = num_heads
793
+ self.num_head_channels = num_head_channels
794
+ self.num_heads_upsample = num_heads_upsample
795
+
796
+ time_embed_dim = model_channels * 4
797
+ self.time_embed = nn.Sequential(
798
+ linear(model_channels, time_embed_dim),
799
+ nn.SiLU(),
800
+ linear(time_embed_dim, time_embed_dim),
801
+ )
802
+
803
+ self.input_blocks = nn.ModuleList(
804
+ [
805
+ TimestepEmbedSequential(
806
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
807
+ )
808
+ ]
809
+ )
810
+ self._feature_size = model_channels
811
+ input_block_chans = [model_channels]
812
+ ch = model_channels
813
+ ds = 1
814
+ for level, mult in enumerate(channel_mult):
815
+ for _ in range(num_res_blocks):
816
+ layers = [
817
+ ResBlock(
818
+ ch,
819
+ time_embed_dim,
820
+ dropout,
821
+ out_channels=mult * model_channels,
822
+ dims=dims,
823
+ use_checkpoint=use_checkpoint,
824
+ use_scale_shift_norm=use_scale_shift_norm,
825
+ )
826
+ ]
827
+ ch = mult * model_channels
828
+ if ds in attention_resolutions:
829
+ layers.append(
830
+ AttentionBlock(
831
+ ch,
832
+ use_checkpoint=use_checkpoint,
833
+ num_heads=num_heads,
834
+ num_head_channels=num_head_channels,
835
+ use_new_attention_order=use_new_attention_order,
836
+ )
837
+ )
838
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
839
+ self._feature_size += ch
840
+ input_block_chans.append(ch)
841
+ if level != len(channel_mult) - 1:
842
+ out_ch = ch
843
+ self.input_blocks.append(
844
+ TimestepEmbedSequential(
845
+ ResBlock(
846
+ ch,
847
+ time_embed_dim,
848
+ dropout,
849
+ out_channels=out_ch,
850
+ dims=dims,
851
+ use_checkpoint=use_checkpoint,
852
+ use_scale_shift_norm=use_scale_shift_norm,
853
+ down=True,
854
+ )
855
+ if resblock_updown
856
+ else Downsample(
857
+ ch, conv_resample, dims=dims, out_channels=out_ch
858
+ )
859
+ )
860
+ )
861
+ ch = out_ch
862
+ input_block_chans.append(ch)
863
+ ds *= 2
864
+ self._feature_size += ch
865
+
866
+ self.middle_block = TimestepEmbedSequential(
867
+ ResBlock(
868
+ ch,
869
+ time_embed_dim,
870
+ dropout,
871
+ dims=dims,
872
+ use_checkpoint=use_checkpoint,
873
+ use_scale_shift_norm=use_scale_shift_norm,
874
+ ),
875
+ AttentionBlock(
876
+ ch,
877
+ use_checkpoint=use_checkpoint,
878
+ num_heads=num_heads,
879
+ num_head_channels=num_head_channels,
880
+ use_new_attention_order=use_new_attention_order,
881
+ ),
882
+ ResBlock(
883
+ ch,
884
+ time_embed_dim,
885
+ dropout,
886
+ dims=dims,
887
+ use_checkpoint=use_checkpoint,
888
+ use_scale_shift_norm=use_scale_shift_norm,
889
+ ),
890
+ )
891
+ self._feature_size += ch
892
+ self.pool = pool
893
+ if pool == "adaptive":
894
+ self.out = nn.Sequential(
895
+ normalization(ch),
896
+ nn.SiLU(),
897
+ nn.AdaptiveAvgPool2d((1, 1)),
898
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
899
+ nn.Flatten(),
900
+ )
901
+ elif pool == "attention":
902
+ assert num_head_channels != -1
903
+ self.out = nn.Sequential(
904
+ normalization(ch),
905
+ nn.SiLU(),
906
+ AttentionPool2d(
907
+ (image_size // ds), ch, num_head_channels, out_channels
908
+ ),
909
+ )
910
+ elif pool == "spatial":
911
+ self.out = nn.Sequential(
912
+ nn.Linear(self._feature_size, 2048),
913
+ nn.ReLU(),
914
+ nn.Linear(2048, self.out_channels),
915
+ )
916
+ elif pool == "spatial_v2":
917
+ self.out = nn.Sequential(
918
+ nn.Linear(self._feature_size, 2048),
919
+ normalization(2048),
920
+ nn.SiLU(),
921
+ nn.Linear(2048, self.out_channels),
922
+ )
923
+ else:
924
+ raise NotImplementedError(f"Unexpected {pool} pooling")
925
+
926
+ def convert_to_fp16(self):
927
+ """
928
+ Convert the torso of the model to float16.
929
+ """
930
+ self.input_blocks.apply(convert_module_to_f16)
931
+ self.middle_block.apply(convert_module_to_f16)
932
+
933
+ def convert_to_fp32(self):
934
+ """
935
+ Convert the torso of the model to float32.
936
+ """
937
+ self.input_blocks.apply(convert_module_to_f32)
938
+ self.middle_block.apply(convert_module_to_f32)
939
+
940
+ def forward(self, x, timesteps):
941
+ """
942
+ Apply the model to an input batch.
943
+ :param x: an [N x C x ...] Tensor of inputs.
944
+ :param timesteps: a 1-D batch of timesteps.
945
+ :return: an [N x K] Tensor of outputs.
946
+ """
947
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
948
+
949
+ results = []
950
+ h = x.type(self.dtype)
951
+ for module in self.input_blocks:
952
+ h = module(h, emb)
953
+ if self.pool.startswith("spatial"):
954
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
955
+ h = self.middle_block(h, emb)
956
+ if self.pool.startswith("spatial"):
957
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
958
+ h = th.cat(results, axis=-1)
959
+ return self.out(h)
960
+ else:
961
+ h = h.type(x.dtype)
962
+ return self.out(h)
963
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas, alphas, alphas_prev
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+ args = tuple(inputs) + tuple(params)
114
+ return CheckpointFunction.apply(func, len(inputs), *args)
115
+ else:
116
+ return func(*inputs)
117
+
118
+
119
+ class CheckpointFunction(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(ctx, run_function, length, *args):
122
+ ctx.run_function = run_function
123
+ ctx.input_tensors = list(args[:length])
124
+ ctx.input_params = list(args[length:])
125
+
126
+ with torch.no_grad():
127
+ output_tensors = ctx.run_function(*ctx.input_tensors)
128
+ return output_tensors
129
+
130
+ @staticmethod
131
+ def backward(ctx, *output_grads):
132
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133
+ with torch.enable_grad():
134
+ # Fixes a bug where the first op in run_function modifies the
135
+ # Tensor storage in place, which is not allowed for detach()'d
136
+ # Tensors.
137
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138
+ output_tensors = ctx.run_function(*shallow_copies)
139
+ input_grads = torch.autograd.grad(
140
+ output_tensors,
141
+ ctx.input_tensors + ctx.input_params,
142
+ output_grads,
143
+ allow_unused=True,
144
+ )
145
+ del ctx.input_tensors
146
+ del ctx.input_params
147
+ del output_tensors
148
+ return (None, None) + input_grads
149
+
150
+
151
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152
+ """
153
+ Create sinusoidal timestep embeddings.
154
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
155
+ These may be fractional.
156
+ :param dim: the dimension of the output.
157
+ :param max_period: controls the minimum frequency of the embeddings.
158
+ :return: an [N x dim] Tensor of positional embeddings.
159
+ """
160
+ if not repeat_only:
161
+ half = dim // 2
162
+ freqs = torch.exp(
163
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164
+ ).to(device=timesteps.device)
165
+ args = timesteps[:, None].float() * freqs[None]
166
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167
+ if dim % 2:
168
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169
+ else:
170
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
171
+ return embedding
172
+
173
+
174
+ def zero_module(module):
175
+ """
176
+ Zero out the parameters of a module and return it.
177
+ """
178
+ for p in module.parameters():
179
+ p.detach().zero_()
180
+ return module
181
+
182
+
183
+ def scale_module(module, scale):
184
+ """
185
+ Scale the parameters of a module and return it.
186
+ """
187
+ for p in module.parameters():
188
+ p.detach().mul_(scale)
189
+ return module
190
+
191
+
192
+ def mean_flat(tensor):
193
+ """
194
+ Take the mean over all non-batch dimensions.
195
+ """
196
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
197
+
198
+
199
+ def normalization(channels):
200
+ """
201
+ Make a standard normalization layer.
202
+ :param channels: number of input channels.
203
+ :return: an nn.Module for normalization.
204
+ """
205
+ return GroupNorm32(32, channels)
206
+
207
+
208
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209
+ class SiLU(nn.Module):
210
+ def forward(self, x):
211
+ return x * torch.sigmoid(x)
212
+
213
+
214
+ class GroupNorm32(nn.GroupNorm):
215
+ def forward(self, x):
216
+ return super().forward(x.float()).type(x.dtype)
217
+
218
+ def conv_nd(dims, *args, **kwargs):
219
+ """
220
+ Create a 1D, 2D, or 3D convolution module.
221
+ """
222
+ if dims == 1:
223
+ return nn.Conv1d(*args, **kwargs)
224
+ elif dims == 2:
225
+ return nn.Conv2d(*args, **kwargs)
226
+ elif dims == 3:
227
+ return nn.Conv3d(*args, **kwargs)
228
+ raise ValueError(f"unsupported dimensions: {dims}")
229
+
230
+
231
+ def linear(*args, **kwargs):
232
+ """
233
+ Create a linear module.
234
+ """
235
+ return nn.Linear(*args, **kwargs)
236
+
237
+
238
+ def avg_pool_nd(dims, *args, **kwargs):
239
+ """
240
+ Create a 1D, 2D, or 3D average pooling module.
241
+ """
242
+ if dims == 1:
243
+ return nn.AvgPool1d(*args, **kwargs)
244
+ elif dims == 2:
245
+ return nn.AvgPool2d(*args, **kwargs)
246
+ elif dims == 3:
247
+ return nn.AvgPool3d(*args, **kwargs)
248
+ raise ValueError(f"unsupported dimensions: {dims}")
249
+
250
+
251
+ class HybridConditioner(nn.Module):
252
+
253
+ def __init__(self, c_concat_config, c_crossattn_config):
254
+ super().__init__()
255
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
256
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257
+
258
+ def forward(self, c_concat, c_crossattn):
259
+ c_concat = self.concat_conditioner(c_concat)
260
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
261
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262
+
263
+
264
+ def noise_like(shape, device, repeat=False):
265
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266
+ noise = lambda: torch.randn(shape, device=device)
267
+ return repeat_noise() if repeat else noise()
ldm/modules/discriminator/model.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ class ActNorm(nn.Module):
6
+ def __init__(self, num_features, logdet=False, affine=True,
7
+ allow_reverse_init=False):
8
+ assert affine
9
+ super().__init__()
10
+ self.logdet = logdet
11
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
12
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
13
+ self.allow_reverse_init = allow_reverse_init
14
+
15
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
16
+
17
+ def initialize(self, input):
18
+ with torch.no_grad():
19
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
20
+ mean = (
21
+ flatten.mean(1)
22
+ .unsqueeze(1)
23
+ .unsqueeze(2)
24
+ .unsqueeze(3)
25
+ .permute(1, 0, 2, 3)
26
+ )
27
+ std = (
28
+ flatten.std(1)
29
+ .unsqueeze(1)
30
+ .unsqueeze(2)
31
+ .unsqueeze(3)
32
+ .permute(1, 0, 2, 3)
33
+ )
34
+
35
+ self.loc.data.copy_(-mean)
36
+ self.scale.data.copy_(1 / (std + 1e-6))
37
+
38
+ def forward(self, input, reverse=False):
39
+ if reverse:
40
+ return self.reverse(input)
41
+ if len(input.shape) == 2:
42
+ input = input[:, :, None, None]
43
+ squeeze = True
44
+ else:
45
+ squeeze = False
46
+
47
+ _, _, height, width = input.shape
48
+
49
+ if self.training and self.initialized.item() == 0:
50
+ self.initialize(input)
51
+ self.initialized.fill_(1)
52
+
53
+ h = self.scale * (input + self.loc)
54
+
55
+ if squeeze:
56
+ h = h.squeeze(-1).squeeze(-1)
57
+
58
+ if self.logdet:
59
+ log_abs = torch.log(torch.abs(self.scale))
60
+ logdet = height * width * torch.sum(log_abs)
61
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
62
+ return h, logdet
63
+
64
+ return h
65
+
66
+ def reverse(self, output):
67
+ if self.training and self.initialized.item() == 0:
68
+ if not self.allow_reverse_init:
69
+ raise RuntimeError(
70
+ "Initializing ActNorm in reverse direction is "
71
+ "disabled by default. Use allow_reverse_init=True to enable."
72
+ )
73
+ else:
74
+ self.initialize(output)
75
+ self.initialized.fill_(1)
76
+
77
+ if len(output.shape) == 2:
78
+ output = output[:, :, None, None]
79
+ squeeze = True
80
+ else:
81
+ squeeze = False
82
+
83
+ h = output / self.scale - self.loc
84
+
85
+ if squeeze:
86
+ h = h.squeeze(-1).squeeze(-1)
87
+ return h
88
+
89
+ def weights_init(m):
90
+ classname = m.__class__.__name__
91
+ if classname.find('Conv') != -1:
92
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
93
+ elif classname.find('BatchNorm') != -1:
94
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
95
+ nn.init.constant_(m.bias.data, 0)
96
+
97
+
98
+ class NLayerDiscriminator(nn.Module):
99
+ """Defines a PatchGAN discriminator as in Pix2Pix
100
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
101
+ """
102
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
103
+ """Construct a PatchGAN discriminator
104
+ Parameters:
105
+ input_nc (int) -- the number of channels in input images
106
+ ndf (int) -- the number of filters in the last conv layer
107
+ n_layers (int) -- the number of conv layers in the discriminator
108
+ norm_layer -- normalization layer
109
+ """
110
+ super(NLayerDiscriminator, self).__init__()
111
+ if not use_actnorm:
112
+ norm_layer = nn.BatchNorm2d
113
+ else:
114
+ norm_layer = ActNorm
115
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
116
+ use_bias = norm_layer.func != nn.BatchNorm2d
117
+ else:
118
+ use_bias = norm_layer != nn.BatchNorm2d
119
+
120
+ kw = 4
121
+ padw = 1
122
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
123
+ nf_mult = 1
124
+ nf_mult_prev = 1
125
+ for n in range(1, n_layers): # gradually increase the number of filters
126
+ nf_mult_prev = nf_mult
127
+ nf_mult = min(2 ** n, 8)
128
+ sequence += [
129
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
130
+ norm_layer(ndf * nf_mult),
131
+ nn.LeakyReLU(0.2, True)
132
+ ]
133
+
134
+ nf_mult_prev = nf_mult
135
+ nf_mult = min(2 ** n_layers, 8)
136
+ sequence += [
137
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
138
+ norm_layer(ndf * nf_mult),
139
+ nn.LeakyReLU(0.2, True)
140
+ ]
141
+ # output 1 channel prediction map
142
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
143
+ self.main = nn.Sequential(*sequence)
144
+
145
+ def forward(self, input):
146
+ """Standard forward."""
147
+ return self.main(input)
148
+
149
+ class NLayerDiscriminator1dFeats(NLayerDiscriminator):
150
+ """Defines a PatchGAN discriminator as in Pix2Pix
151
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
152
+ """
153
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
154
+ """Construct a PatchGAN discriminator
155
+ Parameters:
156
+ input_nc (int) -- the number of channels in input feats
157
+ ndf (int) -- the number of filters in the last conv layer
158
+ n_layers (int) -- the number of conv layers in the discriminator
159
+ norm_layer -- normalization layer
160
+ """
161
+ super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
162
+
163
+ if not use_actnorm:
164
+ norm_layer = nn.BatchNorm1d
165
+ else:
166
+ norm_layer = ActNorm
167
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm has affine parameters
168
+ use_bias = norm_layer.func != nn.BatchNorm1d
169
+ else:
170
+ use_bias = norm_layer != nn.BatchNorm1d
171
+
172
+ kw = 4
173
+ padw = 1
174
+ sequence = [nn.Conv1d(input_nc, input_nc//2, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
175
+ nf_mult = input_nc//2
176
+ nf_mult_prev = 1
177
+ for n in range(1, n_layers): # gradually decrease the number of filters
178
+ nf_mult_prev = nf_mult
179
+ nf_mult = max(nf_mult_prev // (2 ** n), 8)
180
+ sequence += [
181
+ nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
182
+ norm_layer(nf_mult),
183
+ nn.LeakyReLU(0.2, True)
184
+ ]
185
+
186
+ nf_mult_prev = nf_mult
187
+ nf_mult = max(nf_mult_prev // (2 ** n), 8)
188
+ sequence += [
189
+ nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
190
+ norm_layer(nf_mult),
191
+ nn.LeakyReLU(0.2, True)
192
+ ]
193
+ nf_mult_prev = nf_mult
194
+ nf_mult = max(nf_mult_prev // (2 ** n), 8)
195
+ sequence += [
196
+ nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
197
+ norm_layer(nf_mult),
198
+ nn.LeakyReLU(0.2, True)
199
+ ]
200
+ # output 1 channel prediction map
201
+ sequence += [nn.Conv1d(nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
202
+ self.main = nn.Sequential(*sequence)
203
+
204
+
205
+ class NLayerDiscriminator1dSpecs(NLayerDiscriminator):
206
+ """Defines a PatchGAN discriminator as in Pix2Pix
207
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
208
+ """
209
+ def __init__(self, input_nc=80, ndf=64, n_layers=3, use_actnorm=False):
210
+ """Construct a PatchGAN discriminator
211
+ Parameters:
212
+ input_nc (int) -- the number of channels in input specs
213
+ ndf (int) -- the number of filters in the last conv layer
214
+ n_layers (int) -- the number of conv layers in the discriminator
215
+ norm_layer -- normalization layer
216
+ """
217
+ super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
218
+
219
+ if not use_actnorm:
220
+ norm_layer = nn.BatchNorm1d
221
+ else:
222
+ norm_layer = ActNorm
223
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm has affine parameters
224
+ use_bias = norm_layer.func != nn.BatchNorm1d
225
+ else:
226
+ use_bias = norm_layer != nn.BatchNorm1d
227
+
228
+ kw = 4
229
+ padw = 1
230
+ sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
231
+ nf_mult = 1
232
+ nf_mult_prev = 1
233
+ for n in range(1, n_layers): # gradually decrease the number of filters
234
+ nf_mult_prev = nf_mult
235
+ nf_mult = min(2 ** n, 8)
236
+ sequence += [
237
+ nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
238
+ norm_layer(ndf * nf_mult),
239
+ nn.LeakyReLU(0.2, True)
240
+ ]
241
+
242
+ nf_mult_prev = nf_mult
243
+ nf_mult = min(2 ** n_layers, 8)
244
+ sequence += [
245
+ nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
246
+ norm_layer(ndf * nf_mult),
247
+ nn.LeakyReLU(0.2, True)
248
+ ]
249
+ # output 1 channel prediction map
250
+ sequence += [nn.Conv1d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
251
+ self.main = nn.Sequential(*sequence)
252
+
253
+ def forward(self, input):
254
+ """Standard forward."""
255
+ # (B, C, L)
256
+ input = input.squeeze(1)
257
+ input = self.main(input)
258
+ return input
259
+
260
+
261
+ if __name__ == '__main__':
262
+ import torch
263
+
264
+ ## FEATURES
265
+ disc_in_channels = 2048
266
+ disc_num_layers = 2
267
+ use_actnorm = False
268
+ disc_ndf = 64
269
+ discriminator = NLayerDiscriminator1dFeats(input_nc=disc_in_channels, n_layers=disc_num_layers,
270
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
271
+ inputs = torch.rand((6, 2048, 212))
272
+ outputs = discriminator(inputs)
273
+ print(outputs.shape)
274
+
275
+ ## AUDIO
276
+ disc_in_channels = 1
277
+ disc_num_layers = 3
278
+ use_actnorm = False
279
+ disc_ndf = 64
280
+ discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
281
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
282
+ inputs = torch.rand((6, 1, 80, 848))
283
+ outputs = discriminator(inputs)
284
+ print(outputs.shape)
285
+
286
+ ## IMAGE
287
+ disc_in_channels = 3
288
+ disc_num_layers = 3
289
+ use_actnorm = False
290
+ disc_ndf = 64
291
+ discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
292
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
293
+ inputs = torch.rand((6, 3, 256, 256))
294
+ outputs = discriminator(inputs)
295
+ print(outputs.shape)
ldm/modules/discriminator/multi_window_disc.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Discriminator2DFactory(nn.Module):
7
+ def __init__(self, time_length, freq_length=80, kernel=(3, 3), c_in=1, hidden_size=128,
8
+ norm_type='bn', reduction='sum'):
9
+ super(Discriminator2DFactory, self).__init__()
10
+ padding = (kernel[0] // 2, kernel[1] // 2)
11
+
12
+ def discriminator_block(in_filters, out_filters, first=False):
13
+ """
14
+ Input: (B, in, 2H, 2W)
15
+ Output:(B, out, H, W)
16
+ """
17
+ conv = nn.Conv2d(in_filters, out_filters, kernel, (2, 2), padding)
18
+ if norm_type == 'sn':
19
+ conv = nn.utils.spectral_norm(conv)
20
+ block = [
21
+ conv, # padding = kernel//2
22
+ nn.LeakyReLU(0.2, inplace=True),
23
+ nn.Dropout2d(0.25)
24
+ ]
25
+ if norm_type == 'bn' and not first:
26
+ block.append(nn.BatchNorm2d(out_filters, 0.8))
27
+ if norm_type == 'in' and not first:
28
+ block.append(nn.InstanceNorm2d(out_filters, affine=True))
29
+ block = nn.Sequential(*block)
30
+ return block
31
+
32
+ self.model = nn.ModuleList([
33
+ discriminator_block(c_in, hidden_size, first=True),
34
+ discriminator_block(hidden_size, hidden_size),
35
+ discriminator_block(hidden_size, hidden_size),
36
+ ])
37
+
38
+ self.reduction = reduction
39
+ ds_size = (time_length // 2 ** 3, (freq_length + 7) // 2 ** 3)
40
+ if reduction != 'none':
41
+ # The height and width of downsampled image
42
+ self.adv_layer = nn.Linear(hidden_size * ds_size[0] * ds_size[1], 1)
43
+ else:
44
+ self.adv_layer = nn.Linear(hidden_size * ds_size[1], 1)
45
+
46
+ def forward(self, x):
47
+ """
48
+
49
+ :param x: [B, C, T, n_bins]
50
+ :return: validity: [B, 1], h: List of hiddens
51
+ """
52
+ h = []
53
+ for l in self.model:
54
+ x = l(x)
55
+ h.append(x)
56
+ if self.reduction != 'none':
57
+ x = x.view(x.shape[0], -1)
58
+ validity = self.adv_layer(x) # [B, 1]
59
+ else:
60
+ B, _, T_, _ = x.shape
61
+ x = x.transpose(1, 2).reshape(B, T_, -1)
62
+ validity = self.adv_layer(x)[:, :, 0] # [B, T]
63
+ return validity, h
64
+
65
+
66
+ class MultiWindowDiscriminator(nn.Module):
67
+ def __init__(self, time_lengths, cond_size=0, freq_length=80, kernel=(3, 3),
68
+ c_in=1, hidden_size=128, norm_type='bn', reduction='sum'):
69
+ super(MultiWindowDiscriminator, self).__init__()
70
+ self.win_lengths = time_lengths
71
+ self.reduction = reduction
72
+
73
+ self.conv_layers = nn.ModuleList()
74
+ if cond_size > 0:
75
+ self.cond_proj_layers = nn.ModuleList()
76
+ self.mel_proj_layers = nn.ModuleList()
77
+ for time_length in time_lengths:
78
+ conv_layer = [
79
+ Discriminator2DFactory(
80
+ time_length, freq_length, kernel, c_in=c_in, hidden_size=hidden_size,
81
+ norm_type=norm_type, reduction=reduction)
82
+ ]
83
+ self.conv_layers += conv_layer
84
+ if cond_size > 0:
85
+ self.cond_proj_layers.append(nn.Linear(cond_size, freq_length))
86
+ self.mel_proj_layers.append(nn.Linear(freq_length, freq_length))
87
+
88
+ def forward(self, x, x_len, cond=None, start_frames_wins=None):
89
+ '''
90
+ Args:
91
+ x (tensor): input mel, (B, c_in, T, n_bins).
92
+ x_length (tensor): len of per mel. (B,).
93
+
94
+ Returns:
95
+ tensor : (B).
96
+ '''
97
+ validity = []
98
+ if start_frames_wins is None:
99
+ start_frames_wins = [None] * len(self.conv_layers)
100
+ h = []
101
+ for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins):
102
+ x_clip, c_clip, start_frames = self.clip(
103
+ x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C)
104
+ start_frames_wins[i] = start_frames
105
+ if x_clip is None:
106
+ continue
107
+ if cond is not None:
108
+ x_clip = self.mel_proj_layers[i](x_clip) # (B, 1, win_length, C)
109
+ c_clip = self.cond_proj_layers[i](c_clip)[:, None] # (B, 1, win_length, C)
110
+ x_clip = x_clip + c_clip
111
+ x_clip, h_ = self.conv_layers[i](x_clip)
112
+ h += h_
113
+ validity.append(x_clip)
114
+ if len(validity) != len(self.conv_layers):
115
+ return None, start_frames_wins, h
116
+ if self.reduction == 'sum':
117
+ validity = sum(validity) # [B]
118
+ elif self.reduction == 'stack':
119
+ validity = torch.stack(validity, -1) # [B, W_L]
120
+ elif self.reduction == 'none':
121
+ validity = torch.cat(validity, -1) # [B, W_sum]
122
+ return validity, start_frames_wins, h
123
+
124
+ def clip(self, x, cond, x_len, win_length, start_frames=None):
125
+ '''Ramdom clip x to win_length.
126
+ Args:
127
+ x (tensor) : (B, c_in, T, n_bins).
128
+ cond (tensor) : (B, T, H).
129
+ x_len (tensor) : (B,).
130
+ win_length (int): target clip length
131
+
132
+ Returns:
133
+ (tensor) : (B, c_in, win_length, n_bins).
134
+
135
+ '''
136
+ T_start = 0
137
+ T_end = x_len.max() - win_length
138
+ if T_end < 0:
139
+ return None, None, start_frames
140
+ T_end = T_end.item()
141
+ if start_frames is None:
142
+ start_frame = np.random.randint(low=T_start, high=T_end + 1)
143
+ start_frames = [start_frame] * x.size(0)
144
+ else:
145
+ start_frame = start_frames[0]
146
+ x_batch = x[:, :, start_frame: start_frame + win_length]
147
+ c_batch = cond[:, start_frame: start_frame + win_length] if cond is not None else None
148
+ return x_batch, c_batch, start_frames
149
+
150
+
151
+ class Discriminator(nn.Module):
152
+ def __init__(self, time_lengths=[32, 64, 128], freq_length=80, cond_size=0, kernel=(3, 3), c_in=1,
153
+ hidden_size=128, norm_type='bn', reduction='sum', uncond_disc=True):
154
+ super(Discriminator, self).__init__()
155
+ self.time_lengths = time_lengths
156
+ self.cond_size = cond_size
157
+ self.reduction = reduction
158
+ self.uncond_disc = uncond_disc
159
+ if uncond_disc:
160
+ self.discriminator = MultiWindowDiscriminator(
161
+ freq_length=freq_length,
162
+ time_lengths=time_lengths,
163
+ kernel=kernel,
164
+ c_in=c_in, hidden_size=hidden_size, norm_type=norm_type,
165
+ reduction=reduction
166
+ )
167
+ if cond_size > 0:
168
+ self.cond_disc = MultiWindowDiscriminator(
169
+ freq_length=freq_length,
170
+ time_lengths=time_lengths,
171
+ cond_size=cond_size,
172
+ kernel=kernel,
173
+ c_in=c_in, hidden_size=hidden_size, norm_type=norm_type,
174
+ reduction=reduction
175
+ )
176
+
177
+ def forward(self, x, cond=None, start_frames_wins=None):
178
+ """
179
+
180
+ :param x: [B, T, 80]
181
+ :param cond: [B, T, cond_size]
182
+ :param return_y_only:
183
+ :return:
184
+ """
185
+ if len(x.shape) == 3:
186
+ x = x[:, None, :, :]
187
+ x_len = x.sum([1, -1]).ne(0).int().sum([-1])
188
+ ret = {'y_c': None, 'y': None}
189
+ if self.uncond_disc:
190
+ ret['y'], start_frames_wins, ret['h'] = self.discriminator(
191
+ x, x_len, start_frames_wins=start_frames_wins)
192
+ if self.cond_size > 0 and cond is not None:
193
+ ret['y_c'], start_frames_wins, ret['h_c'] = self.cond_disc(
194
+ x, x_len, cond, start_frames_wins=start_frames_wins)
195
+ ret['start_frames_wins'] = start_frames_wins
196
+ return ret
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
ldm/modules/encoders/CLAP/CLAPWrapper.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torchaudio
3
+ from torch._six import string_classes
4
+ import collections
5
+ import re
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from transformers import AutoTokenizer
9
+ from ldm.modules.encoders.CLAP.utils import read_config_as_args
10
+ from ldm.modules.encoders.CLAP.clap import CLAP
11
+ import math
12
+ import torchaudio.transforms as T
13
+ import os
14
+ import torch
15
+ from importlib_resources import files
16
+
17
+
18
+ class CLAPWrapper():
19
+ """
20
+ A class for interfacing CLAP model.
21
+ """
22
+
23
+ def __init__(self, model_fp, device):
24
+ self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
25
+ self.file_path = os.path.realpath(__file__)
26
+ self.default_collate_err_msg_format = (
27
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
28
+ "dicts or lists; found {}")
29
+ self.config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
30
+ self.model_fp = model_fp
31
+ self.device = device
32
+ self.clap, self.tokenizer, self.args = self.load_clap()
33
+
34
+ def load_clap(self):
35
+ r"""Load CLAP model with args from config file"""
36
+
37
+ args = read_config_as_args(self.config_as_str, is_config_str=True)
38
+
39
+ if 'bert' in args.text_model:
40
+ self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
41
+ else:
42
+ self.token_keys = ['input_ids', 'attention_mask']
43
+
44
+ clap = CLAP(
45
+ audioenc_name=args.audioenc_name,
46
+ sample_rate=args.sampling_rate,
47
+ window_size=args.window_size,
48
+ hop_size=args.hop_size,
49
+ mel_bins=args.mel_bins,
50
+ fmin=args.fmin,
51
+ fmax=args.fmax,
52
+ classes_num=args.num_classes,
53
+ out_emb=args.out_emb,
54
+ text_model=args.text_model,
55
+ transformer_embed_dim=args.transformer_embed_dim,
56
+ d_proj=args.d_proj
57
+ )
58
+
59
+ # Load pretrained weights for model
60
+ model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
61
+ clap.load_state_dict(model_state_dict)
62
+
63
+ clap.eval() # set clap in eval mode
64
+ tokenizer = AutoTokenizer.from_pretrained(args.text_model)
65
+
66
+ clap = clap.to(self.device)
67
+ tokenizer = tokenizer.to(self.device)
68
+
69
+ return clap, tokenizer, args
70
+
71
+ def default_collate(self, batch):
72
+ r"""Puts each data field into a tensor with outer dimension batch size"""
73
+ elem = batch[0]
74
+ elem_type = type(elem)
75
+ if isinstance(elem, torch.Tensor):
76
+ out = None
77
+ if torch.utils.data.get_worker_info() is not None:
78
+ # If we're in a background process, concatenate directly into a
79
+ # shared memory tensor to avoid an extra copy
80
+ numel = sum([x.numel() for x in batch])
81
+ storage = elem.storage()._new_shared(numel)
82
+ out = elem.new(storage)
83
+ return torch.stack(batch, 0, out=out)
84
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
85
+ and elem_type.__name__ != 'string_':
86
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
87
+ # array of string classes and object
88
+ if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
89
+ raise TypeError(
90
+ self.default_collate_err_msg_format.format(elem.dtype))
91
+
92
+ return self.default_collate([torch.as_tensor(b) for b in batch])
93
+ elif elem.shape == (): # scalars
94
+ return torch.as_tensor(batch)
95
+ elif isinstance(elem, float):
96
+ return torch.tensor(batch, dtype=torch.float64)
97
+ elif isinstance(elem, int):
98
+ return torch.tensor(batch)
99
+ elif isinstance(elem, string_classes):
100
+ return batch
101
+ elif isinstance(elem, collections.abc.Mapping):
102
+ return {key: self.default_collate([d[key] for d in batch]) for key in elem}
103
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
104
+ return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
105
+ elif isinstance(elem, collections.abc.Sequence):
106
+ # check to make sure that the elements in batch have consistent size
107
+ it = iter(batch)
108
+ elem_size = len(next(it))
109
+ if not all(len(elem) == elem_size for elem in it):
110
+ raise RuntimeError(
111
+ 'each element in list of batch should be of equal size')
112
+ transposed = zip(*batch)
113
+ return [self.default_collate(samples) for samples in transposed]
114
+
115
+ raise TypeError(self.default_collate_err_msg_format.format(elem_type))
116
+
117
+ def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
118
+ r"""Loads audio file and returns raw audio."""
119
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
120
+ audio_time_series, sample_rate = torchaudio.load(audio_path)
121
+ resample_rate = self.args.sampling_rate
122
+ if resample:
123
+ resampler = T.Resample(sample_rate, resample_rate)
124
+ audio_time_series = resampler(audio_time_series)
125
+ audio_time_series = audio_time_series.reshape(-1)
126
+
127
+ # audio_time_series is shorter than predefined audio duration,
128
+ # so audio_time_series is extended
129
+ if audio_duration*sample_rate >= audio_time_series.shape[0]:
130
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
131
+ audio_time_series.shape[0]))
132
+ # Repeat audio_time_series by repeat_factor to match audio_duration
133
+ audio_time_series = audio_time_series.repeat(repeat_factor)
134
+ # remove excess part of audio_time_series
135
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
136
+ else:
137
+ # audio_time_series is longer than predefined audio duration,
138
+ # so audio_time_series is trimmed
139
+ start_index = random.randrange(
140
+ audio_time_series.shape[0] - audio_duration*sample_rate)
141
+ audio_time_series = audio_time_series[start_index:start_index +
142
+ audio_duration*sample_rate]
143
+ return torch.FloatTensor(audio_time_series)
144
+
145
+ def preprocess_audio(self, audio_files, resample):
146
+ r"""Load list of audio files and return raw audio"""
147
+ audio_tensors = []
148
+ for audio_file in audio_files:
149
+ audio_tensor = self.load_audio_into_tensor(
150
+ audio_file, self.args.duration, resample)
151
+ audio_tensor = audio_tensor.reshape(1, -1).to(self.device)
152
+ audio_tensors.append(audio_tensor)
153
+ return self.default_collate(audio_tensors)
154
+
155
+ def preprocess_text(self, text_queries, text_len=100):
156
+ r"""Load list of class labels and return tokenized text"""
157
+ device = next(self.clap.parameters()).device
158
+ tokenized_texts = []
159
+ for ttext in text_queries:
160
+ tok = self.tokenizer.encode_plus(
161
+ text=ttext, add_special_tokens=True, max_length=text_len, pad_to_max_length=True, return_tensors="pt")
162
+ for key in self.token_keys:
163
+ tok[key] = tok[key].reshape(-1).to(device)
164
+ tokenized_texts.append(tok)
165
+ return self.default_collate(tokenized_texts)
166
+
167
+ def get_text_embeddings(self, class_labels):
168
+ r"""Load list of class labels and return text embeddings"""
169
+ preprocessed_text = self.preprocess_text(class_labels)
170
+ text_embeddings = self._get_text_embeddings(preprocessed_text)
171
+ text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
172
+ return text_embeddings
173
+
174
+ def get_audio_embeddings(self, audio_files, resample):
175
+ r"""Load list of audio files and return a audio embeddings"""
176
+ preprocessed_audio = self.preprocess_audio(audio_files, resample)
177
+ audio_embeddings = self._get_audio_embeddings(preprocessed_audio)
178
+ audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
179
+ return audio_embeddings
180
+
181
+ def _get_text_embeddings(self, preprocessed_text):
182
+ r"""Load preprocessed text and return text embeddings"""
183
+ with torch.no_grad():
184
+ text_embeddings = self.clap.caption_encoder(preprocessed_text)
185
+ text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
186
+ return text_embeddings
187
+
188
+ def _get_audio_embeddings(self, preprocessed_audio):
189
+ r"""Load preprocessed audio and return a audio embeddings"""
190
+ with torch.no_grad():
191
+ preprocessed_audio = preprocessed_audio.reshape(
192
+ preprocessed_audio.shape[0], preprocessed_audio.shape[2])
193
+ #Append [0] the audio emebdding, [1] has output class probabilities
194
+ audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0]
195
+ audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
196
+ return audio_embeddings
197
+
198
+ def compute_similarity(self, audio_embeddings, text_embeddings):
199
+ r"""Compute similarity between text and audio embeddings"""
200
+ logit_scale = self.clap.logit_scale.exp()
201
+ similarity = logit_scale*text_embeddings @ audio_embeddings.T
202
+ return similarity.T
203
+
204
+ def _generic_batch_inference(self, func, *args):
205
+ r"""Process audio and/or text per batch"""
206
+ input_tmp = args[0]
207
+ batch_size = args[-1]
208
+ # args[0] has audio_files, args[1] has class_labels
209
+ inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
210
+ args0_len = len(args[0])
211
+ # compute text_embeddings once for all the audio_files batches
212
+ if len(inputs) == 2:
213
+ text_embeddings = self.get_text_embeddings(args[1])
214
+ inputs = [args[0], args[1], text_embeddings]
215
+ dataset_idx = 0
216
+ for _ in range(math.ceil(args0_len/batch_size)):
217
+ next_batch_idx = dataset_idx + batch_size
218
+ # batch size is bigger than available audio/text items
219
+ if next_batch_idx >= args0_len:
220
+ inputs[0] = input_tmp[dataset_idx:]
221
+ return func(*tuple(inputs))
222
+ else:
223
+ inputs[0] = input_tmp[dataset_idx:next_batch_idx]
224
+ yield func(*tuple(inputs))
225
+ dataset_idx = next_batch_idx
226
+
227
+ def get_audio_embeddings_per_batch(self, audio_files, batch_size):
228
+ r"""Load preprocessed audio and return a audio embeddings per batch"""
229
+ return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
230
+
231
+ def get_text_embeddings_per_batch(self, class_labels, batch_size):
232
+ r"""Load preprocessed text and return text embeddings per batch"""
233
+ return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
234
+
235
+ def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
236
+ r"""Compute classification probabilities for each audio recording in a batch and each class label"""
237
+ return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
238
+
239
+ if __name__ == '__main__':
240
+
241
+ # Load and initialize CLAP
242
+ weights_path = "/home1/huangrongjie/Project/Diffusion/LatentDiffusion/CLAP/CLAP_weights_2022.pth"
243
+ clap_model = CLAPWrapper(weights_path, use_cuda=False)
244
+
245
+ y = ["A woman talks nearby as water pours", "Multiple clanging and clanking sounds"]
246
+ x = ['/home2/huangjiawei/data/audiocaps/train/Yr1nicOVtvkQ.wav', '/home2/huangjiawei/data/audiocaps/train/YUDGBjjwyaqE.wav']
247
+
248
+ # Computing text embeddings
249
+ text_embeddings = clap_model.get_text_embeddings(y)
250
+
251
+ import ipdb
252
+ ipdb.set_trace()
253
+
254
+ # Computing audio embeddings
255
+ audio_embeddings = clap_model.get_audio_embeddings(x, resample=True)
256
+ similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
257
+
ldm/modules/encoders/CLAP/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import clap
2
+ from . import audio
3
+ from . import utils
ldm/modules/encoders/CLAP/audio.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
5
+
6
+ def get_audio_encoder(name: str):
7
+ if name == "Cnn14":
8
+ return Cnn14
9
+ else:
10
+ raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_channels, out_channels):
15
+
16
+ super(ConvBlock, self).__init__()
17
+
18
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
19
+ out_channels=out_channels,
20
+ kernel_size=(3, 3), stride=(1, 1),
21
+ padding=(1, 1), bias=False)
22
+
23
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
24
+ out_channels=out_channels,
25
+ kernel_size=(3, 3), stride=(1, 1),
26
+ padding=(1, 1), bias=False)
27
+
28
+ self.bn1 = nn.BatchNorm2d(out_channels)
29
+ self.bn2 = nn.BatchNorm2d(out_channels)
30
+
31
+
32
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
33
+
34
+ x = input
35
+ x = F.relu_(self.bn1(self.conv1(x)))
36
+ x = F.relu_(self.bn2(self.conv2(x)))
37
+ if pool_type == 'max':
38
+ x = F.max_pool2d(x, kernel_size=pool_size)
39
+ elif pool_type == 'avg':
40
+ x = F.avg_pool2d(x, kernel_size=pool_size)
41
+ elif pool_type == 'avg+max':
42
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
43
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
44
+ x = x1 + x2
45
+ else:
46
+ raise Exception('Incorrect argument!')
47
+
48
+ return x
49
+
50
+
51
+ class ConvBlock5x5(nn.Module):
52
+ def __init__(self, in_channels, out_channels):
53
+
54
+ super(ConvBlock5x5, self).__init__()
55
+
56
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
57
+ out_channels=out_channels,
58
+ kernel_size=(5, 5), stride=(1, 1),
59
+ padding=(2, 2), bias=False)
60
+
61
+ self.bn1 = nn.BatchNorm2d(out_channels)
62
+
63
+
64
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
65
+
66
+ x = input
67
+ x = F.relu_(self.bn1(self.conv1(x)))
68
+ if pool_type == 'max':
69
+ x = F.max_pool2d(x, kernel_size=pool_size)
70
+ elif pool_type == 'avg':
71
+ x = F.avg_pool2d(x, kernel_size=pool_size)
72
+ elif pool_type == 'avg+max':
73
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
74
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
75
+ x = x1 + x2
76
+ else:
77
+ raise Exception('Incorrect argument!')
78
+
79
+ return x
80
+
81
+
82
+ class AttBlock(nn.Module):
83
+ def __init__(self, n_in, n_out, activation='linear', temperature=1.):
84
+ super(AttBlock, self).__init__()
85
+
86
+ self.activation = activation
87
+ self.temperature = temperature
88
+ self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
89
+ self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
90
+
91
+ self.bn_att = nn.BatchNorm1d(n_out)
92
+
93
+ def forward(self, x):
94
+ # x: (n_samples, n_in, n_time)
95
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
96
+ cla = self.nonlinear_transform(self.cla(x))
97
+ x = torch.sum(norm_att * cla, dim=2)
98
+ return x, norm_att, cla
99
+
100
+ def nonlinear_transform(self, x):
101
+ if self.activation == 'linear':
102
+ return x
103
+ elif self.activation == 'sigmoid':
104
+ return torch.sigmoid(x)
105
+
106
+
107
+ class Cnn14(nn.Module):
108
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
109
+ fmax, classes_num, out_emb):
110
+
111
+ super(Cnn14, self).__init__()
112
+
113
+ window = 'hann'
114
+ center = True
115
+ pad_mode = 'reflect'
116
+ ref = 1.0
117
+ amin = 1e-10
118
+ top_db = None
119
+
120
+ # Spectrogram extractor
121
+ self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
122
+ win_length=window_size, window=window, center=center, pad_mode=pad_mode,
123
+ freeze_parameters=True)
124
+
125
+ # Logmel feature extractor
126
+ self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
127
+ n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
128
+ freeze_parameters=True)
129
+
130
+ self.bn0 = nn.BatchNorm2d(64)
131
+
132
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
133
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
134
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
135
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
136
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
137
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
138
+
139
+ # out_emb is 2048 for best Cnn14
140
+ self.fc1 = nn.Linear(2048, out_emb, bias=True)
141
+ self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True)
142
+
143
+ def forward(self, input, mixup_lambda=None):
144
+ """
145
+ Input: (batch_size, data_length)
146
+ """
147
+
148
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
149
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
150
+
151
+ x = x.transpose(1, 3)
152
+ x = self.bn0(x)
153
+ x = x.transpose(1, 3)
154
+
155
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
156
+ x = F.dropout(x, p=0.2, training=self.training)
157
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
158
+ x = F.dropout(x, p=0.2, training=self.training)
159
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
160
+ x = F.dropout(x, p=0.2, training=self.training)
161
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
162
+ x = F.dropout(x, p=0.2, training=self.training)
163
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
164
+ x = F.dropout(x, p=0.2, training=self.training)
165
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
166
+ x = F.dropout(x, p=0.2, training=self.training)
167
+ x = torch.mean(x, dim=3)
168
+
169
+ (x1, _) = torch.max(x, dim=2)
170
+ x2 = torch.mean(x, dim=2)
171
+ x = x1 + x2
172
+ x = F.dropout(x, p=0.5, training=self.training)
173
+ x = F.relu_(self.fc1(x))
174
+ embedding = F.dropout(x, p=0.5, training=self.training)
175
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
176
+
177
+ output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}
178
+
179
+ return output_dict
ldm/modules/encoders/CLAP/clap.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from transformers import AutoModel
6
+ from .audio import get_audio_encoder
7
+
8
+ class Projection(nn.Module):
9
+ def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
10
+ super().__init__()
11
+ self.linear1 = nn.Linear(d_in, d_out, bias=False)
12
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
13
+ self.layer_norm = nn.LayerNorm(d_out)
14
+ self.drop = nn.Dropout(p)
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ embed1 = self.linear1(x)
18
+ embed2 = self.drop(self.linear2(F.gelu(embed1)))
19
+ embeds = self.layer_norm(embed1 + embed2)
20
+ return embeds
21
+
22
+ class AudioEncoder(nn.Module):
23
+ def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int,
24
+ hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
25
+ super().__init__()
26
+
27
+ audio_encoder = get_audio_encoder(audioenc_name)
28
+
29
+ self.base = audio_encoder(
30
+ sample_rate, window_size,
31
+ hop_size, mel_bins, fmin, fmax,
32
+ classes_num, d_in)
33
+
34
+ self.projection = Projection(d_in, d_out)
35
+
36
+ def forward(self, x):
37
+ out_dict = self.base(x)
38
+ audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
39
+ projected_vec = self.projection(audio_features)
40
+ return projected_vec, audio_classification_output
41
+
42
+ class TextEncoder(nn.Module):
43
+ def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
44
+ super().__init__()
45
+ self.base = AutoModel.from_pretrained(text_model)
46
+ self.projection = Projection(transformer_embed_dim, d_out)
47
+
48
+ def forward(self, x):
49
+ out = self.base(**x)[0]
50
+ out = out[:, 0, :] # get CLS token output
51
+ projected_vec = self.projection(out)
52
+ return projected_vec
53
+
54
+ class CLAP(nn.Module):
55
+ def __init__(self,
56
+ # audio
57
+ audioenc_name: str,
58
+ sample_rate: int,
59
+ window_size: int,
60
+ hop_size: int,
61
+ mel_bins: int,
62
+ fmin: int,
63
+ fmax: int,
64
+ classes_num: int,
65
+ out_emb: int,
66
+ # text
67
+ text_model: str,
68
+ transformer_embed_dim: int,
69
+ # common
70
+ d_proj: int,
71
+ ):
72
+ super().__init__()
73
+
74
+
75
+ self.audio_encoder = AudioEncoder(
76
+ audioenc_name, out_emb, d_proj,
77
+ sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
78
+
79
+ self.caption_encoder = TextEncoder(
80
+ d_proj, text_model, transformer_embed_dim
81
+ )
82
+
83
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
84
+
85
+ def forward(self, audio, text):
86
+ audio_embed, _ = self.audio_encoder(audio)
87
+ caption_embed = self.caption_encoder(text)
88
+
89
+ return caption_embed, audio_embed, self.logit_scale.exp()
ldm/modules/encoders/CLAP/config.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEXT ENCODER CONFIG
2
+ text_model: 'bert-base-uncased'
3
+ text_len: 100
4
+ transformer_embed_dim: 768
5
+ freeze_text_encoder_weights: True
6
+
7
+ # AUDIO ENCODER CONFIG
8
+ audioenc_name: 'Cnn14'
9
+ out_emb: 2048
10
+ sampling_rate: 44100
11
+ duration: 5
12
+ fmin: 50
13
+ fmax: 14000
14
+ n_fft: 1028
15
+ hop_size: 320
16
+ mel_bins: 64
17
+ window_size: 1024
18
+
19
+ # PROJECTION SPACE CONFIG
20
+ d_proj: 1024
21
+ temperature: 0.003
22
+
23
+ # TRAINING AND EVALUATION CONFIG
24
+ num_classes: 527
25
+ batch_size: 1024
26
+ demo: False
ldm/modules/encoders/CLAP/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import sys
4
+
5
+ def read_config_as_args(config_path,args=None,is_config_str=False):
6
+ return_dict = {}
7
+
8
+ if config_path is not None:
9
+ if is_config_str:
10
+ yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
11
+ else:
12
+ with open(config_path, "r") as f:
13
+ yml_config = yaml.load(f, Loader=yaml.FullLoader)
14
+
15
+ if args != None:
16
+ for k, v in yml_config.items():
17
+ if k in args.__dict__:
18
+ args.__dict__[k] = v
19
+ else:
20
+ sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
21
+ else:
22
+ for k, v in yml_config.items():
23
+ return_dict[k] = v
24
+
25
+ args = args if args != None else return_dict
26
+ return argparse.Namespace(**args)
ldm/modules/encoders/__init__.py ADDED
File without changes
ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+
5
+ from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
6
+ from torch.utils.checkpoint import checkpoint
7
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoTokenizer
8
+ from importlib_resources import files
9
+ from ldm.modules.encoders.CLAP.utils import read_config_as_args
10
+ from ldm.modules.encoders.CLAP.clap import TextEncoder
11
+ from ldm.util import default, count_params
12
+
13
+
14
+ class AbstractEncoder(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def encode(self, *args, **kwargs):
19
+ raise NotImplementedError
20
+
21
+
22
+ class ClassEmbedder(nn.Module):
23
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
24
+ super().__init__()
25
+ self.key = key
26
+ self.embedding = nn.Embedding(n_classes, embed_dim)
27
+
28
+ def forward(self, batch, key=None):
29
+ if key is None:
30
+ key = self.key
31
+ # this is for use in crossattn
32
+ c = batch[key][:, None]# (bsz,1)
33
+ c = self.embedding(c)
34
+ return c
35
+
36
+
37
+ class TransformerEmbedder(AbstractEncoder):
38
+ """Some transformer encoder layers"""
39
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
40
+ super().__init__()
41
+ self.device = device
42
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
43
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
44
+
45
+ def forward(self, tokens):
46
+ tokens = tokens.to(self.device) # meh
47
+ z = self.transformer(tokens, return_embeddings=True)
48
+ return z
49
+
50
+ def encode(self, x):
51
+ return self(x)
52
+
53
+
54
+ class BERTTokenizer(AbstractEncoder):
55
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
56
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
57
+ super().__init__()
58
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
59
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
60
+ self.device = device
61
+ self.vq_interface = vq_interface
62
+ self.max_length = max_length
63
+
64
+ def forward(self, text):
65
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
66
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
67
+ tokens = batch_encoding["input_ids"].to(self.device)
68
+ return tokens
69
+
70
+ @torch.no_grad()
71
+ def encode(self, text):
72
+ tokens = self(text)
73
+ if not self.vq_interface:
74
+ return tokens
75
+ return None, None, [None, None, tokens]
76
+
77
+ def decode(self, text):
78
+ return text
79
+
80
+
81
+ class BERTEmbedder(AbstractEncoder):# 这里不是用的pretrained bert,是用的transformers的BertTokenizer加自定义的TransformerWrapper
82
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
83
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
84
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
85
+ super().__init__()
86
+ self.use_tknz_fn = use_tokenizer
87
+ if self.use_tknz_fn:
88
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
89
+ self.device = device
90
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
91
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
92
+ emb_dropout=embedding_dropout)
93
+
94
+ def forward(self, text):
95
+ if self.use_tknz_fn:
96
+ tokens = self.tknz_fn(text)#.to(self.device)
97
+ else:
98
+ tokens = text
99
+ z = self.transformer(tokens, return_embeddings=True)
100
+ return z
101
+
102
+ def encode(self, text):
103
+ # output of length 77
104
+ return self(text)
105
+
106
+
107
+ class SpatialRescaler(nn.Module):
108
+ def __init__(self,
109
+ n_stages=1,
110
+ method='bilinear',
111
+ multiplier=0.5,
112
+ in_channels=3,
113
+ out_channels=None,
114
+ bias=False):
115
+ super().__init__()
116
+ self.n_stages = n_stages
117
+ assert self.n_stages >= 0
118
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
119
+ self.multiplier = multiplier
120
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
121
+ self.remap_output = out_channels is not None
122
+ if self.remap_output:
123
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
124
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
125
+
126
+ def forward(self,x):
127
+ for stage in range(self.n_stages):
128
+ x = self.interpolator(x, scale_factor=self.multiplier)
129
+
130
+
131
+ if self.remap_output:
132
+ x = self.channel_mapper(x)
133
+ return x
134
+
135
+ def encode(self, x):
136
+ return self(x)
137
+
138
+ def disabled_train(self, mode=True):
139
+ """Overwrite model.train with this function to make sure train/eval mode
140
+ does not change anymore."""
141
+ return self
142
+
143
+ class FrozenT5Embedder(AbstractEncoder):
144
+ """Uses the T5 transformer encoder for text"""
145
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
146
+ super().__init__()
147
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
148
+ self.transformer = T5EncoderModel.from_pretrained(version)
149
+ self.device = device
150
+ self.max_length = max_length # TODO: typical value?
151
+ if freeze:
152
+ self.freeze()
153
+
154
+ def freeze(self):
155
+ self.transformer = self.transformer.eval()
156
+ #self.train = disabled_train
157
+ for param in self.parameters():
158
+ param.requires_grad = False
159
+
160
+ def forward(self, text):
161
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
162
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
163
+ tokens = batch_encoding["input_ids"].to(self.device)
164
+ outputs = self.transformer(input_ids=tokens)
165
+
166
+ z = outputs.last_hidden_state
167
+ return z
168
+
169
+ def encode(self, text):
170
+ return self(text)
171
+
172
+
173
+ class FrozenCLAPEmbedder(AbstractEncoder):
174
+ """Uses the CLAP transformer encoder for text (from huggingface)"""
175
+ def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
176
+ super().__init__()
177
+
178
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
179
+ match_params = dict()
180
+ for key in list(model_state_dict.keys()):
181
+ if 'caption_encoder' in key:
182
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
183
+
184
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
185
+ args = read_config_as_args(config_as_str, is_config_str=True)
186
+
187
+ # To device
188
+ self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
189
+ self.caption_encoder = TextEncoder(
190
+ args.d_proj, args.text_model, args.transformer_embed_dim
191
+ )
192
+
193
+ self.max_length = max_length
194
+ self.device = device
195
+ if freeze: self.freeze()
196
+
197
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
198
+
199
+ def freeze(self):
200
+ self.caption_encoder.base = self.caption_encoder.base.eval()
201
+ for param in self.caption_encoder.base.parameters():
202
+ param.requires_grad = False
203
+
204
+
205
+ def encode(self, text):
206
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
207
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
208
+ tokens = batch_encoding["input_ids"].to(self.device)
209
+
210
+ outputs = self.caption_encoder.base(input_ids=tokens)
211
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
212
+ return z
213
+
214
+ class FrozenCLAPEmbedderNoLoad(AbstractEncoder):
215
+ def __init__(self, config, freeze=True, device="cpu", max_length=77):
216
+ super().__init__()
217
+ args = config
218
+
219
+ # To device
220
+ self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
221
+ self.caption_encoder = TextEncoder(
222
+ args.d_proj, args.text_model, args.transformer_embed_dim
223
+ )
224
+
225
+ self.max_length = max_length
226
+ self.device = device
227
+ if freeze: self.freeze()
228
+
229
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
230
+
231
+ def freeze(self):
232
+ self.caption_encoder.base = self.caption_encoder.base.eval()
233
+ for param in self.caption_encoder.base.parameters():
234
+ param.requires_grad = False
235
+
236
+
237
+ def encode(self, text):
238
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
239
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
240
+ tokens = batch_encoding["input_ids"].to(self.device)
241
+
242
+ outputs = self.caption_encoder.base(input_ids=tokens)
243
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
244
+ return z
245
+
246
+
247
+ class NewFrozenCLAPEmbedder(AbstractEncoder):
248
+ """Uses the CLAP transformer encoder for text (from huggingface)"""
249
+ def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
250
+ super().__init__()
251
+ # To device
252
+ from transformers import RobertaTokenizer
253
+ from ldm.modules.encoders.open_clap import create_model
254
+
255
+
256
+ model, model_cfg = create_model(
257
+ 'HTSAT-tiny',
258
+ 'roberta',
259
+ weights_path,
260
+ enable_fusion=True,
261
+ fusion_type='aff_2d'
262
+ )
263
+
264
+ del model.audio_branch, model.audio_transform, model.audio_projection
265
+ self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
266
+ self.model = model
267
+
268
+ self.max_length = max_length
269
+ self.device = device
270
+ if freeze: self.freeze()
271
+
272
+ param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
273
+ print(f'{self.model.__class__.__name__} comes with: {param_num / 1e+6:.3f} M params.')
274
+
275
+ def freeze(self):
276
+ self.model = self.model.eval()
277
+ for param in self.model.parameters():
278
+ param.requires_grad = False
279
+
280
+ def encode(self, text):
281
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
282
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
283
+ outputs = self.model.text_branch(input_ids=batch_encoding["input_ids"].to(self.device), attention_mask=batch_encoding["attention_mask"].to(self.device))
284
+ z = self.model.text_projection(outputs.last_hidden_state)
285
+ return z
286
+
287
+ class FrozenFLANEmbedder(AbstractEncoder):
288
+ """Uses the T5 transformer encoder for text"""
289
+ def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
290
+ super().__init__()
291
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
292
+ self.transformer = T5EncoderModel.from_pretrained(version)
293
+ self.device = device
294
+ self.max_length = max_length # TODO: typical value?
295
+ if freeze:
296
+ self.freeze()
297
+
298
+ def freeze(self):
299
+ self.transformer = self.transformer.eval()
300
+ #self.train = disabled_train
301
+ for param in self.parameters():
302
+ param.requires_grad = False
303
+
304
+ def forward(self, text):
305
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
306
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
307
+ tokens = batch_encoding["input_ids"].to(self.device)
308
+ outputs = self.transformer(input_ids=tokens)
309
+
310
+ z = outputs.last_hidden_state
311
+ return z
312
+
313
+ def encode(self, text):
314
+ return self(text)
ldm/modules/encoders/open_clap/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .factory import list_models, create_model, create_model_and_transforms, add_model_config
2
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
3
+ from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model
4
+ from .openai import load_openai_model, list_openai_models
5
+ from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
6
+ get_pretrained_url, download_pretrained
7
+ from .tokenizer import SimpleTokenizer, tokenize
8
+ from .transform import image_transform
ldm/modules/encoders/open_clap/bert.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
3
+ model = BertModel.from_pretrained("bert-base-uncased")
4
+ text = "Replace me by any text you'd like."
5
+
6
+ def bert_embeddings(text):
7
+ # text = "Replace me by any text you'd like."
8
+ encoded_input = tokenizer(text, return_tensors='pt')
9
+ output = model(**encoded_input)
10
+ return output
11
+
12
+ from transformers import RobertaTokenizer, RobertaModel
13
+
14
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
15
+ model = RobertaModel.from_pretrained('roberta-base')
16
+ text = "Replace me by any text you'd like."
17
+ def Roberta_embeddings(text):
18
+ # text = "Replace me by any text you'd like."
19
+ encoded_input = tokenizer(text, return_tensors='pt')
20
+ output = model(**encoded_input)
21
+ return output
22
+
23
+ from transformers import BartTokenizer, BartModel
24
+
25
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
26
+ model = BartModel.from_pretrained('facebook/bart-base')
27
+ text = "Replace me by any text you'd like."
28
+ def bart_embeddings(text):
29
+ # text = "Replace me by any text you'd like."
30
+ encoded_input = tokenizer(text, return_tensors='pt')
31
+ output = model(**encoded_input)
32
+ return output
ldm/modules/encoders/open_clap/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
ldm/modules/encoders/open_clap/factory.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ from .model import CLAP, convert_weights_to_fp16
12
+ from .openai import load_openai_model
13
+ from .pretrained import get_pretrained_url, download_pretrained
14
+ from .transform import image_transform
15
+
16
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
+
19
+
20
+ def _natural_key(string_):
21
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
+
23
+
24
+ def _rescan_model_configs():
25
+ global _MODEL_CONFIGS
26
+
27
+ config_ext = (".json",)
28
+ config_files = []
29
+ for config_path in _MODEL_CONFIG_PATHS:
30
+ if config_path.is_file() and config_path.suffix in config_ext:
31
+ config_files.append(config_path)
32
+ elif config_path.is_dir():
33
+ for ext in config_ext:
34
+ config_files.extend(config_path.glob(f"*{ext}"))
35
+
36
+ for cf in config_files:
37
+ with open(cf, "r") as f:
38
+ model_cfg = json.load(f)
39
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
40
+ _MODEL_CONFIGS[cf.stem] = model_cfg
41
+
42
+ _MODEL_CONFIGS = {
43
+ k: v
44
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
45
+ }
46
+
47
+
48
+ _rescan_model_configs() # initial populate of model config registry
49
+
50
+
51
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
52
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
53
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
54
+ state_dict = checkpoint["state_dict"]
55
+ else:
56
+ state_dict = checkpoint
57
+ if skip_params:
58
+ if next(iter(state_dict.items()))[0].startswith("module"):
59
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
60
+ # for k in state_dict:
61
+ # if k.startswith('transformer'):
62
+ # v = state_dict.pop(k)
63
+ # state_dict['text_branch.' + k[12:]] = v
64
+ return state_dict
65
+
66
+
67
+ def create_model(
68
+ amodel_name: str,
69
+ tmodel_name: str,
70
+ pretrained: str = "",
71
+ precision: str = "fp32",
72
+ device: torch.device = torch.device("cpu"),
73
+ jit: bool = False,
74
+ force_quick_gelu: bool = False,
75
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
76
+ skip_params=True,
77
+ pretrained_audio: str = "",
78
+ pretrained_text: str = "",
79
+ enable_fusion: bool = False,
80
+ fusion_type: str = 'None'
81
+ # pretrained_image: bool = False,
82
+ ):
83
+ amodel_name = amodel_name.replace(
84
+ "/", "-"
85
+ ) # for callers using old naming with / in ViT names
86
+ pretrained_orig = pretrained
87
+ pretrained = pretrained.lower()
88
+ if pretrained == "openai":
89
+ if amodel_name in _MODEL_CONFIGS:
90
+ logging.info(f"Loading {amodel_name} model config.")
91
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
92
+ else:
93
+ logging.error(
94
+ f"Model config for {amodel_name} not found; available models {list_models()}."
95
+ )
96
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
97
+
98
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
99
+ # Hard Code in model name
100
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
101
+ model = load_openai_model(
102
+ "ViT-B-16",
103
+ model_cfg,
104
+ device=device,
105
+ jit=jit,
106
+ cache_dir=openai_model_cache_dir,
107
+ enable_fusion=enable_fusion,
108
+ fusion_type=fusion_type
109
+ )
110
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
111
+ if precision == "amp" or precision == "fp32":
112
+ model = model.float()
113
+ else:
114
+ if amodel_name in _MODEL_CONFIGS:
115
+ logging.info(f"Loading {amodel_name} model config.")
116
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
117
+ else:
118
+ logging.error(
119
+ f"Model config for {amodel_name} not found; available models {list_models()}."
120
+ )
121
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
122
+
123
+ if force_quick_gelu:
124
+ # override for use of QuickGELU on non-OpenAI transformer models
125
+ model_cfg["quick_gelu"] = True
126
+
127
+ # if pretrained_image:
128
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
129
+ # # pretrained weight loading for timm models set via vision_cfg
130
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
131
+ # else:
132
+ # assert False, 'pretrained image towers currently only supported for timm models'
133
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
134
+ model_cfg["enable_fusion"] = enable_fusion
135
+ model_cfg["fusion_type"] = fusion_type
136
+ model = CLAP(**model_cfg)
137
+
138
+ if pretrained:
139
+ checkpoint_path = ""
140
+ url = get_pretrained_url(amodel_name, pretrained)
141
+ if url:
142
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
143
+ elif os.path.exists(pretrained_orig):
144
+ checkpoint_path = pretrained_orig
145
+ if checkpoint_path:
146
+ logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).")
147
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
148
+ model.load_state_dict(ckpt)
149
+ param_names = [n for n, p in model.named_parameters()]
150
+ for n in param_names:
151
+ print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
152
+ else:
153
+ logging.warning(
154
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
155
+ )
156
+ raise RuntimeError(
157
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
158
+ )
159
+
160
+ if pretrained_audio:
161
+ if amodel_name.startswith('PANN'):
162
+ if 'Cnn14_mAP' in pretrained_audio: # official checkpoint
163
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
164
+ audio_ckpt = audio_ckpt['model']
165
+ keys = list(audio_ckpt.keys())
166
+ for key in keys:
167
+ if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key:
168
+ v = audio_ckpt.pop(key)
169
+ audio_ckpt['audio_branch.' + key] = v
170
+ elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase
171
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
172
+ audio_ckpt = audio_ckpt['state_dict']
173
+ keys = list(audio_ckpt.keys())
174
+ for key in keys:
175
+ if key.startswith('sed_model'):
176
+ v = audio_ckpt.pop(key)
177
+ audio_ckpt['audio_branch.' + key[10:]] = v
178
+ elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
179
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
180
+ else:
181
+ raise ValueError('Unknown audio checkpoint')
182
+ elif amodel_name.startswith('HTSAT'):
183
+ if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint
184
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
185
+ audio_ckpt = audio_ckpt['state_dict']
186
+ keys = list(audio_ckpt.keys())
187
+ for key in keys:
188
+ if key.startswith('sed_model') and ('spectrogram_extractor' not in key
189
+ and 'logmel_extractor' not in key):
190
+ v = audio_ckpt.pop(key)
191
+ audio_ckpt['audio_branch.' + key[10:]] = v
192
+ elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase
193
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
194
+ audio_ckpt = audio_ckpt['state_dict']
195
+ keys = list(audio_ckpt.keys())
196
+ for key in keys:
197
+ if key.startswith('sed_model'):
198
+ v = audio_ckpt.pop(key)
199
+ audio_ckpt['audio_branch.' + key[10:]] = v
200
+ elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
201
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
202
+ else:
203
+ raise ValueError('Unknown audio checkpoint')
204
+ else:
205
+ raise f'this audio encoder pretrained checkpoint is not support'
206
+
207
+ model.load_state_dict(audio_ckpt, strict=False)
208
+ logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).")
209
+ param_names = [n for n, p in model.named_parameters()]
210
+ for n in param_names:
211
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
212
+
213
+ model.to(device=device)
214
+ if precision == "fp16":
215
+ assert device.type != "cpu"
216
+ convert_weights_to_fp16(model)
217
+
218
+ if jit:
219
+ model = torch.jit.script(model)
220
+
221
+ return model, model_cfg
222
+
223
+
224
+ def create_model_and_transforms(
225
+ model_name: str,
226
+ pretrained: str = "",
227
+ precision: str = "fp32",
228
+ device: torch.device = torch.device("cpu"),
229
+ jit: bool = False,
230
+ force_quick_gelu: bool = False,
231
+ # pretrained_image: bool = False,
232
+ ):
233
+ model = create_model(
234
+ model_name,
235
+ pretrained,
236
+ precision,
237
+ device,
238
+ jit,
239
+ force_quick_gelu=force_quick_gelu,
240
+ # pretrained_image=pretrained_image
241
+ )
242
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
243
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
244
+ return model, preprocess_train, preprocess_val
245
+
246
+
247
+ def list_models():
248
+ """enumerate available model architectures based on config files"""
249
+ return list(_MODEL_CONFIGS.keys())
250
+
251
+
252
+ def add_model_config(path):
253
+ """add model config path or file and update registry"""
254
+ if not isinstance(path, Path):
255
+ path = Path(path)
256
+ _MODEL_CONFIG_PATHS.append(path)
257
+ _rescan_model_configs()
ldm/modules/encoders/open_clap/feature_fusion.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Feature Fusion for Varible-Length Data Processing
3
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
+ '''
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class DAF(nn.Module):
12
+ '''
13
+ 直接相加 DirectAddFuse
14
+ '''
15
+
16
+ def __init__(self):
17
+ super(DAF, self).__init__()
18
+
19
+ def forward(self, x, residual):
20
+ return x + residual
21
+
22
+
23
+ class iAFF(nn.Module):
24
+ '''
25
+ 多特征融合 iAFF
26
+ '''
27
+
28
+ def __init__(self, channels=64, r=4, type='2D'):
29
+ super(iAFF, self).__init__()
30
+ inter_channels = int(channels // r)
31
+
32
+ if type == '1D':
33
+ # 本地注意力
34
+ self.local_att = nn.Sequential(
35
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
+ nn.BatchNorm1d(inter_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm1d(channels),
40
+ )
41
+
42
+ # 全局注意力
43
+ self.global_att = nn.Sequential(
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
+ nn.BatchNorm1d(inter_channels),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
+ nn.BatchNorm1d(channels),
50
+ )
51
+
52
+ # 第二次本地注意力
53
+ self.local_att2 = nn.Sequential(
54
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
+ nn.BatchNorm1d(inter_channels),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
+ nn.BatchNorm1d(channels),
59
+ )
60
+ # 第二次全局注意力
61
+ self.global_att2 = nn.Sequential(
62
+ nn.AdaptiveAvgPool1d(1),
63
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(inter_channels),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
+ nn.BatchNorm1d(channels),
68
+ )
69
+ elif type == '2D':
70
+ # 本地注意力
71
+ self.local_att = nn.Sequential(
72
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
+ nn.BatchNorm2d(inter_channels),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
+ nn.BatchNorm2d(channels),
77
+ )
78
+
79
+ # 全局注意力
80
+ self.global_att = nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm2d(inter_channels),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
+ nn.BatchNorm2d(channels),
87
+ )
88
+
89
+ # 第二次本地注意力
90
+ self.local_att2 = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+ # 第二次全局注意力
98
+ self.global_att2 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+ else:
107
+ raise f'the type is not supported'
108
+
109
+ self.sigmoid = nn.Sigmoid()
110
+
111
+ def forward(self, x, residual):
112
+ flag = False
113
+ xa = x + residual
114
+ if xa.size(0) == 1:
115
+ xa = torch.cat([xa,xa],dim=0)
116
+ flag = True
117
+ xl = self.local_att(xa)
118
+ xg = self.global_att(xa)
119
+ xlg = xl + xg
120
+ wei = self.sigmoid(xlg)
121
+ xi = x * wei + residual * (1 - wei)
122
+
123
+ xl2 = self.local_att2(xi)
124
+ xg2 = self.global_att(xi)
125
+ xlg2 = xl2 + xg2
126
+ wei2 = self.sigmoid(xlg2)
127
+ xo = x * wei2 + residual * (1 - wei2)
128
+ if flag:
129
+ xo = xo[0].unsqueeze(0)
130
+ return xo
131
+
132
+
133
+ class AFF(nn.Module):
134
+ '''
135
+ 多特征融合 AFF
136
+ '''
137
+
138
+ def __init__(self, channels=64, r=4, type='2D'):
139
+ super(AFF, self).__init__()
140
+ inter_channels = int(channels // r)
141
+
142
+ if type == '1D':
143
+ self.local_att = nn.Sequential(
144
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
+ nn.BatchNorm1d(inter_channels),
146
+ nn.ReLU(inplace=True),
147
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
+ nn.BatchNorm1d(channels),
149
+ )
150
+ self.global_att = nn.Sequential(
151
+ nn.AdaptiveAvgPool1d(1),
152
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
+ nn.BatchNorm1d(inter_channels),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
+ nn.BatchNorm1d(channels),
157
+ )
158
+ elif type == '2D':
159
+ self.local_att = nn.Sequential(
160
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
+ nn.BatchNorm2d(inter_channels),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
+ nn.BatchNorm2d(channels),
165
+ )
166
+ self.global_att = nn.Sequential(
167
+ nn.AdaptiveAvgPool2d(1),
168
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
+ nn.BatchNorm2d(inter_channels),
170
+ nn.ReLU(inplace=True),
171
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+ else:
175
+ raise f'the type is not supported.'
176
+
177
+ self.sigmoid = nn.Sigmoid()
178
+
179
+ def forward(self, x, residual):
180
+ flag = False
181
+ xa = x + residual
182
+ if xa.size(0) == 1:
183
+ xa = torch.cat([xa,xa],dim=0)
184
+ flag = True
185
+ xl = self.local_att(xa)
186
+ xg = self.global_att(xa)
187
+ xlg = xl + xg
188
+ wei = self.sigmoid(xlg)
189
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
190
+ if flag:
191
+ xo = xo[0].unsqueeze(0)
192
+ return xo
193
+
ldm/modules/encoders/open_clap/htsat.py ADDED
@@ -0,0 +1,1022 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+ import math
14
+ import warnings
15
+
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+ import torch.utils.checkpoint as checkpoint
18
+
19
+ import random
20
+
21
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
+ from torchlibrosa.augmentation import SpecAugmentation
23
+
24
+ from itertools import repeat
25
+ from .utils import do_mixup, interpolate
26
+
27
+ from .feature_fusion import iAFF, AFF, DAF
28
+
29
+ # from PyTorch internals
30
+ def _ntuple(n):
31
+ def parse(x):
32
+ if isinstance(x, collections.abc.Iterable):
33
+ return x
34
+ return tuple(repeat(x, n))
35
+ return parse
36
+
37
+ to_1tuple = _ntuple(1)
38
+ to_2tuple = _ntuple(2)
39
+ to_3tuple = _ntuple(3)
40
+ to_4tuple = _ntuple(4)
41
+ to_ntuple = _ntuple
42
+
43
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
44
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
45
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
46
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
47
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
48
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
49
+ 'survival rate' as the argument.
50
+ """
51
+ if drop_prob == 0. or not training:
52
+ return x
53
+ keep_prob = 1 - drop_prob
54
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
55
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
56
+ random_tensor.floor_() # binarize
57
+ output = x.div(keep_prob) * random_tensor
58
+ return output
59
+
60
+
61
+ class DropPath(nn.Module):
62
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
63
+ """
64
+ def __init__(self, drop_prob=None):
65
+ super(DropPath, self).__init__()
66
+ self.drop_prob = drop_prob
67
+
68
+ def forward(self, x):
69
+ return drop_path(x, self.drop_prob, self.training)
70
+
71
+ class PatchEmbed(nn.Module):
72
+ """ 2D Image to Patch Embedding
73
+ """
74
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16,
75
+ enable_fusion=False, fusion_type='None'):
76
+ super().__init__()
77
+ img_size = to_2tuple(img_size)
78
+ patch_size = to_2tuple(patch_size)
79
+ patch_stride = to_2tuple(patch_stride)
80
+ self.img_size = img_size
81
+ self.patch_size = patch_size
82
+ self.patch_stride = patch_stride
83
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
84
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
85
+ self.flatten = flatten
86
+ self.in_chans = in_chans
87
+ self.embed_dim = embed_dim
88
+
89
+ self.enable_fusion = enable_fusion
90
+ self.fusion_type = fusion_type
91
+
92
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
93
+
94
+ if (self.enable_fusion) and (self.fusion_type == 'channel_map'):
95
+ self.proj = nn.Conv2d(in_chans*4, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
96
+ else:
97
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
98
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
99
+
100
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
101
+ self.mel_conv2d = nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_size[0], patch_size[1]*3), stride=(patch_stride[0], patch_stride[1] * 3), padding=padding)
102
+ if self.fusion_type == 'daf_2d':
103
+ self.fusion_model = DAF()
104
+ elif self.fusion_type == 'aff_2d':
105
+ self.fusion_model = AFF(channels=embed_dim, type='2D')
106
+ elif self.fusion_type == 'iaff_2d':
107
+ self.fusion_model = iAFF(channels=embed_dim, type='2D')
108
+ def forward(self, x, longer_idx = None):
109
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
110
+ global_x = x[:,0:1,:,:]
111
+
112
+
113
+ # global processing
114
+ B, C, H, W = global_x.shape
115
+ assert H == self.img_size[0] and W == self.img_size[1], \
116
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
117
+ global_x = self.proj(global_x)
118
+ TW = global_x.size(-1)
119
+ if len(longer_idx) > 0:
120
+ # local processing
121
+ local_x = x[longer_idx,1:,:,:].contiguous()
122
+ B, C, H, W = local_x.shape
123
+ local_x = local_x.view(B*C,1,H,W)
124
+ local_x = self.mel_conv2d(local_x)
125
+ local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3))
126
+ local_x = local_x.permute((0,2,3,1,4)).contiguous().flatten(3)
127
+ TB,TC,TH,_ = local_x.size()
128
+ if local_x.size(-1) < TW:
129
+ local_x = torch.cat([local_x, torch.zeros((TB,TC,TH,TW-local_x.size(-1)), device=global_x.device)], dim=-1)
130
+ else:
131
+ local_x = local_x[:,:,:,:TW]
132
+
133
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx],local_x)
134
+ x = global_x
135
+ else:
136
+ B, C, H, W = x.shape
137
+ assert H == self.img_size[0] and W == self.img_size[1], \
138
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
139
+ x = self.proj(x)
140
+
141
+ if self.flatten:
142
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
143
+ x = self.norm(x)
144
+ return x
145
+
146
+ class Mlp(nn.Module):
147
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
148
+ """
149
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = act_layer()
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+ self.drop = nn.Dropout(drop)
157
+
158
+ def forward(self, x):
159
+ x = self.fc1(x)
160
+ x = self.act(x)
161
+ x = self.drop(x)
162
+ x = self.fc2(x)
163
+ x = self.drop(x)
164
+ return x
165
+
166
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
167
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
168
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
169
+ def norm_cdf(x):
170
+ # Computes standard normal cumulative distribution function
171
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
172
+
173
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
174
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
175
+ "The distribution of values may be incorrect.",
176
+ stacklevel=2)
177
+
178
+ with torch.no_grad():
179
+ # Values are generated by using a truncated uniform distribution and
180
+ # then using the inverse CDF for the normal distribution.
181
+ # Get upper and lower cdf values
182
+ l = norm_cdf((a - mean) / std)
183
+ u = norm_cdf((b - mean) / std)
184
+
185
+ # Uniformly fill tensor with values from [l, u], then translate to
186
+ # [2l-1, 2u-1].
187
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
188
+
189
+ # Use inverse cdf transform for normal distribution to get truncated
190
+ # standard normal
191
+ tensor.erfinv_()
192
+
193
+ # Transform to proper mean, std
194
+ tensor.mul_(std * math.sqrt(2.))
195
+ tensor.add_(mean)
196
+
197
+ # Clamp to ensure it's in the proper range
198
+ tensor.clamp_(min=a, max=b)
199
+ return tensor
200
+
201
+
202
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
203
+ # type: (Tensor, float, float, float, float) -> Tensor
204
+ r"""Fills the input Tensor with values drawn from a truncated
205
+ normal distribution. The values are effectively drawn from the
206
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
207
+ with values outside :math:`[a, b]` redrawn until they are within
208
+ the bounds. The method used for generating the random values works
209
+ best when :math:`a \leq \text{mean} \leq b`.
210
+ Args:
211
+ tensor: an n-dimensional `torch.Tensor`
212
+ mean: the mean of the normal distribution
213
+ std: the standard deviation of the normal distribution
214
+ a: the minimum cutoff value
215
+ b: the maximum cutoff value
216
+ Examples:
217
+ >>> w = torch.empty(3, 5)
218
+ >>> nn.init.trunc_normal_(w)
219
+ """
220
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
221
+
222
+
223
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
224
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
225
+ if mode == 'fan_in':
226
+ denom = fan_in
227
+ elif mode == 'fan_out':
228
+ denom = fan_out
229
+ elif mode == 'fan_avg':
230
+ denom = (fan_in + fan_out) / 2
231
+
232
+ variance = scale / denom
233
+
234
+ if distribution == "truncated_normal":
235
+ # constant is stddev of standard normal truncated to (-2, 2)
236
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
237
+ elif distribution == "normal":
238
+ tensor.normal_(std=math.sqrt(variance))
239
+ elif distribution == "uniform":
240
+ bound = math.sqrt(3 * variance)
241
+ tensor.uniform_(-bound, bound)
242
+ else:
243
+ raise ValueError(f"invalid distribution {distribution}")
244
+
245
+
246
+ def lecun_normal_(tensor):
247
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
248
+
249
+ def window_partition(x, window_size):
250
+ """
251
+ Args:
252
+ x: (B, H, W, C)
253
+ window_size (int): window size
254
+ Returns:
255
+ windows: (num_windows*B, window_size, window_size, C)
256
+ """
257
+ B, H, W, C = x.shape
258
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
259
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
260
+ return windows
261
+
262
+
263
+ def window_reverse(windows, window_size, H, W):
264
+ """
265
+ Args:
266
+ windows: (num_windows*B, window_size, window_size, C)
267
+ window_size (int): Window size
268
+ H (int): Height of image
269
+ W (int): Width of image
270
+ Returns:
271
+ x: (B, H, W, C)
272
+ """
273
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
274
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
275
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
276
+ return x
277
+
278
+
279
+ class WindowAttention(nn.Module):
280
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
281
+ It supports both of shifted and non-shifted window.
282
+ Args:
283
+ dim (int): Number of input channels.
284
+ window_size (tuple[int]): The height and width of the window.
285
+ num_heads (int): Number of attention heads.
286
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
287
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
288
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
289
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
290
+ """
291
+
292
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
293
+
294
+ super().__init__()
295
+ self.dim = dim
296
+ self.window_size = window_size # Wh, Ww
297
+ self.num_heads = num_heads
298
+ head_dim = dim // num_heads
299
+ self.scale = qk_scale or head_dim ** -0.5
300
+
301
+ # define a parameter table of relative position bias
302
+ self.relative_position_bias_table = nn.Parameter(
303
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
304
+
305
+ # get pair-wise relative position index for each token inside the window
306
+ coords_h = torch.arange(self.window_size[0])
307
+ coords_w = torch.arange(self.window_size[1])
308
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
309
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
310
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
311
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
312
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
313
+ relative_coords[:, :, 1] += self.window_size[1] - 1
314
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
315
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
316
+ self.register_buffer("relative_position_index", relative_position_index)
317
+
318
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
319
+ self.attn_drop = nn.Dropout(attn_drop)
320
+ self.proj = nn.Linear(dim, dim)
321
+ self.proj_drop = nn.Dropout(proj_drop)
322
+
323
+ trunc_normal_(self.relative_position_bias_table, std=.02)
324
+ self.softmax = nn.Softmax(dim=-1)
325
+
326
+ def forward(self, x, mask=None):
327
+ """
328
+ Args:
329
+ x: input features with shape of (num_windows*B, N, C)
330
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
331
+ """
332
+ B_, N, C = x.shape
333
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
334
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
335
+
336
+ q = q * self.scale
337
+ attn = (q @ k.transpose(-2, -1))
338
+
339
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
340
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
341
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
342
+ attn = attn + relative_position_bias.unsqueeze(0)
343
+
344
+ if mask is not None:
345
+ nW = mask.shape[0]
346
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
347
+ attn = attn.view(-1, self.num_heads, N, N)
348
+ attn = self.softmax(attn)
349
+ else:
350
+ attn = self.softmax(attn)
351
+
352
+ attn = self.attn_drop(attn)
353
+
354
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
355
+ x = self.proj(x)
356
+ x = self.proj_drop(x)
357
+ return x, attn
358
+
359
+ def extra_repr(self):
360
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
361
+
362
+
363
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
364
+ class SwinTransformerBlock(nn.Module):
365
+ r""" Swin Transformer Block.
366
+ Args:
367
+ dim (int): Number of input channels.
368
+ input_resolution (tuple[int]): Input resulotion.
369
+ num_heads (int): Number of attention heads.
370
+ window_size (int): Window size.
371
+ shift_size (int): Shift size for SW-MSA.
372
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
373
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
374
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
375
+ drop (float, optional): Dropout rate. Default: 0.0
376
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
377
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
378
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
379
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
380
+ """
381
+
382
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
383
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
384
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
385
+ super().__init__()
386
+ self.dim = dim
387
+ self.input_resolution = input_resolution
388
+ self.num_heads = num_heads
389
+ self.window_size = window_size
390
+ self.shift_size = shift_size
391
+ self.mlp_ratio = mlp_ratio
392
+ self.norm_before_mlp = norm_before_mlp
393
+ if min(self.input_resolution) <= self.window_size:
394
+ # if window size is larger than input resolution, we don't partition windows
395
+ self.shift_size = 0
396
+ self.window_size = min(self.input_resolution)
397
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
398
+
399
+ self.norm1 = norm_layer(dim)
400
+ self.attn = WindowAttention(
401
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
402
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
403
+
404
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
405
+ if self.norm_before_mlp == 'ln':
406
+ self.norm2 = nn.LayerNorm(dim)
407
+ elif self.norm_before_mlp == 'bn':
408
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
409
+ else:
410
+ raise NotImplementedError
411
+ mlp_hidden_dim = int(dim * mlp_ratio)
412
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
413
+
414
+ if self.shift_size > 0:
415
+ # calculate attention mask for SW-MSA
416
+ H, W = self.input_resolution
417
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
418
+ h_slices = (slice(0, -self.window_size),
419
+ slice(-self.window_size, -self.shift_size),
420
+ slice(-self.shift_size, None))
421
+ w_slices = (slice(0, -self.window_size),
422
+ slice(-self.window_size, -self.shift_size),
423
+ slice(-self.shift_size, None))
424
+ cnt = 0
425
+ for h in h_slices:
426
+ for w in w_slices:
427
+ img_mask[:, h, w, :] = cnt
428
+ cnt += 1
429
+
430
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
431
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
432
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
433
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
434
+ else:
435
+ attn_mask = None
436
+
437
+ self.register_buffer("attn_mask", attn_mask)
438
+
439
+ def forward(self, x):
440
+ # pdb.set_trace()
441
+ H, W = self.input_resolution
442
+ # print("H: ", H)
443
+ # print("W: ", W)
444
+ # pdb.set_trace()
445
+ B, L, C = x.shape
446
+ # assert L == H * W, "input feature has wrong size"
447
+
448
+ shortcut = x
449
+ x = self.norm1(x)
450
+ x = x.view(B, H, W, C)
451
+
452
+ # cyclic shift
453
+ if self.shift_size > 0:
454
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
455
+ else:
456
+ shifted_x = x
457
+
458
+ # partition windows
459
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
460
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
461
+
462
+ # W-MSA/SW-MSA
463
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
464
+
465
+ # merge windows
466
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
467
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
468
+
469
+ # reverse cyclic shift
470
+ if self.shift_size > 0:
471
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
472
+ else:
473
+ x = shifted_x
474
+ x = x.view(B, H * W, C)
475
+
476
+ # FFN
477
+ x = shortcut + self.drop_path(x)
478
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
479
+
480
+ return x, attn
481
+
482
+ def extra_repr(self):
483
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
484
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
485
+
486
+
487
+
488
+ class PatchMerging(nn.Module):
489
+ r""" Patch Merging Layer.
490
+ Args:
491
+ input_resolution (tuple[int]): Resolution of input feature.
492
+ dim (int): Number of input channels.
493
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
494
+ """
495
+
496
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
497
+ super().__init__()
498
+ self.input_resolution = input_resolution
499
+ self.dim = dim
500
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
501
+ self.norm = norm_layer(4 * dim)
502
+
503
+ def forward(self, x):
504
+ """
505
+ x: B, H*W, C
506
+ """
507
+ H, W = self.input_resolution
508
+ B, L, C = x.shape
509
+ assert L == H * W, "input feature has wrong size"
510
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
511
+
512
+ x = x.view(B, H, W, C)
513
+
514
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
515
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
516
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
517
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
518
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
519
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
520
+
521
+ x = self.norm(x)
522
+ x = self.reduction(x)
523
+
524
+ return x
525
+
526
+ def extra_repr(self):
527
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
528
+
529
+
530
+ class BasicLayer(nn.Module):
531
+ """ A basic Swin Transformer layer for one stage.
532
+ Args:
533
+ dim (int): Number of input channels.
534
+ input_resolution (tuple[int]): Input resolution.
535
+ depth (int): Number of blocks.
536
+ num_heads (int): Number of attention heads.
537
+ window_size (int): Local window size.
538
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
539
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
540
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
541
+ drop (float, optional): Dropout rate. Default: 0.0
542
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
543
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
544
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
545
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
546
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
547
+ """
548
+
549
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
550
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
551
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
552
+ norm_before_mlp='ln'):
553
+
554
+ super().__init__()
555
+ self.dim = dim
556
+ self.input_resolution = input_resolution
557
+ self.depth = depth
558
+ self.use_checkpoint = use_checkpoint
559
+
560
+ # build blocks
561
+ self.blocks = nn.ModuleList([
562
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
563
+ num_heads=num_heads, window_size=window_size,
564
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
565
+ mlp_ratio=mlp_ratio,
566
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
567
+ drop=drop, attn_drop=attn_drop,
568
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
569
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
570
+ for i in range(depth)])
571
+
572
+ # patch merging layer
573
+ if downsample is not None:
574
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
575
+ else:
576
+ self.downsample = None
577
+
578
+ def forward(self, x):
579
+ attns = []
580
+ for blk in self.blocks:
581
+ if self.use_checkpoint:
582
+ x = checkpoint.checkpoint(blk, x)
583
+ else:
584
+ x, attn = blk(x)
585
+ if not self.training:
586
+ attns.append(attn.unsqueeze(0))
587
+ if self.downsample is not None:
588
+ x = self.downsample(x)
589
+ if not self.training:
590
+ attn = torch.cat(attns, dim = 0)
591
+ attn = torch.mean(attn, dim = 0)
592
+ return x, attn
593
+
594
+ def extra_repr(self):
595
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
596
+
597
+
598
+ # The Core of HTSAT
599
+ class HTSAT_Swin_Transformer(nn.Module):
600
+ r"""HTSAT based on the Swin Transformer
601
+ Args:
602
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
603
+ patch_size (int | tuple(int)): Patch size. Default: 4
604
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
605
+ in_chans (int): Number of input image channels. Default: 1 (mono)
606
+ num_classes (int): Number of classes for classification head. Default: 527
607
+ embed_dim (int): Patch embedding dimension. Default: 96
608
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
609
+ num_heads (tuple(int)): Number of attention heads in different layers.
610
+ window_size (int): Window size. Default: 8
611
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
612
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
613
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
614
+ drop_rate (float): Dropout rate. Default: 0
615
+ attn_drop_rate (float): Attention dropout rate. Default: 0
616
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
617
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
618
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
619
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
620
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
621
+ config (module): The configuration Module from config.py
622
+ """
623
+
624
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
625
+ in_chans=1, num_classes=527,
626
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
627
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
628
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
629
+ norm_layer=nn.LayerNorm,
630
+ ape=False, patch_norm=True,
631
+ use_checkpoint=False, norm_before_mlp='ln', config = None,
632
+ enable_fusion = False, fusion_type = 'None', **kwargs):
633
+ super(HTSAT_Swin_Transformer, self).__init__()
634
+
635
+ self.config = config
636
+ self.spec_size = spec_size
637
+ self.patch_stride = patch_stride
638
+ self.patch_size = patch_size
639
+ self.window_size = window_size
640
+ self.embed_dim = embed_dim
641
+ self.depths = depths
642
+ self.ape = ape
643
+ self.in_chans = in_chans
644
+ self.num_classes = num_classes
645
+ self.num_heads = num_heads
646
+ self.num_layers = len(self.depths)
647
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
648
+
649
+ self.drop_rate = drop_rate
650
+ self.attn_drop_rate = attn_drop_rate
651
+ self.drop_path_rate = drop_path_rate
652
+
653
+ self.qkv_bias = qkv_bias
654
+ self.qk_scale = None
655
+
656
+ self.patch_norm = patch_norm
657
+ self.norm_layer = norm_layer if self.patch_norm else None
658
+ self.norm_before_mlp = norm_before_mlp
659
+ self.mlp_ratio = mlp_ratio
660
+
661
+ self.use_checkpoint = use_checkpoint
662
+
663
+ self.enable_fusion = enable_fusion
664
+ self.fusion_type = fusion_type
665
+
666
+ # process mel-spec ; used only once
667
+ self.freq_ratio = self.spec_size // self.config.mel_bins
668
+ window = 'hann'
669
+ center = True
670
+ pad_mode = 'reflect'
671
+ ref = 1.0
672
+ amin = 1e-10
673
+ top_db = None
674
+ self.interpolate_ratio = 32 # Downsampled ratio
675
+ # Spectrogram extractor
676
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
677
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
678
+ freeze_parameters=True)
679
+ # Logmel feature extractor
680
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
681
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
682
+ freeze_parameters=True)
683
+ # Spec augmenter
684
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
685
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
686
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
687
+
688
+
689
+ # split spctrogram into non-overlapping patches
690
+ self.patch_embed = PatchEmbed(
691
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
692
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride,
693
+ enable_fusion=self.enable_fusion, fusion_type=self.fusion_type
694
+ )
695
+
696
+ num_patches = self.patch_embed.num_patches
697
+ patches_resolution = self.patch_embed.grid_size
698
+ self.patches_resolution = patches_resolution
699
+
700
+ # absolute position embedding
701
+ if self.ape:
702
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
703
+ trunc_normal_(self.absolute_pos_embed, std=.02)
704
+
705
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
706
+
707
+ # stochastic depth
708
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
709
+
710
+ # build layers
711
+ self.layers = nn.ModuleList()
712
+ for i_layer in range(self.num_layers):
713
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
714
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
715
+ patches_resolution[1] // (2 ** i_layer)),
716
+ depth=self.depths[i_layer],
717
+ num_heads=self.num_heads[i_layer],
718
+ window_size=self.window_size,
719
+ mlp_ratio=self.mlp_ratio,
720
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
721
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
722
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
723
+ norm_layer=self.norm_layer,
724
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
725
+ use_checkpoint=use_checkpoint,
726
+ norm_before_mlp=self.norm_before_mlp)
727
+ self.layers.append(layer)
728
+
729
+ self.norm = self.norm_layer(self.num_features)
730
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
731
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
732
+
733
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
734
+ self.tscam_conv = nn.Conv2d(
735
+ in_channels = self.num_features,
736
+ out_channels = self.num_classes,
737
+ kernel_size = (SF,3),
738
+ padding = (0,1)
739
+ )
740
+ self.head = nn.Linear(num_classes, num_classes)
741
+
742
+ if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']):
743
+ self.mel_conv1d = nn.Sequential(
744
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
745
+ nn.BatchNorm1d(64)
746
+ )
747
+ if self.fusion_type == 'daf_1d':
748
+ self.fusion_model = DAF()
749
+ elif self.fusion_type == 'aff_1d':
750
+ self.fusion_model = AFF(channels=64, type='1D')
751
+ elif self.fusion_type == 'iaff_1d':
752
+ self.fusion_model = iAFF(channels=64, type='1D')
753
+
754
+ self.apply(self._init_weights)
755
+
756
+ def _init_weights(self, m):
757
+ if isinstance(m, nn.Linear):
758
+ trunc_normal_(m.weight, std=.02)
759
+ if isinstance(m, nn.Linear) and m.bias is not None:
760
+ nn.init.constant_(m.bias, 0)
761
+ elif isinstance(m, nn.LayerNorm):
762
+ nn.init.constant_(m.bias, 0)
763
+ nn.init.constant_(m.weight, 1.0)
764
+
765
+ @torch.jit.ignore
766
+ def no_weight_decay(self):
767
+ return {'absolute_pos_embed'}
768
+
769
+ @torch.jit.ignore
770
+ def no_weight_decay_keywords(self):
771
+ return {'relative_position_bias_table'}
772
+
773
+
774
+ def forward_features(self, x, longer_idx = None):
775
+ # A deprecated optimization for using a hierarchical output from different blocks
776
+
777
+ frames_num = x.shape[2]
778
+ x = self.patch_embed(x, longer_idx = longer_idx)
779
+ if self.ape:
780
+ x = x + self.absolute_pos_embed
781
+ x = self.pos_drop(x)
782
+ for i, layer in enumerate(self.layers):
783
+ x, attn = layer(x)
784
+ # for x
785
+ x = self.norm(x)
786
+ B, N, C = x.shape
787
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
788
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
789
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
790
+ B, C, F, T = x.shape
791
+ # group 2D CNN
792
+ c_freq_bin = F // self.freq_ratio
793
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
794
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
795
+ # get latent_output
796
+ fine_grained_latent_output = torch.mean(x, dim = 2)
797
+ fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
798
+
799
+ latent_output = self.avgpool(torch.flatten(x,2))
800
+ latent_output = torch.flatten(latent_output, 1)
801
+
802
+ # display the attention map, if needed
803
+
804
+ x = self.tscam_conv(x)
805
+ x = torch.flatten(x, 2) # B, C, T
806
+
807
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
808
+
809
+ x = self.avgpool(x)
810
+ x = torch.flatten(x, 1)
811
+
812
+ output_dict = {
813
+ 'framewise_output': fpx, # already sigmoided
814
+ 'clipwise_output': torch.sigmoid(x),
815
+ 'fine_grained_embedding': fine_grained_latent_output,
816
+ 'embedding': latent_output
817
+ }
818
+
819
+ return output_dict
820
+
821
+ def crop_wav(self, x, crop_size, spe_pos = None):
822
+ time_steps = x.shape[2]
823
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
824
+ for i in range(len(x)):
825
+ if spe_pos is None:
826
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
827
+ else:
828
+ crop_pos = spe_pos
829
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
830
+ return tx
831
+
832
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
833
+ def reshape_wav2img(self, x):
834
+ B, C, T, F = x.shape
835
+ target_T = int(self.spec_size * self.freq_ratio)
836
+ target_F = self.spec_size // self.freq_ratio
837
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
838
+ # to avoid bicubic zero error
839
+ if T < target_T:
840
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
841
+ if F < target_F:
842
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
843
+ x = x.permute(0,1,3,2).contiguous()
844
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
845
+ # print(x.shape)
846
+ x = x.permute(0,1,3,2,4).contiguous()
847
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
848
+ return x
849
+
850
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
851
+ def repeat_wat2img(self, x, cur_pos):
852
+ B, C, T, F = x.shape
853
+ target_T = int(self.spec_size * self.freq_ratio)
854
+ target_F = self.spec_size // self.freq_ratio
855
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
856
+ # to avoid bicubic zero error
857
+ if T < target_T:
858
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
859
+ if F < target_F:
860
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
861
+ x = x.permute(0,1,3,2).contiguous() # B C F T
862
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
863
+ x = x.repeat(repeats = (1,1,4,1))
864
+ return x
865
+
866
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
867
+
868
+ if self.enable_fusion and x["longer"].sum() == 0:
869
+ # if no audio is longer than 10s, then randomly select one audio to be longer
870
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
871
+
872
+ if not self.enable_fusion:
873
+ x = x["waveform"].to(device=device, non_blocking=True)
874
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
875
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
876
+ x = x.transpose(1, 3)
877
+ x = self.bn0(x)
878
+ x = x.transpose(1, 3)
879
+ if self.training:
880
+ x = self.spec_augmenter(x)
881
+
882
+ if self.training and mixup_lambda is not None:
883
+ x = do_mixup(x, mixup_lambda)
884
+
885
+ x = self.reshape_wav2img(x)
886
+ output_dict = self.forward_features(x)
887
+ else:
888
+ longer_list = x["longer"].to(device=device, non_blocking=True)
889
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
890
+ x = x.transpose(1, 3)
891
+ x = self.bn0(x)
892
+ x = x.transpose(1, 3)
893
+ longer_list_idx = torch.where(longer_list)[0]
894
+ if self.fusion_type in ['daf_1d','aff_1d','iaff_1d']:
895
+ new_x = x[:,0:1,:,:].clone().contiguous()
896
+ if len(longer_list_idx) > 0:
897
+ # local processing
898
+ fusion_x_local = x[longer_list_idx,1:,:,:].clone().contiguous()
899
+ FB,FC,FT,FF = fusion_x_local.size()
900
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
901
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1)).contiguous()
902
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
903
+ fusion_x_local = fusion_x_local.view(FB,FC,FF,fusion_x_local.size(-1))
904
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1,3)).contiguous().flatten(2)
905
+ if fusion_x_local.size(-1) < FT:
906
+ fusion_x_local = torch.cat([fusion_x_local, torch.zeros((FB,FF,FT- fusion_x_local.size(-1)), device=device)], dim=-1)
907
+ else:
908
+ fusion_x_local = fusion_x_local[:,:,:FT]
909
+ # 1D fusion
910
+ new_x = new_x.squeeze(1).permute((0,2,1)).contiguous()
911
+ new_x[longer_list_idx] = self.fusion_model(new_x[longer_list_idx], fusion_x_local)
912
+ x = new_x.permute((0,2,1)).contiguous()[:,None,:,:]
913
+ else:
914
+ x = new_x
915
+
916
+ elif self.fusion_type in ['daf_2d','aff_2d','iaff_2d','channel_map']:
917
+ x = x # no change
918
+
919
+ if self.training:
920
+ x = self.spec_augmenter(x)
921
+ if self.training and mixup_lambda is not None:
922
+ x = do_mixup(x, mixup_lambda)
923
+
924
+ x = self.reshape_wav2img(x)
925
+ output_dict = self.forward_features(x, longer_idx = longer_list_idx)
926
+
927
+ # if infer_mode:
928
+ # # in infer mode. we need to handle different length audio input
929
+ # frame_num = x.shape[2]
930
+ # target_T = int(self.spec_size * self.freq_ratio)
931
+ # repeat_ratio = math.floor(target_T / frame_num)
932
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
933
+ # x = self.reshape_wav2img(x)
934
+ # output_dict = self.forward_features(x)
935
+ # else:
936
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
937
+ # if self.training:
938
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
939
+ # x = self.reshape_wav2img(x)
940
+ # output_dict = self.forward_features(x)
941
+ # else:
942
+ # # Change: Hard code here
943
+ # overlap_size = (x.shape[2] - 1) // 4
944
+ # output_dicts = []
945
+ # crop_size = (x.shape[2] - 1) // 2
946
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
947
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
948
+ # tx = self.reshape_wav2img(tx)
949
+ # output_dicts.append(self.forward_features(tx))
950
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
951
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
952
+ # for d in output_dicts:
953
+ # clipwise_output += d["clipwise_output"]
954
+ # framewise_output += d["framewise_output"]
955
+ # clipwise_output = clipwise_output / len(output_dicts)
956
+ # framewise_output = framewise_output / len(output_dicts)
957
+ # output_dict = {
958
+ # 'framewise_output': framewise_output,
959
+ # 'clipwise_output': clipwise_output
960
+ # }
961
+ # else: # this part is typically used, and most easy one
962
+ # x = self.reshape_wav2img(x)
963
+ # output_dict = self.forward_features(x)
964
+ # x = self.head(x)
965
+
966
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
967
+
968
+
969
+
970
+ return output_dict
971
+
972
+ def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type='None'):
973
+ try:
974
+
975
+ assert audio_cfg.model_name in ["tiny", "base", "large"], "model name for HTS-AT is wrong!"
976
+ if audio_cfg.model_name == "tiny":
977
+ model = HTSAT_Swin_Transformer(
978
+ spec_size=256,
979
+ patch_size=4,
980
+ patch_stride=(4,4),
981
+ num_classes=audio_cfg.class_num,
982
+ embed_dim=96,
983
+ depths=[2,2,6,2],
984
+ num_heads=[4,8,16,32],
985
+ window_size=8,
986
+ config = audio_cfg,
987
+ enable_fusion = enable_fusion,
988
+ fusion_type = fusion_type
989
+ )
990
+ elif audio_cfg.model_name == "base":
991
+ model = HTSAT_Swin_Transformer(
992
+ spec_size=256,
993
+ patch_size=4,
994
+ patch_stride=(4,4),
995
+ num_classes=audio_cfg.class_num,
996
+ embed_dim=128,
997
+ depths=[2,2,12,2],
998
+ num_heads=[4,8,16,32],
999
+ window_size=8,
1000
+ config = audio_cfg,
1001
+ enable_fusion = enable_fusion,
1002
+ fusion_type = fusion_type
1003
+ )
1004
+ elif audio_cfg.model_name == "large":
1005
+ model = HTSAT_Swin_Transformer(
1006
+ spec_size=256,
1007
+ patch_size=4,
1008
+ patch_stride=(4,4),
1009
+ num_classes=audio_cfg.class_num,
1010
+ embed_dim=256,
1011
+ depths=[2,2,12,2],
1012
+ num_heads=[4,8,16,32],
1013
+ window_size=8,
1014
+ config = audio_cfg,
1015
+ enable_fusion = enable_fusion,
1016
+ fusion_type = fusion_type
1017
+ )
1018
+
1019
+ return model
1020
+ except:
1021
+ raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.')
1022
+
ldm/modules/encoders/open_clap/linear_probe.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from .model import MLPLayers
5
+
6
+
7
+ class LinearProbe(nn.Module):
8
+ def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
9
+ """
10
+ Args:
11
+ model: nn.Module
12
+ mlp: bool, if True, then use the MLP layer as the linear probe module
13
+ freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
14
+ in_ch: int, the output channel from CLAP model
15
+ out_ch: int, the output channel from linear probe (class_num)
16
+ act: torch.nn.functional, the activation function before the loss function
17
+ """
18
+ super().__init__()
19
+ in_ch = 512
20
+ self.clap_model = model
21
+ self.clap_model.text_branch = None # to save memory
22
+ self.freeze = freeze
23
+ if mlp:
24
+ self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
25
+ else:
26
+ self.lp_layer = nn.Linear(in_ch, out_ch)
27
+
28
+ if self.freeze:
29
+ for param in self.clap_model.parameters():
30
+ param.requires_grad = False
31
+
32
+ if act == 'None':
33
+ self.act = None
34
+ elif act == 'relu':
35
+ self.act = nn.ReLU()
36
+ elif act == 'elu':
37
+ self.act = nn.ELU()
38
+ elif act == 'prelu':
39
+ self.act = nn.PReLU(num_parameters=in_ch)
40
+ elif act == 'softmax':
41
+ self.act = nn.Softmax(dim=-1)
42
+ elif act == 'sigmoid':
43
+ self.act = nn.Sigmoid()
44
+
45
+ def forward(self, x, mix_lambda=None, device=None):
46
+ """
47
+ Args:
48
+ x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
49
+ mix_lambda: torch.tensor [batch], the mixup lambda
50
+ Returns:
51
+ class_prob: torch.tensor [batch, class_num]
52
+
53
+ """
54
+ # batchnorm cancel grandient
55
+ if self.freeze:
56
+ self.clap_model.eval()
57
+
58
+ x = self.clap_model.audio_projection(
59
+ self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"])
60
+ out = self.lp_layer(x)
61
+ if self.act is not None:
62
+ out = self.act(out)
63
+ return out
ldm/modules/encoders/open_clap/loss.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.sharedctypes import Value
2
+ import torch
3
+ import torch.distributed.nn
4
+ from torch import distributed as dist, nn as nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
8
+
9
+ try:
10
+ import horovod.torch as hvd
11
+ except ImportError:
12
+ hvd = None
13
+
14
+
15
+ def gather_features(
16
+ audio_features,
17
+ text_features,
18
+ audio_features_mlp=None,
19
+ text_features_mlp=None,
20
+ local_loss=False,
21
+ gather_with_grad=False,
22
+ rank=0,
23
+ world_size=1,
24
+ use_horovod=False,
25
+ mlp_loss=False
26
+ ):
27
+ if use_horovod:
28
+ assert hvd is not None, 'Please install horovod'
29
+ if gather_with_grad:
30
+ all_audio_features = hvd.allgather(audio_features)
31
+ all_text_features = hvd.allgather(text_features)
32
+ if mlp_loss:
33
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
34
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
35
+ else:
36
+ with torch.no_grad():
37
+ all_audio_features = hvd.allgather(audio_features)
38
+ all_text_features = hvd.allgather(text_features)
39
+ if mlp_loss:
40
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
41
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
42
+ if not local_loss:
43
+ # ensure grads for local rank when all_* features don't have a gradient
44
+ gathered_audio_features = list(all_audio_features.chunk(world_size, dim=0))
45
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
46
+ gathered_audio_features[rank] = audio_features
47
+ gathered_text_features[rank] = text_features
48
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
49
+ all_text_features = torch.cat(gathered_text_features, dim=0)
50
+ if mlp_loss:
51
+ gathered_audio_features_mlp = list(all_audio_features_mlp.chunk(world_size, dim=0))
52
+ gathered_text_features_mlp = list(all_text_features_mlp.chunk(world_size, dim=0))
53
+ gathered_audio_features_mlp[rank] = audio_features_mlp
54
+ gathered_text_features_mlp[rank] = text_features_mlp
55
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
56
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
57
+ else:
58
+ # We gather tensors from all gpus
59
+ if gather_with_grad:
60
+ all_audio_features = torch.cat(torch.distributed.nn.all_gather(audio_features), dim=0)
61
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
62
+ if mlp_loss:
63
+ all_audio_features_mlp = torch.cat(torch.distributed.nn.all_gather(audio_features_mlp), dim=0)
64
+ all_text_features_mlp = torch.cat(torch.distributed.nn.all_gather(text_features_mlp), dim=0)
65
+ else:
66
+ gathered_audio_features = [torch.zeros_like(audio_features) for _ in range(world_size)]
67
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
68
+ dist.all_gather(gathered_audio_features, audio_features)
69
+ dist.all_gather(gathered_text_features, text_features)
70
+ if mlp_loss:
71
+ gathered_audio_features_mlp = [torch.zeros_like(audio_features_mlp) for _ in range(world_size)]
72
+ gathered_text_features_mlp = [torch.zeros_like(text_features_mlp) for _ in range(world_size)]
73
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
74
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
75
+ if not local_loss:
76
+ # ensure grads for local rank when all_* features don't have a gradient
77
+ gathered_audio_features[rank] = audio_features
78
+ gathered_text_features[rank] = text_features
79
+ if mlp_loss:
80
+ gathered_audio_features_mlp[rank] = audio_features_mlp
81
+ gathered_text_features_mlp[rank] = text_features_mlp
82
+
83
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
84
+ all_text_features = torch.cat(gathered_text_features, dim=0)
85
+ if mlp_loss:
86
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
87
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
88
+ if mlp_loss:
89
+ return all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp
90
+ else:
91
+ return all_audio_features, all_text_features
92
+
93
+ class ClipLoss(nn.Module):
94
+
95
+ def __init__(
96
+ self,
97
+ local_loss=False,
98
+ gather_with_grad=False,
99
+ cache_labels=False,
100
+ rank=0,
101
+ world_size=1,
102
+ use_horovod=False,
103
+ mlp_loss=False,
104
+ weight_loss_kappa=0,
105
+ ):
106
+ super().__init__()
107
+ self.local_loss = local_loss
108
+ self.gather_with_grad = gather_with_grad
109
+ self.cache_labels = cache_labels
110
+ self.rank = rank
111
+ self.world_size = world_size
112
+ self.use_horovod = use_horovod
113
+ self.mlp_loss = mlp_loss
114
+ self.weighted_loss = bool(weight_loss_kappa!=0)
115
+ self.weight_loss_kappa = weight_loss_kappa
116
+ # cache state
117
+ self.prev_num_logits = 0
118
+ self.labels = {}
119
+
120
+ def forward(self, audio_features, text_features, logit_scale_a, logit_scale_t=None, audio_features_mlp=None, text_features_mlp=None):
121
+ device = audio_features.device
122
+ if self.mlp_loss:
123
+ if self.world_size > 1:
124
+ all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = gather_features(
125
+ audio_features=audio_features,text_features=text_features,
126
+ audio_features_mlp=audio_features_mlp,text_features_mlp=text_features_mlp,
127
+ local_loss=self.local_loss,gather_with_grad=self.gather_with_grad,
128
+ rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod,
129
+ mlp_loss=self.mlp_loss
130
+ )
131
+ if self.local_loss:
132
+ a_logits_per_audio = logit_scale_a * audio_features @ all_text_features_mlp.T
133
+ a_logits_per_text = logit_scale_a * text_features_mlp @ all_audio_features.T
134
+ t_logits_per_audio = logit_scale_t * audio_features_mlp @ all_text_features.T
135
+ t_logits_per_text = logit_scale_t * text_features @ all_audio_features_mlp.T
136
+ else:
137
+ a_logits_per_audio = logit_scale_a * all_audio_features @ all_text_features_mlp.T
138
+ a_logits_per_text = a_logits_per_audio.T
139
+ t_logits_per_audio = logit_scale_t * all_audio_features_mlp @ all_text_features.T
140
+ t_logits_per_text = t_logits_per_audio.T
141
+ else:
142
+ a_logits_per_audio = logit_scale_a * audio_features @ text_features_mlp.T
143
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
144
+ t_logits_per_audio = logit_scale_t * audio_features_mlp @ text_features.T
145
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
146
+
147
+ # calculated ground-truth and cache if enabled
148
+ num_logits = a_logits_per_audio.shape[0]
149
+ if self.prev_num_logits != num_logits or device not in self.labels:
150
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
151
+ if self.world_size > 1 and self.local_loss:
152
+ labels = labels + num_logits * self.rank
153
+ if self.cache_labels:
154
+ self.labels[device] = labels
155
+ self.prev_num_logits = num_logits
156
+ else:
157
+ labels = self.labels[device]
158
+
159
+ if not self.weighted_loss:
160
+ total_loss = (
161
+ F.cross_entropy(a_logits_per_audio, labels) +
162
+ F.cross_entropy(a_logits_per_text, labels) +
163
+ F.cross_entropy(t_logits_per_audio, labels) +
164
+ F.cross_entropy(t_logits_per_text, labels)
165
+ ) / 4
166
+ else:
167
+ audio_weight = (audio_features@audio_features.T).detach()
168
+ audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(audio_weight)))).detach()
169
+ text_weight = (text_features@text_features.T).detach()
170
+ text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(text_features)))).detach()
171
+ total_loss = (
172
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) +
173
+ F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) +
174
+ F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) +
175
+ F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
176
+ ) / 4
177
+ else:
178
+ if self.world_size > 1:
179
+ all_audio_features, all_text_features = gather_features(
180
+ audio_features=audio_features,text_features=text_features,
181
+ local_loss=self.local_loss,gather_with_grad=self.gather_with_grad,
182
+ rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod,
183
+ mlp_loss=self.mlp_loss
184
+ )
185
+
186
+ if self.local_loss:
187
+ logits_per_audio = logit_scale_a * audio_features @ all_text_features.T
188
+ logits_per_text = logit_scale_a * text_features @ all_audio_features.T
189
+ else:
190
+ logits_per_audio = logit_scale_a * all_audio_features @ all_text_features.T
191
+ logits_per_text = logits_per_audio.T
192
+ else:
193
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
194
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
195
+
196
+ # calculated ground-truth and cache if enabled
197
+ num_logits = logits_per_audio.shape[0]
198
+ if self.prev_num_logits != num_logits or device not in self.labels:
199
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
200
+ if self.world_size > 1 and self.local_loss:
201
+ labels = labels + num_logits * self.rank
202
+ if self.cache_labels:
203
+ self.labels[device] = labels
204
+ self.prev_num_logits = num_logits
205
+ else:
206
+ labels = self.labels[device]
207
+ if not self.weighted_loss:
208
+ total_loss = (
209
+ F.cross_entropy(logits_per_audio, labels) +
210
+ F.cross_entropy(logits_per_text, labels)
211
+ ) / 2
212
+ else:
213
+ audio_weight = (all_audio_features@all_audio_features.T).detach()
214
+ audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(all_audio_features)))).detach()
215
+ text_weight = (all_text_features@all_text_features.T).detach()
216
+ text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(all_text_features)))).detach()
217
+ total_loss = (
218
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight) +
219
+ F.cross_entropy(logits_per_text, labels, weight=audio_weight)
220
+ ) / 2
221
+ return total_loss
222
+
223
+ def lp_gather_features(
224
+ pred,
225
+ target,
226
+ world_size=1,
227
+ use_horovod=False
228
+ ):
229
+ if use_horovod:
230
+ assert hvd is not None, 'Please install horovod'
231
+ with torch.no_grad():
232
+ all_preds = hvd.allgather(pred)
233
+ all_targets = hvd.allgath(target)
234
+ else:
235
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
236
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
237
+
238
+ dist.all_gather(gathered_preds, pred)
239
+ dist.all_gather(gathered_targets, target)
240
+ all_preds = torch.cat(gathered_preds, dim=0)
241
+ all_targets = torch.cat(gathered_targets, dim=0)
242
+
243
+ return all_preds, all_targets
244
+
245
+
246
+ def get_map(pred, target):
247
+ pred = torch.sigmoid(pred).numpy()
248
+ target = target.numpy()
249
+ return np.mean(average_precision_score(target, pred, average=None))
250
+
251
+ def get_acc(pred, target):
252
+ pred = torch.argmax(pred,1).numpy()
253
+ target = torch.argmax(target,1).numpy()
254
+ return accuracy_score(target, pred)
255
+
256
+ def get_mauc(pred, target):
257
+ pred = torch.sigmoid(pred).numpy()
258
+ target = target.numpy()
259
+ return np.mean(roc_auc_score(target, pred, average=None))
260
+
261
+
262
+ class LPMetrics(object):
263
+ def __init__(self, metric_names = ['map','acc','mauc']):
264
+ self.metrics = []
265
+ for name in metric_names:
266
+ self.metrics.append(self.get_metric(name))
267
+ self.metric_names = metric_names
268
+
269
+ def get_metric(self,name):
270
+ if name == 'map':
271
+ return get_map
272
+ elif name == 'acc':
273
+ return get_acc
274
+ elif name == 'mauc':
275
+ return get_mauc
276
+ else:
277
+ raise ValueError(f'the metric should be at least one of [map, acc, mauc]')
278
+
279
+ def evaluate_mertics(self, pred, target):
280
+ metric_dict = {}
281
+ for i in range(len(self.metric_names)):
282
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
283
+ return metric_dict
284
+
285
+
286
+ def calc_celoss(pred, target):
287
+ target = torch.argmax(target, 1).long()
288
+ return nn.CrossEntropyLoss()(pred, target)
289
+
290
+
291
+ class LPLoss(nn.Module):
292
+
293
+ def __init__(self, loss_name):
294
+ super().__init__()
295
+ if loss_name == 'bce':
296
+ self.loss_func = nn.BCEWithLogitsLoss()
297
+ elif loss_name == 'ce':
298
+ self.loss_func = calc_celoss
299
+ elif loss_name == 'mse':
300
+ self.loss_func = nn.MSELoss()
301
+ else:
302
+ raise ValueError(f'the loss func should be at least one of [bce, ce, mse]')
303
+
304
+ def forward(self, pred, target):
305
+ loss = self.loss_func(pred, target)
306
+ return loss
307
+
ldm/modules/encoders/open_clap/model.py ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLAP Model
2
+
3
+ Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ Adapted to the Audio Task.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from dataclasses import dataclass
9
+ from email.mime import audio
10
+ from typing import Tuple, Union, Callable, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn
16
+
17
+ from .timm_model import TimmModel
18
+ import logging
19
+ from .utils import freeze_batch_norm_2d
20
+
21
+ from .pann_model import create_pann_model
22
+ from .htsat import create_htsat_model
23
+ from transformers import BertModel, RobertaModel, BartModel
24
+ from transformers.tokenization_utils_base import BatchEncoding
25
+
26
+
27
+ class MLPLayers(nn.Module):
28
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
29
+ super(MLPLayers, self).__init__()
30
+ self.nonlin = nonlin
31
+ self.dropout = dropout
32
+
33
+ sequence = []
34
+ for u0, u1 in zip(units[:-1], units[1:]):
35
+ sequence.append(nn.Linear(u0, u1))
36
+ sequence.append(self.nonlin)
37
+ sequence.append(nn.Dropout(self.dropout))
38
+ sequence = sequence[:-2]
39
+
40
+ self.sequential = nn.Sequential(*sequence)
41
+
42
+ def forward(self, X):
43
+ X = self.sequential(X)
44
+ return X
45
+
46
+
47
+ class Bottleneck(nn.Module):
48
+ expansion = 4
49
+
50
+ def __init__(self, inplanes, planes, stride=1):
51
+ super().__init__()
52
+
53
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
54
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
55
+ self.bn1 = nn.BatchNorm2d(planes)
56
+
57
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
58
+ self.bn2 = nn.BatchNorm2d(planes)
59
+
60
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
61
+
62
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
63
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
64
+
65
+ self.relu = nn.ReLU(inplace=True)
66
+ self.downsample = None
67
+ self.stride = stride
68
+
69
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
70
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
71
+ self.downsample = nn.Sequential(
72
+ OrderedDict(
73
+ [
74
+ ("-1", nn.AvgPool2d(stride)),
75
+ (
76
+ "0",
77
+ nn.Conv2d(
78
+ inplanes,
79
+ planes * self.expansion,
80
+ 1,
81
+ stride=1,
82
+ bias=False,
83
+ ),
84
+ ),
85
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
86
+ ]
87
+ )
88
+ )
89
+
90
+ def forward(self, x: torch.Tensor):
91
+ identity = x
92
+
93
+ out = self.relu(self.bn1(self.conv1(x)))
94
+ out = self.relu(self.bn2(self.conv2(out)))
95
+ out = self.avgpool(out)
96
+ out = self.bn3(self.conv3(out))
97
+
98
+ if self.downsample is not None:
99
+ identity = self.downsample(x)
100
+
101
+ out += identity
102
+ out = self.relu(out)
103
+ return out
104
+
105
+
106
+ class AttentionPool2d(nn.Module):
107
+ def __init__(
108
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
109
+ ):
110
+ super().__init__()
111
+ self.positional_embedding = nn.Parameter(
112
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
113
+ )
114
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
115
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
116
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
117
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
118
+ self.num_heads = num_heads
119
+
120
+ def forward(self, x):
121
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
122
+ 2, 0, 1
123
+ ) # NCHW -> (HW)NC
124
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
125
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
126
+ x, _ = F.multi_head_attention_forward(
127
+ query=x,
128
+ key=x,
129
+ value=x,
130
+ embed_dim_to_check=x.shape[-1],
131
+ num_heads=self.num_heads,
132
+ q_proj_weight=self.q_proj.weight,
133
+ k_proj_weight=self.k_proj.weight,
134
+ v_proj_weight=self.v_proj.weight,
135
+ in_proj_weight=None,
136
+ in_proj_bias=torch.cat(
137
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
138
+ ),
139
+ bias_k=None,
140
+ bias_v=None,
141
+ add_zero_attn=False,
142
+ dropout_p=0,
143
+ out_proj_weight=self.c_proj.weight,
144
+ out_proj_bias=self.c_proj.bias,
145
+ use_separate_proj_weight=True,
146
+ training=self.training,
147
+ need_weights=False,
148
+ )
149
+
150
+ return x[0]
151
+
152
+
153
+ class ModifiedResNet(nn.Module):
154
+ """
155
+ A ResNet class that is similar to torchvision's but contains the following changes:
156
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
157
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
158
+ - The final pooling layer is a QKV attention instead of an average pool
159
+ """
160
+
161
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
162
+ super().__init__()
163
+ self.output_dim = output_dim
164
+ self.image_size = image_size
165
+
166
+ # the 3-layer stem
167
+ self.conv1 = nn.Conv2d(
168
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
169
+ )
170
+ self.bn1 = nn.BatchNorm2d(width // 2)
171
+ self.conv2 = nn.Conv2d(
172
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
173
+ )
174
+ self.bn2 = nn.BatchNorm2d(width // 2)
175
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
176
+ self.bn3 = nn.BatchNorm2d(width)
177
+ self.avgpool = nn.AvgPool2d(2)
178
+ self.relu = nn.ReLU(inplace=True)
179
+
180
+ # residual layers
181
+ self._inplanes = width # this is a *mutable* variable used during construction
182
+ self.layer1 = self._make_layer(width, layers[0])
183
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
184
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
185
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
186
+
187
+ embed_dim = width * 32 # the ResNet feature dimension
188
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
189
+
190
+ self.init_parameters()
191
+
192
+ def _make_layer(self, planes, blocks, stride=1):
193
+ layers = [Bottleneck(self._inplanes, planes, stride)]
194
+
195
+ self._inplanes = planes * Bottleneck.expansion
196
+ for _ in range(1, blocks):
197
+ layers.append(Bottleneck(self._inplanes, planes))
198
+
199
+ return nn.Sequential(*layers)
200
+
201
+ def init_parameters(self):
202
+ if self.attnpool is not None:
203
+ std = self.attnpool.c_proj.in_features**-0.5
204
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
205
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
206
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
207
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
208
+
209
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
210
+ for name, param in resnet_block.named_parameters():
211
+ if name.endswith("bn3.weight"):
212
+ nn.init.zeros_(param)
213
+
214
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
215
+ assert (
216
+ unlocked_groups == 0
217
+ ), "partial locking not currently supported for this model"
218
+ for param in self.parameters():
219
+ param.requires_grad = False
220
+ if freeze_bn_stats:
221
+ freeze_batch_norm_2d(self)
222
+
223
+ def stem(self, x):
224
+ for conv, bn in [
225
+ (self.conv1, self.bn1),
226
+ (self.conv2, self.bn2),
227
+ (self.conv3, self.bn3),
228
+ ]:
229
+ x = self.relu(bn(conv(x)))
230
+ x = self.avgpool(x)
231
+ return x
232
+
233
+ def forward(self, x):
234
+ x = self.stem(x)
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.attnpool(x)
240
+
241
+ return x
242
+
243
+
244
+ class LayerNorm(nn.LayerNorm):
245
+ """Subclass torch's LayerNorm to handle fp16."""
246
+
247
+ def forward(self, x: torch.Tensor):
248
+ orig_type = x.dtype
249
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
250
+ return x.to(orig_type)
251
+
252
+
253
+ class QuickGELU(nn.Module):
254
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
255
+ def forward(self, x: torch.Tensor):
256
+ return x * torch.sigmoid(1.702 * x)
257
+
258
+
259
+ class ResidualAttentionBlock(nn.Module):
260
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
261
+ super().__init__()
262
+
263
+ self.attn = nn.MultiheadAttention(d_model, n_head)
264
+ self.ln_1 = LayerNorm(d_model)
265
+ self.mlp = nn.Sequential(
266
+ OrderedDict(
267
+ [
268
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
269
+ ("gelu", act_layer()),
270
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
271
+ ]
272
+ )
273
+ )
274
+ self.ln_2 = LayerNorm(d_model)
275
+
276
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
278
+
279
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
280
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
281
+ x = x + self.mlp(self.ln_2(x))
282
+ return x
283
+
284
+
285
+ class Transformer(nn.Module):
286
+ def __init__(
287
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
288
+ ):
289
+ super().__init__()
290
+ self.width = width
291
+ self.layers = layers
292
+ self.resblocks = nn.ModuleList(
293
+ [
294
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
295
+ for _ in range(layers)
296
+ ]
297
+ )
298
+
299
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
300
+ for r in self.resblocks:
301
+ x = r(x, attn_mask=attn_mask)
302
+ return x
303
+
304
+
305
+ class VisualTransformer(nn.Module):
306
+ def __init__(
307
+ self,
308
+ image_size: int,
309
+ patch_size: int,
310
+ width: int,
311
+ layers: int,
312
+ heads: int,
313
+ output_dim: int,
314
+ act_layer: Callable = nn.GELU,
315
+ ):
316
+ super().__init__()
317
+ self.image_size = image_size
318
+ self.output_dim = output_dim
319
+ self.conv1 = nn.Conv2d(
320
+ in_channels=3,
321
+ out_channels=width,
322
+ kernel_size=patch_size,
323
+ stride=patch_size,
324
+ bias=False,
325
+ )
326
+
327
+ scale = width**-0.5
328
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
329
+ self.positional_embedding = nn.Parameter(
330
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
331
+ )
332
+ self.ln_pre = LayerNorm(width)
333
+
334
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
335
+
336
+ self.ln_post = LayerNorm(width)
337
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
338
+
339
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
340
+ assert (
341
+ unlocked_groups == 0
342
+ ), "partial locking not currently supported for this model"
343
+ for param in self.parameters():
344
+ param.requires_grad = False
345
+
346
+ def forward(self, x: torch.Tensor):
347
+ x = self.conv1(x) # shape = [*, width, grid, grid]
348
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
349
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
350
+ x = torch.cat(
351
+ [
352
+ self.class_embedding.to(x.dtype)
353
+ + torch.zeros(
354
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
355
+ ),
356
+ x,
357
+ ],
358
+ dim=1,
359
+ ) # shape = [*, grid ** 2 + 1, width]
360
+ x = x + self.positional_embedding.to(x.dtype)
361
+ x = self.ln_pre(x)
362
+
363
+ x = x.permute(1, 0, 2) # NLD -> LND
364
+ x = self.text_branch(x)
365
+ x = x.permute(1, 0, 2) # LND -> NLD
366
+
367
+ x = self.ln_post(x[:, 0, :])
368
+
369
+ if self.proj is not None:
370
+ x = x @ self.proj
371
+
372
+ return x
373
+
374
+
375
+ @dataclass
376
+ class CLAPVisionCfg:
377
+ layers: Union[Tuple[int, int, int, int], int] = 12
378
+ width: int = 768
379
+ patch_size: int = 16
380
+ image_size: Union[Tuple[int, int], int] = 224
381
+ timm_model_name: str = (
382
+ None # a valid model name overrides layers, width, patch_size
383
+ )
384
+ timm_model_pretrained: bool = (
385
+ False # use (imagenet) pretrained weights for named model
386
+ )
387
+ timm_pool: str = (
388
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
389
+ )
390
+ timm_proj: str = (
391
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
392
+ )
393
+
394
+
395
+ # Audio Config Class
396
+ @dataclass
397
+ class CLAPAudioCfp:
398
+ model_type: str = "PANN"
399
+ model_name: str = "Cnn14"
400
+ sample_rate: int = 48000
401
+ # Param
402
+ audio_length: int = 1024
403
+ window_size: int = 1024
404
+ hop_size: int = 1024
405
+ fmin: int = 50
406
+ fmax: int = 14000
407
+ class_num: int = 527
408
+ mel_bins: int = 64
409
+ clip_samples: int = 480000
410
+
411
+
412
+ @dataclass
413
+ class CLAPTextCfg:
414
+ context_length: int
415
+ vocab_size: int
416
+ width: int
417
+ heads: int
418
+ layers: int
419
+ model_type: str
420
+
421
+
422
+ class CLAP(nn.Module):
423
+ def __init__(
424
+ self,
425
+ embed_dim: int,
426
+ audio_cfg: CLAPAudioCfp,
427
+ text_cfg: CLAPTextCfg,
428
+ quick_gelu: bool = False,
429
+ enable_fusion: bool = False,
430
+ fusion_type: str = 'None',
431
+ joint_embed_shape: int = 512,
432
+ mlp_act: str = 'relu',
433
+ ):
434
+ super().__init__()
435
+ if isinstance(audio_cfg, dict):
436
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
437
+ if isinstance(text_cfg, dict):
438
+ text_cfg = CLAPTextCfg(**text_cfg)
439
+
440
+ self.audio_cfg = audio_cfg
441
+ self.text_cfg = text_cfg
442
+ self.enable_fusion = enable_fusion
443
+ self.fusion_type = fusion_type
444
+ self.joint_embed_shape = joint_embed_shape
445
+ self.mlp_act = mlp_act
446
+
447
+
448
+ self.context_length = text_cfg.context_length
449
+
450
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
451
+ # memory efficient in recent PyTorch releases (>= 1.10).
452
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
453
+ act_layer = QuickGELU if quick_gelu else nn.GELU
454
+
455
+ if mlp_act == 'relu':
456
+ mlp_act_layer = nn.ReLU()
457
+ elif mlp_act == 'gelu':
458
+ mlp_act_layer = nn.GELU()
459
+ else:
460
+ raise NotImplementedError
461
+
462
+ # audio branch
463
+ # audio branch parameters
464
+ if audio_cfg.model_type == "PANN":
465
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
466
+ elif audio_cfg.model_type == "HTSAT":
467
+ self.audio_branch = create_htsat_model(audio_cfg, enable_fusion, fusion_type)
468
+ else:
469
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
470
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
471
+
472
+
473
+ # text branch
474
+ # text branch parameters
475
+ if text_cfg.model_type == "transformer":
476
+ self.text_branch = Transformer(
477
+ width=text_cfg.width,
478
+ layers=text_cfg.layers,
479
+ heads=text_cfg.heads,
480
+ act_layer=act_layer,
481
+ )
482
+ self.vocab_size = text_cfg.vocab_size
483
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
484
+ self.positional_embedding = nn.Parameter(
485
+ torch.empty(self.context_length, text_cfg.width)
486
+ )
487
+ self.ln_final = LayerNorm(text_cfg.width)
488
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
489
+ self.joint_embed_shape,
490
+ self.joint_embed_shape], dropout=0.1)
491
+ self.text_projection = nn.Sequential(
492
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
493
+ mlp_act_layer,
494
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
495
+ )
496
+ elif text_cfg.model_type == "bert":
497
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
498
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
499
+ self.joint_embed_shape,
500
+ self.joint_embed_shape], dropout=0.1)
501
+ self.text_projection = nn.Sequential(
502
+ nn.Linear(768, self.joint_embed_shape),
503
+ mlp_act_layer,
504
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
505
+ )
506
+ elif text_cfg.model_type == "roberta":
507
+ self.text_branch = RobertaModel.from_pretrained('roberta-base')
508
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
509
+ self.joint_embed_shape,
510
+ self.joint_embed_shape], dropout=0.1)
511
+ self.text_projection = nn.Sequential(
512
+ nn.Linear(768, self.joint_embed_shape),
513
+ mlp_act_layer,
514
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
515
+ )
516
+ elif text_cfg.model_type == "bart":
517
+ self.text_branch = BartModel.from_pretrained('facebook/bart-base')
518
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
519
+ self.joint_embed_shape,
520
+ self.joint_embed_shape], dropout=0.1)
521
+ self.text_projection = nn.Sequential(
522
+ nn.Linear(768, self.joint_embed_shape),
523
+ mlp_act_layer,
524
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
525
+ )
526
+ else:
527
+ logging.error(f"Model config for {text_cfg.model_type} not found")
528
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
529
+ self.text_branch_type = text_cfg.model_type
530
+ # text branch parameters
531
+
532
+ # audio branch parameters
533
+ self.audio_transform = MLPLayers(units=[self.joint_embed_shape,
534
+ self.joint_embed_shape,
535
+ self.joint_embed_shape], dropout=0.1)
536
+
537
+ # below here is text branch parameters
538
+
539
+ # ============================================================================================================
540
+ self.audio_projection = nn.Sequential(
541
+ nn.Linear(embed_dim, self.joint_embed_shape),
542
+ mlp_act_layer,
543
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
544
+ )
545
+
546
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
547
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
548
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
549
+
550
+ self.init_text_branch_parameters()
551
+
552
+ def init_text_branch_parameters(self):
553
+ if self.text_branch_type == "transformer":
554
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
555
+ nn.init.normal_(self.positional_embedding, std=0.01)
556
+ proj_std = (self.text_branch.width**-0.5) * (
557
+ (2 * self.text_branch.layers) ** -0.5
558
+ )
559
+ attn_std = self.text_branch.width**-0.5
560
+ fc_std = (2 * self.text_branch.width) ** -0.5
561
+ for block in self.text_branch.resblocks:
562
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
563
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
564
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
565
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
566
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
567
+ width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
568
+ elif self.text_branch_type == "bart":
569
+ width = self.text_branch.shared.weight.shape[-1]
570
+ else:
571
+ width = self.text_branch.width
572
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
573
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
574
+
575
+ # deprecated
576
+ # if hasattr(self.visual, 'init_parameters'):
577
+ # self.visual.init_parameters()
578
+
579
+ # if self.text_projection is not None:
580
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
581
+
582
+ def build_attention_mask(self):
583
+ # lazily create causal attention mask, with full attention between the vision tokens
584
+ # pytorch uses additive attention mask; fill with -inf
585
+ mask = torch.empty(self.context_length, self.context_length)
586
+ mask.fill_(float("-inf"))
587
+ mask.triu_(1) # zero out the lower diagonal
588
+ return mask
589
+
590
+ def encode_audio(self, audio, device):
591
+ return self.audio_branch(audio, mixup_lambda=None, device=device) # mix lambda needs to add
592
+
593
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
594
+ # tmp = {}
595
+ # for k in x[0].keys():
596
+ # tmp[k] = []
597
+ # for i in range(len(x)):
598
+ # tmp[k].append(x[i][k][:77])
599
+ # for k in x[0].keys():
600
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
601
+ # return tmp
602
+
603
+ def encode_text(self, text, device):
604
+ if self.text_branch_type == "transformer":
605
+ text = text.to(device=device, non_blocking=True)
606
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
607
+
608
+ x = x + self.positional_embedding
609
+ x = x.permute(1, 0, 2) # NLD -> LND
610
+ x = self.text_branch(x, attn_mask=self.attn_mask)
611
+ x = x.permute(1, 0, 2) # LND -> NLD
612
+ x = self.ln_final(x)
613
+
614
+ # x.shape = [batch_size, n_ctx, transformer.width]
615
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
616
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
617
+ elif self.text_branch_type == "bert":
618
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
619
+ # text = BatchEncoding(text)
620
+ x = self.text_branch(
621
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
622
+ attention_mask=text["attention_mask"].to(
623
+ device=device, non_blocking=True
624
+ ),
625
+ token_type_ids=text["token_type_ids"].to(
626
+ device=device, non_blocking=True
627
+ ),
628
+ )["pooler_output"]
629
+ x = self.text_projection(x)
630
+ elif self.text_branch_type == "roberta":
631
+ x = self.text_branch(
632
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
633
+ attention_mask=text["attention_mask"].to(
634
+ device=device, non_blocking=True
635
+ ),
636
+ )["pooler_output"]
637
+
638
+ x = self.text_projection(x)
639
+ elif self.text_branch_type == "bart":
640
+ x = torch.mean(self.text_branch(
641
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
642
+ attention_mask=text["attention_mask"].to(
643
+ device=device, non_blocking=True
644
+ ),
645
+ )["encoder_last_hidden_state"],axis=1)
646
+ x = self.text_projection(x)
647
+ else:
648
+ logging.error(f"Model type {self.text_branch_type} not found")
649
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
650
+ return x
651
+
652
+ def forward(self, audio, text, device=None):
653
+ """Forward audio and text into the CLAP
654
+
655
+ Parameters
656
+ ----------
657
+ audio: torch.Tensor (batch_size, audio_length)
658
+ the time-domain audio input / the batch of mel_spec and longer list.
659
+ text: torch.Tensor () // need to add
660
+ the text token input
661
+ """
662
+ if device is None:
663
+ if audio is not None:
664
+ device = audio.device
665
+ elif text is not None:
666
+ device = text.device
667
+ if audio is None and text is None:
668
+ # a hack to get the logit scale
669
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
670
+ elif audio is None:
671
+ return self.encode_text(text, device=device)
672
+ elif text is None:
673
+ return self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
674
+ audio_features = self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
675
+ audio_features = F.normalize(audio_features, dim=-1)
676
+
677
+ text_features = self.encode_text(
678
+ text, device=device
679
+ )
680
+ # print("text_features", text_features)
681
+ # print("text_features.shape", text_features.shape)
682
+ # print("text_features.type", type(text_features))
683
+ text_features = F.normalize(text_features, dim=-1)
684
+
685
+ audio_features_mlp = self.audio_transform(audio_features)
686
+ text_features_mlp = self.text_transform(text_features)
687
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
688
+ return (
689
+ audio_features,
690
+ text_features,
691
+ audio_features_mlp,
692
+ text_features_mlp,
693
+ self.logit_scale_a.exp(),
694
+ self.logit_scale_t.exp(),
695
+ )
696
+
697
+ def get_logit_scale(self):
698
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
699
+
700
+ def get_textual_embedding(self, data):
701
+
702
+ device = next(self.parameters()).device
703
+ for k in data:
704
+ data[k] = data[k].to(device)
705
+
706
+ # if self.text_branch_type == "roberta":
707
+ text_embeds = self.text_branch(
708
+ input_ids=data["input_ids"].to(device=device, non_blocking=True),
709
+ attention_mask=data["attention_mask"].to(device=device, non_blocking=True),
710
+ )["last_hidden_state"]
711
+
712
+ text_embeds = self.text_projection(text_embeds)
713
+
714
+ text_embeds = F.normalize(text_embeds, dim=-1)
715
+
716
+ return text_embeds
717
+
718
+ def get_text_embedding(self, data):
719
+ """Get the text embedding from the model
720
+
721
+ Parameters
722
+ ----------
723
+ data: torch.Tensor
724
+ a tensor of text embedding
725
+
726
+ Returns
727
+ ----------
728
+ text_embed: torch.Tensor
729
+ a tensor of text_embeds (N, D)
730
+
731
+ """
732
+ device = next(self.parameters()).device
733
+ for k in data:
734
+ data[k] = data[k].to(device)
735
+ text_embeds = self.encode_text(data, device=device)
736
+ text_embeds = F.normalize(text_embeds, dim=-1)
737
+
738
+ return text_embeds
739
+
740
+ def get_audio_embedding(self, data):
741
+ """Get the audio embedding from the model
742
+
743
+ Parameters
744
+ ----------
745
+ data: a list of dict
746
+ the audio input dict list from 'get_audio_feature' method
747
+
748
+ Returns
749
+ ----------
750
+ audio_embed: torch.Tensor
751
+ a tensor of audio_embeds (N, D)
752
+
753
+ """
754
+ device = next(self.parameters()).device
755
+ input_dict = {}
756
+ keys = data[0].keys()
757
+ for k in keys:
758
+ input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(device)
759
+
760
+ audio_embeds = self.audio_projection(self.encode_audio(input_dict, device=device)["embedding"])
761
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
762
+
763
+ return audio_embeds
764
+
765
+
766
+
767
+ def audio_infer(self, audio, hopsize=None, device=None):
768
+ """Forward one audio and produce the audio embedding
769
+
770
+ Parameters
771
+ ----------
772
+ audio: (audio_length)
773
+ the time-domain audio input, notice that it must be only one input
774
+ hopsize: int
775
+ the overlap hopsize as the sliding window
776
+
777
+ Returns
778
+ ----------
779
+ output_dict: {
780
+ key: [n, (embedding_shape)] if "HTS-AT"
781
+ or
782
+ key: [(embedding_shape)] if "PANN"
783
+ }
784
+ the list of key values of the audio branch
785
+
786
+ """
787
+
788
+ assert not self.training, "the inference mode must be run at eval stage"
789
+ output_dict = {}
790
+ # PANN
791
+ if self.audio_cfg.model_type == "PANN":
792
+ audio_input = audio.unsqueeze(dim=0)
793
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
794
+ elif self.audio_cfg.model_type == "HTSAT":
795
+ # repeat
796
+ audio_len = len(audio)
797
+ k = self.audio_cfg.clip_samples // audio_len
798
+ if k > 1:
799
+ audio = audio.repeat(k)
800
+ audio_len = len(audio)
801
+
802
+ if hopsize is None:
803
+ hopsize = min(hopsize, audio_len)
804
+
805
+ if audio_len > self.audio_cfg.clip_samples:
806
+ audio_input = [
807
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
808
+ for pos in range(
809
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
810
+ )
811
+ ]
812
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
813
+ audio_input = torch.stack(audio_input)
814
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
815
+ else:
816
+ audio_input = audio.unsqueeze(dim=0)
817
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
818
+
819
+ return output_dict
820
+
821
+
822
+ def convert_weights_to_fp16(model: nn.Module):
823
+ """Convert applicable model parameters to fp16"""
824
+
825
+ def _convert_weights_to_fp16(l):
826
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
827
+ l.weight.data = l.weight.data.half()
828
+ if l.bias is not None:
829
+ l.bias.data = l.bias.data.half()
830
+
831
+ if isinstance(l, nn.MultiheadAttention):
832
+ for attr in [
833
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
834
+ "in_proj_bias",
835
+ "bias_k",
836
+ "bias_v",
837
+ ]:
838
+ tensor = getattr(l, attr)
839
+ if tensor is not None:
840
+ tensor.data = tensor.data.half()
841
+
842
+ for name in ["text_projection", "proj"]:
843
+ if hasattr(l, name):
844
+ attr = getattr(l, name)
845
+ if attr is not None:
846
+ attr.data = attr.data.half()
847
+
848
+ model.apply(_convert_weights_to_fp16)
849
+
850
+
851
+ # Ignore the state dict of the vision part
852
+ def build_model_from_openai_state_dict(state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = 'None'):
853
+
854
+ embed_dim = model_cfg["embed_dim"]
855
+ audio_cfg = model_cfg["audio_cfg"]
856
+ text_cfg = model_cfg["text_cfg"]
857
+ context_length = state_dict["positional_embedding"].shape[0]
858
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
859
+ transformer_width = state_dict["ln_final.weight"].shape[0]
860
+ transformer_heads = transformer_width // 64
861
+ transformer_layers = len(
862
+ set(
863
+ k.split(".")[2]
864
+ for k in state_dict
865
+ if k.startswith(f"transformer.resblocks")
866
+ )
867
+ )
868
+
869
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
870
+ text_cfg = CLAPTextCfg(**text_cfg)
871
+
872
+ model = CLAP(
873
+ embed_dim,
874
+ audio_cfg=audio_cfg,
875
+ text_cfg=text_cfg,
876
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
877
+ enable_fusion=enable_fusion,
878
+ fusion_type=fusion_type
879
+ )
880
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
881
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
882
+ pop_keys = list(state_dict.keys())[::]
883
+ # pop the visual branch saved weights
884
+ for key in pop_keys:
885
+ if key.startswith("visual."):
886
+ state_dict.pop(key, None)
887
+
888
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
889
+ state_dict.pop(key, None)
890
+
891
+ # not use fp16
892
+ # convert_weights_to_fp16(model)
893
+ model.load_state_dict(state_dict, strict=False)
894
+ return model.eval()
895
+
896
+
897
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
898
+ model.eval()
899
+ audio_length = model.audio_cfg.audio_length
900
+ example_audio = torch.ones((batch_size, audio_length), device=device)
901
+ example_text = torch.zeros(
902
+ (batch_size, model.context_length), dtype=torch.int, device=device
903
+ )
904
+ model = torch.jit.trace_module(
905
+ model,
906
+ inputs=dict(
907
+ forward=(example_audio, example_text),
908
+ encode_text=(example_text,),
909
+ encode_image=(example_audio,),
910
+ ),
911
+ )
912
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
913
+ return model
ldm/modules/encoders/open_clap/model_configs/HTSAT-base.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "base"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
ldm/modules/encoders/open_clap/model_configs/HTSAT-large.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "large"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
ldm/modules/encoders/open_clap/model_configs/HTSAT-tiny-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
ldm/modules/encoders/open_clap/model_configs/HTSAT-tiny.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
ldm/modules/encoders/open_clap/model_configs/PANN-10.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn10"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
ldm/modules/encoders/open_clap/model_configs/PANN-14-fmax-18k.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 18000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
ldm/modules/encoders/open_clap/model_configs/PANN-14-fmax-8k-20s.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 960000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 360,
10
+ "fmin": 50,
11
+ "fmax": 8000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
ldm/modules/encoders/open_clap/model_configs/PANN-14-tiny-transformer.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 4
22
+ }
23
+ }