wsntxxn commited on
Commit
dd3d338
1 Parent(s): f729a94

Change to Hugging Face calling

Browse files
app.py CHANGED
@@ -1,25 +1,28 @@
1
- from pathlib import Path
2
- import argparse
3
  from functools import partial
4
  import gradio as gr
5
  import torch
6
  from torchaudio.functional import resample
 
7
 
8
- import utils.train_util as train_util
9
 
10
-
11
- def load_model(cfg,
12
- ckpt_path,
13
  device):
14
- model = train_util.init_model_from_config(cfg["model"])
15
- ckpt = torch.load(ckpt_path, "cpu")
16
- train_util.load_pretrained_model(model, ckpt)
17
- model.eval()
18
- model = model.to(device)
19
- tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
20
- if not tokenizer.loaded:
21
- tokenizer.load_state_dict(ckpt["tokenizer"])
22
- model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
 
 
 
 
 
 
 
23
  return model, tokenizer
24
 
25
 
@@ -34,19 +37,13 @@ def infer(file, runner):
34
  wav = wav.mean(1)
35
  wav = resample(wav, sr, runner.target_sr)
36
  wav_len = len(wav)
37
- wav = wav.float().unsqueeze(0).to(runner.device)
38
- input_dict = {
39
- "mode": "inference",
40
- "wav": wav,
41
- "wav_len": [wav_len],
42
- "specaug": False,
43
- "sample_method": "beam",
44
- "beam_size": 3,
45
- }
46
  with torch.no_grad():
47
- output_dict = runner.model(input_dict)
48
- seq = output_dict["seq"].cpu().numpy()
49
- cap = runner.tokenizer.decode(seq)[0]
 
 
50
  return cap
51
 
52
  # def input_toggle(input_type):
@@ -59,16 +56,12 @@ class InferRunner:
59
 
60
  def __init__(self, model_name):
61
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
- exp_dir = Path(f"./checkpoints/{model_name.lower()}")
63
- cfg = train_util.load_config(exp_dir / "config.yaml")
64
- self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
65
- self.target_sr = cfg["target_sr"]
66
 
67
  def change_model(self, model_name):
68
- exp_dir = Path(f"./checkpoints/{model_name.lower()}")
69
- cfg = train_util.load_config(exp_dir / "config.yaml")
70
- self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
71
- self.target_sr = cfg["target_sr"]
72
 
73
 
74
  def change_model(radio):
 
 
 
1
  from functools import partial
2
  import gradio as gr
3
  import torch
4
  from torchaudio.functional import resample
5
+ from transformers import AutoModel, PreTrainedTokenizerFast
6
 
 
7
 
8
+ def load_model(model_name,
 
 
9
  device):
10
+ if model_name == "AudioCaps":
11
+ model = AutoModel.from_pretrained(
12
+ "wsntxxn/effb2-trm-audiocaps-captioning",
13
+ trust_remote_code=True
14
+ ).to(device)
15
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
16
+ "wsntxxn/audiocaps-simple-tokenizer"
17
+ )
18
+ elif model_name == "Clotho":
19
+ model = AutoModel.from_pretrained(
20
+ "wsntxxn/effb2-trm-clotho-captioning",
21
+ trust_remote_code=True
22
+ ).to(device)
23
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
24
+ "wsntxxn/clotho-simple-tokenizer"
25
+ )
26
  return model, tokenizer
27
 
28
 
 
37
  wav = wav.mean(1)
38
  wav = resample(wav, sr, runner.target_sr)
39
  wav_len = len(wav)
40
+ wav = wav.float().unsqueeze(0)
 
 
 
 
 
 
 
 
41
  with torch.no_grad():
42
+ word_idx = runner.model(
43
+ audio=wav,
44
+ audio_length=[wav_len]
45
+ )[0]
46
+ cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
47
  return cap
48
 
49
  # def input_toggle(input_type):
 
56
 
57
  def __init__(self, model_name):
58
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ self.model, self.tokenizer = load_model(model_name, self.device)
60
+ self.target_sr = self.model.config.sample_rate
 
 
61
 
62
  def change_model(self, model_name):
63
+ self.model, self.tokenizer = load_model(model_name, self.device)
64
+ self.target_sr = self.model.config.sample_rate
 
 
65
 
66
 
67
  def change_model(radio):
