wsntxxn commited on
Commit
6065472
1 Parent(s): 35ff792

Add AudioCaps checkpoint

Browse files
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argparse
3
+ from functools import partial
4
+ import gradio as gr
5
+ import torch
6
+ from torchaudio.functional import resample
7
+
8
+ import utils.train_util as train_util
9
+
10
+
11
+ def load_model(cfg,
12
+ ckpt_path,
13
+ device):
14
+ model = train_util.init_model_from_config(cfg["model"])
15
+ ckpt = torch.load(ckpt_path, "cpu")
16
+ train_util.load_pretrained_model(model, ckpt)
17
+ model.eval()
18
+ model = model.to(device)
19
+ tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
20
+ if not tokenizer.loaded:
21
+ tokenizer.load_state_dict(ckpt["tokenizer"])
22
+ model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
23
+ return model, tokenizer
24
+
25
+
26
+ def infer(file, device, model, tokenizer, target_sr):
27
+ sr, wav = file
28
+ wav = torch.as_tensor(wav)
29
+ if wav.dtype == torch.short:
30
+ wav = wav / 2 ** 15
31
+ elif wav.dtype == torch.int:
32
+ wav = wav / 2 ** 31
33
+ if wav.ndim > 1:
34
+ wav = wav.mean(1)
35
+ wav = resample(wav, sr, target_sr)
36
+ wav_len = len(wav)
37
+ wav = wav.float().unsqueeze(0).to(device)
38
+ input_dict = {
39
+ "mode": "inference",
40
+ "wav": wav,
41
+ "wav_len": [wav_len],
42
+ "specaug": False,
43
+ "sample_method": "beam",
44
+ "beam_size": 3,
45
+ }
46
+ with torch.no_grad():
47
+ output_dict = model(input_dict)
48
+ seq = output_dict["seq"].cpu().numpy()
49
+ cap = tokenizer.decode(seq)[0]
50
+ return cap
51
+
52
+ # def input_toggle(input_type):
53
+ # if input_type == "file":
54
+ # return gr.update(visible=True), gr.update(visible=False)
55
+ # elif input_type == "mic":
56
+ # return gr.update(visible=False), gr.update(visible=True)
57
+
58
+
59
+ if __name__ == "__main__":
60
+
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument("--share", action="store_true", default=False)
63
+
64
+ args = parser.parse_args()
65
+
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+ exp_dir = Path("./checkpoints/audiocaps")
68
+ cfg = train_util.load_config(exp_dir / "config.yaml")
69
+ target_sr = cfg["target_sr"]
70
+ model, tokenizer = load_model(cfg, exp_dir / "ckpt.pth", device)
71
+
72
+ with gr.Blocks() as demo:
73
+ with gr.Row():
74
+ with gr.Column():
75
+ # radio = gr.Radio(
76
+ # ["file", "mic"],
77
+ # value="file",
78
+ # label="Select input type"
79
+ # )
80
+ file = gr.Audio(label="Input", visible=True)
81
+ # mic = gr.Microphone(label="Input", visible=False)
82
+ # radio.change(fn=input_toggle, inputs=radio, outputs=[file, mic])
83
+ btn = gr.Button("Run")
84
+ with gr.Column():
85
+ output = gr.Textbox(label="Output")
86
+ btn.click(
87
+ fn=partial(infer,
88
+ device=device,
89
+ model=model,
90
+ tokenizer=tokenizer,
91
+ target_sr=target_sr),
92
+ inputs=[file,],
93
+ outputs=output
94
+ )
95
+
96
+ demo.launch(share=args.share)
97
+
checkpoints/audiocaps/ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1c435b1cf05a2b0058dae6f096c4eb4e71c685a19754ed84ea1ee812257434b
3
+ size 55293225
checkpoints/audiocaps/config.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tokenizer:
2
+ type: text_tokenizer.DictTokenizer
3
+ args:
4
+ max_length: 20
5
+
6
+ target_sr: 16000
7
+
8
+ model:
9
+ args:
10
+ shared_dim: 1024
11
+ tchr_dim: 768
12
+ model:
13
+ args: {}
14
+ decoder:
15
+ args:
16
+ attn_emb_dim: 1408
17
+ dropout: 0.2
18
+ emb_dim: 256
19
+ fc_emb_dim: 1408
20
+ nlayers: 2
21
+ tie_weights: true
22
+ vocab_size: 4981
23
+ type: models.transformer_decoder.TransformerDecoder
24
+ encoder:
25
+ args:
26
+ freeze: false
27
+ pretrained: true
28
+ type: models.cnn_encoder.EfficientNetB2
29
+ type: models.transformer_model.TransformerModel
30
+ type: models.kd_wrapper.ContraEncoderKdWrapper
models/__init__.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from utils.model_util import max_with_lens, mean_with_lens
6
+
7
+
8
+ def embedding_pooling(x, lens, pooling="mean"):
9
+ if pooling == "max":
10
+ fc_embs = max_with_lens(x, lens)
11
+ elif pooling == "mean":
12
+ fc_embs = mean_with_lens(x, lens)
13
+ elif pooling == "mean+max":
14
+ x_mean = mean_with_lens(x, lens)
15
+ x_max = max_with_lens(x, lens)
16
+ fc_embs = x_mean + x_max
17
+ elif pooling == "last":
18
+ indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
19
+ # indices: [N, 1, hidden]
20
+ fc_embs = torch.gather(x, 1, indices).squeeze(1)
21
+ else:
22
+ raise Exception(f"pooling method {pooling} not support")
23
+ return fc_embs
24
+
25
+
26
+ class BaseEncoder(nn.Module):
27
+
28
+ """
29
+ Encode the given audio into embedding
30
+ Base encoder class, cannot be called directly
31
+ All encoders should inherit from this class
32
+ """
33
+
34
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
35
+ super(BaseEncoder, self).__init__()
36
+ self.spec_dim = spec_dim
37
+ self.fc_feat_dim = fc_feat_dim
38
+ self.attn_feat_dim = attn_feat_dim
39
+
40
+
41
+ def forward(self, x):
42
+ #########################
43
+ # Arguments:
44
+ # `x`: {
45
+ # (may contain)
46
+ # wav: [batch_size, n_samples],
47
+ # spec: [batch_size, n_frames, spec_dim],
48
+ # fc: [batch_size, fc_feat_dim],
49
+ # attn: [batch_size, attn_max_len, attn_feat_dim],
50
+ # attn_len: [batch_size,]
51
+ # ......
52
+ # }
53
+ #
54
+ # Returns:
55
+ # `encoded`: {
56
+ # fc_emb: [batch_size, fc_emb_dim],
57
+ # attn_emb: [batch_size, attn_max_len, attn_emb_dim],
58
+ # attn_emb_lens: [batch_size,]
59
+ # }
60
+ #########################
61
+ raise NotImplementedError
62
+
63
+
64
+ class BaseDecoder(nn.Module):
65
+ """
66
+ Take word/audio embeddings and output the next word probs
67
+ """
68
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim,
69
+ attn_emb_dim, dropout=0.2, tie_weights=False):
70
+ super().__init__()
71
+ self.emb_dim = emb_dim
72
+ self.vocab_size = vocab_size
73
+ self.fc_emb_dim = fc_emb_dim
74
+ self.attn_emb_dim = attn_emb_dim
75
+ self.tie_weights = tie_weights
76
+ self.word_embedding = nn.Embedding(vocab_size, emb_dim)
77
+ self.in_dropout = nn.Dropout(dropout)
78
+
79
+ def forward(self, x):
80
+ raise NotImplementedError
81
+
82
+ def load_word_embedding(self, weight, freeze=True):
83
+ embedding = np.load(weight)
84
+ assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
85
+ assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
86
+
87
+ # embeddings = torch.as_tensor(embeddings).float()
88
+ # self.word_embeddings.weight = nn.Parameter(embeddings)
89
+ # for para in self.word_embeddings.parameters():
90
+ # para.requires_grad = tune
91
+ self.word_embedding = nn.Embedding.from_pretrained(embedding,
92
+ freeze=freeze)
models/base.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from utils.model_util import mean_with_lens, repeat_tensor
9
+
10
+
11
+ class CaptionMetaMixin:
12
+ pad_idx = 0
13
+ start_idx = 1
14
+ end_idx = 2
15
+ max_length = 20
16
+
17
+ @classmethod
18
+ def set_index(cls, start_idx, end_idx, pad_idx):
19
+ cls.start_idx = start_idx
20
+ cls.end_idx = end_idx
21
+ cls.pad_idx = pad_idx
22
+
23
+
24
+ class CaptionModel(nn.Module, CaptionMetaMixin):
25
+ """
26
+ Encoder-decoder captioning model.
27
+ """
28
+
29
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
30
+ super().__init__()
31
+ self.encoder = encoder
32
+ self.decoder = decoder
33
+ self.vocab_size = decoder.vocab_size
34
+ self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
35
+ self.inference_forward_keys = ["sample_method", "max_length", "temp"]
36
+ freeze_encoder = kwargs.get("freeze_encoder", False)
37
+ if freeze_encoder:
38
+ for param in self.encoder.parameters():
39
+ param.requires_grad = False
40
+ self.check_decoder_compatibility()
41
+
42
+ def check_decoder_compatibility(self):
43
+ compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
44
+ assert isinstance(self.decoder, self.compatible_decoders), \
45
+ f"{self.decoder.__class__.__name__} is incompatible with " \
46
+ f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
47
+
48
+ def forward(self, input_dict: Dict):
49
+ """
50
+ input_dict: {
51
+ (required)
52
+ mode: train/inference,
53
+ [spec, spec_len],
54
+ [fc],
55
+ [attn, attn_len],
56
+ [wav, wav_len],
57
+ [sample_method: greedy],
58
+ [temp: 1.0] (in case of no teacher forcing)
59
+
60
+ (optional, mode=train)
61
+ cap,
62
+ cap_len,
63
+ ss_ratio,
64
+
65
+ (optional, mode=inference)
66
+ sample_method: greedy/beam,
67
+ max_length,
68
+ temp,
69
+ beam_size (optional, sample_method=beam),
70
+ n_best (optional, sample_method=beam),
71
+ }
72
+ """
73
+ encoder_output_dict = self.encoder(input_dict)
74
+ output = self.forward_decoder(input_dict, encoder_output_dict)
75
+ return output
76
+
77
+ def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict):
78
+ if input_dict["mode"] == "train":
79
+ forward_dict = {
80
+ "mode": "train", "sample_method": "greedy", "temp": 1.0
81
+ }
82
+ for key in self.train_forward_keys:
83
+ forward_dict[key] = input_dict[key]
84
+ forward_dict.update(encoder_output_dict)
85
+ output = self.train_forward(forward_dict)
86
+ elif input_dict["mode"] == "inference":
87
+ forward_dict = {"mode": "inference"}
88
+ default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
89
+ for key in self.inference_forward_keys:
90
+ if key in input_dict:
91
+ forward_dict[key] = input_dict[key]
92
+ else:
93
+ forward_dict[key] = default_args[key]
94
+
95
+ if forward_dict["sample_method"] == "beam":
96
+ forward_dict["beam_size"] = input_dict.get("beam_size", 3)
97
+ forward_dict["n_best"] = input_dict.get("n_best", False)
98
+ forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
99
+ elif forward_dict["sample_method"] == "dbs":
100
+ forward_dict["beam_size"] = input_dict.get("beam_size", 6)
101
+ forward_dict["group_size"] = input_dict.get("group_size", 3)
102
+ forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
103
+ forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
104
+
105
+ forward_dict.update(encoder_output_dict)
106
+ output = self.inference_forward(forward_dict)
107
+ else:
108
+ raise Exception("mode should be either 'train' or 'inference'")
109
+ output.update(encoder_output_dict)
110
+ return output
111
+
112
+ def prepare_output(self, input_dict):
113
+ output = {}
114
+ batch_size = input_dict["fc_emb"].size(0)
115
+ if input_dict["mode"] == "train":
116
+ max_length = input_dict["cap"].size(1) - 1
117
+ elif input_dict["mode"] == "inference":
118
+ max_length = input_dict["max_length"]
119
+ else:
120
+ raise Exception("mode should be either 'train' or 'inference'")
121
+ device = input_dict["fc_emb"].device
122
+ output["seq"] = torch.full((batch_size, max_length), self.end_idx,
123
+ dtype=torch.long)
124
+ output["logit"] = torch.empty(batch_size, max_length,
125
+ self.vocab_size).to(device)
126
+ output["sampled_logprob"] = torch.zeros(batch_size, max_length)
127
+ output["embed"] = torch.empty(batch_size, max_length,
128
+ self.decoder.d_model).to(device)
129
+ return output
130
+
131
+ def train_forward(self, input_dict):
132
+ if input_dict["ss_ratio"] != 1: # scheduled sampling training
133
+ input_dict["mode"] = "train"
134
+ return self.stepwise_forward(input_dict)
135
+ output = self.seq_forward(input_dict)
136
+ self.train_process(output, input_dict)
137
+ return output
138
+
139
+ def seq_forward(self, input_dict):
140
+ raise NotImplementedError
141
+
142
+ def train_process(self, output, input_dict):
143
+ pass
144
+
145
+ def inference_forward(self, input_dict):
146
+ if input_dict["sample_method"] == "beam":
147
+ return self.beam_search(input_dict)
148
+ elif input_dict["sample_method"] == "dbs":
149
+ return self.diverse_beam_search(input_dict)
150
+ return self.stepwise_forward(input_dict)
151
+
152
+ def stepwise_forward(self, input_dict):
153
+ """Step-by-step decoding"""
154
+ output = self.prepare_output(input_dict)
155
+ max_length = output["seq"].size(1)
156
+ # start sampling
157
+ for t in range(max_length):
158
+ input_dict["t"] = t
159
+ self.decode_step(input_dict, output)
160
+ if input_dict["mode"] == "inference": # decide whether to stop when sampling
161
+ unfinished_t = output["seq"][:, t] != self.end_idx
162
+ if t == 0:
163
+ unfinished = unfinished_t
164
+ else:
165
+ unfinished *= unfinished_t
166
+ output["seq"][:, t][~unfinished] = self.end_idx
167
+ if unfinished.sum() == 0:
168
+ break
169
+ self.stepwise_process(output)
170
+ return output
171
+
172
+ def decode_step(self, input_dict, output):
173
+ """Decoding operation of timestep t"""
174
+ decoder_input = self.prepare_decoder_input(input_dict, output)
175
+ # feed to the decoder to get logit
176
+ output_t = self.decoder(decoder_input)
177
+ logit_t = output_t["logit"]
178
+ # assert logit_t.ndim == 3
179
+ if logit_t.size(1) == 1:
180
+ logit_t = logit_t.squeeze(1)
181
+ embed_t = output_t["embed"].squeeze(1)
182
+ elif logit_t.size(1) > 1:
183
+ logit_t = logit_t[:, -1, :]
184
+ embed_t = output_t["embed"][:, -1, :]
185
+ else:
186
+ raise Exception("no logit output")
187
+ # sample the next input word and get the corresponding logit
188
+ sampled = self.sample_next_word(logit_t,
189
+ method=input_dict["sample_method"],
190
+ temp=input_dict["temp"])
191
+
192
+ output_t.update(sampled)
193
+ output_t["t"] = input_dict["t"]
194
+ output_t["logit"] = logit_t
195
+ output_t["embed"] = embed_t
196
+ self.stepwise_process_step(output, output_t)
197
+
198
+ def prepare_decoder_input(self, input_dict, output):
199
+ """Prepare the inp ut dict for the decoder"""
200
+ raise NotImplementedError
201
+
202
+ def stepwise_process_step(self, output, output_t):
203
+ """Postprocessing (save output values) after each timestep t"""
204
+ t = output_t["t"]
205
+ output["logit"][:, t, :] = output_t["logit"]
206
+ output["seq"][:, t] = output_t["word"]
207
+ output["sampled_logprob"][:, t] = output_t["probs"]
208
+ output["embed"][:, t, :] = output_t["embed"]
209
+
210
+ def stepwise_process(self, output):
211
+ """Postprocessing after the whole step-by-step autoregressive decoding"""
212
+ pass
213
+
214
+ def sample_next_word(self, logit, method, temp):
215
+ """Sample the next word, given probs output by the decoder"""
216
+ logprob = torch.log_softmax(logit, dim=1)
217
+ if method == "greedy":
218
+ sampled_logprob, word = torch.max(logprob.detach(), 1)
219
+ elif method == "gumbel":
220
+ def sample_gumbel(shape, eps=1e-20):
221
+ U = torch.rand(shape).to(logprob.device)
222
+ return -torch.log(-torch.log(U + eps) + eps)
223
+ def gumbel_softmax_sample(logit, temperature):
224
+ y = logit + sample_gumbel(logit.size())
225
+ return torch.log_softmax(y / temperature, dim=-1)
226
+ _logprob = gumbel_softmax_sample(logprob, temp)
227
+ _, word = torch.max(_logprob.data, 1)
228
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
229
+ else:
230
+ logprob = logprob / temp
231
+ if method.startswith("top"):
232
+ top_num = float(method[3:])
233
+ if 0 < top_num < 1: # top-p sampling
234
+ probs = torch.softmax(logit, dim=1)
235
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
236
+ _cumsum = sorted_probs.cumsum(1)
237
+ mask = _cumsum < top_num
238
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
239
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
240
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
241
+ logprob.scatter_(1, sorted_indices, sorted_probs.log())
242
+ else: # top-k sampling
243
+ k = int(top_num)
244
+ tmp = torch.empty_like(logprob).fill_(float('-inf'))
245
+ topk, indices = torch.topk(logprob, k, dim=1)
246
+ tmp = tmp.scatter(1, indices, topk)
247
+ logprob = tmp
248
+ word = torch.distributions.Categorical(logits=logprob.detach()).sample()
249
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
250
+ word = word.detach().long()
251
+ # sampled_logprob: [N,], word: [N,]
252
+ return {"word": word, "probs": sampled_logprob}
253
+
254
+ def beam_search(self, input_dict):
255
+ output = self.prepare_output(input_dict)
256
+ max_length = input_dict["max_length"]
257
+ beam_size = input_dict["beam_size"]
258
+ if input_dict["n_best"]:
259
+ n_best_size = input_dict["n_best_size"]
260
+ batch_size, max_length = output["seq"].size()
261
+ output["seq"] = torch.full((batch_size, n_best_size, max_length),
262
+ self.end_idx, dtype=torch.long)
263
+
264
+ temp = input_dict["temp"]
265
+ # instance by instance beam seach
266
+ for i in range(output["seq"].size(0)):
267
+ output_i = self.prepare_beamsearch_output(input_dict)
268
+ input_dict["sample_idx"] = i
269
+ for t in range(max_length):
270
+ input_dict["t"] = t
271
+ output_t = self.beamsearch_step(input_dict, output_i)
272
+ #######################################
273
+ # merge with previous beam and select the current max prob beam
274
+ #######################################
275
+ logit_t = output_t["logit"]
276
+ if logit_t.size(1) == 1:
277
+ logit_t = logit_t.squeeze(1)
278
+ elif logit_t.size(1) > 1:
279
+ logit_t = logit_t[:, -1, :]
280
+ else:
281
+ raise Exception("no logit output")
282
+ logprob_t = torch.log_softmax(logit_t, dim=1)
283
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
284
+ logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
285
+ if t == 0: # for the first step, all k seq will have the same probs
286
+ topk_logprob, topk_words = logprob_t[0].topk(
287
+ beam_size, 0, True, True)
288
+ else: # unroll and find top logprob, and their unrolled indices
289
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
290
+ beam_size, 0, True, True)
291
+ topk_words = topk_words.cpu()
292
+ output_i["topk_logprob"] = topk_logprob
293
+ # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
294
+ output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
295
+ rounding_mode='trunc')
296
+ output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
297
+ if t == 0:
298
+ output_i["seq"] = output_i["next_word"].unsqueeze(1)
299
+ else:
300
+ output_i["seq"] = torch.cat([
301
+ output_i["seq"][output_i["prev_words_beam"]],
302
+ output_i["next_word"].unsqueeze(1)], dim=1)
303
+
304
+ # add finished beams to results
305
+ is_end = output_i["next_word"] == self.end_idx
306
+ if t == max_length - 1:
307
+ is_end.fill_(1)
308
+
309
+ for beam_idx in range(beam_size):
310
+ if is_end[beam_idx]:
311
+ final_beam = {
312
+ "seq": output_i["seq"][beam_idx].clone(),
313
+ "score": output_i["topk_logprob"][beam_idx].item()
314
+ }
315
+ final_beam["score"] = final_beam["score"] / (t + 1)
316
+ output_i["done_beams"].append(final_beam)
317
+ output_i["topk_logprob"][is_end] -= 1000
318
+
319
+ self.beamsearch_process_step(output_i, output_t)
320
+
321
+ self.beamsearch_process(output, output_i, input_dict)
322
+ return output
323
+
324
+ def prepare_beamsearch_output(self, input_dict):
325
+ beam_size = input_dict["beam_size"]
326
+ device = input_dict["fc_emb"].device
327
+ output = {
328
+ "topk_logprob": torch.zeros(beam_size).to(device),
329
+ "seq": None,
330
+ "prev_words_beam": None,
331
+ "next_word": None,
332
+ "done_beams": [],
333
+ }
334
+ return output
335
+
336
+ def beamsearch_step(self, input_dict, output_i):
337
+ decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
338
+ output_t = self.decoder(decoder_input)
339
+ output_t["t"] = input_dict["t"]
340
+ return output_t
341
+
342
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
343
+ raise NotImplementedError
344
+
345
+ def beamsearch_process_step(self, output_i, output_t):
346
+ pass
347
+
348
+ def beamsearch_process(self, output, output_i, input_dict):
349
+ i = input_dict["sample_idx"]
350
+ done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
351
+ if input_dict["n_best"]:
352
+ done_beams = done_beams[:input_dict["n_best_size"]]
353
+ for out_idx, done_beam in enumerate(done_beams):
354
+ seq = done_beam["seq"]
355
+ output["seq"][i][out_idx, :len(seq)] = seq
356
+ else:
357
+ seq = done_beams[0]["seq"]
358
+ output["seq"][i][:len(seq)] = seq
359
+
360
+ def diverse_beam_search(self, input_dict):
361
+
362
+ def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
363
+ local_time = t - divm
364
+ unaug_logprob = logprob.clone()
365
+
366
+ if divm > 0:
367
+ change = torch.zeros(logprob.size(-1))
368
+ for prev_choice in range(divm):
369
+ prev_decisions = seq_table[prev_choice][..., local_time]
370
+ for prev_labels in range(bdash):
371
+ change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
372
+
373
+ change = change.to(logprob.device)
374
+ logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
375
+
376
+ return logprob, unaug_logprob
377
+
378
+ output = self.prepare_output(input_dict)
379
+ group_size = input_dict["group_size"]
380
+ batch_size = output["seq"].size(0)
381
+ beam_size = input_dict["beam_size"]
382
+ bdash = beam_size // group_size
383
+ input_dict["bdash"] = bdash
384
+ diversity_lambda = input_dict["diversity_lambda"]
385
+ device = input_dict["fc_emb"].device
386
+ max_length = input_dict["max_length"]
387
+ temp = input_dict["temp"]
388
+ group_nbest = input_dict["group_nbest"]
389
+ batch_size, max_length = output["seq"].size()
390
+ if group_nbest:
391
+ output["seq"] = torch.full((batch_size, beam_size, max_length),
392
+ self.end_idx, dtype=torch.long)
393
+ else:
394
+ output["seq"] = torch.full((batch_size, group_size, max_length),
395
+ self.end_idx, dtype=torch.long)
396
+
397
+
398
+ for i in range(batch_size):
399
+ input_dict["sample_idx"] = i
400
+ seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
401
+ logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
402
+ done_beams_table = [[] for _ in range(group_size)]
403
+
404
+ output_i = {
405
+ "prev_words_beam": [None for _ in range(group_size)],
406
+ "next_word": [None for _ in range(group_size)],
407
+ "state": [None for _ in range(group_size)]
408
+ }
409
+
410
+ for t in range(max_length + group_size - 1):
411
+ input_dict["t"] = t
412
+ for divm in range(group_size):
413
+ input_dict["divm"] = divm
414
+ if t >= divm and t <= max_length + divm - 1:
415
+ local_time = t - divm
416
+ decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
417
+ output_t = self.decoder(decoder_input)
418
+ output_t["divm"] = divm
419
+ logit_t = output_t["logit"]
420
+ if logit_t.size(1) == 1:
421
+ logit_t = logit_t.squeeze(1)
422
+ elif logit_t.size(1) > 1:
423
+ logit_t = logit_t[:, -1, :]
424
+ else:
425
+ raise Exception("no logit output")
426
+ logprob_t = torch.log_softmax(logit_t, dim=1)
427
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
428
+ logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
429
+ logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
430
+ if local_time == 0: # for the first step, all k seq will have the same probs
431
+ topk_logprob, topk_words = logprob_t[0].topk(
432
+ bdash, 0, True, True)
433
+ else: # unroll and find top logprob, and their unrolled indices
434
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
435
+ bdash, 0, True, True)
436
+ topk_words = topk_words.cpu()
437
+ logprob_table[divm] = topk_logprob
438
+ output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
439
+ output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
440
+ if local_time > 0:
441
+ seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
442
+ seq_table[divm] = torch.cat([
443
+ seq_table[divm],
444
+ output_i["next_word"][divm].unsqueeze(-1)], -1)
445
+
446
+ is_end = seq_table[divm][:, t-divm] == self.end_idx
447
+ assert seq_table[divm].shape[-1] == t - divm + 1
448
+ if t == max_length + divm - 1:
449
+ is_end.fill_(1)
450
+ for beam_idx in range(bdash):
451
+ if is_end[beam_idx]:
452
+ final_beam = {
453
+ "seq": seq_table[divm][beam_idx].clone(),
454
+ "score": logprob_table[divm][beam_idx].item()
455
+ }
456
+ final_beam["score"] = final_beam["score"] / (t - divm + 1)
457
+ done_beams_table[divm].append(final_beam)
458
+ logprob_table[divm][is_end] -= 1000
459
+ self.dbs_process_step(output_i, output_t)
460
+ done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
461
+ if group_nbest:
462
+ done_beams = sum(done_beams_table, [])
463
+ else:
464
+ done_beams = [group_beam[0] for group_beam in done_beams_table]
465
+ for _, done_beam in enumerate(done_beams):
466
+ output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
467
+
468
+ return output
469
+
470
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
471
+ raise NotImplementedError
472
+
473
+ def dbs_process_step(self, output_i, output_t):
474
+ pass
475
+
476
+
477
+ class CaptionSequenceModel(nn.Module, CaptionMetaMixin):
478
+
479
+ def __init__(self, model, seq_output_size):
480
+ super().__init__()
481
+ self.model = model
482
+ if model.decoder.d_model != seq_output_size:
483
+ self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
484
+ else:
485
+ self.output_transform = lambda x: x
486
+
487
+ def forward(self, input_dict):
488
+ output = self.model(input_dict)
489
+
490
+ if input_dict["mode"] == "train":
491
+ lens = input_dict["cap_len"] - 1
492
+ # seq_outputs: [N, d_model]
493
+ elif input_dict["mode"] == "inference":
494
+ if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
495
+ return output
496
+ seq = output["seq"]
497
+ lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
498
+ else:
499
+ raise Exception("mode should be either 'train' or 'inference'")
500
+ seq_output = mean_with_lens(output["embed"], lens)
501
+ seq_output = self.output_transform(seq_output)
502
+ output["seq_output"] = seq_output
503
+ return output
504
+
models/cnn_encoder.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from einops import rearrange
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchaudio import transforms
9
+
10
+ from utils.model_util import mean_with_lens, max_with_lens
11
+ from utils.train_util import merge_load_state_dict
12
+
13
+
14
+ def init_layer(layer):
15
+ """Initialize a Linear or Convolutional layer. """
16
+ nn.init.xavier_uniform_(layer.weight)
17
+
18
+ if hasattr(layer, 'bias'):
19
+ if layer.bias is not None:
20
+ layer.bias.data.fill_(0.)
21
+
22
+
23
+ def init_bn(bn):
24
+ """Initialize a Batchnorm layer. """
25
+ bn.bias.data.fill_(0.)
26
+ bn.weight.data.fill_(1.)
27
+
28
+
29
+ class ConvBlock(nn.Module):
30
+ def __init__(self, in_channels, out_channels):
31
+
32
+ super(ConvBlock, self).__init__()
33
+
34
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
35
+ out_channels=out_channels,
36
+ kernel_size=(3, 3), stride=(1, 1),
37
+ padding=(1, 1), bias=False)
38
+
39
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
40
+ out_channels=out_channels,
41
+ kernel_size=(3, 3), stride=(1, 1),
42
+ padding=(1, 1), bias=False)
43
+
44
+ self.bn1 = nn.BatchNorm2d(out_channels)
45
+ self.bn2 = nn.BatchNorm2d(out_channels)
46
+
47
+ self.init_weight()
48
+
49
+ def init_weight(self):
50
+ init_layer(self.conv1)
51
+ init_layer(self.conv2)
52
+ init_bn(self.bn1)
53
+ init_bn(self.bn2)
54
+
55
+
56
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
57
+
58
+ x = input
59
+ x = F.relu_(self.bn1(self.conv1(x)))
60
+ x = F.relu_(self.bn2(self.conv2(x)))
61
+ if pool_type == 'max':
62
+ x = F.max_pool2d(x, kernel_size=pool_size)
63
+ elif pool_type == 'avg':
64
+ x = F.avg_pool2d(x, kernel_size=pool_size)
65
+ elif pool_type == 'avg+max':
66
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
67
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
68
+ x = x1 + x2
69
+ else:
70
+ raise Exception('Incorrect argument!')
71
+
72
+ return x
73
+
74
+
75
+ class ConvBlock5x5(nn.Module):
76
+ def __init__(self, in_channels, out_channels):
77
+
78
+ super(ConvBlock5x5, self).__init__()
79
+
80
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
81
+ out_channels=out_channels,
82
+ kernel_size=(5, 5), stride=(1, 1),
83
+ padding=(2, 2), bias=False)
84
+
85
+ self.bn1 = nn.BatchNorm2d(out_channels)
86
+
87
+ self.init_weight()
88
+
89
+ def init_weight(self):
90
+ init_layer(self.conv1)
91
+ init_bn(self.bn1)
92
+
93
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
94
+
95
+ x = input
96
+ x = F.relu_(self.bn1(self.conv1(x)))
97
+ if pool_type == 'max':
98
+ x = F.max_pool2d(x, kernel_size=pool_size)
99
+ elif pool_type == 'avg':
100
+ x = F.avg_pool2d(x, kernel_size=pool_size)
101
+ elif pool_type == 'avg+max':
102
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
103
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
104
+ x = x1 + x2
105
+ else:
106
+ raise Exception('Incorrect argument!')
107
+
108
+ return x
109
+
110
+
111
+ class Cnn6Encoder(nn.Module):
112
+
113
+ def __init__(self, sample_rate=32000, freeze=False):
114
+ super().__init__()
115
+
116
+ sr_to_fmax = {
117
+ 32000: 14000,
118
+ 16000: 8000
119
+ }
120
+ # Logmel spectrogram extractor
121
+ self.melspec_extractor = transforms.MelSpectrogram(
122
+ sample_rate=sample_rate,
123
+ n_fft=32 * sample_rate // 1000,
124
+ win_length=32 * sample_rate // 1000,
125
+ hop_length=10 * sample_rate // 1000,
126
+ f_min=50,
127
+ f_max=sr_to_fmax[sample_rate],
128
+ n_mels=64,
129
+ norm="slaney",
130
+ mel_scale="slaney"
131
+ )
132
+ self.hop_length = 10 * sample_rate // 1000
133
+ self.db_transform = transforms.AmplitudeToDB()
134
+
135
+ self.bn0 = nn.BatchNorm2d(64)
136
+
137
+ self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
138
+ self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
139
+ self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
140
+ self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
141
+
142
+ self.downsample_ratio = 16
143
+
144
+ self.fc1 = nn.Linear(512, 512, bias=True)
145
+ self.fc_emb_size = 512
146
+ self.init_weight()
147
+ self.freeze = freeze
148
+
149
+ def init_weight(self):
150
+ init_bn(self.bn0)
151
+ init_layer(self.fc1)
152
+
153
+ def load_pretrained(self, pretrained, output_fn):
154
+ checkpoint = torch.load(pretrained, map_location="cpu")
155
+
156
+ if "model" in checkpoint:
157
+ state_dict = checkpoint["model"]
158
+ else:
159
+ raise Exception("Unkown checkpoint format")
160
+
161
+ loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
162
+ if self.freeze:
163
+ for name, param in self.named_parameters():
164
+ if name in loaded_keys:
165
+ param.requires_grad = False
166
+ else:
167
+ param.requires_grad = True
168
+
169
+ def forward(self, input_dict):
170
+ waveform = input_dict["wav"]
171
+ wave_length = input_dict["wav_len"]
172
+ specaug = input_dict["specaug"]
173
+ x = self.melspec_extractor(waveform)
174
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
175
+ x = x.transpose(1, 2)
176
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
177
+
178
+ x = x.transpose(1, 3)
179
+ x = self.bn0(x)
180
+ x = x.transpose(1, 3)
181
+
182
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
183
+ x = F.dropout(x, p=0.2, training=self.training)
184
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
185
+ x = F.dropout(x, p=0.2, training=self.training)
186
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
187
+ x = F.dropout(x, p=0.2, training=self.training)
188
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
189
+ x = F.dropout(x, p=0.2, training=self.training)
190
+
191
+ x = torch.mean(x, dim=3)
192
+ attn_emb = x.transpose(1, 2)
193
+ wave_length = torch.as_tensor(wave_length)
194
+ feat_length = torch.div(wave_length, self.hop_length,
195
+ rounding_mode="floor") + 1
196
+ feat_length = torch.div(feat_length, self.downsample_ratio,
197
+ rounding_mode="floor")
198
+ x_max = max_with_lens(attn_emb, feat_length)
199
+ x_mean = mean_with_lens(attn_emb, feat_length)
200
+ x = x_max + x_mean
201
+ x = F.dropout(x, p=0.5, training=self.training)
202
+ x = F.relu_(self.fc1(x))
203
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
204
+
205
+ return {
206
+ "attn_emb": attn_emb,
207
+ "fc_emb": fc_emb,
208
+ "attn_emb_len": feat_length
209
+ }
210
+
211
+
212
+ class Cnn10Encoder(nn.Module):
213
+
214
+ def __init__(self, sample_rate=32000, freeze=False):
215
+ super().__init__()
216
+
217
+ sr_to_fmax = {
218
+ 32000: 14000,
219
+ 16000: 8000
220
+ }
221
+ # Logmel spectrogram extractor
222
+ self.melspec_extractor = transforms.MelSpectrogram(
223
+ sample_rate=sample_rate,
224
+ n_fft=32 * sample_rate // 1000,
225
+ win_length=32 * sample_rate // 1000,
226
+ hop_length=10 * sample_rate // 1000,
227
+ f_min=50,
228
+ f_max=sr_to_fmax[sample_rate],
229
+ n_mels=64,
230
+ norm="slaney",
231
+ mel_scale="slaney"
232
+ )
233
+ self.hop_length = 10 * sample_rate // 1000
234
+ self.db_transform = transforms.AmplitudeToDB()
235
+
236
+ self.bn0 = nn.BatchNorm2d(64)
237
+
238
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
239
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
240
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
241
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
242
+
243
+ self.downsample_ratio = 16
244
+
245
+ self.fc1 = nn.Linear(512, 512, bias=True)
246
+ self.fc_emb_size = 512
247
+ self.init_weight()
248
+ self.freeze = freeze
249
+
250
+ def init_weight(self):
251
+ init_bn(self.bn0)
252
+ init_layer(self.fc1)
253
+
254
+ def load_pretrained(self, pretrained, output_fn):
255
+ checkpoint = torch.load(pretrained, map_location="cpu")
256
+
257
+ if "model" in checkpoint:
258
+ state_dict = checkpoint["model"]
259
+ else:
260
+ raise Exception("Unkown checkpoint format")
261
+
262
+ loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
263
+ if self.freeze:
264
+ for name, param in self.named_parameters():
265
+ if name in loaded_keys:
266
+ param.requires_grad = False
267
+ else:
268
+ param.requires_grad = True
269
+
270
+ def forward(self, input_dict):
271
+ waveform = input_dict["wav"]
272
+ wave_length = input_dict["wav_len"]
273
+ specaug = input_dict["specaug"]
274
+ x = self.melspec_extractor(waveform)
275
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
276
+ x = x.transpose(1, 2)
277
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
278
+
279
+ x = x.transpose(1, 3)
280
+ x = self.bn0(x)
281
+ x = x.transpose(1, 3)
282
+
283
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
284
+ x = F.dropout(x, p=0.2, training=self.training)
285
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
286
+ x = F.dropout(x, p=0.2, training=self.training)
287
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
288
+ x = F.dropout(x, p=0.2, training=self.training)
289
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
290
+ x = F.dropout(x, p=0.2, training=self.training)
291
+
292
+ x = torch.mean(x, dim=3)
293
+ attn_emb = x.transpose(1, 2)
294
+ wave_length = torch.as_tensor(wave_length)
295
+ feat_length = torch.div(wave_length, self.hop_length,
296
+ rounding_mode="floor") + 1
297
+ feat_length = torch.div(feat_length, self.downsample_ratio,
298
+ rounding_mode="floor")
299
+ x_max = max_with_lens(attn_emb, feat_length)
300
+ x_mean = mean_with_lens(attn_emb, feat_length)
301
+ x = x_max + x_mean
302
+ x = F.dropout(x, p=0.5, training=self.training)
303
+ x = F.relu_(self.fc1(x))
304
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
305
+
306
+ return {
307
+ "attn_emb": attn_emb,
308
+ "fc_emb": fc_emb,
309
+ "attn_emb_len": feat_length
310
+ }
311
+
312
+
313
+ class Cnn14Encoder(nn.Module):
314
+ def __init__(self, sample_rate=32000, freeze=False):
315
+ super().__init__()
316
+ sr_to_fmax = {
317
+ 32000: 14000,
318
+ 16000: 8000
319
+ }
320
+ # Logmel spectrogram extractor
321
+ self.melspec_extractor = transforms.MelSpectrogram(
322
+ sample_rate=sample_rate,
323
+ n_fft=32 * sample_rate // 1000,
324
+ win_length=32 * sample_rate // 1000,
325
+ hop_length=10 * sample_rate // 1000,
326
+ f_min=50,
327
+ f_max=sr_to_fmax[sample_rate],
328
+ n_mels=64,
329
+ norm="slaney",
330
+ mel_scale="slaney"
331
+ )
332
+ self.hop_length = 10 * sample_rate // 1000
333
+ self.db_transform = transforms.AmplitudeToDB()
334
+
335
+ self.bn0 = nn.BatchNorm2d(64)
336
+
337
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
338
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
339
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
340
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
341
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
342
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
343
+
344
+ self.downsample_ratio = 32
345
+
346
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
347
+ self.fc_emb_size = 2048
348
+
349
+ self.init_weight()
350
+ self.freeze = freeze
351
+
352
+ def init_weight(self):
353
+ init_bn(self.bn0)
354
+ init_layer(self.fc1)
355
+
356
+ def load_pretrained(self, pretrained, output_fn):
357
+ checkpoint = torch.load(pretrained, map_location="cpu")
358
+
359
+ if "model" in checkpoint:
360
+ state_keys = checkpoint["model"].keys()
361
+ backbone = False
362
+ for key in state_keys:
363
+ if key.startswith("backbone."):
364
+ backbone = True
365
+ break
366
+
367
+ if backbone: # COLA
368
+ state_dict = {}
369
+ for key, value in checkpoint["model"].items():
370
+ if key.startswith("backbone."):
371
+ model_key = key.replace("backbone.", "")
372
+ state_dict[model_key] = value
373
+ else: # PANNs
374
+ state_dict = checkpoint["model"]
375
+ elif "state_dict" in checkpoint: # BLAT
376
+ state_dict = checkpoint["state_dict"]
377
+ state_dict_keys = list(filter(
378
+ lambda x: "audio_encoder" in x, state_dict.keys()))
379
+ state_dict = {
380
+ key.replace('audio_encoder.', ''): state_dict[key]
381
+ for key in state_dict_keys
382
+ }
383
+ else:
384
+ raise Exception("Unkown checkpoint format")
385
+
386
+ loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
387
+ if self.freeze:
388
+ for name, param in self.named_parameters():
389
+ if name in loaded_keys:
390
+ param.requires_grad = False
391
+ else:
392
+ param.requires_grad = True
393
+
394
+ def forward(self, input_dict):
395
+ waveform = input_dict["wav"]
396
+ wave_length = input_dict["wav_len"]
397
+ specaug = input_dict["specaug"]
398
+ x = self.melspec_extractor(waveform)
399
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
400
+ x = x.transpose(1, 2)
401
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
402
+
403
+ x = x.transpose(1, 3)
404
+ x = self.bn0(x)
405
+ x = x.transpose(1, 3)
406
+
407
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
408
+ x = F.dropout(x, p=0.2, training=self.training)
409
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
410
+ x = F.dropout(x, p=0.2, training=self.training)
411
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
412
+ x = F.dropout(x, p=0.2, training=self.training)
413
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
414
+ x = F.dropout(x, p=0.2, training=self.training)
415
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
416
+ x = F.dropout(x, p=0.2, training=self.training)
417
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
418
+ x = F.dropout(x, p=0.2, training=self.training)
419
+ x = torch.mean(x, dim=3)
420
+ attn_emb = x.transpose(1, 2)
421
+
422
+ wave_length = torch.as_tensor(wave_length)
423
+ feat_length = torch.div(wave_length, self.hop_length,
424
+ rounding_mode="floor") + 1
425
+ feat_length = torch.div(feat_length, self.downsample_ratio,
426
+ rounding_mode="floor")
427
+ x_max = max_with_lens(attn_emb, feat_length)
428
+ x_mean = mean_with_lens(attn_emb, feat_length)
429
+ x = x_max + x_mean
430
+ x = F.dropout(x, p=0.5, training=self.training)
431
+ x = F.relu_(self.fc1(x))
432
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
433
+
434
+ output_dict = {
435
+ 'fc_emb': fc_emb,
436
+ 'attn_emb': attn_emb,
437
+ 'attn_emb_len': feat_length
438
+ }
439
+
440
+ return output_dict
441
+
442
+
443
+ class InvertedResidual(nn.Module):
444
+
445
+ def __init__(self, inp, oup, stride, expand_ratio):
446
+ super().__init__()
447
+ self.stride = stride
448
+ assert stride in [1, 2]
449
+
450
+ hidden_dim = round(inp * expand_ratio)
451
+ self.use_res_connect = self.stride == 1 and inp == oup
452
+
453
+ if expand_ratio == 1:
454
+ _layers = [
455
+ nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
456
+ nn.AvgPool2d(stride),
457
+ nn.BatchNorm2d(hidden_dim),
458
+ nn.ReLU6(inplace=True),
459
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
460
+ nn.BatchNorm2d(oup)
461
+ ]
462
+ _layers = nn.Sequential(*_layers)
463
+ init_layer(_layers[0])
464
+ init_bn(_layers[2])
465
+ init_layer(_layers[4])
466
+ init_bn(_layers[5])
467
+ self.conv = _layers
468
+ else:
469
+ _layers = [
470
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
471
+ nn.BatchNorm2d(hidden_dim),
472
+ nn.ReLU6(inplace=True),
473
+ nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
474
+ nn.AvgPool2d(stride),
475
+ nn.BatchNorm2d(hidden_dim),
476
+ nn.ReLU6(inplace=True),
477
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
478
+ nn.BatchNorm2d(oup)
479
+ ]
480
+ _layers = nn.Sequential(*_layers)
481
+ init_layer(_layers[0])
482
+ init_bn(_layers[1])
483
+ init_layer(_layers[3])
484
+ init_bn(_layers[5])
485
+ init_layer(_layers[7])
486
+ init_bn(_layers[8])
487
+ self.conv = _layers
488
+
489
+ def forward(self, x):
490
+ if self.use_res_connect:
491
+ return x + self.conv(x)
492
+ else:
493
+ return self.conv(x)
494
+
495
+
496
+ class MobileNetV2(nn.Module):
497
+ def __init__(self, sample_rate):
498
+
499
+ super().__init__()
500
+
501
+ sr_to_fmax = {
502
+ 32000: 14000,
503
+ 16000: 8000
504
+ }
505
+ # Logmel spectrogram extractor
506
+ self.melspec_extractor = transforms.MelSpectrogram(
507
+ sample_rate=sample_rate,
508
+ n_fft=32 * sample_rate // 1000,
509
+ win_length=32 * sample_rate // 1000,
510
+ hop_length=10 * sample_rate // 1000,
511
+ f_min=50,
512
+ f_max=sr_to_fmax[sample_rate],
513
+ n_mels=64,
514
+ norm="slaney",
515
+ mel_scale="slaney"
516
+ )
517
+ self.hop_length = 10 * sample_rate // 1000
518
+ self.db_transform = transforms.AmplitudeToDB()
519
+
520
+ self.bn0 = nn.BatchNorm2d(64)
521
+
522
+ width_mult=1.
523
+ block = InvertedResidual
524
+ input_channel = 32
525
+ last_channel = 1280
526
+ interverted_residual_setting = [
527
+ # t, c, n, s
528
+ [1, 16, 1, 1],
529
+ [6, 24, 2, 2],
530
+ [6, 32, 3, 2],
531
+ [6, 64, 4, 2],
532
+ [6, 96, 3, 2],
533
+ [6, 160, 3, 1],
534
+ [6, 320, 1, 1],
535
+ ]
536
+
537
+ self.downsample_ratio = 32
538
+
539
+ def conv_bn(inp, oup, stride):
540
+ _layers = [
541
+ nn.Conv2d(inp, oup, 3, 1, 1, bias=False),
542
+ nn.AvgPool2d(stride),
543
+ nn.BatchNorm2d(oup),
544
+ nn.ReLU6(inplace=True)
545
+ ]
546
+ _layers = nn.Sequential(*_layers)
547
+ init_layer(_layers[0])
548
+ init_bn(_layers[2])
549
+ return _layers
550
+
551
+
552
+ def conv_1x1_bn(inp, oup):
553
+ _layers = nn.Sequential(
554
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
555
+ nn.BatchNorm2d(oup),
556
+ nn.ReLU6(inplace=True)
557
+ )
558
+ init_layer(_layers[0])
559
+ init_bn(_layers[1])
560
+ return _layers
561
+
562
+ # building first layer
563
+ input_channel = int(input_channel * width_mult)
564
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
565
+ self.features = [conv_bn(1, input_channel, 2)]
566
+ # building inverted residual blocks
567
+ for t, c, n, s in interverted_residual_setting:
568
+ output_channel = int(c * width_mult)
569
+ for i in range(n):
570
+ if i == 0:
571
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
572
+ else:
573
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
574
+ input_channel = output_channel
575
+ # building last several layers
576
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
577
+ # make it nn.Sequential
578
+ self.features = nn.Sequential(*self.features)
579
+
580
+ self.fc1 = nn.Linear(1280, 1024, bias=True)
581
+
582
+ self.init_weight()
583
+
584
+ def init_weight(self):
585
+ init_bn(self.bn0)
586
+ init_layer(self.fc1)
587
+
588
+ def forward(self, input_dict):
589
+
590
+ waveform = input_dict["wav"]
591
+ wave_length = input_dict["wav_len"]
592
+ specaug = input_dict["specaug"]
593
+ x = self.melspec_extractor(waveform)
594
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
595
+ x = x.transpose(1, 2)
596
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
597
+
598
+ x = x.transpose(1, 3)
599
+ x = self.bn0(x)
600
+ x = x.transpose(1, 3)
601
+
602
+ x = self.features(x)
603
+
604
+ x = torch.mean(x, dim=3)
605
+ attn_emb = x.transpose(1, 2)
606
+
607
+ wave_length = torch.as_tensor(wave_length)
608
+ feat_length = torch.div(wave_length, self.hop_length,
609
+ rounding_mode="floor") + 1
610
+ feat_length = torch.div(feat_length, self.downsample_ratio,
611
+ rounding_mode="floor")
612
+ x_max = max_with_lens(attn_emb, feat_length)
613
+ x_mean = mean_with_lens(attn_emb, feat_length)
614
+ x = x_max + x_mean
615
+ # TODO: the original PANNs code does not have dropout here, why?
616
+ x = F.dropout(x, p=0.5, training=self.training)
617
+ x = F.relu_(self.fc1(x))
618
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
619
+
620
+ output_dict = {
621
+ 'fc_emb': fc_emb,
622
+ 'attn_emb': attn_emb,
623
+ 'attn_emb_len': feat_length
624
+ }
625
+
626
+ return output_dict
627
+
628
+
629
+ class MobileNetV3(nn.Module):
630
+
631
+ def __init__(self,
632
+ sample_rate,
633
+ model_name,
634
+ n_mels=64,
635
+ win_length=32,
636
+ pretrained=True,
637
+ freeze=False,
638
+ pooling="mean_max_fc"):
639
+
640
+ from captioning.models.eff_at_encoder import get_model, NAME_TO_WIDTH
641
+
642
+ super().__init__()
643
+ sr_to_fmax = {
644
+ 32000: 14000,
645
+ 16000: 8000
646
+ }
647
+ self.n_mels = n_mels
648
+ # Logmel spectrogram extractor
649
+ self.melspec_extractor = transforms.MelSpectrogram(
650
+ sample_rate=sample_rate,
651
+ n_fft=32 * sample_rate // 1000,
652
+ win_length=win_length * sample_rate // 1000,
653
+ hop_length=10 * sample_rate // 1000,
654
+ f_min=50,
655
+ f_max=sr_to_fmax[sample_rate],
656
+ n_mels=n_mels,
657
+ norm="slaney",
658
+ mel_scale="slaney"
659
+ )
660
+ self.hop_length = 10 * sample_rate // 1000
661
+ self.db_transform = transforms.AmplitudeToDB()
662
+
663
+ self.bn0 = nn.BatchNorm2d(n_mels)
664
+
665
+ width_mult = NAME_TO_WIDTH(model_name)
666
+ self.features = get_model(model_name=model_name,
667
+ pretrained=pretrained,
668
+ width_mult=width_mult).features
669
+ self.downsample_ratio = 32
670
+
671
+ if pooling == "mean_max_fc":
672
+ self.fc_emb_size = 512
673
+ self.fc1 = nn.Linear(self.features[-1].out_channels, 512, bias=True)
674
+ elif pooling == "mean":
675
+ self.fc_emb_size = self.features[-1].out_channels
676
+ self.init_weight()
677
+
678
+ if freeze:
679
+ for param in self.parameters():
680
+ param.requires_grad = False
681
+
682
+ self.pooling = pooling
683
+
684
+ def init_weight(self):
685
+ init_bn(self.bn0)
686
+ if hasattr(self, "fc1"):
687
+ init_layer(self.fc1)
688
+
689
+ def forward(self, input_dict):
690
+
691
+ waveform = input_dict["wav"]
692
+ wave_length = input_dict["wav_len"]
693
+ specaug = input_dict["specaug"]
694
+ x = self.melspec_extractor(waveform)
695
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
696
+ x = x.transpose(1, 2)
697
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
698
+
699
+ x = x.transpose(1, 3)
700
+ x = self.bn0(x)
701
+ x = x.transpose(1, 3)
702
+
703
+ x = self.features(x)
704
+
705
+ x = torch.mean(x, dim=3)
706
+ attn_emb = x.transpose(1, 2)
707
+
708
+ wave_length = torch.as_tensor(wave_length)
709
+ feat_length = torch.div(wave_length, self.hop_length,
710
+ rounding_mode="floor") + 1
711
+ feat_length = torch.div(feat_length, self.downsample_ratio,
712
+ rounding_mode="floor")
713
+
714
+ if self.pooling == "mean_max_fc":
715
+ x_max = max_with_lens(attn_emb, feat_length)
716
+ x_mean = mean_with_lens(attn_emb, feat_length)
717
+ x = x_max + x_mean
718
+ x = F.dropout(x, p=0.5, training=self.training)
719
+ x = F.relu_(self.fc1(x))
720
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
721
+ elif self.pooling == "mean":
722
+ fc_emb = mean_with_lens(attn_emb, feat_length)
723
+
724
+ output_dict = {
725
+ 'fc_emb': fc_emb,
726
+ 'attn_emb': attn_emb,
727
+ 'attn_emb_len': feat_length
728
+ }
729
+
730
+ return output_dict
731
+
732
+
733
+ class EfficientNetB2(nn.Module):
734
+
735
+ def __init__(self,
736
+ n_mels: int = 64,
737
+ win_length: int = 32,
738
+ hop_length: int = 10,
739
+ f_min: int = 0,
740
+ pretrained: bool = False,
741
+ prune_ratio: float = 0.0,
742
+ prune_se: bool = True,
743
+ prune_start_layer: int = 0,
744
+ prune_method: str = "operator_norm",
745
+ freeze: bool = False,):
746
+ from models.eff_latent_encoder import get_model, get_pruned_model
747
+ super().__init__()
748
+ sample_rate = 16000
749
+ self.melspec_extractor = transforms.MelSpectrogram(
750
+ sample_rate=sample_rate,
751
+ n_fft=win_length * sample_rate // 1000,
752
+ win_length=win_length * sample_rate // 1000,
753
+ hop_length=hop_length * sample_rate // 1000,
754
+ f_min=f_min,
755
+ n_mels=n_mels,
756
+ )
757
+ self.hop_length = 10 * sample_rate // 1000
758
+ self.db_transform = transforms.AmplitudeToDB(top_db=120)
759
+ if prune_ratio > 0:
760
+ self.backbone = get_pruned_model(pretrained=pretrained,
761
+ prune_ratio=prune_ratio,
762
+ prune_start_layer=prune_start_layer,
763
+ prune_se=prune_se,
764
+ prune_method=prune_method)
765
+ else:
766
+ self.backbone = get_model(pretrained=pretrained)
767
+ self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels
768
+ self.downsample_ratio = 32
769
+ if freeze:
770
+ for param in self.parameters():
771
+ param.requires_grad = False
772
+
773
+ def forward(self, input_dict):
774
+
775
+ waveform = input_dict["wav"]
776
+ wave_length = input_dict["wav_len"]
777
+ specaug = input_dict["specaug"]
778
+ x = self.melspec_extractor(waveform)
779
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
780
+
781
+ x = self.backbone(x)
782
+ attn_emb = x
783
+
784
+ wave_length = torch.as_tensor(wave_length)
785
+ feat_length = torch.div(wave_length, self.hop_length,
786
+ rounding_mode="floor") + 1
787
+ feat_length = torch.div(feat_length, self.downsample_ratio,
788
+ rounding_mode="floor")
789
+ fc_emb = mean_with_lens(attn_emb, feat_length)
790
+
791
+ output_dict = {
792
+ 'fc_emb': fc_emb,
793
+ 'attn_emb': attn_emb,
794
+ 'attn_emb_len': feat_length
795
+ }
796
+ return output_dict
797
+
798
+
799
+ if __name__ == "__main__":
800
+ encoder = MobileNetV3(32000, "mn10_as")
801
+ print(encoder)
802
+ input_dict = {
803
+ "wav": torch.randn(4, 320000),
804
+ "wav_len": torch.tensor([320000, 280000, 160000, 300000]),
805
+ "specaug": True
806
+ }
807
+ output_dict = encoder(input_dict)
808
+ print("attn embed: ", output_dict["attn_emb"].shape)
809
+ print("fc embed: ", output_dict["fc_emb"].shape)
810
+ print("attn embed length: ", output_dict["attn_emb_len"])
models/eff_latent_encoder.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+ from efficientnet_pytorch import EfficientNet
7
+ from efficientnet_pytorch.model import MBConvBlock
8
+ from efficientnet_pytorch import utils as efficientnet_utils
9
+ from efficientnet_pytorch.utils import (
10
+ round_filters,
11
+ round_repeats,
12
+ get_same_padding_conv2d,
13
+ calculate_output_image_size,
14
+ MemoryEfficientSwish,
15
+ )
16
+ from einops import rearrange, reduce
17
+ from torch.hub import load_state_dict_from_url
18
+
19
+
20
+ model_dir = "./"
21
+
22
+
23
+ class _EffiNet(nn.Module):
24
+ """A proxy for efficient net models"""
25
+ def __init__(self,
26
+ blocks_args=None,
27
+ global_params=None,
28
+ prune_start_layer: int = 0,
29
+ prune_se: bool = True,
30
+ prune_ratio: float = 0.0
31
+ ) -> None:
32
+ super().__init__()
33
+ if prune_ratio > 0:
34
+ self.eff_net = EfficientNetB2Pruned(blocks_args=blocks_args,
35
+ global_params=global_params,
36
+ prune_start_layer=prune_start_layer,
37
+ prune_se=prune_se,
38
+ prune_ratio=prune_ratio)
39
+ else:
40
+ self.eff_net = EfficientNet(blocks_args=blocks_args,
41
+ global_params=global_params)
42
+
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ x = rearrange(x, 'b f t -> b 1 f t')
46
+ x = self.eff_net.extract_features(x)
47
+ return reduce(x, 'b c f t -> b t c', 'mean')
48
+
49
+
50
+ def get_model(pretrained=True) -> _EffiNet:
51
+ blocks_args, global_params = efficientnet_utils.get_model_params(
52
+ 'efficientnet-b2', {'include_top': False})
53
+ model = _EffiNet(blocks_args=blocks_args,
54
+ global_params=global_params)
55
+ model.eff_net._change_in_channels(1)
56
+ if pretrained:
57
+ model_path = os.path.join(model_dir, "effb2.pt")
58
+ if not os.path.exists(model_path):
59
+ state_dict = load_state_dict_from_url(
60
+ 'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt',
61
+ progress=True,
62
+ model_dir=model_dir)
63
+ else:
64
+ state_dict = torch.load(model_path)
65
+ del_keys = [key for key in state_dict if key.startswith("front_end")]
66
+ for key in del_keys:
67
+ del state_dict[key]
68
+ model.eff_net.load_state_dict(state_dict)
69
+ return model
70
+
71
+
72
+ class MBConvBlockPruned(MBConvBlock):
73
+
74
+ def __init__(self, block_args, global_params, image_size=None, prune_ratio=0.5, prune_se=True):
75
+ super(MBConvBlock, self).__init__()
76
+ self._block_args = block_args
77
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
78
+ self._bn_eps = global_params.batch_norm_epsilon
79
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
80
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
81
+
82
+ # Expansion phase (Inverted Bottleneck)
83
+ inp = self._block_args.input_filters # number of input channels
84
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
85
+ if self._block_args.expand_ratio != 1:
86
+ oup = int(oup * (1 - prune_ratio))
87
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
88
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
89
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
90
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
91
+
92
+ # Depthwise convolution phase
93
+ k = self._block_args.kernel_size
94
+ s = self._block_args.stride
95
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
96
+ self._depthwise_conv = Conv2d(
97
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
98
+ kernel_size=k, stride=s, bias=False)
99
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
100
+ image_size = calculate_output_image_size(image_size, s)
101
+
102
+ # Squeeze and Excitation layer, if desired
103
+ if self.has_se:
104
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
105
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
106
+ if prune_se:
107
+ num_squeezed_channels = int(num_squeezed_channels * (1 - prune_ratio))
108
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
109
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
110
+
111
+ # Pointwise convolution phase
112
+ final_oup = self._block_args.output_filters
113
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
114
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
115
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
116
+ self._swish = MemoryEfficientSwish()
117
+
118
+
119
+ class EfficientNetB2Pruned(EfficientNet):
120
+
121
+ def __init__(self, blocks_args=None, global_params=None,
122
+ prune_start_layer=0, prune_ratio=0.5, prune_se=True):
123
+ super(EfficientNet, self).__init__()
124
+ assert isinstance(blocks_args, list), 'blocks_args should be a list'
125
+ assert len(blocks_args) > 0, 'block args must be greater than 0'
126
+ self._global_params = global_params
127
+ self._blocks_args = blocks_args
128
+
129
+ # Batch norm parameters
130
+ bn_mom = 1 - self._global_params.batch_norm_momentum
131
+ bn_eps = self._global_params.batch_norm_epsilon
132
+
133
+ # Get stem static or dynamic convolution depending on image size
134
+ image_size = global_params.image_size
135
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
136
+
137
+ n_build_blks = 0
138
+ # Stem
139
+ in_channels = 1 # spectrogram
140
+
141
+ p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
142
+ out_channels = round_filters(32 * (1 - p),
143
+ self._global_params) # number of output channels
144
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
145
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
146
+ image_size = calculate_output_image_size(image_size, 2)
147
+ n_build_blks += 1
148
+
149
+ # Build blocks
150
+ self._blocks = nn.ModuleList([])
151
+ for block_args in self._blocks_args:
152
+
153
+ p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
154
+ orig_input_filters = block_args.input_filters
155
+ # Update block input and output filters based on depth multiplier.
156
+ block_args = block_args._replace(
157
+ input_filters=round_filters(
158
+ block_args.input_filters * (1 - p),
159
+ self._global_params),
160
+ output_filters=round_filters(
161
+ block_args.output_filters * (1 - p),
162
+ self._global_params),
163
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params)
164
+ )
165
+
166
+ if n_build_blks == prune_start_layer:
167
+ block_args = block_args._replace(input_filters=round_filters(
168
+ orig_input_filters,
169
+ self._global_params)
170
+ )
171
+
172
+ # The first block needs to take care of stride and filter size increase.
173
+ self._blocks.append(MBConvBlockPruned(block_args, self._global_params,
174
+ image_size=image_size, prune_ratio=p,
175
+ prune_se=prune_se))
176
+ n_build_blks += 1
177
+
178
+ image_size = calculate_output_image_size(image_size, block_args.stride)
179
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
180
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
181
+ for _ in range(block_args.num_repeat - 1):
182
+ self._blocks.append(MBConvBlockPruned(block_args,
183
+ self._global_params,
184
+ image_size=image_size,
185
+ prune_ratio=p,
186
+ prune_se=prune_se))
187
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
188
+
189
+ # Head
190
+ in_channels = block_args.output_filters # output of final block
191
+ p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
192
+ out_channels = round_filters(1280 * (1 - p), self._global_params)
193
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
194
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
195
+ self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
196
+
197
+ # Final linear layer
198
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
199
+ if self._global_params.include_top:
200
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
201
+ self._fc = nn.Linear(out_channels, self._global_params.num_classes)
202
+
203
+ # set activation to memory efficient swish by default
204
+ self._swish = MemoryEfficientSwish()
205
+
206
+
207
+ def get_pruned_model(pretrained: bool = True,
208
+ prune_ratio: float = 0.5,
209
+ prune_start_layer: int = 0,
210
+ prune_se: bool = True,
211
+ prune_method: str = "operator_norm") -> _EffiNet:
212
+
213
+ import captioning.models.conv_filter_pruning as pruning_lib
214
+
215
+ blocks_args, global_params = efficientnet_utils.get_model_params(
216
+ 'efficientnet-b2', {'include_top': False})
217
+ # print("num blocks: ", len(blocks_args))
218
+ # print("block args: ")
219
+ # for block_arg in blocks_args:
220
+ # print(block_arg)
221
+ model = _EffiNet(blocks_args=blocks_args,
222
+ global_params=global_params,
223
+ prune_start_layer=prune_start_layer,
224
+ prune_se=prune_se,
225
+ prune_ratio=prune_ratio)
226
+
227
+ if prune_method == "operator_norm":
228
+ filter_pruning = pruning_lib.operator_norm_pruning
229
+ elif prune_method == "interspeech":
230
+ filter_pruning = pruning_lib.cs_interspeech
231
+ elif prune_method == "iclr_l1":
232
+ filter_pruning = pruning_lib.iclr_l1
233
+ elif prune_method == "iclr_gm":
234
+ filter_pruning = pruning_lib.iclr_gm
235
+ elif prune_method == "cs_waspaa":
236
+ filter_pruning = pruning_lib.cs_waspaa
237
+
238
+
239
+ if isinstance(pretrained, str):
240
+ ckpt = torch.load(pretrained, "cpu")
241
+ state_dict = {}
242
+ for key in ckpt["model"].keys():
243
+ if key.startswith("model.encoder.backbone"):
244
+ state_dict[key[len("model.encoder.backbone.eff_net."):]] = ckpt["model"][key]
245
+ elif isinstance(pretrained, bool):
246
+ model_path = os.path.join(model_dir, "effb2.pt")
247
+ if not os.path.exists(model_path):
248
+ state_dict = load_state_dict_from_url(
249
+ 'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt',
250
+ progress=True,
251
+ model_dir=model_dir)
252
+ else:
253
+ state_dict = torch.load(model_path)
254
+ del_keys = [key for key in state_dict if key.startswith("front_end")]
255
+ for key in del_keys:
256
+ del state_dict[key]
257
+
258
+ # load pretrained model with corresponding filters
259
+ # rule:
260
+ # * depthwise_conv: in_ch_idx = out_ch_idx = prev_conv_idx
261
+ mod_dep_path = [
262
+ "_conv_stem",
263
+ ]
264
+ conv_to_bn = {"_conv_stem": "_bn0"}
265
+ for i in range(2):
266
+ mod_dep_path.extend([
267
+ f"_blocks.{i}._depthwise_conv",
268
+ f"_blocks.{i}._se_reduce",
269
+ f"_blocks.{i}._se_expand",
270
+ f"_blocks.{i}._project_conv",
271
+ ])
272
+ conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1"
273
+ conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2"
274
+
275
+ for i in range(2, 23):
276
+ mod_dep_path.extend([
277
+ f"_blocks.{i}._expand_conv",
278
+ f"_blocks.{i}._depthwise_conv",
279
+ f"_blocks.{i}._se_reduce",
280
+ f"_blocks.{i}._se_expand",
281
+ f"_blocks.{i}._project_conv"
282
+ ])
283
+ conv_to_bn[f"_blocks.{i}._expand_conv"] = f"_blocks.{i}._bn0"
284
+ conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1"
285
+ conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2"
286
+
287
+ mod_dep_path.append("_conv_head")
288
+ conv_to_bn["_conv_head"] = "_bn1"
289
+
290
+ # print(mod_dep_path)
291
+ # print(conv_to_bn)
292
+
293
+ key_to_w_b_idx = {}
294
+ model_dict = model.eff_net.state_dict()
295
+ for conv_key in tqdm(mod_dep_path):
296
+ weight = state_dict[f"{conv_key}.weight"]
297
+ ptr_n_filter = weight.size(0)
298
+ model_n_filter = model_dict[f"{conv_key}.weight"].size(0)
299
+ if model_n_filter < ptr_n_filter:
300
+ key_to_w_b_idx[conv_key] = filter_pruning(weight.numpy())[:model_n_filter]
301
+ else:
302
+ key_to_w_b_idx[conv_key] = slice(None)
303
+
304
+ pruned_state_dict = {}
305
+ for conv_key, prev_conv_key in zip(mod_dep_path, [None] + mod_dep_path[:-1]):
306
+
307
+ for sub_key in ["weight", "bias"]: # adjust the conv layer
308
+ cur_key = f"{conv_key}.{sub_key}"
309
+
310
+ if cur_key not in state_dict:
311
+ continue
312
+
313
+ if prev_conv_key is None or conv_key.endswith("_depthwise_conv"):
314
+ conv_in_idx = slice(None)
315
+ else:
316
+ conv_in_idx = key_to_w_b_idx[prev_conv_key]
317
+
318
+ # the first pruned layer
319
+ if model_dict[cur_key].ndim > 1 and model_dict[cur_key].size(1) == state_dict[cur_key].size(1):
320
+ conv_in_idx = slice(None)
321
+
322
+ if conv_key.endswith("_depthwise_conv"):
323
+ conv_out_idx = key_to_w_b_idx[prev_conv_key]
324
+ else:
325
+ conv_out_idx = key_to_w_b_idx[conv_key]
326
+
327
+ # if conv_key == "_blocks.16._se_reduce":
328
+ # print(len(conv_out_idx), len(conv_in_idx))
329
+
330
+ if sub_key == "weight":
331
+ pruned_state_dict[cur_key] = state_dict[cur_key][
332
+ conv_out_idx, ...][:, conv_in_idx, ...]
333
+ else:
334
+ pruned_state_dict[cur_key] = state_dict[cur_key][
335
+ conv_out_idx, ...]
336
+
337
+ if conv_key in conv_to_bn: # adjust the corresponding bn layer
338
+ for sub_key in ["weight", "bias", "running_mean", "running_var"]:
339
+ cur_key = f"{conv_to_bn[conv_key]}.{sub_key}"
340
+ if cur_key not in state_dict:
341
+ continue
342
+ pruned_state_dict[cur_key] = state_dict[cur_key][
343
+ key_to_w_b_idx[conv_key], ...]
344
+
345
+ model.eff_net.load_state_dict(pruned_state_dict)
346
+
347
+ return model
models/kd_wrapper.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import repeat
8
+
9
+ from models.base import CaptionMetaMixin
10
+ from utils.model_util import init
11
+
12
+
13
+ class WmlEncoderKdWrapper(nn.Module, CaptionMetaMixin):
14
+
15
+ def __init__(self,
16
+ model: nn.Module,
17
+ shared_dim: int,
18
+ tchr_layer_to_dims: Dict[str, int],
19
+ loss_type: str = "mse",):
20
+ super().__init__()
21
+ self.model = model
22
+ self.tchr_layers = list(tchr_layer_to_dims.keys())
23
+ self.stdnt_qv_proj = nn.Linear(model.encoder.fc_emb_size,
24
+ 2 * shared_dim)
25
+ self.stdnt_qv_proj.apply(init)
26
+ for layer, dim in tchr_layer_to_dims.items():
27
+ self.add_module(f'tchr_kv_proj_{layer}', nn.Linear(dim, 2 * shared_dim))
28
+ getattr(self, f'tchr_kv_proj_{layer}').apply(init)
29
+ if loss_type == "mse":
30
+ self.loss_fn = nn.MSELoss(reduction="none")
31
+
32
+ def forward(self, input_dict: Dict):
33
+ output_dict = self.model(input_dict)
34
+ if "tchr_output" in input_dict:
35
+ stdnt_emb = output_dict["fc_emb"]
36
+ stdnt_qv = self.stdnt_qv_proj(stdnt_emb)
37
+ stdnt_q, stdnt_v = torch.chunk(stdnt_qv, 2, dim=-1)
38
+
39
+ tchr_output = input_dict["tchr_output"]
40
+ layer_ks, layer_vs = [], []
41
+ for layer in self.tchr_layers:
42
+ layer_kv = getattr(self, f'tchr_kv_proj_{layer}')(tchr_output[layer])
43
+ layer_k, layer_v = torch.chunk(layer_kv, 2, dim=-1)
44
+ layer_ks.append(layer_k)
45
+ layer_vs.append(layer_v)
46
+ layer_ks = torch.stack(layer_ks, dim=1)
47
+ layer_vs = torch.stack(layer_vs, dim=1)
48
+ weights = torch.softmax(stdnt_q.unsqueeze(1) @ layer_ks.transpose(1, 2), dim=-1)
49
+ stdnt_v = repeat(stdnt_v, 'b d -> b n d', n=len(self.tchr_layers))
50
+ loss = self.loss_fn(stdnt_v, layer_vs).mean(dim=-1, keepdim=True)
51
+ loss = (weights @ loss).mean()
52
+ output_dict["enc_kd_loss"] = loss
53
+ return output_dict
54
+
55
+
56
+ class MseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
57
+
58
+ def __init__(self,
59
+ model: nn.Module,
60
+ shared_dim: int,
61
+ tchr_dim: int,
62
+ use_tchr_proj: bool = True,
63
+ l2_norm: bool = False,
64
+ ):
65
+ super().__init__()
66
+ self.model = model
67
+ self.use_tchr_proj = use_tchr_proj
68
+ if not use_tchr_proj:
69
+ assert shared_dim == tchr_dim
70
+ self.tchr_dim = tchr_dim
71
+ self.l2_norm = l2_norm
72
+ if hasattr(model, "encoder"):
73
+ self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
74
+ shared_dim)
75
+ else:
76
+ self.stdnt_proj = nn.Linear(model.fc_emb_size,
77
+ shared_dim)
78
+ self.stdnt_proj.apply(init)
79
+ if use_tchr_proj:
80
+ self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
81
+ self.tchr_proj.apply(init)
82
+ else:
83
+ self.tchr_proj = nn.Identity()
84
+
85
+ def forward(self, input_dict: Dict):
86
+ unsup = input_dict.get("unsup", False)
87
+ if unsup is False:
88
+ if self.use_tchr_proj:
89
+ output_dict = self.model(input_dict)
90
+ stdnt_emb = output_dict["fc_emb"]
91
+ else:
92
+ encoder_output = self.model.encoder(input_dict)
93
+ stdnt_emb = encoder_output["fc_emb"]
94
+ encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
95
+ encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
96
+ output_dict = self.model.forward_decoder(input_dict, encoder_output)
97
+ else:
98
+ output_dict = self.model.encoder(input_dict)
99
+ stdnt_emb = output_dict["fc_emb"]
100
+ if "tchr_output" in input_dict:
101
+ stdnt_emb = self.stdnt_proj(stdnt_emb)
102
+ tchr_emb = input_dict["tchr_output"]["embedding"]
103
+ thcr_emb = self.tchr_proj(tchr_emb)
104
+
105
+ if self.l2_norm:
106
+ stdnt_emb = F.normalize(stdnt_emb, dim=-1)
107
+ thcr_emb = F.normalize(thcr_emb, dim=-1)
108
+
109
+ loss = F.mse_loss(stdnt_emb, thcr_emb)
110
+ output_dict["enc_kd_loss"] = loss
111
+ return output_dict
112
+
113
+
114
+ class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin):
115
+
116
+ def __init__(self,
117
+ model: nn.Module,
118
+ shared_dim: int,
119
+ tchr_dim: int,
120
+ ):
121
+ super().__init__()
122
+ self.model = model
123
+ self.tchr_dim = tchr_dim
124
+ if hasattr(model, "encoder"):
125
+ self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
126
+ shared_dim)
127
+ else:
128
+ self.stdnt_proj = nn.Linear(model.fc_emb_size,
129
+ shared_dim)
130
+ self.stdnt_proj.apply(init)
131
+ self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
132
+ self.tchr_proj.apply(init)
133
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
134
+
135
+ def forward(self, input_dict: Dict):
136
+ unsup = input_dict.get("unsup", False)
137
+ if unsup is False:
138
+ output_dict = self.model(input_dict)
139
+ else:
140
+ output_dict = self.model.encoder(input_dict)
141
+ if "tchr_output" in input_dict:
142
+ stdnt_emb = output_dict["fc_emb"]
143
+ stdnt_emb = self.stdnt_proj(stdnt_emb)
144
+ tchr_emb = input_dict["tchr_output"]["embedding"]
145
+ thcr_emb = self.tchr_proj(tchr_emb)
146
+
147
+ stdnt_emb = F.normalize(stdnt_emb, dim=-1)
148
+ thcr_emb = F.normalize(thcr_emb, dim=-1)
149
+
150
+ unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
151
+ logit = self.logit_scale * unscaled_logit
152
+ label = torch.arange(logit.shape[0]).to(logit.device)
153
+ loss1 = F.cross_entropy(logit, label)
154
+ loss2 = F.cross_entropy(logit.transpose(0, 1), label)
155
+ loss = (loss1 + loss2) / 2
156
+ output_dict["enc_kd_loss"] = loss
157
+ return output_dict
158
+
159
+
160
+ class ContraMseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
161
+
162
+ def __init__(self,
163
+ model: nn.Module,
164
+ shared_dim: int,
165
+ tchr_dim: int,
166
+ use_tchr_proj: bool = True,
167
+ l2_norm: bool = False,
168
+ ):
169
+ super().__init__()
170
+ self.model = model
171
+ self.use_tchr_proj = use_tchr_proj
172
+ if not use_tchr_proj:
173
+ assert shared_dim == tchr_dim
174
+ self.tchr_dim = tchr_dim
175
+ self.l2_norm = l2_norm
176
+ if hasattr(model, "encoder"):
177
+ self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
178
+ shared_dim)
179
+ else:
180
+ self.stdnt_proj = nn.Linear(model.fc_emb_size,
181
+ shared_dim)
182
+ self.stdnt_proj.apply(init)
183
+ if use_tchr_proj:
184
+ self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
185
+ self.tchr_proj.apply(init)
186
+ else:
187
+ self.tchr_proj = nn.Identity()
188
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
189
+
190
+ def forward(self, input_dict: Dict):
191
+ unsup = input_dict.get("unsup", False)
192
+ if unsup is False:
193
+ if self.use_tchr_proj:
194
+ output_dict = self.model(input_dict)
195
+ stdnt_emb = output_dict["fc_emb"]
196
+ else:
197
+ encoder_output = self.model.encoder(input_dict)
198
+ stdnt_emb = encoder_output["fc_emb"]
199
+ encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
200
+ encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
201
+ output_dict = self.model.forward_decoder(input_dict, encoder_output)
202
+ else:
203
+ output_dict = self.model.encoder(input_dict)
204
+ stdnt_emb = output_dict["fc_emb"]
205
+ if "tchr_output" in input_dict:
206
+ stdnt_emb = self.stdnt_proj(stdnt_emb)
207
+ tchr_emb = input_dict["tchr_output"]["embedding"]
208
+ thcr_emb = self.tchr_proj(tchr_emb)
209
+
210
+ if self.l2_norm:
211
+ stdnt_emb = F.normalize(stdnt_emb, dim=-1)
212
+ thcr_emb = F.normalize(thcr_emb, dim=-1)
213
+
214
+ mse_loss = F.mse_loss(stdnt_emb, thcr_emb)
215
+
216
+ stdnt_emb = F.normalize(stdnt_emb, dim=-1)
217
+ thcr_emb = F.normalize(thcr_emb, dim=-1)
218
+ unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
219
+ logit = self.logit_scale * unscaled_logit
220
+ label = torch.arange(logit.shape[0]).to(logit.device)
221
+ loss1 = F.cross_entropy(logit, label)
222
+ loss2 = F.cross_entropy(logit.transpose(0, 1), label)
223
+ cntr_loss = (loss1 + loss2) / 2
224
+ output_dict["enc_kd_loss"] = mse_loss + cntr_loss
225
+
226
+ return output_dict
models/transformer_decoder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models import BaseDecoder
7
+ from utils.model_util import generate_length_mask, PositionalEncoding
8
+ from utils.train_util import merge_load_state_dict
9
+
10
+
11
+ class TransformerDecoder(BaseDecoder):
12
+
13
+ def __init__(self,
14
+ emb_dim,
15
+ vocab_size,
16
+ fc_emb_dim,
17
+ attn_emb_dim,
18
+ dropout,
19
+ freeze=False,
20
+ tie_weights=False,
21
+ **kwargs):
22
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
23
+ dropout=dropout, tie_weights=tie_weights)
24
+ self.d_model = emb_dim
25
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
26
+ self.nlayers = kwargs.get("nlayers", 2)
27
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
28
+
29
+ self.pos_encoder = PositionalEncoding(self.d_model, dropout)
30
+ layer = nn.TransformerDecoderLayer(d_model=self.d_model,
31
+ nhead=self.nhead,
32
+ dim_feedforward=self.dim_feedforward,
33
+ dropout=dropout)
34
+ self.model = nn.TransformerDecoder(layer, self.nlayers)
35
+ self.classifier = nn.Linear(self.d_model, vocab_size, bias=False)
36
+ if tie_weights:
37
+ self.classifier.weight = self.word_embedding.weight
38
+ self.attn_proj = nn.Sequential(
39
+ nn.Linear(self.attn_emb_dim, self.d_model),
40
+ nn.ReLU(),
41
+ nn.Dropout(dropout),
42
+ nn.LayerNorm(self.d_model)
43
+ )
44
+ self.init_params()
45
+
46
+ self.freeze = freeze
47
+ if freeze:
48
+ for p in self.parameters():
49
+ p.requires_grad = False
50
+
51
+ def init_params(self):
52
+ for p in self.parameters():
53
+ if p.dim() > 1:
54
+ nn.init.xavier_uniform_(p)
55
+
56
+ def load_pretrained(self, pretrained, output_fn):
57
+ checkpoint = torch.load(pretrained, map_location="cpu")
58
+
59
+ if "model" in checkpoint:
60
+ checkpoint = checkpoint["model"]
61
+ if next(iter(checkpoint)).startswith("decoder."):
62
+ state_dict = {}
63
+ for k, v in checkpoint.items():
64
+ state_dict[k[8:]] = v
65
+
66
+ loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
67
+ if self.freeze:
68
+ for name, param in self.named_parameters():
69
+ if name in loaded_keys:
70
+ param.requires_grad = False
71
+ else:
72
+ param.requires_grad = True
73
+
74
+
75
+ def generate_square_subsequent_mask(self, max_length):
76
+ mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
77
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
78
+ return mask
79
+
80
+ def forward(self, input_dict):
81
+ word = input_dict["word"]
82
+ attn_emb = input_dict["attn_emb"]
83
+ attn_emb_len = input_dict["attn_emb_len"]
84
+ cap_padding_mask = input_dict["cap_padding_mask"]
85
+
86
+ p_attn_emb = self.attn_proj(attn_emb)
87
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
88
+ word = word.to(attn_emb.device)
89
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
90
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
91
+ embed = self.pos_encoder(embed)
92
+
93
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
94
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
95
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
96
+ tgt_key_padding_mask=cap_padding_mask,
97
+ memory_key_padding_mask=memory_key_padding_mask)
98
+ output = output.transpose(0, 1)
99
+ output = {
100
+ "embed": output,
101
+ "logit": self.classifier(output),
102
+ }
103
+ return output
104
+
105
+
106
+ class M2TransformerDecoder(BaseDecoder):
107
+
108
+ def __init__(self, vocab_size, fc_emb_dim, attn_emb_dim, dropout=0.1, **kwargs):
109
+ super().__init__(attn_emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout=dropout,)
110
+ try:
111
+ from m2transformer.models.transformer import MeshedDecoder
112
+ except:
113
+ raise ImportError("meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`")
114
+ del self.word_embedding
115
+ del self.in_dropout
116
+
117
+ self.d_model = attn_emb_dim
118
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
119
+ self.nlayers = kwargs.get("nlayers", 2)
120
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
121
+ self.model = MeshedDecoder(vocab_size, 100, self.nlayers, 0,
122
+ d_model=self.d_model,
123
+ h=self.nhead,
124
+ d_ff=self.dim_feedforward,
125
+ dropout=dropout)
126
+ self.init_params()
127
+
128
+ def init_params(self):
129
+ for p in self.parameters():
130
+ if p.dim() > 1:
131
+ nn.init.xavier_uniform_(p)
132
+
133
+ def forward(self, input_dict):
134
+ word = input_dict["word"]
135
+ attn_emb = input_dict["attn_emb"]
136
+ attn_emb_mask = input_dict["attn_emb_mask"]
137
+ word = word.to(attn_emb.device)
138
+ embed, logit = self.model(word, attn_emb, attn_emb_mask)
139
+ output = {
140
+ "embed": embed,
141
+ "logit": logit,
142
+ }
143
+ return output
144
+
145
+
146
+ class EventTransformerDecoder(TransformerDecoder):
147
+
148
+ def forward(self, input_dict):
149
+ word = input_dict["word"] # index of word embeddings
150
+ attn_emb = input_dict["attn_emb"]
151
+ attn_emb_len = input_dict["attn_emb_len"]
152
+ cap_padding_mask = input_dict["cap_padding_mask"]
153
+ event_emb = input_dict["event"] # [N, emb_dim]
154
+
155
+ p_attn_emb = self.attn_proj(attn_emb)
156
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
157
+ word = word.to(attn_emb.device)
158
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
159
+
160
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
161
+ embed += event_emb
162
+ embed = self.pos_encoder(embed)
163
+
164
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
165
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
166
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
167
+ tgt_key_padding_mask=cap_padding_mask,
168
+ memory_key_padding_mask=memory_key_padding_mask)
169
+ output = output.transpose(0, 1)
170
+ output = {
171
+ "embed": output,
172
+ "logit": self.classifier(output),
173
+ }
174
+ return output
175
+
176
+
177
+ class KeywordProbTransformerDecoder(TransformerDecoder):
178
+
179
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
180
+ dropout, keyword_classes_num, **kwargs):
181
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
182
+ dropout, **kwargs)
183
+ self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
184
+ self.word_keyword_norm = nn.LayerNorm(self.d_model)
185
+
186
+ def forward(self, input_dict):
187
+ word = input_dict["word"] # index of word embeddings
188
+ attn_emb = input_dict["attn_emb"]
189
+ attn_emb_len = input_dict["attn_emb_len"]
190
+ cap_padding_mask = input_dict["cap_padding_mask"]
191
+ keyword = input_dict["keyword"] # [N, keyword_classes_num]
192
+
193
+ p_attn_emb = self.attn_proj(attn_emb)
194
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
195
+ word = word.to(attn_emb.device)
196
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
197
+
198
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
199
+ embed += self.keyword_proj(keyword)
200
+ embed = self.word_keyword_norm(embed)
201
+
202
+ embed = self.pos_encoder(embed)
203
+
204
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
205
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
206
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
207
+ tgt_key_padding_mask=cap_padding_mask,
208
+ memory_key_padding_mask=memory_key_padding_mask)
209
+ output = output.transpose(0, 1)
210
+ output = {
211
+ "embed": output,
212
+ "logit": self.classifier(output),
213
+ }
214
+ return output
models/transformer_model.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models.base import CaptionModel
7
+ from utils.model_util import repeat_tensor
8
+ import models.transformer_decoder
9
+
10
+
11
+ class TransformerModel(CaptionModel):
12
+
13
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
14
+ if not hasattr(self, "compatible_decoders"):
15
+ self.compatible_decoders = (
16
+ models.transformer_decoder.TransformerDecoder,
17
+ )
18
+ super().__init__(encoder, decoder, **kwargs)
19
+
20
+ def seq_forward(self, input_dict):
21
+ cap = input_dict["cap"]
22
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
23
+ cap_padding_mask = cap_padding_mask[:, :-1]
24
+ output = self.decoder(
25
+ {
26
+ "word": cap[:, :-1],
27
+ "attn_emb": input_dict["attn_emb"],
28
+ "attn_emb_len": input_dict["attn_emb_len"],
29
+ "cap_padding_mask": cap_padding_mask
30
+ }
31
+ )
32
+ return output
33
+
34
+ def prepare_decoder_input(self, input_dict, output):
35
+ decoder_input = {
36
+ "attn_emb": input_dict["attn_emb"],
37
+ "attn_emb_len": input_dict["attn_emb_len"]
38
+ }
39
+ t = input_dict["t"]
40
+
41
+ ###############
42
+ # determine input word
43
+ ################
44
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
45
+ word = input_dict["cap"][:, :t+1]
46
+ else:
47
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
48
+ if t == 0:
49
+ word = start_word
50
+ else:
51
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
52
+ # word: [N, T]
53
+ decoder_input["word"] = word
54
+
55
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
56
+ decoder_input["cap_padding_mask"] = cap_padding_mask
57
+ return decoder_input
58
+
59
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
60
+ decoder_input = {}
61
+ t = input_dict["t"]
62
+ i = input_dict["sample_idx"]
63
+ beam_size = input_dict["beam_size"]
64
+ ###############
65
+ # prepare attn embeds
66
+ ################
67
+ if t == 0:
68
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
69
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
70
+ output_i["attn_emb"] = attn_emb
71
+ output_i["attn_emb_len"] = attn_emb_len
72
+ decoder_input["attn_emb"] = output_i["attn_emb"]
73
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
74
+ ###############
75
+ # determine input word
76
+ ################
77
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
78
+ if t == 0:
79
+ word = start_word
80
+ else:
81
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
82
+ decoder_input["word"] = word
83
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
84
+ decoder_input["cap_padding_mask"] = cap_padding_mask
85
+
86
+ return decoder_input
87
+
88
+
89
+ class M2TransformerModel(CaptionModel):
90
+
91
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
92
+ if not hasattr(self, "compatible_decoders"):
93
+ self.compatible_decoders = (
94
+ models.transformer_decoder.M2TransformerDecoder,
95
+ )
96
+ super().__init__(encoder, decoder, **kwargs)
97
+ self.check_encoder_compatibility()
98
+
99
+ def check_encoder_compatibility(self):
100
+ assert isinstance(self.encoder, models.encoder.M2TransformerEncoder), \
101
+ f"only M2TransformerModel is compatible with {self.__class__.__name__}"
102
+
103
+ def seq_forward(self, input_dict):
104
+ cap = input_dict["cap"]
105
+ output = self.decoder(
106
+ {
107
+ "word": cap[:, :-1],
108
+ "attn_emb": input_dict["attn_emb"],
109
+ "attn_emb_mask": input_dict["attn_emb_mask"],
110
+ }
111
+ )
112
+ return output
113
+
114
+ def prepare_decoder_input(self, input_dict, output):
115
+ decoder_input = {
116
+ "attn_emb": input_dict["attn_emb"],
117
+ "attn_emb_mask": input_dict["attn_emb_mask"]
118
+ }
119
+ t = input_dict["t"]
120
+
121
+ ###############
122
+ # determine input word
123
+ ################
124
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
125
+ word = input_dict["cap"][:, :t+1]
126
+ else:
127
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
128
+ if t == 0:
129
+ word = start_word
130
+ else:
131
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
132
+ # word: [N, T]
133
+ decoder_input["word"] = word
134
+
135
+ return decoder_input
136
+
137
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
138
+ decoder_input = {}
139
+ t = input_dict["t"]
140
+ i = input_dict["sample_idx"]
141
+ beam_size = input_dict["beam_size"]
142
+ ###############
143
+ # prepare attn embeds
144
+ ################
145
+ if t == 0:
146
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
147
+ attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
148
+ output_i["attn_emb"] = attn_emb
149
+ output_i["attn_emb_mask"] = attn_emb_mask
150
+ decoder_input["attn_emb"] = output_i["attn_emb"]
151
+ decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
152
+ ###############
153
+ # determine input word
154
+ ################
155
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
156
+ if t == 0:
157
+ word = start_word
158
+ else:
159
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
160
+ decoder_input["word"] = word
161
+
162
+ return decoder_input
163
+
164
+
165
+ class EventEncoder(nn.Module):
166
+ """
167
+ Encode the Label information in AudioCaps and AudioSet
168
+ """
169
+ def __init__(self, emb_dim, vocab_size=527):
170
+ super(EventEncoder, self).__init__()
171
+ self.label_embedding = nn.Parameter(
172
+ torch.randn((vocab_size, emb_dim)), requires_grad=True)
173
+
174
+ def forward(self, word_idxs):
175
+ indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
176
+ embeddings = indices @ self.label_embedding
177
+ return embeddings
178
+
179
+
180
+ class EventCondTransformerModel(TransformerModel):
181
+
182
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
183
+ if not hasattr(self, "compatible_decoders"):
184
+ self.compatible_decoders = (
185
+ models.transformer_decoder.EventTransformerDecoder,
186
+ )
187
+ super().__init__(encoder, decoder, **kwargs)
188
+ self.label_encoder = EventEncoder(decoder.emb_dim, 527)
189
+ self.train_forward_keys += ["events"]
190
+ self.inference_forward_keys += ["events"]
191
+
192
+ # def seq_forward(self, input_dict):
193
+ # cap = input_dict["cap"]
194
+ # cap_padding_mask = (cap == self.pad_idx).to(cap.device)
195
+ # cap_padding_mask = cap_padding_mask[:, :-1]
196
+ # output = self.decoder(
197
+ # {
198
+ # "word": cap[:, :-1],
199
+ # "attn_emb": input_dict["attn_emb"],
200
+ # "attn_emb_len": input_dict["attn_emb_len"],
201
+ # "cap_padding_mask": cap_padding_mask
202
+ # }
203
+ # )
204
+ # return output
205
+
206
+ def prepare_decoder_input(self, input_dict, output):
207
+ decoder_input = super().prepare_decoder_input(input_dict, output)
208
+ decoder_input["events"] = self.label_encoder(input_dict["events"])
209
+ return decoder_input
210
+
211
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
212
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
213
+ t = input_dict["t"]
214
+ i = input_dict["sample_idx"]
215
+ beam_size = input_dict["beam_size"]
216
+ if t == 0:
217
+ output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
218
+ decoder_input["events"] = output_i["events"]
219
+ return decoder_input
220
+
221
+
222
+ class KeywordCondTransformerModel(TransformerModel):
223
+
224
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
225
+ if not hasattr(self, "compatible_decoders"):
226
+ self.compatible_decoders = (
227
+ models.transformer_decoder.KeywordProbTransformerDecoder,
228
+ )
229
+ super().__init__(encoder, decoder, **kwargs)
230
+ self.train_forward_keys += ["keyword"]
231
+ self.inference_forward_keys += ["keyword"]
232
+
233
+ def seq_forward(self, input_dict):
234
+ cap = input_dict["cap"]
235
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
236
+ cap_padding_mask = cap_padding_mask[:, :-1]
237
+ keyword = input_dict["keyword"]
238
+ output = self.decoder(
239
+ {
240
+ "word": cap[:, :-1],
241
+ "attn_emb": input_dict["attn_emb"],
242
+ "attn_emb_len": input_dict["attn_emb_len"],
243
+ "keyword": keyword,
244
+ "cap_padding_mask": cap_padding_mask
245
+ }
246
+ )
247
+ return output
248
+
249
+ def prepare_decoder_input(self, input_dict, output):
250
+ decoder_input = super().prepare_decoder_input(input_dict, output)
251
+ decoder_input["keyword"] = input_dict["keyword"]
252
+ return decoder_input
253
+
254
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
255
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
256
+ t = input_dict["t"]
257
+ i = input_dict["sample_idx"]
258
+ beam_size = input_dict["beam_size"]
259
+ if t == 0:
260
+ output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
261
+ beam_size)
262
+ decoder_input["keyword"] = output_i["keyword"]
263
+ return decoder_input
264
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ efficientnet_pytorch
2
+ PyYAML
text_tokenizer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ from utils.train_util import pad_sequence
6
+
7
+
8
+ class DictTokenizer:
9
+
10
+ def __init__(self,
11
+ tokenizer_path: str = None,
12
+ max_length: int = 20) -> None:
13
+ self.word2idx = {}
14
+ self.idx2word = {}
15
+ self.idx = 0
16
+ self.add_word("<pad>")
17
+ self.add_word("<start>")
18
+ self.add_word("<end>")
19
+ self.add_word("<unk>")
20
+ if tokenizer_path is not None and Path(tokenizer_path).exists():
21
+ state_dict = pickle.load(open(tokenizer_path, "rb"))
22
+ self.load_state_dict(state_dict)
23
+ self.loaded = True
24
+ else:
25
+ self.loaded = False
26
+ self.bos, self.eos = self.word2idx["<start>"], self.word2idx["<end>"]
27
+ self.pad = self.word2idx["<pad>"]
28
+ self.max_length = max_length
29
+
30
+ def add_word(self, word):
31
+ if not word in self.word2idx:
32
+ self.word2idx[word] = self.idx
33
+ self.idx2word[self.idx] = word
34
+ self.idx += 1
35
+
36
+ def encode_word(self, word):
37
+ if word in self.word2idx:
38
+ return self.word2idx[word]
39
+ else:
40
+ return self.word2idx["<unk>"]
41
+
42
+ def __call__(self, texts):
43
+ assert isinstance(texts, list), "the input must be List[str]"
44
+ batch_tokens = []
45
+ for text in texts:
46
+ tokens = [self.encode_word(token) for token in text.split()][:self.max_length]
47
+ tokens = [self.bos] + tokens + [self.eos]
48
+ tokens = np.array(tokens)
49
+ batch_tokens.append(tokens)
50
+ caps, cap_lens = pad_sequence(batch_tokens, self.pad)
51
+ return {
52
+ "cap": caps,
53
+ "cap_len": cap_lens
54
+ }
55
+
56
+ def decode(self, batch_token_ids):
57
+ output = []
58
+ for token_ids in batch_token_ids:
59
+ tokens = []
60
+ for token_id in token_ids:
61
+ if token_id == self.eos:
62
+ break
63
+ elif token_id == self.bos:
64
+ continue
65
+ tokens.append(self.idx2word[token_id])
66
+ output.append(" ".join(tokens))
67
+ return output
68
+
69
+ def __len__(self):
70
+ return len(self.word2idx)
71
+
72
+ def state_dict(self):
73
+ return self.word2idx
74
+
75
+ def load_state_dict(self, state_dict):
76
+ self.word2idx = state_dict
77
+ self.idx2word = {idx: word for word, idx in self.word2idx.items()}
78
+ self.idx = len(self.word2idx)
79
+
80
+
81
+ class HuggingfaceTokenizer:
82
+
83
+ def __init__(self,
84
+ model_name_or_path,
85
+ max_length) -> None:
86
+ from transformers import AutoTokenizer
87
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
88
+ self.max_length = max_length
89
+ self.bos, self.eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
90
+ self.pad = self.tokenizer.pad_token_id
91
+ self.loaded = True
92
+
93
+ def __call__(self, texts):
94
+ assert isinstance(texts, list), "the input must be List[str]"
95
+ batch_token_dict = self.tokenizer(texts,
96
+ padding=True,
97
+ truncation=True,
98
+ max_length=self.max_length,
99
+ return_tensors="pt")
100
+ batch_token_dict["cap"] = batch_token_dict["input_ids"]
101
+ cap_lens = batch_token_dict["attention_mask"].sum(dim=1)
102
+ cap_lens = cap_lens.numpy().astype(np.int32)
103
+ batch_token_dict["cap_len"] = cap_lens
104
+ return batch_token_dict
105
+
106
+ def decode(self, batch_token_ids):
107
+ return self.tokenizer.batch_decode(batch_token_ids, skip_special_tokens=True)
utils/model_util.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
8
+
9
+
10
+ def sort_pack_padded_sequence(input, lengths):
11
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
12
+ tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
13
+ inv_ix = indices.clone()
14
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
15
+ return tmp, inv_ix
16
+
17
+ def pad_unsort_packed_sequence(input, inv_ix):
18
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
19
+ tmp = tmp[inv_ix]
20
+ return tmp
21
+
22
+ def pack_wrapper(module, attn_feats, attn_feat_lens):
23
+ packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
24
+ if isinstance(module, torch.nn.RNNBase):
25
+ return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
26
+ else:
27
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
28
+
29
+ def generate_length_mask(lens, max_length=None):
30
+ lens = torch.as_tensor(lens)
31
+ N = lens.size(0)
32
+ if max_length is None:
33
+ max_length = max(lens)
34
+ if isinstance(max_length, torch.Tensor):
35
+ max_length = max_length.item()
36
+ idxs = torch.arange(max_length).repeat(N).view(N, max_length)
37
+ idxs = idxs.to(lens.device)
38
+ mask = (idxs < lens.view(-1, 1))
39
+ return mask
40
+
41
+ def mean_with_lens(features, lens):
42
+ """
43
+ features: [N, T, ...] (assume the second dimension represents length)
44
+ lens: [N,]
45
+ """
46
+ lens = torch.as_tensor(lens)
47
+ if max(lens) != features.size(1):
48
+ max_length = features.size(1)
49
+ mask = generate_length_mask(lens, max_length)
50
+ else:
51
+ mask = generate_length_mask(lens)
52
+ mask = mask.to(features.device) # [N, T]
53
+
54
+ while mask.ndim < features.ndim:
55
+ mask = mask.unsqueeze(-1)
56
+ feature_mean = features * mask
57
+ feature_mean = feature_mean.sum(1)
58
+ while lens.ndim < feature_mean.ndim:
59
+ lens = lens.unsqueeze(1)
60
+ feature_mean = feature_mean / lens.to(features.device)
61
+ # feature_mean = features * mask.unsqueeze(-1)
62
+ # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
63
+ return feature_mean
64
+
65
+ def max_with_lens(features, lens):
66
+ """
67
+ features: [N, T, ...] (assume the second dimension represents length)
68
+ lens: [N,]
69
+ """
70
+ lens = torch.as_tensor(lens)
71
+ if max(lens) != features.size(1):
72
+ max_length = features.size(1)
73
+ mask = generate_length_mask(lens, max_length)
74
+ else:
75
+ mask = generate_length_mask(lens)
76
+ mask = mask.to(features.device) # [N, T]
77
+
78
+ feature_max = features.clone()
79
+ feature_max[~mask] = float("-inf")
80
+ feature_max, _ = feature_max.max(1)
81
+ return feature_max
82
+
83
+ def repeat_tensor(x, n):
84
+ return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
85
+
86
+ def init(m, method="kaiming"):
87
+ if isinstance(m, (nn.Conv2d, nn.Conv1d)):
88
+ if method == "kaiming":
89
+ nn.init.kaiming_uniform_(m.weight)
90
+ elif method == "xavier":
91
+ nn.init.xavier_uniform_(m.weight)
92
+ else:
93
+ raise Exception(f"initialization method {method} not supported")
94
+ if m.bias is not None:
95
+ nn.init.constant_(m.bias, 0)
96
+ elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
97
+ nn.init.constant_(m.weight, 1)
98
+ if m.bias is not None:
99
+ nn.init.constant_(m.bias, 0)
100
+ elif isinstance(m, nn.Linear):
101
+ if method == "kaiming":
102
+ nn.init.kaiming_uniform_(m.weight)
103
+ elif method == "xavier":
104
+ nn.init.xavier_uniform_(m.weight)
105
+ else:
106
+ raise Exception(f"initialization method {method} not supported")
107
+ if m.bias is not None:
108
+ nn.init.constant_(m.bias, 0)
109
+ elif isinstance(m, nn.Embedding):
110
+ if method == "kaiming":
111
+ nn.init.kaiming_uniform_(m.weight)
112
+ elif method == "xavier":
113
+ nn.init.xavier_uniform_(m.weight)
114
+ else:
115
+ raise Exception(f"initialization method {method} not supported")
116
+
117
+ def compute_batch_score(decode_res,
118
+ key2refs,
119
+ keys,
120
+ start_idx,
121
+ end_idx,
122
+ vocabulary,
123
+ scorer):
124
+ """
125
+ Args:
126
+ decode_res: decoding results of model, [N, max_length]
127
+ key2refs: references of all samples, dict(<key> -> [ref_1, ref_2, ..., ref_n]
128
+ keys: keys of this batch, used to match decode results and refs
129
+ Return:
130
+ scores of this batch, [N,]
131
+ """
132
+
133
+ if scorer is None:
134
+ from pycocoevalcap.cider.cider import Cider
135
+ scorer = Cider()
136
+
137
+ hypothesis = {}
138
+ references = {}
139
+
140
+ for i in range(len(keys)):
141
+
142
+ if keys[i] in hypothesis.keys():
143
+ continue
144
+
145
+ # prepare candidate sentence
146
+ candidate = []
147
+ for w_t in decode_res[i]:
148
+ if w_t == start_idx:
149
+ continue
150
+ elif w_t == end_idx:
151
+ break
152
+ candidate.append(vocabulary.idx2word[w_t])
153
+
154
+ hypothesis[keys[i]] = [" ".join(candidate), ]
155
+
156
+ # prepare reference sentences
157
+ references[keys[i]] = key2refs[keys[i]]
158
+
159
+ score, scores = scorer.compute_score(references, hypothesis)
160
+ key2score = {key: scores[i] for i, key in enumerate(references.keys())}
161
+ results = np.zeros(decode_res.shape[0])
162
+ for i in range(decode_res.shape[0]):
163
+ results[i] = key2score[keys[i]]
164
+ return results
165
+
166
+
167
+ class PositionalEncoding(nn.Module):
168
+
169
+ def __init__(self, d_model, dropout=0.1, max_len=100):
170
+ super(PositionalEncoding, self).__init__()
171
+ self.dropout = nn.Dropout(p=dropout)
172
+
173
+ pe = torch.zeros(max_len, d_model)
174
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
175
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
176
+ (-math.log(10000.0) / d_model))
177
+ pe[:, 0::2] = torch.sin(position * div_term)
178
+ pe[:, 1::2] = torch.cos(position * div_term)
179
+ pe = pe.unsqueeze(0).transpose(0, 1)
180
+ # self.register_buffer("pe", pe)
181
+ self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
182
+
183
+ def forward(self, x):
184
+ # x: [T, N, E]
185
+ x = x + self.pe[:x.size(0), :]
186
+ return self.dropout(x)
utils/train_util.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import sys
4
+ from typing import Callable, Dict, Union
5
+
6
+ import numpy as np
7
+ import yaml
8
+ import torch
9
+
10
+
11
+ def merge_a_into_b(a, b):
12
+ # merge dict a into dict b. values in a will overwrite b.
13
+ for k, v in a.items():
14
+ if isinstance(v, dict) and k in b:
15
+ assert isinstance(
16
+ b[k], dict
17
+ ), "Cannot inherit key '{}' from base!".format(k)
18
+ merge_a_into_b(v, b[k])
19
+ else:
20
+ b[k] = v
21
+
22
+
23
+ def load_config(config_file):
24
+ with open(config_file, "r") as reader:
25
+ config = yaml.load(reader, Loader=yaml.FullLoader)
26
+ if "inherit_from" in config:
27
+ base_config_file = config["inherit_from"]
28
+ base_config_file = os.path.join(
29
+ os.path.dirname(config_file), base_config_file
30
+ )
31
+ assert not os.path.samefile(config_file, base_config_file), \
32
+ "inherit from itself"
33
+ base_config = load_config(base_config_file)
34
+ del config["inherit_from"]
35
+ merge_a_into_b(config, base_config)
36
+ return base_config
37
+ return config
38
+
39
+ def get_cls_from_str(string, reload=False):
40
+ module_name, cls_name = string.rsplit(".", 1)
41
+ if reload:
42
+ module_imp = importlib.import_module(module_name)
43
+ importlib.reload(module_imp)
44
+ return getattr(importlib.import_module(module_name, package=None), cls_name)
45
+
46
+ def init_obj_from_dict(config, **kwargs):
47
+ obj_args = config["args"].copy()
48
+ obj_args.update(kwargs)
49
+ for k in config:
50
+ if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs:
51
+ obj_args[k] = init_obj_from_dict(config[k])
52
+ try:
53
+ obj = get_cls_from_str(config["type"])(**obj_args)
54
+ return obj
55
+ except Exception as e:
56
+ print(f"Initializing {config} failed, detailed error stack: ")
57
+ raise e
58
+
59
+ def init_model_from_config(config, print_fn=sys.stdout.write):
60
+ kwargs = {}
61
+ for k in config:
62
+ if k not in ["type", "args", "pretrained"]:
63
+ sub_model = init_model_from_config(config[k], print_fn)
64
+ if "pretrained" in config[k]:
65
+ load_pretrained_model(sub_model,
66
+ config[k]["pretrained"],
67
+ print_fn)
68
+ kwargs[k] = sub_model
69
+ model = init_obj_from_dict(config, **kwargs)
70
+ return model
71
+
72
+ def merge_load_state_dict(state_dict,
73
+ model: torch.nn.Module,
74
+ output_fn: Callable = sys.stdout.write):
75
+ model_dict = model.state_dict()
76
+ pretrained_dict = {}
77
+ mismatch_keys = []
78
+ for key, value in state_dict.items():
79
+ if key in model_dict and model_dict[key].shape == value.shape:
80
+ pretrained_dict[key] = value
81
+ else:
82
+ mismatch_keys.append(key)
83
+ output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}")
84
+ model_dict.update(pretrained_dict)
85
+ model.load_state_dict(model_dict, strict=True)
86
+ return pretrained_dict.keys()
87
+
88
+
89
+ def load_pretrained_model(model: torch.nn.Module,
90
+ pretrained: Union[str, Dict],
91
+ output_fn: Callable = sys.stdout.write):
92
+ if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
93
+ output_fn(f"pretrained {pretrained} not exist!")
94
+ return
95
+
96
+ if hasattr(model, "load_pretrained"):
97
+ model.load_pretrained(pretrained, output_fn)
98
+ return
99
+
100
+ if isinstance(pretrained, dict):
101
+ state_dict = pretrained
102
+ else:
103
+ state_dict = torch.load(pretrained, map_location="cpu")
104
+
105
+ if "model" in state_dict:
106
+ state_dict = state_dict["model"]
107
+
108
+ merge_load_state_dict(state_dict, model, output_fn)
109
+
110
+ def pad_sequence(data, pad_value=0):
111
+ if isinstance(data[0], (np.ndarray, torch.Tensor)):
112
+ data = [torch.as_tensor(arr) for arr in data]
113
+ padded_seq = torch.nn.utils.rnn.pad_sequence(data,
114
+ batch_first=True,
115
+ padding_value=pad_value)
116
+ length = np.array([x.shape[0] for x in data])
117
+ return padded_seq, length