checkpoints/audiocaps/ckpt.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e1c435b1cf05a2b0058dae6f096c4eb4e71c685a19754ed84ea1ee812257434b
3
- size 55293225
 
 
 
 
checkpoints/audiocaps/config.yaml DELETED
@@ -1,30 +0,0 @@
1
- tokenizer:
2
- type: text_tokenizer.DictTokenizer
3
- args:
4
- max_length: 20
5
-
6
- target_sr: 16000
7
-
8
- model:
9
- args:
10
- shared_dim: 1024
11
- tchr_dim: 768
12
- model:
13
- args: {}
14
- decoder:
15
- args:
16
- attn_emb_dim: 1408
17
- dropout: 0.2
18
- emb_dim: 256
19
- fc_emb_dim: 1408
20
- nlayers: 2
21
- tie_weights: true
22
- vocab_size: 4981
23
- type: models.transformer_decoder.TransformerDecoder
24
- encoder:
25
- args:
26
- freeze: false
27
- pretrained: true
28
- type: models.cnn_encoder.EfficientNetB2
29
- type: models.transformer_model.TransformerModel
30
- type: models.kd_wrapper.ContraEncoderKdWrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints/clotho/ckpt.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:694c9e7139be7ec5aff2153d1af980d6bc305403a76be0d8940481579ea51483
3
- size 54651005
 
 
 
 
checkpoints/clotho/config.yaml DELETED
@@ -1,30 +0,0 @@
1
- tokenizer:
2
- type: text_tokenizer.DictTokenizer
3
- args:
4
- max_length: 20
5
-
6
- target_sr: 16000
7
-
8
- model:
9
- args:
10
- shared_dim: 1024
11
- tchr_dim: 768
12
- model:
13
- args: {}
14
- decoder:
15
- args:
16
- attn_emb_dim: 1408
17
- dropout: 0.2
18
- emb_dim: 256
19
- fc_emb_dim: 1408
20
- nlayers: 2
21
- tie_weights: true
22
- vocab_size: 4368
23
- type: models.transformer_decoder.TransformerDecoder
24
- encoder:
25
- args:
26
- freeze: false
27
- pretrained: true
28
- type: models.cnn_encoder.EfficientNetB2
29
- type: models.transformer_model.TransformerModel
30
- type: models.kd_wrapper.ContraEncoderKdWrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__init__.py DELETED
@@ -1,92 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
-
5
- from utils.model_util import max_with_lens, mean_with_lens
6
-
7
-
8
- def embedding_pooling(x, lens, pooling="mean"):
9
- if pooling == "max":
10
- fc_embs = max_with_lens(x, lens)
11
- elif pooling == "mean":
12
- fc_embs = mean_with_lens(x, lens)
13
- elif pooling == "mean+max":
14
- x_mean = mean_with_lens(x, lens)
15
- x_max = max_with_lens(x, lens)
16
- fc_embs = x_mean + x_max
17
- elif pooling == "last":
18
- indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
19
- # indices: [N, 1, hidden]
20
- fc_embs = torch.gather(x, 1, indices).squeeze(1)
21
- else:
22
- raise Exception(f"pooling method {pooling} not support")
23
- return fc_embs
24
-
25
-
26
- class BaseEncoder(nn.Module):
27
-
28
- """
29
- Encode the given audio into embedding
30
- Base encoder class, cannot be called directly
31
- All encoders should inherit from this class
32
- """
33
-
34
- def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
35
- super(BaseEncoder, self).__init__()
36
- self.spec_dim = spec_dim
37
- self.fc_feat_dim = fc_feat_dim
38
- self.attn_feat_dim = attn_feat_dim
39
-
40
-
41
- def forward(self, x):
42
- #########################
43
- # Arguments:
44
- # `x`: {
45
- # (may contain)
46
- # wav: [batch_size, n_samples],
47
- # spec: [batch_size, n_frames, spec_dim],
48
- # fc: [batch_size, fc_feat_dim],
49
- # attn: [batch_size, attn_max_len, attn_feat_dim],
50
- # attn_len: [batch_size,]
51
- # ......
52
- # }
53
- #
54
- # Returns:
55
- # `encoded`: {
56
- # fc_emb: [batch_size, fc_emb_dim],
57
- # attn_emb: [batch_size, attn_max_len, attn_emb_dim],
58
- # attn_emb_lens: [batch_size,]
59
- # }
60
- #########################
61
- raise NotImplementedError
62
-
63
-
64
- class BaseDecoder(nn.Module):
65
- """
66
- Take word/audio embeddings and output the next word probs
67
- """
68
- def __init__(self, emb_dim, vocab_size, fc_emb_dim,
69
- attn_emb_dim, dropout=0.2, tie_weights=False):
70
- super().__init__()
71
- self.emb_dim = emb_dim
72
- self.vocab_size = vocab_size
73
- self.fc_emb_dim = fc_emb_dim
74
- self.attn_emb_dim = attn_emb_dim
75
- self.tie_weights = tie_weights
76
- self.word_embedding = nn.Embedding(vocab_size, emb_dim)
77
- self.in_dropout = nn.Dropout(dropout)
78
-
79
- def forward(self, x):
80
- raise NotImplementedError
81
-
82
- def load_word_embedding(self, weight, freeze=True):
83
- embedding = np.load(weight)
84
- assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
85
- assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
86
-
87
- # embeddings = torch.as_tensor(embeddings).float()
88
- # self.word_embeddings.weight = nn.Parameter(embeddings)
89
- # for para in self.word_embeddings.parameters():
90
- # para.requires_grad = tune
91
- self.word_embedding = nn.Embedding.from_pretrained(embedding,
92
- freeze=freeze)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/base.py DELETED
@@ -1,504 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- from typing import Dict
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from utils.model_util import mean_with_lens, repeat_tensor
9
-
10
-
11
- class CaptionMetaMixin:
12
- pad_idx = 0
13
- start_idx = 1
14
- end_idx = 2
15
- max_length = 20
16
-
17
- @classmethod
18
- def set_index(cls, start_idx, end_idx, pad_idx):
19
- cls.start_idx = start_idx
20
- cls.end_idx = end_idx
21
- cls.pad_idx = pad_idx
22
-
23
-
24
- class CaptionModel(nn.Module, CaptionMetaMixin):
25
- """
26
- Encoder-decoder captioning model.
27
- """
28
-
29
- def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
30
- super().__init__()
31
- self.encoder = encoder
32
- self.decoder = decoder
33
- self.vocab_size = decoder.vocab_size
34
- self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
35
- self.inference_forward_keys = ["sample_method", "max_length", "temp"]
36
- freeze_encoder = kwargs.get("freeze_encoder", False)
37
- if freeze_encoder:
38
- for param in self.encoder.parameters():
39
- param.requires_grad = False
40
- self.check_decoder_compatibility()
41
-
42
- def check_decoder_compatibility(self):
43
- compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
44
- assert isinstance(self.decoder, self.compatible_decoders), \
45
- f"{self.decoder.__class__.__name__} is incompatible with " \
46
- f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
47
-
48
- def forward(self, input_dict: Dict):
49
- """
50
- input_dict: {
51
- (required)
52
- mode: train/inference,
53
- [spec, spec_len],
54
- [fc],
55
- [attn, attn_len],
56
- [wav, wav_len],
57
- [sample_method: greedy],
58
- [temp: 1.0] (in case of no teacher forcing)
59
-
60
- (optional, mode=train)
61
- cap,
62
- cap_len,
63
- ss_ratio,
64
-
65
- (optional, mode=inference)
66
- sample_method: greedy/beam,
67
- max_length,
68
- temp,
69
- beam_size (optional, sample_method=beam),
70
- n_best (optional, sample_method=beam),
71
- }
72
- """
73
- encoder_output_dict = self.encoder(input_dict)
74
- output = self.forward_decoder(input_dict, encoder_output_dict)
75
- return output
76
-
77
- def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict):
78
- if input_dict["mode"] == "train":
79
- forward_dict = {
80
- "mode": "train", "sample_method": "greedy", "temp": 1.0
81
- }
82
- for key in self.train_forward_keys:
83
- forward_dict[key] = input_dict[key]
84
- forward_dict.update(encoder_output_dict)
85
- output = self.train_forward(forward_dict)
86
- elif input_dict["mode"] == "inference":
87
- forward_dict = {"mode": "inference"}
88
- default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
89
- for key in self.inference_forward_keys:
90
- if key in input_dict:
91
- forward_dict[key] = input_dict[key]
92
- else:
93
- forward_dict[key] = default_args[key]
94
-
95
- if forward_dict["sample_method"] == "beam":
96
- forward_dict["beam_size"] = input_dict.get("beam_size", 3)
97
- forward_dict["n_best"] = input_dict.get("n_best", False)
98
- forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
99
- elif forward_dict["sample_method"] == "dbs":
100
- forward_dict["beam_size"] = input_dict.get("beam_size", 6)
101
- forward_dict["group_size"] = input_dict.get("group_size", 3)
102
- forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
103
- forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
104
-
105
- forward_dict.update(encoder_output_dict)
106
- output = self.inference_forward(forward_dict)
107
- else:
108
- raise Exception("mode should be either 'train' or 'inference'")
109
- output.update(encoder_output_dict)
110
- return output
111
-
112
- def prepare_output(self, input_dict):
113
- output = {}
114
- batch_size = input_dict["fc_emb"].size(0)
115
- if input_dict["mode"] == "train":
116
- max_length = input_dict["cap"].size(1) - 1
117
- elif input_dict["mode"] == "inference":
118
- max_length = input_dict["max_length"]
119
- else:
120
- raise Exception("mode should be either 'train' or 'inference'")
121
- device = input_dict["fc_emb"].device
122
- output["seq"] = torch.full((batch_size, max_length), self.end_idx,
123
- dtype=torch.long)
124
- output["logit"] = torch.empty(batch_size, max_length,
125
- self.vocab_size).to(device)
126
- output["sampled_logprob"] = torch.zeros(batch_size, max_length)
127
- output["embed"] = torch.empty(batch_size, max_length,
128
- self.decoder.d_model).to(device)
129
- return output
130
-
131
- def train_forward(self, input_dict):
132
- if input_dict["ss_ratio"] != 1: # scheduled sampling training
133
- input_dict["mode"] = "train"
134
- return self.stepwise_forward(input_dict)
135
- output = self.seq_forward(input_dict)
136
- self.train_process(output, input_dict)
137
- return output
138
-
139
- def seq_forward(self, input_dict):
140
- raise NotImplementedError
141
-
142
- def train_process(self, output, input_dict):
143
- pass
144
-
145
- def inference_forward(self, input_dict):
146
- if input_dict["sample_method"] == "beam":
147
- return self.beam_search(input_dict)
148
- elif input_dict["sample_method"] == "dbs":
149
- return self.diverse_beam_search(input_dict)
150
- return self.stepwise_forward(input_dict)
151
-
152
- def stepwise_forward(self, input_dict):
153
- """Step-by-step decoding"""
154
- output = self.prepare_output(input_dict)
155
- max_length = output["seq"].size(1)
156
- # start sampling
157
- for t in range(max_length):
158
- input_dict["t"] = t
159
- self.decode_step(input_dict, output)
160
- if input_dict["mode"] == "inference": # decide whether to stop when sampling
161
- unfinished_t = output["seq"][:, t] != self.end_idx
162
- if t == 0:
163
- unfinished = unfinished_t
164
- else:
165
- unfinished *= unfinished_t
166
- output["seq"][:, t][~unfinished] = self.end_idx
167
- if unfinished.sum() == 0:
168
- break
169
- self.stepwise_process(output)
170
- return output
171
-
172
- def decode_step(self, input_dict, output):
173
- """Decoding operation of timestep t"""
174
- decoder_input = self.prepare_decoder_input(input_dict, output)
175
- # feed to the decoder to get logit
176
- output_t = self.decoder(decoder_input)
177
- logit_t = output_t["logit"]
178
- # assert logit_t.ndim == 3
179
- if logit_t.size(1) == 1:
180
- logit_t = logit_t.squeeze(1)
181
- embed_t = output_t["embed"].squeeze(1)
182
- elif logit_t.size(1) > 1:
183
- logit_t = logit_t[:, -1, :]
184
- embed_t = output_t["embed"][:, -1, :]
185
- else:
186
- raise Exception("no logit output")
187
- # sample the next input word and get the corresponding logit
188
- sampled = self.sample_next_word(logit_t,
189
- method=input_dict["sample_method"],
190
- temp=input_dict["temp"])
191
-
192
- output_t.update(sampled)
193
- output_t["t"] = input_dict["t"]
194
- output_t["logit"] = logit_t
195
- output_t["embed"] = embed_t
196
- self.stepwise_process_step(output, output_t)
197
-
198
- def prepare_decoder_input(self, input_dict, output):
199
- """Prepare the inp ut dict for the decoder"""
200
- raise NotImplementedError
201
-
202
- def stepwise_process_step(self, output, output_t):
203
- """Postprocessing (save output values) after each timestep t"""
204
- t = output_t["t"]
205
- output["logit"][:, t, :] = output_t["logit"]
206
- output["seq"][:, t] = output_t["word"]
207
- output["sampled_logprob"][:, t] = output_t["probs"]
208
- output["embed"][:, t, :] = output_t["embed"]
209
-
210
- def stepwise_process(self, output):
211
- """Postprocessing after the whole step-by-step autoregressive decoding"""
212
- pass
213
-
214
- def sample_next_word(self, logit, method, temp):
215
- """Sample the next word, given probs output by the decoder"""
216
- logprob = torch.log_softmax(logit, dim=1)
217
- if method == "greedy":
218
- sampled_logprob, word = torch.max(logprob.detach(), 1)
219
- elif method == "gumbel":
220
- def sample_gumbel(shape, eps=1e-20):
221
- U = torch.rand(shape).to(logprob.device)
222
- return -torch.log(-torch.log(U + eps) + eps)
223
- def gumbel_softmax_sample(logit, temperature):
224
- y = logit + sample_gumbel(logit.size())
225
- return torch.log_softmax(y / temperature, dim=-1)
226
- _logprob = gumbel_softmax_sample(logprob, temp)
227
- _, word = torch.max(_logprob.data, 1)
228
- sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
229
- else:
230
- logprob = logprob / temp
231
- if method.startswith("top"):
232
- top_num = float(method[3:])
233
- if 0 < top_num < 1: # top-p sampling
234
- probs = torch.softmax(logit, dim=1)
235
- sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
236
- _cumsum = sorted_probs.cumsum(1)
237
- mask = _cumsum < top_num
238
- mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
239
- sorted_probs = sorted_probs * mask.to(sorted_probs)
240
- sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
241
- logprob.scatter_(1, sorted_indices, sorted_probs.log())
242
- else: # top-k sampling
243
- k = int(top_num)
244
- tmp = torch.empty_like(logprob).fill_(float('-inf'))
245
- topk, indices = torch.topk(logprob, k, dim=1)
246
- tmp = tmp.scatter(1, indices, topk)
247
- logprob = tmp
248
- word = torch.distributions.Categorical(logits=logprob.detach()).sample()
249
- sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
250
- word = word.detach().long()
251
- # sampled_logprob: [N,], word: [N,]
252
- return {"word": word, "probs": sampled_logprob}
253
-
254
- def beam_search(self, input_dict):
255
- output = self.prepare_output(input_dict)
256
- max_length = input_dict["max_length"]
257
- beam_size = input_dict["beam_size"]
258
- if input_dict["n_best"]:
259
- n_best_size = input_dict["n_best_size"]
260
- batch_size, max_length = output["seq"].size()
261
- output["seq"] = torch.full((batch_size, n_best_size, max_length),
262
- self.end_idx, dtype=torch.long)
263
-
264
- temp = input_dict["temp"]
265
- # instance by instance beam seach
266
- for i in range(output["seq"].size(0)):
267
- output_i = self.prepare_beamsearch_output(input_dict)
268
- input_dict["sample_idx"] = i
269
- for t in range(max_length):
270
- input_dict["t"] = t
271
- output_t = self.beamsearch_step(input_dict, output_i)
272
- #######################################
273
- # merge with previous beam and select the current max prob beam
274
- #######################################
275
- logit_t = output_t["logit"]
276
- if logit_t.size(1) == 1:
277
- logit_t = logit_t.squeeze(1)
278
- elif logit_t.size(1) > 1:
279
- logit_t = logit_t[:, -1, :]
280
- else:
281
- raise Exception("no logit output")
282
- logprob_t = torch.log_softmax(logit_t, dim=1)
283
- logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
284
- logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
285
- if t == 0: # for the first step, all k seq will have the same probs
286
- topk_logprob, topk_words = logprob_t[0].topk(
287
- beam_size, 0, True, True)
288
- else: # unroll and find top logprob, and their unrolled indices
289
- topk_logprob, topk_words = logprob_t.view(-1).topk(
290
- beam_size, 0, True, True)
291
- topk_words = topk_words.cpu()
292
- output_i["topk_logprob"] = topk_logprob
293
- # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
294
- output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
295
- rounding_mode='trunc')
296
- output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
297
- if t == 0:
298
- output_i["seq"] = output_i["next_word"].unsqueeze(1)
299
- else:
300
- output_i["seq"] = torch.cat([
301
- output_i["seq"][output_i["prev_words_beam"]],
302
- output_i["next_word"].unsqueeze(1)], dim=1)
303
-
304
- # add finished beams to results
305
- is_end = output_i["next_word"] == self.end_idx
306
- if t == max_length - 1:
307
- is_end.fill_(1)
308
-
309
- for beam_idx in range(beam_size):
310
- if is_end[beam_idx]:
311
- final_beam = {
312
- "seq": output_i["seq"][beam_idx].clone(),
313
- "score": output_i["topk_logprob"][beam_idx].item()
314
- }
315
- final_beam["score"] = final_beam["score"] / (t + 1)
316
- output_i["done_beams"].append(final_beam)
317
- output_i["topk_logprob"][is_end] -= 1000
318
-
319
- self.beamsearch_process_step(output_i, output_t)
320
-
321
- self.beamsearch_process(output, output_i, input_dict)
322
- return output
323
-
324
- def prepare_beamsearch_output(self, input_dict):
325
- beam_size = input_dict["beam_size"]
326
- device = input_dict["fc_emb"].device
327
- output = {
328
- "topk_logprob": torch.zeros(beam_size).to(device),
329
- "seq": None,
330
- "prev_words_beam": None,
331
- "next_word": None,
332
- "done_beams": [],
333
- }
334
- return output
335
-
336
- def beamsearch_step(self, input_dict, output_i):
337
- decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
338
- output_t = self.decoder(decoder_input)
339
- output_t["t"] = input_dict["t"]
340
- return output_t
341
-
342
- def prepare_beamsearch_decoder_input(self, input_dict, output_i):
343
- raise NotImplementedError
344
-
345
- def beamsearch_process_step(self, output_i, output_t):
346
- pass
347
-
348
- def beamsearch_process(self, output, output_i, input_dict):
349
- i = input_dict["sample_idx"]
350
- done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
351
- if input_dict["n_best"]:
352
- done_beams = done_beams[:input_dict["n_best_size"]]
353
- for out_idx, done_beam in enumerate(done_beams):
354
- seq = done_beam["seq"]
355
- output["seq"][i][out_idx, :len(seq)] = seq
356
- else:
357
- seq = done_beams[0]["seq"]
358
- output["seq"][i][:len(seq)] = seq
359
-
360
- def diverse_beam_search(self, input_dict):
361
-
362
- def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
363
- local_time = t - divm
364
- unaug_logprob = logprob.clone()
365
-
366
- if divm > 0:
367
- change = torch.zeros(logprob.size(-1))
368
- for prev_choice in range(divm):
369
- prev_decisions = seq_table[prev_choice][..., local_time]
370
- for prev_labels in range(bdash):
371
- change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
372
-
373
- change = change.to(logprob.device)
374
- logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
375
-
376
- return logprob, unaug_logprob
377
-
378
- output = self.prepare_output(input_dict)
379
- group_size = input_dict["group_size"]
380
- batch_size = output["seq"].size(0)
381
- beam_size = input_dict["beam_size"]
382
- bdash = beam_size // group_size
383
- input_dict["bdash"] = bdash
384
- diversity_lambda = input_dict["diversity_lambda"]
385
- device = input_dict["fc_emb"].device
386
- max_length = input_dict["max_length"]
387
- temp = input_dict["temp"]
388
- group_nbest = input_dict["group_nbest"]
389
- batch_size, max_length = output["seq"].size()
390
- if group_nbest:
391
- output["seq"] = torch.full((batch_size, beam_size, max_length),
392
- self.end_idx, dtype=torch.long)
393
- else:
394
- output["seq"] = torch.full((batch_size, group_size, max_length),
395
- self.end_idx, dtype=torch.long)
396
-
397
-
398
- for i in range(batch_size):
399
- input_dict["sample_idx"] = i
400
- seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
401
- logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
402
- done_beams_table = [[] for _ in range(group_size)]
403
-
404
- output_i = {
405
- "prev_words_beam": [None for _ in range(group_size)],
406
- "next_word": [None for _ in range(group_size)],
407
- "state": [None for _ in range(group_size)]
408
- }
409
-
410
- for t in range(max_length + group_size - 1):
411
- input_dict["t"] = t
412
- for divm in range(group_size):
413
- input_dict["divm"] = divm
414
- if t >= divm and t <= max_length + divm - 1:
415
- local_time = t - divm
416
- decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
417
- output_t = self.decoder(decoder_input)
418
- output_t["divm"] = divm
419
- logit_t = output_t["logit"]
420
- if logit_t.size(1) == 1:
421
- logit_t = logit_t.squeeze(1)
422
- elif logit_t.size(1) > 1:
423
- logit_t = logit_t[:, -1, :]
424
- else:
425
- raise Exception("no logit output")
426
- logprob_t = torch.log_softmax(logit_t, dim=1)
427
- logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
428
- logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
429
- logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
430
- if local_time == 0: # for the first step, all k seq will have the same probs
431
- topk_logprob, topk_words = logprob_t[0].topk(
432
- bdash, 0, True, True)
433
- else: # unroll and find top logprob, and their unrolled indices
434
- topk_logprob, topk_words = logprob_t.view(-1).topk(
435
- bdash, 0, True, True)
436
- topk_words = topk_words.cpu()
437
- logprob_table[divm] = topk_logprob
438
- output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
439
- output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
440
- if local_time > 0:
441
- seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
442
- seq_table[divm] = torch.cat([
443
- seq_table[divm],
444
- output_i["next_word"][divm].unsqueeze(-1)], -1)
445
-
446
- is_end = seq_table[divm][:, t-divm] == self.end_idx
447
- assert seq_table[divm].shape[-1] == t - divm + 1
448
- if t == max_length + divm - 1:
449
- is_end.fill_(1)
450
- for beam_idx in range(bdash):
451
- if is_end[beam_idx]:
452
- final_beam = {
453
- "seq": seq_table[divm][beam_idx].clone(),
454
- "score": logprob_table[divm][beam_idx].item()
455
- }
456
- final_beam["score"] = final_beam["score"] / (t - divm + 1)
457
- done_beams_table[divm].append(final_beam)
458
- logprob_table[divm][is_end] -= 1000
459
- self.dbs_process_step(output_i, output_t)
460
- done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
461
- if group_nbest:
462
- done_beams = sum(done_beams_table, [])
463
- else:
464
- done_beams = [group_beam[0] for group_beam in done_beams_table]
465
- for _, done_beam in enumerate(done_beams):
466
- output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
467
-
468
- return output
469
-
470
- def prepare_dbs_decoder_input(self, input_dict, output_i):
471
- raise NotImplementedError
472
-
473
- def dbs_process_step(self, output_i, output_t):
474
- pass
475
-
476
-
477
- class CaptionSequenceModel(nn.Module, CaptionMetaMixin):
478
-
479
- def __init__(self, model, seq_output_size):
480
- super().__init__()
481
- self.model = model
482
- if model.decoder.d_model != seq_output_size:
483
- self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
484
- else:
485
- self.output_transform = lambda x: x
486
-
487
- def forward(self, input_dict):
488
- output = self.model(input_dict)
489
-
490
- if input_dict["mode"] == "train":
491
- lens = input_dict["cap_len"] - 1
492
- # seq_outputs: [N, d_model]
493
- elif input_dict["mode"] == "inference":
494
- if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
495
- return output
496
- seq = output["seq"]
497
- lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
498
- else:
499
- raise Exception("mode should be either 'train' or 'inference'")
500
- seq_output = mean_with_lens(output["embed"], lens)
501
- seq_output = self.output_transform(seq_output)
502
- output["seq_output"] = seq_output
503
- return output
504
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/cnn_encoder.py DELETED
@@ -1,808 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from torchaudio import transforms
7
-
8
- from utils.model_util import mean_with_lens, max_with_lens
9
- from utils.train_util import merge_load_state_dict
10
-
11
-
12
- def init_layer(layer):
13
- """Initialize a Linear or Convolutional layer. """
14
- nn.init.xavier_uniform_(layer.weight)
15
-
16
- if hasattr(layer, 'bias'):
17
- if layer.bias is not None:
18
- layer.bias.data.fill_(0.)
19
-
20
-
21
- def init_bn(bn):
22
- """Initialize a Batchnorm layer. """
23
- bn.bias.data.fill_(0.)
24
- bn.weight.data.fill_(1.)
25
-
26
-
27
- class ConvBlock(nn.Module):
28
- def __init__(self, in_channels, out_channels):
29
-
30
- super(ConvBlock, self).__init__()
31
-
32
- self.conv1 = nn.Conv2d(in_channels=in_channels,
33
- out_channels=out_channels,
34
- kernel_size=(3, 3), stride=(1, 1),
35
- padding=(1, 1), bias=False)
36
-
37
- self.conv2 = nn.Conv2d(in_channels=out_channels,
38
- out_channels=out_channels,
39
- kernel_size=(3, 3), stride=(1, 1),
40
- padding=(1, 1), bias=False)
41
-
42
- self.bn1 = nn.BatchNorm2d(out_channels)
43
- self.bn2 = nn.BatchNorm2d(out_channels)
44
-
45
- self.init_weight()
46
-
47
- def init_weight(self):
48
- init_layer(self.conv1)
49
- init_layer(self.conv2)
50
- init_bn(self.bn1)
51
- init_bn(self.bn2)
52
-
53
-
54
- def forward(self, input, pool_size=(2, 2), pool_type='avg'):
55
-
56
- x = input
57
- x = F.relu_(self.bn1(self.conv1(x)))
58
- x = F.relu_(self.bn2(self.conv2(x)))
59
- if pool_type == 'max':
60
- x = F.max_pool2d(x, kernel_size=pool_size)
61
- elif pool_type == 'avg':
62
- x = F.avg_pool2d(x, kernel_size=pool_size)
63
- elif pool_type == 'avg+max':
64
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
65
- x2 = F.max_pool2d(x, kernel_size=pool_size)
66
- x = x1 + x2
67
- else:
68
- raise Exception('Incorrect argument!')
69
-
70
- return x
71
-
72
-
73
- class ConvBlock5x5(nn.Module):
74
- def __init__(self, in_channels, out_channels):
75
-
76
- super(ConvBlock5x5, self).__init__()
77
-
78
- self.conv1 = nn.Conv2d(in_channels=in_channels,
79
- out_channels=out_channels,
80
- kernel_size=(5, 5), stride=(1, 1),
81
- padding=(2, 2), bias=False)
82
-
83
- self.bn1 = nn.BatchNorm2d(out_channels)
84
-
85
- self.init_weight()
86
-
87
- def init_weight(self):
88
- init_layer(self.conv1)
89
- init_bn(self.bn1)
90
-
91
- def forward(self, input, pool_size=(2, 2), pool_type='avg'):
92
-
93
- x = input
94
- x = F.relu_(self.bn1(self.conv1(x)))
95
- if pool_type == 'max':
96
- x = F.max_pool2d(x, kernel_size=pool_size)
97
- elif pool_type == 'avg':
98
- x = F.avg_pool2d(x, kernel_size=pool_size)
99
- elif pool_type == 'avg+max':
100
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
101
- x2 = F.max_pool2d(x, kernel_size=pool_size)
102
- x = x1 + x2
103
- else:
104
- raise Exception('Incorrect argument!')
105
-
106
- return x
107
-
108
-
109
- class Cnn6Encoder(nn.Module):
110
-
111
- def __init__(self, sample_rate=32000, freeze=False):
112
- super().__init__()
113
-
114
- sr_to_fmax = {
115
- 32000: 14000,
116
- 16000: 8000
117
- }
118
- # Logmel spectrogram extractor
119
- self.melspec_extractor = transforms.MelSpectrogram(
120
- sample_rate=sample_rate,
121
- n_fft=32 * sample_rate // 1000,
122
- win_length=32 * sample_rate // 1000,
123
- hop_length=10 * sample_rate // 1000,
124
- f_min=50,
125
- f_max=sr_to_fmax[sample_rate],
126
- n_mels=64,
127
- norm="slaney",
128
- mel_scale="slaney"
129
- )
130
- self.hop_length = 10 * sample_rate // 1000
131
- self.db_transform = transforms.AmplitudeToDB()
132
-
133
- self.bn0 = nn.BatchNorm2d(64)
134
-
135
- self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
136
- self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
137
- self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
138
- self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
139
-
140
- self.downsample_ratio = 16
141
-
142
- self.fc1 = nn.Linear(512, 512, bias=True)
143
- self.fc_emb_size = 512
144
- self.init_weight()
145
- self.freeze = freeze
146
-
147
- def init_weight(self):
148
- init_bn(self.bn0)
149
- init_layer(self.fc1)
150
-
151
- def load_pretrained(self, pretrained, output_fn):
152
- checkpoint = torch.load(pretrained, map_location="cpu")
153
-
154
- if "model" in checkpoint:
155
- state_dict = checkpoint["model"]
156
- else:
157
- raise Exception("Unkown checkpoint format")
158
-
159
- loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
160
- if self.freeze:
161
- for name, param in self.named_parameters():
162
- if name in loaded_keys:
163
- param.requires_grad = False
164
- else:
165
- param.requires_grad = True
166
-
167
- def forward(self, input_dict):
168
- waveform = input_dict["wav"]
169
- wave_length = input_dict["wav_len"]
170
- specaug = input_dict["specaug"]
171
- x = self.melspec_extractor(waveform)
172
- x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
173
- x = x.transpose(1, 2)
174
- x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
175
-
176
- x = x.transpose(1, 3)
177
- x = self.bn0(x)
178
- x = x.transpose(1, 3)
179
-
180
- x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
181
- x = F.dropout(x, p=0.2, training=self.training)
182
- x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
183
- x = F.dropout(x, p=0.2, training=self.training)
184
- x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
185
- x = F.dropout(x, p=0.2, training=self.training)
186
- x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
187
- x = F.dropout(x, p=0.2, training=self.training)
188
-
189
- x = torch.mean(x, dim=3)
190
- attn_emb = x.transpose(1, 2)
191
- wave_length = torch.as_tensor(wave_length)
192
- feat_length = torch.div(wave_length, self.hop_length,
193
- rounding_mode="floor") + 1
194
- feat_length = torch.div(feat_length, self.downsample_ratio,
195
- rounding_mode="floor")
196
- x_max = max_with_lens(attn_emb, feat_length)
197
- x_mean = mean_with_lens(attn_emb, feat_length)
198
- x = x_max + x_mean
199
- x = F.dropout(x, p=0.5, training=self.training)
200
- x = F.relu_(self.fc1(x))
201
- fc_emb = F.dropout(x, p=0.5, training=self.training)
202
-
203
- return {
204
- "attn_emb": attn_emb,
205
- "fc_emb": fc_emb,
206
- "attn_emb_len": feat_length
207
- }
208
-
209
-
210
- class Cnn10Encoder(nn.Module):
211
-
212
- def __init__(self, sample_rate=32000, freeze=False):
213
- super().__init__()
214
-
215
- sr_to_fmax = {
216
- 32000: 14000,
217
- 16000: 8000
218
- }
219
- # Logmel spectrogram extractor
220
- self.melspec_extractor = transforms.MelSpectrogram(
221
- sample_rate=sample_rate,
222
- n_fft=32 * sample_rate // 1000,
223
- win_length=32 * sample_rate // 1000,
224
- hop_length=10 * sample_rate // 1000,
225
- f_min=50,
226
- f_max=sr_to_fmax[sample_rate],
227
- n_mels=64,
228
- norm="slaney",
229
- mel_scale="slaney"
230
- )
231
- self.hop_length = 10 * sample_rate // 1000
232
- self.db_transform = transforms.AmplitudeToDB()
233
-
234
- self.bn0 = nn.BatchNorm2d(64)
235
-
236
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
237
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
238
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
239
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
240
-
241
- self.downsample_ratio = 16
242
-
243
- self.fc1 = nn.Linear(512, 512, bias=True)
244
- self.fc_emb_size = 512
245
- self.init_weight()
246
- self.freeze = freeze
247
-
248
- def init_weight(self):
249
- init_bn(self.bn0)
250
- init_layer(self.fc1)
251
-
252
- def load_pretrained(self, pretrained, output_fn):
253
- checkpoint = torch.load(pretrained, map_location="cpu")
254
-
255
- if "model" in checkpoint:
256
- state_dict = checkpoint["model"]
257
- else:
258
- raise Exception("Unkown checkpoint format")
259
-
260
- loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
261
- if self.freeze:
262
- for name, param in self.named_parameters():
263
- if name in loaded_keys:
264
- param.requires_grad = False
265
- else:
266
- param.requires_grad = True
267
-
268
- def forward(self, input_dict):
269
- waveform = input_dict["wav"]
270
- wave_length = input_dict["wav_len"]
271
- specaug = input_dict["specaug"]
272
- x = self.melspec_extractor(waveform)
273
- x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
274
- x = x.transpose(1, 2)
275
- x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
276
-
277
- x = x.transpose(1, 3)
278
- x = self.bn0(x)
279
- x = x.transpose(1, 3)
280
-
281
- x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
282
- x = F.dropout(x, p=0.2, training=self.training)
283
- x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
284
- x = F.dropout(x, p=0.2, training=self.training)
285
- x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
286
- x = F.dropout(x, p=0.2, training=self.training)
287
- x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
288
- x = F.dropout(x, p=0.2, training=self.training)
289
-
290
- x = torch.mean(x, dim=3)
291
- attn_emb = x.transpose(1, 2)
292
- wave_length = torch.as_tensor(wave_length)
293
- feat_length = torch.div(wave_length, self.hop_length,
294
- rounding_mode="floor") + 1
295
- feat_length = torch.div(feat_length, self.downsample_ratio,
296
- rounding_mode="floor")
297
- x_max = max_with_lens(attn_emb, feat_length)
298
- x_mean = mean_with_lens(attn_emb, feat_length)
299
- x = x_max + x_mean
300
- x = F.dropout(x, p=0.5, training=self.training)
301
- x = F.relu_(self.fc1(x))
302
- fc_emb = F.dropout(x, p=0.5, training=self.training)
303
-
304
- return {
305
- "attn_emb": attn_emb,
306
- "fc_emb": fc_emb,
307
- "attn_emb_len": feat_length
308
- }
309
-
310
-
311
- class Cnn14Encoder(nn.Module):
312
- def __init__(self, sample_rate=32000, freeze=False):
313
- super().__init__()
314
- sr_to_fmax = {
315
- 32000: 14000,
316
- 16000: 8000
317
- }
318
- # Logmel spectrogram extractor
319
- self.melspec_extractor = transforms.MelSpectrogram(
320
- sample_rate=sample_rate,
321
- n_fft=32 * sample_rate // 1000,
322
- win_length=32 * sample_rate // 1000,
323
- hop_length=10 * sample_rate // 1000,
324
- f_min=50,
325
- f_max=sr_to_fmax[sample_rate],
326
- n_mels=64,
327
- norm="slaney",
328
- mel_scale="slaney"
329
- )
330
- self.hop_length = 10 * sample_rate // 1000
331
- self.db_transform = transforms.AmplitudeToDB()
332
-
333
- self.bn0 = nn.BatchNorm2d(64)
334
-
335
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
336
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
337
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
338
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
339
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
340
- self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
341
-
342
- self.downsample_ratio = 32
343
-
344
- self.fc1 = nn.Linear(2048, 2048, bias=True)
345
- self.fc_emb_size = 2048
346
-
347
- self.init_weight()
348
- self.freeze = freeze
349
-
350
- def init_weight(self):
351
- init_bn(self.bn0)
352
- init_layer(self.fc1)
353
-
354
- def load_pretrained(self, pretrained, output_fn):
355
- checkpoint = torch.load(pretrained, map_location="cpu")
356
-
357
- if "model" in checkpoint:
358
- state_keys = checkpoint["model"].keys()
359
- backbone = False
360
- for key in state_keys:
361
- if key.startswith("backbone."):
362
- backbone = True
363
- break
364
-
365
- if backbone: # COLA
366
- state_dict = {}
367
- for key, value in checkpoint["model"].items():
368
- if key.startswith("backbone."):
369
- model_key = key.replace("backbone.", "")
370
- state_dict[model_key] = value
371
- else: # PANNs
372
- state_dict = checkpoint["model"]
373
- elif "state_dict" in checkpoint: # BLAT
374
- state_dict = checkpoint["state_dict"]
375
- state_dict_keys = list(filter(
376
- lambda x: "audio_encoder" in x, state_dict.keys()))
377
- state_dict = {
378
- key.replace('audio_encoder.', ''): state_dict[key]
379
- for key in state_dict_keys
380
- }
381
- else:
382
- raise Exception("Unkown checkpoint format")
383
-
384
- loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
385
- if self.freeze:
386
- for name, param in self.named_parameters():
387
- if name in loaded_keys:
388
- param.requires_grad = False
389
- else:
390
- param.requires_grad = True
391
-
392
- def forward(self, input_dict):
393
- waveform = input_dict["wav"]
394
- wave_length = input_dict["wav_len"]
395
- specaug = input_dict["specaug"]
396
- x = self.melspec_extractor(waveform)
397
- x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
398
- x = x.transpose(1, 2)
399
- x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
400
-
401
- x = x.transpose(1, 3)
402
- x = self.bn0(x)
403
- x = x.transpose(1, 3)
404
-
405
- x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
406
- x = F.dropout(x, p=0.2, training=self.training)
407
- x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
408
- x = F.dropout(x, p=0.2, training=self.training)
409
- x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
410
- x = F.dropout(x, p=0.2, training=self.training)
411
- x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
412
- x = F.dropout(x, p=0.2, training=self.training)
413
- x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
414
- x = F.dropout(x, p=0.2, training=self.training)
415
- x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
416
- x = F.dropout(x, p=0.2, training=self.training)
417
- x = torch.mean(x, dim=3)
418
- attn_emb = x.transpose(1, 2)
419
-
420
- wave_length = torch.as_tensor(wave_length)
421
- feat_length = torch.div(wave_length, self.hop_length,
422
- rounding_mode="floor") + 1
423
- feat_length = torch.div(feat_length, self.downsample_ratio,
424
- rounding_mode="floor")
425
- x_max = max_with_lens(attn_emb, feat_length)
426
- x_mean = mean_with_lens(attn_emb, feat_length)
427
- x = x_max + x_mean
428
- x = F.dropout(x, p=0.5, training=self.training)
429
- x = F.relu_(self.fc1(x))
430
- fc_emb = F.dropout(x, p=0.5, training=self.training)
431
-
432
- output_dict = {
433
- 'fc_emb': fc_emb,
434
- 'attn_emb': attn_emb,
435
- 'attn_emb_len': feat_length
436
- }
437
-
438
- return output_dict
439
-
440
-
441
- class InvertedResidual(nn.Module):
442
-
443
- def __init__(self, inp, oup, stride, expand_ratio):
444
- super().__init__()
445
- self.stride = stride
446
- assert stride in [1, 2]
447
-
448
- hidden_dim = round(inp * expand_ratio)
449
- self.use_res_connect = self.stride == 1 and inp == oup
450
-
451
- if expand_ratio == 1:
452
- _layers = [
453
- nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
454
- nn.AvgPool2d(stride),
455
- nn.BatchNorm2d(hidden_dim),
456
- nn.ReLU6(inplace=True),
457
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
458
- nn.BatchNorm2d(oup)
459
- ]
460
- _layers = nn.Sequential(*_layers)
461
- init_layer(_layers[0])
462
- init_bn(_layers[2])
463
- init_layer(_layers[4])
464
- init_bn(_layers[5])
465
- self.conv = _layers
466
- else:
467
- _layers = [
468
- nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
469
- nn.BatchNorm2d(hidden_dim),
470
- nn.ReLU6(inplace=True),
471
- nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
472
- nn.AvgPool2d(stride),
473
- nn.BatchNorm2d(hidden_dim),
474
- nn.ReLU6(inplace=True),
475
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
476
- nn.BatchNorm2d(oup)
477
- ]
478
- _layers = nn.Sequential(*_layers)
479
- init_layer(_layers[0])
480
- init_bn(_layers[1])
481
- init_layer(_layers[3])
482
- init_bn(_layers[5])
483
- init_layer(_layers[7])
484
- init_bn(_layers[8])
485
- self.conv = _layers
486
-
487
- def forward(self, x):
488
- if self.use_res_connect:
489
- return x + self.conv(x)
490
- else:
491
- return self.conv(x)
492
-
493
-
494
- class MobileNetV2(nn.Module):
495
- def __init__(self, sample_rate):
496
-
497
- super().__init__()
498
-
499
- sr_to_fmax = {
500
- 32000: 14000,
501
- 16000: 8000
502
- }
503
- # Logmel spectrogram extractor
504
- self.melspec_extractor = transforms.MelSpectrogram(
505
- sample_rate=sample_rate,
506
- n_fft=32 * sample_rate // 1000,
507
- win_length=32 * sample_rate // 1000,
508
- hop_length=10 * sample_rate // 1000,
509
- f_min=50,
510
- f_max=sr_to_fmax[sample_rate],
511
- n_mels=64,
512
- norm="slaney",
513
- mel_scale="slaney"
514
- )
515
- self.hop_length = 10 * sample_rate // 1000
516
- self.db_transform = transforms.AmplitudeToDB()
517
-
518
- self.bn0 = nn.BatchNorm2d(64)
519
-
520
- width_mult=1.
521
- block = InvertedResidual
522
- input_channel = 32
523
- last_channel = 1280
524
- interverted_residual_setting = [
525
- # t, c, n, s
526
- [1, 16, 1, 1],
527
- [6, 24, 2, 2],
528
- [6, 32, 3, 2],
529
- [6, 64, 4, 2],
530
- [6, 96, 3, 2],
531
- [6, 160, 3, 1],
532
- [6, 320, 1, 1],
533
- ]
534
-
535
- self.downsample_ratio = 32
536
-
537
- def conv_bn(inp, oup, stride):
538
- _layers = [
539
- nn.Conv2d(inp, oup, 3, 1, 1, bias=False),
540
- nn.AvgPool2d(stride),
541
- nn.BatchNorm2d(oup),
542
- nn.ReLU6(inplace=True)
543
- ]
544
- _layers = nn.Sequential(*_layers)
545
- init_layer(_layers[0])
546
- init_bn(_layers[2])
547
- return _layers
548
-
549
-
550
- def conv_1x1_bn(inp, oup):
551
- _layers = nn.Sequential(
552
- nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
553
- nn.BatchNorm2d(oup),
554
- nn.ReLU6(inplace=True)
555
- )
556
- init_layer(_layers[0])
557
- init_bn(_layers[1])
558
- return _layers
559
-
560
- # building first layer
561
- input_channel = int(input_channel * width_mult)
562
- self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
563
- self.features = [conv_bn(1, input_channel, 2)]
564
- # building inverted residual blocks
565
- for t, c, n, s in interverted_residual_setting:
566
- output_channel = int(c * width_mult)
567
- for i in range(n):
568
- if i == 0:
569
- self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
570
- else:
571
- self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
572
- input_channel = output_channel
573
- # building last several layers
574
- self.features.append(conv_1x1_bn(input_channel, self.last_channel))
575
- # make it nn.Sequential
576
- self.features = nn.Sequential(*self.features)
577
-
578
- self.fc1 = nn.Linear(1280, 1024, bias=True)
579
-
580
- self.init_weight()
581
-
582
- def init_weight(self):
583
- init_bn(self.bn0)
584
- init_layer(self.fc1)
585
-
586
- def forward(self, input_dict):
587
-
588
- waveform = input_dict["wav"]
589
- wave_length = input_dict["wav_len"]
590
- specaug = input_dict["specaug"]
591
- x = self.melspec_extractor(waveform)
592
- x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
593
- x = x.transpose(1, 2)
594
- x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
595
-
596
- x = x.transpose(1, 3)
597
- x = self.bn0(x)
598
- x = x.transpose(1, 3)
599
-
600
- x = self.features(x)
601
-
602
- x = torch.mean(x, dim=3)
603
- attn_emb = x.transpose(1, 2)
604
-
605
- wave_length = torch.as_tensor(wave_length)
606
- feat_length = torch.div(wave_length, self.hop_length,
607
- rounding_mode="floor") + 1
608
- feat_length = torch.div(feat_length, self.downsample_ratio,
609
- rounding_mode="floor")
610
- x_max = max_with_lens(attn_emb, feat_length)
611
- x_mean = mean_with_lens(attn_emb, feat_length)
612
- x = x_max + x_mean
613
- # TODO: the original PANNs code does not have dropout here, why?
614
- x = F.dropout(x, p=0.5, training=self.training)
615
- x = F.relu_(self.fc1(x))
616
- fc_emb = F.dropout(x, p=0.5, training=self.training)
617
-
618
- output_dict = {
619
- 'fc_emb': fc_emb,
620
- 'attn_emb': attn_emb,
621
- 'attn_emb_len': feat_length
622
- }
623
-
624
- return output_dict
625
-
626
-
627
- class MobileNetV3(nn.Module):
628
-
629
- def __init__(self,
630
- sample_rate,
631
- model_name,
632
- n_mels=64,
633
- win_length=32,
634
- pretrained=True,
635
- freeze=False,
636
- pooling="mean_max_fc"):
637
-
638
- from captioning.models.eff_at_encoder import get_model, NAME_TO_WIDTH
639
-
640
- super().__init__()
641
- sr_to_fmax = {
642
- 32000: 14000,
643
- 16000: 8000
644
- }
645
- self.n_mels = n_mels
646
- # Logmel spectrogram extractor
647
- self.melspec_extractor = transforms.MelSpectrogram(
648
- sample_rate=sample_rate,
649
- n_fft=32 * sample_rate // 1000,
650
- win_length=win_length * sample_rate // 1000,
651
- hop_length=10 * sample_rate // 1000,
652
- f_min=50,
653
- f_max=sr_to_fmax[sample_rate],
654
- n_mels=n_mels,
655
- norm="slaney",
656
- mel_scale="slaney"
657
- )
658
- self.hop_length = 10 * sample_rate // 1000
659
- self.db_transform = transforms.AmplitudeToDB()
660
-
661
- self.bn0 = nn.BatchNorm2d(n_mels)
662
-
663
- width_mult = NAME_TO_WIDTH(model_name)
664
- self.features = get_model(model_name=model_name,
665
- pretrained=pretrained,
666
- width_mult=width_mult).features
667
- self.downsample_ratio = 32
668
-
669
- if pooling == "mean_max_fc":
670
- self.fc_emb_size = 512
671
- self.fc1 = nn.Linear(self.features[-1].out_channels, 512, bias=True)
672
- elif pooling == "mean":
673
- self.fc_emb_size = self.features[-1].out_channels
674
- self.init_weight()
675
-
676
- if freeze:
677
- for param in self.parameters():
678
- param.requires_grad = False
679
-
680
- self.pooling = pooling
681
-
682
- def init_weight(self):
683
- init_bn(self.bn0)
684
- if hasattr(self, "fc1"):
685
- init_layer(self.fc1)
686
-
687
- def forward(self, input_dict):
688
-
689
- waveform = input_dict["wav"]
690
- wave_length = input_dict["wav_len"]
691
- specaug = input_dict["specaug"]
692
- x = self.melspec_extractor(waveform)
693
- x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
694
- x = x.transpose(1, 2)
695
- x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
696
-
697
- x = x.transpose(1, 3)
698
- x = self.bn0(x)
699
- x = x.transpose(1, 3)
700
-
701
- x = self.features(x)
702
-
703
- x = torch.mean(x, dim=3)
704
- attn_emb = x.transpose(1, 2)
705
-
706
- wave_length = torch.as_tensor(wave_length)
707
- feat_length = torch.div(wave_length, self.hop_length,
708
- rounding_mode="floor") + 1
709
- feat_length = torch.div(feat_length, self.downsample_ratio,
710
- rounding_mode="floor")
711
-
712
- if self.pooling == "mean_max_fc":
713
- x_max = max_with_lens(attn_emb, feat_length)
714
- x_mean = mean_with_lens(attn_emb, feat_length)
715
- x = x_max + x_mean
716
- x = F.dropout(x, p=0.5, training=self.training)
717
- x = F.relu_(self.fc1(x))
718
- fc_emb = F.dropout(x, p=0.5, training=self.training)
719
- elif self.pooling == "mean":
720
- fc_emb = mean_with_lens(attn_emb, feat_length)
721
-
722
- output_dict = {
723
- 'fc_emb': fc_emb,
724
- 'attn_emb': attn_emb,
725
- 'attn_emb_len': feat_length
726
- }
727
-
728
- return output_dict
729
-
730
-
731
- class EfficientNetB2(nn.Module):
732
-
733
- def __init__(self,
734
- n_mels: int = 64,
735
- win_length: int = 32,
736
- hop_length: int = 10,
737
- f_min: int = 0,
738
- pretrained: bool = False,
739
- prune_ratio: float = 0.0,
740
- prune_se: bool = True,
741
- prune_start_layer: int = 0,
742
- prune_method: str = "operator_norm",
743
- freeze: bool = False,):
744
- from models.eff_latent_encoder import get_model, get_pruned_model
745
- super().__init__()
746
- sample_rate = 16000
747
- self.melspec_extractor = transforms.MelSpectrogram(
748
- sample_rate=sample_rate,
749
- n_fft=win_length * sample_rate // 1000,
750
- win_length=win_length * sample_rate // 1000,
751
- hop_length=hop_length * sample_rate // 1000,
752
- f_min=f_min,
753
- n_mels=n_mels,
754
- )
755
- self.hop_length = 10 * sample_rate // 1000
756
- self.db_transform = transforms.AmplitudeToDB(top_db=120)
757
- if prune_ratio > 0:
758
- self.backbone = get_pruned_model(pretrained=pretrained,
759
- prune_ratio=prune_ratio,
760
- prune_start_layer=prune_start_layer,
761
- prune_se=prune_se,
762
- prune_method=prune_method)
763
- else:
764
- self.backbone = get_model(pretrained=pretrained)
765
- self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels
766
- self.downsample_ratio = 32
767
- if freeze:
768
- for param in self.parameters():
769
- param.requires_grad = False
770
-
771
- def forward(self, input_dict):
772
-
773
- waveform = input_dict["wav"]
774
- wave_length = input_dict["wav_len"]
775
- specaug = input_dict["specaug"]
776
- x = self.melspec_extractor(waveform)
777
- x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
778
-
779
- x = self.backbone(x)
780
- attn_emb = x
781
-
782
- wave_length = torch.as_tensor(wave_length)
783
- feat_length = torch.div(wave_length, self.hop_length,
784
- rounding_mode="floor") + 1
785
- feat_length = torch.div(feat_length, self.downsample_ratio,
786
- rounding_mode="floor")
787
- fc_emb = mean_with_lens(attn_emb, feat_length)
788
-
789
- output_dict = {
790
- 'fc_emb': fc_emb,
791
- 'attn_emb': attn_emb,
792
- 'attn_emb_len': feat_length
793
- }
794
- return output_dict
795
-
796
-
797
- if __name__ == "__main__":
798
- encoder = MobileNetV3(32000, "mn10_as")
799
- print(encoder)
800
- input_dict = {
801
- "wav": torch.randn(4, 320000),
802
- "wav_len": torch.tensor([320000, 280000, 160000, 300000]),
803
- "specaug": True
804
- }
805
- output_dict = encoder(input_dict)
806
- print("attn embed: ", output_dict["attn_emb"].shape)
807
- print("fc embed: ", output_dict["fc_emb"].shape)
808
- print("attn embed length: ", output_dict["attn_emb_len"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eff_latent_encoder.py DELETED
@@ -1,347 +0,0 @@
1
- import os
2
-
3
- import torch
4
- import torch.nn as nn
5
- from tqdm import tqdm
6
- from efficientnet_pytorch import EfficientNet
7
- from efficientnet_pytorch.model import MBConvBlock
8
- from efficientnet_pytorch import utils as efficientnet_utils
9
- from efficientnet_pytorch.utils import (
10
- round_filters,
11
- round_repeats,
12
- get_same_padding_conv2d,
13
- calculate_output_image_size,
14
- MemoryEfficientSwish,
15
- )
16
- from einops import rearrange, reduce
17
- from torch.hub import load_state_dict_from_url
18
-
19
-
20
- model_dir = os.getcwd()
21
-
22
-
23
- class _EffiNet(nn.Module):
24
- """A proxy for efficient net models"""
25
- def __init__(self,
26
- blocks_args=None,
27
- global_params=None,
28
- prune_start_layer: int = 0,
29
- prune_se: bool = True,
30
- prune_ratio: float = 0.0
31
- ) -> None:
32
- super().__init__()
33
- if prune_ratio > 0:
34
- self.eff_net = EfficientNetB2Pruned(blocks_args=blocks_args,
35
- global_params=global_params,
36
- prune_start_layer=prune_start_layer,
37
- prune_se=prune_se,
38
- prune_ratio=prune_ratio)
39
- else:
40
- self.eff_net = EfficientNet(blocks_args=blocks_args,
41
- global_params=global_params)
42
-
43
-
44
- def forward(self, x: torch.Tensor):
45
- x = rearrange(x, 'b f t -> b 1 f t')
46
- x = self.eff_net.extract_features(x)
47
- return reduce(x, 'b c f t -> b t c', 'mean')
48
-
49
-
50
- def get_model(pretrained=True) -> _EffiNet:
51
- blocks_args, global_params = efficientnet_utils.get_model_params(
52
- 'efficientnet-b2', {'include_top': False})
53
- model = _EffiNet(blocks_args=blocks_args,
54
- global_params=global_params)
55
- model.eff_net._change_in_channels(1)
56
- if pretrained:
57
- model_path = os.path.join(model_dir, "effb2.pt")
58
- if not os.path.exists(model_path):
59
- state_dict = load_state_dict_from_url(
60
- 'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt',
61
- progress=True,
62
- model_dir=model_dir)
63
- else:
64
- state_dict = torch.load(model_path)
65
- del_keys = [key for key in state_dict if key.startswith("front_end")]
66
- for key in del_keys:
67
- del state_dict[key]
68
- model.eff_net.load_state_dict(state_dict)
69
- return model
70
-
71
-
72
- class MBConvBlockPruned(MBConvBlock):
73
-
74
- def __init__(self, block_args, global_params, image_size=None, prune_ratio=0.5, prune_se=True):
75
- super(MBConvBlock, self).__init__()
76
- self._block_args = block_args
77
- self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
78
- self._bn_eps = global_params.batch_norm_epsilon
79
- self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
80
- self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
81
-
82
- # Expansion phase (Inverted Bottleneck)
83
- inp = self._block_args.input_filters # number of input channels
84
- oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
85
- if self._block_args.expand_ratio != 1:
86
- oup = int(oup * (1 - prune_ratio))
87
- Conv2d = get_same_padding_conv2d(image_size=image_size)
88
- self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
89
- self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
90
- # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
91
-
92
- # Depthwise convolution phase
93
- k = self._block_args.kernel_size
94
- s = self._block_args.stride
95
- Conv2d = get_same_padding_conv2d(image_size=image_size)
96
- self._depthwise_conv = Conv2d(
97
- in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
98
- kernel_size=k, stride=s, bias=False)
99
- self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
100
- image_size = calculate_output_image_size(image_size, s)
101
-
102
- # Squeeze and Excitation layer, if desired
103
- if self.has_se:
104
- Conv2d = get_same_padding_conv2d(image_size=(1, 1))
105
- num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
106
- if prune_se:
107
- num_squeezed_channels = int(num_squeezed_channels * (1 - prune_ratio))
108
- self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
109
- self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
110
-
111
- # Pointwise convolution phase
112
- final_oup = self._block_args.output_filters
113
- Conv2d = get_same_padding_conv2d(image_size=image_size)
114
- self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
115
- self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
116
- self._swish = MemoryEfficientSwish()
117
-
118
-
119
- class EfficientNetB2Pruned(EfficientNet):
120
-
121
- def __init__(self, blocks_args=None, global_params=None,
122
- prune_start_layer=0, prune_ratio=0.5, prune_se=True):
123
- super(EfficientNet, self).__init__()
124
- assert isinstance(blocks_args, list), 'blocks_args should be a list'
125
- assert len(blocks_args) > 0, 'block args must be greater than 0'
126
- self._global_params = global_params
127
- self._blocks_args = blocks_args
128
-
129
- # Batch norm parameters
130
- bn_mom = 1 - self._global_params.batch_norm_momentum
131
- bn_eps = self._global_params.batch_norm_epsilon
132
-
133
- # Get stem static or dynamic convolution depending on image size
134
- image_size = global_params.image_size
135
- Conv2d = get_same_padding_conv2d(image_size=image_size)
136
-
137
- n_build_blks = 0
138
- # Stem
139
- in_channels = 1 # spectrogram
140
-
141
- p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
142
- out_channels = round_filters(32 * (1 - p),
143
- self._global_params) # number of output channels
144
- self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
145
- self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
146
- image_size = calculate_output_image_size(image_size, 2)
147
- n_build_blks += 1
148
-
149
- # Build blocks
150
- self._blocks = nn.ModuleList([])
151
- for block_args in self._blocks_args:
152
-
153
- p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
154
- orig_input_filters = block_args.input_filters
155
- # Update block input and output filters based on depth multiplier.
156
- block_args = block_args._replace(
157
- input_filters=round_filters(
158
- block_args.input_filters * (1 - p),
159
- self._global_params),
160
- output_filters=round_filters(
161
- block_args.output_filters * (1 - p),
162
- self._global_params),
163
- num_repeat=round_repeats(block_args.num_repeat, self._global_params)
164
- )
165
-
166
- if n_build_blks == prune_start_layer:
167
- block_args = block_args._replace(input_filters=round_filters(
168
- orig_input_filters,
169
- self._global_params)
170
- )
171
-
172
- # The first block needs to take care of stride and filter size increase.
173
- self._blocks.append(MBConvBlockPruned(block_args, self._global_params,
174
- image_size=image_size, prune_ratio=p,
175
- prune_se=prune_se))
176
- n_build_blks += 1
177
-
178
- image_size = calculate_output_image_size(image_size, block_args.stride)
179
- if block_args.num_repeat > 1: # modify block_args to keep same output size
180
- block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
181
- for _ in range(block_args.num_repeat - 1):
182
- self._blocks.append(MBConvBlockPruned(block_args,
183
- self._global_params,
184
- image_size=image_size,
185
- prune_ratio=p,
186
- prune_se=prune_se))
187
- # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
188
-
189
- # Head
190
- in_channels = block_args.output_filters # output of final block
191
- p = 0.0 if n_build_blks < prune_start_layer else prune_ratio
192
- out_channels = round_filters(1280 * (1 - p), self._global_params)
193
- Conv2d = get_same_padding_conv2d(image_size=image_size)
194
- self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
195
- self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
196
-
197
- # Final linear layer
198
- self._avg_pooling = nn.AdaptiveAvgPool2d(1)
199
- if self._global_params.include_top:
200
- self._dropout = nn.Dropout(self._global_params.dropout_rate)
201
- self._fc = nn.Linear(out_channels, self._global_params.num_classes)
202
-
203
- # set activation to memory efficient swish by default
204
- self._swish = MemoryEfficientSwish()
205
-
206
-
207
- def get_pruned_model(pretrained: bool = True,
208
- prune_ratio: float = 0.5,
209
- prune_start_layer: int = 0,
210
- prune_se: bool = True,
211
- prune_method: str = "operator_norm") -> _EffiNet:
212
-
213
- import captioning.models.conv_filter_pruning as pruning_lib
214
-
215
- blocks_args, global_params = efficientnet_utils.get_model_params(
216
- 'efficientnet-b2', {'include_top': False})
217
- # print("num blocks: ", len(blocks_args))
218
- # print("block args: ")
219
- # for block_arg in blocks_args:
220
- # print(block_arg)
221
- model = _EffiNet(blocks_args=blocks_args,
222
- global_params=global_params,
223
- prune_start_layer=prune_start_layer,
224
- prune_se=prune_se,
225
- prune_ratio=prune_ratio)
226
-
227
- if prune_method == "operator_norm":
228
- filter_pruning = pruning_lib.operator_norm_pruning
229
- elif prune_method == "interspeech":
230
- filter_pruning = pruning_lib.cs_interspeech
231
- elif prune_method == "iclr_l1":
232
- filter_pruning = pruning_lib.iclr_l1
233
- elif prune_method == "iclr_gm":
234
- filter_pruning = pruning_lib.iclr_gm
235
- elif prune_method == "cs_waspaa":
236
- filter_pruning = pruning_lib.cs_waspaa
237
-
238
-
239
- if isinstance(pretrained, str):
240
- ckpt = torch.load(pretrained, "cpu")
241
- state_dict = {}
242
- for key in ckpt["model"].keys():
243
- if key.startswith("model.encoder.backbone"):
244
- state_dict[key[len("model.encoder.backbone.eff_net."):]] = ckpt["model"][key]
245
- elif isinstance(pretrained, bool):
246
- model_path = os.path.join(model_dir, "effb2.pt")
247
- if not os.path.exists(model_path):
248
- state_dict = load_state_dict_from_url(
249
- 'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt',
250
- progress=True,
251
- model_dir=model_dir)
252
- else:
253
- state_dict = torch.load(model_path)
254
- del_keys = [key for key in state_dict if key.startswith("front_end")]
255
- for key in del_keys:
256
- del state_dict[key]
257
-
258
- # load pretrained model with corresponding filters
259
- # rule:
260
- # * depthwise_conv: in_ch_idx = out_ch_idx = prev_conv_idx
261
- mod_dep_path = [
262
- "_conv_stem",
263
- ]
264
- conv_to_bn = {"_conv_stem": "_bn0"}
265
- for i in range(2):
266
- mod_dep_path.extend([
267
- f"_blocks.{i}._depthwise_conv",
268
- f"_blocks.{i}._se_reduce",
269
- f"_blocks.{i}._se_expand",
270
- f"_blocks.{i}._project_conv",
271
- ])
272
- conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1"
273
- conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2"
274
-
275
- for i in range(2, 23):
276
- mod_dep_path.extend([
277
- f"_blocks.{i}._expand_conv",
278
- f"_blocks.{i}._depthwise_conv",
279
- f"_blocks.{i}._se_reduce",
280
- f"_blocks.{i}._se_expand",
281
- f"_blocks.{i}._project_conv"
282
- ])
283
- conv_to_bn[f"_blocks.{i}._expand_conv"] = f"_blocks.{i}._bn0"
284
- conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1"
285
- conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2"
286
-
287
- mod_dep_path.append("_conv_head")
288
- conv_to_bn["_conv_head"] = "_bn1"
289
-
290
- # print(mod_dep_path)
291
- # print(conv_to_bn)
292
-
293
- key_to_w_b_idx = {}
294
- model_dict = model.eff_net.state_dict()
295
- for conv_key in tqdm(mod_dep_path):
296
- weight = state_dict[f"{conv_key}.weight"]
297
- ptr_n_filter = weight.size(0)
298
- model_n_filter = model_dict[f"{conv_key}.weight"].size(0)
299
- if model_n_filter < ptr_n_filter:
300
- key_to_w_b_idx[conv_key] = filter_pruning(weight.numpy())[:model_n_filter]
301
- else:
302
- key_to_w_b_idx[conv_key] = slice(None)
303
-
304
- pruned_state_dict = {}
305
- for conv_key, prev_conv_key in zip(mod_dep_path, [None] + mod_dep_path[:-1]):
306
-
307
- for sub_key in ["weight", "bias"]: # adjust the conv layer
308
- cur_key = f"{conv_key}.{sub_key}"
309
-
310
- if cur_key not in state_dict:
311
- continue
312
-
313
- if prev_conv_key is None or conv_key.endswith("_depthwise_conv"):
314
- conv_in_idx = slice(None)
315
- else:
316
- conv_in_idx = key_to_w_b_idx[prev_conv_key]
317
-
318
- # the first pruned layer
319
- if model_dict[cur_key].ndim > 1 and model_dict[cur_key].size(1) == state_dict[cur_key].size(1):
320
- conv_in_idx = slice(None)
321
-
322
- if conv_key.endswith("_depthwise_conv"):
323
- conv_out_idx = key_to_w_b_idx[prev_conv_key]
324
- else:
325
- conv_out_idx = key_to_w_b_idx[conv_key]
326
-
327
- # if conv_key == "_blocks.16._se_reduce":
328
- # print(len(conv_out_idx), len(conv_in_idx))
329
-
330
- if sub_key == "weight":
331
- pruned_state_dict[cur_key] = state_dict[cur_key][
332
- conv_out_idx, ...][:, conv_in_idx, ...]
333
- else:
334
- pruned_state_dict[cur_key] = state_dict[cur_key][
335
- conv_out_idx, ...]
336
-
337
- if conv_key in conv_to_bn: # adjust the corresponding bn layer
338
- for sub_key in ["weight", "bias", "running_mean", "running_var"]:
339
- cur_key = f"{conv_to_bn[conv_key]}.{sub_key}"
340
- if cur_key not in state_dict:
341
- continue
342
- pruned_state_dict[cur_key] = state_dict[cur_key][
343
- key_to_w_b_idx[conv_key], ...]
344
-
345
- model.eff_net.load_state_dict(pruned_state_dict)
346
-
347
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/kd_wrapper.py DELETED
@@ -1,226 +0,0 @@
1
- from typing import Dict
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from einops import repeat
8
-
9
- from models.base import CaptionMetaMixin
10
- from utils.model_util import init
11
-
12
-
13
- class WmlEncoderKdWrapper(nn.Module, CaptionMetaMixin):
14
-
15
- def __init__(self,
16
- model: nn.Module,
17
- shared_dim: int,
18
- tchr_layer_to_dims: Dict[str, int],
19
- loss_type: str = "mse",):
20
- super().__init__()
21
- self.model = model
22
- self.tchr_layers = list(tchr_layer_to_dims.keys())
23
- self.stdnt_qv_proj = nn.Linear(model.encoder.fc_emb_size,
24
- 2 * shared_dim)
25
- self.stdnt_qv_proj.apply(init)
26
- for layer, dim in tchr_layer_to_dims.items():
27
- self.add_module(f'tchr_kv_proj_{layer}', nn.Linear(dim, 2 * shared_dim))
28
- getattr(self, f'tchr_kv_proj_{layer}').apply(init)
29
- if loss_type == "mse":
30
- self.loss_fn = nn.MSELoss(reduction="none")
31
-
32
- def forward(self, input_dict: Dict):
33
- output_dict = self.model(input_dict)
34
- if "tchr_output" in input_dict:
35
- stdnt_emb = output_dict["fc_emb"]
36
- stdnt_qv = self.stdnt_qv_proj(stdnt_emb)
37
- stdnt_q, stdnt_v = torch.chunk(stdnt_qv, 2, dim=-1)
38
-
39
- tchr_output = input_dict["tchr_output"]
40
- layer_ks, layer_vs = [], []
41
- for layer in self.tchr_layers:
42
- layer_kv = getattr(self, f'tchr_kv_proj_{layer}')(tchr_output[layer])
43
- layer_k, layer_v = torch.chunk(layer_kv, 2, dim=-1)
44
- layer_ks.append(layer_k)
45
- layer_vs.append(layer_v)
46
- layer_ks = torch.stack(layer_ks, dim=1)
47
- layer_vs = torch.stack(layer_vs, dim=1)
48
- weights = torch.softmax(stdnt_q.unsqueeze(1) @ layer_ks.transpose(1, 2), dim=-1)
49
- stdnt_v = repeat(stdnt_v, 'b d -> b n d', n=len(self.tchr_layers))
50
- loss = self.loss_fn(stdnt_v, layer_vs).mean(dim=-1, keepdim=True)
51
- loss = (weights @ loss).mean()
52
- output_dict["enc_kd_loss"] = loss
53
- return output_dict
54
-
55
-
56
- class MseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
57
-
58
- def __init__(self,
59
- model: nn.Module,
60
- shared_dim: int,
61
- tchr_dim: int,
62
- use_tchr_proj: bool = True,
63
- l2_norm: bool = False,
64
- ):
65
- super().__init__()
66
- self.model = model
67
- self.use_tchr_proj = use_tchr_proj
68
- if not use_tchr_proj:
69
- assert shared_dim == tchr_dim
70
- self.tchr_dim = tchr_dim
71
- self.l2_norm = l2_norm
72
- if hasattr(model, "encoder"):
73
- self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
74
- shared_dim)
75
- else:
76
- self.stdnt_proj = nn.Linear(model.fc_emb_size,
77
- shared_dim)
78
- self.stdnt_proj.apply(init)
79
- if use_tchr_proj:
80
- self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
81
- self.tchr_proj.apply(init)
82
- else:
83
- self.tchr_proj = nn.Identity()
84
-
85
- def forward(self, input_dict: Dict):
86
- unsup = input_dict.get("unsup", False)
87
- if unsup is False:
88
- if self.use_tchr_proj:
89
- output_dict = self.model(input_dict)
90
- stdnt_emb = output_dict["fc_emb"]
91
- else:
92
- encoder_output = self.model.encoder(input_dict)
93
- stdnt_emb = encoder_output["fc_emb"]
94
- encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
95
- encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
96
- output_dict = self.model.forward_decoder(input_dict, encoder_output)
97
- else:
98
- output_dict = self.model.encoder(input_dict)
99
- stdnt_emb = output_dict["fc_emb"]
100
- if "tchr_output" in input_dict:
101
- stdnt_emb = self.stdnt_proj(stdnt_emb)
102
- tchr_emb = input_dict["tchr_output"]["embedding"]
103
- thcr_emb = self.tchr_proj(tchr_emb)
104
-
105
- if self.l2_norm:
106
- stdnt_emb = F.normalize(stdnt_emb, dim=-1)
107
- thcr_emb = F.normalize(thcr_emb, dim=-1)
108
-
109
- loss = F.mse_loss(stdnt_emb, thcr_emb)
110
- output_dict["enc_kd_loss"] = loss
111
- return output_dict
112
-
113
-
114
- class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin):
115
-
116
- def __init__(self,
117
- model: nn.Module,
118
- shared_dim: int,
119
- tchr_dim: int,
120
- ):
121
- super().__init__()
122
- self.model = model
123
- self.tchr_dim = tchr_dim
124
- if hasattr(model, "encoder"):
125
- self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
126
- shared_dim)
127
- else:
128
- self.stdnt_proj = nn.Linear(model.fc_emb_size,
129
- shared_dim)
130
- self.stdnt_proj.apply(init)
131
- self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
132
- self.tchr_proj.apply(init)
133
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
134
-
135
- def forward(self, input_dict: Dict):
136
- unsup = input_dict.get("unsup", False)
137
- if unsup is False:
138
- output_dict = self.model(input_dict)
139
- else:
140
- output_dict = self.model.encoder(input_dict)
141
- if "tchr_output" in input_dict:
142
- stdnt_emb = output_dict["fc_emb"]
143
- stdnt_emb = self.stdnt_proj(stdnt_emb)
144
- tchr_emb = input_dict["tchr_output"]["embedding"]
145
- thcr_emb = self.tchr_proj(tchr_emb)
146
-
147
- stdnt_emb = F.normalize(stdnt_emb, dim=-1)
148
- thcr_emb = F.normalize(thcr_emb, dim=-1)
149
-
150
- unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
151
- logit = self.logit_scale * unscaled_logit
152
- label = torch.arange(logit.shape[0]).to(logit.device)
153
- loss1 = F.cross_entropy(logit, label)
154
- loss2 = F.cross_entropy(logit.transpose(0, 1), label)
155
- loss = (loss1 + loss2) / 2
156
- output_dict["enc_kd_loss"] = loss
157
- return output_dict
158
-
159
-
160
- class ContraMseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
161
-
162
- def __init__(self,
163
- model: nn.Module,
164
- shared_dim: int,
165
- tchr_dim: int,
166
- use_tchr_proj: bool = True,
167
- l2_norm: bool = False,
168
- ):
169
- super().__init__()
170
- self.model = model
171
- self.use_tchr_proj = use_tchr_proj
172
- if not use_tchr_proj:
173
- assert shared_dim == tchr_dim
174
- self.tchr_dim = tchr_dim
175
- self.l2_norm = l2_norm
176
- if hasattr(model, "encoder"):
177
- self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
178
- shared_dim)
179
- else:
180
- self.stdnt_proj = nn.Linear(model.fc_emb_size,
181
- shared_dim)
182
- self.stdnt_proj.apply(init)
183
- if use_tchr_proj:
184
- self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
185
- self.tchr_proj.apply(init)
186
- else:
187
- self.tchr_proj = nn.Identity()
188
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
189
-
190
- def forward(self, input_dict: Dict):
191
- unsup = input_dict.get("unsup", False)
192
- if unsup is False:
193
- if self.use_tchr_proj:
194
- output_dict = self.model(input_dict)
195
- stdnt_emb = output_dict["fc_emb"]
196
- else:
197
- encoder_output = self.model.encoder(input_dict)
198
- stdnt_emb = encoder_output["fc_emb"]
199
- encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
200
- encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
201
- output_dict = self.model.forward_decoder(input_dict, encoder_output)
202
- else:
203
- output_dict = self.model.encoder(input_dict)
204
- stdnt_emb = output_dict["fc_emb"]
205
- if "tchr_output" in input_dict:
206
- stdnt_emb = self.stdnt_proj(stdnt_emb)
207
- tchr_emb = input_dict["tchr_output"]["embedding"]
208
- thcr_emb = self.tchr_proj(tchr_emb)
209
-
210
- if self.l2_norm:
211
- stdnt_emb = F.normalize(stdnt_emb, dim=-1)
212
- thcr_emb = F.normalize(thcr_emb, dim=-1)
213
-
214
- mse_loss = F.mse_loss(stdnt_emb, thcr_emb)
215
-
216
- stdnt_emb = F.normalize(stdnt_emb, dim=-1)
217
- thcr_emb = F.normalize(thcr_emb, dim=-1)
218
- unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
219
- logit = self.logit_scale * unscaled_logit
220
- label = torch.arange(logit.shape[0]).to(logit.device)
221
- loss1 = F.cross_entropy(logit, label)
222
- loss2 = F.cross_entropy(logit.transpose(0, 1), label)
223
- cntr_loss = (loss1 + loss2) / 2
224
- output_dict["enc_kd_loss"] = mse_loss + cntr_loss
225
-
226
- return output_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/transformer_decoder.py DELETED
@@ -1,214 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
- from models import BaseDecoder
7
- from utils.model_util import generate_length_mask, PositionalEncoding
8
- from utils.train_util import merge_load_state_dict
9
-
10
-
11
- class TransformerDecoder(BaseDecoder):
12
-
13
- def __init__(self,
14
- emb_dim,
15
- vocab_size,
16
- fc_emb_dim,
17
- attn_emb_dim,
18
- dropout,
19
- freeze=False,
20
- tie_weights=False,
21
- **kwargs):
22
- super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
23
- dropout=dropout, tie_weights=tie_weights)
24
- self.d_model = emb_dim
25
- self.nhead = kwargs.get("nhead", self.d_model // 64)
26
- self.nlayers = kwargs.get("nlayers", 2)
27
- self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
28
-
29
- self.pos_encoder = PositionalEncoding(self.d_model, dropout)
30
- layer = nn.TransformerDecoderLayer(d_model=self.d_model,
31
- nhead=self.nhead,
32
- dim_feedforward=self.dim_feedforward,
33
- dropout=dropout)
34
- self.model = nn.TransformerDecoder(layer, self.nlayers)
35
- self.classifier = nn.Linear(self.d_model, vocab_size, bias=False)
36
- if tie_weights:
37
- self.classifier.weight = self.word_embedding.weight
38
- self.attn_proj = nn.Sequential(
39
- nn.Linear(self.attn_emb_dim, self.d_model),
40
- nn.ReLU(),
41
- nn.Dropout(dropout),
42
- nn.LayerNorm(self.d_model)
43
- )
44
- self.init_params()
45
-
46
- self.freeze = freeze
47
- if freeze:
48
- for p in self.parameters():
49
- p.requires_grad = False
50
-
51
- def init_params(self):
52
- for p in self.parameters():
53
- if p.dim() > 1:
54
- nn.init.xavier_uniform_(p)
55
-
56
- def load_pretrained(self, pretrained, output_fn):
57
- checkpoint = torch.load(pretrained, map_location="cpu")
58
-
59
- if "model" in checkpoint:
60
- checkpoint = checkpoint["model"]
61
- if next(iter(checkpoint)).startswith("decoder."):
62
- state_dict = {}
63
- for k, v in checkpoint.items():
64
- state_dict[k[8:]] = v
65
-
66
- loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
67
- if self.freeze:
68
- for name, param in self.named_parameters():
69
- if name in loaded_keys:
70
- param.requires_grad = False
71
- else:
72
- param.requires_grad = True
73
-
74
-
75
- def generate_square_subsequent_mask(self, max_length):
76
- mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
77
- mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
78
- return mask
79
-
80
- def forward(self, input_dict):
81
- word = input_dict["word"]
82
- attn_emb = input_dict["attn_emb"]
83
- attn_emb_len = input_dict["attn_emb_len"]
84
- cap_padding_mask = input_dict["cap_padding_mask"]
85
-
86
- p_attn_emb = self.attn_proj(attn_emb)
87
- p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
88
- word = word.to(attn_emb.device)
89
- embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
90
- embed = embed.transpose(0, 1) # [T, N, emb_dim]
91
- embed = self.pos_encoder(embed)
92
-
93
- tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
94
- memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
95
- output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
96
- tgt_key_padding_mask=cap_padding_mask,
97
- memory_key_padding_mask=memory_key_padding_mask)
98
- output = output.transpose(0, 1)
99
- output = {
100
- "embed": output,
101
- "logit": self.classifier(output),
102
- }
103
- return output
104
-
105
-
106
- class M2TransformerDecoder(BaseDecoder):
107
-
108
- def __init__(self, vocab_size, fc_emb_dim, attn_emb_dim, dropout=0.1, **kwargs):
109
- super().__init__(attn_emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout=dropout,)
110
- try:
111
- from m2transformer.models.transformer import MeshedDecoder
112
- except:
113
- raise ImportError("meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`")
114
- del self.word_embedding
115
- del self.in_dropout
116
-
117
- self.d_model = attn_emb_dim
118
- self.nhead = kwargs.get("nhead", self.d_model // 64)
119
- self.nlayers = kwargs.get("nlayers", 2)
120
- self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
121
- self.model = MeshedDecoder(vocab_size, 100, self.nlayers, 0,
122
- d_model=self.d_model,
123
- h=self.nhead,
124
- d_ff=self.dim_feedforward,
125
- dropout=dropout)
126
- self.init_params()
127
-
128
- def init_params(self):
129
- for p in self.parameters():
130
- if p.dim() > 1:
131
- nn.init.xavier_uniform_(p)
132
-
133
- def forward(self, input_dict):
134
- word = input_dict["word"]
135
- attn_emb = input_dict["attn_emb"]
136
- attn_emb_mask = input_dict["attn_emb_mask"]
137
- word = word.to(attn_emb.device)
138
- embed, logit = self.model(word, attn_emb, attn_emb_mask)
139
- output = {
140
- "embed": embed,
141
- "logit": logit,
142
- }
143
- return output
144
-
145
-
146
- class EventTransformerDecoder(TransformerDecoder):
147
-
148
- def forward(self, input_dict):
149
- word = input_dict["word"] # index of word embeddings
150
- attn_emb = input_dict["attn_emb"]
151
- attn_emb_len = input_dict["attn_emb_len"]
152
- cap_padding_mask = input_dict["cap_padding_mask"]
153
- event_emb = input_dict["event"] # [N, emb_dim]
154
-
155
- p_attn_emb = self.attn_proj(attn_emb)
156
- p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
157
- word = word.to(attn_emb.device)
158
- embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
159
-
160
- embed = embed.transpose(0, 1) # [T, N, emb_dim]
161
- embed += event_emb
162
- embed = self.pos_encoder(embed)
163
-
164
- tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
165
- memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
166
- output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
167
- tgt_key_padding_mask=cap_padding_mask,
168
- memory_key_padding_mask=memory_key_padding_mask)
169
- output = output.transpose(0, 1)
170
- output = {
171
- "embed": output,
172
- "logit": self.classifier(output),
173
- }
174
- return output
175
-
176
-
177
- class KeywordProbTransformerDecoder(TransformerDecoder):
178
-
179
- def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
180
- dropout, keyword_classes_num, **kwargs):
181
- super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
182
- dropout, **kwargs)
183
- self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
184
- self.word_keyword_norm = nn.LayerNorm(self.d_model)
185
-
186
- def forward(self, input_dict):
187
- word = input_dict["word"] # index of word embeddings
188
- attn_emb = input_dict["attn_emb"]
189
- attn_emb_len = input_dict["attn_emb_len"]
190
- cap_padding_mask = input_dict["cap_padding_mask"]
191
- keyword = input_dict["keyword"] # [N, keyword_classes_num]
192
-
193
- p_attn_emb = self.attn_proj(attn_emb)
194
- p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
195
- word = word.to(attn_emb.device)
196
- embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
197
-
198
- embed = embed.transpose(0, 1) # [T, N, emb_dim]
199
- embed += self.keyword_proj(keyword)
200
- embed = self.word_keyword_norm(embed)
201
-
202
- embed = self.pos_encoder(embed)
203
-
204
- tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
205
- memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
206
- output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
207
- tgt_key_padding_mask=cap_padding_mask,
208
- memory_key_padding_mask=memory_key_padding_mask)
209
- output = output.transpose(0, 1)
210
- output = {
211
- "embed": output,
212
- "logit": self.classifier(output),
213
- }
214
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/transformer_model.py DELETED
@@ -1,264 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import random
3
- import torch
4
- import torch.nn as nn
5
-
6
- from models.base import CaptionModel
7
- from utils.model_util import repeat_tensor
8
- import models.transformer_decoder
9
-
10
-
11
- class TransformerModel(CaptionModel):
12
-
13
- def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
14
- if not hasattr(self, "compatible_decoders"):
15
- self.compatible_decoders = (
16
- models.transformer_decoder.TransformerDecoder,
17
- )
18
- super().__init__(encoder, decoder, **kwargs)
19
-
20
- def seq_forward(self, input_dict):
21
- cap = input_dict["cap"]
22
- cap_padding_mask = (cap == self.pad_idx).to(cap.device)
23
- cap_padding_mask = cap_padding_mask[:, :-1]
24
- output = self.decoder(
25
- {
26
- "word": cap[:, :-1],
27
- "attn_emb": input_dict["attn_emb"],
28
- "attn_emb_len": input_dict["attn_emb_len"],
29
- "cap_padding_mask": cap_padding_mask
30
- }
31
- )
32
- return output
33
-
34
- def prepare_decoder_input(self, input_dict, output):
35
- decoder_input = {
36
- "attn_emb": input_dict["attn_emb"],
37
- "attn_emb_len": input_dict["attn_emb_len"]
38
- }
39
- t = input_dict["t"]
40
-
41
- ###############
42
- # determine input word
43
- ################
44
- if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
45
- word = input_dict["cap"][:, :t+1]
46
- else:
47
- start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
48
- if t == 0:
49
- word = start_word
50
- else:
51
- word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
52
- # word: [N, T]
53
- decoder_input["word"] = word
54
-
55
- cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
56
- decoder_input["cap_padding_mask"] = cap_padding_mask
57
- return decoder_input
58
-
59
- def prepare_beamsearch_decoder_input(self, input_dict, output_i):
60
- decoder_input = {}
61
- t = input_dict["t"]
62
- i = input_dict["sample_idx"]
63
- beam_size = input_dict["beam_size"]
64
- ###############
65
- # prepare attn embeds
66
- ################
67
- if t == 0:
68
- attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
69
- attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
70
- output_i["attn_emb"] = attn_emb
71
- output_i["attn_emb_len"] = attn_emb_len
72
- decoder_input["attn_emb"] = output_i["attn_emb"]
73
- decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
74
- ###############
75
- # determine input word
76
- ################
77
- start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
78
- if t == 0:
79
- word = start_word
80
- else:
81
- word = torch.cat((start_word, output_i["seq"]), dim=-1)
82
- decoder_input["word"] = word
83
- cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
84
- decoder_input["cap_padding_mask"] = cap_padding_mask
85
-
86
- return decoder_input
87
-
88
-
89
- class M2TransformerModel(CaptionModel):
90
-
91
- def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
92
- if not hasattr(self, "compatible_decoders"):
93
- self.compatible_decoders = (
94
- models.transformer_decoder.M2TransformerDecoder,
95
- )
96
- super().__init__(encoder, decoder, **kwargs)
97
- self.check_encoder_compatibility()
98
-
99
- def check_encoder_compatibility(self):
100
- assert isinstance(self.encoder, models.encoder.M2TransformerEncoder), \
101
- f"only M2TransformerModel is compatible with {self.__class__.__name__}"
102
-
103
- def seq_forward(self, input_dict):
104
- cap = input_dict["cap"]
105
- output = self.decoder(
106
- {
107
- "word": cap[:, :-1],
108
- "attn_emb": input_dict["attn_emb"],
109
- "attn_emb_mask": input_dict["attn_emb_mask"],
110
- }
111
- )
112
- return output
113
-
114
- def prepare_decoder_input(self, input_dict, output):
115
- decoder_input = {
116
- "attn_emb": input_dict["attn_emb"],
117
- "attn_emb_mask": input_dict["attn_emb_mask"]
118
- }
119
- t = input_dict["t"]
120
-
121
- ###############
122
- # determine input word
123
- ################
124
- if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
125
- word = input_dict["cap"][:, :t+1]
126
- else:
127
- start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
128
- if t == 0:
129
- word = start_word
130
- else:
131
- word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
132
- # word: [N, T]
133
- decoder_input["word"] = word
134
-
135
- return decoder_input
136
-
137
- def prepare_beamsearch_decoder_input(self, input_dict, output_i):
138
- decoder_input = {}
139
- t = input_dict["t"]
140
- i = input_dict["sample_idx"]
141
- beam_size = input_dict["beam_size"]
142
- ###############
143
- # prepare attn embeds
144
- ################
145
- if t == 0:
146
- attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
147
- attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
148
- output_i["attn_emb"] = attn_emb
149
- output_i["attn_emb_mask"] = attn_emb_mask
150
- decoder_input["attn_emb"] = output_i["attn_emb"]
151
- decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
152
- ###############
153
- # determine input word
154
- ################
155
- start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
156
- if t == 0:
157
- word = start_word
158
- else:
159
- word = torch.cat((start_word, output_i["seq"]), dim=-1)
160
- decoder_input["word"] = word
161
-
162
- return decoder_input
163
-
164
-
165
- class EventEncoder(nn.Module):
166
- """
167
- Encode the Label information in AudioCaps and AudioSet
168
- """
169
- def __init__(self, emb_dim, vocab_size=527):
170
- super(EventEncoder, self).__init__()
171
- self.label_embedding = nn.Parameter(
172
- torch.randn((vocab_size, emb_dim)), requires_grad=True)
173
-
174
- def forward(self, word_idxs):
175
- indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
176
- embeddings = indices @ self.label_embedding
177
- return embeddings
178
-
179
-
180
- class EventCondTransformerModel(TransformerModel):
181
-
182
- def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
183
- if not hasattr(self, "compatible_decoders"):
184
- self.compatible_decoders = (
185
- models.transformer_decoder.EventTransformerDecoder,
186
- )
187
- super().__init__(encoder, decoder, **kwargs)
188
- self.label_encoder = EventEncoder(decoder.emb_dim, 527)
189
- self.train_forward_keys += ["events"]
190
- self.inference_forward_keys += ["events"]
191
-
192
- # def seq_forward(self, input_dict):
193
- # cap = input_dict["cap"]
194
- # cap_padding_mask = (cap == self.pad_idx).to(cap.device)
195
- # cap_padding_mask = cap_padding_mask[:, :-1]
196
- # output = self.decoder(
197
- # {
198
- # "word": cap[:, :-1],
199
- # "attn_emb": input_dict["attn_emb"],
200
- # "attn_emb_len": input_dict["attn_emb_len"],
201
- # "cap_padding_mask": cap_padding_mask
202
- # }
203
- # )
204
- # return output
205
-
206
- def prepare_decoder_input(self, input_dict, output):
207
- decoder_input = super().prepare_decoder_input(input_dict, output)
208
- decoder_input["events"] = self.label_encoder(input_dict["events"])
209
- return decoder_input
210
-
211
- def prepare_beamsearch_decoder_input(self, input_dict, output_i):
212
- decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
213
- t = input_dict["t"]
214
- i = input_dict["sample_idx"]
215
- beam_size = input_dict["beam_size"]
216
- if t == 0:
217
- output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
218
- decoder_input["events"] = output_i["events"]
219
- return decoder_input
220
-
221
-
222
- class KeywordCondTransformerModel(TransformerModel):
223
-
224
- def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
225
- if not hasattr(self, "compatible_decoders"):
226
- self.compatible_decoders = (
227
- models.transformer_decoder.KeywordProbTransformerDecoder,
228
- )
229
- super().__init__(encoder, decoder, **kwargs)
230
- self.train_forward_keys += ["keyword"]
231
- self.inference_forward_keys += ["keyword"]
232
-
233
- def seq_forward(self, input_dict):
234
- cap = input_dict["cap"]
235
- cap_padding_mask = (cap == self.pad_idx).to(cap.device)
236
- cap_padding_mask = cap_padding_mask[:, :-1]
237
- keyword = input_dict["keyword"]
238
- output = self.decoder(
239
- {
240
- "word": cap[:, :-1],
241
- "attn_emb": input_dict["attn_emb"],
242
- "attn_emb_len": input_dict["attn_emb_len"],
243
- "keyword": keyword,
244
- "cap_padding_mask": cap_padding_mask
245
- }
246
- )
247
- return output
248
-
249
- def prepare_decoder_input(self, input_dict, output):
250
- decoder_input = super().prepare_decoder_input(input_dict, output)
251
- decoder_input["keyword"] = input_dict["keyword"]
252
- return decoder_input
253
-
254
- def prepare_beamsearch_decoder_input(self, input_dict, output_i):
255
- decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
256
- t = input_dict["t"]
257
- i = input_dict["sample_idx"]
258
- beam_size = input_dict["beam_size"]
259
- if t == 0:
260
- output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
261
- beam_size)
262
- decoder_input["keyword"] = output_i["keyword"]
263
- return decoder_input
264
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
 
1
  efficientnet_pytorch
2
- PyYAML
3
  torchaudio
4
- einops
 
1
+ transformers
2
  efficientnet_pytorch
 
3
  torchaudio
4
+ einops
text_tokenizer.py DELETED
@@ -1,107 +0,0 @@
1
- import pickle
2
- from pathlib import Path
3
-
4
- import numpy as np
5
- from utils.train_util import pad_sequence
6
-
7
-
8
- class DictTokenizer:
9
-
10
- def __init__(self,
11
- tokenizer_path: str = None,
12
- max_length: int = 20) -> None:
13
- self.word2idx = {}
14
- self.idx2word = {}
15
- self.idx = 0
16
- self.add_word("<pad>")
17
- self.add_word("<start>")
18
- self.add_word("<end>")
19
- self.add_word("<unk>")
20
- if tokenizer_path is not None and Path(tokenizer_path).exists():
21
- state_dict = pickle.load(open(tokenizer_path, "rb"))
22
- self.load_state_dict(state_dict)
23
- self.loaded = True
24
- else:
25
- self.loaded = False
26
- self.bos, self.eos = self.word2idx["<start>"], self.word2idx["<end>"]
27
- self.pad = self.word2idx["<pad>"]
28
- self.max_length = max_length
29
-
30
- def add_word(self, word):
31
- if not word in self.word2idx:
32
- self.word2idx[word] = self.idx
33
- self.idx2word[self.idx] = word
34
- self.idx += 1
35
-
36
- def encode_word(self, word):
37
- if word in self.word2idx:
38
- return self.word2idx[word]
39
- else:
40
- return self.word2idx["<unk>"]
41
-
42
- def __call__(self, texts):
43
- assert isinstance(texts, list), "the input must be List[str]"
44
- batch_tokens = []
45
- for text in texts:
46
- tokens = [self.encode_word(token) for token in text.split()][:self.max_length]
47
- tokens = [self.bos] + tokens + [self.eos]
48
- tokens = np.array(tokens)
49
- batch_tokens.append(tokens)
50
- caps, cap_lens = pad_sequence(batch_tokens, self.pad)
51
- return {
52
- "cap": caps,
53
- "cap_len": cap_lens
54
- }
55
-
56
- def decode(self, batch_token_ids):
57
- output = []
58
- for token_ids in batch_token_ids:
59
- tokens = []
60
- for token_id in token_ids:
61
- if token_id == self.eos:
62
- break
63
- elif token_id == self.bos:
64
- continue
65
- tokens.append(self.idx2word[token_id])
66
- output.append(" ".join(tokens))
67
- return output
68
-
69
- def __len__(self):
70
- return len(self.word2idx)
71
-
72
- def state_dict(self):
73
- return self.word2idx
74
-
75
- def load_state_dict(self, state_dict):
76
- self.word2idx = state_dict
77
- self.idx2word = {idx: word for word, idx in self.word2idx.items()}
78
- self.idx = len(self.word2idx)
79
-
80
-
81
- class HuggingfaceTokenizer:
82
-
83
- def __init__(self,
84
- model_name_or_path,
85
- max_length) -> None:
86
- from transformers import AutoTokenizer
87
- self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
88
- self.max_length = max_length
89
- self.bos, self.eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
90
- self.pad = self.tokenizer.pad_token_id
91
- self.loaded = True
92
-
93
- def __call__(self, texts):
94
- assert isinstance(texts, list), "the input must be List[str]"
95
- batch_token_dict = self.tokenizer(texts,
96
- padding=True,
97
- truncation=True,
98
- max_length=self.max_length,
99
- return_tensors="pt")
100
- batch_token_dict["cap"] = batch_token_dict["input_ids"]
101
- cap_lens = batch_token_dict["attention_mask"].sum(dim=1)
102
- cap_lens = cap_lens.numpy().astype(np.int32)
103
- batch_token_dict["cap_len"] = cap_lens
104
- return batch_token_dict
105
-
106
- def decode(self, batch_token_ids):
107
- return self.tokenizer.batch_decode(batch_token_ids, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/model_util.py DELETED
@@ -1,186 +0,0 @@
1
- import math
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
-
7
- from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
8
-
9
-
10
- def sort_pack_padded_sequence(input, lengths):
11
- sorted_lengths, indices = torch.sort(lengths, descending=True)
12
- tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
13
- inv_ix = indices.clone()
14
- inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
15
- return tmp, inv_ix
16
-
17
- def pad_unsort_packed_sequence(input, inv_ix):
18
- tmp, _ = pad_packed_sequence(input, batch_first=True)
19
- tmp = tmp[inv_ix]
20
- return tmp
21
-
22
- def pack_wrapper(module, attn_feats, attn_feat_lens):
23
- packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
24
- if isinstance(module, torch.nn.RNNBase):
25
- return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
26
- else:
27
- return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
28
-
29
- def generate_length_mask(lens, max_length=None):
30
- lens = torch.as_tensor(lens)
31
- N = lens.size(0)
32
- if max_length is None:
33
- max_length = max(lens)
34
- if isinstance(max_length, torch.Tensor):
35
- max_length = max_length.item()
36
- idxs = torch.arange(max_length).repeat(N).view(N, max_length)
37
- idxs = idxs.to(lens.device)
38
- mask = (idxs < lens.view(-1, 1))
39
- return mask
40
-
41
- def mean_with_lens(features, lens):
42
- """
43
- features: [N, T, ...] (assume the second dimension represents length)
44
- lens: [N,]
45
- """
46
- lens = torch.as_tensor(lens)
47
- if max(lens) != features.size(1):
48
- max_length = features.size(1)
49
- mask = generate_length_mask(lens, max_length)
50
- else:
51
- mask = generate_length_mask(lens)
52
- mask = mask.to(features.device) # [N, T]
53
-
54
- while mask.ndim < features.ndim:
55
- mask = mask.unsqueeze(-1)
56
- feature_mean = features * mask
57
- feature_mean = feature_mean.sum(1)
58
- while lens.ndim < feature_mean.ndim:
59
- lens = lens.unsqueeze(1)
60
- feature_mean = feature_mean / lens.to(features.device)
61
- # feature_mean = features * mask.unsqueeze(-1)
62
- # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
63
- return feature_mean
64
-
65
- def max_with_lens(features, lens):
66
- """
67
- features: [N, T, ...] (assume the second dimension represents length)
68
- lens: [N,]
69
- """
70
- lens = torch.as_tensor(lens)
71
- if max(lens) != features.size(1):
72
- max_length = features.size(1)
73
- mask = generate_length_mask(lens, max_length)
74
- else:
75
- mask = generate_length_mask(lens)
76
- mask = mask.to(features.device) # [N, T]
77
-
78
- feature_max = features.clone()
79
- feature_max[~mask] = float("-inf")
80
- feature_max, _ = feature_max.max(1)
81
- return feature_max
82
-
83
- def repeat_tensor(x, n):
84
- return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
85
-
86
- def init(m, method="kaiming"):
87
- if isinstance(m, (nn.Conv2d, nn.Conv1d)):
88
- if method == "kaiming":
89
- nn.init.kaiming_uniform_(m.weight)
90
- elif method == "xavier":
91
- nn.init.xavier_uniform_(m.weight)
92
- else:
93
- raise Exception(f"initialization method {method} not supported")
94
- if m.bias is not None:
95
- nn.init.constant_(m.bias, 0)
96
- elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
97
- nn.init.constant_(m.weight, 1)
98
- if m.bias is not None:
99
- nn.init.constant_(m.bias, 0)
100
- elif isinstance(m, nn.Linear):
101
- if method == "kaiming":
102
- nn.init.kaiming_uniform_(m.weight)
103
- elif method == "xavier":
104
- nn.init.xavier_uniform_(m.weight)
105
- else:
106
- raise Exception(f"initialization method {method} not supported")
107
- if m.bias is not None:
108
- nn.init.constant_(m.bias, 0)
109
- elif isinstance(m, nn.Embedding):
110
- if method == "kaiming":
111
- nn.init.kaiming_uniform_(m.weight)
112
- elif method == "xavier":
113
- nn.init.xavier_uniform_(m.weight)
114
- else:
115
- raise Exception(f"initialization method {method} not supported")
116
-
117
- def compute_batch_score(decode_res,
118
- key2refs,
119
- keys,
120
- start_idx,
121
- end_idx,
122
- vocabulary,
123
- scorer):
124
- """
125
- Args:
126
- decode_res: decoding results of model, [N, max_length]
127
- key2refs: references of all samples, dict(<key> -> [ref_1, ref_2, ..., ref_n]
128
- keys: keys of this batch, used to match decode results and refs
129
- Return:
130
- scores of this batch, [N,]
131
- """
132
-
133
- if scorer is None:
134
- from pycocoevalcap.cider.cider import Cider
135
- scorer = Cider()
136
-
137
- hypothesis = {}
138
- references = {}
139
-
140
- for i in range(len(keys)):
141
-
142
- if keys[i] in hypothesis.keys():
143
- continue
144
-
145
- # prepare candidate sentence
146
- candidate = []
147
- for w_t in decode_res[i]:
148
- if w_t == start_idx:
149
- continue
150
- elif w_t == end_idx:
151
- break
152
- candidate.append(vocabulary.idx2word[w_t])
153
-
154
- hypothesis[keys[i]] = [" ".join(candidate), ]
155
-
156
- # prepare reference sentences
157
- references[keys[i]] = key2refs[keys[i]]
158
-
159
- score, scores = scorer.compute_score(references, hypothesis)
160
- key2score = {key: scores[i] for i, key in enumerate(references.keys())}
161
- results = np.zeros(decode_res.shape[0])
162
- for i in range(decode_res.shape[0]):
163
- results[i] = key2score[keys[i]]
164
- return results
165
-
166
-
167
- class PositionalEncoding(nn.Module):
168
-
169
- def __init__(self, d_model, dropout=0.1, max_len=100):
170
- super(PositionalEncoding, self).__init__()
171
- self.dropout = nn.Dropout(p=dropout)
172
-
173
- pe = torch.zeros(max_len, d_model)
174
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
175
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
176
- (-math.log(10000.0) / d_model))
177
- pe[:, 0::2] = torch.sin(position * div_term)
178
- pe[:, 1::2] = torch.cos(position * div_term)
179
- pe = pe.unsqueeze(0).transpose(0, 1)
180
- # self.register_buffer("pe", pe)
181
- self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
182
-
183
- def forward(self, x):
184
- # x: [T, N, E]
185
- x = x + self.pe[:x.size(0), :]
186
- return self.dropout(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/train_util.py DELETED
@@ -1,117 +0,0 @@
1
- import importlib
2
- import os
3
- import sys
4
- from typing import Callable, Dict, Union
5
-
6
- import numpy as np
7
- import yaml
8
- import torch
9
-
10
-
11
- def merge_a_into_b(a, b):
12
- # merge dict a into dict b. values in a will overwrite b.
13
- for k, v in a.items():
14
- if isinstance(v, dict) and k in b:
15
- assert isinstance(
16
- b[k], dict
17
- ), "Cannot inherit key '{}' from base!".format(k)
18
- merge_a_into_b(v, b[k])
19
- else:
20
- b[k] = v
21
-
22
-
23
- def load_config(config_file):
24
- with open(config_file, "r") as reader:
25
- config = yaml.load(reader, Loader=yaml.FullLoader)
26
- if "inherit_from" in config:
27
- base_config_file = config["inherit_from"]
28
- base_config_file = os.path.join(
29
- os.path.dirname(config_file), base_config_file
30
- )
31
- assert not os.path.samefile(config_file, base_config_file), \
32
- "inherit from itself"
33
- base_config = load_config(base_config_file)
34
- del config["inherit_from"]
35
- merge_a_into_b(config, base_config)
36
- return base_config
37
- return config
38
-
39
- def get_cls_from_str(string, reload=False):
40
- module_name, cls_name = string.rsplit(".", 1)
41
- if reload:
42
- module_imp = importlib.import_module(module_name)
43
- importlib.reload(module_imp)
44
- return getattr(importlib.import_module(module_name, package=None), cls_name)
45
-
46
- def init_obj_from_dict(config, **kwargs):
47
- obj_args = config["args"].copy()
48
- obj_args.update(kwargs)
49
- for k in config:
50
- if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs:
51
- obj_args[k] = init_obj_from_dict(config[k])
52
- try:
53
- obj = get_cls_from_str(config["type"])(**obj_args)
54
- return obj
55
- except Exception as e:
56
- print(f"Initializing {config} failed, detailed error stack: ")
57
- raise e
58
-
59
- def init_model_from_config(config, print_fn=sys.stdout.write):
60
- kwargs = {}
61
- for k in config:
62
- if k not in ["type", "args", "pretrained"]:
63
- sub_model = init_model_from_config(config[k], print_fn)
64
- if "pretrained" in config[k]:
65
- load_pretrained_model(sub_model,
66
- config[k]["pretrained"],
67
- print_fn)
68
- kwargs[k] = sub_model
69
- model = init_obj_from_dict(config, **kwargs)
70
- return model
71
-
72
- def merge_load_state_dict(state_dict,
73
- model: torch.nn.Module,
74
- output_fn: Callable = sys.stdout.write):
75
- model_dict = model.state_dict()
76
- pretrained_dict = {}
77
- mismatch_keys = []
78
- for key, value in state_dict.items():
79
- if key in model_dict and model_dict[key].shape == value.shape:
80
- pretrained_dict[key] = value
81
- else:
82
- mismatch_keys.append(key)
83
- output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n")
84
- model_dict.update(pretrained_dict)
85
- model.load_state_dict(model_dict, strict=True)
86
- return pretrained_dict.keys()
87
-
88
-
89
- def load_pretrained_model(model: torch.nn.Module,
90
- pretrained: Union[str, Dict],
91
- output_fn: Callable = sys.stdout.write):
92
- if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
93
- output_fn(f"pretrained {pretrained} not exist!")
94
- return
95
-
96
- if hasattr(model, "load_pretrained"):
97
- model.load_pretrained(pretrained, output_fn)
98
- return
99
-
100
- if isinstance(pretrained, dict):
101
- state_dict = pretrained
102
- else:
103
- state_dict = torch.load(pretrained, map_location="cpu")
104
-
105
- if "model" in state_dict:
106
- state_dict = state_dict["model"]
107
-
108
- merge_load_state_dict(state_dict, model, output_fn)
109
-
110
- def pad_sequence(data, pad_value=0):
111
- if isinstance(data[0], (np.ndarray, torch.Tensor)):
112
- data = [torch.as_tensor(arr) for arr in data]
113
- padded_seq = torch.nn.utils.rnn.pad_sequence(data,
114
- batch_first=True,
115
- padding_value=pad_value)
116
- length = np.array([x.shape[0] for x in data])
117
- return padded_seq, length